Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/lerobot/datasets/dataset_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def __init__(
self.episodes = episodes
self._tolerance_s = tolerance_s
self._video_backend = video_backend
self._image_transforms = image_transforms
self._image_transforms = None
self.set_image_transforms(image_transforms)

self.hf_dataset: datasets.Dataset | None = None
self._absolute_to_relative_idx: dict[int, int] | None = None
Expand All @@ -83,6 +84,16 @@ def __init__(
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)

def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform applied to visual observations."""
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self._image_transforms = image_transforms

def clear_image_transforms(self) -> None:
"""Remove the transform applied to visual observations."""
self.set_image_transforms(None)

def try_load(self) -> bool:
"""Attempt to load from local cache. Returns True if data is sufficient."""
try:
Expand Down
7 changes: 4 additions & 3 deletions src/lerobot/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,9 @@ def __init__(
super().__init__()
self.repo_id = repo_id
self._requested_root = Path(root) if root else None
self.reader = None
self.set_image_transforms(image_transforms)
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
Expand Down Expand Up @@ -484,7 +485,7 @@ def set_image_transforms(self, image_transforms: Callable | None) -> None:
raise TypeError("image_transforms must be callable or None.")
self.image_transforms = image_transforms
if self.reader is not None:
self.reader._image_transforms = image_transforms
self.reader.set_image_transforms(image_transforms)

def clear_image_transforms(self) -> None:
"""Remove the transform applied to visual observations."""
Expand Down
26 changes: 26 additions & 0 deletions tests/datasets/test_dataset_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,32 @@ def sentinel_transform(img):
assert transform_called["count"] >= 1


def test_set_image_transforms_updates_reader_behavior(tmp_path, lerobot_dataset_factory):
"""Reader setters update and clear visual transforms without replacing the reader."""
transform_called = {"count": 0}

def sentinel_transform(img):
transform_called["count"] += 1
return img

dataset = lerobot_dataset_factory(
root=tmp_path / "ds",
total_episodes=1,
total_frames=5,
use_videos=False,
)
reader = dataset.reader

reader.set_image_transforms(sentinel_transform)
reader.get_item(0)
assert transform_called["count"] >= 1

calls_after_set = transform_called["count"]
reader.clear_image_transforms()
reader.get_item(0)
assert transform_called["count"] == calls_after_set


# ── File paths ───────────────────────────────────────────────────────


Expand Down
Loading