diff --git a/rfdetr/config.py b/rfdetr/config.py index 12f7fd7b5..6efd9d0b4 100644 --- a/rfdetr/config.py +++ b/rfdetr/config.py @@ -4,14 +4,39 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ +from pydantic import BaseModel, field_validator, model_validator +from pydantic_core.core_schema import ValidationInfo # for field_validator(info) +from typing import List, Optional, Literal +import os, torch -from pydantic import BaseModel -from typing import List, Optional, Literal, Type -import torch DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +# centralize all supported encoder names (add dinov3). +EncoderName = Literal[ + "dinov2_windowed_small", + "dinov2_windowed_base", + "dinov3_small", + "dinov3_base", + "dinov3_large", +] + +def _encoder_default(): + """Default encoder name for the model config.""" + # default to v2 unless explicitly overridden by env + val = os.getenv("RFD_ENCODER", "").strip() or "dinov2_windowed_small" + + # guardrail: only accept known names + allowed = { + "dinov2_windowed_small","dinov2_windowed_base", + "dinov3_small","dinov3_base","dinov3_large" + } + return val if val in allowed else "dinov2_windowed_small" + class ModelConfig(BaseModel): - encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] + """Base configuration for RF-DETR models.""" + # WAS: only dinov2_windowed_*; NOW: include dinov3_* as drop-in options + encoder: EncoderName = _encoder_default() + out_feature_indexes: List[int] dec_layers: int two_stage: bool = True @@ -33,15 +58,82 @@ class ModelConfig(BaseModel): group_detr: int = 13 gradient_checkpointing: bool = False positional_encoding_size: int + # used only when encoder startswith("dinov3") + dinov3_repo_dir: Optional[str] = None # e.g., r"D:\repos\dinov3" + dinov3_weights_path: Optional[str] = None # e.g., r"C:\models\dinov3-vitb16.pth" + dinov3_hf_token: Optional[str] = None # or rely on HUGGINGFACE_HUB_TOKEN + dinov3_prefer_hf: bool = True # try HF first, then hub fallback + # force /16 for v3 + @field_validator("patch_size", mode="after") + def _coerce_patch_for_dinov3(cls, v, info: ValidationInfo): + """Ensure patch size is 16 for DINOv3 encoders.""" + enc = str(info.data.get("encoder", "")) + return 16 if enc.startswith("dinov3") else v + + # keep pos-encoding grid consistent with resolution / patch + @field_validator("positional_encoding_size", mode="after") + def _sync_pos_enc_with_resolution(cls, v, info: ValidationInfo): + """Sync positional encoding size with resolution and patch size.""" + values = info.data or {} + res, ps = values.get("resolution"), values.get("patch_size") + return max(1, res // ps) if (res and ps) else v + + # env fallbacks for local repo/weights when *not* preferring HF + @field_validator("dinov3_repo_dir", "dinov3_weights_path", mode="after") + def _fallback_to_env(cls, v, info: ValidationInfo): + """Fallback to environment variables if not set.""" + values = info.data or {} + if (not v) and str(values.get("encoder","")).startswith("dinov3") and not values.get("dinov3_prefer_hf", True): + env_map = {"dinov3_repo_dir": "DINOV3_REPO", "dinov3_weights_path": "DINOV3_WEIGHTS"} + env_key = env_map[info.field_name] + return os.getenv(env_key, v) + return v + + # neutralize windowing for v3 (avoid accidental asserts downstream) + @field_validator("num_windows", mode="after") + def _neutralize_windows_for_dinov3(cls, v, info: ValidationInfo): + """Neutralize windowing for DINOv3 encoders.""" + enc = str((info.data or {}).get("encoder","")) + return 1 if enc.startswith("dinov3") else v + + # auto-fit out_feature_indexes to avoid projector shape mismatches + @field_validator("out_feature_indexes", mode="after") + def _coerce_out_feats_for_backbone(cls, v, info: ValidationInfo): + """Ensure out_feature_indexes are compatible with the encoder.""" + enc = str((info.data or {}).get("encoder","")) + if enc.startswith("dinov3"): + # DINOv3 path: default to fewer, stable high-level features + return v if len(v) in (2,) else [8, 11] + return v + + # Final safety net: once the whole model is built, normalize settings for DINOv3. + @model_validator(mode="after") + def _final_autofix_for_dinov3(self): + """Final adjustments after model construction.""" + enc = str(self.encoder) + if enc.startswith("dinov3"): + # enforce /16 patch + matching pos-enc grid + self.patch_size = 16 + if self.resolution: + self.positional_encoding_size = max(1, self.resolution // self.patch_size) + # windowing is a no-op for v3 + self.num_windows = 1 + # most important: use 2 high-level features to match projector weights across v2/v3 + if len(self.out_feature_indexes) != 2: + self.out_feature_indexes = [8, 11] + return self + class RFDETRBaseConfig(ModelConfig): """ The configuration for an RF-DETR Base model. """ - encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_small" + # Allow choosing dinov3_* without changing call sites + encoder: EncoderName = _encoder_default() + print("Using RFDETRBaseConfig with encoder:", encoder) hidden_dim: int = 256 - patch_size: int = 14 - num_windows: int = 4 + patch_size: int = 14 # will auto-become 16 if encoder startswith("dinov3") + num_windows: int = 4 # ignored by DINOv3 branch dec_layers: int = 3 sa_nheads: int = 8 ca_nheads: int = 16 @@ -49,16 +141,18 @@ class RFDETRBaseConfig(ModelConfig): num_queries: int = 300 num_select: int = 300 projector_scale: List[Literal["P3", "P4", "P5"]] = ["P4"] - out_feature_indexes: List[int] = [2, 5, 8, 11] + out_feature_indexes: List[int] = [2, 4, 5, 9] pretrain_weights: Optional[str] = "rf-detr-base.pth" - resolution: int = 560 - positional_encoding_size: int = 37 + #resolution: int = 504 # 560//16=35 when dinov3_* is used + resolution: int = 512 # 512//16=32 → pos grid auto=32 for both v2/v3 + positional_encoding_size: int = 36 # will auto-sync to resolution//patch_size + class RFDETRLargeConfig(RFDETRBaseConfig): """ The configuration for an RF-DETR Large model. """ - encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_base" + encoder: EncoderName = "dinov2_windowed_base" hidden_dim: int = 384 sa_nheads: int = 12 ca_nheads: int = 24 @@ -66,6 +160,7 @@ class RFDETRLargeConfig(RFDETRBaseConfig): projector_scale: List[Literal["P3", "P4", "P5"]] = ["P3", "P5"] pretrain_weights: Optional[str] = "rf-detr-large.pth" + class RFDETRNanoConfig(RFDETRBaseConfig): """ The configuration for an RF-DETR Nano model. @@ -74,10 +169,11 @@ class RFDETRNanoConfig(RFDETRBaseConfig): num_windows: int = 2 dec_layers: int = 2 patch_size: int = 16 - resolution: int = 384 + resolution: int = 384 # 384//16=24 → pos grid auto=24 for both v2/v3 positional_encoding_size: int = 24 pretrain_weights: Optional[str] = "rf-detr-nano.pth" + class RFDETRSmallConfig(RFDETRBaseConfig): """ The configuration for an RF-DETR Small model. @@ -86,10 +182,11 @@ class RFDETRSmallConfig(RFDETRBaseConfig): num_windows: int = 2 dec_layers: int = 3 patch_size: int = 16 - resolution: int = 512 + resolution: int = 512 # 512//16=32 → pos grid auto=32 positional_encoding_size: int = 32 pretrain_weights: Optional[str] = "rf-detr-small.pth" + class RFDETRMediumConfig(RFDETRBaseConfig): """ The configuration for an RF-DETR Medium model. @@ -98,10 +195,12 @@ class RFDETRMediumConfig(RFDETRBaseConfig): num_windows: int = 2 dec_layers: int = 4 patch_size: int = 16 - resolution: int = 576 + #resolution: int = 504 # 576//16=36 → pos grid auto=36 + resolution: int = 512 positional_encoding_size: int = 36 pretrain_weights: Optional[str] = "rf-detr-medium.pth" + class TrainConfig(BaseModel): lr: float = 1e-4 lr_encoder: float = 1.5e-4 diff --git a/rfdetr/engine.py b/rfdetr/engine.py index 31e68cae3..57822d44f 100644 --- a/rfdetr/engine.py +++ b/rfdetr/engine.py @@ -21,7 +21,6 @@ import sys from typing import Iterable import random - import torch import torch.nn.functional as F @@ -39,12 +38,16 @@ from rfdetr.util.misc import NestedTensor import numpy as np + def get_autocast_args(args): + """Return autocast arguments based on the DEPRECATED_AMP flag and args.""" + use_cuda = torch.cuda.is_available() + enabled = bool(getattr(args, "amp", False) and use_cuda) if DEPRECATED_AMP: - return {'enabled': args.amp, 'dtype': torch.bfloat16} + return {"enabled": enabled, "dtype": torch.bfloat16} else: - return {'device_type': 'cuda', 'enabled': args.amp, 'dtype': torch.bfloat16} - + # only use CUDA autocast when CUDA exists + return {"device_type": "cuda", "enabled": enabled, "dtype": torch.bfloat16} def train_one_epoch( model: torch.nn.Module, @@ -75,11 +78,11 @@ def train_one_epoch( print("Grad accum steps: ", args.grad_accum_steps) print("Total batch size: ", batch_size * utils.get_world_size()) - # Add gradient scaler for AMP + use_amp = bool(getattr(args, "amp", False) and torch.cuda.is_available()) if DEPRECATED_AMP: - scaler = GradScaler(enabled=args.amp) + scaler = GradScaler(enabled=use_amp) else: - scaler = GradScaler('cuda', enabled=args.amp) + scaler = GradScaler("cuda", enabled=use_amp) optimizer.zero_grad() assert batch_size % args.grad_accum_steps == 0 @@ -113,7 +116,9 @@ def train_one_epoch( scales = compute_multi_scale_scales(args.resolution, args.expanded_scales, args.patch_size, args.num_windows) random.seed(it) scale = random.choice(scales) - with torch.inference_mode(): + # DO NOT use torch.inference_mode() here; it creates inference tensors that + # are incompatible with subsequent training operations, so use torch.no_grad(). + with torch.no_grad(): samples.tensors = F.interpolate(samples.tensors, size=scale, mode='bilinear', align_corners=False) samples.mask = F.interpolate(samples.mask.unsqueeze(1).float(), size=scale, mode='nearest').squeeze(1).bool() @@ -124,16 +129,17 @@ def train_one_epoch( new_samples = NestedTensor(new_samples_tensors, samples.mask[start_idx:final_idx]) new_samples = new_samples.to(device) new_targets = [{k: v.to(device) for k, v in t.items()} for t in targets[start_idx:final_idx]] - - with autocast(**get_autocast_args(args)): - outputs = model(new_samples, new_targets) - loss_dict = criterion(outputs, new_targets) - weight_dict = criterion.weight_dict - losses = sum( - (1 / args.grad_accum_steps) * loss_dict[k] * weight_dict[k] - for k in loss_dict.keys() - if k in weight_dict - ) + torch.set_grad_enabled(True) # safety + with torch.inference_mode(False): + with autocast(**get_autocast_args(args)): + outputs = model(new_samples, new_targets) + loss_dict = criterion(outputs, new_targets) + weight_dict = criterion.weight_dict + losses = sum( + (1 / args.grad_accum_steps) * loss_dict[k] * weight_dict[k] + for k in loss_dict.keys() + if k in weight_dict + ) scaler.scale(losses).backward() diff --git a/rfdetr/inference_test.py b/rfdetr/inference_test.py new file mode 100644 index 000000000..d790951ea --- /dev/null +++ b/rfdetr/inference_test.py @@ -0,0 +1,118 @@ + +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) +# Copyright (c) 2024 Baidu. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ + +""" Demo Inference script for RF-DETR with easy switch between DINOv2 and DINOv3(local repo).""" +import os +import io +import requests +import supervision as sv +from PIL import Image + +ENCODER_ALIASES = { + "dinov2": "dinov2_windowed_small", + "v2": "dinov2_windowed_small", + "dinov2_small": "dinov2_windowed_small", + "dinov2_base": "dinov2_windowed_base", + "dinov3": "dinov3_base", + "v3": "dinov3_base", +} + +VALID_ENCODERS = { + "dinov2_windowed_small", + "dinov2_windowed_base", + "dinov3_small", + "dinov3_base", + "dinov3_large", +} + +def resolve_encoder(enc_str: str) -> str: + """Resolve the encoder string to a valid encoder name. + Args: + enc_str (str): The encoder string to resolve. + + Returns: + str: The resolved encoder name. + Examples: + resolve_encoder("v2") # returns "dinov2_windowed_small" + """ + enc_str = enc_str.strip().lower() + enc = ENCODER_ALIASES.get(enc_str, enc_str) + if enc not in VALID_ENCODERS: + raise ValueError(f"Unknown encoder '{enc_str}'. Valid: {sorted(list(VALID_ENCODERS))}") + return enc + +## If using DINOv3, ensure you have the local repo and weights set up. +def main(encoder:str = "v3", repo_dir: str = "./dinov3_repo", dino_v_weights_path: str = "./dinov3_weights.pth"): + """Main function to run the inference demo. + + Args: + encoder (str): The encoder to use, e.g., "v2", "v3", or exact name. + repo_dir (str): Path to the local DINOv3 repository. + dino_v_weights_path (str): Path to the DINOv3 weights file. + + Returns: + None + + Examples: + main(encoder="v2") + main(encoder="v3", repo_dir="D:/repos/dinov3", dino_v_weights_path="D:/repos/dinov3/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth") + """ + encoder = resolve_encoder(encoder) + + if encoder.startswith("dinov3"): + print("Using DINOv3 encoder:", encoder) + # Set the environment variables for DINOv3 repo and weights + os.environ["DINOV3_REPO"] = repo_dir + os.environ["DINOV3_WEIGHTS"] = dino_v_weights_path + elif encoder.startswith("dinov2"): + print("Using DINOv2 encoder:", encoder) + + # Set env *before* importing your package (your Pydantic defaults read env) + os.environ["RFD_ENCODER"] = encoder + + + # Optional: ensure we don't try Hugging Face first (use Hub fallback). + # If you DO have HF access+token and want to prefer HF, just remove this line. + os.environ.pop("HUGGINGFACE_HUB_TOKEN", None) + + from rfdetr import RFDETRMedium, RFDETRBase + from rfdetr.util.coco_classes import COCO_CLASSES + + model = RFDETRMedium() # uses encoder="dinov3_base" per your config defaults + model.optimize_for_inference() + + while True: + url = input("Enter image URL (or 'exit' to quit): ") + if url.lower() == 'exit': + break + #url = "https://media.roboflow.com/notebooks/examples/dog-2.jpeg" + image = Image.open(io.BytesIO(requests.get(url).content)) + + detections = model.predict(image, threshold=0.5) + + labels = [ + f"{COCO_CLASSES[class_id]} {confidence:.2f}" + for class_id, confidence in zip(detections.class_id, detections.confidence) + ] + + annotated_image = image.copy() + annotated_image = sv.BoxAnnotator().annotate(annotated_image, detections) + annotated_image = sv.LabelAnnotator().annotate(annotated_image, detections, labels) + + sv.plot_image(annotated_image) + +if __name__ == "__main__": + main(encoder="v2", repo_dir="..\dinov3", dino_v_weights_path="..\dinov3/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth") \ No newline at end of file diff --git a/rfdetr/models/backbone/backbone.py b/rfdetr/models/backbone/backbone.py index 94b1f774c..cca09e94f 100644 --- a/rfdetr/models/backbone/backbone.py +++ b/rfdetr/models/backbone/backbone.py @@ -29,6 +29,7 @@ from rfdetr.models.backbone.base import BackboneBase from rfdetr.models.backbone.projector import MultiScaleProjector from rfdetr.models.backbone.dinov2 import DinoV2 +from rfdetr.models.backbone.dinov3 import DinoV3 __all__ = ["Backbone"] @@ -53,38 +54,60 @@ def __init__(self, load_dinov2_weights: bool = True, patch_size: int = 14, num_windows: int = 4, - positional_encoding_size: bool = False, + #positional_encoding_size: bool = False, + positional_encoding_size: int = 0, + # optional DINOv3 loading knobs (HF/Hub) ---- + dinov3_repo_dir: str | None = None, + dinov3_weights_path: str | None = None, + dinov3_hf_token: str | None = None, + dinov3_prefer_hf: bool = True, ): super().__init__() - # an example name here would be "dinov2_base" or "dinov2_registers_windowed_base" - # if "registers" is in the name, then use_registers is set to True, otherwise it is set to False - # similarly, if "windowed" is in the name, then use_windowed_attn is set to True, otherwise it is set to False - # the last part of the name should be the size - # and the start should be dinov2 + + # Accept either "dinov2_*" (existing) or "dinov3_*" (new). name_parts = name.split("_") - assert name_parts[0] == "dinov2" + family = name_parts[0] size = name_parts[-1] - use_registers = False - if "registers" in name_parts: - use_registers = True - name_parts.remove("registers") - use_windowed_attn = False - if "windowed" in name_parts: - use_windowed_attn = True - name_parts.remove("windowed") - assert len(name_parts) == 2, "name should be dinov2, then either registers, windowed, both, or none, then the size" - self.encoder = DinoV2( - size=name_parts[-1], - out_feature_indexes=out_feature_indexes, - shape=target_shape, - use_registers=use_registers, - use_windowed_attn=use_windowed_attn, - gradient_checkpointing=gradient_checkpointing, - load_dinov2_weights=load_dinov2_weights, - patch_size=patch_size, - num_windows=num_windows, - positional_encoding_size=positional_encoding_size, - ) + + if family == "dinov2": + # Existing semantics: optional "registers" and/or "windowed" tokens. + use_registers = False + if "registers" in name_parts: + use_registers = True + name_parts.remove("registers") + use_windowed_attn = False + if "windowed" in name_parts: + use_windowed_attn = True + name_parts.remove("windowed") + assert len(name_parts) == 2, "name should be dinov2, then either registers, windowed, both, or none, then the size" + self.encoder = DinoV2( + size=size, + out_feature_indexes=out_feature_indexes, + shape=target_shape, + use_registers=use_registers, + use_windowed_attn=use_windowed_attn, + gradient_checkpointing=gradient_checkpointing, + load_dinov2_weights=load_dinov2_weights, + patch_size=patch_size, + num_windows=num_windows, + positional_encoding_size=positional_encoding_size, + ) + elif family == "dinov3": + # new DINOv3 branch (no registers/windowing here) + self.encoder = DinoV3( + size=size, + out_feature_indexes=out_feature_indexes, + shape=target_shape, + patch_size=patch_size if patch_size else 16, + # reuse your existing flag for "load pretrained?" to avoid config churn + load_dinov3_weights=load_dinov2_weights, + hf_token=dinov3_hf_token, + repo_dir=dinov3_repo_dir, + weights=dinov3_weights_path, + prefer_hf=dinov3_prefer_hf, + ) + else: + raise AssertionError(f"Backbone name must start with 'dinov2' or 'dinov3', got: {family}") # build encoder + projector as backbone module if freeze_encoder: for param in self.encoder.parameters(): diff --git a/rfdetr/models/backbone/dinov3.py b/rfdetr/models/backbone/dinov3.py new file mode 100644 index 000000000..1265e6e12 --- /dev/null +++ b/rfdetr/models/backbone/dinov3.py @@ -0,0 +1,230 @@ +# ------------------------------------------------------------------------ +# Note: This file is an original wrapper that *loads* DINOv3 models. +# Using DINOv3 code/weights is subject to that license’s restrictions. +# DINOv3 itself is licensed under the DINOv3 License; see https://github.com/facebookresearch/dinov3/tree/main?tab=License-1-ov-file +# ------------------------------------------------------------------------ + +from typing import Sequence, Optional +import os +import torch +import torch.nn as nn + +# HF ids for ViT (/16) +_HF_IDS = { + "small": "facebook/dinov3-vits16-pretrain-lvd1689m", + "base": "facebook/dinov3-vitb16-pretrain-lvd1689m", + "large": "facebook/dinov3-vitl16-pretrain-lvd1689m", +} +# channels per size (for projector wiring) +_SIZE2WIDTH = {"small": 384, "base": 768, "large": 1024} +# torch.hub entrypoints in the official repo +_HUB_ENTRY = {"small": "dinov3_vits16", "base": "dinov3_vitb16", "large": "dinov3_vitl16"} + +class DinoV3(nn.Module): + """ + RF-DETR-facing DINOv3 wrapper: + - forward(x) -> List[B, C, H/16, W/16] (one per selected layer) + - _out_feature_channels: List[int] + - export(): no-op (kept for parity) + """ + + def __init__( + self, + shape: Sequence[int] = (640, 640), + out_feature_indexes: Sequence[int] = (2, 5, 8, 11), + size: str = "base", + patch_size: int = 16, + # new knobs: + load_dinov3_weights: bool = True, + hf_token: Optional[str] = None, + repo_dir: Optional[str] = None, # path to local dinov3 clone + weights: Optional[str] = None, # path or URL to *.pth (hub) + pretrained_name: Optional[str] = None, + **__, + ): + """ + A DINOv3 wrapper for RF-DETR. + + Args: + shape (Sequence[int]): Input image shape (H, W). + out_feature_indexes (Sequence[int]): Layer indexes to return. + size (str): DINOv3 model size: "small", "base", or "large". + patch_size (int): Patch size for the model. + load_dinov3_weights (bool): If True, load DINOv3 weights from HF or hub. + hf_token (Optional[str]): Hugging Face token for private models. + repo_dir (Optional[str]): Path to the local DINOv3 repository. + weights (Optional[str]): Path to the DINOv3 weights file. + pretrained_name (Optional[str]): Pretrained model name for HF. + """ + + super().__init__() + assert size in _HF_IDS, f"size must be one of {list(_HF_IDS)}, got {size}" + + self.shape = tuple(shape) + self.patch_size = int(patch_size) + self.model_patch = 16 + self.num_register_tokens = 0 + self.hidden_size = _SIZE2WIDTH[size] + self._out_feature_channels = [self.hidden_size] * len(out_feature_indexes) + self.out_feature_indexes = list(out_feature_indexes) + + # Allow env overrides (so you don't have to touch code) + repo_dir = repo_dir or os.getenv("DINOV3_REPO") + weights = weights or os.getenv("DINOV3_WEIGHTS") + hub_id = pretrained_name or _HF_IDS[size] + + # 1) Try HF if weights are allowed and a token is available + use_hf = load_dinov3_weights and (hf_token or os.getenv("HUGGINGFACE_HUB_TOKEN")) + if use_hf: + from transformers import AutoModel + self.encoder = AutoModel.from_pretrained( + hub_id, + token=True if hf_token is None else hf_token, + output_hidden_states=True, + return_dict=True, + ) + # pick config info when available + cfg = self.encoder.config + self.hidden_size = int(getattr(cfg, "hidden_size", self.hidden_size)) + self.num_register_tokens = int(getattr(cfg, "num_register_tokens", 0)) + self.model_patch = int(getattr(cfg, "patch_size", self.model_patch)) + else: + # 2) Fallback to PyTorch Hub (local repo + weights path or URL) + if not (repo_dir and weights): + raise RuntimeError( + "HF access unavailable/gated. Set DINOV3_REPO and DINOV3_WEIGHTS (or pass repo_dir/weights) " + "to load via torch.hub as per the DINOv3 README." + ) + entry = _HUB_ENTRY[size] + self.encoder = torch.hub.load(repo_dir, entry, source="local", weights=weights) + # best-effort introspection (these attrs may or may not exist on hub module) + self.num_register_tokens = int(getattr(self.encoder, "num_register_tokens", 0)) + self.model_patch = int(getattr(self.encoder, "patch_size", self.model_patch)) + + def export(self): # parity with dinov2 path + pass + + def _tokens_to_map(self, hidden_state: torch.Tensor, B: int, H: int, W: int) -> torch.Tensor: + """ + Accepts either: + - [B, 1+R+HW, C] (CLS + register + patch tokens) + - [B, HW, C] (no special tokens) + Returns: + - [B, C, H/ps, W/ps] + """ + ps = self.model_patch + assert H % ps == 0 and W % ps == 0, f"Input must be divisible by patch size {ps}, got {(H, W)}" + hp, wp = H // ps, W // ps + C = hidden_state.shape[-1] + + if hidden_state.dim() == 2: + # e.g., [HW, C] (no batch) -> try to recover batch + seq = hidden_state.shape[0] + assert seq % B == 0, f"Cannot infer batch from tokens of shape {hidden_state.shape} with B={B}" + hidden_state = hidden_state.view(B, seq // B, C) + + assert hidden_state.dim() == 3, f"Expected tokens [B, S, C], got {tuple(hidden_state.shape)}" + S = hidden_state.shape[1] + expected_hw = hp * wp + + if S == expected_hw: + seq = hidden_state # already patch tokens + elif S >= expected_hw + 1 + self.num_register_tokens: + # drop CLS + registers, then take the last expected_hw tokens + seq = hidden_state[:, 1 + self.num_register_tokens :, :] + seq = seq[:, -expected_hw:, :] + else: + # unknown extra tokens count; take the last expected_hw tokens + seq = hidden_state[:, -expected_hw:, :] + + return seq.view(B, hp, wp, C).permute(0, 3, 1, 2).contiguous() + + def forward(self, x: torch.Tensor): + B, _, H, W = x.shape + + # --- HF path (AutoModel) ----------------------------- + # Heuristic: HF models have .config and accept pixel_values=... + if hasattr(self.encoder, "config"): + out = self.encoder(pixel_values=x, output_hidden_states=True, return_dict=True) + hs = out.hidden_states # tuple/list: embeddings + each layer + feats = [self._tokens_to_map(hs[i], B, H, W) for i in self.out_feature_indexes] + return feats + + # --- Hub path: try get_intermediate_layers ----------- + if hasattr(self.encoder, "get_intermediate_layers"): + try: + max_idx = max(self.out_feature_indexes) + hs_list = self.encoder.get_intermediate_layers(x, n=max_idx + 1, reshape=False) + proc = [] + for h in hs_list: + # some impls return (tokens, cls) tuples + if isinstance(h, (list, tuple)): + h = h[0] + proc.append(h) + feats = [self._tokens_to_map(proc[i], B, H, W) for i in self.out_feature_indexes] + return feats + except Exception: + pass # fall through to plain forward + + # --- Hub path: try forward_features ------------------ + if hasattr(self.encoder, "forward_features"): + try: + ff = self.encoder.forward_features(x) + # Common patterns: dict with patch tokens or tensors + if isinstance(ff, dict): + cand = ( + ff.get("x_norm_patchtokens", None) + or ff.get("patch_tokens", None) + or ff.get("tokens", None) + or next((t for t in ff.values() if torch.is_tensor(t) and t.dim() >= 2), None) + ) + else: + cand = ff + if cand is not None: + if torch.is_tensor(cand) and cand.dim() == 4: + # Already a spatial map [B, C, Hp, Wp] + # Repeat to match requested out_feature_indexes count + C = cand.shape[1] + if C != self.hidden_size: + self.hidden_size = int(C) + self._out_feature_channels = [self.hidden_size] * len(self.out_feature_indexes) + return [cand for _ in self.out_feature_indexes] + # Otherwise assume tokens + tokens = cand + # If [HW, C] or [B*HW, C], _tokens_to_map will handle reshape + feats = [self._tokens_to_map(tokens, B, H, W) for _ in self.out_feature_indexes] + return feats + except Exception: + pass # fall through + + # --- Plain forward fallback -------------------------- + out = self.encoder(x) + # Normalize to a tensor + if isinstance(out, (list, tuple)): + # pick first tensor-like item + out = next((t for t in out if torch.is_tensor(t)), out[0]) + if not torch.is_tensor(out): + raise RuntimeError( + "DINOv3 hub module returned an unsupported output type for tracing. " + "Prefer the HF path or a hub build exposing intermediate layers." + ) + + # Case A: already spatial map [B, C, Hp, Wp] + if out.dim() == 4: + C = out.shape[1] + if C != self.hidden_size: + self.hidden_size = int(C) + self._out_feature_channels = [self.hidden_size] * len(self.out_feature_indexes) + return [out for _ in self.out_feature_indexes] + + # Case B/C: tokens [B, S, C] or [S, C] (batchless) + tokens = out + if tokens.dim() == 2: + # Let _tokens_to_map do the reshaping using B + feats = [self._tokens_to_map(tokens, B, H, W) for _ in self.out_feature_indexes] + return feats + elif tokens.dim() == 3: + feats = [self._tokens_to_map(tokens, B, H, W) for _ in self.out_feature_indexes] + return feats + + raise RuntimeError(f"Unsupported hub output shape: {tuple(out.shape)}") diff --git a/rfdetr/models/backbone/dinov3_configs/dinov3_base.json b/rfdetr/models/backbone/dinov3_configs/dinov3_base.json new file mode 100644 index 000000000..e5a567814 --- /dev/null +++ b/rfdetr/models/backbone/dinov3_configs/dinov3_base.json @@ -0,0 +1,26 @@ +{ + "architectures": ["DINOv3ViTModel"], + "model_type": "dinov3_vit", + "hidden_size": 768, + "num_hidden_layers": 12, + "num_attention_heads": 12, + "mlp_ratio": 4, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "attention_probs_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-05, + "image_size": 512, + "patch_size": 16, + "num_channels": 3, + "query_bias": true, + "key_bias": false, + "value_bias": true, + "proj_bias": true, + "mlp_bias": true, + "layerscale_value": 1.0, + "drop_path_rate": 0.0, + "use_gated_mlp": false, + "num_register_tokens": 4, + "torch_dtype": "float32" +} diff --git a/rfdetr/models/backbone/dinov3_configs/dinov3_large.json b/rfdetr/models/backbone/dinov3_configs/dinov3_large.json new file mode 100644 index 000000000..7325e6928 --- /dev/null +++ b/rfdetr/models/backbone/dinov3_configs/dinov3_large.json @@ -0,0 +1,26 @@ +{ + "architectures": ["DINOv3ViTModel"], + "model_type": "dinov3_vit", + "hidden_size": 1024, + "num_hidden_layers": 24, + "num_attention_heads": 16, + "mlp_ratio": 4, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "attention_probs_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-05, + "image_size": 512, + "patch_size": 16, + "num_channels": 3, + "query_bias": true, + "key_bias": false, + "value_bias": true, + "proj_bias": true, + "mlp_bias": true, + "layerscale_value": 1.0, + "drop_path_rate": 0.0, + "use_gated_mlp": false, + "num_register_tokens": 4, + "torch_dtype": "float32" +} diff --git a/rfdetr/models/backbone/dinov3_configs/dinov3_small.json b/rfdetr/models/backbone/dinov3_configs/dinov3_small.json new file mode 100644 index 000000000..435309161 --- /dev/null +++ b/rfdetr/models/backbone/dinov3_configs/dinov3_small.json @@ -0,0 +1,26 @@ +{ + "architectures": ["DINOv3ViTModel"], + "model_type": "dinov3_vit", + "hidden_size": 384, + "num_hidden_layers": 12, + "num_attention_heads": 6, + "mlp_ratio": 4, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "attention_probs_dropout_prob": 0.0, + "initializer_range": 0.02, + "layer_norm_eps": 1e-05, + "image_size": 512, + "patch_size": 16, + "num_channels": 3, + "query_bias": true, + "key_bias": false, + "value_bias": true, + "proj_bias": true, + "mlp_bias": true, + "layerscale_value": 1.0, + "drop_path_rate": 0.0, + "use_gated_mlp": false, + "num_register_tokens": 4, + "torch_dtype": "float32" +} diff --git a/rfdetr/train_v2_or_v3.py b/rfdetr/train_v2_or_v3.py new file mode 100644 index 000000000..f65eb7b4a --- /dev/null +++ b/rfdetr/train_v2_or_v3.py @@ -0,0 +1,113 @@ +# train_single.py +import argparse +import os + +ENCODER_ALIASES = { + "dinov2": "dinov2_windowed_small", + "v2": "dinov2_windowed_small", + "dinov2_small": "dinov2_windowed_small", + "dinov2_base": "dinov2_windowed_base", + "dinov3": "dinov3_base", + "v3": "dinov3_base", +} + +VALID_ENCODERS = { + "dinov2_windowed_small", + "dinov2_windowed_base", + "dinov3_small", + "dinov3_base", + "dinov3_large", +} + +def parse_args(): + ap = argparse.ArgumentParser("RF-DETR Medium single-run trainer (v2 or v3)") + ap.add_argument("--data", required=True, help="Dataset root with train/valid/test") + ap.add_argument("--out", default="./runs", help="Root output dir for TB & checkpoints") + ap.add_argument("--epochs", type=int, default=20) + ap.add_argument("--bs", type=int, default=8, help="Batch size per iteration") ##TODO #Not actually applying, not worked out why + ap.add_argument("--workers", type=int, default=None, + help="DataLoader workers (default: 0 on Windows, else 2)") + ap.add_argument("--encoder", default="dinov2", + help=("dinov2|v2|dinov2_small|dinov2_base|dinov3|v3|" + "dinov3_small|dinov3_base|dinov3_large or exact name")) + ap.add_argument("--name", default=None, help="Optional run name (subdir under --out)") + + # Optional local DINOv3 assets + ap.add_argument("--dinov3-repo", default=None, help="Local DINOv3 repo (sets DINOV3_REPO)") + ap.add_argument("--dinov3-weights", default=None, help="Path to DINOv3 .pth (sets DINOV3_WEIGHTS)") + return ap.parse_args() + +def resolve_encoder(enc_str: str) -> str: + enc_str = enc_str.strip().lower() + enc = ENCODER_ALIASES.get(enc_str, enc_str) + if enc not in VALID_ENCODERS: + raise ValueError(f"Unknown encoder '{enc_str}'. Valid: {sorted(list(VALID_ENCODERS))}") + return enc + +def main(): + args = parse_args() + + # Safer default on Windows for DataLoader workers + if args.workers is None: + import platform + args.workers = 0 if platform.system() == "Windows" else 2 + + encoder_name = resolve_encoder(args.encoder) + + # Set env *before* importing your package (your Pydantic defaults read env) + os.environ["RFD_ENCODER"] = encoder_name + if args.dinov3_repo: + os.environ["DINOV3_REPO"] = args.dinov3_repo + if args.dinov3_weights: + os.environ["DINOV3_WEIGHTS"] = args.dinov3_weights + + # Now import project code + from rfdetr import RFDETRBase, RFDETRMedium + from rfdetr.config import RFDETRBaseConfig, RFDETRMediumConfig # to override config cleanly + + # Two thin wrappers so we can control encoder and pretrain at construction time + class RFDETRBaseV2(RFDETRMedium): + def get_model_config(self, **kwargs): + # keep RF-DETR pretrain (default) for v2 + return RFDETRMediumConfig(encoder="dinov2_windowed_small", **kwargs) + + class RFDETRBaseV3(RFDETRMedium): + def get_model_config(self, **kwargs): + # IMPORTANT: disable RF-DETR pretrain for v3 to avoid shape mismatches + return RFDETRMediumConfig(encoder="dinov3_base", pretrain_weights=None, **kwargs) + + # Output dir (separate subdirs so TB shows two runs side-by-side) + if args.name: + run_name = args.name + else: + tag = "DINOv2" if encoder_name.startswith("dinov2") else "DINOv3" + run_name = f"{tag}_Base" + out_dir = os.path.join(args.out, run_name) + os.makedirs(out_dir, exist_ok=True) + + # Build and train + ModelCls = RFDETRBaseV3 if encoder_name.startswith("dinov3") else RFDETRBaseV2 + model = ModelCls() + + print(f"\n=== Training RF-DETR Base with encoder: {encoder_name} ===") + print(f"Dataset: {args.data}") + print(f"Output : {out_dir}") + + train_kwargs = dict( + dataset_dir=args.data, + output_dir=out_dir, + epochs=args.epochs, + batch_size=args.bs, #TODO #Not actually applying, not worked out why + num_workers=args.workers, + tensorboard=True, + run_test=True, + ) + # NOTE: train() expects kwargs, not a TrainConfig instance + model.train(**train_kwargs) + + print("\nDone. View in TensorBoard with:") + print(f" tensorboard --logdir {args.out}") + print("Open http://127.0.0.1:6006") + +if __name__ == "__main__": + main()