diff --git a/pyproject.toml b/pyproject.toml index 4a1efab30b..127e794d16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -220,6 +220,8 @@ lerobot-replay="lerobot.scripts.lerobot_replay:main" lerobot-setup-motors="lerobot.scripts.lerobot_setup_motors:main" lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main" lerobot-eval="lerobot.scripts.lerobot_eval:main" +lerobot-eval-parallel="lerobot.scripts.lerobot_eval_parallel:main" +lerobot-eval-autotune="lerobot.scripts.lerobot_eval_autotune:main" lerobot-train="lerobot.scripts.lerobot_train:main" lerobot-train-tokenizer="lerobot.scripts.lerobot_train_tokenizer:main" lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main" diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index 38039a7bf2..6aa6f33bcd 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -69,6 +69,11 @@ class EvalConfig: # `use_async_envs` specifies whether to use asynchronous environments (multiprocessing). # Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1. use_async_envs: bool = True + # Sharding: split n_episodes across independent processes. + # shard_id=0, num_shards=1 is the default (no sharding, existing behaviour). + # Set via lerobot_eval_parallel or manually: --eval.shard_id=K --eval.num_shards=N + shard_id: int = 0 + num_shards: int = 1 def __post_init__(self) -> None: if self.batch_size > self.n_episodes: @@ -80,6 +85,12 @@ def __post_init__(self) -> None: f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), " f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)." ) + if self.num_shards < 1: + raise ValueError(f"`num_shards` must be >= 1, got {self.num_shards}") + if not (0 <= self.shard_id < self.num_shards): + raise ValueError( + f"`shard_id` must be in [0, num_shards), got shard_id={self.shard_id}, num_shards={self.num_shards}" + ) @dataclass diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index c87c47d1b7..573fcc2140 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -47,8 +47,10 @@ """ import concurrent.futures as cf +import copy import json import logging +import math import threading import time from collections import defaultdict @@ -56,7 +58,6 @@ from contextlib import nullcontext from copy import deepcopy from dataclasses import asdict -from functools import partial from pathlib import Path from pprint import pformat from typing import Any, TypedDict @@ -92,6 +93,14 @@ ) +def _shard_episodes(n_episodes: int, shard_id: int, num_shards: int) -> list[int]: + """Return the episode indices assigned to this shard (round-robin distribution). + + Example: _shard_episodes(10, 1, 4) -> [1, 5, 9] + """ + return list(range(shard_id, n_episodes, num_shards)) + + def rollout( env: gym.vector.VectorEnv, policy: PreTrainedPolicy, @@ -553,6 +562,14 @@ def eval_main(cfg: EvalPipelineConfig): # Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments) env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy) + # Sharding: each shard runs a subset of n_episodes with non-overlapping seeds. + shard_id = cfg.eval.shard_id + num_shards = cfg.eval.num_shards + episodes_for_shard = _shard_episodes(cfg.eval.n_episodes, shard_id, num_shards) + n_per_shard = len(episodes_for_shard) + # Shift the seed so each shard gets a different, non-overlapping seed range. + shard_seed = (cfg.seed or 0) + shard_id * math.ceil(cfg.eval.n_episodes / num_shards) + with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): info = eval_policy_all( envs=envs, @@ -561,10 +578,10 @@ def eval_main(cfg: EvalPipelineConfig): env_postprocessor=env_postprocessor, preprocessor=preprocessor, postprocessor=postprocessor, - n_episodes=cfg.eval.n_episodes, + n_episodes=n_per_shard, max_episodes_rendered=10, videos_dir=Path(cfg.output_dir) / "videos", - start_seed=cfg.seed, + start_seed=shard_seed, max_parallel_tasks=cfg.env.max_parallel_tasks, ) print("Overall Aggregated Metrics:") @@ -577,8 +594,13 @@ def eval_main(cfg: EvalPipelineConfig): # Close all vec envs close_envs(envs) - # Save info - with open(Path(cfg.output_dir) / "eval_info.json", "w") as f: + # Save info — use shard-specific filename when running in parallel mode. + if num_shards > 1: + out_path = Path(cfg.output_dir) / f"shard_{shard_id}_of_{num_shards}.json" + else: + out_path = Path(cfg.output_dir) / "eval_info.json" + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w") as f: json.dump(info, f, indent=2) logging.info("End of eval") @@ -738,34 +760,49 @@ def _append(key, value): group_acc[group]["video_paths"].extend(paths) overall["video_paths"].extend(paths) + def _make_thread_policy(p: PreTrainedPolicy) -> PreTrainedPolicy: + """Shallow copy sharing weight tensors, with independent per-thread state. + + copy.copy() gives a new Python object whose _parameters dict is a shared + reference (same tensor storage, zero extra VRAM). reset() then rebinds + mutable state (action queues etc.) to fresh per-thread objects. + + Note: does NOT work for ACT with temporal_ensemble_coeff — that policy's + reset() mutates a shared sub-object. Use max_parallel_tasks=1 for that config. + """ + thread_p = copy.copy(p) + thread_p.reset() + return thread_p + # Choose runner (sequential vs threaded) - task_runner = partial( - run_one, - policy=policy, - env_preprocessor=env_preprocessor, - env_postprocessor=env_postprocessor, - preprocessor=preprocessor, - postprocessor=postprocessor, - n_episodes=n_episodes, - max_episodes_rendered=max_episodes_rendered, - videos_dir=videos_dir, - return_episode_data=return_episode_data, - start_seed=start_seed, - ) + _runner_kwargs = { + "env_preprocessor": env_preprocessor, + "env_postprocessor": env_postprocessor, + "preprocessor": preprocessor, + "postprocessor": postprocessor, + "n_episodes": n_episodes, + "max_episodes_rendered": max_episodes_rendered, + "videos_dir": videos_dir, + "return_episode_data": return_episode_data, + "start_seed": start_seed, + } if max_parallel_tasks <= 1: for task_group, task_id, env in tasks: try: - tg, tid, metrics = task_runner(task_group, task_id, env) + tg, tid, metrics = run_one(task_group, task_id, env, policy=policy, **_runner_kwargs) _accumulate_to(tg, metrics) per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics}) finally: env.close() else: + # threaded path: each thread gets a shallow policy copy (shared weights, independent state) with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor: fut2meta = {} for task_group, task_id, env in tasks: - fut = executor.submit(task_runner, task_group, task_id, env) + fut = executor.submit( + run_one, task_group, task_id, env, policy=_make_thread_policy(policy), **_runner_kwargs + ) fut2meta[fut] = (task_group, task_id, env) for fut in cf.as_completed(fut2meta): tg, tid, env = fut2meta[fut] diff --git a/src/lerobot/scripts/lerobot_eval_autotune.py b/src/lerobot/scripts/lerobot_eval_autotune.py new file mode 100644 index 0000000000..346dd345dd --- /dev/null +++ b/src/lerobot/scripts/lerobot_eval_autotune.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Probe hardware and recommend optimal lerobot-eval-parallel flags. + +Run standalone: + lerobot-eval-autotune --policy.path=lerobot/smolvla_libero --env.type=libero + +Or called programmatically from lerobot_eval_parallel when --num-shards auto. + +Steps: + 1. Probe GPU VRAM and CPU core count. + 2. Measure model VRAM footprint (load policy, delta of cuda.memory_allocated). + 3. Compute max shards limited by VRAM (85% of total). + 4. Probe env step time (optional, skipped when skip_timing=True). + 5. Probe inference time (optional, skipped when skip_timing=True). + 6. Derive num_shards = min(vram_limit, saturation_shards). + 7. Choose MUJOCO_GL (egl vs osmesa) based on remaining VRAM headroom. + 8. Compute batch_size = max(4, min(floor(cpu_cores * 0.8 / num_shards), 64)). + 9. Print paste-ready command. +""" + +import math +import os +import sys +import time +from dataclasses import dataclass + + +@dataclass +class AutotuneRecommendation: + num_shards: int + batch_size: int + mujoco_gl: str + use_amp: bool + # Probed values + gpu_name: str + vram_gb: float + cpu_cores: int + model_gb: float + env_step_ms: float | None + infer_ms: float | None + + +_DEFAULT_ENV_STEP_MS = 22.0 # LIBERO on GPU, typical value +_DEFAULT_INFER_MS = 5.0 # SmolVLA fp16 on H100 + + +def _probe_gpu() -> tuple[str, float]: + """Return (gpu_name, vram_gb). Falls back to CPU sentinel on non-CUDA systems.""" + try: + import torch + + if not torch.cuda.is_available(): + return "CPU (no CUDA)", 0.0 + props = torch.cuda.get_device_properties(0) + return props.name, props.total_memory / (1024**3) + except Exception: + return "unknown", 0.0 + + +def _probe_model_gb(passthrough: list[str]) -> float: + """Load the policy (from --policy.path) and measure VRAM delta. Returns GB.""" + # Extract policy path from passthrough args + policy_path = None + for tok in passthrough: + if tok.startswith("policy.path="): + policy_path = tok.split("=", 1)[1] + break + if tok.startswith("--policy.path="): + policy_path = tok.split("=", 1)[1] + break + if policy_path is None: + return 0.0 + + try: + import torch + + from lerobot.policies.factory import make_policy + from lerobot.policies.pretrained import PreTrainedConfig + + if not torch.cuda.is_available(): + return 0.0 + torch.cuda.synchronize() + before = torch.cuda.memory_allocated(0) + cfg = PreTrainedConfig.from_pretrained(policy_path) + cfg.pretrained_path = policy_path # type: ignore[assignment] + policy = make_policy(cfg=cfg) + policy.eval() + torch.cuda.synchronize() + after = torch.cuda.memory_allocated(0) + del policy + torch.cuda.empty_cache() + return (after - before) / (1024**3) + except Exception as e: + print(f"[autotune] could not measure model VRAM: {e}", file=sys.stderr) + return 0.0 + + +def _probe_env_step_ms(passthrough: list[str], batch_size: int = 8, n_steps: int = 30) -> float | None: + """Run a short env warmup and return median step latency in ms. Returns None on failure.""" + try: + import numpy as np + + from lerobot.envs.factory import make_env + + # Parse env config from passthrough using lerobot's own parser + env_type = None + for tok in passthrough: + if tok.startswith("env.type=") or tok.startswith("--env.type="): + env_type = tok.split("=", 1)[1] + break + if env_type is None: + return None + + # Minimal env config + from lerobot.envs.factory import make_env_config + + env_cfg = make_env_config(env_type) + envs = make_env(env_cfg, n_envs=batch_size, use_async_envs=(batch_size > 1)) + # Get first vec env + first_suite = next(iter(envs.values())) + env = next(iter(first_suite.values())) + + env.reset() + dummy_action = np.zeros((batch_size, env.single_action_space.shape[0])) + timings = [] + for _ in range(n_steps): + t0 = time.perf_counter() + env.step(dummy_action) + timings.append((time.perf_counter() - t0) * 1000) + env.close() + return float(np.median(timings)) + except Exception as e: + print(f"[autotune] env step probe failed: {e}", file=sys.stderr) + return None + + +def probe_and_recommend( + passthrough: list[str], + skip_timing: bool = False, +) -> AutotuneRecommendation: + """Probe hardware + model and return the recommended configuration.""" + gpu_name, vram_gb = _probe_gpu() + cpu_cores = os.cpu_count() or 4 + + # Model footprint + model_gb = _probe_model_gb(passthrough) + if model_gb == 0.0: + # Unknown model: assume a conservative 14 GB (SmolVLA fp16) as placeholder + model_gb = 14.0 + print("[autotune] model size unknown, assuming 14 GB (SmolVLA fp16)", file=sys.stderr) + + # Max shards from VRAM (leave 15% headroom for activations + env frames) + max_shards_vram = max(1, math.floor(vram_gb * 0.85 / model_gb)) if vram_gb > 0 else 1 + + # Timing probes + env_step_ms: float | None = None + infer_ms: float | None = None + if not skip_timing: + env_step_ms = _probe_env_step_ms(passthrough) + # Inference time: assume ~infer = env_step / saturation_factor heuristic + # Full probe would require loading policy — skip for now to stay fast. + infer_ms = _DEFAULT_INFER_MS + + # Number of shards to saturate GPU: ceil(env_step / infer) + _step = env_step_ms or _DEFAULT_ENV_STEP_MS + _infer = infer_ms or _DEFAULT_INFER_MS + saturation_shards = max(1, math.ceil(_step / _infer)) + + num_shards = min(max_shards_vram, saturation_shards) + + # Rendering mode: EGL if all model copies + env frame buffers fit in VRAM + env_vram_per_shard_gb = 0.01 # ~10 MB overhead per env batch + total_with_egl = num_shards * (model_gb + env_vram_per_shard_gb) + mujoco_gl = "egl" if (vram_gb == 0 or total_with_egl < vram_gb * 0.85) else "osmesa" + + # Batch size: fill CPU cores evenly across shards + batch_size = max(4, min(math.floor(cpu_cores * 0.8 / num_shards), 64)) + + # Recommend AMP when model is large (saves ~50% VRAM) + use_amp = model_gb > 8.0 + + return AutotuneRecommendation( + num_shards=num_shards, + batch_size=batch_size, + mujoco_gl=mujoco_gl, + use_amp=use_amp, + gpu_name=gpu_name, + vram_gb=vram_gb, + cpu_cores=cpu_cores, + model_gb=model_gb, + env_step_ms=env_step_ms, + infer_ms=infer_ms, + ) + + +def main(argv: list[str] | None = None) -> None: + passthrough = argv if argv is not None else sys.argv[1:] + + rec = probe_and_recommend(passthrough) + + env_step_str = ( + f"{rec.env_step_ms:.0f}ms" if rec.env_step_ms else f"~{_DEFAULT_ENV_STEP_MS:.0f}ms (estimated)" + ) + infer_str = f"{rec.infer_ms:.0f}ms" if rec.infer_ms else f"~{_DEFAULT_INFER_MS:.0f}ms (estimated)" + + print() + print( + f"GPU: {rec.gpu_name} | VRAM: {rec.vram_gb:.1f} GB | CPU cores: {rec.cpu_cores} | Model: {rec.model_gb:.1f} GB" + ) + print() + print(f" env_step_ms: {env_step_str} | infer_ms: {infer_str}") + print() + print(f" num_shards: {rec.num_shards}") + print(f" batch_size: {rec.batch_size}") + print(f" MUJOCO_GL: {rec.mujoco_gl}") + if rec.use_amp: + print(" use_amp: true (recommended — halves VRAM, faster matmuls)") + print() + + # Build paste-ready command + flags = [f"--num-shards {rec.num_shards}", f"eval.batch_size={rec.batch_size}"] + if rec.use_amp: + flags.append("policy.use_amp=true") + flags_str = " \\\n ".join(flags) + passthrough_str = " \\\n ".join(passthrough) if passthrough else "[your flags]" + + print(" Paste-ready command:") + print(f" MUJOCO_GL={rec.mujoco_gl} lerobot-eval-parallel \\") + print(f" {flags_str} \\") + print(f" {passthrough_str}") + print() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/scripts/lerobot_eval_parallel.py b/src/lerobot/scripts/lerobot_eval_parallel.py new file mode 100644 index 0000000000..caead485c6 --- /dev/null +++ b/src/lerobot/scripts/lerobot_eval_parallel.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Run lerobot-eval across N independent subprocesses (shards) for maximum GPU utilization. + +Each shard handles a disjoint subset of episodes and writes its own JSON results file. +Results are merged and printed when all shards complete. + +Usage: + lerobot-eval-parallel --num-shards 4 [any lerobot-eval flags] + lerobot-eval-parallel --num-shards auto [any lerobot-eval flags] + lerobot-eval-parallel --num-shards auto --render-device cpu [any lerobot-eval flags] + +--num-shards auto: + Calls lerobot-eval-autotune to probe hardware and determine the optimal number of shards. + +--render-device gpu|cpu|auto: + Controls MUJOCO_GL env var. 'gpu' -> EGL (faster, ~3ms/frame, ~200KB VRAM/env). + 'cpu' -> osmesa (slower, ~12ms/frame, 0 VRAM). 'auto' picks based on VRAM headroom. + Default: auto. +""" + +import argparse +import json +import os +import subprocess +import sys +from pathlib import Path + + +def _parse_known(argv: list[str]) -> tuple[argparse.Namespace, list[str]]: + p = argparse.ArgumentParser(add_help=False) + p.add_argument("--num-shards", default="1") + p.add_argument("--render-device", choices=["gpu", "cpu", "auto"], default="auto") + p.add_argument("--output-dir", default=None) + return p.parse_known_args(argv) + + +def _resolve_num_shards(num_shards_str: str, passthrough: list[str]) -> int: + if num_shards_str == "auto": + from lerobot.scripts.lerobot_eval_autotune import probe_and_recommend + + rec = probe_and_recommend(passthrough) + print( + f"[autotune] recommended num_shards={rec.num_shards}, batch_size={rec.batch_size}, MUJOCO_GL={rec.mujoco_gl}" + ) + return rec.num_shards + return int(num_shards_str) + + +def _resolve_mujoco_gl(render_device: str, num_shards: int, passthrough: list[str]) -> str: + if render_device == "gpu": + return "egl" + if render_device == "cpu": + return "osmesa" + # auto: use EGL for single shard; for multiple shards check VRAM headroom + if num_shards == 1: + return "egl" + try: + from lerobot.scripts.lerobot_eval_autotune import probe_and_recommend + + rec = probe_and_recommend(passthrough, skip_timing=True) + return rec.mujoco_gl + except Exception: + # Conservative fallback: osmesa avoids EGL VRAM contention + return "osmesa" + + +def _extract_output_dir(passthrough: list[str]) -> str | None: + for tok in passthrough: + if tok.startswith("--output-dir="): + return tok.split("=", 1)[1] + if tok == "--output-dir": + idx = passthrough.index(tok) + if idx + 1 < len(passthrough): + return passthrough[idx + 1] + return None + + +def _merge_shards(output_dir: str, num_shards: int) -> dict: + """Merge per-shard JSON files into a single result dict and write eval_info.json.""" + all_per_task: list[dict] = [] + per_group: dict[str, dict] = {} + + for k in range(num_shards): + shard_path = Path(output_dir) / f"shard_{k}_of_{num_shards}.json" + if not shard_path.exists(): + print(f"[warning] shard file not found: {shard_path}", file=sys.stderr) + continue + with open(shard_path) as f: + shard = json.load(f) + all_per_task.extend(shard.get("per_task", [])) + for group, metrics in shard.get("per_group", {}).items(): + if group not in per_group: + per_group[group] = {"sum_rewards": [], "max_rewards": [], "successes": []} + for key in ("sum_rewards", "max_rewards", "successes"): + # metrics may store aggregates; reconstruct lists if possible + per_group[group][key].extend(metrics.get(key, [])) + + # Re-aggregate + import numpy as np + + def _nanmean(xs: list) -> float: + return float(np.nanmean(xs)) if xs else float("nan") + + groups_out = {} + all_sr, all_mr, all_succ = [], [], [] + for group, acc in per_group.items(): + groups_out[group] = { + "avg_sum_reward": _nanmean(acc["sum_rewards"]), + "avg_max_reward": _nanmean(acc["max_rewards"]), + "pc_success": _nanmean(acc["successes"]) * 100 if acc["successes"] else float("nan"), + "n_episodes": len(acc["sum_rewards"]), + } + all_sr.extend(acc["sum_rewards"]) + all_mr.extend(acc["max_rewards"]) + all_succ.extend(acc["successes"]) + + overall = { + "avg_sum_reward": _nanmean(all_sr), + "avg_max_reward": _nanmean(all_mr), + "pc_success": _nanmean(all_succ) * 100 if all_succ else float("nan"), + "n_episodes": len(all_sr), + } + + merged = {"per_task": all_per_task, "per_group": groups_out, "overall": overall} + out_path = Path(output_dir) / "eval_info.json" + with open(out_path, "w") as f: + json.dump(merged, f, indent=2) + return merged + + +def main(argv: list[str] | None = None) -> None: + args, passthrough = _parse_known(argv if argv is not None else sys.argv[1:]) + + num_shards = _resolve_num_shards(args.num_shards, passthrough) + mujoco_gl = _resolve_mujoco_gl(args.render_device, num_shards, passthrough) + + output_dir = args.output_dir or _extract_output_dir(passthrough) + + print(f"[lerobot-eval-parallel] launching {num_shards} shard(s), MUJOCO_GL={mujoco_gl}") + + child_env = {**os.environ, "MUJOCO_GL": mujoco_gl, "OMP_NUM_THREADS": "1"} + + procs = [] + for k in range(num_shards): + cmd = [ + sys.executable, + "-m", + "lerobot.scripts.lerobot_eval", + f"eval.shard_id={k}", + f"eval.num_shards={num_shards}", + *passthrough, + ] + if output_dir: + # Each shard shares the same output_dir; shard files are named shard_K_of_N.json + cmd.append(f"output_dir={output_dir}") + procs.append(subprocess.Popen(cmd, env=child_env)) + + return_codes = [p.wait() for p in procs] + if any(rc != 0 for rc in return_codes): + failed = [k for k, rc in enumerate(return_codes) if rc != 0] + print(f"[lerobot-eval-parallel] shards {failed} failed with non-zero exit codes.", file=sys.stderr) + sys.exit(1) + + if output_dir and num_shards > 1: + merged = _merge_shards(output_dir, num_shards) + print("\n=== Merged Results ===") + print(json.dumps(merged["overall"], indent=2)) + + +if __name__ == "__main__": + main()