diff --git a/pyproject.toml b/pyproject.toml index c662c6ee..fe106467 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ [project.optional-dependencies] detection = ["inference-models>=0.19.0"] +tune = ["optuna>=3.0.0"] [project.scripts] trackers = "trackers.scripts.__main__:main" diff --git a/test/core/test_registration.py b/test/core/test_registration.py index 7986649d..cd047d9b 100644 --- a/test/core/test_registration.py +++ b/test/core/test_registration.py @@ -4,7 +4,8 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ -from typing import Any +import inspect +from typing import Any, ClassVar import pytest @@ -268,6 +269,85 @@ def test_tracker_params_have_descriptions(self, tracker_id: str) -> None: assert has_descriptions +class TestSearchSpaceValidation: + """Tests for search_space ClassVar validation in __init_subclass__.""" + + def test_search_space_keys_match_init_params(self) -> None: + """ByteTrack and SORT search_space keys are valid __init__ parameters.""" + from trackers import ByteTrackTracker, SORTTracker + + for tracker_cls in (ByteTrackTracker, SORTTracker): + init_params = set(inspect.signature(tracker_cls.__init__).parameters) - { + "self" + } + for key in tracker_cls.search_space: + assert key in init_params, ( + f"{tracker_cls.__name__}.search_space has invalid key: {key}" + ) + + def test_search_space_invalid_key_raises_value_error(self) -> None: + """A tracker with search_space key not in __init__ raises ValueError.""" + + with pytest.raises(ValueError, match=r"search_space key .* is not a parameter"): + + class BadTracker(BaseTracker): + tracker_id = "bad" + search_space: ClassVar[dict[str, dict]] = { + "nonexistent_param": { + "type": "uniform", + "range": [0, 1], + } + } + + def __init__(self) -> None: + pass + + def update(self, detections: Any) -> Any: + return detections + + def reset(self) -> None: + pass + + def test_tracker_without_search_space_works(self) -> None: + """Trackers without search_space are still valid.""" + + class MinimalTracker(BaseTracker): + tracker_id = "minimal" + + def __init__(self) -> None: + pass + + def update(self, detections: Any) -> Any: + return detections + + def reset(self) -> None: + pass + + assert ( + not hasattr(MinimalTracker, "search_space") + or getattr(MinimalTracker, "search_space", None) is None + ) + assert "minimal" in BaseTracker._registered_trackers() + + def test_tracker_with_empty_search_space_works(self) -> None: + """Trackers with empty search_space skip validation.""" + + class EmptySpaceTracker(BaseTracker): + tracker_id = "empty_space" + search_space: ClassVar[dict[str, dict]] = {} + + def __init__(self, x: int = 1) -> None: + pass + + def update(self, detections: Any) -> Any: + return detections + + def reset(self) -> None: + pass + + assert "empty_space" in BaseTracker._registered_trackers() + + class TestTrackerInstantiation: @pytest.mark.parametrize("tracker_id", ["bytetrack", "sort"]) def test_instantiate_with_defaults(self, tracker_id: str) -> None: diff --git a/test/tune/__init__.py b/test/tune/__init__.py new file mode 100644 index 00000000..57226e88 --- /dev/null +++ b/test/tune/__init__.py @@ -0,0 +1,5 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ diff --git a/test/tune/test_tuner.py b/test/tune/test_tuner.py new file mode 100644 index 00000000..3c72683b --- /dev/null +++ b/test/tune/test_tuner.py @@ -0,0 +1,218 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch + +import pytest + +from trackers.eval.results import ( + BenchmarkResult, + CLEARMetrics, + HOTAMetrics, + IdentityMetrics, + SequenceResult, +) +from trackers.tune.tuner import Tuner, _extract_metric + +optuna = pytest.importorskip("optuna") + + +_MOT_LINE = "1,-1,10,20,100,80,0.9,1\n" + + +def _make_benchmark_result( + mota: float = 0.75, + hota: float | None = None, + idf1: float | None = None, +) -> BenchmarkResult: + clear = CLEARMetrics( + MOTA=mota, + MOTP=0.8, + MODA=0.76, + CLR_Re=0.8, + CLR_Pr=0.9, + MTR=0.7, + PTR=0.2, + MLR=0.1, + sMOTA=0.72, + CLR_TP=100, + CLR_FN=20, + CLR_FP=10, + IDSW=5, + MT=7, + PT=2, + ML=1, + Frag=3, + ) + hota_metrics = ( + HOTAMetrics( + HOTA=hota, + DetA=0.7, + AssA=0.65, + DetRe=0.72, + DetPr=0.85, + AssRe=0.68, + AssPr=0.9, + LocA=0.78, + OWTA=0.69, + HOTA_TP=1000, + HOTA_FN=300, + HOTA_FP=200, + ) + if hota is not None + else None + ) + identity_metrics = ( + IdentityMetrics(IDF1=idf1, IDR=0.7, IDP=0.8, IDTP=90, IDFN=15, IDFP=10) + if idf1 is not None + else None + ) + return BenchmarkResult( + sequences={}, + aggregate=SequenceResult( + sequence="COMBINED", + CLEAR=clear, + HOTA=hota_metrics, + Identity=identity_metrics, + ), + ) + + +def _setup_dirs(tmp_path: Path) -> tuple[Path, Path]: + gt_dir = tmp_path / "gt" + gt_dir.mkdir() + (gt_dir / "seq1.txt").write_text(_MOT_LINE) + + det_dir = tmp_path / "det" + det_dir.mkdir() + (det_dir / "seq1.txt").write_text(_MOT_LINE) + + return gt_dir, det_dir + + +@pytest.mark.parametrize( + "metric,make_kwargs,expected", + [ + ("MOTA", {"mota": 0.75}, 0.75), + ("HOTA", {"hota": 0.62}, 0.62), + ("IDF1", {"idf1": 0.71}, 0.71), + ], +) +def test_extract_metric(metric: str, make_kwargs: dict, expected: float) -> None: + result = _make_benchmark_result(**make_kwargs) + assert _extract_metric(result, metric) == pytest.approx(expected) + + +@pytest.mark.parametrize("metric", ["NONEXISTENT", "HOTA"]) +def test_extract_metric_raises(metric: str) -> None: + # Default result has no HOTA/Identity families — both should raise + result = _make_benchmark_result() + with pytest.raises(ValueError, match=r"not found in BenchmarkResult\.aggregate"): + _extract_metric(result, metric) + + +class TestTunerInit: + def test_raises_for_unknown_tracker(self, tmp_path: Path) -> None: + gt_dir, det_dir = _setup_dirs(tmp_path) + with pytest.raises(ValueError, match=r"not registered"): + Tuner("nonexistent_tracker", gt_dir, det_dir) + + def test_raises_for_tracker_without_search_space(self, tmp_path: Path) -> None: + from trackers.core.base import BaseTracker + + class _NoSearchSpaceTracker(BaseTracker): + tracker_id = "_test_no_ss" + search_space = None + + def update(self, detections): # type: ignore[override] + return detections + + def reset(self) -> None: + pass + + gt_dir, det_dir = _setup_dirs(tmp_path) + try: + with pytest.raises(ValueError, match=r"does not define a search_space"): + Tuner("_test_no_ss", gt_dir, det_dir) + finally: + BaseTracker._registry.pop("_test_no_ss", None) + + def test_raises_when_no_sequences_found(self, tmp_path: Path) -> None: + gt_dir = tmp_path / "gt" + gt_dir.mkdir() + det_dir = tmp_path / "det" + det_dir.mkdir() + with pytest.raises(ValueError, match=r"No sequences found"): + Tuner("bytetrack", gt_dir, det_dir) + + def test_valid_init_stores_attributes(self, tmp_path: Path) -> None: + gt_dir, det_dir = _setup_dirs(tmp_path) + tuner = Tuner("bytetrack", gt_dir, det_dir, n_trials=10) + assert tuner._tracker_id == "bytetrack" + assert tuner._objective_metric == "MOTA" + assert tuner._metrics == ["CLEAR"] + assert tuner._n_trials == 10 + assert tuner._sequences == ["seq1"] + + def test_seqmap_filters_sequences(self, tmp_path: Path) -> None: + gt_dir, det_dir = _setup_dirs(tmp_path) + (det_dir / "seq2.txt").write_text(_MOT_LINE) + seqmap = tmp_path / "seqmap.txt" + seqmap.write_text("seq1\n") + tuner = Tuner("bytetrack", gt_dir, det_dir, seqmap=seqmap) + assert tuner._sequences == ["seq1"] + + +class TestTunerRun: + def test_run_returns_dict_with_search_space_keys(self, tmp_path: Path) -> None: + from trackers import ByteTrackTracker + + assert ByteTrackTracker.search_space is not None + expected_keys = set(ByteTrackTracker.search_space.keys()) + gt_dir, det_dir = _setup_dirs(tmp_path) + + with ( + patch( + "trackers.tune.tuner.evaluate_mot_sequences", + return_value=_make_benchmark_result(), + ), + patch("trackers.tune.tuner._run_tracker_on_detections"), + ): + tuner = Tuner("bytetrack", gt_dir, det_dir, n_trials=2) + best = tuner.run() + + assert isinstance(best, dict) + assert set(best.keys()) == expected_keys + + def test_run_calls_tracker_reset_per_sequence(self, tmp_path: Path) -> None: + """reset() must be called once per sequence per trial.""" + from trackers import SORTTracker + + reset_calls: list[int] = [] + gt_dir, det_dir = _setup_dirs(tmp_path) + (det_dir / "seq2.txt").write_text(_MOT_LINE) # two sequences → two resets + + original_reset = SORTTracker.reset + + def _counting_reset(self_tracker: SORTTracker) -> None: + reset_calls.append(1) + original_reset(self_tracker) + + with ( + patch( + "trackers.tune.tuner.evaluate_mot_sequences", + return_value=_make_benchmark_result(), + ), + patch("trackers.tune.tuner._run_tracker_on_detections"), + patch.object(SORTTracker, "reset", _counting_reset), + ): + tuner = Tuner("sort", gt_dir, det_dir, n_trials=1) + tuner.run() + + assert len(reset_calls) == 2 # 1 trial * 2 sequences diff --git a/trackers/core/base.py b/trackers/core/base.py index a56d2c5a..91c13923 100644 --- a/trackers/core/base.py +++ b/trackers/core/base.py @@ -222,18 +222,40 @@ class BaseTracker(ABC): Subclasses that define `tracker_id` are automatically registered and become discoverable. Parameter metadata is extracted from __init__ for CLI integration. + Attributes: + tracker_id: Unique identifier for the tracker. Subclasses must define + this to be registered. + search_space: Hyperparameter search space for tuning. Each key must + match an `__init__` parameter. Values are dicts with `type` + (`"randint"` or `"uniform"`) and `range` (`[low, high]`). """ _registry: ClassVar[dict[str, TrackerInfo]] = {} tracker_id: ClassVar[str | None] = None + search_space: ClassVar[dict[str, dict] | None] = None def __init_subclass__(cls, **kwargs: Any) -> None: """Register subclass in the tracker registry if it defines tracker_id. Extracts parameter metadata from __init__ at class definition time. + Validates search_space (if present) against __init__ parameters. """ super().__init_subclass__(**kwargs) + # Validate search_space keys match __init__ parameters (search_space optional) + search_space = getattr(cls, "search_space", None) + if search_space is not None and len(search_space) > 0: + init_params = { + n for n in inspect.signature(cls.__init__).parameters if n != "self" + } + for key in search_space: + if key not in init_params: + raise ValueError( + f"{cls.__name__}: search_space key {key!r} is not a " + f"parameter of __init__. " + f"Valid parameters: {sorted(init_params)}" + ) + tracker_id = getattr(cls, "tracker_id", None) if tracker_id is not None: BaseTracker._registry[tracker_id] = TrackerInfo( diff --git a/trackers/core/bytetrack/tracker.py b/trackers/core/bytetrack/tracker.py index 3792266c..b7915e1b 100644 --- a/trackers/core/bytetrack/tracker.py +++ b/trackers/core/bytetrack/tracker.py @@ -4,6 +4,8 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ +from typing import ClassVar + import numpy as np import supervision as sv from scipy.optimize import linear_sum_assignment @@ -57,6 +59,14 @@ class ByteTrackTracker(BaseTracker): tracker_id = "bytetrack" + search_space: ClassVar[dict[str, dict]] = { + "lost_track_buffer": {"type": "randint", "range": [10, 91]}, + "track_activation_threshold": {"type": "uniform", "range": [0.1, 0.9]}, + "minimum_iou_threshold": {"type": "uniform", "range": [0.05, 0.7]}, + "high_conf_det_threshold": {"type": "uniform", "range": [0.3, 0.8]}, + "minimum_consecutive_frames": {"type": "randint", "range": [1, 4]}, + } + def __init__( self, lost_track_buffer: int = 30, diff --git a/trackers/core/ocsort/tracker.py b/trackers/core/ocsort/tracker.py index f84fa7b6..891bcc7f 100644 --- a/trackers/core/ocsort/tracker.py +++ b/trackers/core/ocsort/tracker.py @@ -4,6 +4,8 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ +from typing import ClassVar + import numpy as np import supervision as sv from scipy.optimize import linear_sum_assignment @@ -61,6 +63,15 @@ class OCSORTTracker(BaseTracker): tracker_id = "ocsort" + search_space: ClassVar[dict[str, dict]] = { + "lost_track_buffer": {"type": "randint", "range": [10, 61]}, + "minimum_iou_threshold": {"type": "uniform", "range": [0.1, 0.5]}, + "minimum_consecutive_frames": {"type": "randint", "range": [3, 6]}, + "direction_consistency_weight": {"type": "uniform", "range": [0.0, 0.5]}, + "high_conf_det_threshold": {"type": "uniform", "range": [0.4, 0.8]}, + "delta_t": {"type": "randint", "range": [1, 4]}, + } + def __init__( self, lost_track_buffer: int = 30, diff --git a/trackers/core/sort/tracker.py b/trackers/core/sort/tracker.py index 5e0f2a5f..2470e83b 100644 --- a/trackers/core/sort/tracker.py +++ b/trackers/core/sort/tracker.py @@ -4,6 +4,8 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ +from typing import ClassVar + import numpy as np import supervision as sv from scipy.optimize import linear_sum_assignment @@ -55,6 +57,13 @@ class SORTTracker(BaseTracker): tracker_id = "sort" + search_space: ClassVar[dict[str, dict]] = { + "lost_track_buffer": {"type": "randint", "range": [10, 91]}, + "track_activation_threshold": {"type": "uniform", "range": [0.1, 0.9]}, + "minimum_consecutive_frames": {"type": "randint", "range": [1, 4]}, + "minimum_iou_threshold": {"type": "uniform", "range": [0.05, 0.7]}, + } + def __init__( self, lost_track_buffer: int = 30, diff --git a/trackers/tune/__init__.py b/trackers/tune/__init__.py new file mode 100644 index 00000000..5f233ccf --- /dev/null +++ b/trackers/tune/__init__.py @@ -0,0 +1,11 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +"""Hyperparameter tuning utilities for MOT trackers.""" + +from trackers.tune.tuner import Tuner + +__all__ = ["Tuner"] diff --git a/trackers/tune/tuner.py b/trackers/tune/tuner.py new file mode 100644 index 00000000..dd4cd8db --- /dev/null +++ b/trackers/tune/tuner.py @@ -0,0 +1,247 @@ +# ------------------------------------------------------------------------ +# Trackers +# Copyright (c) 2026 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +"""Optuna-based hyperparameter tuner for registered MOT trackers.""" + +from __future__ import annotations + +import tempfile +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import supervision as sv + +from trackers.core.base import BaseTracker +from trackers.eval.evaluate import evaluate_mot_sequences +from trackers.eval.results import BenchmarkResult +from trackers.io.mot import _load_mot_file, _mot_frame_to_detections, _MOTOutput + +if TYPE_CHECKING: + import optuna + + +class Tuner: + """Wraps Optuna to tune hyperparameters of a registered MOT tracker. + + Uses the tracker's ``search_space`` ClassVar to define the Optuna parameter + distributions. For each trial the tuner instantiates the tracker with the + sampled parameters, runs it frame-by-frame over every sequence using + pre-computed detections from ``detections_dir``, and evaluates the + predictions with ``evaluate_mot_sequences``. + + Args: + tracker_id: Registered tracker identifier (e.g. ``"bytetrack"``). + gt_dir: Directory of ground-truth MOT files. + detections_dir: Directory of pre-computed detection files in MOT flat + format (one ``{seq}.txt`` per sequence). + metrics: Metric families to compute. Supported values are + ``["CLEAR", "HOTA", "Identity"]``. Defaults to ``["CLEAR"]``. + objective: Scalar metric field to maximise (e.g. ``"MOTA"``, + ``"HOTA"``, ``"IDF1"``). Defaults to ``"MOTA"``. + n_trials: Number of Optuna trials to run. Defaults to ``100``. + threshold: IoU threshold forwarded to ``evaluate_mot_sequences``. + Defaults to ``0.5``. + seqmap: Optional path to a sequence map file. When provided only the + listed sequences are evaluated. + + Examples: + >>> from trackers.tune import Tuner # doctest: +SKIP + >>> + >>> tuner = Tuner( # doctest: +SKIP + ... tracker_id="bytetrack", + ... gt_dir="data/gt/", + ... detections_dir="data/det/", + ... n_trials=50, + ... ) + >>> best_params = tuner.run() # doctest: +SKIP + """ + + def __init__( + self, + tracker_id: str, + gt_dir: str | Path, + detections_dir: str | Path, + metrics: list[str] | None = None, + objective: str = "MOTA", + n_trials: int = 100, + threshold: float = 0.5, + seqmap: str | Path | None = None, + ) -> None: + try: + import optuna as _optuna + + self._optuna = _optuna + except ImportError as exc: + raise ImportError( + "Error: optuna is required for hyperparameter tuning. " + "Install it with: pip install 'trackers[tune]'" + ) from exc + + tracker_info = BaseTracker._lookup_tracker(tracker_id) + if tracker_info is None: + raise ValueError( + f"Tracker {tracker_id!r} is not registered. " + f"Available trackers: {BaseTracker._registered_trackers()}" + ) + + search_space = tracker_info.tracker_class.search_space + if not search_space: + raise ValueError( + f"Tracker {tracker_id!r} does not define a search_space. " + "Add a search_space ClassVar to enable tuning." + ) + + self._tracker_id = tracker_id + self._tracker_info = tracker_info + self._search_space: dict[str, dict] = search_space + self._gt_dir = Path(gt_dir) + self._detections_dir = Path(detections_dir) + self._metrics = metrics or ["CLEAR"] + self._objective_metric = objective + self._n_trials = n_trials + self._threshold = threshold + self._sequences = _discover_sequences(self._detections_dir, seqmap) + + if not self._sequences: + raise ValueError(f"No sequences found in {self._detections_dir}") + + def _objective(self, trial: optuna.Trial) -> float: + """Sample hyperparameters, run tracker over all sequences, return metric. + + Args: + trial: Optuna trial used to sample parameter values. + + Returns: + Scalar metric value for this trial. + """ + params: dict[str, Any] = {} + for name, spec in self._search_space.items(): + low, high = spec["range"] + if spec["type"] == "randint": + params[name] = trial.suggest_int(name, low, high) + else: + params[name] = trial.suggest_float(name, low, high) + + # Start from __init__ defaults and override with sampled params + kwargs: dict[str, Any] = { + n: p.default_value for n, p in self._tracker_info.parameters.items() + } + kwargs.update(params) + tracker = self._tracker_info.tracker_class(**kwargs) + + with tempfile.TemporaryDirectory() as tmp_dir: + output_dir = Path(tmp_dir) + for seq_name in self._sequences: + tracker.reset() + det_path = self._detections_dir / f"{seq_name}.txt" + pred_path = output_dir / f"{seq_name}.txt" + _run_tracker_on_detections(tracker, det_path, pred_path) + + result: BenchmarkResult = evaluate_mot_sequences( + gt_dir=self._gt_dir, + tracker_dir=output_dir, + metrics=self._metrics, + threshold=self._threshold, + ) + + return _extract_metric(result, self._objective_metric) + + def run(self) -> dict[str, Any]: + """Create an Optuna study, run trials, and return the best parameters. + + Returns: + Dictionary mapping each ``search_space`` parameter name to its + best value found across all trials. + """ + self.study = self._optuna.create_study( + direction="maximize", + study_name=f"trackers-tune-{self._tracker_id}", + ) + self.study.optimize(self._objective, n_trials=self._n_trials) + return dict(self.study.best_params) + + +def _discover_sequences( + detections_dir: str | Path, + seqmap: str | Path | None, +) -> list[str]: + """Return the list of sequence names to tune over. + + Reads sequence names from ``seqmap`` when provided, otherwise discovers + them by globbing ``*.txt`` files in ``detections_dir``. + + Args: + detections_dir: Directory containing ``{seq}.txt`` detection files. + seqmap: Optional sequence map file. Each non-comment, non-empty line + is treated as a sequence name. + + Returns: + Sorted list of sequence names. + """ + detections_dir = Path(detections_dir) + if seqmap is not None: + lines = Path(seqmap).read_text().splitlines() + return [ + ln.strip() + for ln in lines + if ln.strip() and not ln.startswith("#") and ln.strip().lower() != "name" + ] + return sorted(p.stem for p in detections_dir.glob("*.txt")) + + +def _run_tracker_on_detections( + tracker: BaseTracker, + det_path: Path, + pred_path: Path, +) -> None: + """Run a tracker on a MOT detection file and write predictions. + + Iterates every frame from 1 to the last frame in the detection file, + feeding ``sv.Detections.empty()`` for frames with no detections so the + tracker can age and prune its internal state correctly. + + Args: + tracker: Tracker instance already reset for this sequence. + det_path: Path to the MOT-format detection file. + pred_path: Destination path for the MOT-format prediction file. + """ + det_data = _load_mot_file(det_path) + max_frame = max(det_data.keys()) + + with _MOTOutput(pred_path) as mot_out: + for frame_idx in range(1, max_frame + 1): + if frame_idx in det_data: + dets = _mot_frame_to_detections(det_data[frame_idx]) + else: + dets = sv.Detections.empty() + tracked = tracker.update(dets) + mot_out.write(frame_idx, tracked) + + +def _extract_metric(result: BenchmarkResult, metric: str) -> float: + """Extract a scalar metric value from ``BenchmarkResult.aggregate``. + + Searches CLEAR, HOTA, and Identity metrics in order. + + Args: + result: Benchmark result returned by ``evaluate_mot_sequences``. + metric: Field name to extract (e.g. ``"MOTA"``, ``"HOTA"``, + ``"IDF1"``). + + Returns: + The metric value as a float.""" + agg = result.aggregate + for metrics_obj in (agg.CLEAR, agg.HOTA, agg.Identity): + if metrics_obj is None: + continue + value = metrics_obj.to_dict().get(metric) + if value is not None: + return float(value) + + raise ValueError( + f"Metric {metric!r} not found in BenchmarkResult.aggregate. " + "Ensure the corresponding metric family is included in `metrics`." + )