Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions runpod/cli/groups/config/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@
"""

import os
import tempfile
from pathlib import Path

import tomli as toml
import tomlkit

CREDENTIAL_FILE = os.path.expanduser("~/.runpod/config.toml")


def set_credentials(api_key: str, profile: str = "default", overwrite=False) -> None:
"""
Sets the user's credentials in ~/.runpod/config.toml
If profile already exists user must use `update_credentials` instead.
If profile already exists user must pass overwrite=True.

Args:
api_key (str): The user's API key.
Expand All @@ -27,23 +29,37 @@ def set_credentials(api_key: str, profile: str = "default", overwrite=False) ->
[default]
api_key = "RUNPOD_API_KEY"
"""
os.makedirs(os.path.dirname(CREDENTIAL_FILE), exist_ok=True)
cred_dir = os.path.dirname(CREDENTIAL_FILE)
os.makedirs(cred_dir, exist_ok=True)
Path(CREDENTIAL_FILE).touch(exist_ok=True)

if not overwrite:
with open(CREDENTIAL_FILE, "r", encoding="UTF-8") as cred_file:
try:
with open(CREDENTIAL_FILE, "rb") as cred_file:
existing = toml.load(cred_file)
except (TypeError, ValueError):
existing = {}
if profile in existing:
content = cred_file.read()
config = (
tomlkit.parse(content)
if content.strip()
else tomlkit.document()
)
except tomlkit.exceptions.ParseError as exc:
raise ValueError("~/.runpod/config.toml is not a valid TOML file.") from exc

if not overwrite:
if profile in config:
raise ValueError(
"Profile already exists. Use `update_credentials` instead."
"Profile already exists. Use set_credentials(overwrite=True) to update."
)

with open(CREDENTIAL_FILE, "w", encoding="UTF-8") as cred_file:
cred_file.write("[" + profile + "]\n")
cred_file.write('api_key = "' + api_key + '"\n')
config[profile] = {"api_key": api_key}

fd, tmp_path = tempfile.mkstemp(dir=cred_dir, suffix=".toml")
try:
with os.fdopen(fd, "w", encoding="UTF-8") as tmp_file:
tomlkit.dump(config, tmp_file)
os.replace(tmp_path, CREDENTIAL_FILE)
except BaseException:
os.unlink(tmp_path)
raise


def check_credentials(profile: str = "default"):
Expand Down
128 changes: 104 additions & 24 deletions tests/test_cli/test_cli_groups/test_config_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,70 @@ class TestConfig(unittest.TestCase):
def setUp(self) -> None:
self.sample_credentials = "[default]\n" 'api_key = "RUNPOD_API_KEY"\n'

@patch("runpod.cli.groups.config.functions.toml.load")
@patch("builtins.open", new_callable=mock_open())
def test_set_credentials(self, mock_file, mock_toml_load):
@patch("runpod.cli.groups.config.functions.os.replace")
@patch("runpod.cli.groups.config.functions.os.unlink")
@patch("runpod.cli.groups.config.functions.os.fdopen", new_callable=mock_open)
@patch("runpod.cli.groups.config.functions.tempfile.mkstemp")
@patch("runpod.cli.groups.config.functions.tomlkit.dump")
@patch("runpod.cli.groups.config.functions.tomlkit.document")
@patch("runpod.cli.groups.config.functions.Path.touch")
@patch("runpod.cli.groups.config.functions.os.makedirs")
@patch("builtins.open", new_callable=mock_open, read_data="")
def test_set_credentials(
self, mock_file, _mock_makedirs, _mock_touch, mock_document,
mock_dump, mock_mkstemp, _mock_fdopen, _mock_unlink, _mock_replace,
):
"""
Tests the set_credentials function.
"""
mock_toml_load.return_value = ""
mock_mkstemp.return_value = (99, "/tmp/cred.toml")
mock_document.side_effect = [{}, {"default": True}]
functions.set_credentials("RUNPOD_API_KEY")

mock_file.assert_called_with(functions.CREDENTIAL_FILE, "w", encoding="UTF-8")
assert any(
call.args[0] == functions.CREDENTIAL_FILE
and call.args[1] == "r"
and call.kwargs.get("encoding") == "UTF-8"
for call in mock_file.call_args_list
)
assert mock_dump.called

with self.assertRaises(ValueError) as context:
mock_toml_load.return_value = {"default": True}
functions.set_credentials("RUNPOD_API_KEY")

self.assertEqual(
str(context.exception),
"Profile already exists. Use `update_credentials` instead.",
"Profile already exists. Use set_credentials(overwrite=True) to update.",
)

@patch("runpod.cli.groups.config.functions.os.replace")
@patch("runpod.cli.groups.config.functions.os.unlink")
@patch("runpod.cli.groups.config.functions.os.fdopen", new_callable=mock_open)
@patch("runpod.cli.groups.config.functions.tempfile.mkstemp")
@patch("runpod.cli.groups.config.functions.tomlkit.dump")
@patch("runpod.cli.groups.config.functions.Path.touch")
@patch("runpod.cli.groups.config.functions.os.makedirs")
@patch(
"builtins.open",
new_callable=mock_open,
read_data='[default]\napi_key = "EXISTING_KEY"\n\n[profile1]\napi_key = "KEY1"\n',
)
def test_set_credentials_preserves_existing_profiles(
self, _mock_file, _mock_makedirs, _mock_touch, mock_dump,
mock_mkstemp, _mock_fdopen, _mock_unlink, _mock_replace,
):
"""Adding a new profile must preserve all existing profiles."""
mock_mkstemp.return_value = (99, "/tmp/cred.toml")
functions.set_credentials("NEW_KEY", profile="profile2")

dumped_config = mock_dump.call_args[0][0]
assert "default" in dumped_config
assert dumped_config["default"]["api_key"] == "EXISTING_KEY"
assert "profile1" in dumped_config
assert dumped_config["profile1"]["api_key"] == "KEY1"
assert "profile2" in dumped_config
assert dumped_config["profile2"]["api_key"] == "NEW_KEY"

@patch("builtins.open", new_callable=mock_open())
@patch("runpod.cli.groups.config.functions.toml.load")
@patch("runpod.cli.groups.config.functions.os.path.exists")
Expand Down Expand Up @@ -124,26 +168,62 @@ def test_get_credentials_type_error(
result = functions.get_credentials("default")
assert result is None

@patch("runpod.cli.groups.config.functions.os.replace")
@patch("runpod.cli.groups.config.functions.os.unlink")
@patch("runpod.cli.groups.config.functions.os.fdopen", new_callable=mock_open)
@patch("runpod.cli.groups.config.functions.tempfile.mkstemp")
@patch("runpod.cli.groups.config.functions.tomlkit.dump")
@patch("runpod.cli.groups.config.functions.Path.touch")
@patch("runpod.cli.groups.config.functions.os.makedirs")
@patch("runpod.cli.groups.config.functions.toml.load")
@patch("builtins.open", new_callable=mock_open())
def test_set_credentials_corrupted_toml_allows_overwrite(
self, _mock_file, mock_toml_load, _mock_makedirs, _mock_touch
@patch(
"builtins.open",
new_callable=mock_open,
read_data='[default]\napi_key = "OLD_KEY"\n',
)
def test_set_credentials_overwrite_replaces_existing_profile(
self, _mock_file, _mock_makedirs, _mock_touch, mock_dump,
mock_mkstemp, _mock_fdopen, _mock_unlink, _mock_replace,
):
"""set_credentials with overwrite=True ignores corrupted existing file."""
mock_toml_load.side_effect = ValueError("Invalid TOML")
# overwrite=True skips the toml.load check entirely
functions.set_credentials("NEW_KEY", overwrite=True)

"""overwrite=True replaces an existing profile's api_key."""
mock_mkstemp.return_value = (99, "/tmp/cred.toml")
functions.set_credentials("NEW_KEY", profile="default", overwrite=True)

dumped_config = mock_dump.call_args[0][0]
assert dumped_config["default"]["api_key"] == "NEW_KEY"

@patch("runpod.cli.groups.config.functions.os.replace")
@patch("runpod.cli.groups.config.functions.os.unlink")
@patch("runpod.cli.groups.config.functions.os.fdopen", new_callable=mock_open)
@patch("runpod.cli.groups.config.functions.tempfile.mkstemp")
@patch("runpod.cli.groups.config.functions.tomlkit.dump", side_effect=OSError("disk full"))
@patch("runpod.cli.groups.config.functions.Path.touch")
@patch("runpod.cli.groups.config.functions.os.makedirs")
@patch("runpod.cli.groups.config.functions.toml.load")
@patch("builtins.open", new_callable=mock_open())
def test_set_credentials_corrupted_toml_no_overwrite(
self, _mock_file, mock_toml_load, _mock_makedirs, _mock_touch
@patch("builtins.open", new_callable=mock_open, read_data="")
def test_set_credentials_cleans_up_temp_on_dump_failure(
self, _mock_file, _mock_makedirs, _mock_touch, _mock_dump,
mock_mkstemp, _mock_fdopen, mock_unlink, mock_replace,
):
"""Temp file is removed and original config untouched when dump fails."""
mock_mkstemp.return_value = (99, "/tmp/cred.toml")
with self.assertRaises(OSError):
functions.set_credentials("KEY")

mock_unlink.assert_called_once_with("/tmp/cred.toml")
mock_replace.assert_not_called()

@patch("runpod.cli.groups.config.functions.os.replace")
@patch("runpod.cli.groups.config.functions.os.unlink")
@patch("runpod.cli.groups.config.functions.tempfile.mkstemp")
@patch("runpod.cli.groups.config.functions.tomlkit.dump")
@patch("runpod.cli.groups.config.functions.Path.touch")
@patch("runpod.cli.groups.config.functions.os.makedirs")
@patch("builtins.open", new_callable=mock_open, read_data="not valid toml {{{")
def test_set_credentials_corrupted_toml_raises(
self, _mock_file, _mock_makedirs, _mock_touch, _mock_dump,
_mock_mkstemp, _mock_unlink, _mock_replace,
):
"""set_credentials without overwrite treats corrupted file as empty."""
mock_toml_load.side_effect = ValueError("Invalid TOML")
# Should not raise — corrupted file is treated as having no profiles
functions.set_credentials("NEW_KEY", overwrite=False)
"""set_credentials raises ValueError on corrupted TOML regardless of overwrite."""
with self.assertRaises(ValueError):
functions.set_credentials("NEW_KEY", overwrite=True)
with self.assertRaises(ValueError):
functions.set_credentials("NEW_KEY", overwrite=False)
3 changes: 2 additions & 1 deletion tests/test_shared/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ class TestAPIKey(unittest.TestCase):
"""

@patch("builtins.open", new_callable=mock_open, read_data=CREDENTIALS)
def test_use_file_credentials(self, mock_file):
@patch("runpod.cli.groups.config.functions.os.path.exists", return_value=True)
def test_use_file_credentials(self, _mock_exists, mock_file):
"""
Test that the API key is read from the credentials file
"""
Expand Down
Loading