diff --git a/rfdetr/config.py b/rfdetr/config.py index da5745639..2e36c1b60 100644 --- a/rfdetr/config.py +++ b/rfdetr/config.py @@ -26,6 +26,7 @@ class ModelConfig(BaseModel): amp: bool = True num_classes: int = 90 pretrain_weights: Optional[str] = None + onnx_path: Optional[str] = None device: Literal["cpu", "cuda", "mps"] = DEVICE resolution: int = 560 group_detr: int = 13 diff --git a/rfdetr/detr.py b/rfdetr/detr.py index 73fff74f5..66eb8eec6 100644 --- a/rfdetr/detr.py +++ b/rfdetr/detr.py @@ -28,6 +28,21 @@ class RFDETR: def __init__(self, **kwargs): self.model_config = self.get_model_config(**kwargs) + + self.onnx_model = None + if self.model_config.onnx_path: + import onnxruntime as ort + + if self.model_config.device == "cuda": + providers = [ + "CUDAExecutionProvider", + # optionally add device id ("CUDAExecutionProvider", {"device_id": cuda_device_id}) + "CPUExecutionProvider", + ] + else: + providers = ["CPUExecutionProvider"] + self.onnx_model = ort.InferenceSession(self.model_config.onnx_path, providers=providers) + self.maybe_download_pretrain_weights() self.model = self.get_model(self.model_config) self.callbacks = defaultdict(list) @@ -41,7 +56,7 @@ def get_model_config(self, **kwargs): def train(self, **kwargs): config = self.get_train_config(**kwargs) self.train_from_config(config, **kwargs) - + def export(self, **kwargs): self.model.export(**kwargs) @@ -59,14 +74,13 @@ def train_from_config(self, config: TrainConfig, **kwargs): f"reinitializing your detection head with {num_classes} classes." ) self.model.reinitialize_detection_head(num_classes) - - + train_config = config.dict() model_config = self.model_config.dict() model_config.pop("num_classes") if "class_names" in model_config: model_config.pop("class_names") - + if "class_names" in train_config and train_config["class_names"] is None: train_config["class_names"] = class_names @@ -75,7 +89,7 @@ def train_from_config(self, config: TrainConfig, **kwargs): model_config.pop(k) if k in kwargs: kwargs.pop(k) - + all_kwargs = {**model_config, **train_config, **kwargs, "num_classes": num_classes} metrics_plot_sink = MetricsPlotSink(output_dir=config.output_dir) @@ -120,7 +134,8 @@ def get_model(self, config: ModelConfig): def predict( self, - images: Union[str, Image.Image, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, Image.Image, torch.Tensor]]], + images: Union[ + str, Image.Image, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, Image.Image, torch.Tensor]]], threshold: float = 0.5, **kwargs, ) -> Union[sv.Detections, List[sv.Detections]]: @@ -186,7 +201,18 @@ def predict( batch_tensor = torch.stack(processed_images) with torch.inference_mode(): - predictions = self.model.model(batch_tensor) + if self.onnx_model: + image_np = batch_tensor.cpu().float().numpy() + + input_name = self.onnx_model.get_inputs()[0].name + outputs = self.onnx_model.run(None, {input_name: image_np}) + + predictions = { + "pred_boxes": torch.tensor(outputs[0]).to(self.model.device), + "pred_logits": torch.tensor(outputs[1]).to(self.model.device), + } + else: + predictions = self.model.model(batch_tensor) target_sizes = torch.tensor(orig_sizes, device=self.model.device) results = self.model.postprocessors["bbox"](predictions, target_sizes=target_sizes) @@ -218,6 +244,7 @@ def get_model_config(self, **kwargs): def get_train_config(self, **kwargs): return TrainConfig(**kwargs) + class RFDETRLarge(RFDETR): def get_model_config(self, **kwargs): return RFDETRLargeConfig(**kwargs)