Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 81 additions & 1 deletion test/core/test_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------

from typing import Any
import inspect
from typing import Any, ClassVar

import pytest

Expand Down Expand Up @@ -268,6 +269,85 @@ def test_tracker_params_have_descriptions(self, tracker_id: str) -> None:
assert has_descriptions


class TestSearchSpaceValidation:
"""Tests for search_space ClassVar validation in __init_subclass__."""

def test_search_space_keys_match_init_params(self) -> None:
"""ByteTrack and SORT search_space keys are valid __init__ parameters."""
from trackers import ByteTrackTracker, SORTTracker

for tracker_cls in (ByteTrackTracker, SORTTracker):
init_params = set(inspect.signature(tracker_cls.__init__).parameters) - {
"self"
}
for key in tracker_cls.search_space:
assert key in init_params, (
f"{tracker_cls.__name__}.search_space has invalid key: {key}"
)

def test_search_space_invalid_key_raises_value_error(self) -> None:
"""A tracker with search_space key not in __init__ raises ValueError."""

with pytest.raises(ValueError, match=r"search_space key .* is not a parameter"):

class BadTracker(BaseTracker):
tracker_id = "bad"
search_space: ClassVar[dict[str, dict]] = {
"nonexistent_param": {
"type": "uniform",
"range": [0, 1],
}
}

def __init__(self) -> None:
pass

def update(self, detections: Any) -> Any:
return detections

def reset(self) -> None:
pass

def test_tracker_without_search_space_works(self) -> None:
"""Trackers without search_space are still valid."""

class MinimalTracker(BaseTracker):
tracker_id = "minimal"

def __init__(self) -> None:
pass

def update(self, detections: Any) -> Any:
return detections

def reset(self) -> None:
pass

assert (
not hasattr(MinimalTracker, "search_space")
or getattr(MinimalTracker, "search_space", None) is None
)
assert "minimal" in BaseTracker._registered_trackers()

def test_tracker_with_empty_search_space_works(self) -> None:
"""Trackers with empty search_space skip validation."""

class EmptySpaceTracker(BaseTracker):
tracker_id = "empty_space"
search_space: ClassVar[dict[str, dict]] = {}

def __init__(self, x: int = 1) -> None:
pass

def update(self, detections: Any) -> Any:
return detections

def reset(self) -> None:
pass

assert "empty_space" in BaseTracker._registered_trackers()


class TestTrackerInstantiation:
@pytest.mark.parametrize("tracker_id", ["bytetrack", "sort"])
def test_instantiate_with_defaults(self, tracker_id: str) -> None:
Expand Down
22 changes: 22 additions & 0 deletions trackers/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,18 +222,40 @@ class BaseTracker(ABC):
Subclasses that define `tracker_id` are automatically registered and
become discoverable. Parameter metadata is extracted from __init__ for
CLI integration.
Attributes:
tracker_id: Unique identifier for the tracker. Subclasses must define
this to be registered.
search_space: Hyperparameter search space for tuning. Each key must
match an `__init__` parameter. Values are dicts with `type`
(`"randint"` or `"uniform"`) and `range` (`[low, high]`).
"""

_registry: ClassVar[dict[str, TrackerInfo]] = {}
tracker_id: ClassVar[str | None] = None
search_space: ClassVar[dict[str, dict] | None] = None

def __init_subclass__(cls, **kwargs: Any) -> None:
"""Register subclass in the tracker registry if it defines tracker_id.

Extracts parameter metadata from __init__ at class definition time.
Validates search_space (if present) against __init__ parameters.
"""
super().__init_subclass__(**kwargs)

# Validate search_space keys match __init__ parameters (search_space optional)
search_space = getattr(cls, "search_space", None)
if search_space is not None and len(search_space) > 0:
init_params = {
n for n in inspect.signature(cls.__init__).parameters if n != "self"
}
for key in search_space:
if key not in init_params:
raise ValueError(
f"{cls.__name__}: search_space key {key!r} is not a "
f"parameter of __init__. "
f"Valid parameters: {sorted(init_params)}"
)

tracker_id = getattr(cls, "tracker_id", None)
if tracker_id is not None:
BaseTracker._registry[tracker_id] = TrackerInfo(
Expand Down
10 changes: 10 additions & 0 deletions trackers/core/bytetrack/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------

from typing import ClassVar

import numpy as np
import supervision as sv
from scipy.optimize import linear_sum_assignment
Expand Down Expand Up @@ -57,6 +59,14 @@ class ByteTrackTracker(BaseTracker):

tracker_id = "bytetrack"

search_space: ClassVar[dict[str, dict]] = {
"lost_track_buffer": {"type": "randint", "range": [10, 91]},
"track_activation_threshold": {"type": "uniform", "range": [0.1, 0.9]},
"minimum_iou_threshold": {"type": "uniform", "range": [0.05, 0.7]},
"high_conf_det_threshold": {"type": "uniform", "range": [0.3, 0.8]},
"minimum_consecutive_frames": {"type": "randint", "range": [1, 4]},
}

def __init__(
self,
lost_track_buffer: int = 30,
Expand Down
9 changes: 9 additions & 0 deletions trackers/core/sort/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------

from typing import ClassVar

import numpy as np
import supervision as sv
from scipy.optimize import linear_sum_assignment
Expand Down Expand Up @@ -55,6 +57,13 @@ class SORTTracker(BaseTracker):

tracker_id = "sort"

search_space: ClassVar[dict[str, dict]] = {
"lost_track_buffer": {"type": "randint", "range": [10, 91]},
"track_activation_threshold": {"type": "uniform", "range": [0.1, 0.9]},
"minimum_consecutive_frames": {"type": "randint", "range": [1, 4]},
"minimum_iou_threshold": {"type": "uniform", "range": [0.05, 0.7]},
}

def __init__(
self,
lost_track_buffer: int = 30,
Expand Down