Conversation
There was a problem hiding this comment.
Pull request overview
Adds native MLX-based segmentation inference so RF-DETR segmentation models can run via optimize_for_inference(backend="mlx") on Apple Silicon, extending the existing MLX detection backend.
Changes:
- Introduces an MLX segmentation inference path (
MLXSegInferenceModel) including an MLXSegHead, seg-weight conversion, and decoder intermediate outputs. - Adds MLX routing + prediction support in
RFDETR.optimize_for_inference()/RFDETR.predict(). - Adds MLX-only test coverage and a new
mlxoptional dependency extra.
Reviewed changes
Copilot reviewed 10 out of 11 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
src/rfdetr/detr.py |
Adds backend="mlx" routing, MLX model caching, and an MLX prediction path. |
src/rfdetr/mlx/__init__.py |
Exposes MLX builders for detection vs segmentation inference models. |
src/rfdetr/mlx/inference.py |
Implements compiled MLX detection + segmentation inference and postprocessing. |
src/rfdetr/mlx/decoder.py |
Adds return_intermediate support for segmentation features. |
src/rfdetr/mlx/seg_head.py |
Adds MLX segmentation head implementation + builder from converted weights. |
src/rfdetr/mlx/convert_weights.py |
Adds segmentation-head weight extraction/remapping and conv transposition support. |
src/rfdetr/mlx/backbone.py |
Adds MLX DINOv2 backbone implementation used by the MLX inference pipeline. |
tests/models/test_mlx_inference.py |
Adds MLX detection backend tests and routing tests for seg vs det builds. |
tests/models/test_mlx_seg_inference.py |
Adds MLX segmentation backend tests (seg weights, intermediates, seg head, postprocess). |
pyproject.toml |
Adds mlx optional extra and registers an mlx pytest marker. |
.gitignore |
Ignores additional local artifacts (e.g., .pth, scratch/demo outputs). |
Codecov Report❌ Patch coverage is ❌ Your patch check has failed because the patch coverage (3%) is below the target coverage (95%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## develop #767 +/- ##
=======================================
- Coverage 77% 69% -8%
=======================================
Files 97 103 +6
Lines 7538 8426 +888
=======================================
+ Hits 5801 5830 +29
- Misses 1737 2596 +859 🚀 New features to boost your workflow:
|
a9f2ed9 to
7d4a6dc
Compare
7d4a6dc to
82f67e4
Compare
…r torch annotation
40a7690 to
ede91f3
Compare
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
… postprocess Addresses review comment by @Copilot (PR roboflow#767) `np.argpartition(-flat, num_select)` raises ValueError when num_select equals flat.size; changed kth to num_select-1 to match the detection postprocess. Added num_select==0 early-exit (with correctly-shaped empty masks array) for parity with the detection path.
…dcoding Addresses review finding [HIGH] by @maintainer-review (PR roboflow#767) set(range(3,depth,3)) = {3,6,9} is correct for Nano/Small/Medium/Large (out_feature_indexes=[3,6,9,12]) but silently wrong for RFDETRBaseConfig and RFDETRLargeDeprecatedConfig (out_feature_indexes=[2,5,8,11]) where the PyTorch backbone runs full attention at {2,5,8,11}.
Addresses review finding [MEDIUM] by @maintainer-review (PR roboflow#767) backbone.py:interpolate_pos_embed() imports scipy.ndimage.zoom; scipy was only in [train], so pip install 'rfdetr[mlx]' users got a confusing ModuleNotFoundError at inference time when pos-embed interpolation ran.
Addresses review findings [MEDIUM] by @maintainer-review (PR roboflow#767) - optimize_for_inference(backend='typo') previously fell through to PyTorch silently; now raises ValueError with the supported options listed - predict(shape=...) was silently ignored when backend='mlx'; now raises NotImplementedError with a workaround hint
…ection Addresses review finding [MEDIUM] by @maintainer-review (PR roboflow#767) MLX source files import mlx.core at module level, which is unavailable on non-Darwin hosts. Previously --doctest-modules was globally disabled as a workaround, silently killing doctest coverage for all other modules. Add a root conftest.py with collect_ignore_glob to skip src/rfdetr/mlx/*.py during collection and restore the --doctest-modules flag.
Addresses review finding [LOW] by @maintainer-review (PR roboflow#767) _optimized_half was set to False in remove_optimized_model() but never read anywhere in the codebase; dead code.
Addresses review finding [LOW] by @maintainer-review (PR roboflow#767) pytest.mark.mlx was registered in pyproject.toml but never applied to the test files, so 'pytest -m mlx' silently matched nothing. Add a module-level pytestmark so '-m mlx' correctly selects all MLX tests.
…/rf-detr into feat/mlx-inference
…mage.zoom Addresses review finding [HIGH] by @maintainer-review (PR roboflow#767) The original 100-iteration sequential cv2.resize loop dominated inference time (~2-5ms per mask on CPU). scipy.ndimage.zoom operates on the full (N, H, W) array in one call (order=1 = bilinear), cutting the resize step to a single operation. Also drops the undeclared cv2 dependency (available transitively via supervision, but not listed in [mlx] extras).
…attn_layers Addresses review findings [HIGH]/[MEDIUM] by @maintainer-review (PR roboflow#767) - test_returns_false_when_mlx_not_installed: assertion was 'check() is False or sys.platform != "darwin"' which is vacuously True on every non-Darwin CI runner; removed the short-circuit so the mock is actually validated on all platforms - backbone.py: add comment explaining why full_attn_layers = set(feature_indices) matches PyTorch's global-attention schedule (out_feature_indexes excludes those layers from windowed attention; feature_indices are the 0-indexed form)
Consolidates segmentation inference tests into `test_mlx_inference.py` for unified MLX test coverage. Removes the now-redundant `test_mlx_seg_inference.py`. Updates file-level docstrings to include segmentation tests.
What does this PR do?
Adds native MLX segmentation inference so
RFDETRSegNano/Small/Medium/Largework withoptimize_for_inference(backend="mlx")on Apple Silicon.MLXSegInferenceModelwith compiled FP16 forward pass (backbone → decoder with intermediates → seg head → masks)SegHeadmodule (depthwise conv blocks + einsum mask generation)convert_seg_weights) with key remappingreturn_intermediatesupport to expose spatial features and per-layer hidden statessegmentation_head=Truein config →MLXSegInferenceModel, otherwise →MLXInferenceModelsv.Detections(mask=...)Related Issue(s): None
Type of Change
Testing
Test details:
test_mlx_seg_inference.py: weight conversion key remapping, conv transposition, num_blocks auto-detection, decoderreturn_intermediateshapes,SegHeadoutput shapes,build_seg_headfrom dict, postprocess outputkeys/shapes/values
test_mlx_inference.py: verifies seg config routes toMLXSegInferenceModeland det config routes toMLXInferenceModelpytest src/ tests/ -n 2 -m "not gpu")pre-commit run --all-filescleanRFDETRSegLarge: correct person/couch/remote detections with masks via MLX backendChecklist
Additional Context
Tested on Apple M4 Pro. The segmentation pipeline reuses the same backbone and decoder as the detection MLX backend, adding only the seg head and
return_intermediateplumbing. No changes to existing detection inference behavior.