Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d73824a
Add native Apple Silicon inference via MLX backend
RacineAI-comp Feb 26, 2026
cffe7a0
Add MLX segmentation backend (RFDETRSeg* via optimize_for_inference)
RacineAI-comp Mar 2, 2026
e5bd7a1
Add MLX segmentation backend (RFDETRSeg* via optimize_for_inference)
RacineAI-comp Mar 2, 2026
65a5d7a
fix test warnings (#761)
omkar-334 Mar 3, 2026
82f67e4
fix: revert --doctest-modules (breaks CI on non-MLX platforms)
RacineAI-comp Mar 3, 2026
ede91f3
fix: remove stale train_from_config from rebase, add TYPE_CHECKING fo…
RacineAI-comp Mar 14, 2026
849eeac
Merge branch 'develop' into feat/mlx-inference
RacineAI-comp Mar 16, 2026
63f30df
Merge branch 'develop' into feat/mlx-inference
RacineAI-comp Mar 17, 2026
15f660d
Merge branch 'develop' into feat/mlx-inference
RacineAI-comp Mar 18, 2026
37f0687
Merge branch 'develop' into feat/mlx-inference
RacineAI-comp Mar 19, 2026
d3daf3c
Apply suggestions from code review
Borda Mar 19, 2026
7916b0f
Merge branch 'develop' into feat/mlx-inference
RacineAI-comp Mar 19, 2026
167ab24
Update decoder.py
RacineAI-comp Mar 19, 2026
505a009
Merge branch 'develop' into feat/mlx-inference
RacineAI-comp Mar 23, 2026
47d77a9
Merge branch 'develop' into feat/mlx-inference
RacineAI-comp Mar 24, 2026
8be5694
Merge branch 'develop' into feat/mlx-inference
Borda Mar 25, 2026
66f30bb
fix(mlx): guard num_select==0 and use correct argpartition kth in seg…
Borda Mar 25, 2026
b47a72b
Merge branch 'develop' into feat/mlx-inference
RacineAI-comp Mar 25, 2026
6cdefae
fix(mlx): derive full_attn_layers from feature_indices instead of har…
Borda Mar 25, 2026
a4e1ed9
fix(deps): add scipy to [mlx] optional extra
Borda Mar 25, 2026
3addaf3
fix(detr): raise on unsupported backend= and shape= with MLX
Borda Mar 25, 2026
5ecf4ab
fix(tests): restore --doctest-modules; exclude mlx/ from doctest coll…
Borda Mar 25, 2026
0b2401d
refactor(detr): remove dead _optimized_half attribute
Borda Mar 25, 2026
412be9c
fix(tests): apply pytest.mark.mlx to MLX test modules
Borda Mar 25, 2026
aa34d81
lint: auto-fix violations after resolve cycle
Borda Mar 25, 2026
397c14c
Merge branch 'feat/mlx-inference' of https://github.com/RacineAI-comp…
Borda Mar 25, 2026
72b9944
perf(mlx): replace per-mask cv2.resize loop with vectorised scipy.ndi…
Borda Mar 25, 2026
8443c13
fix(tests): remove vacuous is_mlx_available assertion; document full_…
Borda Mar 25, 2026
b0a789e
migrate(tests): merge MLX segmentation tests into test_mlx_inference.py
Borda Mar 25, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,19 @@ cython_debug/

# model artifacts
rf-detr*
*.pth
output/*

train_test.py
run_demo.py
demo_live.py
examples/

# macOS
.DS_Store

# demo/scratch outputs
results/

# test artifacts
test_visualizations/
Expand Down
9 changes: 9 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# ------------------------------------------------------------------------
# RF-DETR
# Copyright (c) 2025 Roboflow. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------

# Exclude MLX modules from --doctest-modules collection: they import mlx.core
# unconditionally at module level, which is only available on macOS/Darwin.
collect_ignore_glob = ["src/rfdetr/mlx/*.py"]
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ visual = [
]
cli = ["jsonargparse[signatures]>=4.27.7"]
plus = ["rfdetr_plus>=1.0.1, <2.0.0"]
mlx = [
"mlx>=0.22.0; sys_platform == 'darwin'",
"scipy",
]

[dependency-groups]
# TODO: Temporary: GPU runner only has CUDA 12.8 driver (12080); torch>=2.11 requires a newer driver.
Expand Down Expand Up @@ -163,11 +167,12 @@ ignore = [
addopts = [
"-v", # verbose output
"--color=yes", # colored output
"--doctest-modules", # run doctests from all modules
"--doctest-modules",
]
pythonpath = ["src"]
markers = [
"gpu: tests that require GPU or are slow on CPU",
"mlx: tests that require MLX (macOS Apple Silicon only)",
"flaky: tests that may fail due to nondeterminism"
]
filterwarnings = [
Expand Down
125 changes: 124 additions & 1 deletion src/rfdetr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def __init__(self, **kwargs):
self._optimized_batch_size = None
self._optimized_resolution = None
self._optimized_dtype = None
self._inference_backend = None
self._mlx_model = None

def maybe_download_pretrain_weights(self):
"""
Expand Down Expand Up @@ -226,15 +228,42 @@ def train(self, **kwargs):
if dataset_class_names is not None:
self.model.class_names = dataset_class_names

def optimize_for_inference(self, compile=True, batch_size=1, dtype=torch.float32):
def optimize_for_inference(self, compile=True, batch_size=1, dtype=torch.float32, backend: str = "pytorch"):
"""Optimize the model for inference.

Args:
compile: Whether to JIT-compile the PyTorch model.
batch_size: Batch size for JIT compilation.
dtype: Data type for PyTorch inference.
backend: Inference backend. "pytorch" for JIT-traced PyTorch,
"mlx" for native Apple Silicon inference via MLX (macOS only).
"""
self.remove_optimized_model()

if backend == "mlx":
if getattr(self.model_config, "segmentation_head", False):
from rfdetr.mlx import build_mlx_seg_inference

self._mlx_model = build_mlx_seg_inference(self.model_config, self.model)
else:
from rfdetr.mlx import build_mlx_inference

self._mlx_model = build_mlx_inference(self.model_config, self.model)
self._is_optimized_for_inference = True
self._inference_backend = "mlx"
self._optimized_resolution = self.model.resolution
return

if backend != "pytorch":
raise ValueError(f"Unknown inference backend {backend!r}. Expected 'pytorch' or 'mlx'.")

self.model.inference_model = deepcopy(self.model.model)
self.model.inference_model.eval()
self.model.inference_model.export()

self._optimized_resolution = self.model.resolution
self._is_optimized_for_inference = True
self._inference_backend = "pytorch"

self.model.inference_model = self.model.inference_model.to(dtype=dtype)
self._optimized_dtype = dtype
Expand All @@ -255,6 +284,8 @@ def remove_optimized_model(self):
self._optimized_has_been_compiled = False
self._optimized_batch_size = None
self._optimized_resolution = None
self._inference_backend = None
self._mlx_model = None
self._optimized_dtype = None

@deprecated(
Expand Down Expand Up @@ -501,6 +532,13 @@ def predict(
(e.g. ``float``) or is a ``bool``, if either dimension is zero or
negative, or if either dimension is not divisible by 14.
"""
if self._inference_backend == "mlx":
if shape is not None:
raise NotImplementedError(
"'shape' is not supported with backend='mlx'. "
"Resize the input image before calling predict() instead."
)
return self._predict_mlx(images, threshold)
import supervision as sv

if shape is not None:
Expand Down Expand Up @@ -649,6 +687,91 @@ def predict(

return detections_list if len(detections_list) > 1 else detections_list[0]

def _predict_mlx(
self,
images: Union[
str, Image.Image, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, Image.Image, torch.Tensor]]
],
threshold: float = 0.5,
) -> Union[sv.Detections, List[sv.Detections]]:
"""Run inference through the MLX backend.

Accepts the same input types as predict(). Preprocesses images on CPU
(load + resize), then passes uint8 arrays to the compiled MLX pipeline.

Args:
images: Single image or list of images.
threshold: Confidence threshold.

Returns:
Detection results as sv.Detections.
"""
import mlx.core as mx

if not isinstance(images, list):
images = [images]

orig_sizes = []
uint8_arrays = []
resolution = self._mlx_model.resolution

for img in images:
if isinstance(img, str):
if img.startswith("http"):
img = requests.get(img, stream=True).raw
img = Image.open(img)

if isinstance(img, torch.Tensor):
# Convert CHW float [0,1] tensor to HWC uint8
img = (img.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
elif isinstance(img, Image.Image):
img = np.array(img)

if img.ndim == 2:
img = np.stack([img] * 3, axis=-1)
elif img.shape[2] == 4:
img = img[:, :, :3]

orig_sizes.append((img.shape[0], img.shape[1]))

# Resize to model resolution using PIL (fast CPU resize)
pil_img = Image.fromarray(img)
pil_img = pil_img.resize((resolution, resolution), Image.BILINEAR)
uint8_arrays.append(np.array(pil_img))

# Stack into batch and run MLX inference
batch = np.stack(uint8_arrays)
x = mx.array(batch)
outputs = self._mlx_model.forward(x)

# Postprocess to numpy
results = self._mlx_model.postprocess(outputs, orig_sizes)

detections_list = []
for result in results:
scores = result["scores"]
labels = result["labels"]
boxes = result["boxes"]

keep = scores > threshold
scores = scores[keep]
labels = labels[keep]
boxes = boxes[keep]

mask = None
if "masks" in result:
mask = (result["masks"][keep] > 0.5).astype(bool)

detections = sv.Detections(
xyxy=boxes.astype(np.float32),
confidence=scores.astype(np.float32),
class_id=labels.astype(np.intp),
mask=mask,
)
detections_list.append(detections)

return detections_list if len(detections_list) > 1 else detections_list[0]

def deploy_to_roboflow(
self,
workspace: str,
Expand Down
101 changes: 101 additions & 0 deletions src/rfdetr/mlx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# ------------------------------------------------------------------------
# RF-DETR
# Copyright (c) 2025 Roboflow. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------

"""MLX backend for RF-DETR inference on Apple Silicon.

Provides native Metal-accelerated inference using MLX, achieving up to 6x
speedup over PyTorch MPS on Apple Silicon hardware (M1-M4).

Usage::

from rfdetr import RFDETRNano

model = RFDETRNano()
model.optimize_for_inference(backend="mlx")
detections = model.predict(image)
"""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from rfdetr.mlx.inference import MLXInferenceModel, MLXSegInferenceModel


def is_mlx_available() -> bool:
"""Check whether MLX is available on this system.

Returns:
True if running on macOS with MLX installed, False otherwise.
"""
try:
import platform

import mlx.core # noqa: F401

return platform.system() == "Darwin"
except ImportError:
return False


def build_mlx_inference(
model_config: object,
pytorch_model: object,
) -> "MLXInferenceModel":
"""Build a compiled MLX inference model from a PyTorch RF-DETR model.

Converts PyTorch weights to MLX format, builds the MLX model graph,
casts to FP16, and compiles the full forward pass for Metal execution.

Args:
model_config: RF-DETR model configuration (e.g., RFDETRNanoConfig).
pytorch_model: The rfdetr.main.Model instance with loaded weights.

Returns:
Compiled MLX inference model ready for predict() calls.

Raises:
RuntimeError: If MLX is not available on this system.
"""
if not is_mlx_available():
raise RuntimeError(
"MLX is not available. MLX requires macOS on Apple Silicon. Install with: pip install 'rfdetr[mlx]'"
)

from rfdetr.mlx.inference import MLXInferenceModel

return MLXInferenceModel.from_pytorch(model_config, pytorch_model)


def build_mlx_seg_inference(
model_config: object,
pytorch_model: object,
) -> "MLXSegInferenceModel":
"""Build a compiled MLX segmentation inference model from a PyTorch RF-DETR seg model.

Converts PyTorch weights (backbone, decoder, and segmentation head) to MLX
format, builds the MLX model graph, casts to FP16, and compiles the full
forward pass for Metal execution.

Args:
model_config: RF-DETR segmentation model configuration (e.g., RFDETRSegNanoConfig).
pytorch_model: The rfdetr.main.Model instance with loaded weights.

Returns:
Compiled MLX segmentation inference model ready for predict() calls.

Raises:
RuntimeError: If MLX is not available on this system.
"""
if not is_mlx_available():
raise RuntimeError(
"MLX is not available. MLX requires macOS on Apple Silicon. Install with: pip install 'rfdetr[mlx]'"
)

from rfdetr.mlx.inference import MLXSegInferenceModel

return MLXSegInferenceModel.from_pytorch(model_config, pytorch_model)
Loading
Loading