From d93d9d33e387d5ed5b72031494741e12bdf749e9 Mon Sep 17 00:00:00 2001 From: Ethan Yu Date: Mon, 30 Mar 2026 17:33:07 -0700 Subject: [PATCH] Add unit test coverage for utils/job module Add 6 new test files covering previously untested code paths across common.py, backend_job_defs.py, task.py, workflow.py, kb_objects.py, kb_methods.py, jobs_base.py, app.py, and task_io.py. Tests cover model validation, enum methods, utility functions, and K8s object factory logic. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/utils/job/tests/BUILD | 54 +++ src/utils/job/tests/test_backend_job_defs.py | 119 +++++ src/utils/job/tests/test_common_utils.py | 111 +++++ src/utils/job/tests/test_jobs_base_unit.py | 86 ++++ src/utils/job/tests/test_kb_objects_unit.py | 289 ++++++++++++ src/utils/job/tests/test_task_models.py | 467 +++++++++++++++++++ src/utils/job/tests/test_workflow_models.py | 223 +++++++++ 7 files changed, 1349 insertions(+) create mode 100644 src/utils/job/tests/test_backend_job_defs.py create mode 100644 src/utils/job/tests/test_common_utils.py create mode 100644 src/utils/job/tests/test_jobs_base_unit.py create mode 100644 src/utils/job/tests/test_kb_objects_unit.py create mode 100644 src/utils/job/tests/test_task_models.py create mode 100644 src/utils/job/tests/test_workflow_models.py diff --git a/src/utils/job/tests/BUILD b/src/utils/job/tests/BUILD index df70621ed..7a12be69d 100644 --- a/src/utils/job/tests/BUILD +++ b/src/utils/job/tests/BUILD @@ -106,3 +106,57 @@ osmo_py_test( size = "large", tags = ["requires-network"], ) + +py_test( + name = "test_common_utils", + srcs = ["test_common_utils.py"], + deps = [ + "//src/lib/data/storage", + "//src/lib/utils:osmo_errors", + "//src/utils/job", + ], +) + +py_test( + name = "test_backend_job_defs", + srcs = ["test_backend_job_defs.py"], + deps = [ + "//src/utils/job:backend_job", + ], +) + +py_test( + name = "test_task_models", + srcs = ["test_task_models.py"], + deps = [ + "//src/lib/utils:osmo_errors", + "//src/utils/job", + ], +) + +py_test( + name = "test_workflow_models", + srcs = ["test_workflow_models.py"], + deps = [ + "//src/utils/connectors", + "//src/utils/job", + ], +) + +py_test( + name = "test_kb_objects_unit", + srcs = ["test_kb_objects_unit.py"], + deps = [ + "//src/lib/utils:priority", + "//src/utils/job", + "//src/utils/job:backend_job", + ], +) + +py_test( + name = "test_jobs_base_unit", + srcs = ["test_jobs_base_unit.py"], + deps = [ + "//src/utils/job:jobs_base", + ], +) diff --git a/src/utils/job/tests/test_backend_job_defs.py b/src/utils/job/tests/test_backend_job_defs.py new file mode 100644 index 000000000..88b817d61 --- /dev/null +++ b/src/utils/job/tests/test_backend_job_defs.py @@ -0,0 +1,119 @@ +""" +SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +SPDX-License-Identifier: Apache-2.0 +""" +import unittest + +from src.utils.job import backend_job_defs + + +class EffectiveApiVersionTest(unittest.TestCase): + def test_generic_api_preferred(self): + spec = backend_job_defs.BackendCleanupSpec( + labels={'app': 'test'}, + generic_api=backend_job_defs.BackendGenericApi( + api_version='scheduling.run.ai/v2', kind='Queue'), + custom_api=backend_job_defs.BackendCustomApi( + api_major='old.api', api_minor='v1', path='queues')) + self.assertEqual(spec.effective_api_version, 'scheduling.run.ai/v2') + + def test_custom_api_fallback(self): + spec = backend_job_defs.BackendCleanupSpec( + labels={'app': 'test'}, + custom_api=backend_job_defs.BackendCustomApi( + api_major='scheduling.run.ai', api_minor='v2alpha2', path='podgroups')) + self.assertEqual(spec.effective_api_version, 'scheduling.run.ai/v2alpha2') + + def test_default_v1_when_no_api(self): + spec = backend_job_defs.BackendCleanupSpec(labels={'app': 'test'}) + self.assertEqual(spec.effective_api_version, 'v1') + + +class EffectiveKindTest(unittest.TestCase): + def test_generic_api_preferred(self): + spec = backend_job_defs.BackendCleanupSpec( + labels={'app': 'test'}, + resource_type='Pod', + generic_api=backend_job_defs.BackendGenericApi( + api_version='v1', kind='Service')) + self.assertEqual(spec.effective_kind, 'Service') + + def test_resource_type_fallback(self): + spec = backend_job_defs.BackendCleanupSpec( + labels={'app': 'test'}, + resource_type='Pod') + self.assertEqual(spec.effective_kind, 'Pod') + + def test_none_when_nothing_set(self): + spec = backend_job_defs.BackendCleanupSpec(labels={'app': 'test'}) + self.assertIsNone(spec.effective_kind) + + +class K8sSelectorTest(unittest.TestCase): + def test_single_label(self): + spec = backend_job_defs.BackendCleanupSpec(labels={'app': 'osmo'}) + self.assertEqual(spec.k8s_selector, 'app=osmo') + + def test_multiple_labels(self): + spec = backend_job_defs.BackendCleanupSpec( + labels={'app': 'osmo', 'env': 'prod'}) + parts = spec.k8s_selector.split(',') + self.assertEqual(len(parts), 2) + self.assertIn('app=osmo', parts) + self.assertIn('env=prod', parts) + + +class BackendCreateGroupMixinTest(unittest.TestCase): + def test_default_values(self): + mixin = backend_job_defs.BackendCreateGroupMixin( + group_name='group1', k8s_resources=[{'kind': 'Pod'}]) + self.assertEqual(mixin.backend_k8s_timeout, 60) + self.assertEqual(mixin.scheduler_settings, {}) + + +class BackendGenericApiTest(unittest.TestCase): + def test_creation(self): + api = backend_job_defs.BackendGenericApi(api_version='v1', kind='Pod') + self.assertEqual(api.api_version, 'v1') + self.assertEqual(api.kind, 'Pod') + + +class BackendCustomApiTest(unittest.TestCase): + def test_creation(self): + api = backend_job_defs.BackendCustomApi( + api_major='scheduling.run.ai', api_minor='v2alpha2', path='podgroups') + self.assertEqual(api.api_major, 'scheduling.run.ai') + self.assertEqual(api.api_minor, 'v2alpha2') + self.assertEqual(api.path, 'podgroups') + + +class BackendSynchronizeQueuesMixinTest(unittest.TestCase): + def test_default_immutable_kinds(self): + mixin = backend_job_defs.BackendSynchronizeQueuesMixin( + cleanup_specs=[], + k8s_resources=[]) + self.assertEqual(mixin.immutable_kinds, []) + + def test_with_immutable_kinds(self): + mixin = backend_job_defs.BackendSynchronizeQueuesMixin( + cleanup_specs=[], + k8s_resources=[], + immutable_kinds=['Topology']) + self.assertEqual(mixin.immutable_kinds, ['Topology']) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/utils/job/tests/test_common_utils.py b/src/utils/job/tests/test_common_utils.py new file mode 100644 index 000000000..c3bc3a3bc --- /dev/null +++ b/src/utils/job/tests/test_common_utils.py @@ -0,0 +1,111 @@ +""" +SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +SPDX-License-Identifier: Apache-2.0 +""" +import datetime +import os +import unittest + +from src.lib.data.storage.backends.common import StoragePath +from src.lib.utils import osmo_errors +from src.utils.job import common + + +class GetWorkflowLogsPathTest(unittest.TestCase): + def test_returns_joined_path(self): + result = common.get_workflow_logs_path('wf-123', 'output.log') + self.assertEqual(result, os.path.join('wf-123', 'output.log')) + + def test_different_filenames(self): + result = common.get_workflow_logs_path('my-workflow', 'error.log') + self.assertEqual(result, os.path.join('my-workflow', 'error.log')) + + +class GetWorkflowAppPathTest(unittest.TestCase): + def test_without_prefix(self): + path_params = StoragePath( + scheme='s3', host='bucket', endpoint_url='http://s3', + container='bucket', region='us-east-1', prefix='') + result = common.get_workflow_app_path('app-uuid-123', 1, path_params) + self.assertEqual(result, os.path.join('app-uuid-123', '1', 'workflow_app.txt')) + + def test_with_prefix(self): + path_params = StoragePath( + scheme='s3', host='bucket', endpoint_url='http://s3', + container='bucket', region='us-east-1', prefix='my-prefix') + result = common.get_workflow_app_path('app-uuid-123', 2, path_params) + self.assertEqual( + result, + os.path.join('my-prefix', 'app-uuid-123', '2', 'workflow_app.txt')) + + def test_version_is_stringified(self): + path_params = StoragePath( + scheme='s3', host='bucket', endpoint_url='http://s3', + container='bucket', region='us-east-1', prefix='') + result = common.get_workflow_app_path('uuid', 42, path_params) + self.assertIn('42', result) + + +class CalculateTotalTimeoutTest(unittest.TestCase): + def test_returns_sum_of_timeouts(self): + queue = datetime.timedelta(seconds=300) + execution = datetime.timedelta(seconds=600) + result = common.calculate_total_timeout('wf-1', queue, execution) + self.assertEqual(result, 900) + + def test_raises_without_exec_timeout(self): + with self.assertRaises(osmo_errors.OSMODatabaseError): + common.calculate_total_timeout( + 'wf-1', queue_timeout=datetime.timedelta(seconds=300)) + + def test_raises_without_queue_timeout(self): + with self.assertRaises(osmo_errors.OSMODatabaseError): + common.calculate_total_timeout( + 'wf-1', exec_timeout=datetime.timedelta(seconds=300)) + + def test_raises_with_both_none(self): + with self.assertRaises(osmo_errors.OSMODatabaseError): + common.calculate_total_timeout('wf-1') + + def test_truncates_to_int(self): + queue = datetime.timedelta(seconds=1.7) + execution = datetime.timedelta(seconds=2.9) + result = common.calculate_total_timeout('wf-1', queue, execution) + self.assertEqual(result, int(1.7) + int(2.9)) + + +class BarrierKeyTest(unittest.TestCase): + def test_format(self): + result = common.barrier_key('wf-1', 'group-a', 'sync') + self.assertEqual(result, 'client-connections:wf-1:group-a:barrier-sync') + + def test_different_inputs(self): + result = common.barrier_key('workflow', 'train', 'ready') + self.assertEqual(result, 'client-connections:workflow:train:barrier-ready') + + +class WorkflowPluginsTest(unittest.TestCase): + def test_default_rsync_false(self): + plugins = common.WorkflowPlugins() + self.assertFalse(plugins.rsync) + + def test_rsync_enabled(self): + plugins = common.WorkflowPlugins(rsync=True) + self.assertTrue(plugins.rsync) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/utils/job/tests/test_jobs_base_unit.py b/src/utils/job/tests/test_jobs_base_unit.py new file mode 100644 index 000000000..6ab9015d9 --- /dev/null +++ b/src/utils/job/tests/test_jobs_base_unit.py @@ -0,0 +1,86 @@ +""" +SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +SPDX-License-Identifier: Apache-2.0 +""" +import datetime +import unittest +from unittest import mock + +from src.utils.job import jobs_base + + +# --------------------------------------------------------------------------- +# JobStatus +# --------------------------------------------------------------------------- +class JobStatusTest(unittest.TestCase): + def test_values(self): + self.assertEqual(jobs_base.JobStatus.SUCCESS.value, 'SUCCESS') + self.assertEqual(jobs_base.JobStatus.FAILED_RETRY.value, 'FAILED_RETRY') + self.assertEqual(jobs_base.JobStatus.FAILED_NO_RETRY.value, 'FAILED_NO_RETRY') + + +# --------------------------------------------------------------------------- +# JobResult +# --------------------------------------------------------------------------- +class JobResultTest(unittest.TestCase): + def test_success_result(self): + result = jobs_base.JobResult(status=jobs_base.JobStatus.SUCCESS, message=None) + self.assertFalse(result.retry) + + def test_retry_result(self): + result = jobs_base.JobResult( + status=jobs_base.JobStatus.FAILED_RETRY, message='network error') + self.assertTrue(result.retry) + + def test_no_retry_result(self): + result = jobs_base.JobResult( + status=jobs_base.JobStatus.FAILED_NO_RETRY, message='bad input') + self.assertFalse(result.retry) + + def test_str_with_message(self): + result = jobs_base.JobResult( + status=jobs_base.JobStatus.FAILED_RETRY, message='timeout') + self.assertEqual(str(result), 'FAILED_RETRY: timeout') + + def test_str_without_message(self): + result = jobs_base.JobResult( + status=jobs_base.JobStatus.SUCCESS, message=None) + self.assertEqual(str(result), 'SUCCESS') + + +# --------------------------------------------------------------------------- +# update_progress_writer +# --------------------------------------------------------------------------- +class UpdateProgressWriterTest(unittest.TestCase): + def test_does_not_report_when_time_not_elapsed(self): + writer = mock.MagicMock() + last = datetime.datetime.now() + freq = datetime.timedelta(hours=1) + result = jobs_base.update_progress_writer(writer, last, freq) + writer.report_progress.assert_not_called() + self.assertEqual(result, last) + + +# --------------------------------------------------------------------------- +# UNIQUE_JOB_TTL constant +# --------------------------------------------------------------------------- +class UniqueJobTtlTest(unittest.TestCase): + def test_value(self): + self.assertEqual(jobs_base.UNIQUE_JOB_TTL, 5 * 24 * 60 * 60) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/utils/job/tests/test_kb_objects_unit.py b/src/utils/job/tests/test_kb_objects_unit.py new file mode 100644 index 000000000..a427e75bc --- /dev/null +++ b/src/utils/job/tests/test_kb_objects_unit.py @@ -0,0 +1,289 @@ +""" +SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +SPDX-License-Identifier: Apache-2.0 +""" +import base64 +import json +import unittest +from unittest import mock + +from src.lib.utils import priority as wf_priority +from src.utils.job import backend_job_defs, kb_objects + + +# --------------------------------------------------------------------------- +# k8s_name +# --------------------------------------------------------------------------- +class K8sNameTest(unittest.TestCase): + def test_lowercase(self): + self.assertEqual(kb_objects.k8s_name('MyTask'), 'mytask') + + def test_underscores_to_hyphens(self): + self.assertEqual(kb_objects.k8s_name('my_task_name'), 'my-task-name') + + def test_already_valid(self): + self.assertEqual(kb_objects.k8s_name('valid-name'), 'valid-name') + + def test_mixed(self): + self.assertEqual(kb_objects.k8s_name('My_Task_Name'), 'my-task-name') + + +# --------------------------------------------------------------------------- +# construct_pod_name +# --------------------------------------------------------------------------- +class ConstructPodNameTest(unittest.TestCase): + def test_format(self): + wf_uuid = 'abcdef1234567890extra' + task_uuid = '0987654321fedcbaextra' + result = kb_objects.construct_pod_name(wf_uuid, task_uuid) + self.assertEqual(result, 'abcdef1234567890-0987654321fedcba') + + def test_truncation(self): + result = kb_objects.construct_pod_name('a' * 40, 'b' * 40) + parts = result.split('-') + self.assertEqual(len(parts), 2) + self.assertEqual(len(parts[0]), 16) + self.assertEqual(len(parts[1]), 16) + + +# --------------------------------------------------------------------------- +# K8sObjectFactory +# --------------------------------------------------------------------------- +class K8sObjectFactoryTest(unittest.TestCase): + def setUp(self): + self.factory = kb_objects.K8sObjectFactory(scheduler_name='default-scheduler') + + def test_create_secret(self): + secret = self.factory.create_secret( + name='my-secret', + labels={'app': 'osmo'}, + data={'key': 'dmFsdWU='}, + string_data={'plain': 'value'}) + self.assertEqual(secret['kind'], 'Secret') + self.assertEqual(secret['metadata']['name'], 'my-secret') + self.assertEqual(secret['metadata']['labels']['app'], 'osmo') + self.assertEqual(secret['data']['key'], 'dmFsdWU=') + self.assertEqual(secret['stringData']['plain'], 'value') + self.assertEqual(secret['type'], 'Opaque') + + def test_create_secret_custom_type(self): + secret = self.factory.create_secret( + name='tls-secret', labels={}, data={}, string_data={}, + secret_type='kubernetes.io/tls') + self.assertEqual(secret['type'], 'kubernetes.io/tls') + + def test_create_headless_service(self): + svc = self.factory.create_headless_service( + name='my-group', labels={'osmo.workflow': 'wf1'}) + self.assertEqual(svc['kind'], 'Service') + self.assertEqual(svc['spec']['clusterIP'], 'None') + self.assertEqual(svc['spec']['selector']['osmo.workflow'], 'wf1') + + def test_create_config_map(self): + cm = self.factory.create_config_map( + name='my-config', labels={'app': 'osmo'}, + data={'key1': 'val1', 'key2': 'val2'}) + self.assertEqual(cm['kind'], 'ConfigMap') + self.assertEqual(cm['metadata']['name'], 'my-config') + self.assertEqual(cm['data']['key1'], 'val1') + + def test_create_image_secret(self): + cred = {'registry.io': {'username': 'user', 'password': 'pass'}} + secret = self.factory.create_image_secret( + secret_name='reg-secret', labels={'app': 'osmo'}, cred=cred) + self.assertEqual(secret['kind'], 'Secret') + self.assertEqual(secret['type'], 'kubernetes.io/dockerconfigjson') + decoded = json.loads( + base64.b64decode(secret['data']['.dockerconfigjson']).decode('utf-8')) + self.assertEqual(decoded['auths']['registry.io']['username'], 'user') + + def test_priority_supported(self): + self.assertFalse(self.factory.priority_supported()) + + def test_topology_supported(self): + self.assertFalse(self.factory.topology_supported()) + + def test_retry_allowed(self): + self.assertTrue(self.factory.retry_allowed()) + + def test_list_scheduler_resources_spec_empty(self): + backend = mock.MagicMock() + self.assertEqual(self.factory.list_scheduler_resources_spec(backend), []) + + def test_list_immutable_scheduler_resources_empty(self): + self.assertEqual(self.factory.list_immutable_scheduler_resources(), []) + + def test_get_scheduler_resources_spec_empty(self): + backend = mock.MagicMock() + self.assertEqual(self.factory.get_scheduler_resources_spec(backend, []), []) + + def test_get_group_cleanup_specs(self): + labels = {'app': 'osmo'} + specs = self.factory.get_group_cleanup_specs(labels) + self.assertEqual(len(specs), 1) + self.assertEqual(specs[0].resource_type, 'Pod') + self.assertEqual(specs[0].labels, labels) + + def test_get_error_log_specs(self): + labels = {'app': 'osmo'} + spec = self.factory.get_error_log_specs(labels) + self.assertEqual(spec.resource_type, 'Pod') + self.assertEqual(spec.labels, labels) + + def test_update_pod_k8s_resource(self): + pod: dict = {'spec': {}} + self.factory.update_pod_k8s_resource( + pod, 'group-uuid', 'pool1', wf_priority.WorkflowPriority.NORMAL) + self.assertEqual(pod['spec']['schedulerName'], 'default-scheduler') + + def test_create_group_k8s_resources(self): + pods: list = [{'spec': {}, 'metadata': {'labels': {}}}] + result = self.factory.create_group_k8s_resources( + group_uuid='g1', pods=pods, labels={}, pool_name='pool1', + priority=wf_priority.WorkflowPriority.NORMAL, + topology_keys=[], task_infos=[]) + self.assertEqual(len(result), 1) + self.assertEqual(result[0]['spec']['schedulerName'], 'default-scheduler') + + +# --------------------------------------------------------------------------- +# FileMount +# --------------------------------------------------------------------------- +class FileMountTest(unittest.TestCase): + def setUp(self): + self.factory = kb_objects.K8sObjectFactory(scheduler_name='default') + + def test_creation_and_digest(self): + fm = kb_objects.FileMount( + group_uid='group-uuid-1234567890', + path='/home/user/config.yaml', + content='key: value', + k8s_factory=self.factory) + self.assertNotEqual(fm.digest, '') + self.assertTrue(fm.name.startswith('osmo-')) + + def test_digest_cannot_be_set(self): + with self.assertRaises(Exception): + kb_objects.FileMount( + group_uid='group-uuid', + path='/home/test.txt', + content='data', + digest='custom-digest', + k8s_factory=self.factory) + + def test_custom_digest(self): + fm = kb_objects.FileMount( + group_uid='group-uuid-1234567890', + path='/test/file.txt', + content='data', + k8s_factory=self.factory) + original_digest = fm.digest + fm.custom_digest('my-hash-string') + self.assertNotEqual(fm.digest, original_digest) + + def test_name_uses_truncated_group_uid(self): + fm = kb_objects.FileMount( + group_uid='abcdefghijklmnopqrstuvwxyz', + path='/test.txt', content='data', + k8s_factory=self.factory) + self.assertIn('abcdefghijklmnop', fm.name) + + def test_volume(self): + fm = kb_objects.FileMount( + group_uid='group-uuid-1234567890', + path='/test/file.txt', content='data', + k8s_factory=self.factory) + vol = fm.volume() + self.assertEqual(vol['name'], fm.name) + self.assertEqual(vol['secret']['secretName'], fm.name) + + def test_volume_mount(self): + fm = kb_objects.FileMount( + group_uid='group-uuid-1234567890', + path='/test/file.txt', content='data', + k8s_factory=self.factory) + vm = fm.volume_mount() + self.assertEqual(vm['mountPath'], '/test/file.txt') + self.assertEqual(vm['subPath'], 'file.txt') + self.assertEqual(vm['name'], fm.name) + + def test_secret(self): + fm = kb_objects.FileMount( + group_uid='group-uuid-1234567890', + path='/test/file.txt', content='data', + k8s_factory=self.factory) + labels = {'app': 'osmo'} + secret = fm.secret(labels) + self.assertEqual(secret['kind'], 'Secret') + self.assertIn('file.txt', secret['data']) + + def test_different_content_different_digest(self): + fm1 = kb_objects.FileMount( + group_uid='group', path='/test.txt', content='data1', + k8s_factory=self.factory) + fm2 = kb_objects.FileMount( + group_uid='group', path='/test.txt', content='data2', + k8s_factory=self.factory) + self.assertNotEqual(fm1.digest, fm2.digest) + + +# --------------------------------------------------------------------------- +# get_k8s_object_factory +# --------------------------------------------------------------------------- +class GetK8sObjectFactoryTest(unittest.TestCase): + def test_unsupported_scheduler_raises(self): + backend = mock.MagicMock() + backend.scheduler_settings.scheduler_type = 'UNSUPPORTED' + with self.assertRaises(Exception): + kb_objects.get_k8s_object_factory(backend) + + +# --------------------------------------------------------------------------- +# kb_methods (CustomObject stubs) +# --------------------------------------------------------------------------- +class KbMethodsStubTest(unittest.TestCase): + def test_custom_object_metadata_stub(self): + from src.utils.job.kb_methods import CustomObjectMetadataStub + stub = CustomObjectMetadataStub(name='test-pod') + self.assertEqual(stub.name, 'test-pod') + + def test_custom_object_stub(self): + from src.utils.job.kb_methods import CustomObjectStub, CustomObjectMetadataStub + meta = CustomObjectMetadataStub(name='pod-1') + stub = CustomObjectStub(metadata=meta) + self.assertEqual(stub.metadata.name, 'pod-1') + + def test_custom_object_list_stub(self): + from src.utils.job.kb_methods import ( + CustomObjectListStub, CustomObjectStub, CustomObjectMetadataStub) + items = [ + CustomObjectStub(metadata=CustomObjectMetadataStub(name='a')), + CustomObjectStub(metadata=CustomObjectMetadataStub(name='b')), + ] + stub = CustomObjectListStub(items=items) + self.assertEqual(len(stub.items), 2) + self.assertEqual(stub.items[0].metadata.name, 'a') + + def test_kb_methods_factory_raises_for_none_kind(self): + from unittest import mock + from src.utils.job import kb_methods + spec = backend_job_defs.BackendCleanupSpec(labels={'app': 'test'}) + with self.assertRaises(ValueError): + kb_methods.kb_methods_factory(mock.MagicMock(), spec) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/utils/job/tests/test_task_models.py b/src/utils/job/tests/test_task_models.py new file mode 100644 index 000000000..4d7218d7f --- /dev/null +++ b/src/utils/job/tests/test_task_models.py @@ -0,0 +1,467 @@ +""" +SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +SPDX-License-Identifier: Apache-2.0 +""" +import copy +import datetime +import unittest + +from src.lib.utils import osmo_errors +from src.utils.job import common as task_common, task +from src.utils.job.app import AppStatus + + +# --------------------------------------------------------------------------- +# AppStatus +# --------------------------------------------------------------------------- +class AppStatusDeletedTest(unittest.TestCase): + def test_deleted_states(self): + self.assertTrue(AppStatus.DELETED.deleted()) + self.assertTrue(AppStatus.PENDING_DELETE.deleted()) + + def test_non_deleted_states(self): + self.assertFalse(AppStatus.PENDING.deleted()) + self.assertFalse(AppStatus.READY.deleted()) + + +# --------------------------------------------------------------------------- +# ExitAction / ExitCode enums +# --------------------------------------------------------------------------- +class ExitActionTest(unittest.TestCase): + def test_values(self): + self.assertEqual(task.ExitAction.COMPLETED.value, 'COMPLETE') + self.assertEqual(task.ExitAction.FAILED.value, 'FAIL') + self.assertEqual(task.ExitAction.RESCHEDULED.value, 'RESCHEDULE') + + +class ExitCodeTest(unittest.TestCase): + def test_known_codes(self): + self.assertEqual(task.ExitCode.FAILED_PREFLIGHT.value, 1001) + self.assertEqual(task.ExitCode.FAILED_UPSTREAM.value, 3000) + self.assertEqual(task.ExitCode.FAILED_UNKNOWN.value, 4000) + + +# --------------------------------------------------------------------------- +# TaskGroupStatus +# --------------------------------------------------------------------------- +class TaskGroupStatusMethodsTest(unittest.TestCase): + def test_finished_for_completed(self): + self.assertTrue(task.TaskGroupStatus.COMPLETED.finished()) + + def test_finished_for_rescheduled(self): + self.assertTrue(task.TaskGroupStatus.RESCHEDULED.finished()) + + def test_finished_for_failed(self): + self.assertTrue(task.TaskGroupStatus.FAILED.finished()) + + def test_not_finished_for_running(self): + self.assertFalse(task.TaskGroupStatus.RUNNING.finished()) + + def test_group_finished_for_completed(self): + self.assertTrue(task.TaskGroupStatus.COMPLETED.group_finished()) + + def test_group_finished_for_failed(self): + self.assertTrue(task.TaskGroupStatus.FAILED.group_finished()) + + def test_group_finished_false_for_rescheduled(self): + self.assertFalse(task.TaskGroupStatus.RESCHEDULED.group_finished()) + + def test_failed(self): + for status in task.TaskGroupStatus: + if status.name.startswith('FAILED'): + self.assertTrue(status.failed(), f'{status.name} should be failed') + else: + self.assertFalse(status.failed(), f'{status.name} should not be failed') + + def test_prescheduling(self): + prescheduling = {task.TaskGroupStatus.SUBMITTING, + task.TaskGroupStatus.WAITING, + task.TaskGroupStatus.PROCESSING} + for status in task.TaskGroupStatus: + self.assertEqual(status.prescheduling(), status in prescheduling, + f'{status.name}') + + def test_in_queue(self): + in_queue = {task.TaskGroupStatus.SUBMITTING, + task.TaskGroupStatus.WAITING, + task.TaskGroupStatus.PROCESSING, + task.TaskGroupStatus.SCHEDULING} + for status in task.TaskGroupStatus: + self.assertEqual(status.in_queue(), status in in_queue, + f'{status.name}') + + def test_prerunning(self): + prerunning = {task.TaskGroupStatus.SUBMITTING, + task.TaskGroupStatus.WAITING, + task.TaskGroupStatus.PROCESSING, + task.TaskGroupStatus.SCHEDULING, + task.TaskGroupStatus.INITIALIZING} + for status in task.TaskGroupStatus: + self.assertEqual(status.prerunning(), status in prerunning, + f'{status.name}') + + def test_canceled(self): + canceled = {task.TaskGroupStatus.FAILED_CANCELED, + task.TaskGroupStatus.FAILED_EXEC_TIMEOUT, + task.TaskGroupStatus.FAILED_QUEUE_TIMEOUT} + for status in task.TaskGroupStatus: + self.assertEqual(status.canceled(), status in canceled, + f'{status.name}') + + def test_server_errored(self): + server_errored = {task.TaskGroupStatus.FAILED_SERVER_ERROR, + task.TaskGroupStatus.FAILED_EVICTED, + task.TaskGroupStatus.FAILED_START_ERROR, + task.TaskGroupStatus.FAILED_IMAGE_PULL} + for status in task.TaskGroupStatus: + self.assertEqual(status.server_errored(), status in server_errored, + f'{status.name}') + + def test_has_error_logs_for_rescheduled(self): + self.assertTrue(task.TaskGroupStatus.RESCHEDULED.has_error_logs()) + + def test_has_error_logs_for_regular_failure(self): + self.assertTrue(task.TaskGroupStatus.FAILED.has_error_logs()) + self.assertTrue(task.TaskGroupStatus.FAILED_BACKEND_ERROR.has_error_logs()) + self.assertTrue(task.TaskGroupStatus.FAILED_PREEMPTED.has_error_logs()) + + def test_no_error_logs_for_server_errored(self): + self.assertFalse(task.TaskGroupStatus.FAILED_SERVER_ERROR.has_error_logs()) + self.assertFalse(task.TaskGroupStatus.FAILED_EVICTED.has_error_logs()) + self.assertFalse(task.TaskGroupStatus.FAILED_IMAGE_PULL.has_error_logs()) + + def test_no_error_logs_for_canceled(self): + self.assertFalse(task.TaskGroupStatus.FAILED_CANCELED.has_error_logs()) + self.assertFalse(task.TaskGroupStatus.FAILED_EXEC_TIMEOUT.has_error_logs()) + self.assertFalse(task.TaskGroupStatus.FAILED_QUEUE_TIMEOUT.has_error_logs()) + + def test_no_error_logs_for_upstream(self): + self.assertFalse(task.TaskGroupStatus.FAILED_UPSTREAM.has_error_logs()) + + def test_no_error_logs_for_running(self): + self.assertFalse(task.TaskGroupStatus.RUNNING.has_error_logs()) + + def test_backend_states(self): + states = task.TaskGroupStatus.backend_states() + self.assertIn('SCHEDULING', states) + self.assertIn('RUNNING', states) + + def test_get_alive_statuses(self): + alive = task.TaskGroupStatus.get_alive_statuses() + self.assertIn(task.TaskGroupStatus.SUBMITTING, alive) + self.assertIn(task.TaskGroupStatus.RUNNING, alive) + self.assertNotIn(task.TaskGroupStatus.COMPLETED, alive) + self.assertNotIn(task.TaskGroupStatus.FAILED, alive) + + +# --------------------------------------------------------------------------- +# create_login_dict +# --------------------------------------------------------------------------- +class CreateLoginDictTest(unittest.TestCase): + def test_with_token(self): + result = task.create_login_dict( + user='testuser', url='https://api.example.com', + token='my-token', refresh_endpoint='/refresh', + refresh_token='refresh-abc') + self.assertEqual(result['token_login']['id_token'], 'my-token') + self.assertEqual(result['token_login']['refresh_url'], '/refresh') + self.assertEqual(result['token_login']['refresh_token'], 'refresh-abc') + self.assertEqual(result['url'], 'https://api.example.com') + self.assertTrue(result['osmo_token']) + self.assertEqual(result['username'], 'testuser') + + def test_without_token(self): + result = task.create_login_dict( + user='devuser', url='https://dev.example.com') + self.assertEqual(result['dev_login']['username'], 'devuser') + self.assertEqual(result['url'], 'https://dev.example.com') + self.assertNotIn('token_login', result) + + +# --------------------------------------------------------------------------- +# shorten_name_to_fit_kb +# --------------------------------------------------------------------------- +class ShortenNameToFitKbTest(unittest.TestCase): + def test_short_name_unchanged(self): + self.assertEqual(task.shorten_name_to_fit_kb('short'), 'short') + + def test_exactly_63_chars(self): + name = 'a' * 63 + self.assertEqual(task.shorten_name_to_fit_kb(name), name) + + def test_truncates_to_63(self): + name = 'a' * 100 + self.assertEqual(len(task.shorten_name_to_fit_kb(name)), 63) + + def test_strips_trailing_hyphens(self): + name = 'a' * 60 + '---' + 'b' * 10 + result = task.shorten_name_to_fit_kb(name) + self.assertTrue(len(result) <= 63) + self.assertFalse(result.endswith('-')) + + def test_strips_trailing_underscores(self): + name = 'a' * 60 + '___' + 'b' * 10 + result = task.shorten_name_to_fit_kb(name) + self.assertTrue(len(result) <= 63) + self.assertFalse(result.endswith('_')) + + +# --------------------------------------------------------------------------- +# _encode_hstore / decode_hstore +# --------------------------------------------------------------------------- +class HstoreTest(unittest.TestCase): + def test_encode_single_task(self): + result = task._encode_hstore({'taskA'}) + self.assertEqual(result, '"taskA" => "NULL"') + + def test_encode_multiple_tasks(self): + result = task._encode_hstore({'taskA', 'taskB'}) + self.assertIn('"taskA" => "NULL"', result) + self.assertIn('"taskB" => "NULL"', result) + + def test_decode_single_task(self): + encoded = '"taskA"=>"NULL"' + self.assertEqual(task.decode_hstore(encoded), {'taskA'}) + + def test_decode_multiple_tasks(self): + encoded = '"taskA"=>"NULL", "taskB"=>"NULL"' + self.assertEqual(task.decode_hstore(encoded), {'taskA', 'taskB'}) + + def test_decode_db_format(self): + db_output = '"taskA"=>"NULL","taskB"=>"NULL"' + decoded = task.decode_hstore(db_output) + self.assertEqual(decoded, {'taskA', 'taskB'}) + + +# --------------------------------------------------------------------------- +# TaskInputOutput +# --------------------------------------------------------------------------- +class TaskInputOutputTest(unittest.TestCase): + def test_simple_task_name(self): + tio = task.TaskInputOutput(task='myTask') + self.assertEqual(tio.task, 'myTask') + self.assertFalse(tio.is_from_previous_workflow()) + + def test_previous_workflow_task(self): + tio = task.TaskInputOutput(task='prevWorkflow:taskName') + self.assertTrue(tio.is_from_previous_workflow()) + + def test_parsed_workflow_info_simple(self): + tio = task.TaskInputOutput(task='taskA') + first, second = tio.parsed_workflow_info() + self.assertEqual(first, 'taskA') + self.assertIsNone(second) + + def test_parsed_workflow_info_with_workflow(self): + tio = task.TaskInputOutput(task='wf1:taskA') + first, second = tio.parsed_workflow_info() + self.assertEqual(first, 'wf1') + self.assertEqual(second, 'taskA') + + def test_valid_regex(self): + tio = task.TaskInputOutput(task='taskA', regex=r'.*\.txt') + self.assertEqual(tio.regex, r'.*\.txt') + + def test_invalid_regex_raises(self): + with self.assertRaises(Exception): + task.TaskInputOutput(task='taskA', regex='[invalid') + + def test_empty_regex_allowed(self): + tio = task.TaskInputOutput(task='taskA', regex='') + self.assertEqual(tio.regex, '') + + def test_hash(self): + tio1 = task.TaskInputOutput(task='taskA') + tio2 = task.TaskInputOutput(task='taskA') + self.assertEqual(hash(tio1), hash(tio2)) + + def test_different_hash(self): + tio1 = task.TaskInputOutput(task='taskA') + tio2 = task.TaskInputOutput(task='taskB') + self.assertNotEqual(hash(tio1), hash(tio2)) + + +# --------------------------------------------------------------------------- +# URLInputOutput +# --------------------------------------------------------------------------- +class URLInputOutputTest(unittest.TestCase): + def test_creation(self): + uio = task.URLInputOutput(url='s3://bucket/path') + self.assertEqual(uio.url, 's3://bucket/path') + + def test_valid_regex(self): + uio = task.URLInputOutput(url='s3://b', regex=r'.*\.log') + self.assertEqual(uio.regex, r'.*\.log') + + def test_invalid_regex_raises(self): + with self.assertRaises(Exception): + task.URLInputOutput(url='s3://b', regex='[bad') + + def test_empty_regex(self): + uio = task.URLInputOutput(url='s3://b', regex='') + self.assertEqual(uio.regex, '') + + def test_hash(self): + u1 = task.URLInputOutput(url='s3://a') + u2 = task.URLInputOutput(url='s3://a') + self.assertEqual(hash(u1), hash(u2)) + + +# --------------------------------------------------------------------------- +# File +# --------------------------------------------------------------------------- +class FileTest(unittest.TestCase): + def test_valid_path(self): + f = task.File(path='/home/user/test.txt', contents='hello') + self.assertEqual(f.path, '/home/user/test.txt') + + def test_output_path_allowed(self): + f = task.File(path='/osmo/data/output/meta.yaml', contents='data') + self.assertEqual(f.path, '/osmo/data/output/meta.yaml') + + def test_osmo_path_rejected(self): + with self.assertRaises(Exception): + task.File(path='/osmo/forbidden', contents='data') + + def test_empty_path_rejected(self): + with self.assertRaises(Exception): + task.File(path='/', contents='data') + + def test_encoded_contents_plain(self): + f = task.File(path='/home/test.txt', contents='hello world') + import base64 + decoded = base64.b64decode(f.encoded_contents()).decode('utf-8') + self.assertEqual(decoded, 'hello world') + + def test_encoded_contents_already_base64(self): + import base64 + original = base64.b64encode(b'binary data').decode('utf-8') + f = task.File(path='/home/test.bin', contents=original, base64=True) + self.assertEqual(f.encoded_contents(), original) + + +# --------------------------------------------------------------------------- +# CheckpointSpec +# --------------------------------------------------------------------------- +class CheckpointSpecTest(unittest.TestCase): + def test_frequency_from_int(self): + spec = task.CheckpointSpec( + path='/data', url='s3://bucket/ckpt', frequency=300) + self.assertEqual(spec.frequency, datetime.timedelta(seconds=300)) + + def test_frequency_from_float(self): + spec = task.CheckpointSpec( + path='/data', url='s3://bucket/ckpt', frequency=60.5) + self.assertEqual(spec.frequency, datetime.timedelta(seconds=60.5)) + + def test_frequency_from_timedelta(self): + td = datetime.timedelta(minutes=10) + spec = task.CheckpointSpec( + path='/data', url='s3://bucket/ckpt', frequency=td) + self.assertEqual(spec.frequency, td) + + def test_valid_regex(self): + spec = task.CheckpointSpec( + path='/data', url='s3://bucket/ckpt', frequency=60, + regex=r'ckpt-\d+') + self.assertEqual(spec.regex, r'ckpt-\d+') + + def test_invalid_regex_raises(self): + with self.assertRaises(Exception): + task.CheckpointSpec( + path='/data', url='s3://bucket/ckpt', frequency=60, + regex='[bad') + + +# --------------------------------------------------------------------------- +# TaskKPI +# --------------------------------------------------------------------------- +class TaskKPITest(unittest.TestCase): + def test_creation(self): + kpi = task.TaskKPI(index='loss', path='/metrics/loss.json') + self.assertEqual(kpi.index, 'loss') + self.assertEqual(kpi.path, '/metrics/loss.json') + + +# --------------------------------------------------------------------------- +# TaskSpec validation +# --------------------------------------------------------------------------- +class TaskSpecValidationTest(unittest.TestCase): + def test_osmo_ctrl_name_rejected(self): + with self.assertRaises(Exception): + task.TaskSpec(name='osmo-ctrl', image='ubuntu', command=['echo']) + + def test_empty_command_rejected(self): + with self.assertRaises(Exception): + task.TaskSpec(name='myTask', image='ubuntu', command=[]) + + def test_duplicate_file_paths_rejected(self): + with self.assertRaises(Exception): + task.TaskSpec( + name='myTask', image='ubuntu', command=['run'], + files=[ + task.File(path='/home/a.txt', contents='1'), + task.File(path='/home/a.txt', contents='2'), + ]) + + def test_valid_task_spec(self): + spec = task.TaskSpec(name='myTask', image='ubuntu', command=['echo', 'hi']) + self.assertEqual(spec.name, 'myTask') + self.assertEqual(spec.image, 'ubuntu') + + +# --------------------------------------------------------------------------- +# render_group_templates +# --------------------------------------------------------------------------- +class RenderGroupTemplatesTest(unittest.TestCase): + def test_injects_labels(self): + templates = [{ + 'apiVersion': 'v1', + 'kind': 'ConfigMap', + 'metadata': {'name': 'test'}, + }] + labels = {'app': 'osmo'} + result = task.render_group_templates(templates, {}, labels) + self.assertEqual(len(result), 1) + self.assertEqual(result[0]['metadata']['labels']['app'], 'osmo') + + def test_strips_namespace(self): + templates = [{ + 'metadata': {'name': 'test', 'namespace': 'old-ns'}, + }] + result = task.render_group_templates(templates, {}, {}) + self.assertNotIn('namespace', result[0]['metadata']) + + def test_does_not_modify_original(self): + templates = [{'metadata': {'name': 'test'}}] + original = copy.deepcopy(templates) + task.render_group_templates(templates, {}, {'key': 'val'}) + self.assertEqual(templates, original) + + +# --------------------------------------------------------------------------- +# DownloadTypeMetrics (task_io.py) +# --------------------------------------------------------------------------- +class DownloadTypeMetricsTest(unittest.TestCase): + def test_values(self): + from src.utils.job.task_io import DownloadTypeMetrics + self.assertEqual(DownloadTypeMetrics.DOWNLOAD.value, 'download') + self.assertEqual(DownloadTypeMetrics.MOUNTPOINT.value, 'mountpoint-s3') + self.assertEqual(DownloadTypeMetrics.NOT_APPLICABLE.value, 'N/A') + + +if __name__ == '__main__': + unittest.main() diff --git a/src/utils/job/tests/test_workflow_models.py b/src/utils/job/tests/test_workflow_models.py new file mode 100644 index 000000000..e59e9cab4 --- /dev/null +++ b/src/utils/job/tests/test_workflow_models.py @@ -0,0 +1,223 @@ +""" +SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +SPDX-License-Identifier: Apache-2.0 +""" +import datetime +import unittest + +import yaml + +from src.utils import connectors +from src.utils.job import workflow + + +# --------------------------------------------------------------------------- +# WorkflowStatus +# --------------------------------------------------------------------------- +class WorkflowStatusAliveTest(unittest.TestCase): + def test_alive_statuses(self): + for status in workflow.WorkflowStatus.get_alive_statuses(): + self.assertTrue(status.alive(), f'{status.name} should be alive') + self.assertFalse(status.finished(), f'{status.name} should not be finished') + + def test_completed_is_finished(self): + self.assertTrue(workflow.WorkflowStatus.COMPLETED.finished()) + self.assertFalse(workflow.WorkflowStatus.COMPLETED.alive()) + + def test_failed_statuses(self): + failed_statuses = [ + workflow.WorkflowStatus.FAILED, + workflow.WorkflowStatus.FAILED_SUBMISSION, + workflow.WorkflowStatus.FAILED_SERVER_ERROR, + workflow.WorkflowStatus.FAILED_EXEC_TIMEOUT, + workflow.WorkflowStatus.FAILED_QUEUE_TIMEOUT, + workflow.WorkflowStatus.FAILED_CANCELED, + workflow.WorkflowStatus.FAILED_BACKEND_ERROR, + workflow.WorkflowStatus.FAILED_IMAGE_PULL, + workflow.WorkflowStatus.FAILED_EVICTED, + workflow.WorkflowStatus.FAILED_START_ERROR, + workflow.WorkflowStatus.FAILED_START_TIMEOUT, + workflow.WorkflowStatus.FAILED_PREEMPTED, + ] + for status in failed_statuses: + self.assertTrue(status.failed(), f'{status.name} should be failed') + self.assertTrue(status.finished(), f'{status.name} should be finished') + self.assertFalse(status.alive(), f'{status.name} should not be alive') + + def test_completed_is_not_failed(self): + self.assertFalse(workflow.WorkflowStatus.COMPLETED.failed()) + + def test_alive_statuses_list(self): + alive = workflow.WorkflowStatus.get_alive_statuses() + self.assertIn(workflow.WorkflowStatus.PENDING, alive) + self.assertIn(workflow.WorkflowStatus.RUNNING, alive) + self.assertIn(workflow.WorkflowStatus.WAITING, alive) + self.assertEqual(len(alive), 3) + + +# --------------------------------------------------------------------------- +# action_queue_name +# --------------------------------------------------------------------------- +class ActionQueueNameTest(unittest.TestCase): + def test_format(self): + result = workflow.action_queue_name('wf-1', 'train', 0) + self.assertEqual(result, 'client-connections:wf-1:train:0') + + def test_with_retry_id(self): + result = workflow.action_queue_name('wf-2', 'eval', 3) + self.assertEqual(result, 'client-connections:wf-2:eval:3') + + +# --------------------------------------------------------------------------- +# TimeoutSpec +# --------------------------------------------------------------------------- +class TimeoutSpecValidationTest(unittest.TestCase): + def test_from_int(self): + spec = workflow.TimeoutSpec(exec_timeout=300, queue_timeout=60) + self.assertEqual(spec.exec_timeout, datetime.timedelta(seconds=300)) + self.assertEqual(spec.queue_timeout, datetime.timedelta(seconds=60)) + + def test_from_float(self): + spec = workflow.TimeoutSpec(exec_timeout=60.5) + self.assertEqual(spec.exec_timeout, datetime.timedelta(seconds=60.5)) + + def test_from_none(self): + spec = workflow.TimeoutSpec(exec_timeout=None) + self.assertIsNone(spec.exec_timeout) + + def test_from_timedelta(self): + td = datetime.timedelta(hours=1) + spec = workflow.TimeoutSpec(exec_timeout=td) + self.assertEqual(spec.exec_timeout, td) + + +# --------------------------------------------------------------------------- +# split_assertion_rules +# --------------------------------------------------------------------------- +class SplitAssertionRulesTest(unittest.TestCase): + def test_static_assertion(self): + assertion = connectors.ResourceAssertion( + operator='GT', + left_operand='{{USER_CPU}}', + right_operand='0', + assert_message='CPU must be > 0') + static, k8 = workflow.split_assertion_rules([assertion]) + self.assertEqual(len(static), 1) + self.assertEqual(len(k8), 0) + + def test_k8_assertion_left_operand(self): + assertion = connectors.ResourceAssertion( + operator='GE', + left_operand='{{K8_GPU}}', + right_operand='{{USER_GPU}}', + assert_message='K8 GPU must be >= user GPU') + static, k8 = workflow.split_assertion_rules([assertion]) + self.assertEqual(len(static), 0) + self.assertEqual(len(k8), 1) + + def test_k8_assertion_right_operand(self): + assertion = connectors.ResourceAssertion( + operator='LE', + left_operand='{{USER_MEMORY}}', + right_operand='{{K8_MEMORY}}', + assert_message='User memory <= k8 memory') + static, k8 = workflow.split_assertion_rules([assertion]) + self.assertEqual(len(static), 0) + self.assertEqual(len(k8), 1) + + def test_mixed_assertions(self): + static_assertion = connectors.ResourceAssertion( + operator='GT', + left_operand='{{USER_CPU}}', + right_operand='0', + assert_message='CPU > 0') + k8_assertion = connectors.ResourceAssertion( + operator='GE', + left_operand='{{K8_GPU}}', + right_operand='{{USER_GPU}}', + assert_message='K8 GPU >= user GPU') + static, k8 = workflow.split_assertion_rules( + [static_assertion, k8_assertion]) + self.assertEqual(len(static), 1) + self.assertEqual(len(k8), 1) + + def test_empty_list(self): + static, k8 = workflow.split_assertion_rules([]) + self.assertEqual(len(static), 0) + self.assertEqual(len(k8), 0) + + +# --------------------------------------------------------------------------- +# VersionedWorkflowSpec +# --------------------------------------------------------------------------- +class VersionedWorkflowSpecTest(unittest.TestCase): + def test_unsupported_version_raises(self): + with self.assertRaises(Exception): + workflow.VersionedWorkflowSpec( + version=1, + workflow=workflow.WorkflowSpec( + name='test', + tasks=[{'name': 'a', 'command': ['echo'], 'image': 'ubuntu'}])) + + def test_default_version_is_2(self): + spec_dict = yaml.safe_load(''' + workflow: + name: test + tasks: + - name: task1 + command: ['echo'] + image: ubuntu + ''') + versioned = workflow.VersionedWorkflowSpec(**spec_dict) + self.assertEqual(versioned.version, 2) + + +# --------------------------------------------------------------------------- +# WorkflowSpec validation +# --------------------------------------------------------------------------- +class WorkflowSpecValidationTest(unittest.TestCase): + def test_no_tasks_or_groups_raises(self): + with self.assertRaises(Exception): + workflow.WorkflowSpec(name='test') + + def test_both_tasks_and_groups_raises(self): + with self.assertRaises(Exception): + workflow.WorkflowSpec( + name='test', + tasks=[{'name': 'a', 'command': ['echo'], 'image': 'ubuntu'}], + groups=[{'name': 'g', 'tasks': [ + {'name': 'b', 'command': ['echo'], 'image': 'ubuntu'} + ]}]) + + +# --------------------------------------------------------------------------- +# ResourceValidationResult +# --------------------------------------------------------------------------- +class ResourceValidationResultTest(unittest.TestCase): + def test_passed(self): + result = workflow.ResourceValidationResult(passed=True) + self.assertTrue(result.passed) + self.assertEqual(result.logs, '') + + def test_failed_with_logs(self): + result = workflow.ResourceValidationResult( + passed=False, logs='GPU count too low') + self.assertFalse(result.passed) + self.assertEqual(result.logs, 'GPU count too low') + + +if __name__ == '__main__': + unittest.main()