diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index 1c59ccb7dd..cf48fb9a88 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +from functools import partial from typing import Any import gymnasium as gym @@ -163,7 +164,9 @@ def make_env( if n_envs < 1: raise ValueError("`n_envs` must be at least 1") - env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv + env_cls = ( + partial(gym.vector.AsyncVectorEnv, context="spawn") if use_async_envs else gym.vector.SyncVectorEnv + ) if "libero" in cfg.type: from lerobot.envs.libero import create_libero_envs diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index fd17a67621..c53e1a0844 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -130,7 +130,21 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: return policy_features +def _has_env_attr(env: gym.vector.VectorEnv, attr: str) -> bool: + """Check if sub-environments have an attribute, compatible with sync and async vector envs.""" + if hasattr(env, "envs"): + return hasattr(env.envs[0], attr) + try: + env.call(attr) + return True + except Exception: + return False + + def are_all_envs_same_type(env: gym.vector.VectorEnv) -> bool: + if not hasattr(env, "envs"): + # AsyncVectorEnv: cannot inspect subprocess env types directly + return True first_type = type(env.envs[0]) # Get type of first env return all(type(e) is first_type for e in env.envs) # Fast type check @@ -139,7 +153,7 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None: with warnings.catch_warnings(): warnings.simplefilter("once", UserWarning) # Apply filter only in this function - if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")): + if not (_has_env_attr(env, "task_description") or _has_env_attr(env, "task")): warnings.warn( "The environment does not have 'task_description' and 'task'. Some policies require these features.", UserWarning, @@ -155,7 +169,7 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None: def add_envs_task(env: gym.vector.VectorEnv, observation: RobotObservation) -> RobotObservation: """Adds task feature to the observation dict with respect to the first environment attribute.""" - if hasattr(env.envs[0], "task_description"): + if _has_env_attr(env, "task_description"): task_result = env.call("task_description") if isinstance(task_result, tuple): @@ -167,7 +181,7 @@ def add_envs_task(env: gym.vector.VectorEnv, observation: RobotObservation) -> R raise TypeError("All items in task_description result must be strings") observation["task"] = task_result - elif hasattr(env.envs[0], "task"): + elif _has_env_attr(env, "task"): task_result = env.call("task") if isinstance(task_result, tuple):