Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [

[project.optional-dependencies]
detection = ["inference-models>=0.19.0"]
tune = ["optuna>=3.0.0"]

[project.scripts]
trackers = "trackers.scripts.__main__:main"
Expand Down
82 changes: 81 additions & 1 deletion test/core/test_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------

from typing import Any
import inspect
from typing import Any, ClassVar

import pytest

Expand Down Expand Up @@ -268,6 +269,85 @@ def test_tracker_params_have_descriptions(self, tracker_id: str) -> None:
assert has_descriptions


class TestSearchSpaceValidation:
"""Tests for search_space ClassVar validation in __init_subclass__."""

def test_search_space_keys_match_init_params(self) -> None:
"""ByteTrack and SORT search_space keys are valid __init__ parameters."""
from trackers import ByteTrackTracker, SORTTracker

for tracker_cls in (ByteTrackTracker, SORTTracker):
init_params = set(inspect.signature(tracker_cls.__init__).parameters) - {
"self"
}
for key in tracker_cls.search_space:
assert key in init_params, (
f"{tracker_cls.__name__}.search_space has invalid key: {key}"
)

def test_search_space_invalid_key_raises_value_error(self) -> None:
"""A tracker with search_space key not in __init__ raises ValueError."""

with pytest.raises(ValueError, match=r"search_space key .* is not a parameter"):

class BadTracker(BaseTracker):
tracker_id = "bad"
search_space: ClassVar[dict[str, dict]] = {
"nonexistent_param": {
"type": "uniform",
"range": [0, 1],
}
}

def __init__(self) -> None:
pass

def update(self, detections: Any) -> Any:
return detections

def reset(self) -> None:
pass

def test_tracker_without_search_space_works(self) -> None:
"""Trackers without search_space are still valid."""

class MinimalTracker(BaseTracker):
tracker_id = "minimal"

def __init__(self) -> None:
pass

def update(self, detections: Any) -> Any:
return detections

def reset(self) -> None:
pass

assert (
not hasattr(MinimalTracker, "search_space")
or getattr(MinimalTracker, "search_space", None) is None
)
assert "minimal" in BaseTracker._registered_trackers()

def test_tracker_with_empty_search_space_works(self) -> None:
"""Trackers with empty search_space skip validation."""

class EmptySpaceTracker(BaseTracker):
tracker_id = "empty_space"
search_space: ClassVar[dict[str, dict]] = {}

def __init__(self, x: int = 1) -> None:
pass

def update(self, detections: Any) -> Any:
return detections

def reset(self) -> None:
pass

assert "empty_space" in BaseTracker._registered_trackers()


class TestTrackerInstantiation:
@pytest.mark.parametrize("tracker_id", ["bytetrack", "sort"])
def test_instantiate_with_defaults(self, tracker_id: str) -> None:
Expand Down
5 changes: 5 additions & 0 deletions test/tune/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# ------------------------------------------------------------------------
# Trackers
# Copyright (c) 2026 Roboflow. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
218 changes: 218 additions & 0 deletions test/tune/test_tuner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# ------------------------------------------------------------------------
# Trackers
# Copyright (c) 2026 Roboflow. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------

from __future__ import annotations

from pathlib import Path
from unittest.mock import patch

import pytest

from trackers.eval.results import (
BenchmarkResult,
CLEARMetrics,
HOTAMetrics,
IdentityMetrics,
SequenceResult,
)
from trackers.tune.tuner import Tuner, _extract_metric

optuna = pytest.importorskip("optuna")


_MOT_LINE = "1,-1,10,20,100,80,0.9,1\n"


def _make_benchmark_result(
mota: float = 0.75,
hota: float | None = None,
idf1: float | None = None,
) -> BenchmarkResult:
clear = CLEARMetrics(
MOTA=mota,
MOTP=0.8,
MODA=0.76,
CLR_Re=0.8,
CLR_Pr=0.9,
MTR=0.7,
PTR=0.2,
MLR=0.1,
sMOTA=0.72,
CLR_TP=100,
CLR_FN=20,
CLR_FP=10,
IDSW=5,
MT=7,
PT=2,
ML=1,
Frag=3,
)
hota_metrics = (
HOTAMetrics(
HOTA=hota,
DetA=0.7,
AssA=0.65,
DetRe=0.72,
DetPr=0.85,
AssRe=0.68,
AssPr=0.9,
LocA=0.78,
OWTA=0.69,
HOTA_TP=1000,
HOTA_FN=300,
HOTA_FP=200,
)
if hota is not None
else None
)
identity_metrics = (
IdentityMetrics(IDF1=idf1, IDR=0.7, IDP=0.8, IDTP=90, IDFN=15, IDFP=10)
if idf1 is not None
else None
)
return BenchmarkResult(
sequences={},
aggregate=SequenceResult(
sequence="COMBINED",
CLEAR=clear,
HOTA=hota_metrics,
Identity=identity_metrics,
),
)


def _setup_dirs(tmp_path: Path) -> tuple[Path, Path]:
gt_dir = tmp_path / "gt"
gt_dir.mkdir()
(gt_dir / "seq1.txt").write_text(_MOT_LINE)

det_dir = tmp_path / "det"
det_dir.mkdir()
(det_dir / "seq1.txt").write_text(_MOT_LINE)

return gt_dir, det_dir


@pytest.mark.parametrize(
"metric,make_kwargs,expected",
[
("MOTA", {"mota": 0.75}, 0.75),
("HOTA", {"hota": 0.62}, 0.62),
("IDF1", {"idf1": 0.71}, 0.71),
],
)
def test_extract_metric(metric: str, make_kwargs: dict, expected: float) -> None:
result = _make_benchmark_result(**make_kwargs)
assert _extract_metric(result, metric) == pytest.approx(expected)


@pytest.mark.parametrize("metric", ["NONEXISTENT", "HOTA"])
def test_extract_metric_raises(metric: str) -> None:
# Default result has no HOTA/Identity families — both should raise
result = _make_benchmark_result()
with pytest.raises(ValueError, match=r"not found in BenchmarkResult\.aggregate"):
_extract_metric(result, metric)


class TestTunerInit:
def test_raises_for_unknown_tracker(self, tmp_path: Path) -> None:
gt_dir, det_dir = _setup_dirs(tmp_path)
with pytest.raises(ValueError, match=r"not registered"):
Tuner("nonexistent_tracker", gt_dir, det_dir)

def test_raises_for_tracker_without_search_space(self, tmp_path: Path) -> None:
from trackers.core.base import BaseTracker

class _NoSearchSpaceTracker(BaseTracker):
tracker_id = "_test_no_ss"
search_space = None

def update(self, detections): # type: ignore[override]
return detections

def reset(self) -> None:
pass

gt_dir, det_dir = _setup_dirs(tmp_path)
try:
with pytest.raises(ValueError, match=r"does not define a search_space"):
Tuner("_test_no_ss", gt_dir, det_dir)
finally:
BaseTracker._registry.pop("_test_no_ss", None)

def test_raises_when_no_sequences_found(self, tmp_path: Path) -> None:
gt_dir = tmp_path / "gt"
gt_dir.mkdir()
det_dir = tmp_path / "det"
det_dir.mkdir()
with pytest.raises(ValueError, match=r"No sequences found"):
Tuner("bytetrack", gt_dir, det_dir)

def test_valid_init_stores_attributes(self, tmp_path: Path) -> None:
gt_dir, det_dir = _setup_dirs(tmp_path)
tuner = Tuner("bytetrack", gt_dir, det_dir, n_trials=10)
assert tuner._tracker_id == "bytetrack"
assert tuner._objective_metric == "MOTA"
assert tuner._metrics == ["CLEAR"]
assert tuner._n_trials == 10
assert tuner._sequences == ["seq1"]

def test_seqmap_filters_sequences(self, tmp_path: Path) -> None:
gt_dir, det_dir = _setup_dirs(tmp_path)
(det_dir / "seq2.txt").write_text(_MOT_LINE)
seqmap = tmp_path / "seqmap.txt"
seqmap.write_text("seq1\n")
tuner = Tuner("bytetrack", gt_dir, det_dir, seqmap=seqmap)
assert tuner._sequences == ["seq1"]


class TestTunerRun:
def test_run_returns_dict_with_search_space_keys(self, tmp_path: Path) -> None:
from trackers import ByteTrackTracker

assert ByteTrackTracker.search_space is not None
expected_keys = set(ByteTrackTracker.search_space.keys())
gt_dir, det_dir = _setup_dirs(tmp_path)

with (
patch(
"trackers.tune.tuner.evaluate_mot_sequences",
return_value=_make_benchmark_result(),
),
patch("trackers.tune.tuner._run_tracker_on_detections"),
):
tuner = Tuner("bytetrack", gt_dir, det_dir, n_trials=2)
best = tuner.run()

assert isinstance(best, dict)
assert set(best.keys()) == expected_keys

def test_run_calls_tracker_reset_per_sequence(self, tmp_path: Path) -> None:
"""reset() must be called once per sequence per trial."""
from trackers import SORTTracker

reset_calls: list[int] = []
gt_dir, det_dir = _setup_dirs(tmp_path)
(det_dir / "seq2.txt").write_text(_MOT_LINE) # two sequences → two resets

original_reset = SORTTracker.reset

def _counting_reset(self_tracker: SORTTracker) -> None:
reset_calls.append(1)
original_reset(self_tracker)

with (
patch(
"trackers.tune.tuner.evaluate_mot_sequences",
return_value=_make_benchmark_result(),
),
patch("trackers.tune.tuner._run_tracker_on_detections"),
patch.object(SORTTracker, "reset", _counting_reset),
):
tuner = Tuner("sort", gt_dir, det_dir, n_trials=1)
tuner.run()

assert len(reset_calls) == 2 # 1 trial * 2 sequences
22 changes: 22 additions & 0 deletions trackers/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,18 +222,40 @@ class BaseTracker(ABC):
Subclasses that define `tracker_id` are automatically registered and
become discoverable. Parameter metadata is extracted from __init__ for
CLI integration.
Attributes:
tracker_id: Unique identifier for the tracker. Subclasses must define
this to be registered.
search_space: Hyperparameter search space for tuning. Each key must
match an `__init__` parameter. Values are dicts with `type`
(`"randint"` or `"uniform"`) and `range` (`[low, high]`).
"""

_registry: ClassVar[dict[str, TrackerInfo]] = {}
tracker_id: ClassVar[str | None] = None
search_space: ClassVar[dict[str, dict] | None] = None

def __init_subclass__(cls, **kwargs: Any) -> None:
"""Register subclass in the tracker registry if it defines tracker_id.

Extracts parameter metadata from __init__ at class definition time.
Validates search_space (if present) against __init__ parameters.
"""
super().__init_subclass__(**kwargs)

# Validate search_space keys match __init__ parameters (search_space optional)
search_space = getattr(cls, "search_space", None)
if search_space is not None and len(search_space) > 0:
init_params = {
n for n in inspect.signature(cls.__init__).parameters if n != "self"
}
for key in search_space:
if key not in init_params:
raise ValueError(
f"{cls.__name__}: search_space key {key!r} is not a "
f"parameter of __init__. "
f"Valid parameters: {sorted(init_params)}"
)

tracker_id = getattr(cls, "tracker_id", None)
if tracker_id is not None:
BaseTracker._registry[tracker_id] = TrackerInfo(
Expand Down
10 changes: 10 additions & 0 deletions trackers/core/bytetrack/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------

from typing import ClassVar

import numpy as np
import supervision as sv
from scipy.optimize import linear_sum_assignment
Expand Down Expand Up @@ -57,6 +59,14 @@ class ByteTrackTracker(BaseTracker):

tracker_id = "bytetrack"

search_space: ClassVar[dict[str, dict]] = {
"lost_track_buffer": {"type": "randint", "range": [10, 91]},
"track_activation_threshold": {"type": "uniform", "range": [0.1, 0.9]},
"minimum_iou_threshold": {"type": "uniform", "range": [0.05, 0.7]},
"high_conf_det_threshold": {"type": "uniform", "range": [0.3, 0.8]},
"minimum_consecutive_frames": {"type": "randint", "range": [1, 4]},
}

def __init__(
self,
lost_track_buffer: int = 30,
Expand Down
Loading
Loading