diff --git a/.github/workflows/functional-tests.yml b/.github/workflows/functional-tests.yml new file mode 100644 index 000000000..58ef9f59b --- /dev/null +++ b/.github/workflows/functional-tests.yml @@ -0,0 +1,40 @@ +name: ๐Ÿงช RF-DETR Functional Tests + +on: + pull_request: + branches: [main, develop] + push: + branches: [develop] + +permissions: + contents: read + checks: write + +jobs: + functional-tests: + name: Run functional test suite + runs-on: ubuntu-latest + timeout-minutes: 20 + strategy: + matrix: + python-version: ["3.10"] + + steps: + - name: ๐Ÿ“ฅ Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + + - name: ๐Ÿ Install uv and set Python ${{ matrix.python-version }} + uses: astral-sh/setup-uv@bd01e18f51369d5a26f1651c3cb451d3417e3bba # v6.3.1 + with: + python-version: ${{ matrix.python-version }} + activate-environment: true + + - name: ๐Ÿ—๏ธ Install dependencies + run: | + uv pip install . + + - name: โœ… Run functional tests + run: | + uv run python functional_testing.py diff --git a/README.md b/README.md index 5f5e5cd2a..c0ccbdba6 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,22 @@ On image segmentation, RF-DETR Seg (Preview) is 3x faster and more accurate than - `2025/04/03`: We release early stopping, gradient checkpointing, metrics saving, training resume, TensorBoard and W&B logging support. - `2025/03/20`: We release RF-DETR real-time object detection model. **Code and checkpoint for RF-DETR-large and RF-DETR-base are available.** +## Advanced Features (Experimental) + +The following features are available when installing RF-DETR from source and enable more advanced research and customization use cases: + +- **IoU-aware query selection and adaptive query allocation** (detection) + - Optional improvements to query initialization inside the transformer. + - Controlled via `ModelConfig` / CLI flags: `use_iou_aware_query`, `adaptive_query_allocation`. + +- **Enhanced segmentation head with mask quality scoring** (segmentation) + - New head that adds mask quality prediction and dynamic refinement. + - Controlled via `ModelConfig` / CLI flags: `enhanced_segmentation`, `mask_quality_prediction`, `dynamic_mask_refinement`. + +- **Advanced data augmentations** (training) + - Mosaic, MixUp, and Copy-Paste augmentations implemented in `rfdetr.datasets.advanced_augmentations`. + - Can be plugged into custom training scripts for stronger data augmentation pipelines. + ## Results RF-DETR achieves state-of-the-art performance on both the Microsoft COCO and the RF100-VL benchmarks. @@ -180,6 +196,16 @@ You can fine-tune an RF-DETR Nano, Small, Medium, and Base model with a custom d Visit our [documentation website](https://rfdetr.roboflow.com) to learn more about how to use RF-DETR. +### Testing and CI + +For contributors, a small functional test suite is provided to validate core enhancements (IoU-aware queries, adaptive query allocation, and enhanced segmentation head): + +```bash +python functional_testing.py +``` + +GitHub Actions runs these tests automatically via the `functional-tests` workflow on pull requests targeting the main development branches. + ## License Both the code and the weights pretrained on the COCO dataset are released under the [Apache 2.0 license](https://github.com/roboflow/r-flow/blob/main/LICENSE). diff --git a/functional_testing.py b/functional_testing.py new file mode 100644 index 000000000..752918f06 --- /dev/null +++ b/functional_testing.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +"""Test script for RF-DETR enhancements. + +Validates IoU-aware query selection, adaptive query allocation, +and enhanced segmentation head integration. +""" + +import torch +import numpy as np +from PIL import Image +import sys +import os + +# Add rfdetr to path +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from rfdetr.models.iou_aware_query_selector import IoUAwareQuerySelector, AdaptiveQueryAllocator +from rfdetr.models.enhanced_segmentation_head import EnhancedSegmentationHead, AdaptiveMaskLoss + + +def test_iou_aware_query_selector(): + """Test IoU-aware query selector.""" + print("Testing IoU-aware query selector...") + + # Create dummy data + batch_size = 2 + num_queries = 300 + feature_dim = 256 + num_memory = 1000 + + # Initialize selector + selector = IoUAwareQuerySelector( + d_model=feature_dim, + num_queries=num_queries + ) + + # Create dummy inputs + memory = torch.randn(batch_size, num_memory, feature_dim) + spatial_shapes = torch.tensor([[32, 32], [16, 16], [8, 8], [4, 4]]) + level_start_index = torch.tensor([0, 1024, 1536, 1792]) + reference_points = torch.rand(batch_size, num_queries, 4) + + # Forward pass + try: + selected_features, scores = selector(memory, spatial_shapes, level_start_index, reference_points) + assert selected_features.shape == (batch_size, num_queries, feature_dim) + assert scores.shape == (batch_size, num_queries, 1) + print("โœ“ IoU-aware query selector test passed") + return True + except Exception as e: + print(f"โœ— IoU-aware query selector test failed: {e}") + return False + + +def test_adaptive_query_allocator(): + """Test adaptive query allocator.""" + print("Testing adaptive query allocator...") + + # Create dummy data + batch_size = 2 + num_queries = 300 + feature_dim = 256 + num_memory = 1000 + + # Initialize allocator + allocator = AdaptiveQueryAllocator(base_queries=num_queries) + + # Create dummy input + memory = torch.randn(batch_size, num_memory, feature_dim) + + # Forward pass + try: + allocated_queries = allocator(memory) + assert isinstance(allocated_queries, int) + assert 100 <= allocated_queries <= 600 # Should be within min/max range + print(f"โœ“ Adaptive query allocator test passed (allocated {allocated_queries} queries)") + return True + except Exception as e: + print(f"โœ— Adaptive query allocator test failed: {e}") + return False + + +def test_enhanced_segmentation_head(): + """Test enhanced segmentation head.""" + print("Testing enhanced segmentation head...") + + # Create dummy data + batch_size = 2 + num_queries = 100 + feature_dim = 256 + image_size = (512, 512) + + # Initialize enhanced segmentation head + seg_head = EnhancedSegmentationHead( + feature_dim=feature_dim, + num_layers=3, + use_quality_prediction=True, + use_dynamic_refinement=True + ) + + # Create dummy inputs + spatial_features = torch.randn(batch_size, feature_dim, 64, 64) + query_features = [torch.randn(batch_size, num_queries, feature_dim) for _ in range(3)] + bbox_features = torch.rand(batch_size, num_queries, 4) + + # Forward pass + try: + mask_logits, quality_scores = seg_head( + spatial_features, query_features, image_size, bbox_features + ) + + assert len(mask_logits) == 3 # Should have 3 layers + assert mask_logits[-1].shape == (batch_size, num_queries, 128, 128) + assert quality_scores is not None + assert quality_scores.shape == (batch_size, num_queries, 1) + print("โœ“ Enhanced segmentation head test passed") + return True + except Exception as e: + print(f"โœ— Enhanced segmentation head test failed: {e}") + return False + + + + +def test_integration(): + """Test integration with existing RF-DETR components.""" + print("Testing integration...") + + try: + # Test imports work correctly + from rfdetr.models.transformer import Transformer + from rfdetr.config import ModelConfig + + # Test configuration with new features + config = ModelConfig( + encoder="dinov2_windowed_small", + out_feature_indexes=[2, 5, 8, 11], + dec_layers=3, + projector_scale=["P3", "P4", "P5"], + hidden_dim=256, + patch_size=14, + num_windows=4, + sa_nheads=8, + ca_nheads=8, + dec_n_points=4, + resolution=640, + positional_encoding_size=10000, + use_iou_aware_query=True, + adaptive_query_allocation=True, + enhanced_segmentation=True, + mask_quality_prediction=True, + dynamic_mask_refinement=True + ) + + assert config.use_iou_aware_query == True + assert config.adaptive_query_allocation == True + assert config.enhanced_segmentation == True + + print("โœ“ Integration test passed") + return True + except Exception as e: + print(f"โœ— Integration test failed: {e}") + return False + + +def main(): + """Run all tests.""" + print("Running RF-DETR Enhancement Tests\n") + print("=" * 50) + + tests = [ + test_iou_aware_query_selector, + test_adaptive_query_allocator, + test_enhanced_segmentation_head, + test_integration, + ] + + passed = 0 + total = len(tests) + + for test in tests: + if test(): + passed += 1 + print() + + print("=" * 50) + print(f"Tests passed: {passed}/{total}") + + if passed == total: + print("๐ŸŽ‰ All tests passed! Ready for PR.") + return 0 + else: + print("โŒ Some tests failed. Please fix issues before creating PR.") + return 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/rfdetr/config.py b/rfdetr/config.py index 8999cdf24..fa8488123 100644 --- a/rfdetr/config.py +++ b/rfdetr/config.py @@ -37,6 +37,11 @@ class ModelConfig(BaseModel): cls_loss_coef: float = 1.0 segmentation_head: bool = False mask_downsample_ratio: int = 4 + use_iou_aware_query: bool = False + adaptive_query_allocation: bool = False + enhanced_segmentation: bool = False + mask_quality_prediction: bool = True + dynamic_mask_refinement: bool = True class RFDETRBaseConfig(ModelConfig): diff --git a/rfdetr/datasets/advanced_augmentations.py b/rfdetr/datasets/advanced_augmentations.py new file mode 100644 index 000000000..415635f91 --- /dev/null +++ b/rfdetr/datasets/advanced_augmentations.py @@ -0,0 +1,492 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +""" +Advanced Data Augmentation Pipeline for RF-DETR +This module implements advanced augmentation techniques including Mosaic, MixUp, +and Copy-Paste augmentations specifically designed for DETR-based models. +""" + +import random +import numpy as np +import torch +import torchvision.transforms.functional as F +from PIL import Image, ImageFilter, ImageEnhance +from typing import Dict, List, Tuple, Optional, Union +import cv2 + +from rfdetr.util.box_ops import box_xyxy_to_cxcywh, box_cxcywh_to_xyxy + + +class MosaicAugmentation: + """ + Mosaic augmentation adapted for DETR models. + Combines 4 images into one mosaic image. + """ + + def __init__(self, prob: float = 0.5): + """ + Initialize Mosaic augmentation. + + Args: + prob: Probability of applying mosaic augmentation + """ + self.prob = prob + + def __call__( + self, + images: List[Image.Image], + targets: List[Dict] + ) -> Tuple[Image.Image, Dict]: + """ + Apply mosaic augmentation to a batch of images. + + Args: + images: List of 4 PIL images + targets: List of corresponding target dictionaries + + Returns: + mosaic_image: Mosaic augmented image + mosaic_target: Combined target dictionary + """ + if random.random() > self.prob or len(images) < 4: + return images[0], targets[0] + + # Get dimensions + img0 = images[0] + w, h = img0.size + + # Create mosaic canvas + mosaic_img = Image.new('RGB', (w * 2, h * 2)) + + # Calculate split points + xc = random.randint(int(w * 0.5), int(w * 1.5)) + yc = random.randint(int(h * 0.5), int(h * 1.5)) + + # Positions for 4 images + positions = [ + (0, 0, xc, yc), # top-left + (xc, 0, w * 2, yc), # top-right + (0, yc, xc, h * 2), # bottom-left + (xc, yc, w * 2, h * 2) # bottom-right + ] + + mosaic_target = { + 'boxes': [], + 'labels': [], + 'area': [], + 'iscrowd': [], + 'orig_size': torch.tensor([h * 2, w * 2]), + 'size': torch.tensor([h * 2, w * 2]) + } + + # Process each image + for i, (img, target) in enumerate(zip(images[:4], targets[:4])): + x1, y1, x2, y2 = positions[i] + + # Resize and paste image + img_resized = img.resize((x2 - x1, y2 - y1), Image.BILINEAR) + mosaic_img.paste(img_resized, (x1, y1)) + + # Adjust bounding boxes + if 'boxes' in target: + boxes = target['boxes'] + # Convert from normalized to absolute coordinates + orig_w, orig_h = img.size + boxes_abs = boxes.clone() + boxes_abs[:, [0, 2]] *= orig_w + boxes_abs[:, [1, 3]] *= orig_h + + # Adjust for mosaic position + boxes_abs[:, [0, 2]] += x1 + boxes_abs[:, [1, 3]] += y1 + + # Normalize to mosaic canvas + boxes_abs[:, [0, 2]] /= (w * 2) + boxes_abs[:, [1, 3]] /= (h * 2) + + # Filter boxes that are within bounds + mask = (boxes_abs[:, 0] < 1) & (boxes_abs[:, 2] > 0) & \ + (boxes_abs[:, 1] < 1) & (boxes_abs[:, 3] > 0) + + if mask.any(): + mosaic_target['boxes'].extend(boxes_abs[mask]) + mosaic_target['labels'].extend(target['labels'][mask]) + + # Calculate area + area = (boxes_abs[mask, 2] - boxes_abs[mask, 0]) * \ + (boxes_abs[mask, 3] - boxes_abs[mask, 1]) + mosaic_target['area'].extend(area) + + if 'iscrowd' in target: + mosaic_target['iscrowd'].extend( + [target['iscrowd'][j] for j in range(len(target['iscrowd'])) if mask[j]] + ) + else: + mosaic_target['iscrowd'].extend([0] * mask.sum().item()) + + # Convert to tensors + if mosaic_target['boxes']: + mosaic_target['boxes'] = torch.stack(mosaic_target['boxes']) + mosaic_target['labels'] = torch.stack(mosaic_target['labels']) + mosaic_target['area'] = torch.stack(mosaic_target['area']) + mosaic_target['iscrowd'] = torch.tensor(mosaic_target['iscrowd']) + else: + # Empty target + mosaic_target['boxes'] = torch.empty((0, 4)) + mosaic_target['labels'] = torch.empty((0,), dtype=torch.long) + mosaic_target['area'] = torch.empty((0,)) + mosaic_target['iscrowd'] = torch.empty((0,), dtype=torch.long) + + return mosaic_img, mosaic_target + + +class MixUpAugmentation: + """ + MixUp augmentation for object detection. + Blends two images and their targets. + """ + + def __init__(self, prob: float = 0.5, alpha: float = 1.0): + """ + Initialize MixUp augmentation. + + Args: + prob: Probability of applying mixup + alpha: Alpha parameter for beta distribution + """ + self.prob = prob + self.alpha = alpha + + def __call__( + self, + img1: Image.Image, + target1: Dict, + img2: Image.Image, + target2: Dict + ) -> Tuple[Image.Image, Dict]: + """ + Apply mixup augmentation to two images. + + Args: + img1: First image + target1: First target + img2: Second image + target2: Second target + + Returns: + mixed_image: Mixup augmented image + mixed_target: Combined target dictionary + """ + if random.random() > self.prob: + return img1, target1 + + # Sample mixing coefficient + lam = np.random.beta(self.alpha, self.alpha) + + # Convert to tensors + img1_tensor = F.to_tensor(img1) + img2_tensor = F.to_tensor(img2) + + # Ensure same size + if img1_tensor.shape != img2_tensor.shape: + img2_tensor = F.resize(img2_tensor, img1_tensor.shape[-2:]) + + # Mix images + mixed_tensor = lam * img1_tensor + (1 - lam) * img2_tensor + mixed_image = F.to_pil_image(mixed_tensor) + + # Combine targets + mixed_target = { + 'boxes': torch.cat([target1['boxes'], target2['boxes']], dim=0), + 'labels': torch.cat([target1['labels'], target2['labels']], dim=0), + 'area': torch.cat([target1['area'], target2['area']], dim=0), + 'iscrowd': torch.cat([target1.get('iscrowd', torch.zeros_like(target1['labels'])), + target2.get('iscrowd', torch.zeros_like(target2['labels']))], dim=0), + 'orig_size': target1['orig_size'], + 'size': target1['size'] + } + + return mixed_image, mixed_target + + +class CopyPasteAugmentation: + """ + Copy-Paste augmentation for instance segmentation. + Copies objects from one image to another. + """ + + def __init__(self, prob: float = 0.5, max_objects: int = 5): + """ + Initialize Copy-Paste augmentation. + + Args: + prob: Probability of applying copy-paste + max_objects: Maximum number of objects to copy + """ + self.prob = prob + self.max_objects = max_objects + + def __call__( + self, + img1: Image.Image, + target1: Dict, + img2: Image.Image, + target2: Dict + ) -> Tuple[Image.Image, Dict]: + """ + Apply copy-paste augmentation. + + Args: + img1: Target image (where objects will be pasted) + target1: Target image target + img2: Source image (where objects will be copied from) + target2: Source image target + + Returns: + result_image: Image with pasted objects + result_target: Updated target dictionary + """ + if random.random() > self.prob or 'boxes' not in target2: + return img1, target1 + + # Convert to numpy arrays + img1_np = np.array(img1) + img2_np = np.array(img2) + + # Select random objects to copy + num_objects = min(len(target2['boxes']), self.max_objects) + if num_objects == 0: + return img1, target1 + + indices = random.sample(range(len(target2['boxes'])), num_objects) + + result_target = target1.copy() + pasted_boxes = [] + pasted_labels = [] + pasted_areas = [] + + for idx in indices: + box = target2['boxes'][idx] + label = target2['labels'][idx] + + # Convert to absolute coordinates + h1, w1 = img1_np.shape[:2] + h2, w2 = img2_np.shape[:2] + + box_abs = box.clone() + box_abs[[0, 2]] *= w2 + box_abs[[1, 3]] *= h2 + + x1, y1, x2, y2 = box_abs.int().tolist() + + # Ensure coordinates are within bounds + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = min(w2, x2), min(h2, y2) + + if x2 <= x1 or y2 <= y1: + continue + + # Extract object + obj_mask = np.zeros((h2, w2), dtype=np.uint8) + obj_mask[y1:y2, x1:x2] = 255 + + # Find random position in target image + obj_w, obj_h = x2 - x1, y2 - y1 + if obj_w > w1 or obj_h > h1: + continue + + paste_x = random.randint(0, w1 - obj_w) + paste_y = random.randint(0, h1 - obj_h) + + # Simple paste (without sophisticated blending) + try: + obj_region = img2_np[y1:y2, x1:x2] + img1_np[paste_y:paste_y+obj_h, paste_x:paste_x+obj_w] = obj_region + + # Add to target + new_box = torch.tensor([ + paste_x / w1, + paste_y / h1, + (paste_x + obj_w) / w1, + (paste_y + obj_h) / h1 + ]) + + pasted_boxes.append(new_box) + pasted_labels.append(label) + pasted_areas.append((new_box[2] - new_box[0]) * (new_box[3] - new_box[1])) + + except Exception: + continue + + # Update target + if pasted_boxes: + result_target['boxes'] = torch.cat([ + target1['boxes'], + torch.stack(pasted_boxes) + ], dim=0) + result_target['labels'] = torch.cat([ + target1['labels'], + torch.stack(pasted_labels) + ], dim=0) + result_target['area'] = torch.cat([ + target1['area'], + torch.stack(pasted_areas) + ], dim=0) + + iscrowd1 = target1.get('iscrowd', torch.zeros_like(target1['labels'])) + result_target['iscrowd'] = torch.cat([ + iscrowd1, + torch.zeros(len(pasted_labels), dtype=torch.long) + ], dim=0) + + result_image = Image.fromarray(img1_np) + return result_image, result_target + + +class AdvancedAugmentationPipeline: + """ + Comprehensive augmentation pipeline combining multiple techniques. + """ + + def __init__( + self, + mosaic_prob: float = 0.5, + mixup_prob: float = 0.3, + copypaste_prob: float = 0.3, + color_jitter: float = 0.2, + gaussian_blur: float = 0.1, + normalize_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406), + normalize_std: Tuple[float, float, float] = (0.229, 0.224, 0.225) + ): + """ + Initialize advanced augmentation pipeline. + + Args: + mosaic_prob: Probability of mosaic augmentation + mixup_prob: Probability of mixup augmentation + copypaste_prob: Probability of copy-paste augmentation + color_jitter: Color jitter strength + gaussian_blur: Gaussian blur probability + normalize_mean: Normalization mean values + normalize_std: Normalization std values + """ + self.mosaic = MosaicAugmentation(prob=mosaic_prob) + self.mixup = MixUpAugmentation(prob=mixup_prob) + self.copypaste = CopyPasteAugmentation(prob=copypaste_prob) + + self.color_jitter = T.ColorJitter( + brightness=color_jitter, + contrast=color_jitter, + saturation=color_jitter, + hue=color_jitter * 0.1 + ) + + self.gaussian_blur_prob = gaussian_blur + self.normalize = T.Normalize(mean=normalize_mean, std=normalize_std) + + def __call__( + self, + image: Image.Image, + target: Dict, + additional_images: Optional[List[Image.Image]] = None, + additional_targets: Optional[List[Dict]] = None + ) -> Tuple[torch.Tensor, Dict]: + """ + Apply advanced augmentation pipeline. + + Args: + image: Input image + target: Target dictionary + additional_images: Additional images for complex augmentations + additional_targets: Additional targets for complex augmentations + + Returns: + augmented_tensor: Augmented image tensor + augmented_target: Augmented target dictionary + """ + # Apply color jitter + if random.random() < 0.8: + image = self.color_jitter(image) + + # Apply Gaussian blur + if random.random() < self.gaussian_blur_prob: + image = image.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.1, 2.0))) + + # Apply complex augmentations if additional data is available + if additional_images and additional_targets: + # Try mosaic augmentation + if len(additional_images) >= 3: + mosaic_images = [image] + additional_images[:3] + mosaic_targets = [target] + additional_targets[:3] + image, target = self.mosaic(mosaic_images, mosaic_targets) + + # Try mixup augmentation + elif len(additional_images) >= 1: + image, target = self.mixup(image, target, additional_images[0], additional_targets[0]) + + # Try copy-paste augmentation + elif len(additional_images) >= 1: + image, target = self.copypaste(image, target, additional_images[0], additional_targets[0]) + + # Convert to tensor and normalize + tensor = F.to_tensor(image) + tensor = self.normalize(tensor) + + return tensor, target + + +def create_advanced_transforms( + train: bool = True, + image_size: Tuple[int, int] = (640, 640), + **augmentation_kwargs +) -> T.Compose: + """ + Create advanced transformation pipeline. + + Args: + train: Whether to apply training augmentations + image_size: Target image size + **augmentation_kwargs: Additional augmentation parameters + + Returns: + transforms: Composed transformation pipeline + """ + transforms_list = [] + + if train: + # Advanced augmentations + advanced_aug = AdvancedAugmentationPipeline(**augmentation_kwargs) + + def advanced_aug_wrapper(sample): + image = sample['image'] + target = sample['target'] + + # Handle additional images if available + additional_images = sample.get('additional_images') + additional_targets = sample.get('additional_targets') + + tensor, target = advanced_aug(image, target, additional_images, additional_targets) + + return { + 'image': tensor, + 'target': target + } + + transforms_list.append(advanced_aug_wrapper) + + # Resize + transforms_list.append(T.Resize(image_size)) + + # Convert to tensor (if not already done) + if not train: + transforms_list.append(T.ToTensor()) + transforms_list.append(T.Normalize( + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225) + )) + + return T.Compose(transforms_list) diff --git a/rfdetr/models/enhanced_segmentation_head.py b/rfdetr/models/enhanced_segmentation_head.py new file mode 100644 index 000000000..836c7124d --- /dev/null +++ b/rfdetr/models/enhanced_segmentation_head.py @@ -0,0 +1,378 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +""" +Enhanced Segmentation Head with Mask Quality Scoring +This module implements an improved segmentation head with mask quality prediction +and dynamic refinement for better instance segmentation performance. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Tuple, Optional + + +class MaskQualityPredictor(nn.Module): + """ + Predict mask quality scores for better mask ranking and selection. + """ + + def __init__(self, feature_dim: int = 256, hidden_dim: int = 128): + super().__init__() + self.predictor = nn.Sequential( + nn.Linear(feature_dim, hidden_dim), + nn.ReLU(inplace=True), + nn.Dropout(0.1), + nn.Linear(hidden_dim, hidden_dim // 2), + nn.ReLU(inplace=True), + nn.Dropout(0.1), + nn.Linear(hidden_dim // 2, 1), + nn.Sigmoid() + ) + + def forward(self, mask_features: torch.Tensor, bbox_features: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Predict mask quality scores. + + Args: + mask_features: Mask features (B, N, C) + bbox_features: Optional bbox features (B, N, 4) + + Returns: + quality_scores: Quality scores (B, N, 1) + """ + if bbox_features is not None: + combined_features = torch.cat([mask_features, bbox_features], dim=-1) + # Adjust input dimension if bbox features are concatenated + combined_dim = combined_features.shape[-1] + if combined_dim != self.predictor[0].in_features: + # Recreate first layer with correct dimensions + self.predictor[0] = nn.Linear(combined_dim, self.predictor[0].out_features) + return self.predictor(combined_features) + else: + return self.predictor(mask_features) + + +class DynamicMaskRefiner(nn.Module): + """ + Dynamically refine masks using attention mechanisms. + """ + + def __init__(self, feature_dim: int = 256, num_heads: int = 8): + super().__init__() + self.feature_dim = feature_dim + self.num_heads = num_heads + self.head_dim = feature_dim // num_heads + + assert feature_dim % num_heads == 0, "feature_dim must be divisible by num_heads" + + self.query_proj = nn.Linear(feature_dim, feature_dim) + self.key_proj = nn.Linear(feature_dim, feature_dim) + self.value_proj = nn.Linear(feature_dim, feature_dim) + self.out_proj = nn.Linear(feature_dim, feature_dim) + + self.norm1 = nn.LayerNorm(feature_dim) + self.norm2 = nn.LayerNorm(feature_dim) + + self.ffn = nn.Sequential( + nn.Linear(feature_dim, feature_dim * 4), + nn.ReLU(inplace=True), + nn.Dropout(0.1), + nn.Linear(feature_dim * 4, feature_dim) + ) + + def forward(self, query_features: torch.Tensor, context_features: torch.Tensor) -> torch.Tensor: + """ + Refine mask features using attention. + + Args: + query_features: Query features (B, N, C) + context_features: Context features (B, H*W, C) + + Returns: + refined_features: Refined features (B, N, C) + """ + B, N, C = query_features.shape + H_W = context_features.shape[1] + + # Multi-head attention + q = self.query_proj(query_features) # (B, N, C) + k = self.key_proj(context_features) # (B, H_W, C) + v = self.value_proj(context_features) # (B, H_W, C) + + # Reshape for multi-head attention + q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) # (B, num_heads, N, head_dim) + k = k.view(B, H_W, self.num_heads, self.head_dim).transpose(1, 2) # (B, num_heads, H_W, head_dim) + v = v.view(B, H_W, self.num_heads, self.head_dim).transpose(1, 2) # (B, num_heads, H_W, head_dim) + + # Attention + attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5) # (B, num_heads, N, H_W) + attn_weights = F.softmax(attn_scores, dim=-1) + + attn_output = torch.matmul(attn_weights, v) # (B, num_heads, N, head_dim) + attn_output = attn_output.transpose(1, 2).contiguous().view(B, N, C) # (B, N, C) + + # Output projection + refined = self.out_proj(attn_output) + + # Residual connection and layer norm + refined = self.norm1(refined + query_features) + + # Feed-forward network + refined = self.norm2(refined + self.ffn(refined)) + + return refined + + +class EnhancedSegmentationHead(nn.Module): + """ + Enhanced segmentation head with mask quality scoring and dynamic refinement. + """ + + def __init__( + self, + feature_dim: int = 256, + num_layers: int = 3, + downsample_ratio: int = 4, + use_quality_prediction: bool = True, + use_dynamic_refinement: bool = True, + num_refinement_heads: int = 8 + ): + super().__init__() + self.feature_dim = feature_dim + self.num_layers = num_layers + self.downsample_ratio = downsample_ratio + self.use_quality_prediction = use_quality_prediction + self.use_dynamic_refinement = use_dynamic_refinement + + # Mask quality predictor + if use_quality_prediction: + self.quality_predictor = MaskQualityPredictor(feature_dim) + + # Dynamic mask refiner + if use_dynamic_refinement: + self.mask_refiner = DynamicMaskRefiner(feature_dim, num_refinement_heads) + + # Enhanced mask generation layers + self.mask_layers = nn.ModuleList([ + nn.Sequential( + nn.Conv2d(feature_dim, feature_dim, 3, padding=1), + nn.GroupNorm(8, feature_dim), + nn.ReLU(inplace=True), + nn.Conv2d(feature_dim, feature_dim, 3, padding=1), + nn.GroupNorm(8, feature_dim), + nn.ReLU(inplace=True) + ) for _ in range(num_layers) + ]) + + # Feature projection for mask generation + self.spatial_proj = nn.Conv2d(feature_dim, feature_dim, 1) + self.query_proj = nn.Linear(feature_dim, feature_dim) + + # Final mask prediction + self.mask_predictor = nn.Conv2d(feature_dim, 1, 1) + + self._export = False + + def export(self): + """Export mode for deployment.""" + self._export = True + self._forward_origin = self.forward + self.forward = self.forward_export + + def forward( + self, + spatial_features: torch.Tensor, + query_features: List[torch.Tensor], + image_size: Tuple[int, int], + bbox_features: Optional[torch.Tensor] = None + ) -> Tuple[List[torch.Tensor], Optional[torch.Tensor]]: + """ + Forward pass of enhanced segmentation head. + + Args: + spatial_features: Spatial features (B, C, H, W) + query_features: Query features from decoder layers [(B, N, C)] + image_size: Original image size (H, W) + bbox_features: Optional bbox features for quality prediction (B, N, 4) + + Returns: + mask_logits: List of mask logits [(B, N, H', W')] + quality_scores: Optional quality scores (B, N, 1) + """ + target_size = (image_size[0] // self.downsample_ratio, image_size[1] // self.downsample_ratio) + spatial_features = F.interpolate(spatial_features, size=target_size, mode='bilinear', align_corners=False) + + # Apply enhanced mask layers + refined_spatial = spatial_features + for layer in self.mask_layers: + refined_spatial = layer(refined_spatial) + + # Project spatial features + spatial_proj = self.spatial_proj(refined_spatial) + + # Process query features + mask_logits = [] + quality_scores = None + + for i, qf in enumerate(query_features): + # Project query features + qf_proj = self.query_proj(qf) + + # Apply dynamic refinement if enabled and not the first layer + if self.use_dynamic_refinement and i > 0: + # Flatten spatial features for attention + B, C, H, W = spatial_proj.shape + spatial_flat = spatial_proj.view(B, C, H * W).transpose(1, 2) # (B, H*W, C) + + # Refine query features + qf_refined = self.mask_refiner(qf_proj, spatial_flat) + else: + qf_refined = qf_proj + + # Generate mask logits + mask_logit = torch.einsum('bchw,bnc->bnhw', spatial_proj, qf_refined) + mask_logits.append(mask_logit) + + # Predict mask quality if enabled + if self.use_quality_prediction and bbox_features is not None: + # Use the last layer's query features for quality prediction + quality_scores = self.quality_predictor(query_features[-1], bbox_features) + + return mask_logits, quality_scores + + def forward_export( + self, + spatial_features: torch.Tensor, + query_features: List[torch.Tensor], + image_size: Tuple[int, int], + bbox_features: Optional[torch.Tensor] = None + ) -> Tuple[List[torch.Tensor], Optional[torch.Tensor]]: + """ + Export-friendly forward pass. + """ + # Simplified version for export - only processes first query feature + target_size = (image_size[0] // self.downsample_ratio, image_size[1] // self.downsample_ratio) + spatial_features = F.interpolate(spatial_features, size=target_size, mode='bilinear', align_corners=False) + + # Apply enhanced mask layers + refined_spatial = spatial_features + for layer in self.mask_layers: + refined_spatial = layer(refined_spatial) + + # Project spatial features + spatial_proj = self.spatial_proj(refined_spatial) + + # Process only the first query feature + qf = self.query_proj(query_features[0]) + + # Generate mask logits + mask_logit = torch.einsum('bchw,bnc->bnhw', spatial_proj, qf) + + # Predict mask quality if enabled + quality_scores = None + if self.use_quality_prediction and bbox_features is not None: + quality_scores = self.quality_predictor(qf, bbox_features) + + return [mask_logit], quality_scores + + +class AdaptiveMaskLoss(nn.Module): + """ + Adaptive mask loss that considers mask quality. + """ + + def __init__(self, dice_weight: float = 1.0, ce_weight: float = 1.0, quality_weight: float = 0.1): + super().__init__() + self.dice_weight = dice_weight + self.ce_weight = ce_weight + self.quality_weight = quality_weight + + def dice_loss(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ + Compute dice loss. + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + targets = targets.flatten(1) + + intersection = (inputs * targets).sum(dim=1) + cardinality = (inputs + targets).sum(dim=1) + + dice_loss = 1 - (2. * intersection + 1e-6) / (cardinality + 1e-6) + return dice_loss.mean() + + def ce_loss(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ + Compute cross-entropy loss. + """ + return F.binary_cross_entropy_with_logits(inputs, targets) + + def quality_loss(self, quality_scores: torch.Tensor, dice_scores: torch.Tensor) -> torch.Tensor: + """ + Compute quality prediction loss. + """ + return F.mse_loss(quality_scores.squeeze(-1), dice_scores) + + def forward( + self, + mask_logits: List[torch.Tensor], + targets: torch.Tensor, + quality_scores: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, dict]: + """ + Compute adaptive mask loss. + + Args: + mask_logits: List of mask logits [(B, N, H, W)] + targets: Target masks (B, N, H, W) + quality_scores: Optional quality scores (B, N, 1) + + Returns: + total_loss: Combined loss + loss_dict: Individual loss components + """ + losses = {} + + # Use the last layer's predictions + pred_masks = mask_logits[-1] + + # Dice loss + dice_loss = self.dice_loss(pred_masks, targets) + losses['dice_loss'] = dice_loss + + # Cross-entropy loss + ce_loss = self.ce_loss(pred_masks, targets) + losses['ce_loss'] = ce_loss + + # Quality loss if quality scores are available + quality_loss = torch.tensor(0.0, device=pred_masks.device) + if quality_scores is not None: + # Calculate dice scores for each mask + with torch.no_grad(): + pred_masks_sigmoid = pred_masks.sigmoid() + dice_scores = [] + for i in range(pred_masks.shape[1]): # iterate over masks + pred_mask = pred_masks_sigmoid[:, i].flatten(1) + target_mask = targets[:, i].flatten(1) + intersection = (pred_mask * target_mask).sum(dim=1) + cardinality = (pred_mask + target_mask).sum(dim=1) + dice_score = (2. * intersection + 1e-6) / (cardinality + 1e-6) + dice_scores.append(dice_score) + dice_scores = torch.stack(dice_scores, dim=1) # (B, N) + + quality_loss = self.quality_loss(quality_scores, dice_scores) + losses['quality_loss'] = quality_loss + + # Combine losses + total_loss = ( + self.dice_weight * dice_loss + + self.ce_weight * ce_loss + + self.quality_weight * quality_loss + ) + + return total_loss, losses diff --git a/rfdetr/models/iou_aware_query_selector.py b/rfdetr/models/iou_aware_query_selector.py new file mode 100644 index 000000000..c0da4706f --- /dev/null +++ b/rfdetr/models/iou_aware_query_selector.py @@ -0,0 +1,318 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +""" +IoU-aware Query Selection for RF-DETR +This module implements IoU-aware query selection inspired by RT-DETR v2 +to improve object query initialization and focus on relevant objects. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Tuple + + +class IoUAwareQuerySelector(nn.Module): + """ + IoU-aware query selection mechanism for improved object query initialization. + + This module selects the most relevant queries based on IoU scores between + predicted boxes and reference points, improving detection accuracy especially + for small objects and reducing false positives. + """ + + def __init__( + self, + d_model: int = 256, + num_queries: int = 300, + num_layers: int = 3, + dropout: float = 0.1, + use_aware_score: bool = True, + alpha: float = 0.5 # balance between classification and localization + ): + """ + Initialize IoU-aware query selector. + + Args: + d_model: Feature dimension + num_queries: Number of object queries + num_layers: Number of MLP layers + dropout: Dropout rate + use_aware_score: Whether to use IoU-aware scoring + alpha: Balance factor between classification and localization scores + """ + super().__init__() + self.d_model = d_model + self.num_queries = num_queries + self.use_aware_score = use_aware_score + self.alpha = alpha + + # Classification confidence prediction + self.classifier = nn.Sequential( + nn.Linear(d_model, d_model), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(d_model, 1) # binary classification (object vs background) + ) + + # Bounding box regression + self.bbox_regressor = nn.Sequential( + nn.Linear(d_model, d_model), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(d_model, 4) # (cx, cy, w, h) + ) + + # IoU prediction network + self.iou_predictor = nn.Sequential( + nn.Linear(d_model + 4, d_model), # features + bbox coords + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(d_model, 1), + nn.Sigmoid() + ) + + # Query feature enhancement + self.query_enhancer = nn.Sequential( + nn.Linear(d_model, d_model), + nn.LayerNorm(d_model), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(d_model, d_model) + ) + + self._init_weights() + + def _init_weights(self): + """Initialize weights""" + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + memory: torch.Tensor, + spatial_shapes: torch.Tensor, + level_start_index: torch.Tensor, + reference_points: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for IoU-aware query selection. + + Args: + memory: Feature memory from encoder (bs, N, d_model) + spatial_shapes: Spatial shapes of feature levels (num_levels, 2) + level_start_index: Start index for each level (num_levels,) + reference_points: Reference points for queries (bs, num_queries, 4) + + Returns: + selected_features: Selected query features (bs, num_queries, d_model) + selection_scores: Selection scores (bs, N, 1) + """ + bs, N, d_model = memory.shape + + # Predict classification confidence + cls_scores = self.classifier(memory) # (bs, N, 1) + cls_scores = torch.sigmoid(cls_scores) + + # Predict bounding boxes + bbox_deltas = self.bbox_regressor(memory) # (bs, N, 4) + bbox_deltas = torch.sigmoid(bbox_deltas) # normalize to [0, 1] + + # Calculate IoU-aware scores + if self.use_aware_score and reference_points is not None: + # Use only the first num_queries reference points for memory positions + if reference_points.shape[1] < N: + # If we have fewer reference points than memory positions, repeat the reference points + repeat_times = (N + reference_points.shape[1] - 1) // reference_points.shape[1] + ref_points_expanded = reference_points.repeat(1, repeat_times, 1)[:, :N, :] + else: + # Use the first N reference points + ref_points_expanded = reference_points[:, :N, :] + + # Calculate IoU between predicted boxes and reference points + iou_scores = self._calculate_iou_aware_score( + bbox_deltas, ref_points_expanded, memory + ) + + # Combine classification and IoU scores + combined_scores = ( + self.alpha * cls_scores + + (1 - self.alpha) * iou_scores + ) + else: + combined_scores = cls_scores + + # Select top-K queries + top_k_scores, top_k_indices = torch.topk( + combined_scores.squeeze(-1), + k=min(self.num_queries, N), + dim=-1 + ) # (bs, num_queries) + + # Gather selected features + selected_features = torch.gather( + memory, + 1, + top_k_indices.unsqueeze(-1).expand(-1, -1, d_model) + ) # (bs, num_queries, d_model) + + # Enhance selected query features + enhanced_features = self.query_enhancer(selected_features) + enhanced_features = enhanced_features + selected_features # residual connection + + return enhanced_features, top_k_scores.unsqueeze(-1) + + def _calculate_iou_aware_score( + self, + pred_boxes: torch.Tensor, + ref_boxes: torch.Tensor, + features: torch.Tensor + ) -> torch.Tensor: + """ + Calculate IoU-aware scores using predicted boxes and reference points. + + Args: + pred_boxes: Predicted boxes (bs, N, 4) in (cx, cy, w, h) format + ref_boxes: Reference boxes (bs, N, 4) in (cx, cy, w, h) format + features: Feature vectors (bs, N, d_model) + + Returns: + iou_scores: IoU-aware scores (bs, N, 1) + """ + # Convert to (x1, y1, x2, y2) format for IoU calculation + pred_boxes_xyxy = self._cxcywh_to_xyxy(pred_boxes) + ref_boxes_xyxy = self._cxcywh_to_xyxy(ref_boxes) + + # Calculate IoU + iou = self._calculate_iou(pred_boxes_xyxy, ref_boxes_xyxy) # (bs, N) + + # Predict IoU using features and box coordinates + box_features = torch.cat([features, pred_boxes], dim=-1) # (bs, N, d_model + 4) + predicted_iou = self.iou_predictor(box_features).squeeze(-1) # (bs, N) + + # Combine geometric IoU with predicted IoU + combined_iou = 0.7 * iou + 0.3 * predicted_iou + + return combined_iou.unsqueeze(-1) + + def _cxcywh_to_xyxy(self, boxes: torch.Tensor) -> torch.Tensor: + """Convert from (cx, cy, w, h) to (x1, y1, x2, y2) format""" + cx, cy, w, h = boxes.unbind(-1) + x1 = cx - 0.5 * w + y1 = cy - 0.5 * h + x2 = cx + 0.5 * w + y2 = cy + 0.5 * h + return torch.stack([x1, y1, x2, y2], dim=-1) + + def _calculate_iou( + self, + boxes1: torch.Tensor, + boxes2: torch.Tensor + ) -> torch.Tensor: + """ + Calculate IoU between two sets of boxes. + + Args: + boxes1: First set of boxes (bs, N, 4) in (x1, y1, x2, y2) format + boxes2: Second set of boxes (bs, N, 4) in (x1, y1, x2, y2) format + + Returns: + iou: IoU values (bs, N) + """ + # Intersection + inter_x1 = torch.max(boxes1[..., 0], boxes2[..., 0]) + inter_y1 = torch.max(boxes1[..., 1], boxes2[..., 1]) + inter_x2 = torch.min(boxes1[..., 2], boxes2[..., 2]) + inter_y2 = torch.min(boxes1[..., 3], boxes2[..., 3]) + + inter_area = torch.clamp(inter_x2 - inter_x1, min=0) * \ + torch.clamp(inter_y2 - inter_y1, min=0) + + # Union + area1 = (boxes1[..., 2] - boxes1[..., 0]) * (boxes1[..., 3] - boxes1[..., 1]) + area2 = (boxes2[..., 2] - boxes2[..., 0]) * (boxes2[..., 3] - boxes2[..., 1]) + union_area = area1 + area2 - inter_area + + # IoU + iou = inter_area / torch.clamp(union_area, min=1e-7) + + return iou + + +class AdaptiveQueryAllocator(nn.Module): + """ + Adaptive query allocation based on image complexity. + + This module dynamically adjusts the number of queries based on + image complexity to improve efficiency and accuracy. + """ + + def __init__( + self, + base_queries: int = 300, + max_queries: int = 600, + min_queries: int = 100, + complexity_threshold: float = 0.5 + ): + """ + Initialize adaptive query allocator. + + Args: + base_queries: Base number of queries + max_queries: Maximum number of queries + min_queries: Minimum number of queries + complexity_threshold: Threshold for increasing queries + """ + super().__init__() + self.base_queries = base_queries + self.max_queries = max_queries + self.min_queries = min_queries + self.complexity_threshold = complexity_threshold + + # Complexity estimator + self.complexity_estimator = nn.Sequential( + nn.Linear(256, 128), + nn.ReLU(), + nn.Linear(128, 1), + nn.Sigmoid() + ) + + def forward(self, memory: torch.Tensor) -> int: + """ + Estimate image complexity and determine optimal number of queries. + + Args: + memory: Feature memory (bs, N, d_model) + + Returns: + num_queries: Optimal number of queries + """ + # Average pool features to get global representation + global_features = memory.mean(dim=1) # (bs, d_model) + + # Estimate complexity + complexity = self.complexity_estimator(global_features).mean() # scalar + + # Adaptive query allocation + if complexity > self.complexity_threshold: + num_queries = min( + int(self.base_queries * (1 + complexity)), + self.max_queries + ) + else: + num_queries = max( + int(self.base_queries * complexity), + self.min_queries + ) + + return num_queries diff --git a/rfdetr/models/lwdetr.py b/rfdetr/models/lwdetr.py index 95ae7e46b..f0ac6153d 100644 --- a/rfdetr/models/lwdetr.py +++ b/rfdetr/models/lwdetr.py @@ -35,6 +35,7 @@ from rfdetr.models.matcher import build_matcher from rfdetr.models.transformer import build_transformer from rfdetr.models.segmentation_head import SegmentationHead, get_uncertain_point_coords_with_randomness, point_sample +from rfdetr.models.enhanced_segmentation_head import EnhancedSegmentationHead, AdaptiveMaskLoss class LWDETR(nn.Module): """ This is the Group DETR v3 module that performs object detection """ @@ -812,7 +813,17 @@ def build_model(args): args.num_feature_levels = len(args.projector_scale) transformer = build_transformer(args) - segmentation_head = SegmentationHead(args.hidden_dim, args.dec_layers, downsample_ratio=args.mask_downsample_ratio) if args.segmentation_head else None + # Build enhanced segmentation head if enabled + if args.segmentation_head and getattr(args, 'enhanced_segmentation', False): + segmentation_head = EnhancedSegmentationHead( + feature_dim=args.hidden_dim, + num_layers=args.dec_layers, + downsample_ratio=args.mask_downsample_ratio, + use_quality_prediction=getattr(args, 'mask_quality_prediction', True), + use_dynamic_refinement=getattr(args, 'dynamic_mask_refinement', True) + ) + else: + segmentation_head = SegmentationHead(args.hidden_dim, args.dec_layers, downsample_ratio=args.mask_downsample_ratio) if args.segmentation_head else None model = LWDETR( backbone, diff --git a/rfdetr/models/transformer.py b/rfdetr/models/transformer.py index 343c6fcdc..3a2270426 100644 --- a/rfdetr/models/transformer.py +++ b/rfdetr/models/transformer.py @@ -24,6 +24,7 @@ from torch import nn, Tensor from rfdetr.models.ops.modules import MSDeformAttn +from rfdetr.models.iou_aware_query_selector import IoUAwareQuerySelector, AdaptiveQueryAllocator class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" @@ -136,7 +137,9 @@ def __init__(self, d_model=512, sa_nhead=8, ca_nhead=8, num_queries=300, num_feature_levels=4, dec_n_points=4, lite_refpoint_refine=False, decoder_norm_type='LN', - bbox_reparam=False): + bbox_reparam=False, + use_iou_aware_query=False, + adaptive_query_allocation=False): super().__init__() self.encoder = None @@ -165,6 +168,22 @@ def __init__(self, d_model=512, sa_nhead=8, ca_nhead=8, num_queries=300, self.enc_output = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(group_detr)]) self.enc_output_norm = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(group_detr)]) + # Initialize IoU-aware query selector + self.use_iou_aware_query = use_iou_aware_query + self.adaptive_query_allocation = adaptive_query_allocation + + if use_iou_aware_query: + self.iou_query_selector = IoUAwareQuerySelector( + d_model=d_model, + num_queries=num_queries, + dropout=dropout + ) + + if adaptive_query_allocation: + self.adaptive_allocator = AdaptiveQueryAllocator( + base_queries=num_queries + ) + self._reset_parameters() self.num_queries = num_queries @@ -264,6 +283,45 @@ def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat): memory_ts = torch.cat(memory_ts, dim=1)#.transpose(0, 1) boxes_ts = torch.cat(boxes_ts, dim=1)#.transpose(0, 1) + # Apply adaptive query allocation if enabled + if self.adaptive_query_allocation and not self.two_stage: + # Dynamically adjust number of queries based on image complexity + adaptive_num_queries = self.adaptive_allocator(memory) + if adaptive_num_queries != self.num_queries: + # Adjust query embeddings based on adaptive allocation + if adaptive_num_queries < self.num_queries: + # Reduce queries + query_feat = query_feat[:adaptive_num_queries] + refpoint_embed = refpoint_embed[:adaptive_num_queries] + elif adaptive_num_queries > self.num_queries: + # Increase queries (pad with learned embeddings) + padding_size = adaptive_num_queries - self.num_queries + query_padding = query_feat[:padding_size] # reuse first few embeddings + refpoint_padding = refpoint_embed[:padding_size] + query_feat = torch.cat([query_feat, query_padding], dim=0) + refpoint_embed = torch.cat([refpoint_embed, refpoint_padding], dim=0) + + # Apply IoU-aware query selection if enabled + if self.use_iou_aware_query and not self.two_stage: + # Use IoU-aware query selection to improve query initialization + selected_memory, selection_scores = self.iou_query_selector( + memory=memory, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + reference_points=refpoint_embed.unsqueeze(0).repeat(bs, 1, 1) + ) + + # Update memory with selected features + memory = torch.cat([memory, selected_memory], dim=1) + + # Update spatial shapes and level start index + new_spatial_shape = torch.tensor([self.num_queries, 1], device=spatial_shapes.device) + spatial_shapes = torch.cat([spatial_shapes, new_spatial_shape.unsqueeze(0)], dim=0) + level_start_index = torch.cat([ + level_start_index, + torch.tensor([spatial_shapes[:-1].prod(1).sum()], device=level_start_index.device) + ], dim=0) + if self.dec_layers > 0: tgt = query_feat.unsqueeze(0).repeat(bs, 1, 1) refpoint_embed = refpoint_embed.unsqueeze(0).repeat(bs, 1, 1) @@ -577,6 +635,8 @@ def build_transformer(args): lite_refpoint_refine=args.lite_refpoint_refine, decoder_norm_type=args.decoder_norm, bbox_reparam=args.bbox_reparam, + use_iou_aware_query=getattr(args, 'use_iou_aware_query', False), + adaptive_query_allocation=getattr(args, 'adaptive_query_allocation', False), )