diff --git a/.gitignore b/.gitignore index 2fe57a694..0910888ed 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,7 @@ docs/**/domain_config.js .ruff_cache .lycheecache + +.venv/ +build/ +*.egg-info/ \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md index 521a3dece..deb1caaeb 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -120,6 +120,8 @@ Entry point: `service/core/service.py`. Framework: FastAPI + Uvicorn + OpenTelem | `utils/job/` | `Task`, `FrontendJob`, `K8sObjectFactory`, `PodGroupTopologyBuilder` | Workflow execution framework. Task → K8s spec generation. Gang scheduling via PodGroup. Topology constraints. Backend job definitions. | | `utils/connectors/` | `ClusterConnector`, `PostgresConnector`, `RedisConnector` | K8s API wrapper, PostgreSQL operations, Redis job queue management. | | `utils/secret_manager/` | `SecretManager` | JWE-based secret encryption/decryption. MEK/UEK key management. | +| `utils/local_executor.py` | `LocalExecutor`, `run_workflow_locally` | Local Docker Compose-based workflow execution. Generates a `docker-compose.yml` from workflow specs and runs `docker compose up`, providing on-cluster container paths (`/osmo/data/output`, `/osmo/data/input/N`), real parallel execution via `depends_on`, cycle detection, DNS-addressable `{{host:taskname}}`, resume (`--from-step`), and GPU passthrough. | +| `utils/spec_includes.py` | `resolve_includes` | Helpers to resolve and merge workflow spec `includes` directives into fully composed specs. Supports recursive inclusion, cycle detection, deep-merging, and `default-values` variable expansion. | | `utils/progress_check/` | — | Liveness/progress tracking for long-running services. | | `utils/metrics/` | — | Prometheus metrics collection and export. | @@ -139,6 +141,7 @@ Entry point: `cli.py` → `main_parser.py` (argparse). Subcommand modules: | `login.py` | Authentication | | `pool.py`, `resources.py`, `user.py`, `credential.py`, `access_token.py`, `bucket.py`, `task.py`, `version.py` | Supporting commands | | `backend.py` | Backend cluster management | +| `local.py` | Local workflow execution via Docker (`osmo local run`) | Features: Tab completion (shtab), response formatting (`formatters.py`), spec editor (`editor.py`), PyInstaller packaging (`cli_builder.py`, `packaging/`). diff --git a/cookbook/tutorials/BUILD b/cookbook/tutorials/BUILD new file mode 100644 index 000000000..d56c526f4 --- /dev/null +++ b/cookbook/tutorials/BUILD @@ -0,0 +1,5 @@ +filegroup( + name = "tutorial_specs", + srcs = glob(["*.yaml"]), + visibility = ["//src/utils/tests:__pkg__"], +) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..7c0b47312 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,44 @@ +[build-system] +requires = ["setuptools>=64"] +build-backend = "setuptools.build_meta" + +[project] +name = "nvidia-osmo" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "pydantic>=1.10,<2", + "pyyaml>=6.0", + "requests>=2.28", + "urllib3>=1.26", + "typing_extensions>=4.0", + "boto3>=1.26", + "botocore>=1.29", + "mypy-boto3-iam>=1.26", + "mypy-boto3-s3>=1.26", + "mypy-boto3-sts>=1.26", + "azure-storage-blob>=12.14", + "azure-identity>=1.12", + "psycopg2-binary>=2.9", + "pyjwt>=2.6", + "jwcrypto>=1.5", + "jinja2>=3.1", + "pytz>=2023.3", + "texttable>=1.6", + "tqdm>=4.64", + "aiofiles>=23.0", + "kombu>=5.2", + "redis>=4.4", + "kubernetes>=24.2", + "fastapi>=0.100", + "slack_sdk>=3.20", + "shtab>=1.5", +] + +[project.scripts] +osmo = "src.cli.cli:main" + +[tool.setuptools.packages.find] +include = ["src*"] +exclude = ["src.ui*", "src.tests*"] +namespaces = true diff --git a/src/cli/BUILD b/src/cli/BUILD index 7a9b905ee..584173d0a 100755 --- a/src/cli/BUILD +++ b/src/cli/BUILD @@ -37,6 +37,7 @@ osmo_py_library( "dataset.py", "editor.py", "formatters.py", + "local.py", "login.py", "main_parser.py", "pool.py", @@ -73,6 +74,8 @@ osmo_py_library( "//src/lib/utils:validation", "//src/lib/utils:version", "//src/lib/utils:workflow", + "//src/utils:local_executor", + "//src/utils:spec_includes", ], ) diff --git a/src/cli/local.py b/src/cli/local.py new file mode 100644 index 000000000..d1c37fada --- /dev/null +++ b/src/cli/local.py @@ -0,0 +1,209 @@ +# pylint: disable=line-too-long +""" +SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +SPDX-License-Identifier: Apache-2.0 +""" + +import argparse +import os +import re +import sys + +import shtab +import yaml + +from src.utils import local_executor, spec_includes + + +def setup_parser(parser: argparse._SubParsersAction): + """Register the 'local' subcommand and its nested actions with the CLI argument parser.""" + local_parser = parser.add_parser( + 'local', + help='Run workflows locally using Docker (no Kubernetes cluster required).') + subparsers = local_parser.add_subparsers(dest='command') + subparsers.required = True + + run_parser = subparsers.add_parser( + 'run', + help='Execute a workflow spec locally using Docker containers.') + run_parser.add_argument( + '-f', '--file', + required=True, + dest='workflow_file', + help='Path to the workflow YAML spec file.').complete = shtab.FILE + run_parser.add_argument( + '--work-dir', + dest='work_dir', + default=None, + help='Directory for task inputs/outputs. Defaults to a temporary directory.') + run_parser.add_argument( + '--keep', + action='store_true', + default=False, + help='Keep the work directory after execution (always kept on failure).') + run_parser.add_argument( + '--docker', + dest='docker_cmd', + default='docker', + help='Docker-compatible command to use (e.g. podman). Default: docker.') + run_parser.add_argument( + '--resume', + action='store_true', + default=False, + help='Resume a previous run, skipping tasks that already completed successfully. ' + 'Requires --work-dir pointing to the previous run directory.') + run_parser.add_argument( + '--from-step', + dest='from_step', + default=None, + help='Resume from a specific task, re-running it and all downstream tasks. ' + 'Tasks upstream of the specified step are skipped if they completed ' + 'successfully. Requires --work-dir pointing to the previous run directory.') + run_parser.add_argument( + '--shm-size', + dest='shm_size', + default=None, + help='Shared memory size for GPU containers (e.g. 16g, 32g). ' + 'Defaults to 16g for tasks that request GPUs. ' + 'PyTorch DataLoader workers require large shared memory.') + run_parser.set_defaults(func=_run_local) + + compose_parser = subparsers.add_parser( + 'compose', + help='Flatten includes and expand task refs into a single spec with a ' + 'default-values variable map (no variable substitution).') + compose_parser.add_argument( + '-f', '--file', + required=True, + dest='workflow_file', + help='Path to the workflow YAML spec file.').complete = shtab.FILE + compose_parser.add_argument( + '-o', '--output', + dest='output_file', + default=None, + help='Write the composed spec to a file instead of stdout.').complete = shtab.FILE + compose_parser.set_defaults(func=_compose) + + +def _run_local(service_client, args: argparse.Namespace): # pylint: disable=unused-argument + """Execute a workflow locally via Docker using the parsed CLI arguments.""" + try: + success = local_executor.run_workflow_locally( + spec_path=args.workflow_file, + work_dir=args.work_dir, + keep_work_dir=args.keep, + resume=args.resume, + from_step=args.from_step, + docker_cmd=args.docker_cmd, + shm_size=args.shm_size, + ) + except (ValueError, FileNotFoundError, PermissionError) as error: + print(f'Error: {error}', file=sys.stderr) + sys.exit(1) + + if not success: + sys.exit(1) + + +_ENV_REF_RE = re.compile(r'\$\{env:([^}]+)\}') + + +def _resolve_set_env_refs(value: str) -> str: + """Replace ``${env:VAR}`` patterns only when VAR is present in ``os.environ``.""" + def _replacer(match: re.Match) -> str: + env_var = match.group(1) + if env_var in os.environ: + return os.environ[env_var] + return match.group(0) + return _ENV_REF_RE.sub(_replacer, value) + + +def _compose(service_client, args: argparse.Namespace): # pylint: disable=unused-argument + """Flatten includes, resolve variables, and output a submittable spec. + + When all ``${env:VAR}`` references can be resolved the output is fully + flat: no ``default-values`` section, no ``{variable}`` references — + ready to submit to the OSMO server or run locally. + + When environment variables are missing the output keeps a + ``default-values`` section with the unresolvable entries so the user + can fill them in and re-compose. + """ + unresolved_env: dict = {} + try: + abs_path = os.path.abspath(args.workflow_file) + with open(abs_path, encoding='utf-8') as f: + spec_text = f.read() + + spec_text = spec_includes.resolve_includes( + spec_text, os.path.dirname(abs_path), source_path=abs_path) + + unresolved_env = spec_includes.find_unresolved_env_variables(spec_text) + + if unresolved_env: + spec_text = _compose_with_unresolved(spec_text, unresolved_env) + else: + spec_text = _compose_fully_resolved(spec_text) + except (ValueError, FileNotFoundError, PermissionError) as error: + print(f'Error: {error}', file=sys.stderr) + sys.exit(1) + + if unresolved_env: + env_list = ', '.join( + f'${v}' for v in sorted(set(unresolved_env.values()))) + print( + f'Warning: environment variables not set: {env_list}\n' + 'Set them and re-compose, or edit the default-values section ' + 'in the output.', + file=sys.stderr) + + if args.output_file: + with open(args.output_file, 'w', encoding='utf-8') as f: + f.write(spec_text) + print(f'Composed spec written to {args.output_file}', file=sys.stderr) + else: + print(spec_text, end='') + + +def _compose_fully_resolved(spec_text: str) -> str: + """Resolve all variables and produce a submittable spec.""" + return spec_includes.resolve_default_values(spec_text) + + +def _compose_with_unresolved(spec_text: str, + unresolved_env: dict) -> str: + """Keep a ``default-values`` map for variables that cannot be resolved.""" + parsed = yaml.safe_load(spec_text) + raw_defaults = parsed.pop('default-values', None) or {} + + scalar_defaults: dict = {} + for key in sorted(raw_defaults): + value = raw_defaults[key] + if isinstance(value, (str, int, float, bool)): + scalar_defaults[key] = value + elif value is None: + scalar_defaults[key] = value + + for key, value in scalar_defaults.items(): + if isinstance(value, str): + scalar_defaults[key] = _resolve_set_env_refs(value) + + output: dict = {} + if scalar_defaults: + output['default-values'] = scalar_defaults + output.update(parsed) + + return yaml.safe_dump(output, default_flow_style=False, sort_keys=False) diff --git a/src/cli/main_parser.py b/src/cli/main_parser.py index 79484ee16..bd097111d 100644 --- a/src/cli/main_parser.py +++ b/src/cli/main_parser.py @@ -28,6 +28,7 @@ credential, data, dataset, + local, login, pool, profile, @@ -55,7 +56,8 @@ profile.setup_parser, pool.setup_parser, user.setup_parser, - config.setup_parser + config.setup_parser, + local.setup_parser, ) diff --git a/src/cli/workflow.py b/src/cli/workflow.py index eec06bf51..8b8a9560d 100644 --- a/src/cli/workflow.py +++ b/src/cli/workflow.py @@ -48,6 +48,7 @@ from src.lib.data import storage from src.lib.utils import (client, common, osmo_errors, paths, port_forward, priority as wf_priority, validation, workflow as workflow_utils) +from src.utils import spec_includes INTERACTIVE_COMMANDS = ['bash', 'sh', 'zsh', 'fish', 'tcsh', 'csh', 'ksh'] @@ -588,8 +589,11 @@ def parse_file_for_template(workflow_contents: str, set_variables: List[str], def _load_wf_file(workflow_path: str, set_variables: List[str], set_string_variables: List[str]) -> TemplateData: - with open(workflow_path, 'r', encoding='utf-8') as file: + abs_path = os.path.abspath(workflow_path) + with open(abs_path, 'r', encoding='utf-8') as file: full_file_text = file.read() + full_file_text = spec_includes.resolve_includes( + full_file_text, os.path.dirname(abs_path), source_path=abs_path) return parse_file_for_template(full_file_text, set_variables, set_string_variables) diff --git a/src/utils/BUILD b/src/utils/BUILD index f674edf6a..0a555a38a 100644 --- a/src/utils/BUILD +++ b/src/utils/BUILD @@ -98,3 +98,24 @@ osmo_py_library( ], visibility = ["//visibility:public"], ) + +osmo_py_library( + name = "spec_includes", + srcs = ["spec_includes.py"], + deps = [ + requirement("pyyaml"), + "//src/lib/utils:osmo_errors", + ], + visibility = ["//visibility:public"], +) + +osmo_py_library( + name = "local_executor", + srcs = ["local_executor.py"], + deps = [ + requirement("pyyaml"), + "//src/utils:spec_includes", + "//src/utils/job", + ], + visibility = ["//visibility:public"], +) diff --git a/src/utils/job/task.py b/src/utils/job/task.py index fc182084b..66405f8b7 100644 --- a/src/utils/job/task.py +++ b/src/utils/job/task.py @@ -598,7 +598,7 @@ class TaskSpec(pydantic.BaseModel, extra=pydantic.Extra.forbid): """ Represents the container spec in a task spec. """ name: task_common.NamePattern image: str - command: List[str] + command: List[str] = [] inputs: List[InputType] = [] outputs: List[OutputType] = [] kpis: List[TaskKPI] = [] @@ -663,12 +663,10 @@ def validate_command(cls, command: List[str], values: Dict) -> List[str]: """ Validates command. Returns the value of command if valid. - Raises: - ValueError: Containers fails validation. + An empty command list means "use the container image's default + ENTRYPOINT", which is a valid and common pattern (e.g. NRE images + with pycena_run as the built-in entrypoint). """ - name = values.get('name', '') - if not command: - raise ValueError(f'Container {name} should have at least one command.') return command @pydantic.validator('files') diff --git a/src/utils/local_executor.py b/src/utils/local_executor.py new file mode 100644 index 000000000..c33b3ceb6 --- /dev/null +++ b/src/utils/local_executor.py @@ -0,0 +1,755 @@ +""" +SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # pylint: disable=line-too-long + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +SPDX-License-Identifier: Apache-2.0 +""" +# Security: This module executes workflow specs by generating a +# docker-compose.yml and invoking Docker Compose via subprocess. +# Specs must come from trusted sources. Path-traversal protections +# are in place for data directories, but the spec itself is not sandboxed. + +import dataclasses +import json +import logging +import os +import re +import shutil +import subprocess +import tempfile +from collections import deque +from typing import Dict, List, Set + +import yaml + +from src.utils import spec_includes +from src.utils.job import task as task_module +from src.utils.job import workflow as workflow_module + + +logger = logging.getLogger(__name__) + +STATE_FILE_NAME = '.osmo-state.json' +COMPOSE_FILE_NAME = 'docker-compose.yml' + +OSMO_OUTPUT_PATH = '/osmo/data/output' +OSMO_INPUT_PATH_PREFIX = '/osmo/data/input' + +_OSMO_RUNTIME_TOKEN = re.compile( + r'\{\{\s*(uuid|workflow_id|output|input:[^}]+|host:[^}]+|item)\s*\}\}') +_ANY_DOUBLE_BRACE = re.compile(r'\{\{[^}]+?\}\}') + + +@dataclasses.dataclass +class TaskNode: + """A node in the workflow DAG, linking a task spec to its upstream + and downstream dependencies.""" + + name: str + spec: task_module.TaskSpec + group: str + upstream: Set[str] = dataclasses.field(default_factory=set) + downstream: Set[str] = dataclasses.field(default_factory=set) + + +@dataclasses.dataclass +class TaskResult: + """Outcome of a single task execution, capturing its exit code and output directory path.""" + + name: str + exit_code: int + output_dir: str + + +class LocalExecutor: + """ + Executes an OSMO workflow spec locally using Docker Compose, without Kubernetes. + + Generates a docker-compose.yml from the workflow spec and runs + ``docker compose up``, giving: + + - Correct container paths matching on-cluster behavior + (``/osmo/data/output``, ``/osmo/data/input/N``) + - Real parallel execution with native dependency ordering + via ``depends_on: condition: service_completed_successfully`` + - Cycle detection (Compose validates the DAG; also checked upfront) + - DNS-addressable service names for ``{{host:taskname}}`` + - GPU passthrough via ``deploy.resources.reservations.devices`` + + Does NOT support (raises clear errors): + - Dataset / URL inputs/outputs (require object storage) + - Credentials, checkpoints, volumeMounts (require cluster infra) + - Templated specs with Jinja (require server-side expansion; use --dry-run first) + """ + + DEFAULT_SHM_SIZE = '16g' + + _ENTRYPOINT_COMMANDS = frozenset({ + 'bash', 'sh', 'dash', 'zsh', + 'python', 'python3', 'python3.10', 'python3.11', 'python3.12', + 'perl', 'ruby', 'node', + }) + + def __init__(self, work_dir: str, keep_work_dir: bool = False, + docker_cmd: str = 'docker', + shm_size: str | None = None, + extra_volumes: List[str] | None = None): + """Initialize the executor with a work directory, cleanup preference, + and container runtime command. + + Args: + extra_volumes: Additional Docker volume mounts (``host:container`` + strings) added to every task. Useful for making host paths + (e.g. repository root, credential directories) visible inside + containers. + """ + self._work_dir = os.path.abspath(work_dir) + self._keep_work_dir = keep_work_dir + self._docker_cmd = docker_cmd + self._shm_size = shm_size + self._extra_volumes = list(extra_volumes) if extra_volumes else [] + self._task_nodes: Dict[str, TaskNode] = {} + self._results: Dict[str, TaskResult] = {} + self._available_gpus: int | None = None + + def _detect_available_gpus(self) -> int: + """Query nvidia-smi to count available GPUs, caching the result for subsequent calls.""" + if self._available_gpus is not None: + return self._available_gpus + try: + result = subprocess.run( + ['nvidia-smi', '--query-gpu=index', + '--format=csv,noheader'], + capture_output=True, text=True, timeout=10, + check=False, + ) + if result.returncode == 0: + gpu_indices = [ + line.strip() + for line in result.stdout.strip().splitlines() + if line.strip() + ] + self._available_gpus = len(gpu_indices) + else: + logger.warning( + 'nvidia-smi failed (exit %d) — assuming 0 GPUs available', + result.returncode) + self._available_gpus = 0 + except FileNotFoundError: + logger.warning('nvidia-smi not found — assuming 0 GPUs available') + self._available_gpus = 0 + except subprocess.TimeoutExpired: + logger.warning('nvidia-smi timed out — assuming 0 GPUs available') + self._available_gpus = 0 + return self._available_gpus + + def load_spec(self, spec_text: str) -> workflow_module.WorkflowSpec: + """Parse raw YAML text into a validated WorkflowSpec via the versioned spec model.""" + raw = yaml.safe_load(spec_text) + versioned = workflow_module.VersionedWorkflowSpec(**raw) + return versioned.workflow + + def execute(self, spec: workflow_module.WorkflowSpec, + resume: bool = False, from_step: str | None = None) -> bool: + """Run all tasks via Docker Compose, returning True if the entire workflow succeeds.""" + self._results.clear() + self._build_dag(spec) + self._detect_cycles() + self._validate_for_local(spec) + self._setup_directories() + + if resume or from_step: + self._restore_completed_tasks(from_step) + + tasks_to_run = set(self._task_nodes.keys()) - set(self._results.keys()) + if not tasks_to_run: + logger.info('Workflow "%s": all tasks already completed', spec.name) + return True + + total_tasks = sum(len(g.tasks) for g in self._groups(spec)) + skipped = len(self._results) + if skipped > 0: + logger.info('Workflow "%s": resuming — %d task(s) skipped, %d remaining', + spec.name, skipped, len(tasks_to_run)) + else: + logger.info('Workflow "%s": %d task(s) across %d group(s)', + spec.name, total_tasks, len(self._groups(spec))) + + compose_config = self._generate_compose_config(spec, tasks_to_run) + compose_path = os.path.join(self._work_dir, COMPOSE_FILE_NAME) + with open(compose_path, 'w', encoding='utf-8') as f: + yaml.dump(compose_config, f, default_flow_style=False, sort_keys=False) + + logger.info('Generated %s', compose_path) + + project_name = re.sub(r'[^a-z0-9-]', '-', os.path.basename(self._work_dir).lower()) + compose_cmd = [ + self._docker_cmd, 'compose', + '-f', compose_path, + '--project-name', project_name, + 'up', + ] + + logger.info('Starting Docker Compose execution') + + try: + process = subprocess.run(compose_cmd, capture_output=False, check=False) + compose_exit_code = process.returncode + except FileNotFoundError: + logger.error( + '%s not found. Is Docker (with Compose V2) installed and in your PATH?', + self._docker_cmd) + return False + + self._collect_compose_results(compose_path, project_name, tasks_to_run) + self._save_state() + self._compose_down(compose_path, project_name) + + failed = [name for name, r in self._results.items() + if r.exit_code not in (0, -1)] + not_run = [name for name, r in self._results.items() if r.exit_code == -1] + + if failed: + logger.error('Workflow failed. Failed tasks: %s', ', '.join(sorted(failed))) + if not_run: + logger.error('Tasks not started (blocked by failures): %s', + ', '.join(sorted(not_run))) + return False + + unexecuted = set(self._task_nodes.keys()) - set(self._results.keys()) + if unexecuted: + logger.error( + 'Workflow "%s" stalled — tasks not completed: %s', + spec.name, ', '.join(sorted(unexecuted))) + return False + + if compose_exit_code != 0: + logger.error('Docker Compose exited with code %d', compose_exit_code) + return False + + logger.info('Workflow "%s" completed successfully', spec.name) + return True + + def _detect_cycles(self): + """Detect cycles in the task DAG using Kahn's algorithm (topological sort). + + Raises ValueError with the names of tasks involved in the cycle.""" + in_degree = {name: len(node.upstream) for name, node in self._task_nodes.items()} + queue: deque[str] = deque( + name for name, degree in in_degree.items() if degree == 0) + visited_count = 0 + + while queue: + current = queue.popleft() + visited_count += 1 + for downstream in self._task_nodes[current].downstream: + in_degree[downstream] -= 1 + if in_degree[downstream] == 0: + queue.append(downstream) + + if visited_count != len(self._task_nodes): + cycle_members = sorted( + name for name, degree in in_degree.items() if degree > 0) + raise ValueError( + f'Circular dependency detected among tasks: {", ".join(cycle_members)}') + + def _generate_compose_config(self, spec: workflow_module.WorkflowSpec, + tasks_to_run: Set[str]) -> Dict: + """Generate a docker-compose.yml configuration dict for the tasks that need to run.""" + services: Dict[str, Dict] = {} + for task_name in self._topological_order(): + if task_name not in tasks_to_run: + continue + node = self._task_nodes[task_name] + services[task_name] = self._build_service_config(node, spec, tasks_to_run) + return {'services': services} + + def _topological_order(self) -> List[str]: + """Return task names in topological order (stable, respecting insertion order).""" + in_degree = {name: len(node.upstream) for name, node in self._task_nodes.items()} + queue: deque[str] = deque( + name for name in self._task_nodes if in_degree[name] == 0) + order: List[str] = [] + while queue: + current = queue.popleft() + order.append(current) + for downstream in self._task_nodes[current].downstream: + in_degree[downstream] -= 1 + if in_degree[downstream] == 0: + queue.append(downstream) + return order + + @staticmethod + def _escape_compose_interpolation(text: str) -> str: + """Escape ``$`` as ``$$`` to prevent Docker Compose host-variable interpolation. + + Docker Compose expands ``$VAR`` and ``${VAR}`` from the host + environment before passing values to containers. Doubling the + dollar sign makes it a literal ``$`` that the container's shell + can then expand from the container's own environment.""" + return text.replace('$', '$$') + + def _build_service_config(self, node: TaskNode, + spec: workflow_module.WorkflowSpec, + tasks_to_run: Set[str]) -> Dict: + """Build a single Docker Compose service configuration for a task.""" + task_spec = node.spec + task_dir = os.path.join(self._work_dir, node.name) + output_dir = os.path.join(task_dir, 'output') + os.makedirs(output_dir, exist_ok=True) + + token_map = self._build_container_token_map(node) + service: Dict = {'image': task_spec.image} + + volumes = [f'{output_dir}:{OSMO_OUTPUT_PATH}'] + for index, input_source in enumerate(task_spec.inputs): + if isinstance(input_source, task_module.TaskInputOutput): + upstream_task = input_source.task + if upstream_task in self._results: + upstream_output = self._results[upstream_task].output_dir + else: + upstream_output = os.path.join(self._work_dir, upstream_task, 'output') + + container_index = f'{OSMO_INPUT_PATH_PREFIX}/{index}' + container_named = f'{OSMO_INPUT_PATH_PREFIX}/{upstream_task}' + volumes.append(f'{upstream_output}:{container_index}:ro') + if container_index != container_named: + volumes.append(f'{upstream_output}:{container_named}:ro') + + files_dir = os.path.join(task_dir, 'files') + os.makedirs(files_dir, exist_ok=True) + for file_spec in task_spec.files: + resolved_contents = self._substitute_tokens(file_spec.contents, token_map) + host_path = os.path.realpath(os.path.join(files_dir, file_spec.path.lstrip('/'))) + if not host_path.startswith(os.path.realpath(files_dir) + os.sep): + raise ValueError( + f'Task "{node.name}": file path "{file_spec.path}" escapes the task directory') + os.makedirs(os.path.dirname(host_path), exist_ok=True) + with open(host_path, 'w', encoding='utf-8') as f: + f.write(resolved_contents) + volumes.append(f'{host_path}:{file_spec.path}:ro') + + volumes.extend(self._extra_volumes) + service['volumes'] = volumes + + resolved_command = [ + self._escape_compose_interpolation(self._substitute_tokens(c, token_map)) + for c in task_spec.command] + resolved_args = [ + self._escape_compose_interpolation(self._substitute_tokens(a, token_map)) + for a in task_spec.args] + + if resolved_command: + first_cmd = resolved_command[0] + if first_cmd.startswith('/') or first_cmd in self._ENTRYPOINT_COMMANDS: + service['entrypoint'] = [first_cmd] + rest = resolved_command[1:] + resolved_args + else: + rest = resolved_command + resolved_args + if rest: + service['command'] = rest + elif resolved_args: + service['command'] = resolved_args + + if task_spec.environment: + service['environment'] = { + key: self._escape_compose_interpolation( + self._substitute_tokens(value, token_map)) + for key, value in task_spec.environment.items() + } + + gpu_count = self._task_gpu_count(task_spec, spec) + if gpu_count > 0: + available = self._detect_available_gpus() + effective_count = min(gpu_count, available) if available > 0 else 0 + if effective_count > 0: + service['deploy'] = { + 'resources': { + 'reservations': { + 'devices': [{ + 'driver': 'nvidia', + 'count': effective_count, + 'capabilities': ['gpu'], + }] + } + } + } + logger.info( + 'Task "%s" requesting %d GPU(s), using %d', + node.name, gpu_count, effective_count) + else: + logger.warning( + 'Task "%s" requests %d GPU(s) but no GPUs available' + ' — running without GPU support', + node.name, gpu_count) + service['shm_size'] = self._shm_size or self.DEFAULT_SHM_SIZE + elif self._shm_size: + service['shm_size'] = self._shm_size + + depends_on: Dict[str, Dict] = {} + for upstream_task in sorted(node.upstream): + if upstream_task in tasks_to_run: + depends_on[upstream_task] = { + 'condition': 'service_completed_successfully'} + if depends_on: + service['depends_on'] = depends_on + + return service + + def _build_container_token_map(self, node: TaskNode) -> Dict[str, str]: + """Build a mapping of {{token}} keys to on-cluster container paths.""" + tokens: Dict[str, str] = { + 'output': OSMO_OUTPUT_PATH, + } + for index, input_source in enumerate(node.spec.inputs): + if isinstance(input_source, task_module.TaskInputOutput): + tokens[f'input:{input_source.task}'] = ( + f'{OSMO_INPUT_PATH_PREFIX}/{input_source.task}') + tokens[f'input:{index}'] = f'{OSMO_INPUT_PATH_PREFIX}/{index}' + + for task_name in self._task_nodes: + tokens[f'host:{task_name}'] = task_name + + return tokens + + def _collect_compose_results(self, compose_path: str, project_name: str, + tasks_to_run: Set[str]): + """Collect exit codes from Docker Compose services after execution.""" + try: + result = subprocess.run( + [self._docker_cmd, 'compose', '-f', compose_path, + '--project-name', project_name, + 'ps', '-a', '--format', 'json'], + capture_output=True, text=True, check=False, timeout=30) + + if result.returncode == 0 and result.stdout.strip(): + for info in self._parse_compose_ps_output(result.stdout): + service_name = info.get('Service', '') + if service_name in tasks_to_run and service_name not in self._results: + exit_code = info.get('ExitCode', -1) + output_dir = os.path.join( + self._work_dir, service_name, 'output') + self._results[service_name] = TaskResult( + name=service_name, + exit_code=exit_code, + output_dir=output_dir) + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + + for task_name in tasks_to_run: + if task_name not in self._results: + output_dir = os.path.join(self._work_dir, task_name, 'output') + self._results[task_name] = TaskResult( + name=task_name, exit_code=-1, output_dir=output_dir) + + @staticmethod + def _parse_compose_ps_output(output: str) -> List[Dict]: + """Parse the JSON output from ``docker compose ps --format json``. + + Handles both a single JSON array and newline-delimited JSON objects.""" + output = output.strip() + try: + data = json.loads(output) + if isinstance(data, list): + return data + return [data] + except json.JSONDecodeError: + results: List[Dict] = [] + for line in output.splitlines(): + line = line.strip() + if not line: + continue + try: + results.append(json.loads(line)) + except json.JSONDecodeError: + continue + return results + + def _compose_down(self, compose_path: str, project_name: str): + """Clean up Docker Compose containers and networks (preserves bind-mounted data).""" + try: + subprocess.run( + [self._docker_cmd, 'compose', '-f', compose_path, + '--project-name', project_name, + 'down', '--remove-orphans'], + capture_output=True, check=False, timeout=60) + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + + @property + def _state_file_path(self) -> str: + """Absolute path to the JSON state file used for resume tracking.""" + return os.path.join(self._work_dir, STATE_FILE_NAME) + + def _save_state(self): + """Persist current task results to the state file so runs can be resumed later.""" + state = { + 'tasks': { + name: {'exit_code': result.exit_code, 'output_dir': result.output_dir} + for name, result in self._results.items() + if result.exit_code != -1 + } + } + with open(self._state_file_path, 'w', encoding='utf-8') as f: + json.dump(state, f, indent=2) + + def _load_state(self) -> Dict | None: + """Load previously saved task state from disk, returning None if no state file exists.""" + if not os.path.exists(self._state_file_path): + return None + with open(self._state_file_path, encoding='utf-8') as f: + return json.load(f) + + def _restore_completed_tasks(self, from_step: str | None = None): + """Reload completed tasks from a previous run, optionally + invalidating from a given step onward.""" + state = self._load_state() + if state is None: + logger.info('No previous state found — starting from scratch') + return + + completed: Dict[str, Dict] = {} + for name, info in state.get('tasks', {}).items(): + if name not in self._task_nodes: + continue + if info['exit_code'] == 0 and os.path.isdir(info['output_dir']): + completed[name] = info + + if from_step: + if from_step not in self._task_nodes: + raise ValueError(f'Task "{from_step}" not found in workflow') + to_invalidate = self._get_downstream_tasks(from_step) + to_invalidate.add(from_step) + for name in to_invalidate: + completed.pop(name, None) + + for name, info in completed.items(): + self._results[name] = TaskResult( + name=name, exit_code=0, output_dir=info['output_dir']) + logger.info('Resuming: skipping completed task "%s"', name) + + def _get_downstream_tasks(self, task_name: str) -> Set[str]: + """Return all transitive downstream dependents of the given task via BFS.""" + visited: Set[str] = set() + queue = [task_name] + while queue: + current = queue.pop(0) + for downstream in self._task_nodes[current].downstream: + if downstream not in visited: + visited.add(downstream) + queue.append(downstream) + return visited + + def _groups(self, spec: workflow_module.WorkflowSpec) -> List[task_module.TaskGroupSpec]: + """Return the spec's groups, or synthesize one group per task when groups are absent.""" + if spec.groups: + return spec.groups + return [task_module.TaskGroupSpec(name=t.name, tasks=[t]) for t in spec.tasks] + + def _build_dag(self, spec: workflow_module.WorkflowSpec): + """Construct the internal DAG of TaskNodes from the workflow spec's + tasks and input dependencies.""" + self._task_nodes.clear() + task_to_group: Dict[str, str] = {} + + for group in self._groups(spec): + for task_spec in group.tasks: + task_to_group[task_spec.name] = group.name + self._task_nodes[task_spec.name] = TaskNode( + name=task_spec.name, + spec=task_spec, + group=group.name, + ) + + for group in self._groups(spec): + for task_spec in group.tasks: + for input_source in task_spec.inputs: + if isinstance(input_source, task_module.TaskInputOutput): + upstream_task = input_source.task + if upstream_task not in self._task_nodes: + raise ValueError( + f'Task "{task_spec.name}" depends on ' + f'unknown task "{upstream_task}"') + self._task_nodes[task_spec.name].upstream.add(upstream_task) + self._task_nodes[upstream_task].downstream.add(task_spec.name) + + # For flat task lists (no explicit groups), add implicit sequential + # dependencies so each task waits for the previous one — matching + # on-cluster behavior where tasks in a list run sequentially. + if not spec.groups and spec.tasks: + task_names = [t.name for t in spec.tasks] + for i in range(1, len(task_names)): + prev, curr = task_names[i - 1], task_names[i] + self._task_nodes[curr].upstream.add(prev) + self._task_nodes[prev].downstream.add(curr) + + def _validate_for_local(self, spec: workflow_module.WorkflowSpec): + """Raise ValueError if the spec uses features unsupported + in local mode (datasets, URLs, credentials, etc.).""" + unsupported_features = [] + for group in self._groups(spec): + for task_spec in group.tasks: + for input_source in task_spec.inputs: + if isinstance(input_source, task_module.DatasetInputOutput): + unsupported_features.append( + f'Task "{task_spec.name}": dataset inputs require object storage') + elif isinstance(input_source, task_module.URLInputOutput): + unsupported_features.append( + f'Task "{task_spec.name}": URL inputs require network/storage access') + + for output in task_spec.outputs: + unsupported_output = ( + task_module.DatasetInputOutput, + task_module.URLInputOutput, + ) + if isinstance(output, unsupported_output): + unsupported_features.append( + f'Task "{task_spec.name}": dataset/URL outputs require object storage') + + if task_spec.credentials: + unsupported_features.append( + f'Task "{task_spec.name}": credentials require the OSMO secret manager') + + if task_spec.checkpoint: + unsupported_features.append( + f'Task "{task_spec.name}": checkpoints require object storage') + + if task_spec.volumeMounts: + unsupported_features.append( + f'Task "{task_spec.name}": volumeMounts require cluster-level host paths') + + if task_spec.privileged: + unsupported_features.append( + f'Task "{task_spec.name}": privileged containers ' + f'are not supported in local mode') + + if task_spec.hostNetwork: + unsupported_features.append( + f'Task "{task_spec.name}": hostNetwork is not supported in local mode') + + if unsupported_features: + raise ValueError( + 'The following features are not supported in local execution mode:\n - ' + + '\n - '.join(unsupported_features)) + + def _setup_directories(self): + """Create the work directory and per-task output directories on the host filesystem.""" + os.makedirs(self._work_dir, exist_ok=True) + for task_name in self._task_nodes: + os.makedirs(os.path.join(self._work_dir, task_name, 'output'), exist_ok=True) + + def _task_gpu_count(self, task_spec: task_module.TaskSpec, + spec: workflow_module.WorkflowSpec) -> int: + """Return the number of GPUs requested by a task's resource spec, defaulting to 0.""" + resource_spec = spec.resources.get(task_spec.resource) + if resource_spec and resource_spec.gpu: + return resource_spec.gpu + return 0 + + def _substitute_tokens(self, text: str, tokens: Dict[str, str]) -> str: + """Replace all {{key}} placeholders in text with their corresponding token values.""" + for key, value in tokens.items(): + text = re.sub(r'\{\{\s*' + re.escape(key) + r'\s*\}\}', value, text) + return text + + +def check_unresolved_variables(spec_text: str): + """Raise ValueError if the spec contains ``{{ variable }}`` placeholders + that are not OSMO runtime tokens. + + OSMO runtime tokens (``{{output}}``, ``{{input:…}}``, ``{{host:…}}``, + ``{{uuid}}``, ``{{workflow_id}}``, ``{{item}}``) are left intact since + they are resolved at container runtime. Any other ``{{ }}`` pattern + indicates a template variable that was not expanded by ``default-values`` + and would pass through silently, causing subtle breakage. + """ + all_braces = _ANY_DOUBLE_BRACE.findall(spec_text) + unresolved = [ + token for token in all_braces + if not _OSMO_RUNTIME_TOKEN.match(token) + ] + if unresolved: + unique = sorted(set(unresolved)) + raise ValueError( + 'Unresolved template variables found in spec (did you forget to ' + 'add them to default-values?):\n ' + + '\n '.join(unique) + + '\nHint: use "osmo workflow submit --dry-run -f " to ' + 'expand Jinja templates server-side, or add the variables to ' + 'the default-values section.') + + +def run_workflow_locally(spec_path: str, work_dir: str | None = None, + keep_work_dir: bool = False, + resume: bool = False, + from_step: str | None = None, + docker_cmd: str = 'docker', + shm_size: str | None = None, + extra_volumes: List[str] | None = None) -> bool: + """Load a workflow spec from disk and execute it locally via Docker Compose, + managing the work directory lifecycle.""" + if (resume or from_step) and work_dir is None: + raise ValueError( + '--resume and --from-step require --work-dir pointing to a previous run directory.') + + with open(spec_path, encoding='utf-8') as f: + spec_text = f.read() + + abs_path = os.path.abspath(spec_path) + spec_text = spec_includes.resolve_includes( + spec_text, os.path.dirname(abs_path), source_path=abs_path) + + unresolved_env = spec_includes.find_unresolved_env_variables(spec_text) + if unresolved_env: + lines = [f' {name} (from ${env_var})' + for name, env_var in sorted(unresolved_env.items())] + raise ValueError( + 'The following environment variables are required but not set:\n' + + '\n'.join(lines) + + '\nSet them before running, or use "osmo local compose" to ' + 'inspect what needs to be configured.') + + spec_text = spec_includes.resolve_default_values(spec_text) + + template_markers = ('{%', '{#') + if any(marker in spec_text for marker in template_markers): + raise ValueError( + 'This spec uses Jinja templates which require server-side expansion.\n' + 'Run "osmo workflow submit --dry-run -f " first to get the expanded spec,\n' + 'then save that output and run it locally.') + + check_unresolved_variables(spec_text) + + created_work_dir = work_dir is None + effective_work_dir: str = ( + os.path.abspath(work_dir) if work_dir is not None + else tempfile.mkdtemp(prefix='osmo-local-') + ) + if created_work_dir: + logger.info('Using temporary work directory: %s', effective_work_dir) + + executor = LocalExecutor(work_dir=effective_work_dir, keep_work_dir=keep_work_dir, + docker_cmd=docker_cmd, shm_size=shm_size, + extra_volumes=extra_volumes) + spec = executor.load_spec(spec_text) + success = executor.execute(spec, resume=resume or from_step is not None, + from_step=from_step) + + if created_work_dir and not keep_work_dir and success: + logger.info('Cleaning up work directory: %s', effective_work_dir) + shutil.rmtree(effective_work_dir, ignore_errors=True) + elif not success: + logger.info('Work directory preserved for debugging: %s', effective_work_dir) + + return success diff --git a/src/utils/spec_includes.py b/src/utils/spec_includes.py new file mode 100644 index 000000000..9abfcdc0a --- /dev/null +++ b/src/utils/spec_includes.py @@ -0,0 +1,457 @@ +""" +SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +SPDX-License-Identifier: Apache-2.0 +""" + +import copy +import os +import re +from typing import Any, Dict, FrozenSet, List + +import yaml + +from src.lib.utils import osmo_errors + + +_VAR_REF_PATTERN = re.compile(r'^\{\{\s*([a-zA-Z_][a-zA-Z0-9_.]*)\s*\}\}$') +_ENV_REF_PATTERN = re.compile(r'\$\{env:([^}]+)\}') +_SCALAR_REF_PATTERN = re.compile(r'(? bool: + """Return True if *value* is a non-empty list of dicts that all have a ``name`` key.""" + if not isinstance(value, list) or len(value) == 0: + return False + return all(isinstance(item, dict) and 'name' in item for item in value) + + +def _merge_named_lists(base_list: List[Dict[str, Any]], + override_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Merge two lists of named dicts, matching items by their ``name`` field. + + - Items present in both lists are deep-merged (override wins). + - Items only in base are kept in their original position. + - Items only in override are appended after all base items. + """ + base_by_name: Dict[str, Dict[str, Any]] = {} + base_order: List[str] = [] + for item in base_list: + name = item['name'] + base_by_name[name] = item + base_order.append(name) + + override_by_name: Dict[str, Dict[str, Any]] = {} + override_order: List[str] = [] + for item in override_list: + name = item['name'] + override_by_name[name] = item + override_order.append(name) + + merged: List[Dict[str, Any]] = [] + seen: set = set() + + for name in base_order: + if name in override_by_name: + merged.append(deep_merge_dicts(base_by_name[name], override_by_name[name])) + else: + merged.append(base_by_name[name]) + seen.add(name) + + for name in override_order: + if name not in seen: + merged.append(override_by_name[name]) + + return merged + + +def deep_merge_dicts(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: + """Recursively merge two dicts where values in *override* take precedence. + + - Dict values are merged recursively. + - Lists of dicts with a ``name`` key are merged by name (matched items + are deep-merged, unmatched items are kept/appended). + - All other types (plain lists, scalars) in *override* replace the + corresponding *base* value entirely. + """ + merged: Dict[str, Any] = {} + for key in set(base) | set(override): + if key in base and key in override: + base_val = base[key] + override_val = override[key] + if isinstance(base_val, dict) and isinstance(override_val, dict): + merged[key] = deep_merge_dicts(base_val, override_val) + elif _is_named_dict_list(base_val) and _is_named_dict_list(override_val): + merged[key] = _merge_named_lists(base_val, override_val) + else: + merged[key] = override_val + elif key in base: + merged[key] = base[key] + else: + merged[key] = override[key] + return merged + + +def _lookup_dot_path(data: Dict[str, Any], dot_path: str) -> Any: + """Navigate a nested dict via a dot-separated key path. + + Returns the value at the path, ``None`` if the key exists with a null + value, or the ``_MISSING`` sentinel if the key does not exist. + """ + current: Any = data + for segment in dot_path.split('.'): + if not isinstance(current, dict) or segment not in current: + return _MISSING + current = current[segment] + return current + + +def _expand_task_refs(tasks: List[Any], + default_values: Dict[str, Any]) -> List[Any]: + """Replace ``{{ ref }}`` strings in a task list with their dict values from *default_values*. + + - If the referenced value is a dict, it is injected as a task definition + with ``name`` set to the last segment of the reference path (unless + already present). + - If the referenced value is explicitly ``null``, the entry is removed + (the task is excluded from the workflow). + - Unresolvable references or scalar values are left unchanged for Jinja. + """ + expanded: List[Any] = [] + for item in tasks: + if not isinstance(item, str): + expanded.append(item) + continue + match = _VAR_REF_PATTERN.match(item) + if match is None: + expanded.append(item) + continue + ref_path = match.group(1) + value = _lookup_dot_path(default_values, ref_path) + if value is _MISSING: + expanded.append(item) + continue + if value is None: + continue + if not isinstance(value, dict): + expanded.append(item) + continue + task_dict = copy.deepcopy(value) + if 'name' not in task_dict: + task_dict['name'] = ref_path.rsplit('.', 1)[-1] + expanded.append(task_dict) + return expanded + + +def _expand_refs_in_workflow(spec_dict: Dict[str, Any], + default_values: Dict[str, Any]) -> None: + """Expand ``{{ ref }}`` strings in workflow task and group-task lists in place.""" + workflow = spec_dict.get('workflow') + if not isinstance(workflow, dict): + return + + if 'tasks' in workflow and isinstance(workflow['tasks'], list): + workflow['tasks'] = _expand_task_refs(workflow['tasks'], default_values) + + if 'groups' in workflow and isinstance(workflow['groups'], list): + for group in workflow['groups']: + if isinstance(group, dict) and 'tasks' in group \ + and isinstance(group['tasks'], list): + group['tasks'] = _expand_task_refs(group['tasks'], default_values) + + +def resolve_includes(spec_text: str, base_directory: str, + source_path: str | None = None) -> str: + """Resolve ``includes`` directives in a workflow spec. + + Reads included files relative to *base_directory*, recursively resolves + nested includes, and deep-merges all specs. The main file's values take + precedence over included values. Diamond-shaped includes (A -> B -> D and + A -> C -> D) are allowed; true cycles are detected and rejected. + + Included files (and the main file when it uses ``includes``) must be + parseable by ``yaml.safe_load`` -- unquoted Jinja template syntax such as + ``{{ var }}`` is not supported. Quoted references like ``"{{ var }}"`` + are fine. + + Task references (``"{{ key }}"`` in ``tasks`` lists) are resolved against + the merged ``default-values``. Setting a key to ``null`` in + ``default-values`` removes the corresponding task. + + Args: + spec_text: Raw YAML text of the workflow spec. + base_directory: Directory to resolve relative include paths against. + source_path: Absolute path of the file being processed, used for + cycle detection of the root file. + + Returns: + Merged YAML text with all includes resolved and the ``includes`` key + removed. If the spec has no ``includes`` key the original text is + returned unchanged. + """ + if 'includes:' not in spec_text: + return spec_text + + ancestors: FrozenSet[str] = frozenset() + if source_path is not None: + ancestors = frozenset({os.path.normpath(os.path.abspath(source_path))}) + + try: + spec_dict = yaml.safe_load(spec_text) + except yaml.YAMLError as yaml_err: + if re.search(r'^includes:', spec_text, re.MULTILINE): + raise osmo_errors.OSMOUserError( + 'Failed to parse workflow spec for includes resolution. ' + 'Specs using "includes" must be valid YAML — Jinja template ' + 'variables like {{ }} must be in quoted strings. ' + f'Parse error: {yaml_err}') from yaml_err + return spec_text + + if not isinstance(spec_dict, dict) or 'includes' not in spec_dict: + return spec_text + + return _resolve_includes(spec_dict, base_directory, ancestors) + + +def _resolve_includes(spec_dict: Dict[str, Any], base_directory: str, + ancestors: FrozenSet[str]) -> str: + """Internal recursive include resolver operating on a parsed YAML dict.""" + includes = spec_dict.pop('includes', None) + if includes is None: + defaults = spec_dict.get('default-values', {}) + if isinstance(defaults, dict): + _expand_refs_in_workflow(spec_dict, defaults) + return yaml.safe_dump(spec_dict, default_flow_style=False, sort_keys=False) + + if not isinstance(includes, list): + raise osmo_errors.OSMOUserError( + 'The "includes" key must be a list of file paths.') + + included_dicts: List[Dict[str, Any]] = [] + + for include_path in includes: + if not isinstance(include_path, str): + raise osmo_errors.OSMOUserError( + f'Each include path must be a string, got: ' + f'{type(include_path).__name__}') + + resolved_path = os.path.normpath( + os.path.join(base_directory, include_path)) + + if resolved_path in ancestors: + raise osmo_errors.OSMOUserError( + f'Circular include detected: "{include_path}" ' + f'(resolved to {resolved_path})') + + if not os.path.isfile(resolved_path): + raise osmo_errors.OSMOUserError( + f'Included file not found: "{include_path}" ' + f'(resolved to {resolved_path})') + + with open(resolved_path, encoding='utf-8') as file_handle: + included_text = file_handle.read() + + try: + included_dict = yaml.safe_load(included_text) + except yaml.YAMLError as yaml_err: + raise osmo_errors.OSMOUserError( + f'Failed to parse included file "{include_path}": {yaml_err}') from yaml_err + + if not isinstance(included_dict, dict): + raise osmo_errors.OSMOUserError( + f'Included file "{include_path}" must be a YAML mapping ' + f'at the top level.') + + child_ancestors = ancestors | {resolved_path} + + if 'includes' in included_dict: + included_resolved_text = _resolve_includes( + included_dict, os.path.dirname(resolved_path), + child_ancestors) + included_dict = yaml.safe_load(included_resolved_text) + + included_dict.pop('includes', None) + included_dicts.append(included_dict) + + all_defaults: Dict[str, Any] = {} + for included in included_dicts: + all_defaults = deep_merge_dicts( + all_defaults, included.get('default-values', {})) + all_defaults = deep_merge_dicts( + all_defaults, spec_dict.get('default-values', {})) + + for included in included_dicts: + _expand_refs_in_workflow(included, all_defaults) + _expand_refs_in_workflow(spec_dict, all_defaults) + + base_dict: Dict[str, Any] = {} + for included in included_dicts: + base_dict = deep_merge_dicts(base_dict, included) + + merged = deep_merge_dicts(base_dict, spec_dict) + return yaml.safe_dump(merged, default_flow_style=False, sort_keys=False) + + +def _resolve_env_refs(text: str) -> str: + """Replace ``${env:VAR}`` patterns with their values from ``os.environ``.""" + def _replacer(match: re.Match) -> str: + return os.environ.get(match.group(1), '') + return _ENV_REF_PATTERN.sub(_replacer, text) + + +def _resolve_env_refs_recursive(obj: Any) -> Any: + """Walk *obj* and resolve ``${env:VAR}`` patterns in every string.""" + if isinstance(obj, str): + return _resolve_env_refs(obj) + if isinstance(obj, dict): + return {k: _resolve_env_refs_recursive(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_resolve_env_refs_recursive(item) for item in obj] + return obj + + +def _collect_scalar_variables(default_values: Dict[str, Any]) -> Dict[str, str]: + """Return only the scalar (non-dict, non-list) entries from *default_values* as strings.""" + variables: Dict[str, str] = {} + for key, value in default_values.items(): + if isinstance(value, str): + variables[key] = value + elif isinstance(value, (int, float, bool)): + variables[key] = str(value) + return variables + + +def _resolve_nested_variables(variables: Dict[str, str], + max_iterations: int = 10) -> None: + """Iteratively resolve ``{ref}`` placeholders inside variable values themselves.""" + for _ in range(max_iterations): + changed = False + for key, value in list(variables.items()): + if not isinstance(value, str): + continue + def _replacer(match: re.Match) -> str: + ref = match.group(1) + return variables.get(ref, match.group(0)) + new_value = _SCALAR_REF_PATTERN.sub(_replacer, value) + if new_value != value: + variables[key] = new_value + changed = True + if not changed: + break + + +def _extract_and_remove_default_values( + spec_text: str) -> tuple[Dict[str, Any] | None, str]: + """Extract the ``default-values`` block from raw YAML text. + + Returns ``(default_values_dict, remaining_text)`` where the block has been + removed from *remaining_text*. Returns ``(None, spec_text)`` when no + ``default-values`` section is found or when it cannot be parsed. + """ + match = _DEFAULT_VALUES_BLOCK.search(spec_text) + if match is None: + return None, spec_text + + dv_section = 'default-values:\n' + match.group(1) + try: + parsed = yaml.safe_load(dv_section) + except yaml.YAMLError: + return None, spec_text + + default_values = parsed.get('default-values') if isinstance(parsed, dict) else None + if not isinstance(default_values, dict): + return None, spec_text + + remaining = spec_text[:match.start()] + spec_text[match.end():] + return default_values, remaining + + +def find_unresolved_env_variables(spec_text: str) -> Dict[str, str]: + """Find ``default-values`` variables whose ``${env:VAR}`` source is not set. + + Call this on spec text **after** ``resolve_includes`` (so that + ``default-values`` from all included files are merged) and **before** + ``resolve_default_values`` (which replaces ``${env:…}`` patterns). + + Returns: + A dict mapping *variable_name* → *env_var_name* for every scalar + entry whose value contains an ``${env:VAR}`` reference where ``VAR`` + is not present in ``os.environ``. + """ + default_values, _ = _extract_and_remove_default_values(spec_text) + if default_values is None: + return {} + + unresolved: Dict[str, str] = {} + for key, value in default_values.items(): + if not isinstance(value, str): + continue + for match in _ENV_REF_PATTERN.finditer(value): + env_var = match.group(1) + if env_var not in os.environ: + unresolved[key] = env_var + return unresolved + + +def resolve_default_values(spec_text: str) -> str: + """Resolve ``default-values`` variables and ``${env:VAR}`` references. + + All processing happens at the text level so that Jinja-style ``{{var}}`` + patterns (which are invalid YAML when unquoted) can be substituted before + the spec is parsed. + + Processing steps: + + 1. Resolve ``${env:VAR}`` patterns everywhere against ``os.environ``. + 2. Extract and remove the ``default-values`` block from the raw text. + 3. Collect scalar entries into a variable map. + 4. Iteratively resolve ``{variable}`` references within the variable map + itself (handles chained references like ``local_dir: "{repo_dir}/local"``). + 5. Substitute ``{{variable}}`` (Jinja-style double-brace) and ``{variable}`` + (single-brace) patterns in the spec text for every known variable key. + OSMO runtime tokens (``{{output}}``, ``{{input:0}}``, ``{{host:…}}``, + ``{{item}}``, etc.) are left intact because their names are not keys in + ``default-values``. + 6. Return the cleaned text without the ``default-values`` section. + + If the spec has no ``default-values``, the text is returned with only + ``${env:VAR}`` references resolved. + """ + spec_text = _resolve_env_refs(spec_text) + + default_values, spec_text = _extract_and_remove_default_values(spec_text) + if default_values is None: + return spec_text + + resolved_defaults: Dict[str, Any] = _resolve_env_refs_recursive(default_values) + variables = _collect_scalar_variables(resolved_defaults) + _resolve_nested_variables(variables) + + def _jinja_replacer(match: re.Match) -> str: + return variables.get(match.group(1), match.group(0)) + + def _scalar_replacer(match: re.Match) -> str: + return variables.get(match.group(1), match.group(0)) + + spec_text = _JINJA_VAR_PATTERN.sub(_jinja_replacer, spec_text) + spec_text = _SCALAR_REF_PATTERN.sub(_scalar_replacer, spec_text) + + return spec_text diff --git a/src/utils/tests/BUILD b/src/utils/tests/BUILD index a7e63113d..1e89b9b57 100644 --- a/src/utils/tests/BUILD +++ b/src/utils/tests/BUILD @@ -30,3 +30,25 @@ osmo_py_test( requirement("jwcrypto"), ] ) + +osmo_py_test( + name = "test_local_executor", + srcs = ["test_local_executor.py"], + deps = [ + "//src/cli:cli_lib", + "//src/utils:local_executor", + ], + data = [ + "//cookbook/tutorials:tutorial_specs", + ], + local = True, +) + +osmo_py_test( + name = "test_spec_includes", + srcs = ["test_spec_includes.py"], + deps = [ + "//src/lib/utils:osmo_errors", + "//src/utils:spec_includes", + ], +) diff --git a/src/utils/tests/test_local_executor.py b/src/utils/tests/test_local_executor.py new file mode 100644 index 000000000..e762a856a --- /dev/null +++ b/src/utils/tests/test_local_executor.py @@ -0,0 +1,2436 @@ +""" +SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # pylint: disable=line-too-long + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +SPDX-License-Identifier: Apache-2.0 +""" + +import os +import shutil +import subprocess +import tempfile +import textwrap +import unittest +from typing import Any, ClassVar, Dict + +import yaml + +from src.utils import spec_includes +from src.utils.job import task as task_module +from src.utils.local_executor import ( + OSMO_INPUT_PATH_PREFIX, + OSMO_OUTPUT_PATH, + LocalExecutor, + TaskNode, + TaskResult, + check_unresolved_variables, + run_workflow_locally, +) + + +# --------------------------------------------------------------------------- +# Helper: detect Docker + Compose availability once for the entire module +# --------------------------------------------------------------------------- +def _docker_available() -> bool: + """Return True if the Docker daemon and Compose V2 are reachable.""" + try: + docker_result = subprocess.run( + ['docker', 'info'], + capture_output=True, + timeout=10, + ) + if docker_result.returncode != 0: + return False + compose_result = subprocess.run( + ['docker', 'compose', 'version'], + capture_output=True, + timeout=10, + ) + return compose_result.returncode == 0 + except (FileNotFoundError, subprocess.TimeoutExpired): + return False + + +DOCKER_AVAILABLE = _docker_available() +SKIP_DOCKER_MSG = 'Docker with Compose V2 is not available on this machine' + + +# ============================================================================ +# Unit tests — no Docker required; exercise parsing, DAG, tokens, validation +# ============================================================================ +class TestLoadSpec(unittest.TestCase): + """Verify that real OSMO YAML specs are parsed correctly via the existing Pydantic models.""" + + def test_single_task_spec(self): + """Parse a minimal single-task workflow and verify name, task count, and image.""" + spec_text = textwrap.dedent('''\ + workflow: + name: hello-osmo + tasks: + - name: hello + image: ubuntu:24.04 + command: ["echo"] + args: ["Hello from OSMO!"] + ''') + executor = LocalExecutor(work_dir='/tmp/unused') + spec = executor.load_spec(spec_text) + self.assertEqual(spec.name, 'hello-osmo') + self.assertEqual(len(spec.tasks), 1) + self.assertEqual(spec.tasks[0].name, 'hello') + self.assertEqual(spec.tasks[0].image, 'ubuntu:24.04') + + def test_serial_tasks_spec(self): + """Parse a two-task serial workflow and verify the task input dependency is resolved.""" + spec_text = textwrap.dedent('''\ + workflow: + name: serial-tasks + tasks: + - name: task1 + image: ubuntu:22.04 + command: [sh] + args: [/tmp/run.sh] + files: + - contents: | + echo "Hello from task1" + echo "data" > {{output}}/test.txt + path: /tmp/run.sh + - name: task2 + image: ubuntu:22.04 + command: [sh] + args: [/tmp/run.sh] + files: + - contents: | + cat {{input:0}}/test.txt + path: /tmp/run.sh + inputs: + - task: task1 + ''') + executor = LocalExecutor(work_dir='/tmp/unused') + spec = executor.load_spec(spec_text) + self.assertEqual(spec.name, 'serial-tasks') + self.assertEqual(len(spec.tasks), 2) + first_input = spec.tasks[1].inputs[0] + self.assertIsInstance(first_input, task_module.TaskInputOutput) + if isinstance(first_input, task_module.TaskInputOutput): + self.assertEqual(first_input.task, 'task1') + + def test_groups_spec(self): + """Parse a grouped workflow and verify group structure and the lead task flag.""" + spec_text = textwrap.dedent('''\ + workflow: + name: grouped + groups: + - name: first-group + tasks: + - name: leader + lead: true + image: ubuntu:24.04 + command: ["echo", "leader"] + - name: follower + image: ubuntu:24.04 + command: ["echo", "follower"] + ''') + executor = LocalExecutor(work_dir='/tmp/unused') + spec = executor.load_spec(spec_text) + self.assertEqual(len(spec.groups), 1) + self.assertEqual(len(spec.groups[0].tasks), 2) + self.assertTrue(spec.groups[0].tasks[0].lead) + + def test_versioned_spec(self): + """Parse a spec with an explicit version field and verify it loads correctly.""" + spec_text = textwrap.dedent('''\ + version: 2 + workflow: + name: versioned + tasks: + - name: task + image: alpine:3.18 + command: ["echo", "ok"] + ''') + executor = LocalExecutor(work_dir='/tmp/unused') + spec = executor.load_spec(spec_text) + self.assertEqual(spec.name, 'versioned') + + def test_invalid_version_rejected(self): + """Reject a spec with an unsupported version number.""" + spec_text = textwrap.dedent('''\ + version: 99 + workflow: + name: bad-version + tasks: + - name: task + image: alpine:3.18 + command: ["echo", "ok"] + ''') + executor = LocalExecutor(work_dir='/tmp/unused') + with self.assertRaises(ValueError): + executor.load_spec(spec_text) + + def test_both_tasks_and_groups_rejected(self): + """Reject a spec that defines both top-level tasks and groups simultaneously.""" + spec_text = textwrap.dedent('''\ + workflow: + name: invalid + tasks: + - name: t + image: alpine:3.18 + command: ["echo"] + groups: + - name: g + tasks: + - name: t2 + image: alpine:3.18 + command: ["echo"] + ''') + executor = LocalExecutor(work_dir='/tmp/unused') + with self.assertRaises(ValueError): + executor.load_spec(spec_text) + + def test_empty_workflow_rejected(self): + """Reject a spec with no tasks or groups defined.""" + spec_text = textwrap.dedent('''\ + workflow: + name: empty + ''') + executor = LocalExecutor(work_dir='/tmp/unused') + with self.assertRaises(ValueError): + executor.load_spec(spec_text) + + def test_resources_spec_parsed(self): + """Parse a spec with resource definitions and verify cpu/memory values.""" + spec_text = textwrap.dedent('''\ + workflow: + name: with-resources + resources: + default: + cpu: 2 + memory: 4Gi + storage: 10Gi + tasks: + - name: task + image: ubuntu:24.04 + command: ["echo", "ok"] + ''') + executor = LocalExecutor(work_dir='/tmp/unused') + spec = executor.load_spec(spec_text) + self.assertEqual(spec.resources['default'].cpu, 2) + self.assertEqual(spec.resources['default'].memory, '4Gi') + + def test_environment_parsed(self): + """Parse a spec with environment variables and verify key-value pairs are preserved.""" + spec_text = textwrap.dedent('''\ + workflow: + name: env-test + tasks: + - name: task + image: alpine:3.18 + command: ["printenv"] + environment: + MY_VAR: hello + ANOTHER: world + ''') + executor = LocalExecutor(work_dir='/tmp/unused') + spec = executor.load_spec(spec_text) + self.assertEqual(spec.tasks[0].environment['MY_VAR'], 'hello') + self.assertEqual(spec.tasks[0].environment['ANOTHER'], 'world') + + +class TestBuildDag(unittest.TestCase): + """Verify DAG construction from task dependencies.""" + + def _make_executor(self) -> LocalExecutor: + """Create a LocalExecutor with a throwaway work directory for DAG-only tests.""" + return LocalExecutor(work_dir='/tmp/unused') + + def test_no_dependencies(self): + """Tasks in separate groups with no input dependencies have empty upstream/downstream.""" + spec_text = textwrap.dedent('''\ + workflow: + name: parallel + groups: + - name: group-a + tasks: + - name: a + lead: true + image: alpine:3.18 + command: ["echo", "a"] + - name: group-b + tasks: + - name: b + lead: true + image: alpine:3.18 + command: ["echo", "b"] + - name: group-c + tasks: + - name: c + lead: true + image: alpine:3.18 + command: ["echo", "c"] + ''') + executor = self._make_executor() + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + + self.assertEqual(len(executor._task_nodes), 3) + for node in executor._task_nodes.values(): + self.assertEqual(len(node.upstream), 0) + self.assertEqual(len(node.downstream), 0) + + def test_flat_task_list_gets_implicit_sequential_deps(self): + """A flat task list (no groups) gets implicit sequential dependencies.""" + spec_text = textwrap.dedent('''\ + workflow: + name: sequential + tasks: + - name: a + image: alpine:3.18 + command: ["echo", "a"] + - name: b + image: alpine:3.18 + command: ["echo", "b"] + - name: c + image: alpine:3.18 + command: ["echo", "c"] + ''') + executor = self._make_executor() + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + + self.assertEqual(len(executor._task_nodes), 3) + self.assertEqual(executor._task_nodes['a'].upstream, set()) + self.assertEqual(executor._task_nodes['a'].downstream, {'b'}) + self.assertEqual(executor._task_nodes['b'].upstream, {'a'}) + self.assertEqual(executor._task_nodes['b'].downstream, {'c'}) + self.assertEqual(executor._task_nodes['c'].upstream, {'b'}) + self.assertEqual(executor._task_nodes['c'].downstream, set()) + + def test_serial_chain(self): + """A three-task chain produces correct upstream/downstream links at each step.""" + spec_text = textwrap.dedent('''\ + workflow: + name: serial + tasks: + - name: first + image: alpine:3.18 + command: ["echo"] + - name: second + image: alpine:3.18 + command: ["echo"] + inputs: + - task: first + - name: third + image: alpine:3.18 + command: ["echo"] + inputs: + - task: second + ''') + executor = self._make_executor() + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + + self.assertEqual(executor._task_nodes['first'].upstream, set()) + self.assertEqual(executor._task_nodes['first'].downstream, {'second'}) + self.assertEqual(executor._task_nodes['second'].upstream, {'first'}) + self.assertEqual(executor._task_nodes['second'].downstream, {'third'}) + self.assertEqual(executor._task_nodes['third'].upstream, {'second'}) + self.assertEqual(executor._task_nodes['third'].downstream, set()) + + def test_diamond_dependency(self): + """A diamond DAG (root -> left/right -> join) wires fan-out and fan-in edges correctly.""" + spec_text = textwrap.dedent('''\ + workflow: + name: diamond + tasks: + - name: root + image: alpine:3.18 + command: ["echo"] + - name: left + image: alpine:3.18 + command: ["echo"] + inputs: + - task: root + - name: right + image: alpine:3.18 + command: ["echo"] + inputs: + - task: root + - name: join + image: alpine:3.18 + command: ["echo"] + inputs: + - task: left + - task: right + ''') + executor = self._make_executor() + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + + self.assertEqual(executor._task_nodes['root'].downstream, {'left', 'right'}) + self.assertEqual(executor._task_nodes['join'].upstream, {'left', 'right'}) + + def test_unknown_dependency_raises(self): + """Referencing a non-existent upstream task raises ValueError.""" + spec_text = textwrap.dedent('''\ + workflow: + name: broken + tasks: + - name: task1 + image: alpine:3.18 + command: ["echo"] + inputs: + - task: nonexistent + ''') + executor = self._make_executor() + spec = executor.load_spec(spec_text) + with self.assertRaises(ValueError) as context: + executor._build_dag(spec) + self.assertIn('nonexistent', str(context.exception)) + + def test_groups_with_cross_group_deps(self): + """Dependencies between tasks in different groups are wired correctly.""" + spec_text = textwrap.dedent('''\ + workflow: + name: cross-group + groups: + - name: fetch + tasks: + - name: download + lead: true + image: alpine:3.18 + command: ["echo"] + - name: process + tasks: + - name: transform + lead: true + image: alpine:3.18 + command: ["echo"] + inputs: + - task: download + ''') + executor = self._make_executor() + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + + self.assertEqual(executor._task_nodes['download'].downstream, {'transform'}) + self.assertEqual(executor._task_nodes['transform'].upstream, {'download'}) + + +class TestCycleDetection(unittest.TestCase): + """Verify that circular dependencies are detected upfront.""" + + def _make_executor(self) -> LocalExecutor: + return LocalExecutor(work_dir='/tmp/unused') + + @staticmethod + def _stub_spec() -> task_module.TaskSpec: + """Return a minimal TaskSpec usable as a placeholder in TaskNode.""" + return task_module.TaskSpec(name='stub', image='alpine:3.18', command=['true']) + + def test_no_cycle_passes(self): + """A valid DAG passes cycle detection without error.""" + spec_text = textwrap.dedent('''\ + workflow: + name: ok + tasks: + - name: a + image: alpine:3.18 + command: ["echo"] + - name: b + image: alpine:3.18 + command: ["echo"] + inputs: + - task: a + ''') + executor = self._make_executor() + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._detect_cycles() + + def test_self_cycle_detected(self): + """A task depending on itself is detected as a cycle. + + Note: Pydantic's TaskSpec validation actually prevents self-references + in practice, but this tests the DAG logic directly. + """ + executor = self._make_executor() + executor._task_nodes['a'] = TaskNode( + name='a', spec=self._stub_spec(), group='g', + upstream={'a'}, downstream={'a'}) + with self.assertRaises(ValueError) as context: + executor._detect_cycles() + self.assertIn('Circular dependency', str(context.exception)) + + def test_two_node_cycle_detected(self): + """A mutual dependency between two tasks is detected.""" + executor = self._make_executor() + executor._task_nodes['a'] = TaskNode( + name='a', spec=self._stub_spec(), group='g', + upstream={'b'}, downstream={'b'}) + executor._task_nodes['b'] = TaskNode( + name='b', spec=self._stub_spec(), group='g', + upstream={'a'}, downstream={'a'}) + with self.assertRaises(ValueError) as context: + executor._detect_cycles() + self.assertIn('a', str(context.exception)) + self.assertIn('b', str(context.exception)) + + def test_indirect_cycle_detected(self): + """A three-node cycle (a -> b -> c -> a) is detected.""" + executor = self._make_executor() + executor._task_nodes['a'] = TaskNode( + name='a', spec=self._stub_spec(), group='g', + upstream={'c'}, downstream={'b'}) + executor._task_nodes['b'] = TaskNode( + name='b', spec=self._stub_spec(), group='g', + upstream={'a'}, downstream={'c'}) + executor._task_nodes['c'] = TaskNode( + name='c', spec=self._stub_spec(), group='g', + upstream={'b'}, downstream={'a'}) + with self.assertRaises(ValueError) as context: + executor._detect_cycles() + self.assertIn('Circular dependency', str(context.exception)) + + def test_parallel_tasks_no_cycle(self): + """Independent parallel tasks pass cycle detection.""" + spec_text = textwrap.dedent('''\ + workflow: + name: parallel + tasks: + - name: a + image: alpine:3.18 + command: ["echo"] + - name: b + image: alpine:3.18 + command: ["echo"] + ''') + executor = self._make_executor() + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._detect_cycles() + + +class TestSubstituteTokens(unittest.TestCase): + """Verify {{token}} placeholder replacement in command strings and file contents.""" + + def test_output_token(self): + """The {{output}} token is replaced with the on-cluster output path.""" + executor = LocalExecutor(work_dir='/tmp/unused') + tokens = {'output': OSMO_OUTPUT_PATH} + result = executor._substitute_tokens( + 'echo data > {{output}}/file.txt', tokens) + self.assertEqual(result, f'echo data > {OSMO_OUTPUT_PATH}/file.txt') + + def test_input_by_index(self): + """The {{input:N}} token is replaced with the Nth upstream input path.""" + executor = LocalExecutor(work_dir='/tmp/unused') + tokens = {'input:0': f'{OSMO_INPUT_PATH_PREFIX}/0'} + result = executor._substitute_tokens('cat {{input:0}}/data.csv', tokens) + self.assertEqual(result, f'cat {OSMO_INPUT_PATH_PREFIX}/0/data.csv') + + def test_input_by_name(self): + """The {{input:taskname}} token is replaced with the named task's input path.""" + executor = LocalExecutor(work_dir='/tmp/unused') + tokens = {'input:task1': f'{OSMO_INPUT_PATH_PREFIX}/task1'} + result = executor._substitute_tokens('cat {{ input:task1 }}/data.csv', tokens) + self.assertEqual(result, f'cat {OSMO_INPUT_PATH_PREFIX}/task1/data.csv') + + def test_whitespace_around_tokens(self): + """Whitespace inside {{ token }} braces is tolerated during substitution.""" + executor = LocalExecutor(work_dir='/tmp/unused') + tokens = {'output': '/out'} + result = executor._substitute_tokens('{{ output }}/file.txt', tokens) + self.assertEqual(result, '/out/file.txt') + + def test_multiple_tokens_in_one_string(self): + """Multiple distinct tokens in the same string are all replaced.""" + executor = LocalExecutor(work_dir='/tmp/unused') + tokens = {'output': OSMO_OUTPUT_PATH, 'input:0': f'{OSMO_INPUT_PATH_PREFIX}/0'} + result = executor._substitute_tokens( + 'cp {{input:0}}/src {{output}}/dst', tokens) + self.assertEqual( + result, + f'cp {OSMO_INPUT_PATH_PREFIX}/0/src {OSMO_OUTPUT_PATH}/dst') + + def test_no_tokens_unchanged(self): + """Text without any token placeholders passes through unchanged.""" + executor = LocalExecutor(work_dir='/tmp/unused') + result = executor._substitute_tokens('plain text no tokens', {}) + self.assertEqual(result, 'plain text no tokens') + + def test_host_token(self): + """The {{host:taskname}} token is replaced with the compose service name.""" + executor = LocalExecutor(work_dir='/tmp/unused') + tokens = {'host:worker': 'worker'} + result = executor._substitute_tokens( + 'http://{{ host:worker }}:8080/health', tokens) + self.assertEqual(result, 'http://worker:8080/health') + + +class TestBuildContainerTokenMap(unittest.TestCase): + """Verify that token maps use on-cluster container paths.""" + + def test_output_only(self): + """A task with no inputs produces a token map with the on-cluster output path.""" + spec_text = textwrap.dedent('''\ + workflow: + name: simple + tasks: + - name: task1 + image: alpine:3.18 + command: ["echo"] + ''') + executor = LocalExecutor(work_dir='/tmp/work') + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + + node = executor._task_nodes['task1'] + tokens = executor._build_container_token_map(node) + self.assertEqual(tokens['output'], OSMO_OUTPUT_PATH) + self.assertIn('host:task1', tokens) + self.assertEqual(tokens['host:task1'], 'task1') + + def test_with_upstream_inputs(self): + """A task with upstream inputs gets on-cluster input paths by index and name.""" + spec_text = textwrap.dedent('''\ + workflow: + name: serial + tasks: + - name: producer + image: alpine:3.18 + command: ["echo"] + - name: consumer + image: alpine:3.18 + command: ["echo"] + inputs: + - task: producer + ''') + executor = LocalExecutor(work_dir='/tmp/work') + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + + node = executor._task_nodes['consumer'] + tokens = executor._build_container_token_map(node) + + self.assertEqual(tokens['output'], OSMO_OUTPUT_PATH) + self.assertEqual(tokens['input:0'], f'{OSMO_INPUT_PATH_PREFIX}/0') + self.assertEqual(tokens['input:producer'], f'{OSMO_INPUT_PATH_PREFIX}/producer') + self.assertIn('host:producer', tokens) + self.assertIn('host:consumer', tokens) + + +class TestComposeConfigGeneration(unittest.TestCase): + """Verify that the generated Docker Compose config is correct.""" + + def setUp(self): + self.work_dir = tempfile.mkdtemp(prefix='osmo-local-compose-') + + def tearDown(self): + shutil.rmtree(self.work_dir, ignore_errors=True) + + def test_single_task_compose_config(self): + """A single-task workflow generates a valid compose config with correct volumes.""" + spec_text = textwrap.dedent('''\ + workflow: + name: simple + tasks: + - name: hello + image: alpine:3.18 + command: ["echo", "hi"] + ''') + executor = LocalExecutor(work_dir=self.work_dir) + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._setup_directories() + config = executor._generate_compose_config(spec, {'hello'}) + + self.assertIn('services', config) + self.assertIn('hello', config['services']) + service = config['services']['hello'] + self.assertEqual(service['image'], 'alpine:3.18') + self.assertNotIn('entrypoint', service) + self.assertEqual(service['command'], ['echo', 'hi']) + + output_mount = f'{self.work_dir}/hello/output:{OSMO_OUTPUT_PATH}' + self.assertIn(output_mount, service['volumes']) + + def test_serial_tasks_have_depends_on(self): + """A serial workflow generates depends_on with service_completed_successfully.""" + spec_text = textwrap.dedent('''\ + workflow: + name: serial + tasks: + - name: task1 + image: alpine:3.18 + command: ["echo"] + - name: task2 + image: alpine:3.18 + command: ["echo"] + inputs: + - task: task1 + ''') + executor = LocalExecutor(work_dir=self.work_dir) + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._setup_directories() + config = executor._generate_compose_config(spec, {'task1', 'task2'}) + + task2_service = config['services']['task2'] + self.assertIn('depends_on', task2_service) + self.assertEqual( + task2_service['depends_on']['task1'], + {'condition': 'service_completed_successfully'}) + + task1_service = config['services']['task1'] + self.assertNotIn('depends_on', task1_service) + + def test_parallel_tasks_no_depends_on(self): + """Independent parallel tasks (in separate groups) have no depends_on entries.""" + spec_text = textwrap.dedent('''\ + workflow: + name: parallel + groups: + - name: group-a + tasks: + - name: a + lead: true + image: alpine:3.18 + command: ["echo"] + - name: group-b + tasks: + - name: b + lead: true + image: alpine:3.18 + command: ["echo"] + ''') + executor = LocalExecutor(work_dir=self.work_dir) + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._setup_directories() + config = executor._generate_compose_config(spec, {'a', 'b'}) + + self.assertNotIn('depends_on', config['services']['a']) + self.assertNotIn('depends_on', config['services']['b']) + + def test_input_volumes_mount_at_cluster_paths(self): + """Input volumes mount at /osmo/data/input/N and /osmo/data/input/taskname.""" + spec_text = textwrap.dedent('''\ + workflow: + name: serial + tasks: + - name: producer + image: alpine:3.18 + command: ["echo"] + - name: consumer + image: alpine:3.18 + command: ["echo"] + inputs: + - task: producer + ''') + executor = LocalExecutor(work_dir=self.work_dir) + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._setup_directories() + config = executor._generate_compose_config(spec, {'producer', 'consumer'}) + + consumer_volumes = config['services']['consumer']['volumes'] + upstream_output = f'{self.work_dir}/producer/output' + + self.assertIn(f'{upstream_output}:{OSMO_INPUT_PATH_PREFIX}/0:ro', + consumer_volumes) + self.assertIn(f'{upstream_output}:{OSMO_INPUT_PATH_PREFIX}/producer:ro', + consumer_volumes) + + def test_environment_in_compose_config(self): + """Environment variables appear in the compose service config.""" + spec_text = textwrap.dedent('''\ + workflow: + name: env-test + tasks: + - name: task + image: alpine:3.18 + command: ["printenv"] + environment: + MY_VAR: hello + SECOND: "42" + ''') + executor = LocalExecutor(work_dir=self.work_dir) + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._setup_directories() + config = executor._generate_compose_config(spec, {'task'}) + + env = config['services']['task']['environment'] + self.assertEqual(env['MY_VAR'], 'hello') + self.assertEqual(env['SECOND'], '42') + + def test_files_written_and_mounted(self): + """Inline files are written to host and mounted into the compose service.""" + spec_text = textwrap.dedent('''\ + workflow: + name: files-test + tasks: + - name: task + image: alpine:3.18 + command: ["sh", "/tmp/run.sh"] + files: + - contents: | + echo "hello" + path: /tmp/run.sh + ''') + executor = LocalExecutor(work_dir=self.work_dir) + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._setup_directories() + config = executor._generate_compose_config(spec, {'task'}) + + task_files_dir = os.path.join(self.work_dir, 'task', 'files') + host_file = os.path.join(task_files_dir, 'tmp', 'run.sh') + self.assertTrue(os.path.exists(host_file)) + with open(host_file) as f: + self.assertIn('echo "hello"', f.read()) + + volumes = config['services']['task']['volumes'] + expected_mount = f'{os.path.realpath(host_file)}:/tmp/run.sh:ro' + self.assertIn(expected_mount, volumes) + + def test_resume_skips_completed_tasks(self): + """When resuming, completed tasks are excluded from the compose config.""" + spec_text = textwrap.dedent('''\ + workflow: + name: resume + tasks: + - name: task1 + image: alpine:3.18 + command: ["echo"] + - name: task2 + image: alpine:3.18 + command: ["echo"] + inputs: + - task: task1 + ''') + executor = LocalExecutor(work_dir=self.work_dir) + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._setup_directories() + + executor._results['task1'] = TaskResult( + name='task1', exit_code=0, + output_dir=os.path.join(self.work_dir, 'task1', 'output')) + + tasks_to_run = set(executor._task_nodes.keys()) - set(executor._results.keys()) + config = executor._generate_compose_config(spec, tasks_to_run) + + self.assertNotIn('task1', config['services']) + self.assertIn('task2', config['services']) + self.assertNotIn('depends_on', config['services']['task2']) + + +class TestValidateForLocal(unittest.TestCase): + """Verify that unsupported features are detected and rejected.""" + + def _make_executor(self) -> LocalExecutor: + """Create a LocalExecutor with a throwaway work directory for validation-only tests.""" + return LocalExecutor(work_dir='/tmp/unused') + + def test_simple_spec_passes(self): + """A spec using only task-to-task inputs passes local validation.""" + spec_text = textwrap.dedent('''\ + workflow: + name: ok + tasks: + - name: task + image: alpine:3.18 + command: ["echo", "ok"] + ''') + executor = self._make_executor() + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._validate_for_local(spec) + + def test_dataset_input_rejected(self): + """A spec with dataset inputs is rejected as unsupported in local mode.""" + spec_text = textwrap.dedent('''\ + workflow: + name: bad + tasks: + - name: task + image: ubuntu:24.04 + command: ["echo"] + inputs: + - dataset: + name: my_dataset + ''') + executor = self._make_executor() + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + with self.assertRaises(ValueError) as context: + executor._validate_for_local(spec) + self.assertIn('dataset', str(context.exception)) + + def test_url_input_rejected(self): + """A spec with URL inputs is rejected as unsupported in local mode.""" + spec_text = textwrap.dedent('''\ + workflow: + name: bad + tasks: + - name: task + image: ubuntu:24.04 + command: ["echo"] + inputs: + - url: s3://my-bucket/data/ + ''') + executor = self._make_executor() + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + with self.assertRaises(ValueError) as context: + executor._validate_for_local(spec) + self.assertIn('URL', str(context.exception)) + + def test_dataset_output_rejected(self): + """A spec with dataset outputs is rejected as unsupported in local mode.""" + spec_text = textwrap.dedent('''\ + workflow: + name: bad + tasks: + - name: task + image: ubuntu:24.04 + command: ["echo"] + outputs: + - dataset: + name: my_dataset + ''') + executor = self._make_executor() + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + with self.assertRaises(ValueError) as context: + executor._validate_for_local(spec) + self.assertIn('dataset', str(context.exception).lower()) + + def test_url_output_rejected(self): + """A spec with URL outputs is rejected as unsupported in local mode.""" + spec_text = textwrap.dedent('''\ + workflow: + name: bad + tasks: + - name: task + image: ubuntu:24.04 + command: ["echo"] + outputs: + - url: s3://my-bucket/models/ + ''') + executor = self._make_executor() + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + with self.assertRaises(ValueError) as context: + executor._validate_for_local(spec) + self.assertIn('object storage', str(context.exception).lower()) + + def test_multiple_unsupported_features_all_reported(self): + """All unsupported features across multiple tasks are reported in a single error.""" + spec_text = textwrap.dedent('''\ + workflow: + name: bad + tasks: + - name: task1 + image: ubuntu:24.04 + command: ["echo"] + inputs: + - url: s3://bucket/data/ + - name: task2 + image: ubuntu:24.04 + command: ["echo"] + inputs: + - dataset: + name: ds + ''') + executor = self._make_executor() + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + with self.assertRaises(ValueError) as context: + executor._validate_for_local(spec) + error_message = str(context.exception) + self.assertIn('task1', error_message) + self.assertIn('task2', error_message) + + def test_task_deps_only_passes(self): + """A spec with only task-to-task dependencies passes local validation.""" + spec_text = textwrap.dedent('''\ + workflow: + name: ok + tasks: + - name: producer + image: alpine:3.18 + command: ["echo"] + - name: consumer + image: alpine:3.18 + command: ["echo"] + inputs: + - task: producer + ''') + executor = self._make_executor() + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._validate_for_local(spec) + + def test_files_and_env_pass(self): + """A spec using files and environment variables passes local validation.""" + spec_text = textwrap.dedent('''\ + workflow: + name: ok + tasks: + - name: task + image: alpine:3.18 + command: ["sh", "/tmp/run.sh"] + environment: + MY_VAR: hello + files: + - contents: echo hi + path: /tmp/run.sh + ''') + executor = self._make_executor() + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._validate_for_local(spec) + + +class TestValidateForLocalRemainingBranches(unittest.TestCase): + """Verify that _validate_for_local rejects credentials, checkpoint, volumeMounts, privileged, and hostNetwork.""" + + _UNSUPPORTED_SPECS: ClassVar[Dict[str, Any]] = { + 'credentials': { + 'yaml': textwrap.dedent('''\ + workflow: + name: bad + tasks: + - name: task + image: ubuntu:24.04 + command: ["echo"] + credentials: + my-secret: NGC_API_KEY + '''), + 'expected_substring': 'credentials', + }, + 'checkpoint': { + 'yaml': textwrap.dedent('''\ + workflow: + name: bad + tasks: + - name: task + image: ubuntu:24.04 + command: ["echo"] + checkpoint: + - path: /output/model + url: s3://bucket/checkpoints/ + frequency: 300 + '''), + 'expected_substring': 'checkpoint', + }, + 'volumeMounts': { + 'yaml': textwrap.dedent('''\ + workflow: + name: bad + tasks: + - name: task + image: ubuntu:24.04 + command: ["echo"] + volumeMounts: + - "/data:/data:ro" + '''), + 'expected_substring': 'volumeMounts', + }, + 'privileged': { + 'yaml': textwrap.dedent('''\ + workflow: + name: bad + tasks: + - name: task + image: ubuntu:24.04 + command: ["echo"] + privileged: true + '''), + 'expected_substring': 'privileged', + }, + 'hostNetwork': { + 'yaml': textwrap.dedent('''\ + workflow: + name: bad + tasks: + - name: task + image: ubuntu:24.04 + command: ["echo"] + hostNetwork: true + '''), + 'expected_substring': 'hostNetwork', + }, + } + + def test_unsupported_fields_rejected(self): + """Each unsupported task-level field is detected and rejected with a descriptive error.""" + for feature, case in self._UNSUPPORTED_SPECS.items(): + with self.subTest(feature=feature): + executor = LocalExecutor(work_dir='/tmp/unused') + spec = executor.load_spec(case['yaml']) + executor._build_dag(spec) + with self.assertRaises(ValueError) as context: + executor._validate_for_local(spec) + self.assertIn(case['expected_substring'], str(context.exception)) + + +class TestFilePathTraversal(unittest.TestCase): + """Verify that file paths cannot escape the task directory.""" + + def setUp(self): + """Create a temporary work directory.""" + self.work_dir = tempfile.mkdtemp(prefix='osmo-local-traversal-') + + def tearDown(self): + """Remove the temporary work directory.""" + shutil.rmtree(self.work_dir, ignore_errors=True) + + def test_path_traversal_rejected(self): + """A file spec with a path that escapes the task directory raises ValueError.""" + spec_text = textwrap.dedent('''\ + workflow: + name: traversal + tasks: + - name: task + image: alpine:3.18 + command: ["echo"] + files: + - contents: "malicious" + path: /../../etc/evil.conf + ''') + executor = LocalExecutor(work_dir=self.work_dir, keep_work_dir=True) + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._setup_directories() + with self.assertRaises(ValueError) as context: + executor._generate_compose_config(spec, {'task'}) + self.assertIn('escapes the task directory', str(context.exception)) + + def test_safe_nested_path_accepted(self): + """A file spec with a safe nested path is accepted without error.""" + spec_text = textwrap.dedent('''\ + workflow: + name: safe + tasks: + - name: task + image: alpine:3.18 + command: ["echo"] + files: + - contents: "safe" + path: /tmp/scripts/run.sh + ''') + executor = LocalExecutor(work_dir=self.work_dir, keep_work_dir=True) + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._setup_directories() + config = executor._generate_compose_config(spec, {'task'}) + self.assertIn('task', config['services']) + + +class TestShmSize(unittest.TestCase): + """Verify that shm_size appears in the compose config for GPU tasks.""" + + def setUp(self): + """Create a temporary work directory for shm-size tests.""" + self.work_dir = tempfile.mkdtemp(prefix='osmo-local-shm-') + + def tearDown(self): + """Remove the temporary work directory after each test.""" + shutil.rmtree(self.work_dir, ignore_errors=True) + + def test_gpu_task_gets_default_shm_size(self): + """A GPU task includes shm_size with the default value when none is specified.""" + spec_text = textwrap.dedent('''\ + workflow: + name: shm-test + resources: + gpu-resource: + gpu: 1 + tasks: + - name: train + image: pytorch:latest + resource: gpu-resource + command: ["python", "train.py"] + ''') + executor = LocalExecutor(work_dir=self.work_dir, keep_work_dir=True) + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._setup_directories() + config = executor._generate_compose_config(spec, {'train'}) + + self.assertEqual(config['services']['train']['shm_size'], '16g') + + def test_gpu_task_gets_custom_shm_size(self): + """A GPU task uses the user-specified shm_size value.""" + spec_text = textwrap.dedent('''\ + workflow: + name: shm-test + resources: + gpu-resource: + gpu: 1 + tasks: + - name: train + image: pytorch:latest + resource: gpu-resource + command: ["python", "train.py"] + ''') + executor = LocalExecutor(work_dir=self.work_dir, keep_work_dir=True, shm_size='32g') + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._setup_directories() + config = executor._generate_compose_config(spec, {'train'}) + + self.assertEqual(config['services']['train']['shm_size'], '32g') + + def test_non_gpu_task_has_no_default_shm_size(self): + """A CPU-only task without explicit shm_size does not include shm_size.""" + spec_text = textwrap.dedent('''\ + workflow: + name: no-gpu + tasks: + - name: preprocess + image: alpine:3.18 + command: ["echo", "ok"] + ''') + executor = LocalExecutor(work_dir=self.work_dir, keep_work_dir=True) + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._setup_directories() + config = executor._generate_compose_config(spec, {'preprocess'}) + + self.assertNotIn('shm_size', config['services']['preprocess']) + + def test_non_gpu_task_gets_explicit_shm_size(self): + """A CPU-only task gets shm_size when the user explicitly specifies it.""" + spec_text = textwrap.dedent('''\ + workflow: + name: no-gpu + tasks: + - name: preprocess + image: alpine:3.18 + command: ["echo", "ok"] + ''') + executor = LocalExecutor(work_dir=self.work_dir, keep_work_dir=True, shm_size='8g') + spec = executor.load_spec(spec_text) + executor._build_dag(spec) + executor._setup_directories() + config = executor._generate_compose_config(spec, {'preprocess'}) + + self.assertEqual(config['services']['preprocess']['shm_size'], '8g') + + +class TestJinjaTemplateDetection(unittest.TestCase): + """Verify that specs containing Jinja template markers are rejected before execution.""" + + def _write_temp_spec(self, content: str) -> str: + """Write YAML content to a temporary file and return its path.""" + f = tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) + f.write(content) + f.flush() + f.close() + return f.name + + def test_jinja_block_detected(self): + """A spec containing {% %} Jinja block tags is rejected.""" + path = self._write_temp_spec(textwrap.dedent('''\ + workflow: + name: {% if true %}test{% endif %} + tasks: + - name: task + image: alpine:3.18 + command: ["echo"] + ''')) + try: + with self.assertRaises(ValueError) as context: + run_workflow_locally(path) + self.assertIn('Jinja', str(context.exception)) + finally: + os.unlink(path) + + def test_jinja_comment_detected(self): + """A spec containing {# #} Jinja comment tags is rejected.""" + path = self._write_temp_spec(textwrap.dedent('''\ + {# A comment #} + workflow: + name: test + tasks: + - name: task + image: alpine:3.18 + command: ["echo"] + ''')) + try: + with self.assertRaises(ValueError) as context: + run_workflow_locally(path) + self.assertIn('Jinja', str(context.exception)) + finally: + os.unlink(path) + + def test_default_values_resolved_locally(self): + """A spec with default-values is resolved locally, not rejected.""" + path = self._write_temp_spec(textwrap.dedent('''\ + workflow: + name: "{{experiment_name}}" + tasks: + - name: task + image: alpine:3.18 + command: ["echo"] + default-values: + experiment_name: my-experiment + ''')) + try: + with open(path, encoding='utf-8') as f: + spec_text = f.read() + resolved = spec_includes.resolve_default_values(spec_text) + self.assertNotIn('default-values', resolved) + self.assertIn('my-experiment', resolved) + self.assertNotIn('{{experiment_name}}', resolved) + finally: + os.unlink(path) + + +class TestUnresolvedVariables(unittest.TestCase): + """Verify that unresolved template variables are detected.""" + + def test_unresolved_variable_detected(self): + """A bare {{ variable }} without a default-values entry is rejected.""" + spec_text = textwrap.dedent('''\ + workflow: + name: "{{ missing_var }}" + tasks: + - name: task + image: alpine:3.18 + command: ["echo"] + ''') + with self.assertRaises(ValueError) as context: + check_unresolved_variables(spec_text) + self.assertIn('missing_var', str(context.exception)) + + def test_osmo_runtime_tokens_allowed(self): + """OSMO runtime tokens (output, input, host, etc.) are not flagged.""" + spec_text = textwrap.dedent('''\ + workflow: + name: ok + tasks: + - name: task + image: alpine:3.18 + command: ["sh", "-c"] + args: ["echo data > {{output}}/file.txt && cat {{input:0}}/data.csv"] + environment: + HOST: "{{ host:worker }}" + ''') + check_unresolved_variables(spec_text) + + def test_mixed_resolved_and_unresolved(self): + """OSMO tokens pass but an unresolved variable is caught.""" + spec_text = textwrap.dedent('''\ + workflow: + name: "{{ undefined_name }}" + tasks: + - name: task + image: alpine:3.18 + command: ["sh", "-c"] + args: ["echo > {{output}}/file.txt"] + ''') + with self.assertRaises(ValueError) as context: + check_unresolved_variables(spec_text) + self.assertIn('undefined_name', str(context.exception)) + self.assertNotIn('output', str(context.exception)) + + def test_unresolved_variable_in_run_workflow_locally(self): + """run_workflow_locally rejects specs with unresolved variables.""" + path = tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) + path.write(textwrap.dedent('''\ + workflow: + name: "{{ missing }}" + tasks: + - name: task + image: alpine:3.18 + command: ["echo"] + ''')) + path.flush() + path.close() + try: + with self.assertRaises(ValueError) as context: + run_workflow_locally(path.name) + self.assertIn('Unresolved template variables', str(context.exception)) + finally: + os.unlink(path.name) + + def test_default_values_prevents_unresolved_error(self): + """Variables defined in default-values are resolved and don't trigger the check.""" + path = tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) + path.write(textwrap.dedent('''\ + workflow: + name: "{{ my_name }}" + tasks: + - name: task + image: alpine:3.18 + command: ["echo"] + default-values: + my_name: resolved-name + ''')) + path.flush() + path.close() + try: + with open(path.name, encoding='utf-8') as f: + spec_text = f.read() + resolved = spec_includes.resolve_default_values(spec_text) + check_unresolved_variables(resolved) + self.assertIn('resolved-name', resolved) + finally: + os.unlink(path.name) + + +class TestIncludesWithDefaultValues(unittest.TestCase): + """Verify that default-values from includes are resolved for local execution.""" + + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir, ignore_errors=True) + + def _write_file(self, name: str, content: str) -> str: + path = os.path.join(self.test_dir, name) + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w', encoding='utf-8') as f: + f.write(textwrap.dedent(content)) + return path + + def test_included_default_values_stripped(self): + """A base file's default-values should be resolved and stripped for local execution.""" + self._write_file('base.yaml', textwrap.dedent('''\ + workflow: + name: base + resources: + default: + cpu: 4 + default-values: + unused_var: some_value + ''')) + main_path = self._write_file('main.yaml', textwrap.dedent('''\ + includes: + - base.yaml + workflow: + name: local-test + tasks: + - name: hello + image: ubuntu:24.04 + command: ["echo"] + args: ["hello"] + ''')) + with open(main_path, encoding='utf-8') as f: + spec_text = f.read() + + abs_path = os.path.abspath(main_path) + spec_text = spec_includes.resolve_includes( + spec_text, os.path.dirname(abs_path), source_path=abs_path) + + resolved_dict = yaml.safe_load(spec_text) + self.assertIn('default-values', resolved_dict, + 'Resolved spec should have default-values from base') + + spec_text = spec_includes.resolve_default_values(spec_text) + self.assertNotIn('default-values', spec_text, + 'resolve_default_values should strip the section') + + def test_included_default_values_resolved(self): + """A spec using default-values variables from an included base is resolved locally.""" + self._write_file('base.yaml', textwrap.dedent('''\ + workflow: + name: base + resources: + default: + cpu: 4 + ''')) + main_path = self._write_file('main.yaml', textwrap.dedent('''\ + includes: + - base.yaml + workflow: + name: "{{ experiment_name }}" + tasks: + - name: hello + image: ubuntu:24.04 + command: ["echo"] + default-values: + experiment_name: my-experiment + ''')) + with open(main_path, encoding='utf-8') as f: + spec_text = f.read() + + abs_path = os.path.abspath(main_path) + spec_text = spec_includes.resolve_includes( + spec_text, os.path.dirname(abs_path), source_path=abs_path) + spec_text = spec_includes.resolve_default_values(spec_text) + + self.assertNotIn('default-values', spec_text) + self.assertIn('my-experiment', spec_text) + self.assertNotIn('{{ experiment_name }}', spec_text) + + +# ============================================================================ +# Tests that exercise error paths without requiring Docker +# ============================================================================ +class TestDockerNotFoundHandling(unittest.TestCase): + """Verify graceful failure when Docker is not available (no Docker required to run).""" + + def setUp(self): + """Create a temporary work directory.""" + self.work_dir = tempfile.mkdtemp(prefix='osmo-local-test-') + + def tearDown(self): + """Remove the temporary work directory.""" + shutil.rmtree(self.work_dir, ignore_errors=True) + + def test_docker_not_found_graceful_failure(self): + """Using a non-existent docker binary results in a graceful failure rather than a crash.""" + spec_text = textwrap.dedent('''\ + workflow: + name: no-docker + tasks: + - name: task + image: alpine:3.18 + command: ["echo", "ok"] + ''') + executor = LocalExecutor( + work_dir=self.work_dir, + keep_work_dir=True, + docker_cmd='nonexistent-docker-binary-12345', + ) + spec = executor.load_spec(spec_text) + self.assertFalse(executor.execute(spec)) + + +class TestCookbookSpecValidation(unittest.TestCase): + """ + Validate that cookbook specs using unsupported features are rejected + before any container is started (no Docker required to run). + """ + + COOKBOOK_DIR = os.path.join(os.path.dirname(__file__), '..', '..', '..', + 'cookbook', 'tutorials') + + def setUp(self): + """Create a temporary work directory for cookbook validation tests.""" + self.work_dir = tempfile.mkdtemp(prefix='osmo-local-cookbook-') + + def tearDown(self): + """Remove the temporary work directory after each test.""" + shutil.rmtree(self.work_dir, ignore_errors=True) + + def _run_cookbook_spec(self, filename: str) -> bool: + """Execute a cookbook tutorial spec file through the local executor.""" + spec_path = os.path.join(self.COOKBOOK_DIR, filename) + self.assertTrue(os.path.exists(spec_path), + f'Cookbook file not found: {spec_path}') + return run_workflow_locally( + spec_path=spec_path, + work_dir=self.work_dir, + keep_work_dir=True, + ) + + def test_unsupported_spec_data_download(self): + """data_download.yaml uses URL inputs — verify it is cleanly rejected.""" + with self.assertRaises(ValueError) as context: + self._run_cookbook_spec('data_download.yaml') + self.assertIn('URL', str(context.exception)) + + def test_unsupported_spec_data_upload(self): + """data_upload.yaml uses URL outputs — verify it is cleanly rejected.""" + with self.assertRaises(ValueError) as context: + self._run_cookbook_spec('data_upload.yaml') + self.assertIn('object storage', str(context.exception).lower()) + + def test_unsupported_spec_dataset_upload(self): + """dataset_upload.yaml uses dataset outputs — verify it is cleanly rejected.""" + with self.assertRaises(ValueError) as context: + self._run_cookbook_spec('dataset_upload.yaml') + self.assertIn('dataset', str(context.exception).lower()) + + def test_template_spec_resolved_locally(self): + """template_hello_world.yaml uses default-values — verify variables are resolved.""" + spec_path = os.path.join(self.COOKBOOK_DIR, 'template_hello_world.yaml') + self.assertTrue(os.path.exists(spec_path), + f'Cookbook file not found: {spec_path}') + with open(spec_path, encoding='utf-8') as f: + spec_text = f.read() + resolved = spec_includes.resolve_default_values(spec_text) + self.assertNotIn('default-values', resolved) + self.assertIn('hello-osmo', resolved) + self.assertNotIn('{{workflow_name}}', resolved) + self.assertNotIn('{{ubuntu_version}}', resolved) + self.assertNotIn('{{message}}', resolved) + self.assertIn('Hello from OSMO!', resolved) + + +class TestComposeCommand(unittest.TestCase): + """Verify ``osmo local compose`` resolves includes and default-values.""" + + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir, ignore_errors=True) + + def _write_file(self, name: str, content: str) -> str: + path = os.path.join(self.test_dir, name) + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w', encoding='utf-8') as f: + f.write(textwrap.dedent(content)) + return path + + def test_compose_resolves_includes_and_defaults(self): + """Compose flattens includes and substitutes default-values variables.""" + self._write_file('base.yaml', '''\ + workflow: + resources: + default: + cpu: 4 + ''') + main_path = self._write_file('main.yaml', '''\ + includes: + - base.yaml + workflow: + name: "{{experiment}}" + tasks: + - name: train + image: ubuntu:24.04 + command: ["echo", "hi"] + default-values: + experiment: my-run + ''') + + abs_path = os.path.abspath(main_path) + with open(abs_path, encoding='utf-8') as f: + spec_text = f.read() + + spec_text = spec_includes.resolve_includes( + spec_text, os.path.dirname(abs_path), source_path=abs_path) + spec_text = spec_includes.resolve_default_values(spec_text) + + parsed = yaml.safe_load(spec_text) + self.assertNotIn('includes', parsed) + self.assertNotIn('default-values', parsed) + self.assertEqual(parsed['workflow']['name'], 'my-run') + self.assertEqual(parsed['workflow']['resources']['default']['cpu'], 4) + + def test_compose_writes_output_file(self): + """Compose with -o writes the flat spec to the given path.""" + main_path = self._write_file('spec.yaml', '''\ + workflow: + name: simple + tasks: + - name: hello + image: alpine:3.18 + command: ["echo"] + default-values: + unused_var: value + ''') + output_path = os.path.join(self.test_dir, 'composed.yaml') + + abs_path = os.path.abspath(main_path) + with open(abs_path, encoding='utf-8') as f: + spec_text = f.read() + spec_text = spec_includes.resolve_includes( + spec_text, os.path.dirname(abs_path), source_path=abs_path) + spec_text = spec_includes.resolve_default_values(spec_text) + + with open(output_path, 'w', encoding='utf-8') as f: + f.write(spec_text) + + with open(output_path, encoding='utf-8') as f: + composed = yaml.safe_load(f.read()) + self.assertNotIn('default-values', composed) + self.assertEqual(composed['workflow']['name'], 'simple') + + def test_compose_result_is_submittable(self): + """The composed spec has no includes or default-values and is valid YAML.""" + self._write_file('shared.yaml', '''\ + workflow: + resources: + default: + cpu: 2 + memory: 4Gi + default-values: + base_image: ubuntu:24.04 + ''') + main_path = self._write_file('pipeline.yaml', '''\ + includes: + - shared.yaml + workflow: + name: pipeline + tasks: + - name: step1 + image: "{base_image}" + command: ["echo", "hello"] + default-values: + base_image: nvidia/cuda:12.0-base + ''') + + abs_path = os.path.abspath(main_path) + with open(abs_path, encoding='utf-8') as f: + spec_text = f.read() + spec_text = spec_includes.resolve_includes( + spec_text, os.path.dirname(abs_path), source_path=abs_path) + spec_text = spec_includes.resolve_default_values(spec_text) + + parsed = yaml.safe_load(spec_text) + self.assertNotIn('includes', parsed) + self.assertNotIn('default-values', parsed) + self.assertEqual(parsed['workflow']['name'], 'pipeline') + self.assertEqual(parsed['workflow']['tasks'][0]['image'], + 'nvidia/cuda:12.0-base') + self.assertEqual(parsed['workflow']['resources']['default']['cpu'], 2) + + +class TestSubmitIncludesResolution(unittest.TestCase): + """Verify that _load_wf_file resolves includes before sending to server.""" + + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir, ignore_errors=True) + + def _write_file(self, name: str, content: str) -> str: + path = os.path.join(self.test_dir, name) + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w', encoding='utf-8') as f: + f.write(textwrap.dedent(content)) + return path + + def test_load_wf_file_flattens_includes(self): + """_load_wf_file merges included files so the server receives a flat spec.""" + from src.cli.workflow import _load_wf_file + + self._write_file('base.yaml', '''\ + workflow: + resources: + default: + cpu: 8 + ''') + main_path = self._write_file('main.yaml', '''\ + includes: + - base.yaml + workflow: + name: test-wf + tasks: + - name: task1 + image: alpine:3.18 + command: ["echo"] + ''') + + template_data = _load_wf_file(main_path, [], []) + parsed = yaml.safe_load(template_data.file) + + self.assertNotIn('includes', parsed) + self.assertEqual(parsed['workflow']['resources']['default']['cpu'], 8) + self.assertEqual(parsed['workflow']['name'], 'test-wf') + + def test_load_wf_file_preserves_template_markers(self): + """Includes are resolved but Jinja/default-values markers are preserved for the server.""" + from src.cli.workflow import _load_wf_file + + self._write_file('base.yaml', '''\ + workflow: + resources: + default: + cpu: 4 + ''') + main_path = self._write_file('main.yaml', '''\ + includes: + - base.yaml + workflow: + name: "{{experiment}}" + tasks: + - name: task1 + image: alpine:3.18 + command: ["echo"] + default-values: + experiment: my-exp + ''') + + template_data = _load_wf_file(main_path, [], []) + self.assertTrue(template_data.is_templated) + self.assertNotIn('includes', template_data.file) + self.assertIn('default-values', template_data.file) + self.assertIn('{{experiment}}', template_data.file) + + +class TestRunWorkflowLocallyErrors(unittest.TestCase): + """Test error handling in run_workflow_locally() that does not require Docker.""" + + def test_nonexistent_file_raises(self): + """Passing a non-existent spec file path raises FileNotFoundError.""" + with self.assertRaises(FileNotFoundError): + run_workflow_locally(spec_path='/nonexistent/path/spec.yaml') + + def test_unset_env_var_raises_with_helpful_message(self): + """Specs referencing unset ${env:VAR} in default-values raise with env var name.""" + path = tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) + path.write(textwrap.dedent('''\ + default-values: + root: "${env:__OSMO_TEST_NEVER_SET_12345__}" + workflow: + name: test + tasks: + - name: task + image: alpine:3.18 + command: ["echo"] + args: ["{root}/data"] + ''')) + path.flush() + path.close() + try: + with self.assertRaises(ValueError) as context: + run_workflow_locally(path.name) + self.assertIn('__OSMO_TEST_NEVER_SET_12345__', str(context.exception)) + self.assertIn('not set', str(context.exception)) + finally: + os.unlink(path.name) + + +# ============================================================================ +# Integration tests — require Docker + Compose; test actual container execution +# ============================================================================ +@unittest.skipUnless(DOCKER_AVAILABLE, SKIP_DOCKER_MSG) +class TestDockerExecution(unittest.TestCase): + """ + Integration tests that run real OSMO workflow specs through the local executor + using Docker Compose. Each test uses a spec that would normally run on a Kubernetes cluster. + """ + + def setUp(self): + """Create a temporary work directory for each Docker execution test.""" + self.work_dir = tempfile.mkdtemp(prefix='osmo-local-test-') + + def tearDown(self): + """Remove the temporary work directory after each test.""" + shutil.rmtree(self.work_dir, ignore_errors=True) + + def _execute_spec(self, spec_text: str) -> bool: + """Parse and execute a workflow spec string, returning the success status.""" + executor = LocalExecutor(work_dir=self.work_dir, keep_work_dir=True) + spec = executor.load_spec(spec_text) + return executor.execute(spec) + + # ---- Single task tests ---- + + def test_hello_world(self): + """Run a minimal single-task workflow that echoes a message.""" + spec_text = textwrap.dedent('''\ + workflow: + name: hello-osmo + tasks: + - name: hello + image: alpine:3.18 + command: ["echo", "Hello from OSMO!"] + ''') + self.assertTrue(self._execute_spec(spec_text)) + + def test_single_task_with_args(self): + """Run a task with separate command and args fields.""" + spec_text = textwrap.dedent('''\ + workflow: + name: args-test + tasks: + - name: task + image: alpine:3.18 + command: ["echo"] + args: ["argument1", "argument2"] + ''') + self.assertTrue(self._execute_spec(spec_text)) + + def test_task_failure_returns_false(self): + """A task that exits with a non-zero code causes execute() to return False.""" + spec_text = textwrap.dedent('''\ + workflow: + name: will-fail + tasks: + - name: failing-task + image: alpine:3.18 + command: ["sh", "-c", "exit 42"] + ''') + self.assertFalse(self._execute_spec(spec_text)) + + # ---- Environment variable tests ---- + + def test_environment_variables(self): + """Environment variables declared in the spec are passed to the Docker container.""" + spec_text = textwrap.dedent('''\ + workflow: + name: env-test + tasks: + - name: check-env + image: alpine:3.18 + command: ["sh", "-c"] + args: ["test \\"$MY_VAR\\" = \\"hello_world\\" && test \\"$SECOND\\" = \\"42\\""] + environment: + MY_VAR: hello_world + SECOND: "42" + ''') + self.assertTrue(self._execute_spec(spec_text)) + + # ---- Files mount tests ---- + + def test_inline_file_mounted(self): + """An inline file declared in the spec is mounted and executable inside the container.""" + spec_text = textwrap.dedent('''\ + workflow: + name: files-test + tasks: + - name: check-file + image: alpine:3.18 + command: ["sh", "/tmp/run.sh"] + files: + - contents: | + echo "script ran successfully" + path: /tmp/run.sh + ''') + self.assertTrue(self._execute_spec(spec_text)) + + def test_multiple_files_mounted(self): + """Multiple inline files at different paths are all mounted into the container.""" + spec_text = textwrap.dedent('''\ + workflow: + name: multi-files + tasks: + - name: check-files + image: alpine:3.18 + command: ["sh", "-c"] + args: ["cat /tmp/config.txt && sh /scripts/run.sh"] + files: + - contents: "key=value" + path: /tmp/config.txt + - contents: | + echo "second script ok" + path: /scripts/run.sh + ''') + self.assertTrue(self._execute_spec(spec_text)) + + # ---- Data output tests ---- + + def test_output_directory_writable(self): + """The {{output}} directory is writable from inside the container and persists on the host.""" + spec_text = textwrap.dedent('''\ + workflow: + name: output-test + tasks: + - name: write-output + image: alpine:3.18 + command: ["sh", "-c"] + args: ["echo 'payload' > {{output}}/result.txt"] + ''') + self.assertTrue(self._execute_spec(spec_text)) + output_file = os.path.join(self.work_dir, 'write-output', 'output', 'result.txt') + self.assertTrue(os.path.exists(output_file)) + with open(output_file) as f: + self.assertEqual(f.read().strip(), 'payload') + + # ---- Serial data flow tests ---- + + def test_serial_data_flow_two_tasks(self): + """Data written to {{output}} by a producer is readable via {{input:0}} by the consumer.""" + spec_text = textwrap.dedent('''\ + workflow: + name: serial-data + tasks: + - name: producer + image: alpine:3.18 + command: ["sh", "-c"] + args: ["echo 'from_producer' > {{output}}/data.txt"] + - name: consumer + image: alpine:3.18 + command: ["sh", "-c"] + args: ["cat {{input:0}}/data.txt > {{output}}/received.txt"] + inputs: + - task: producer + ''') + self.assertTrue(self._execute_spec(spec_text)) + received = os.path.join(self.work_dir, 'consumer', 'output', 'received.txt') + self.assertTrue(os.path.exists(received)) + with open(received) as f: + self.assertEqual(f.read().strip(), 'from_producer') + + def test_serial_chain_three_tasks(self): + """Mimics cookbook/tutorials/serial_workflow.yaml""" + spec_text = textwrap.dedent('''\ + workflow: + name: serial-chain + tasks: + - name: task1 + image: alpine:3.18 + command: ["sh", "-c"] + args: ["echo 'task1_data' > {{output}}/result.txt"] + + - name: task2 + image: alpine:3.18 + command: ["sh", "-c"] + args: + - | + cat {{input:0}}/result.txt > {{output}}/result.txt + echo '_plus_task2' >> {{output}}/result.txt + inputs: + - task: task1 + + - name: task3 + image: alpine:3.18 + command: ["sh", "-c"] + args: + - | + cat {{input:0}}/result.txt > {{output}}/final.txt + cat {{input:1}}/result.txt >> {{output}}/final.txt + inputs: + - task: task1 + - task: task2 + ''') + self.assertTrue(self._execute_spec(spec_text)) + final = os.path.join(self.work_dir, 'task3', 'output', 'final.txt') + with open(final) as f: + content = f.read() + self.assertIn('task1_data', content) + self.assertIn('_plus_task2', content) + + # ---- Parallel execution tests ---- + + def test_parallel_independent_tasks(self): + """Independent tasks with no dependencies all execute and produce their respective outputs.""" + spec_text = textwrap.dedent('''\ + workflow: + name: parallel-tasks + tasks: + - name: task-a + image: alpine:3.18 + command: ["sh", "-c"] + args: ["echo 'a' > {{output}}/marker.txt"] + - name: task-b + image: alpine:3.18 + command: ["sh", "-c"] + args: ["echo 'b' > {{output}}/marker.txt"] + - name: task-c + image: alpine:3.18 + command: ["sh", "-c"] + args: ["echo 'c' > {{output}}/marker.txt"] + ''') + self.assertTrue(self._execute_spec(spec_text)) + for task_name, expected in [('task-a', 'a'), ('task-b', 'b'), ('task-c', 'c')]: + marker = os.path.join(self.work_dir, task_name, 'output', 'marker.txt') + with open(marker) as f: + self.assertEqual(f.read().strip(), expected) + + # ---- Diamond DAG tests ---- + + def test_diamond_dag(self): + """A diamond-shaped DAG executes correctly with fan-out and fan-in data flow.""" + spec_text = textwrap.dedent('''\ + workflow: + name: diamond + tasks: + - name: root + image: alpine:3.18 + command: ["sh", "-c"] + args: ["echo 'root_data' > {{output}}/base.txt"] + - name: left + image: alpine:3.18 + command: ["sh", "-c"] + args: ["echo 'left:' > {{output}}/result.txt && cat {{input:0}}/base.txt >> {{output}}/result.txt"] + inputs: + - task: root + - name: right + image: alpine:3.18 + command: ["sh", "-c"] + args: ["echo 'right:' > {{output}}/result.txt && cat {{input:0}}/base.txt >> {{output}}/result.txt"] + inputs: + - task: root + - name: join + image: alpine:3.18 + command: ["sh", "-c"] + args: ["cat {{input:0}}/result.txt > {{output}}/final.txt && cat {{input:1}}/result.txt >> {{output}}/final.txt"] + inputs: + - task: left + - task: right + ''') + self.assertTrue(self._execute_spec(spec_text)) + final = os.path.join(self.work_dir, 'join', 'output', 'final.txt') + with open(final) as f: + content = f.read() + self.assertIn('left:', content) + self.assertIn('right:', content) + self.assertIn('root_data', content) + + # ---- Failure propagation tests ---- + + def test_failure_cancels_downstream(self): + """A failed task prevents its downstream dependent from running.""" + spec_text = textwrap.dedent('''\ + workflow: + name: fail-chain + tasks: + - name: failing + image: alpine:3.18 + command: ["sh", "-c", "exit 1"] + - name: should-not-run + image: alpine:3.18 + command: ["sh", "-c", "echo 'oops' > {{output}}/should_not_exist.txt"] + inputs: + - task: failing + ''') + self.assertFalse(self._execute_spec(spec_text)) + output_file = os.path.join(self.work_dir, 'should-not-run', 'output', 'should_not_exist.txt') + self.assertFalse(os.path.exists(output_file)) + + def test_parallel_failure_does_not_affect_independent_branch(self): + """When one branch of a parallel DAG fails, the executor stops with overall failure.""" + spec_text = textwrap.dedent('''\ + workflow: + name: partial-fail + tasks: + - name: root + image: alpine:3.18 + command: ["sh", "-c"] + args: ["echo ok > {{output}}/data.txt"] + - name: fail-branch + image: alpine:3.18 + command: ["sh", "-c", "exit 1"] + inputs: + - task: root + - name: ok-branch + image: alpine:3.18 + command: ["sh", "-c"] + args: ["cat {{input:0}}/data.txt > {{output}}/received.txt"] + inputs: + - task: root + ''') + result = self._execute_spec(spec_text) + self.assertFalse(result) + + # ---- Groups (ganged tasks) tests ---- + + def test_group_with_single_task(self): + """A group containing a single lead task executes and produces output.""" + spec_text = textwrap.dedent('''\ + workflow: + name: single-group + groups: + - name: my-group + tasks: + - name: leader + lead: true + image: alpine:3.18 + command: ["sh", "-c"] + args: ["echo 'group_ok' > {{output}}/marker.txt"] + ''') + self.assertTrue(self._execute_spec(spec_text)) + marker = os.path.join(self.work_dir, 'leader', 'output', 'marker.txt') + with open(marker) as f: + self.assertEqual(f.read().strip(), 'group_ok') + + def test_groups_with_data_flow(self): + """Mimics cookbook/tutorials/combination_workflow_simple.yaml structure.""" + spec_text = textwrap.dedent('''\ + workflow: + name: data-pipeline + groups: + - name: prepare-data + tasks: + - name: generate-dataset + lead: true + image: alpine:3.18 + command: ["sh", "-c"] + args: + - | + mkdir -p {{output}}/data + for i in 1 2 3; do echo "sample_$i" >> {{output}}/data/dataset.csv; done + - name: train-models + tasks: + - name: train-model + lead: true + image: alpine:3.18 + command: ["sh", "-c"] + args: + - | + wc -l {{input:0}}/data/dataset.csv > {{output}}/line_count.txt + inputs: + - task: generate-dataset + ''') + self.assertTrue(self._execute_spec(spec_text)) + line_count_file = os.path.join(self.work_dir, 'train-model', 'output', 'line_count.txt') + with open(line_count_file) as f: + content = f.read() + self.assertIn('3', content) + + # ---- Input by task name tests ---- + + def test_input_by_task_name(self): + """The {{input:taskname}} token resolves to the named upstream task's output directory.""" + spec_text = textwrap.dedent('''\ + workflow: + name: named-input + tasks: + - name: producer + image: alpine:3.18 + command: ["sh", "-c"] + args: ["echo 'named_data' > {{output}}/out.txt"] + - name: consumer + image: alpine:3.18 + command: ["sh", "-c"] + args: ["cat {{input:producer}}/out.txt > {{output}}/received.txt"] + inputs: + - task: producer + ''') + self.assertTrue(self._execute_spec(spec_text)) + received = os.path.join(self.work_dir, 'consumer', 'output', 'received.txt') + with open(received) as f: + self.assertEqual(f.read().strip(), 'named_data') + + # ---- Files with token substitution ---- + + def test_file_contents_with_token_substitution(self): + """Mimics cookbook/tutorials/serial_workflow.yaml pattern of inline scripts with tokens.""" + spec_text = textwrap.dedent('''\ + workflow: + name: file-tokens + tasks: + - name: writer + image: alpine:3.18 + command: ["sh", "/tmp/run.sh"] + files: + - contents: | + echo "writing output" + echo "file_data" > {{output}}/result.txt + path: /tmp/run.sh + - name: reader + image: alpine:3.18 + command: ["sh", "/tmp/run.sh"] + files: + - contents: | + cat {{input:0}}/result.txt > {{output}}/received.txt + path: /tmp/run.sh + inputs: + - task: writer + ''') + self.assertTrue(self._execute_spec(spec_text)) + received = os.path.join(self.work_dir, 'reader', 'output', 'received.txt') + with open(received) as f: + self.assertEqual(f.read().strip(), 'file_data') + + # ---- Resource spec ignored gracefully ---- + + def test_resources_ignored_gracefully(self): + """Resource specs are K8s-specific; local executor should accept and ignore them.""" + spec_text = textwrap.dedent('''\ + workflow: + name: with-resources + resources: + default: + cpu: 2 + memory: 4Gi + storage: 10Gi + tasks: + - name: task + image: alpine:3.18 + command: ["echo", "ok"] + ''') + self.assertTrue(self._execute_spec(spec_text)) + + # ---- Alternative container runtime ---- + + def test_custom_docker_command(self): + """An explicitly specified docker command is used to run the container.""" + spec_text = textwrap.dedent('''\ + workflow: + name: custom-cmd + tasks: + - name: task + image: alpine:3.18 + command: ["echo", "ok"] + ''') + executor = LocalExecutor( + work_dir=self.work_dir, + keep_work_dir=True, + docker_cmd='docker', + ) + spec = executor.load_spec(spec_text) + self.assertTrue(executor.execute(spec)) + + # ---- Compose file is generated ---- + + def test_compose_file_generated(self): + """Executing a workflow generates a docker-compose.yml in the work directory.""" + spec_text = textwrap.dedent('''\ + workflow: + name: compose-check + tasks: + - name: task + image: alpine:3.18 + command: ["echo", "ok"] + ''') + self.assertTrue(self._execute_spec(spec_text)) + compose_path = os.path.join(self.work_dir, 'docker-compose.yml') + self.assertTrue(os.path.exists(compose_path)) + with open(compose_path) as f: + config = yaml.safe_load(f.read()) + self.assertIn('services', config) + self.assertIn('task', config['services']) + + +# ============================================================================ +# Integration tests using actual cookbook spec files from the repo +# ============================================================================ +@unittest.skipUnless(DOCKER_AVAILABLE, SKIP_DOCKER_MSG) +class TestCookbookSpecs(unittest.TestCase): + """ + Run real OSMO cookbook YAML specs that are designed for Kubernetes clusters, + and verify they execute successfully in the local Docker executor. + """ + + COOKBOOK_DIR = os.path.join(os.path.dirname(__file__), '..', '..', '..', + 'cookbook', 'tutorials') + + def setUp(self): + """Create a temporary work directory for cookbook spec tests.""" + self.work_dir = tempfile.mkdtemp(prefix='osmo-local-cookbook-') + + def tearDown(self): + """Remove the temporary work directory after each cookbook test.""" + shutil.rmtree(self.work_dir, ignore_errors=True) + + def _run_cookbook_spec(self, filename: str) -> bool: + """Execute a cookbook tutorial spec file through the local executor.""" + spec_path = os.path.join(self.COOKBOOK_DIR, filename) + self.assertTrue(os.path.exists(spec_path), + f'Cookbook file not found: {spec_path}') + return run_workflow_locally( + spec_path=spec_path, + work_dir=self.work_dir, + keep_work_dir=True, + ) + + def test_hello_world_yaml(self): + """Execute the hello_world.yaml cookbook tutorial spec.""" + self.assertTrue(self._run_cookbook_spec('hello_world.yaml')) + + def test_parallel_tasks_yaml(self): + """Execute the parallel_tasks.yaml cookbook tutorial spec.""" + self.assertTrue(self._run_cookbook_spec('parallel_tasks.yaml')) + + def test_serial_workflow_yaml(self): + """Execute the serial_workflow.yaml cookbook tutorial spec.""" + self.assertTrue(self._run_cookbook_spec('serial_workflow.yaml')) + + def test_resources_basic_yaml(self): + """Execute the resources_basic.yaml cookbook tutorial spec.""" + self.assertTrue(self._run_cookbook_spec('resources_basic.yaml')) + + def test_combination_workflow_simple_yaml(self): + """ + The combination_workflow_simple.yaml has a 'sleep 120' in transform-a. + We skip it here because it would take 2+ minutes; a trimmed version + of the same structure is tested in TestDockerExecution.test_groups_with_data_flow. + """ + self.skipTest('Contains sleep 120; covered by test_groups_with_data_flow') + + +# ============================================================================ +# run_workflow_locally() integration tests +# ============================================================================ +@unittest.skipUnless(DOCKER_AVAILABLE, SKIP_DOCKER_MSG) +class TestRunWorkflowLocally(unittest.TestCase): + """Test the top-level run_workflow_locally() convenience function.""" + + def setUp(self): + """Create a temporary work directory for run_workflow_locally tests.""" + self.work_dir = tempfile.mkdtemp(prefix='osmo-local-func-') + + def tearDown(self): + """Remove the temporary work directory after each test.""" + shutil.rmtree(self.work_dir, ignore_errors=True) + + def test_caller_supplied_work_dir_preserved_on_success(self): + """A caller-supplied work_dir is never deleted, even with keep_work_dir=False.""" + work_dir = tempfile.mkdtemp(prefix='osmo-local-cleanup-') + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write(textwrap.dedent('''\ + workflow: + name: cleanup-test + tasks: + - name: task + image: alpine:3.18 + command: ["echo", "ok"] + ''')) + spec_path = f.name + try: + result = run_workflow_locally( + spec_path=spec_path, + work_dir=work_dir, + keep_work_dir=False, + ) + self.assertTrue(result) + self.assertTrue(os.path.exists(work_dir)) + finally: + os.unlink(spec_path) + if os.path.exists(work_dir): + shutil.rmtree(work_dir, ignore_errors=True) + + def test_failure_preserves_work_dir(self): + """On failure, the work directory is preserved for debugging regardless of the keep flag.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write(textwrap.dedent('''\ + workflow: + name: fail-test + tasks: + - name: task + image: alpine:3.18 + command: ["sh", "-c", "exit 1"] + ''')) + spec_path = f.name + try: + result = run_workflow_locally( + spec_path=spec_path, + work_dir=self.work_dir, + keep_work_dir=False, + ) + self.assertFalse(result) + self.assertTrue(os.path.exists(self.work_dir)) + finally: + os.unlink(spec_path) + + def test_keep_flag_preserves_on_success(self): + """With keep_work_dir=True, the work directory is preserved even on success.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write(textwrap.dedent('''\ + workflow: + name: keep-test + tasks: + - name: task + image: alpine:3.18 + command: ["echo", "ok"] + ''')) + spec_path = f.name + try: + result = run_workflow_locally( + spec_path=spec_path, + work_dir=self.work_dir, + keep_work_dir=True, + ) + self.assertTrue(result) + self.assertTrue(os.path.exists(self.work_dir)) + finally: + os.unlink(spec_path) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/utils/tests/test_spec_includes.py b/src/utils/tests/test_spec_includes.py new file mode 100644 index 000000000..91230d950 --- /dev/null +++ b/src/utils/tests/test_spec_includes.py @@ -0,0 +1,1099 @@ +""" +SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +SPDX-License-Identifier: Apache-2.0 +""" +import os +import shutil +import tempfile +import textwrap +import unittest +from typing import Any, Dict + +import yaml + +from src.lib.utils import osmo_errors +from src.utils.spec_includes import ( + deep_merge_dicts, find_unresolved_env_variables, resolve_default_values, + resolve_includes, +) + + +class DeepMergeDictsTests(unittest.TestCase): + """Unit tests for deep_merge_dicts.""" + + def test_disjoint_keys(self): + result = deep_merge_dicts({'a': 1}, {'b': 2}) + self.assertEqual(result, {'a': 1, 'b': 2}) + + def test_override_scalar(self): + result = deep_merge_dicts({'a': 1}, {'a': 99}) + self.assertEqual(result, {'a': 99}) + + def test_nested_dict_merge(self): + base = {'a': {'x': 1, 'y': 2}, 'b': 3} + override = {'a': {'y': 99, 'z': 100}} + result = deep_merge_dicts(base, override) + self.assertEqual(result, {'a': {'x': 1, 'y': 99, 'z': 100}, 'b': 3}) + + def test_plain_list_replacement(self): + base = {'items': [1, 2, 3]} + override = {'items': [4, 5]} + result = deep_merge_dicts(base, override) + self.assertEqual(result, {'items': [4, 5]}) + + def test_named_list_merge_disjoint(self): + base = {'tasks': [{'name': 'a', 'image': 'img-a'}]} + override = {'tasks': [{'name': 'b', 'image': 'img-b'}]} + result = deep_merge_dicts(base, override) + self.assertEqual(result['tasks'], [ + {'name': 'a', 'image': 'img-a'}, + {'name': 'b', 'image': 'img-b'}, + ]) + + def test_named_list_merge_override_existing(self): + base = {'tasks': [ + {'name': 'train', 'image': 'train:v1', 'command': ['python3']}, + {'name': 'eval', 'image': 'eval:v1'}, + ]} + override = {'tasks': [ + {'name': 'train', 'image': 'train:v2'}, + ]} + result = deep_merge_dicts(base, override) + self.assertEqual(result['tasks'], [ + {'name': 'train', 'image': 'train:v2', 'command': ['python3']}, + {'name': 'eval', 'image': 'eval:v1'}, + ]) + + def test_named_list_preserves_base_order_appends_new(self): + base = {'tasks': [ + {'name': 'first', 'val': 1}, + {'name': 'second', 'val': 2}, + ]} + override = {'tasks': [ + {'name': 'third', 'val': 3}, + {'name': 'first', 'val': 10}, + ]} + result = deep_merge_dicts(base, override) + names = [t['name'] for t in result['tasks']] + self.assertEqual(names, ['first', 'second', 'third']) + self.assertEqual(result['tasks'][0]['val'], 10) + + def test_named_list_empty_base(self): + base: Dict[str, Any] = {'tasks': []} + override: Dict[str, Any] = {'tasks': [{'name': 'a', 'image': 'img'}]} + result = deep_merge_dicts(base, override) + self.assertEqual(result['tasks'], [{'name': 'a', 'image': 'img'}]) + + def test_named_list_empty_override_clears(self): + base: Dict[str, Any] = {'tasks': [{'name': 'a', 'image': 'img'}]} + override: Dict[str, Any] = {'tasks': []} + result = deep_merge_dicts(base, override) + self.assertEqual(result['tasks'], []) + + def test_mixed_list_without_name_key_replaced(self): + base = {'args': [{'cmd': 'echo'}, {'cmd': 'ls'}]} + override = {'args': [{'cmd': 'cat'}]} + result = deep_merge_dicts(base, override) + self.assertEqual(result['args'], [{'cmd': 'cat'}]) + + def test_override_dict_with_scalar(self): + result = deep_merge_dicts({'a': {'nested': 1}}, {'a': 'flat'}) + self.assertEqual(result, {'a': 'flat'}) + + def test_override_scalar_with_dict(self): + result = deep_merge_dicts({'a': 'flat'}, {'a': {'nested': 1}}) + self.assertEqual(result, {'a': {'nested': 1}}) + + def test_empty_base(self): + result = deep_merge_dicts({}, {'a': 1}) + self.assertEqual(result, {'a': 1}) + + def test_empty_override(self): + result = deep_merge_dicts({'a': 1}, {}) + self.assertEqual(result, {'a': 1}) + + def test_deeply_nested(self): + base = {'l1': {'l2': {'l3': {'val': 'base', 'keep': True}}}} + override = {'l1': {'l2': {'l3': {'val': 'override'}}}} + result = deep_merge_dicts(base, override) + self.assertEqual(result, {'l1': {'l2': {'l3': {'val': 'override', 'keep': True}}}}) + + +class ResolveIncludesTests(unittest.TestCase): + """Unit tests for resolve_includes.""" + + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def _write_file(self, relative_path: str, content: str) -> str: + full_path = os.path.join(self.test_dir, relative_path) + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, 'w', encoding='utf-8') as file_handle: + file_handle.write(textwrap.dedent(content)) + return full_path + + def test_no_includes_returns_original_text(self): + spec = 'workflow:\n name: test\n' + result = resolve_includes(spec, self.test_dir) + self.assertEqual(result, spec) + + def test_simple_include_merges_workflow(self): + self._write_file('base.yaml', '''\ + workflow: + name: base + resources: + default: + cpu: 8 + gpu: 1 + ''') + spec = textwrap.dedent('''\ + includes: + - base.yaml + workflow: + name: child + tasks: + - name: task1 + image: ubuntu + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + self.assertEqual(parsed['workflow']['name'], 'child') + self.assertEqual(parsed['workflow']['resources']['default']['cpu'], 8) + self.assertEqual(parsed['workflow']['resources']['default']['gpu'], 1) + self.assertEqual(len(parsed['workflow']['tasks']), 1) + self.assertNotIn('includes', parsed) + + def test_default_values_merged(self): + self._write_file('base.yaml', '''\ + default-values: + var1: base_val1 + var2: base_val2 + workflow: + name: base + ''') + spec = textwrap.dedent('''\ + includes: + - base.yaml + default-values: + var2: child_val2 + var3: child_val3 + workflow: + name: child + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + self.assertEqual(parsed['default-values'], { + 'var1': 'base_val1', + 'var2': 'child_val2', + 'var3': 'child_val3', + }) + + def test_main_file_overrides_included_values(self): + self._write_file('base.yaml', '''\ + workflow: + name: base + resources: + default: + cpu: 4 + gpu: 1 + ''') + spec = textwrap.dedent('''\ + includes: + - base.yaml + workflow: + name: override + resources: + default: + cpu: 16 + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + self.assertEqual(parsed['workflow']['name'], 'override') + self.assertEqual(parsed['workflow']['resources']['default']['cpu'], 16) + self.assertEqual(parsed['workflow']['resources']['default']['gpu'], 1) + + def test_multiple_includes_merged_in_order(self): + self._write_file('first.yaml', '''\ + workflow: + name: first + resources: + default: + cpu: 2 + ''') + self._write_file('second.yaml', '''\ + workflow: + name: second + resources: + default: + cpu: 8 + memory: 32Gi + ''') + spec = textwrap.dedent('''\ + includes: + - first.yaml + - second.yaml + workflow: + name: main + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + self.assertEqual(parsed['workflow']['name'], 'main') + self.assertEqual(parsed['workflow']['resources']['default']['cpu'], 8) + self.assertEqual(parsed['workflow']['resources']['default']['memory'], '32Gi') + + def test_nested_includes(self): + self._write_file('grandparent.yaml', '''\ + workflow: + name: grandparent + resources: + default: + cpu: 4 + ''') + self._write_file('parent.yaml', '''\ + includes: + - grandparent.yaml + workflow: + name: parent + resources: + default: + gpu: 2 + ''') + spec = textwrap.dedent('''\ + includes: + - parent.yaml + workflow: + name: child + tasks: + - name: task1 + image: ubuntu + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + self.assertEqual(parsed['workflow']['name'], 'child') + self.assertEqual(parsed['workflow']['resources']['default']['cpu'], 4) + self.assertEqual(parsed['workflow']['resources']['default']['gpu'], 2) + self.assertEqual(len(parsed['workflow']['tasks']), 1) + + def test_diamond_includes(self): + self._write_file('shared.yaml', '''\ + workflow: + name: shared + resources: + default: + cpu: 4 + ''') + self._write_file('branch_a.yaml', '''\ + includes: + - shared.yaml + workflow: + name: branch-a + ''') + self._write_file('branch_b.yaml', '''\ + includes: + - shared.yaml + workflow: + name: branch-b + resources: + default: + memory: 16Gi + ''') + spec = textwrap.dedent('''\ + includes: + - branch_a.yaml + - branch_b.yaml + workflow: + name: root + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + self.assertEqual(parsed['workflow']['name'], 'root') + self.assertEqual(parsed['workflow']['resources']['default']['cpu'], 4) + self.assertEqual(parsed['workflow']['resources']['default']['memory'], '16Gi') + + def test_circular_include_raises(self): + self._write_file('a.yaml', '''\ + includes: + - b.yaml + workflow: + name: a + ''') + self._write_file('b.yaml', '''\ + includes: + - a.yaml + workflow: + name: b + ''') + spec = textwrap.dedent('''\ + includes: + - a.yaml + workflow: + name: root + ''') + root_path = os.path.join(self.test_dir, 'root.yaml') + with self.assertRaises(osmo_errors.OSMOUserError) as context: + resolve_includes(spec, self.test_dir, source_path=root_path) + self.assertIn('Circular', str(context.exception)) + + def test_self_include_raises(self): + main_path = self._write_file('self.yaml', '''\ + includes: + - self.yaml + workflow: + name: self-ref + ''') + with open(main_path, encoding='utf-8') as file_handle: + spec = file_handle.read() + with self.assertRaises(osmo_errors.OSMOUserError) as context: + resolve_includes(spec, self.test_dir, source_path=main_path) + self.assertIn('Circular', str(context.exception)) + + def test_missing_include_file_raises(self): + spec = textwrap.dedent('''\ + includes: + - nonexistent.yaml + workflow: + name: test + ''') + with self.assertRaises(osmo_errors.OSMOUserError) as context: + resolve_includes(spec, self.test_dir) + self.assertIn('not found', str(context.exception)) + + def test_includes_not_a_list_raises(self): + spec = textwrap.dedent('''\ + includes: base.yaml + workflow: + name: test + ''') + with self.assertRaises(osmo_errors.OSMOUserError) as context: + resolve_includes(spec, self.test_dir) + self.assertIn('list', str(context.exception)) + + def test_include_path_not_string_raises(self): + spec = textwrap.dedent('''\ + includes: + - 42 + workflow: + name: test + ''') + with self.assertRaises(osmo_errors.OSMOUserError) as context: + resolve_includes(spec, self.test_dir) + self.assertIn('string', str(context.exception)) + + def test_included_file_not_mapping_raises(self): + self._write_file('list.yaml', '- item1\n- item2\n') + spec = textwrap.dedent('''\ + includes: + - list.yaml + workflow: + name: test + ''') + with self.assertRaises(osmo_errors.OSMOUserError) as context: + resolve_includes(spec, self.test_dir) + self.assertIn('mapping', str(context.exception)) + + def test_relative_paths_in_subdirectories(self): + self._write_file('bases/common.yaml', '''\ + workflow: + name: common + resources: + default: + cpu: 8 + ''') + spec = textwrap.dedent('''\ + includes: + - bases/common.yaml + workflow: + name: main + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + self.assertEqual(parsed['workflow']['resources']['default']['cpu'], 8) + + def test_version_preserved_from_main(self): + self._write_file('base.yaml', '''\ + version: 2 + workflow: + name: base + ''') + spec = textwrap.dedent('''\ + includes: + - base.yaml + version: 2 + workflow: + name: child + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + self.assertEqual(parsed['version'], 2) + self.assertEqual(parsed['workflow']['name'], 'child') + + def test_quoted_jinja_variables_preserved(self): + self._write_file('base.yaml', '''\ + workflow: + name: base + resources: + default: + cpu: 8 + ''') + spec = textwrap.dedent('''\ + includes: + - base.yaml + workflow: + name: "{{ workflow_name }}" + tasks: + - name: task1 + image: "my-image:{{ tag }}" + default-values: + workflow_name: my-wf + tag: latest + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + self.assertEqual(parsed['workflow']['name'], '{{ workflow_name }}') + self.assertEqual(parsed['workflow']['tasks'][0]['image'], 'my-image:{{ tag }}') + self.assertEqual(parsed['workflow']['resources']['default']['cpu'], 8) + + def test_includes_substring_in_value_ignored(self): + spec = textwrap.dedent('''\ + workflow: + name: test + tasks: + - name: task1 + image: ubuntu + command: ["echo", "this includes: some text"] + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + self.assertEqual(parsed['workflow']['name'], 'test') + + def test_tasks_composed_from_multiple_includes(self): + self._write_file('tasks/preprocess.yaml', '''\ + workflow: + tasks: + - name: preprocess + image: preprocess:v1 + command: ["python3", "preprocess.py"] + ''') + self._write_file('tasks/train.yaml', '''\ + workflow: + tasks: + - name: train + image: train:v1 + command: ["python3", "train.py"] + ''') + self._write_file('tasks/evaluate.yaml', '''\ + workflow: + tasks: + - name: evaluate + image: evaluate:v1 + command: ["python3", "evaluate.py"] + ''') + spec = textwrap.dedent('''\ + includes: + - tasks/preprocess.yaml + - tasks/train.yaml + - tasks/evaluate.yaml + workflow: + name: full-pipeline + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + task_names = [t['name'] for t in parsed['workflow']['tasks']] + self.assertEqual(task_names, ['preprocess', 'train', 'evaluate']) + self.assertEqual(parsed['workflow']['name'], 'full-pipeline') + + def test_task_override_from_main_file(self): + self._write_file('base_tasks.yaml', '''\ + workflow: + tasks: + - name: preprocess + image: preprocess:v1 + command: ["python3", "preprocess.py"] + - name: train + image: train:v1 + command: ["python3", "train.py"] + ''') + spec = textwrap.dedent('''\ + includes: + - base_tasks.yaml + workflow: + name: my-pipeline + tasks: + - name: train + image: train:v2 + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + task_names = [t['name'] for t in parsed['workflow']['tasks']] + self.assertEqual(task_names, ['preprocess', 'train']) + self.assertEqual(parsed['workflow']['tasks'][0]['image'], 'preprocess:v1') + self.assertEqual(parsed['workflow']['tasks'][1]['image'], 'train:v2') + self.assertEqual(parsed['workflow']['tasks'][1]['command'], ['python3', 'train.py']) + + def test_tasks_composed_with_shared_resources(self): + self._write_file('base.yaml', '''\ + workflow: + name: base + resources: + default: + cpu: 8 + gpu: 1 + timeout: + execution: 3600 + ''') + self._write_file('tasks/step_a.yaml', '''\ + workflow: + tasks: + - name: step-a + image: step-a:latest + command: ["run"] + ''') + self._write_file('tasks/step_b.yaml', '''\ + workflow: + tasks: + - name: step-b + image: step-b:latest + command: ["run"] + ''') + spec = textwrap.dedent('''\ + includes: + - base.yaml + - tasks/step_a.yaml + - tasks/step_b.yaml + workflow: + name: composed-pipeline + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + self.assertEqual(parsed['workflow']['name'], 'composed-pipeline') + self.assertEqual(parsed['workflow']['resources']['default']['cpu'], 8) + task_names = [t['name'] for t in parsed['workflow']['tasks']] + self.assertEqual(task_names, ['step-a', 'step-b']) + + +class VariableReferenceTests(unittest.TestCase): + """Unit tests for {{ ref }} task variable references in includes.""" + + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def _write_file(self, relative_path: str, content: str) -> str: + full_path = os.path.join(self.test_dir, relative_path) + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, 'w', encoding='utf-8') as file_handle: + file_handle.write(textwrap.dedent(content)) + return full_path + + def test_basic_task_ref_from_default_values(self): + self._write_file('base.yaml', '''\ + default-values: + preprocess: + image: preprocess:v1 + command: ["python3", "preprocess.py"] + train: + image: train:v1 + command: ["python3", "train.py"] + ''') + spec = textwrap.dedent('''\ + includes: + - base.yaml + workflow: + name: my-pipeline + tasks: + - "{{ preprocess }}" + - "{{ train }}" + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + task_names = [t['name'] for t in parsed['workflow']['tasks']] + self.assertEqual(task_names, ['preprocess', 'train']) + self.assertEqual(parsed['workflow']['tasks'][0]['image'], 'preprocess:v1') + self.assertEqual(parsed['workflow']['tasks'][1]['command'], ['python3', 'train.py']) + + def test_dot_path_ref(self): + self._write_file('base.yaml', '''\ + default-values: + task_library: + preprocess: + image: preprocess:v1 + command: ["python3", "preprocess.py"] + ''') + spec = textwrap.dedent('''\ + includes: + - base.yaml + workflow: + name: pipeline + tasks: + - "{{ task_library.preprocess }}" + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + task = parsed['workflow']['tasks'][0] + self.assertEqual(task['name'], 'preprocess') + self.assertEqual(task['image'], 'preprocess:v1') + + def test_ref_preserves_explicit_name(self): + self._write_file('base.yaml', '''\ + default-values: + my_task: + name: custom-name + image: img:v1 + ''') + spec = textwrap.dedent('''\ + includes: + - base.yaml + workflow: + name: test + tasks: + - "{{ my_task }}" + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + self.assertEqual(parsed['workflow']['tasks'][0]['name'], 'custom-name') + + def test_ref_with_named_merge_override(self): + self._write_file('base.yaml', '''\ + default-values: + preprocess: + image: preprocess:v1 + command: ["python3", "preprocess.py"] + train: + image: train:v1 + command: ["python3", "train.py"] + workflow: + tasks: + - "{{ preprocess }}" + - "{{ train }}" + ''') + spec = textwrap.dedent('''\ + includes: + - base.yaml + workflow: + name: my-pipeline + tasks: + - name: train + image: train:v2 + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + task_names = [t['name'] for t in parsed['workflow']['tasks']] + self.assertEqual(task_names, ['preprocess', 'train']) + self.assertEqual(parsed['workflow']['tasks'][0]['image'], 'preprocess:v1') + self.assertEqual(parsed['workflow']['tasks'][1]['image'], 'train:v2') + self.assertEqual(parsed['workflow']['tasks'][1]['command'], ['python3', 'train.py']) + + def test_cross_file_ref(self): + self._write_file('tasks.yaml', '''\ + default-values: + my_task: + image: worker:v1 + command: ["run"] + ''') + spec = textwrap.dedent('''\ + includes: + - tasks.yaml + workflow: + name: pipeline + tasks: + - "{{ my_task }}" + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + self.assertEqual(parsed['workflow']['tasks'][0]['name'], 'my_task') + self.assertEqual(parsed['workflow']['tasks'][0]['image'], 'worker:v1') + + def test_unresolvable_ref_left_for_jinja(self): + self._write_file('base.yaml', '''\ + workflow: + name: base + ''') + spec = textwrap.dedent('''\ + includes: + - base.yaml + workflow: + name: test + tasks: + - "{{ nonexistent }}" + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + self.assertEqual(parsed['workflow']['tasks'], ['{{ nonexistent }}']) + + def test_scalar_ref_left_for_jinja(self): + self._write_file('base.yaml', '''\ + default-values: + my_image: ubuntu:24.04 + ''') + spec = textwrap.dedent('''\ + includes: + - base.yaml + workflow: + name: test + tasks: + - "{{ my_image }}" + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + self.assertEqual(parsed['workflow']['tasks'], ['{{ my_image }}']) + + def test_ref_mixed_with_inline_tasks(self): + self._write_file('base.yaml', '''\ + default-values: + preprocess: + image: preprocess:v1 + command: ["python3", "preprocess.py"] + ''') + spec = textwrap.dedent('''\ + includes: + - base.yaml + workflow: + name: pipeline + tasks: + - "{{ preprocess }}" + - name: custom-task + image: custom:v1 + command: ["bash", "run.sh"] + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + task_names = [t['name'] for t in parsed['workflow']['tasks']] + self.assertEqual(task_names, ['preprocess', 'custom-task']) + + def test_ref_in_group_tasks(self): + self._write_file('base.yaml', '''\ + default-values: + server: + image: server:v1 + command: ["serve"] + lead: true + client: + image: client:v1 + command: ["connect"] + ''') + spec = textwrap.dedent('''\ + includes: + - base.yaml + workflow: + name: grouped + groups: + - name: my-group + tasks: + - "{{ server }}" + - "{{ client }}" + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + group = parsed['workflow']['groups'][0] + task_names = [t['name'] for t in group['tasks']] + self.assertEqual(task_names, ['server', 'client']) + self.assertTrue(group['tasks'][0]['lead']) + + def test_null_removes_task(self): + self._write_file('base.yaml', '''\ + default-values: + preprocess: + image: preprocess:v1 + command: ["python3", "preprocess.py"] + train: + image: train:v1 + command: ["python3", "train.py"] + workflow: + tasks: + - "{{ preprocess }}" + - "{{ train }}" + ''') + spec = textwrap.dedent('''\ + includes: + - base.yaml + default-values: + train: null + workflow: + name: preprocess-only + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + task_names = [t['name'] for t in parsed['workflow']['tasks']] + self.assertEqual(task_names, ['preprocess']) + + def test_null_removes_from_group_tasks(self): + self._write_file('base.yaml', '''\ + default-values: + server: + image: server:v1 + lead: true + client: + image: client:v1 + workflow: + groups: + - name: my-group + tasks: + - "{{ server }}" + - "{{ client }}" + ''') + spec = textwrap.dedent('''\ + includes: + - base.yaml + default-values: + client: null + workflow: + name: server-only + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + group = parsed['workflow']['groups'][0] + task_names = [t['name'] for t in group['tasks']] + self.assertEqual(task_names, ['server']) + + def test_null_with_multiple_includes(self): + self._write_file('tasks_a.yaml', '''\ + default-values: + task_a: + image: a:v1 + command: ["run_a"] + ''') + self._write_file('tasks_b.yaml', '''\ + default-values: + task_b: + image: b:v1 + command: ["run_b"] + task_c: + image: c:v1 + command: ["run_c"] + ''') + spec = textwrap.dedent('''\ + includes: + - tasks_a.yaml + - tasks_b.yaml + default-values: + task_b: null + workflow: + name: selective + tasks: + - "{{ task_a }}" + - "{{ task_b }}" + - "{{ task_c }}" + ''') + result = resolve_includes(spec, self.test_dir) + parsed = yaml.safe_load(result) + + task_names = [t['name'] for t in parsed['workflow']['tasks']] + self.assertEqual(task_names, ['task_a', 'task_c']) + + +class ResolveDefaultValuesTests(unittest.TestCase): + """Unit tests for resolve_default_values.""" + + def test_basic_substitution(self): + spec = textwrap.dedent('''\ + default-values: + greeting: hello + workflow: + name: "{{greeting}}-world" + ''') + result = resolve_default_values(spec) + self.assertNotIn('default-values', result) + self.assertIn('hello-world', result) + + def test_single_brace_substitution(self): + spec = textwrap.dedent('''\ + default-values: + base_dir: /data + workflow: + name: test + tasks: + - name: task1 + args: ["{base_dir}/output"] + ''') + result = resolve_default_values(spec) + self.assertIn('/data/output', result) + + def test_nested_variable_resolution(self): + spec = textwrap.dedent('''\ + default-values: + root: /opt + sub_dir: "{root}/app" + data_dir: "{sub_dir}/data" + workflow: + name: test + tasks: + - name: task1 + args: ["{data_dir}/file.txt"] + ''') + result = resolve_default_values(spec) + self.assertIn('/opt/app/data/file.txt', result) + self.assertNotIn('{root}', result) + self.assertNotIn('{sub_dir}', result) + + def test_none_value_leaves_references_unresolved(self): + """A None value in default-values leaves {var} references in the output.""" + spec = textwrap.dedent('''\ + default-values: + missing_var: + sub_dir: "{missing_var}/app" + workflow: + name: test + tasks: + - name: task1 + args: ["{sub_dir}/data"] + ''') + result = resolve_default_values(spec) + self.assertIn('{missing_var}', result) + + def test_env_ref_unset_resolves_to_empty(self): + """An unset ${env:VAR} in a quoted YAML value resolves to empty string.""" + spec = textwrap.dedent('''\ + default-values: + root: "${env:__OSMO_TEST_UNSET_VAR_12345__}" + local_dir: "{root}/local" + workflow: + name: test + tasks: + - name: task1 + args: ["{local_dir}/output"] + ''') + result = resolve_default_values(spec) + self.assertIn('/local/output', result) + self.assertNotIn('${env:', result) + + def test_env_ref_set_resolves(self): + """A set ${env:VAR} resolves correctly through variable chain.""" + original = os.environ.get('__OSMO_TEST_VAR__') + os.environ['__OSMO_TEST_VAR__'] = '/my/path' + try: + spec = textwrap.dedent('''\ + default-values: + root: "${env:__OSMO_TEST_VAR__}" + sub: "{root}/data" + workflow: + name: test + tasks: + - name: task1 + args: ["{sub}/file"] + ''') + result = resolve_default_values(spec) + self.assertIn('/my/path/data/file', result) + self.assertNotIn('{root}', result) + self.assertNotIn('${env:', result) + finally: + if original is None: + del os.environ['__OSMO_TEST_VAR__'] + else: + os.environ['__OSMO_TEST_VAR__'] = original + + def test_no_default_values_returns_text(self): + spec = 'workflow:\n name: test\n' + result = resolve_default_values(spec) + self.assertEqual(result, spec) + + +class FindUnresolvedEnvVariablesTests(unittest.TestCase): + """Unit tests for find_unresolved_env_variables.""" + + def test_no_default_values(self): + spec = 'workflow:\n name: test\n' + self.assertEqual(find_unresolved_env_variables(spec), {}) + + def test_no_env_refs(self): + spec = textwrap.dedent('''\ + default-values: + greeting: hello + workflow: + name: test + ''') + self.assertEqual(find_unresolved_env_variables(spec), {}) + + def test_set_env_var_not_reported(self): + original = os.environ.get('__OSMO_TEST_SET_VAR__') + os.environ['__OSMO_TEST_SET_VAR__'] = '/some/path' + try: + spec = textwrap.dedent('''\ + default-values: + root: "${env:__OSMO_TEST_SET_VAR__}" + workflow: + name: test + ''') + self.assertEqual(find_unresolved_env_variables(spec), {}) + finally: + if original is None: + del os.environ['__OSMO_TEST_SET_VAR__'] + else: + os.environ['__OSMO_TEST_SET_VAR__'] = original + + def test_unset_env_var_reported(self): + spec = textwrap.dedent('''\ + default-values: + root: "${env:__OSMO_TEST_DEFINITELY_UNSET__}" + workflow: + name: test + ''') + result = find_unresolved_env_variables(spec) + self.assertEqual(result, {'root': '__OSMO_TEST_DEFINITELY_UNSET__'}) + + def test_multiple_unset_env_vars(self): + spec = textwrap.dedent('''\ + default-values: + var_a: "${env:__OSMO_UNSET_A__}" + var_b: "${env:__OSMO_UNSET_B__}" + var_c: "no-env-ref" + workflow: + name: test + ''') + result = find_unresolved_env_variables(spec) + self.assertEqual(result, { + 'var_a': '__OSMO_UNSET_A__', + 'var_b': '__OSMO_UNSET_B__', + }) + + def test_dict_entries_ignored(self): + spec = textwrap.dedent('''\ + default-values: + my_task: + name: task1 + image: ubuntu + root: "${env:__OSMO_UNSET_X__}" + workflow: + name: test + ''') + result = find_unresolved_env_variables(spec) + self.assertEqual(result, {'root': '__OSMO_UNSET_X__'}) + + +if __name__ == '__main__': + unittest.main()