diff --git a/src/openai/cli/_api/chat/__init__.py b/src/openai/cli/_api/chat/__init__.py index 87d971630a..5260cb725e 100644 --- a/src/openai/cli/_api/chat/__init__.py +++ b/src/openai/cli/_api/chat/__init__.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from argparse import ArgumentParser -from . import completions +from . import completions, fine_tune if TYPE_CHECKING: from argparse import _SubParsersAction @@ -11,3 +11,4 @@ def register(subparser: _SubParsersAction[ArgumentParser]) -> None: completions.register(subparser) + fine_tune.register(subparser) diff --git a/src/openai/cli/_api/chat/fine_tune.py b/src/openai/cli/_api/chat/fine_tune.py new file mode 100644 index 0000000000..adc1edcc54 --- /dev/null +++ b/src/openai/cli/_api/chat/fine_tune.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import json +from typing import TYPE_CHECKING +from argparse import ArgumentParser + +from ..._models import BaseModel +from ..._utils import get_client, print_model +from ...._types import Omittable, omit +from ...._utils import is_given +from ....pagination import SyncCursorPage +from ....types.fine_tuning import FineTuningJob + +if TYPE_CHECKING: + from argparse import _SubParsersAction + + +def register(subparser: _SubParsersAction[ArgumentParser]) -> None: + sub = subparser.add_parser("chat.fine_tune.create") + sub.add_argument( + "-m", + "--model", + help="The chat model to fine-tune.", + required=True, + ) + sub.add_argument( + "-F", + "--training-file", + help="The training file to fine-tune the model on.", + required=True, + ) + sub.add_argument( + "-H", + "--hyperparameters", + help="JSON string of hyperparameters to use for fine-tuning.", + type=str, + ) + sub.add_argument( + "-s", + "--suffix", + help="A suffix to add to the fine-tuned model name.", + ) + sub.add_argument( + "-V", + "--validation-file", + help="The validation file to use for fine-tuning.", + ) + sub.set_defaults(func=CLIChatFineTune.create, args_model=CLIChatFineTuneCreateArgs) + + sub = subparser.add_parser("chat.fine_tune.list") + sub.add_argument( + "-a", + "--after", + help="Identifier for the last job from the previous pagination request.", + ) + sub.add_argument( + "-l", + "--limit", + help="Number of chat fine-tuning jobs to retrieve.", + type=int, + ) + sub.set_defaults(func=CLIChatFineTune.list, args_model=CLIChatFineTuneListArgs) + + sub = subparser.add_parser("chat.fine_tune.apply") + sub.add_argument( + "-i", + "--id", + help="The ID of the chat fine-tuning job to apply.", + required=True, + ) + sub.set_defaults(func=CLIChatFineTune.apply, args_model=CLIChatFineTuneApplyArgs) + + +class CLIChatFineTuneCreateArgs(BaseModel): + model: str + training_file: str + hyperparameters: Omittable[str] = omit + suffix: Omittable[str] = omit + validation_file: Omittable[str] = omit + + +class CLIChatFineTuneListArgs(BaseModel): + after: Omittable[str] = omit + limit: Omittable[int] = omit + + +class CLIChatFineTuneApplyArgs(BaseModel): + id: str + + +class CLIChatFineTune: + @staticmethod + def create(args: CLIChatFineTuneCreateArgs) -> None: + hyperparameters = json.loads(str(args.hyperparameters)) if is_given(args.hyperparameters) else omit + fine_tuning_job: FineTuningJob = get_client().fine_tuning.jobs.create( + model=args.model, + training_file=args.training_file, + hyperparameters=hyperparameters, + suffix=args.suffix, + validation_file=args.validation_file, + ) + print_model(fine_tuning_job) + + @staticmethod + def list(args: CLIChatFineTuneListArgs) -> None: + fine_tuning_jobs: SyncCursorPage[FineTuningJob] = get_client().fine_tuning.jobs.list( + after=args.after or omit, + limit=args.limit or omit, + ) + print_model(fine_tuning_jobs) + + @staticmethod + def apply(args: CLIChatFineTuneApplyArgs) -> None: + # `apply` is a CLI convenience alias to the existing resume operation. + fine_tuning_job: FineTuningJob = get_client().fine_tuning.jobs.resume(fine_tuning_job_id=args.id) + print_model(fine_tuning_job) diff --git a/tests/cli/test_chat_fine_tune_cli.py b/tests/cli/test_chat_fine_tune_cli.py new file mode 100644 index 0000000000..75ef4819c6 --- /dev/null +++ b/tests/cli/test_chat_fine_tune_cli.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from openai._types import omit +from openai.cli import _cli +from openai.cli._errors import CLIError +from openai.cli._api.chat import fine_tune + + +def test_chat_fine_tune_list_command_is_registered() -> None: + parser = _cli._build_parser() + parsed = parser.parse_args(["api", "chat.fine_tune.list"]) + + assert parsed.func == fine_tune.CLIChatFineTune.list + + +def test_chat_fine_tune_create_requires_training_file() -> None: + parser = _cli._build_parser() + + with pytest.raises(SystemExit): + parser.parse_args(["api", "chat.fine_tune.create", "--model", "gpt-4o-mini"]) + + +def test_chat_fine_tune_create_calls_fine_tuning_jobs_create(monkeypatch: pytest.MonkeyPatch) -> None: + client = MagicMock() + created_job = MagicMock() + client.fine_tuning.jobs.create.return_value = created_job + printed: list[object] = [] + + monkeypatch.setattr(fine_tune, "get_client", lambda: client) + monkeypatch.setattr(fine_tune, "print_model", lambda model: printed.append(model)) + + args = fine_tune.CLIChatFineTuneCreateArgs( + model="gpt-4o-mini", + training_file="file-123", + hyperparameters='{"n_epochs": 2}', + ) + fine_tune.CLIChatFineTune.create(args) + + client.fine_tuning.jobs.create.assert_called_once_with( + model="gpt-4o-mini", + training_file="file-123", + hyperparameters={"n_epochs": 2}, + suffix=omit, + validation_file=omit, + ) + assert printed == [created_job] + + +def test_chat_fine_tune_list_calls_fine_tuning_jobs_list(monkeypatch: pytest.MonkeyPatch) -> None: + client = MagicMock() + listed_jobs = MagicMock() + client.fine_tuning.jobs.list.return_value = listed_jobs + printed: list[object] = [] + + monkeypatch.setattr(fine_tune, "get_client", lambda: client) + monkeypatch.setattr(fine_tune, "print_model", lambda model: printed.append(model)) + + args = fine_tune.CLIChatFineTuneListArgs() + fine_tune.CLIChatFineTune.list(args) + + client.fine_tuning.jobs.list.assert_called_once_with(after=omit, limit=omit) + assert printed == [listed_jobs] + + +def test_chat_fine_tune_apply_aliases_to_resume(monkeypatch: pytest.MonkeyPatch) -> None: + client = MagicMock() + resumed_job = MagicMock() + client.fine_tuning.jobs.resume.return_value = resumed_job + printed: list[object] = [] + + monkeypatch.setattr(fine_tune, "get_client", lambda: client) + monkeypatch.setattr(fine_tune, "print_model", lambda model: printed.append(model)) + + fine_tune.CLIChatFineTune.apply(fine_tune.CLIChatFineTuneApplyArgs(id="ftjob-123")) + + client.fine_tuning.jobs.resume.assert_called_once_with(fine_tuning_job_id="ftjob-123") + assert printed == [resumed_job] + + +def test_cli_main_returns_error_code_when_chat_fine_tune_handler_raises( + monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] +) -> None: + monkeypatch.setattr(_cli.sys, "argv", ["openai", "api", "chat.fine_tune.list"]) + + def raise_error(_: object) -> None: + raise CLIError("boom") + + monkeypatch.setattr(fine_tune.CLIChatFineTune, "list", staticmethod(raise_error)) + + exit_code = _cli.main() + captured = capsys.readouterr() + + assert exit_code == 1 + assert "Error:" in captured.err + assert "boom" in captured.err