Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions src/utils/job/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,57 @@ osmo_py_test(
size = "large",
tags = ["requires-network"],
)

py_test(
name = "test_common_utils",
srcs = ["test_common_utils.py"],
deps = [
"//src/lib/data/storage",
"//src/lib/utils:osmo_errors",
"//src/utils/job",
],
)

py_test(
name = "test_backend_job_defs",
srcs = ["test_backend_job_defs.py"],
deps = [
"//src/utils/job:backend_job",
],
)

py_test(
name = "test_task_models",
srcs = ["test_task_models.py"],
deps = [
"//src/lib/utils:osmo_errors",
"//src/utils/job",
],
)

py_test(
name = "test_workflow_models",
srcs = ["test_workflow_models.py"],
deps = [
"//src/utils/connectors",
"//src/utils/job",
],
)

py_test(
name = "test_kb_objects_unit",
srcs = ["test_kb_objects_unit.py"],
deps = [
"//src/lib/utils:priority",
"//src/utils/job",
"//src/utils/job:backend_job",
],
)

py_test(
name = "test_jobs_base_unit",
srcs = ["test_jobs_base_unit.py"],
deps = [
"//src/utils/job:jobs_base",
],
)
119 changes: 119 additions & 0 deletions src/utils/job/tests/test_backend_job_defs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

SPDX-License-Identifier: Apache-2.0
"""
import unittest

from src.utils.job import backend_job_defs


class EffectiveApiVersionTest(unittest.TestCase):
def test_generic_api_preferred(self):
spec = backend_job_defs.BackendCleanupSpec(
labels={'app': 'test'},
generic_api=backend_job_defs.BackendGenericApi(
api_version='scheduling.run.ai/v2', kind='Queue'),
custom_api=backend_job_defs.BackendCustomApi(
api_major='old.api', api_minor='v1', path='queues'))
self.assertEqual(spec.effective_api_version, 'scheduling.run.ai/v2')

def test_custom_api_fallback(self):
spec = backend_job_defs.BackendCleanupSpec(
labels={'app': 'test'},
custom_api=backend_job_defs.BackendCustomApi(
api_major='scheduling.run.ai', api_minor='v2alpha2', path='podgroups'))
self.assertEqual(spec.effective_api_version, 'scheduling.run.ai/v2alpha2')

def test_default_v1_when_no_api(self):
spec = backend_job_defs.BackendCleanupSpec(labels={'app': 'test'})
self.assertEqual(spec.effective_api_version, 'v1')


class EffectiveKindTest(unittest.TestCase):
def test_generic_api_preferred(self):
spec = backend_job_defs.BackendCleanupSpec(
labels={'app': 'test'},
resource_type='Pod',
generic_api=backend_job_defs.BackendGenericApi(
api_version='v1', kind='Service'))
self.assertEqual(spec.effective_kind, 'Service')

def test_resource_type_fallback(self):
spec = backend_job_defs.BackendCleanupSpec(
labels={'app': 'test'},
resource_type='Pod')
self.assertEqual(spec.effective_kind, 'Pod')

def test_none_when_nothing_set(self):
spec = backend_job_defs.BackendCleanupSpec(labels={'app': 'test'})
self.assertIsNone(spec.effective_kind)


class K8sSelectorTest(unittest.TestCase):
def test_single_label(self):
spec = backend_job_defs.BackendCleanupSpec(labels={'app': 'osmo'})
self.assertEqual(spec.k8s_selector, 'app=osmo')

def test_multiple_labels(self):
spec = backend_job_defs.BackendCleanupSpec(
labels={'app': 'osmo', 'env': 'prod'})
parts = spec.k8s_selector.split(',')
self.assertEqual(len(parts), 2)
self.assertIn('app=osmo', parts)
self.assertIn('env=prod', parts)


class BackendCreateGroupMixinTest(unittest.TestCase):
def test_default_values(self):
mixin = backend_job_defs.BackendCreateGroupMixin(
group_name='group1', k8s_resources=[{'kind': 'Pod'}])
self.assertEqual(mixin.backend_k8s_timeout, 60)
self.assertEqual(mixin.scheduler_settings, {})


class BackendGenericApiTest(unittest.TestCase):
def test_creation(self):
api = backend_job_defs.BackendGenericApi(api_version='v1', kind='Pod')
self.assertEqual(api.api_version, 'v1')
self.assertEqual(api.kind, 'Pod')


class BackendCustomApiTest(unittest.TestCase):
def test_creation(self):
api = backend_job_defs.BackendCustomApi(
api_major='scheduling.run.ai', api_minor='v2alpha2', path='podgroups')
self.assertEqual(api.api_major, 'scheduling.run.ai')
self.assertEqual(api.api_minor, 'v2alpha2')
self.assertEqual(api.path, 'podgroups')


class BackendSynchronizeQueuesMixinTest(unittest.TestCase):
def test_default_immutable_kinds(self):
mixin = backend_job_defs.BackendSynchronizeQueuesMixin(
cleanup_specs=[],
k8s_resources=[])
self.assertEqual(mixin.immutable_kinds, [])

def test_with_immutable_kinds(self):
mixin = backend_job_defs.BackendSynchronizeQueuesMixin(
cleanup_specs=[],
k8s_resources=[],
immutable_kinds=['Topology'])
self.assertEqual(mixin.immutable_kinds, ['Topology'])


if __name__ == '__main__':
unittest.main()
111 changes: 111 additions & 0 deletions src/utils/job/tests/test_common_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

SPDX-License-Identifier: Apache-2.0
"""
import datetime
import os
import unittest

from src.lib.data.storage.backends.common import StoragePath
from src.lib.utils import osmo_errors
from src.utils.job import common


class GetWorkflowLogsPathTest(unittest.TestCase):
def test_returns_joined_path(self):
result = common.get_workflow_logs_path('wf-123', 'output.log')
self.assertEqual(result, os.path.join('wf-123', 'output.log'))

def test_different_filenames(self):
result = common.get_workflow_logs_path('my-workflow', 'error.log')
self.assertEqual(result, os.path.join('my-workflow', 'error.log'))


class GetWorkflowAppPathTest(unittest.TestCase):
def test_without_prefix(self):
path_params = StoragePath(
scheme='s3', host='bucket', endpoint_url='http://s3',
container='bucket', region='us-east-1', prefix='')
result = common.get_workflow_app_path('app-uuid-123', 1, path_params)
self.assertEqual(result, os.path.join('app-uuid-123', '1', 'workflow_app.txt'))

def test_with_prefix(self):
path_params = StoragePath(
scheme='s3', host='bucket', endpoint_url='http://s3',
container='bucket', region='us-east-1', prefix='my-prefix')
result = common.get_workflow_app_path('app-uuid-123', 2, path_params)
self.assertEqual(
result,
os.path.join('my-prefix', 'app-uuid-123', '2', 'workflow_app.txt'))

def test_version_is_stringified(self):
path_params = StoragePath(
scheme='s3', host='bucket', endpoint_url='http://s3',
container='bucket', region='us-east-1', prefix='')
result = common.get_workflow_app_path('uuid', 42, path_params)
self.assertIn('42', result)


class CalculateTotalTimeoutTest(unittest.TestCase):
def test_returns_sum_of_timeouts(self):
queue = datetime.timedelta(seconds=300)
execution = datetime.timedelta(seconds=600)
result = common.calculate_total_timeout('wf-1', queue, execution)
self.assertEqual(result, 900)

def test_raises_without_exec_timeout(self):
with self.assertRaises(osmo_errors.OSMODatabaseError):
common.calculate_total_timeout(
'wf-1', queue_timeout=datetime.timedelta(seconds=300))

def test_raises_without_queue_timeout(self):
with self.assertRaises(osmo_errors.OSMODatabaseError):
common.calculate_total_timeout(
'wf-1', exec_timeout=datetime.timedelta(seconds=300))

def test_raises_with_both_none(self):
with self.assertRaises(osmo_errors.OSMODatabaseError):
common.calculate_total_timeout('wf-1')

def test_truncates_to_int(self):
queue = datetime.timedelta(seconds=1.7)
execution = datetime.timedelta(seconds=2.9)
result = common.calculate_total_timeout('wf-1', queue, execution)
self.assertEqual(result, int(1.7) + int(2.9))


class BarrierKeyTest(unittest.TestCase):
def test_format(self):
result = common.barrier_key('wf-1', 'group-a', 'sync')
self.assertEqual(result, 'client-connections:wf-1:group-a:barrier-sync')

def test_different_inputs(self):
result = common.barrier_key('workflow', 'train', 'ready')
self.assertEqual(result, 'client-connections:workflow:train:barrier-ready')


class WorkflowPluginsTest(unittest.TestCase):
def test_default_rsync_false(self):
plugins = common.WorkflowPlugins()
self.assertFalse(plugins.rsync)

def test_rsync_enabled(self):
plugins = common.WorkflowPlugins(rsync=True)
self.assertTrue(plugins.rsync)


if __name__ == '__main__':
unittest.main()
86 changes: 86 additions & 0 deletions src/utils/job/tests/test_jobs_base_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

SPDX-License-Identifier: Apache-2.0
"""
import datetime
import unittest
from unittest import mock

from src.utils.job import jobs_base


# ---------------------------------------------------------------------------
# JobStatus
# ---------------------------------------------------------------------------
class JobStatusTest(unittest.TestCase):
def test_values(self):
self.assertEqual(jobs_base.JobStatus.SUCCESS.value, 'SUCCESS')
self.assertEqual(jobs_base.JobStatus.FAILED_RETRY.value, 'FAILED_RETRY')
self.assertEqual(jobs_base.JobStatus.FAILED_NO_RETRY.value, 'FAILED_NO_RETRY')


# ---------------------------------------------------------------------------
# JobResult
# ---------------------------------------------------------------------------
class JobResultTest(unittest.TestCase):
def test_success_result(self):
result = jobs_base.JobResult(status=jobs_base.JobStatus.SUCCESS, message=None)
self.assertFalse(result.retry)

def test_retry_result(self):
result = jobs_base.JobResult(
status=jobs_base.JobStatus.FAILED_RETRY, message='network error')
self.assertTrue(result.retry)

def test_no_retry_result(self):
result = jobs_base.JobResult(
status=jobs_base.JobStatus.FAILED_NO_RETRY, message='bad input')
self.assertFalse(result.retry)

def test_str_with_message(self):
result = jobs_base.JobResult(
status=jobs_base.JobStatus.FAILED_RETRY, message='timeout')
self.assertEqual(str(result), 'FAILED_RETRY: timeout')

def test_str_without_message(self):
result = jobs_base.JobResult(
status=jobs_base.JobStatus.SUCCESS, message=None)
self.assertEqual(str(result), 'SUCCESS')


# ---------------------------------------------------------------------------
# update_progress_writer
# ---------------------------------------------------------------------------
class UpdateProgressWriterTest(unittest.TestCase):
def test_does_not_report_when_time_not_elapsed(self):
writer = mock.MagicMock()
last = datetime.datetime.now()
freq = datetime.timedelta(hours=1)
result = jobs_base.update_progress_writer(writer, last, freq)
writer.report_progress.assert_not_called()
self.assertEqual(result, last)


# ---------------------------------------------------------------------------
# UNIQUE_JOB_TTL constant
# ---------------------------------------------------------------------------
class UniqueJobTtlTest(unittest.TestCase):
def test_value(self):
self.assertEqual(jobs_base.UNIQUE_JOB_TTL, 5 * 24 * 60 * 60)


if __name__ == '__main__':
unittest.main()
Loading
Loading