diff --git a/trackers/core/reid/model.py b/trackers/core/reid/model.py index bbd7aacf..de84df76 100644 --- a/trackers/core/reid/model.py +++ b/trackers/core/reid/model.py @@ -67,18 +67,25 @@ def _initialize_reid_model_from_timm( return cls(model, device, transforms, model_metadata) -def _initialize_reid_model_from_checkpoint(cls, checkpoint_path: str): - state_dict, config = load_safetensors_checkpoint(checkpoint_path) +def _initialize_reid_model_from_checkpoint(cls, checkpoint_path: str, config_path: str): + state_dict, config = load_safetensors_checkpoint(checkpoint_path, config_path) + model_name = config.get("architecture") + if model_name is None: + raise ValueError( + f"The config at {config_path} is missing the 'architecture' key." + ) + init_kwargs = {} + init_kwargs["pretrained"] = False reid_model_instance = _initialize_reid_model_from_timm( - cls, **config["model_metadata"] + cls, model_name_or_checkpoint_path=model_name, device="auto", **init_kwargs ) - if config["projection_dimension"]: + if config.get("projection_dimension"): reid_model_instance._add_projection_layer( - projection_dimension=config["projection_dimension"] + projection_dimension=config.get("projection_dimension") ) for k, v in state_dict.items(): - state_dict[k].to(reid_model_instance.device) - reid_model_instance.backbone_model.load_state_dict(state_dict) + state_dict[k] = v.to(reid_model_instance.device) + reid_model_instance.backbone_model.load_state_dict(state_dict, strict=False) return reid_model_instance @@ -122,6 +129,7 @@ def __init__( def from_timm( cls, model_name_or_checkpoint_path: str, + config_path: Optional[str] = None, device: Optional[str] = "auto", get_pooled_features: bool = True, **kwargs, @@ -134,6 +142,8 @@ def from_timm( model_name_or_checkpoint_path (str): Name of the timm model to use or path to a safetensors checkpoint. If the exact model name is not found, the closest match from `timm.list_models` will be used. + config_path (str): Path to the config file for the local + safetensors checkpoint. device (str): Device to run the model on. get_pooled_features (bool): Whether to get the pooled features from the model or not. @@ -143,9 +153,13 @@ def from_timm( Returns: ReIDModel: A new instance of `ReIDModel`. """ - if os.path.exists(model_name_or_checkpoint_path): + if ( + config_path is not None + and os.path.exists(model_name_or_checkpoint_path) + and os.path.exists(config_path) + ): return _initialize_reid_model_from_checkpoint( - cls, model_name_or_checkpoint_path + cls, model_name_or_checkpoint_path, config_path ) else: return _initialize_reid_model_from_timm( diff --git a/trackers/utils/torch_utils.py b/trackers/utils/torch_utils.py index 2d0c4b71..45b36e05 100644 --- a/trackers/utils/torch_utils.py +++ b/trackers/utils/torch_utils.py @@ -59,28 +59,31 @@ def parse_device_spec(device_spec: Union[str, torch.device]) -> torch.device: def load_safetensors_checkpoint( - checkpoint_path: str, device: str = "cpu" + checkpoint_path: str, + config_path: str, + device: str = "cpu", ) -> Tuple[dict[str, torch.Tensor], dict[str, Any]]: """ - Load a safetensors checkpoint into a dictionary of tensors and a dictionary - of metadata. + Load a safetensors checkpoint into a dictionary of tensors and a + separate JSON config file. Args: checkpoint_path (str): The path to the safetensors checkpoint. + config_path (str): The path to the JSON config file. device (str): The device to load the checkpoint on. Returns: - Tuple[dict[str, torch.Tensor], dict[str, Any]]: A tuple containing the - state_dict and the config. + state_dict (dict): model weights + config (dict): model config """ state_dict = {} with safe_open(checkpoint_path, framework="pt", device=device) as f: for key in f.keys(): state_dict[key] = f.get_tensor(key) - metadata = f.metadata() - config = json.loads(metadata["config"]) if "config" in metadata else {} - model_metadata = config.pop("model_metadata") if "model_metadata" in config else {} - if "kwargs" in model_metadata: + with open(config_path, "r") as f: + config = json.load(f) + model_metadata = config.pop("model_metadata", {}) + if isinstance(model_metadata, dict) and "kwargs" in model_metadata: kwargs = model_metadata.pop("kwargs") model_metadata = {**kwargs, **model_metadata} config["model_metadata"] = model_metadata