Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rfdetr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class ModelConfig(BaseModel):
amp: bool = True
num_classes: int = 90
pretrain_weights: Optional[str] = None
onnx_path: Optional[str] = None
device: Literal["cpu", "cuda", "mps"] = DEVICE
resolution: int = 560
group_detr: int = 13
Expand Down
41 changes: 34 additions & 7 deletions rfdetr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,21 @@ class RFDETR:

def __init__(self, **kwargs):
self.model_config = self.get_model_config(**kwargs)

self.onnx_model = None
if self.model_config.onnx_path:
import onnxruntime as ort

if self.model_config.device == "cuda":
providers = [
"CUDAExecutionProvider",
# optionally add device id ("CUDAExecutionProvider", {"device_id": cuda_device_id})
"CPUExecutionProvider",
]
else:
providers = ["CPUExecutionProvider"]
self.onnx_model = ort.InferenceSession(self.model_config.onnx_path, providers=providers)

self.maybe_download_pretrain_weights()
self.model = self.get_model(self.model_config)
self.callbacks = defaultdict(list)
Expand All @@ -41,7 +56,7 @@ def get_model_config(self, **kwargs):
def train(self, **kwargs):
config = self.get_train_config(**kwargs)
self.train_from_config(config, **kwargs)

def export(self, **kwargs):
self.model.export(**kwargs)

Expand All @@ -59,14 +74,13 @@ def train_from_config(self, config: TrainConfig, **kwargs):
f"reinitializing your detection head with {num_classes} classes."
)
self.model.reinitialize_detection_head(num_classes)



train_config = config.dict()
model_config = self.model_config.dict()
model_config.pop("num_classes")
if "class_names" in model_config:
model_config.pop("class_names")

if "class_names" in train_config and train_config["class_names"] is None:
train_config["class_names"] = class_names

Expand All @@ -75,7 +89,7 @@ def train_from_config(self, config: TrainConfig, **kwargs):
model_config.pop(k)
if k in kwargs:
kwargs.pop(k)

all_kwargs = {**model_config, **train_config, **kwargs, "num_classes": num_classes}

metrics_plot_sink = MetricsPlotSink(output_dir=config.output_dir)
Expand Down Expand Up @@ -120,7 +134,8 @@ def get_model(self, config: ModelConfig):

def predict(
self,
images: Union[str, Image.Image, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, Image.Image, torch.Tensor]]],
images: Union[
str, Image.Image, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, Image.Image, torch.Tensor]]],
threshold: float = 0.5,
**kwargs,
) -> Union[sv.Detections, List[sv.Detections]]:
Expand Down Expand Up @@ -186,7 +201,18 @@ def predict(
batch_tensor = torch.stack(processed_images)

with torch.inference_mode():
predictions = self.model.model(batch_tensor)
if self.onnx_model:
image_np = batch_tensor.cpu().float().numpy()

input_name = self.onnx_model.get_inputs()[0].name
outputs = self.onnx_model.run(None, {input_name: image_np})

predictions = {
"pred_boxes": torch.tensor(outputs[0]).to(self.model.device),
"pred_logits": torch.tensor(outputs[1]).to(self.model.device),
}
else:
predictions = self.model.model(batch_tensor)
target_sizes = torch.tensor(orig_sizes, device=self.model.device)
results = self.model.postprocessors["bbox"](predictions, target_sizes=target_sizes)

Expand Down Expand Up @@ -218,6 +244,7 @@ def get_model_config(self, **kwargs):
def get_train_config(self, **kwargs):
return TrainConfig(**kwargs)


class RFDETRLarge(RFDETR):
def get_model_config(self, **kwargs):
return RFDETRLargeConfig(**kwargs)
Expand Down