diff --git a/rfdetr/config.py b/rfdetr/config.py index c2409c2d8..29c3aed9c 100644 --- a/rfdetr/config.py +++ b/rfdetr/config.py @@ -18,6 +18,7 @@ class ModelConfig(BaseModel): layer_norm: bool = True amp: bool = True num_classes: int = 90 + pretrain_save_file:Optional[str] = 'model.pth' pretrain_weights: Optional[str] = None device: Literal["cpu", "cuda", "mps"] = DEVICE resolution: int = 560 @@ -34,7 +35,7 @@ class RFDETRBaseConfig(ModelConfig): num_select: int = 300 projector_scale: List[Literal["P3", "P4", "P5"]] = ["P4"] out_feature_indexes: List[int] = [2, 5, 8, 11] - pretrain_weights: Optional[str] = "rf-detr-base.pth" + pretrain_weights: Optional[str] = "rfdetr_base" class RFDETRLargeConfig(RFDETRBaseConfig): encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_base" @@ -43,7 +44,7 @@ class RFDETRLargeConfig(RFDETRBaseConfig): ca_nheads: int = 24 dec_n_points: int = 4 projector_scale: List[Literal["P3", "P4", "P5"]] = ["P3", "P5"] - pretrain_weights: Optional[str] = "rf-detr-large.pth" + pretrain_weights: Optional[str] = "rfdetr_large" class TrainConfig(BaseModel): lr: float = 1e-4 diff --git a/rfdetr/detr.py b/rfdetr/detr.py index 13121a4f9..3c3c79b46 100644 --- a/rfdetr/detr.py +++ b/rfdetr/detr.py @@ -26,7 +26,8 @@ def __init__(self, **kwargs): self.callbacks = defaultdict(list) def maybe_download_pretrain_weights(self): - download_pretrain_weights(self.model_config.pretrain_weights) + download_pretrain_weights(self.model_config.pretrain_weights, self.model_config.pretrain_save_file) + def get_model_config(self, **kwargs): return ModelConfig(**kwargs) diff --git a/rfdetr/main.py b/rfdetr/main.py index 10f5a0b08..110fceed1 100644 --- a/rfdetr/main.py +++ b/rfdetr/main.py @@ -52,22 +52,29 @@ logger = getLogger(__name__) HOSTED_MODELS = { - "rf-detr-base.pth": "https://storage.googleapis.com/rfdetr/rf-detr-base-coco.pth", + "rfdetr_base": "https://storage.googleapis.com/rfdetr/rf-detr-base-coco.pth", # below is a less converged model that may be better for finetuning but worse for inference - "rf-detr-base-2.pth": "https://storage.googleapis.com/rfdetr/rf-detr-base-2.pth", - "rf-detr-large.pth": "https://storage.googleapis.com/rfdetr/rf-detr-large.pth" + "rfdetr_base2": "https://storage.googleapis.com/rfdetr/rf-detr-base-2.pth", + "rfdetr_large": "https://storage.googleapis.com/rfdetr/rf-detr-large.pth" } -def download_pretrain_weights(pretrain_weights: str, redownload=False): - if pretrain_weights in HOSTED_MODELS: - if redownload or not os.path.exists(pretrain_weights): - logger.info( - f"Downloading pretrained weights for {pretrain_weights}" - ) - download_file( - HOSTED_MODELS[pretrain_weights], - pretrain_weights, - ) +def download_pretrain_weights(model_type: str, output_path: str, redownload:bool = False): + if model_type not in HOSTED_MODELS: + raise ValueError(f"Unknown model type '{model_type}'. Valid options are: {list(HOSTED_MODELS.keys())}") + + if redownload or not os.path.exists(output_path): + # Create parent directory only if there is one + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + logger.info( + f"Downloading pretrained weights for {model_type}" + ) + download_file( + HOSTED_MODELS[model_type], + output_path, + ) class Model: def __init__(self, **kwargs): @@ -75,16 +82,16 @@ def __init__(self, **kwargs): self.resolution = args.resolution self.model = build_model(args) self.device = torch.device(args.device) - if args.pretrain_weights is not None: + if args.pretrain_save_file is not None: print("Loading pretrain weights") try: - checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False) + checkpoint = torch.load(args.pretrain_save_file, map_location='cpu', weights_only=False) except Exception as e: print(f"Failed to load pretrain weights: {e}") # re-download weights if they are corrupted print("Failed to load pretrain weights, re-downloading") - download_pretrain_weights(args.pretrain_weights, redownload=True) - checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False) + download_pretrain_weights(args.pretrain_weights,args.pretrain_save_file, redownload=True) + checkpoint = torch.load(args.pretrain_save_file, map_location='cpu', weights_only=False) checkpoint_num_classes = checkpoint['model']['class_embed.bias'].shape[0] if checkpoint_num_classes != args.num_classes + 1: @@ -541,6 +548,7 @@ def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_ "cutoff_epoch", "pretrained_encoder", "pretrain_weights", + "pretrain_save_file", "pretrain_exclude_keys", "pretrain_keys_modify_to_load", "freeze_florence", @@ -630,6 +638,8 @@ def get_args_parser(): parser.add_argument('--pretrained_encoder', type=str, default=None, help="Path to the pretrained encoder.") parser.add_argument('--pretrain_weights', type=str, default=None, + help="Model type to use.") + parser.add_argument('--pretrain_save_file', type=str, default='model.pth', help="Path to the pretrained model.") parser.add_argument('--pretrain_exclude_keys', type=str, default=None, nargs='+', help="Keys you do not want to load.") @@ -806,6 +816,7 @@ def populate_args( # Model parameters pretrained_encoder=None, pretrain_weights=None, + pretrain_save_file=None, pretrain_exclude_keys=None, pretrain_keys_modify_to_load=None, pretrained_distiller=None, @@ -924,6 +935,7 @@ def populate_args( cutoff_epoch=cutoff_epoch, pretrained_encoder=pretrained_encoder, pretrain_weights=pretrain_weights, + pretrain_save_file=pretrain_save_file, pretrain_exclude_keys=pretrain_exclude_keys, pretrain_keys_modify_to_load=pretrain_keys_modify_to_load, pretrained_distiller=pretrained_distiller,