From 3e171e0aebce347a1caa04a2166b86463cb18998 Mon Sep 17 00:00:00 2001 From: Omkar Kabde Date: Tue, 24 Feb 2026 16:58:02 +0530 Subject: [PATCH] Add search spaces to trackers (#264) * add search spaces in trackers * add validation * add tests * update search spaces and fixes * remove unnecessary tests * fix errors --- test/core/test_registration.py | 82 +++++++++++++++++++++++++++++- trackers/core/base.py | 22 ++++++++ trackers/core/bytetrack/tracker.py | 10 +++- trackers/core/sort/tracker.py | 9 ++++ 4 files changed, 121 insertions(+), 2 deletions(-) diff --git a/test/core/test_registration.py b/test/core/test_registration.py index 7986649d..cd047d9b 100644 --- a/test/core/test_registration.py +++ b/test/core/test_registration.py @@ -4,7 +4,8 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ -from typing import Any +import inspect +from typing import Any, ClassVar import pytest @@ -268,6 +269,85 @@ def test_tracker_params_have_descriptions(self, tracker_id: str) -> None: assert has_descriptions +class TestSearchSpaceValidation: + """Tests for search_space ClassVar validation in __init_subclass__.""" + + def test_search_space_keys_match_init_params(self) -> None: + """ByteTrack and SORT search_space keys are valid __init__ parameters.""" + from trackers import ByteTrackTracker, SORTTracker + + for tracker_cls in (ByteTrackTracker, SORTTracker): + init_params = set(inspect.signature(tracker_cls.__init__).parameters) - { + "self" + } + for key in tracker_cls.search_space: + assert key in init_params, ( + f"{tracker_cls.__name__}.search_space has invalid key: {key}" + ) + + def test_search_space_invalid_key_raises_value_error(self) -> None: + """A tracker with search_space key not in __init__ raises ValueError.""" + + with pytest.raises(ValueError, match=r"search_space key .* is not a parameter"): + + class BadTracker(BaseTracker): + tracker_id = "bad" + search_space: ClassVar[dict[str, dict]] = { + "nonexistent_param": { + "type": "uniform", + "range": [0, 1], + } + } + + def __init__(self) -> None: + pass + + def update(self, detections: Any) -> Any: + return detections + + def reset(self) -> None: + pass + + def test_tracker_without_search_space_works(self) -> None: + """Trackers without search_space are still valid.""" + + class MinimalTracker(BaseTracker): + tracker_id = "minimal" + + def __init__(self) -> None: + pass + + def update(self, detections: Any) -> Any: + return detections + + def reset(self) -> None: + pass + + assert ( + not hasattr(MinimalTracker, "search_space") + or getattr(MinimalTracker, "search_space", None) is None + ) + assert "minimal" in BaseTracker._registered_trackers() + + def test_tracker_with_empty_search_space_works(self) -> None: + """Trackers with empty search_space skip validation.""" + + class EmptySpaceTracker(BaseTracker): + tracker_id = "empty_space" + search_space: ClassVar[dict[str, dict]] = {} + + def __init__(self, x: int = 1) -> None: + pass + + def update(self, detections: Any) -> Any: + return detections + + def reset(self) -> None: + pass + + assert "empty_space" in BaseTracker._registered_trackers() + + class TestTrackerInstantiation: @pytest.mark.parametrize("tracker_id", ["bytetrack", "sort"]) def test_instantiate_with_defaults(self, tracker_id: str) -> None: diff --git a/trackers/core/base.py b/trackers/core/base.py index a56d2c5a..91c13923 100644 --- a/trackers/core/base.py +++ b/trackers/core/base.py @@ -222,18 +222,40 @@ class BaseTracker(ABC): Subclasses that define `tracker_id` are automatically registered and become discoverable. Parameter metadata is extracted from __init__ for CLI integration. + Attributes: + tracker_id: Unique identifier for the tracker. Subclasses must define + this to be registered. + search_space: Hyperparameter search space for tuning. Each key must + match an `__init__` parameter. Values are dicts with `type` + (`"randint"` or `"uniform"`) and `range` (`[low, high]`). """ _registry: ClassVar[dict[str, TrackerInfo]] = {} tracker_id: ClassVar[str | None] = None + search_space: ClassVar[dict[str, dict] | None] = None def __init_subclass__(cls, **kwargs: Any) -> None: """Register subclass in the tracker registry if it defines tracker_id. Extracts parameter metadata from __init__ at class definition time. + Validates search_space (if present) against __init__ parameters. """ super().__init_subclass__(**kwargs) + # Validate search_space keys match __init__ parameters (search_space optional) + search_space = getattr(cls, "search_space", None) + if search_space is not None and len(search_space) > 0: + init_params = { + n for n in inspect.signature(cls.__init__).parameters if n != "self" + } + for key in search_space: + if key not in init_params: + raise ValueError( + f"{cls.__name__}: search_space key {key!r} is not a " + f"parameter of __init__. " + f"Valid parameters: {sorted(init_params)}" + ) + tracker_id = getattr(cls, "tracker_id", None) if tracker_id is not None: BaseTracker._registry[tracker_id] = TrackerInfo( diff --git a/trackers/core/bytetrack/tracker.py b/trackers/core/bytetrack/tracker.py index a35ecf91..8f2aa833 100644 --- a/trackers/core/bytetrack/tracker.py +++ b/trackers/core/bytetrack/tracker.py @@ -5,7 +5,7 @@ # ------------------------------------------------------------------------ from copy import deepcopy -from typing import cast +from typing import ClassVar, cast import numpy as np import supervision as sv @@ -45,6 +45,14 @@ class ByteTrackTracker(BaseTracker): tracker_id = "bytetrack" + search_space: ClassVar[dict[str, dict]] = { + "lost_track_buffer": {"type": "randint", "range": [10, 91]}, + "track_activation_threshold": {"type": "uniform", "range": [0.1, 0.9]}, + "minimum_iou_threshold": {"type": "uniform", "range": [0.05, 0.7]}, + "high_conf_det_threshold": {"type": "uniform", "range": [0.3, 0.8]}, + "minimum_consecutive_frames": {"type": "randint", "range": [1, 4]}, + } + def __init__( self, lost_track_buffer: int = 30, diff --git a/trackers/core/sort/tracker.py b/trackers/core/sort/tracker.py index 6e956a92..42e83fe5 100644 --- a/trackers/core/sort/tracker.py +++ b/trackers/core/sort/tracker.py @@ -4,6 +4,8 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ +from typing import ClassVar + import numpy as np import supervision as sv from scipy.optimize import linear_sum_assignment @@ -41,6 +43,13 @@ class SORTTracker(BaseTracker): tracker_id = "sort" + search_space: ClassVar[dict[str, dict]] = { + "lost_track_buffer": {"type": "randint", "range": [10, 91]}, + "track_activation_threshold": {"type": "uniform", "range": [0.1, 0.9]}, + "minimum_consecutive_frames": {"type": "randint", "range": [1, 4]}, + "minimum_iou_threshold": {"type": "uniform", "range": [0.05, 0.7]}, + } + def __init__( self, lost_track_buffer: int = 30,