diff --git a/runpod/cli/groups/config/functions.py b/runpod/cli/groups/config/functions.py index ac5043b8..72e9b738 100644 --- a/runpod/cli/groups/config/functions.py +++ b/runpod/cli/groups/config/functions.py @@ -6,9 +6,11 @@ """ import os +import tempfile from pathlib import Path import tomli as toml +import tomlkit CREDENTIAL_FILE = os.path.expanduser("~/.runpod/config.toml") @@ -16,7 +18,7 @@ def set_credentials(api_key: str, profile: str = "default", overwrite=False) -> None: """ Sets the user's credentials in ~/.runpod/config.toml - If profile already exists user must use `update_credentials` instead. + If profile already exists user must pass overwrite=True. Args: api_key (str): The user's API key. @@ -27,23 +29,37 @@ def set_credentials(api_key: str, profile: str = "default", overwrite=False) -> [default] api_key = "RUNPOD_API_KEY" """ - os.makedirs(os.path.dirname(CREDENTIAL_FILE), exist_ok=True) + cred_dir = os.path.dirname(CREDENTIAL_FILE) + os.makedirs(cred_dir, exist_ok=True) Path(CREDENTIAL_FILE).touch(exist_ok=True) - if not overwrite: + with open(CREDENTIAL_FILE, "r", encoding="UTF-8") as cred_file: try: - with open(CREDENTIAL_FILE, "rb") as cred_file: - existing = toml.load(cred_file) - except (TypeError, ValueError): - existing = {} - if profile in existing: + content = cred_file.read() + config = ( + tomlkit.parse(content) + if content.strip() + else tomlkit.document() + ) + except tomlkit.exceptions.ParseError as exc: + raise ValueError("~/.runpod/config.toml is not a valid TOML file.") from exc + + if not overwrite: + if profile in config: raise ValueError( - "Profile already exists. Use `update_credentials` instead." + "Profile already exists. Use set_credentials(overwrite=True) to update." ) - with open(CREDENTIAL_FILE, "w", encoding="UTF-8") as cred_file: - cred_file.write("[" + profile + "]\n") - cred_file.write('api_key = "' + api_key + '"\n') + config[profile] = {"api_key": api_key} + + fd, tmp_path = tempfile.mkstemp(dir=cred_dir, suffix=".toml") + try: + with os.fdopen(fd, "w", encoding="UTF-8") as tmp_file: + tomlkit.dump(config, tmp_file) + os.replace(tmp_path, CREDENTIAL_FILE) + except BaseException: + os.unlink(tmp_path) + raise def check_credentials(profile: str = "default"): diff --git a/tests/test_cli/test_cli_groups/test_config_functions.py b/tests/test_cli/test_cli_groups/test_config_functions.py index 192b2cca..abb5b118 100644 --- a/tests/test_cli/test_cli_groups/test_config_functions.py +++ b/tests/test_cli/test_cli_groups/test_config_functions.py @@ -14,26 +14,70 @@ class TestConfig(unittest.TestCase): def setUp(self) -> None: self.sample_credentials = "[default]\n" 'api_key = "RUNPOD_API_KEY"\n' - @patch("runpod.cli.groups.config.functions.toml.load") - @patch("builtins.open", new_callable=mock_open()) - def test_set_credentials(self, mock_file, mock_toml_load): + @patch("runpod.cli.groups.config.functions.os.replace") + @patch("runpod.cli.groups.config.functions.os.unlink") + @patch("runpod.cli.groups.config.functions.os.fdopen", new_callable=mock_open) + @patch("runpod.cli.groups.config.functions.tempfile.mkstemp") + @patch("runpod.cli.groups.config.functions.tomlkit.dump") + @patch("runpod.cli.groups.config.functions.tomlkit.document") + @patch("runpod.cli.groups.config.functions.Path.touch") + @patch("runpod.cli.groups.config.functions.os.makedirs") + @patch("builtins.open", new_callable=mock_open, read_data="") + def test_set_credentials( + self, mock_file, _mock_makedirs, _mock_touch, mock_document, + mock_dump, mock_mkstemp, _mock_fdopen, _mock_unlink, _mock_replace, + ): """ Tests the set_credentials function. """ - mock_toml_load.return_value = "" + mock_mkstemp.return_value = (99, "/tmp/cred.toml") + mock_document.side_effect = [{}, {"default": True}] functions.set_credentials("RUNPOD_API_KEY") - mock_file.assert_called_with(functions.CREDENTIAL_FILE, "w", encoding="UTF-8") + assert any( + call.args[0] == functions.CREDENTIAL_FILE + and call.args[1] == "r" + and call.kwargs.get("encoding") == "UTF-8" + for call in mock_file.call_args_list + ) + assert mock_dump.called with self.assertRaises(ValueError) as context: - mock_toml_load.return_value = {"default": True} functions.set_credentials("RUNPOD_API_KEY") self.assertEqual( str(context.exception), - "Profile already exists. Use `update_credentials` instead.", + "Profile already exists. Use set_credentials(overwrite=True) to update.", ) + @patch("runpod.cli.groups.config.functions.os.replace") + @patch("runpod.cli.groups.config.functions.os.unlink") + @patch("runpod.cli.groups.config.functions.os.fdopen", new_callable=mock_open) + @patch("runpod.cli.groups.config.functions.tempfile.mkstemp") + @patch("runpod.cli.groups.config.functions.tomlkit.dump") + @patch("runpod.cli.groups.config.functions.Path.touch") + @patch("runpod.cli.groups.config.functions.os.makedirs") + @patch( + "builtins.open", + new_callable=mock_open, + read_data='[default]\napi_key = "EXISTING_KEY"\n\n[profile1]\napi_key = "KEY1"\n', + ) + def test_set_credentials_preserves_existing_profiles( + self, _mock_file, _mock_makedirs, _mock_touch, mock_dump, + mock_mkstemp, _mock_fdopen, _mock_unlink, _mock_replace, + ): + """Adding a new profile must preserve all existing profiles.""" + mock_mkstemp.return_value = (99, "/tmp/cred.toml") + functions.set_credentials("NEW_KEY", profile="profile2") + + dumped_config = mock_dump.call_args[0][0] + assert "default" in dumped_config + assert dumped_config["default"]["api_key"] == "EXISTING_KEY" + assert "profile1" in dumped_config + assert dumped_config["profile1"]["api_key"] == "KEY1" + assert "profile2" in dumped_config + assert dumped_config["profile2"]["api_key"] == "NEW_KEY" + @patch("builtins.open", new_callable=mock_open()) @patch("runpod.cli.groups.config.functions.toml.load") @patch("runpod.cli.groups.config.functions.os.path.exists") @@ -124,26 +168,62 @@ def test_get_credentials_type_error( result = functions.get_credentials("default") assert result is None + @patch("runpod.cli.groups.config.functions.os.replace") + @patch("runpod.cli.groups.config.functions.os.unlink") + @patch("runpod.cli.groups.config.functions.os.fdopen", new_callable=mock_open) + @patch("runpod.cli.groups.config.functions.tempfile.mkstemp") + @patch("runpod.cli.groups.config.functions.tomlkit.dump") @patch("runpod.cli.groups.config.functions.Path.touch") @patch("runpod.cli.groups.config.functions.os.makedirs") - @patch("runpod.cli.groups.config.functions.toml.load") - @patch("builtins.open", new_callable=mock_open()) - def test_set_credentials_corrupted_toml_allows_overwrite( - self, _mock_file, mock_toml_load, _mock_makedirs, _mock_touch + @patch( + "builtins.open", + new_callable=mock_open, + read_data='[default]\napi_key = "OLD_KEY"\n', + ) + def test_set_credentials_overwrite_replaces_existing_profile( + self, _mock_file, _mock_makedirs, _mock_touch, mock_dump, + mock_mkstemp, _mock_fdopen, _mock_unlink, _mock_replace, ): - """set_credentials with overwrite=True ignores corrupted existing file.""" - mock_toml_load.side_effect = ValueError("Invalid TOML") - # overwrite=True skips the toml.load check entirely - functions.set_credentials("NEW_KEY", overwrite=True) - + """overwrite=True replaces an existing profile's api_key.""" + mock_mkstemp.return_value = (99, "/tmp/cred.toml") + functions.set_credentials("NEW_KEY", profile="default", overwrite=True) + + dumped_config = mock_dump.call_args[0][0] + assert dumped_config["default"]["api_key"] == "NEW_KEY" + + @patch("runpod.cli.groups.config.functions.os.replace") + @patch("runpod.cli.groups.config.functions.os.unlink") + @patch("runpod.cli.groups.config.functions.os.fdopen", new_callable=mock_open) + @patch("runpod.cli.groups.config.functions.tempfile.mkstemp") + @patch("runpod.cli.groups.config.functions.tomlkit.dump", side_effect=OSError("disk full")) @patch("runpod.cli.groups.config.functions.Path.touch") @patch("runpod.cli.groups.config.functions.os.makedirs") - @patch("runpod.cli.groups.config.functions.toml.load") - @patch("builtins.open", new_callable=mock_open()) - def test_set_credentials_corrupted_toml_no_overwrite( - self, _mock_file, mock_toml_load, _mock_makedirs, _mock_touch + @patch("builtins.open", new_callable=mock_open, read_data="") + def test_set_credentials_cleans_up_temp_on_dump_failure( + self, _mock_file, _mock_makedirs, _mock_touch, _mock_dump, + mock_mkstemp, _mock_fdopen, mock_unlink, mock_replace, + ): + """Temp file is removed and original config untouched when dump fails.""" + mock_mkstemp.return_value = (99, "/tmp/cred.toml") + with self.assertRaises(OSError): + functions.set_credentials("KEY") + + mock_unlink.assert_called_once_with("/tmp/cred.toml") + mock_replace.assert_not_called() + + @patch("runpod.cli.groups.config.functions.os.replace") + @patch("runpod.cli.groups.config.functions.os.unlink") + @patch("runpod.cli.groups.config.functions.tempfile.mkstemp") + @patch("runpod.cli.groups.config.functions.tomlkit.dump") + @patch("runpod.cli.groups.config.functions.Path.touch") + @patch("runpod.cli.groups.config.functions.os.makedirs") + @patch("builtins.open", new_callable=mock_open, read_data="not valid toml {{{") + def test_set_credentials_corrupted_toml_raises( + self, _mock_file, _mock_makedirs, _mock_touch, _mock_dump, + _mock_mkstemp, _mock_unlink, _mock_replace, ): - """set_credentials without overwrite treats corrupted file as empty.""" - mock_toml_load.side_effect = ValueError("Invalid TOML") - # Should not raise — corrupted file is treated as having no profiles - functions.set_credentials("NEW_KEY", overwrite=False) + """set_credentials raises ValueError on corrupted TOML regardless of overwrite.""" + with self.assertRaises(ValueError): + functions.set_credentials("NEW_KEY", overwrite=True) + with self.assertRaises(ValueError): + functions.set_credentials("NEW_KEY", overwrite=False) diff --git a/tests/test_shared/test_auth.py b/tests/test_shared/test_auth.py index a70292d3..00f85848 100644 --- a/tests/test_shared/test_auth.py +++ b/tests/test_shared/test_auth.py @@ -17,7 +17,8 @@ class TestAPIKey(unittest.TestCase): """ @patch("builtins.open", new_callable=mock_open, read_data=CREDENTIALS) - def test_use_file_credentials(self, mock_file): + @patch("runpod.cli.groups.config.functions.os.path.exists", return_value=True) + def test_use_file_credentials(self, _mock_exists, mock_file): """ Test that the API key is read from the credentials file """