diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py index 3720a50847..ec40bcd81e 100644 --- a/src/lerobot/datasets/dataset_reader.py +++ b/src/lerobot/datasets/dataset_reader.py @@ -72,7 +72,8 @@ def __init__( self.episodes = episodes self._tolerance_s = tolerance_s self._video_backend = video_backend - self._image_transforms = image_transforms + self._image_transforms = None + self.set_image_transforms(image_transforms) self.hf_dataset: datasets.Dataset | None = None self._absolute_to_relative_idx: dict[int, int] | None = None @@ -83,6 +84,16 @@ def __init__( check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s) self.delta_indices = get_delta_indices(delta_timestamps, meta.fps) + def set_image_transforms(self, image_transforms: Callable | None) -> None: + """Replace the transform applied to visual observations.""" + if image_transforms is not None and not callable(image_transforms): + raise TypeError("image_transforms must be callable or None.") + self._image_transforms = image_transforms + + def clear_image_transforms(self) -> None: + """Remove the transform applied to visual observations.""" + self.set_image_transforms(None) + def try_load(self) -> bool: """Attempt to load from local cache. Returns True if data is sufficient.""" try: diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 1725046f23..906a86c4ac 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -194,8 +194,9 @@ def __init__( super().__init__() self.repo_id = repo_id self._requested_root = Path(root) if root else None - self.reader = None - self.set_image_transforms(image_transforms) + if image_transforms is not None and not callable(image_transforms): + raise TypeError("image_transforms must be callable or None.") + self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps self.episodes = episodes self.tolerance_s = tolerance_s @@ -484,7 +485,7 @@ def set_image_transforms(self, image_transforms: Callable | None) -> None: raise TypeError("image_transforms must be callable or None.") self.image_transforms = image_transforms if self.reader is not None: - self.reader._image_transforms = image_transforms + self.reader.set_image_transforms(image_transforms) def clear_image_transforms(self) -> None: """Remove the transform applied to visual observations.""" diff --git a/tests/datasets/test_dataset_reader.py b/tests/datasets/test_dataset_reader.py index 4c8a8b23f8..48d2af7b27 100644 --- a/tests/datasets/test_dataset_reader.py +++ b/tests/datasets/test_dataset_reader.py @@ -143,6 +143,32 @@ def sentinel_transform(img): assert transform_called["count"] >= 1 +def test_set_image_transforms_updates_reader_behavior(tmp_path, lerobot_dataset_factory): + """Reader setters update and clear visual transforms without replacing the reader.""" + transform_called = {"count": 0} + + def sentinel_transform(img): + transform_called["count"] += 1 + return img + + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", + total_episodes=1, + total_frames=5, + use_videos=False, + ) + reader = dataset.reader + + reader.set_image_transforms(sentinel_transform) + reader.get_item(0) + assert transform_called["count"] >= 1 + + calls_after_set = transform_called["count"] + reader.clear_image_transforms() + reader.get_item(0) + assert transform_called["count"] == calls_after_set + + # ── File paths ───────────────────────────────────────────────────────