diff --git a/src/auth/src/access_boundary.rs b/src/auth/src/access_boundary.rs index e27a3ac9e7..154cfb4a27 100644 --- a/src/auth/src/access_boundary.rs +++ b/src/auth/src/access_boundary.rs @@ -18,6 +18,7 @@ use crate::credentials::{ AccessToken, AccessTokenCredentialsProvider, CacheableResource, CredentialsProvider, dynamic, }; use crate::errors::CredentialsError; +use crate::io::{HttpRequest, SharedHttpClientProvider}; use crate::mds::client::Client as MDSClient; use crate::{Result, errors}; use google_cloud_gax::Result as GaxResult; @@ -28,7 +29,6 @@ use google_cloud_gax::retry_loop_internal::retry_loop; use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicy, RetryPolicyExt}; use google_cloud_gax::retry_throttler::{AdaptiveThrottler, RetryThrottlerArg}; use http::{Extensions, HeaderMap, HeaderValue}; -use reqwest::Client; use std::clone::Clone; use std::error::Error; use std::fmt::Debug; @@ -235,11 +235,16 @@ impl CredentialsWithAccessBoundary where T: dynamic::AccessTokenCredentialsProvider + 'static, { - pub(crate) fn new(credentials: T, access_boundary_url: Option) -> Self { + pub(crate) fn new( + credentials: T, + access_boundary_url: Option, + http: SharedHttpClientProvider, + ) -> Self { let credentials = Arc::new(credentials); let provider = IAMAccessBoundaryProvider { credentials: credentials.clone(), url: access_boundary_url, + http, }; let access_boundary = Arc::new(AccessBoundary::new(provider)); Self { @@ -253,6 +258,7 @@ where credentials: T, mds_client: MDSClient, iam_endpoint_override: Option, + http: SharedHttpClientProvider, ) -> Self { let credentials = Arc::new(credentials); let provider = MDSAccessBoundaryProvider { @@ -260,6 +266,7 @@ where mds_client, iam_endpoint_override, url: OnceLock::new(), + http, }; let access_boundary = Arc::new(AccessBoundary::new(provider)); Self { @@ -403,6 +410,7 @@ where { credentials: Arc, url: Option, + http: SharedHttpClientProvider, } #[async_trait::async_trait] @@ -413,7 +421,11 @@ where async fn fetch_access_boundary(&self) -> Result> { match self.url.as_ref() { Some(url) => { - let client = AccessBoundaryClient::new(self.credentials.clone(), url.clone()); + let client = AccessBoundaryClient::new( + self.credentials.clone(), + url.clone(), + SharedHttpClientProvider::clone(&self.http), + ); client.fetch().await } None => Ok(None), // No URL means no access boundary @@ -431,6 +443,7 @@ where mds_client: MDSClient, iam_endpoint_override: Option, url: OnceLock, + http: SharedHttpClientProvider, } #[async_trait::async_trait] @@ -449,7 +462,11 @@ where } let url = self.url.get().unwrap().to_string(); - let client = AccessBoundaryClient::new(self.credentials.clone(), url); + let client = AccessBoundaryClient::new( + self.credentials.clone(), + url, + SharedHttpClientProvider::clone(&self.http), + ); client.fetch().await } } @@ -472,10 +489,11 @@ struct AccessBoundaryClient { url: String, retry_policy: Arc, backoff_policy: Arc, + http: SharedHttpClientProvider, } impl AccessBoundaryClient { - fn new(credentials: Arc, url: String) -> Self { + fn new(credentials: Arc, url: String, http: SharedHttpClientProvider) -> Self { let retry_policy = Aip194Strict.with_time_limit(Duration::from_secs(60)); let backoff_policy = ExponentialBackoff::default(); @@ -484,6 +502,7 @@ impl AccessBoundaryClient { url, retry_policy: Arc::new(retry_policy), backoff_policy: Arc::new(backoff_policy), + http, } } } @@ -509,19 +528,19 @@ where } async fn fetch_with_retry(self) -> GaxResult { - let client = Client::new(); let sleep = async |d| tokio::time::sleep(d).await; let retry_throttler: RetryThrottlerArg = AdaptiveThrottler::default().into(); let creds = self.credentials; let url = self.url; + let http = self.http; let inner = async move |d| { let headers = creds .headers(Extensions::new()) .await .map_err(GaxError::authentication)?; - let attempt = self::fetch_access_boundary_call(&client, &url, headers); + let attempt = self::fetch_access_boundary_call(&http, &url, headers); match d { Some(timeout) => match tokio::time::timeout(timeout, attempt).await { Ok(r) => r, @@ -544,7 +563,7 @@ where } async fn fetch_access_boundary_call( - client: &Client, + http: &SharedHttpClientProvider, url: &str, headers: CacheableResource, ) -> GaxResult { @@ -555,24 +574,19 @@ async fn fetch_access_boundary_call( } }; - let resp = client - .get(url) - .headers(headers) - .send() - .await - .map_err(GaxError::io)?; - - let status = resp.status(); - if !status.is_success() { - let err_headers = resp.headers().clone(); - let err_payload = resp - .bytes() - .await - .map_err(|e| GaxError::transport(err_headers.clone(), e))?; - return Err(GaxError::http(status.as_u16(), err_headers, err_payload)); + let request = HttpRequest::get(url).headers_from_map(&headers); + + let response = http.execute(request).await.map_err(GaxError::io)?; + + if !response.is_success() { + return Err(GaxError::http( + response.status.as_u16(), + response.headers, + response.body.into(), + )); } - resp.json().await.map_err(GaxError::io) + response.json().map_err(GaxError::io) } async fn refresh_task(provider: T, tx_header: watch::Sender<(Option, EntityTag)>) @@ -750,7 +764,11 @@ pub(crate) mod tests { }); let url = server.url("/allowedLocations").to_string(); - let creds = CredentialsWithAccessBoundary::new(mock, Some(url)); + let creds = CredentialsWithAccessBoundary::new( + mock, + Some(url), + SharedHttpClientProvider::default(), + ); // wait for the background task to fetch the access boundary. creds.wait_for_boundary().await; @@ -804,9 +822,18 @@ pub(crate) mod tests { }) }); let endpoint = server.url("").to_string().trim_end_matches('/').to_string(); - let mds_client = MDSClient::new(Some(endpoint.clone())); + let mds_client = MDSClient::new( + Some(endpoint.clone()), + SharedEnvProvider::default(), + SharedHttpClientProvider::default(), + ); - let creds = CredentialsWithAccessBoundary::new_for_mds(mock, mds_client, Some(endpoint)); + let creds = CredentialsWithAccessBoundary::new_for_mds( + mock, + mds_client, + Some(endpoint), + SharedHttpClientProvider::default(), + ); // wait for the background task to fetch the access boundary. creds.wait_for_boundary().await; @@ -849,7 +876,8 @@ pub(crate) mod tests { }) }); let url = server.url("/allowedLocations").to_string(); - let client = AccessBoundaryClient::new(Arc::new(mock), url); + let client = + AccessBoundaryClient::new(Arc::new(mock), url, SharedHttpClientProvider::default()); let val = client.fetch().await?; assert!(val.is_none(), "{val:?}"); @@ -879,7 +907,8 @@ pub(crate) mod tests { }); let url = server.url("/allowedLocations").to_string(); - let mut client = AccessBoundaryClient::new(Arc::new(mock), url); + let mut client = + AccessBoundaryClient::new(Arc::new(mock), url, SharedHttpClientProvider::default()); client.retry_policy = Arc::new(Aip194Strict.with_attempt_limit(3)); client.backoff_policy = Arc::new(test_backoff_policy()); @@ -899,7 +928,11 @@ pub(crate) mod tests { )) }); - let client = AccessBoundaryClient::new(Arc::new(mock), "http://localhost".to_string()); + let client = AccessBoundaryClient::new( + Arc::new(mock), + "http://localhost".to_string(), + SharedHttpClientProvider::default(), + ); let err = client.fetch().await.unwrap_err(); assert!(!err.is_transient(), "{err:?}"); } @@ -918,7 +951,8 @@ pub(crate) mod tests { data: headers, }) }); - let creds = CredentialsWithAccessBoundary::new(mock, None); + let creds = + CredentialsWithAccessBoundary::new(mock, None, SharedHttpClientProvider::default()); let cached_headers = creds.headers(Extensions::new()).await?; let token = get_token_from_headers(cached_headers.clone()); @@ -1240,7 +1274,8 @@ pub(crate) mod tests { }); let url = server.url("/allowedLocations").to_string(); - let mut client = AccessBoundaryClient::new(Arc::new(mock), url); + let mut client = + AccessBoundaryClient::new(Arc::new(mock), url, SharedHttpClientProvider::default()); client.backoff_policy = Arc::new(test_backoff_policy()); let val = client.fetch().await?; assert_eq!(val.as_deref(), Some("0x123")); diff --git a/src/auth/src/credentials.rs b/src/auth/src/credentials.rs index 0d23ad691c..abf45b3d8a 100644 --- a/src/auth/src/credentials.rs +++ b/src/auth/src/credentials.rs @@ -19,6 +19,10 @@ use crate::build_errors::Error as BuilderError; use crate::constants::GOOGLE_CLOUD_QUOTA_PROJECT_VAR; use crate::errors::{self, CredentialsError}; +use crate::io::{ + EnvProvider, FsProvider, HttpClientProvider, IoConfig, SharedEnvProvider, SharedFsProvider, + SharedHttpClientProvider, +}; use crate::token::Token; use crate::{BuildResult, Result}; use http::{Extensions, HeaderMap}; @@ -429,6 +433,7 @@ pub(crate) mod dynamic { pub struct Builder { quota_project_id: Option, scopes: Option>, + providers: IoConfig, } impl Default for Builder { @@ -447,6 +452,7 @@ impl Default for Builder { Self { quota_project_id: None, scopes: None, + providers: IoConfig::default(), } } } @@ -506,6 +512,39 @@ impl Builder { self } + /// Sets a custom environment variable provider. + /// + /// When set, the auth crate will use this provider for all environment + /// variable lookups during credential construction instead of reading + /// from the process environment directly. + pub fn with_env_provider(mut self, provider: impl EnvProvider + 'static) -> Self { + self.providers.env = SharedEnvProvider::new(provider); + self + } + + /// Sets a custom filesystem provider. + /// + /// When set, the auth crate will use this provider for all file read + /// operations during credential construction instead of reading from + /// the real filesystem directly. + pub fn with_fs_provider(mut self, provider: impl FsProvider + 'static) -> Self { + self.providers.fs = SharedFsProvider::new(provider); + self + } + + /// Sets a custom HTTP client provider. + /// + /// When set, the auth crate will use this provider for all HTTP + /// requests during credential construction and token retrieval instead + /// of using `reqwest::Client` directly. + pub fn with_http_client_provider( + mut self, + provider: impl HttpClientProvider + 'static, + ) -> Self { + self.providers.http = SharedHttpClientProvider::new(provider); + self + } + /// Returns a [Credentials] instance with the configured settings. /// /// # Errors @@ -546,16 +585,18 @@ impl Builder { /// /// [application-default credentials]: https://cloud.google.com/docs/authentication/application-default-credentials pub fn build_access_token_credentials(self) -> BuildResult { - let json_data = match load_adc()? { + let json_data = match load_adc(&self.providers.env, &self.providers.fs)? { AdcContents::Contents(contents) => { Some(serde_json::from_str(&contents).map_err(BuilderError::parsing)?) } AdcContents::FallbackToMds => None, }; - let quota_project_id = std::env::var(GOOGLE_CLOUD_QUOTA_PROJECT_VAR) - .ok() + let quota_project_id = self + .providers + .env + .var(GOOGLE_CLOUD_QUOTA_PROJECT_VAR) .or(self.quota_project_id); - build_credentials(json_data, quota_project_id, self.scopes) + build_credentials(json_data, quota_project_id, self.scopes, self.providers) } /// Returns a [crate::signer::Signer] instance with the configured settings. @@ -576,16 +617,18 @@ impl Builder { /// # Ok(()) } /// ``` pub fn build_signer(self) -> BuildResult { - let json_data = match load_adc()? { + let json_data = match load_adc(&self.providers.env, &self.providers.fs)? { AdcContents::Contents(contents) => { Some(serde_json::from_str(&contents).map_err(BuilderError::parsing)?) } AdcContents::FallbackToMds => None, }; - let quota_project_id = std::env::var(GOOGLE_CLOUD_QUOTA_PROJECT_VAR) - .ok() + let quota_project_id = self + .providers + .env + .var(GOOGLE_CLOUD_QUOTA_PROJECT_VAR) .or(self.quota_project_id); - build_signer(json_data, quota_project_id, self.scopes) + build_signer(json_data, quota_project_id, self.scopes, self.providers) } } @@ -596,7 +639,7 @@ enum AdcPath { } #[derive(Debug, PartialEq)] -enum AdcContents { +pub(crate) enum AdcContents { Contents(String), FallbackToMds, } @@ -616,12 +659,13 @@ fn extract_credential_type(json: &Value) -> BuildResult<&str> { /// `mds::Builder`, `service_account::Builder`, etc.) before calling `.build()`. /// It helps avoid repetitive code in the `build_credentials` function. macro_rules! config_builder { - ($builder_instance:expr, $quota_project_id_option:expr, $scopes_option:expr, $apply_scopes_closure:expr) => {{ + ($builder_instance:expr, $quota_project_id_option:expr, $scopes_option:expr, $apply_scopes_closure:expr, $providers:expr) => {{ let builder = config_common_builder!( $builder_instance, $quota_project_id_option, $scopes_option, - $apply_scopes_closure + $apply_scopes_closure, + $providers ); builder.build_access_token_credentials() }}; @@ -630,20 +674,25 @@ macro_rules! config_builder { /// Applies common optional configurations (quota project ID, scopes) to a /// specific credential builder instance and then return a signer for it. macro_rules! config_signer { - ($builder_instance:expr, $quota_project_id_option:expr, $scopes_option:expr, $apply_scopes_closure:expr) => {{ + ($builder_instance:expr, $quota_project_id_option:expr, $scopes_option:expr, $apply_scopes_closure:expr, $providers:expr) => {{ let builder = config_common_builder!( $builder_instance, $quota_project_id_option, $scopes_option, - $apply_scopes_closure + $apply_scopes_closure, + $providers ); builder.build_signer() }}; } macro_rules! config_common_builder { - ($builder_instance:expr, $quota_project_id_option:expr, $scopes_option:expr, $apply_scopes_closure:expr) => {{ - let builder = $builder_instance; + ($builder_instance:expr, $quota_project_id_option:expr, $scopes_option:expr, $apply_scopes_closure:expr, $providers:expr) => {{ + let providers = $providers; + let builder = $builder_instance + .with_env_provider(SharedEnvProvider::clone(&providers.env)) + .with_fs_provider(SharedFsProvider::clone(&providers.fs)) + .with_http_client_provider(SharedHttpClientProvider::clone(&providers.http)); let builder = $quota_project_id_option .into_iter() .fold(builder, |b, qp| b.with_quota_project_id(qp)); @@ -660,13 +709,15 @@ fn build_credentials( json: Option, quota_project_id: Option, scopes: Option>, + providers: IoConfig, ) -> BuildResult { match json { None => config_builder!( mds::Builder::from_adc(), quota_project_id, scopes, - |b: mds::Builder, s: Vec| b.with_scopes(s) + |b: mds::Builder, s: Vec| b.with_scopes(s), + providers ), Some(json) => { let cred_type = extract_credential_type(&json)?; @@ -676,7 +727,8 @@ fn build_credentials( user_account::Builder::new(json), quota_project_id, scopes, - |b: user_account::Builder, s: Vec| b.with_scopes(s) + |b: user_account::Builder, s: Vec| b.with_scopes(s), + providers ) } "service_account" => config_builder!( @@ -684,21 +736,24 @@ fn build_credentials( quota_project_id, scopes, |b: service_account::Builder, s: Vec| b - .with_access_specifier(service_account::AccessSpecifier::from_scopes(s)) + .with_access_specifier(service_account::AccessSpecifier::from_scopes(s)), + providers ), "impersonated_service_account" => { config_builder!( impersonated::Builder::new(json), quota_project_id, scopes, - |b: impersonated::Builder, s: Vec| b.with_scopes(s) + |b: impersonated::Builder, s: Vec| b.with_scopes(s), + providers ) } "external_account" => config_builder!( external_account::Builder::new(json), quota_project_id, scopes, - |b: external_account::Builder, s: Vec| b.with_scopes(s) + |b: external_account::Builder, s: Vec| b.with_scopes(s), + providers ), _ => Err(BuilderError::unknown_type(cred_type)), } @@ -710,13 +765,15 @@ fn build_signer( json: Option, quota_project_id: Option, scopes: Option>, + providers: IoConfig, ) -> BuildResult { match json { None => config_signer!( mds::Builder::from_adc(), quota_project_id, scopes, - |b: mds::Builder, s: Vec| b.with_scopes(s) + |b: mds::Builder, s: Vec| b.with_scopes(s), + providers ), Some(json) => { let cred_type = extract_credential_type(&json)?; @@ -729,14 +786,16 @@ fn build_signer( quota_project_id, scopes, |b: service_account::Builder, s: Vec| b - .with_access_specifier(service_account::AccessSpecifier::from_scopes(s)) + .with_access_specifier(service_account::AccessSpecifier::from_scopes(s)), + providers ), "impersonated_service_account" => { config_signer!( impersonated::Builder::new(json), quota_project_id, scopes, - |b: impersonated::Builder, s: Vec| b.with_scopes(s) + |b: impersonated::Builder, s: Vec| b.with_scopes(s), + providers ) } "external_account" => Err(BuilderError::not_supported( @@ -759,39 +818,37 @@ fn path_not_found(path: String) -> BuilderError { )) } -fn load_adc() -> BuildResult { - match adc_path() { +pub(crate) fn load_adc(env: &SharedEnvProvider, fs: &SharedFsProvider) -> BuildResult { + match adc_path(env) { None => Ok(AdcContents::FallbackToMds), - Some(AdcPath::FromEnv(path)) => match std::fs::read_to_string(&path) { + Some(AdcPath::FromEnv(path)) => match fs.read_to_string(&path) { Ok(contents) => Ok(AdcContents::Contents(contents)), Err(e) if e.kind() == std::io::ErrorKind::NotFound => Err(path_not_found(path)), Err(e) => Err(BuilderError::loading(e)), }, - Some(AdcPath::WellKnown(path)) => match std::fs::read_to_string(path) { + Some(AdcPath::WellKnown(path)) => match fs.read_to_string(&path) { Ok(contents) => Ok(AdcContents::Contents(contents)), Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(AdcContents::FallbackToMds), Err(e) => Err(BuilderError::loading(e)), }, } } - /// The path to Application Default Credentials (ADC), as specified in [AIP-4110]. /// /// [AIP-4110]: https://google.aip.dev/auth/4110 -fn adc_path() -> Option { - if let Ok(path) = std::env::var("GOOGLE_APPLICATION_CREDENTIALS") { +fn adc_path(env: &SharedEnvProvider) -> Option { + if let Some(path) = env.var("GOOGLE_APPLICATION_CREDENTIALS") { return Some(AdcPath::FromEnv(path)); } - Some(AdcPath::WellKnown(adc_well_known_path()?)) + Some(AdcPath::WellKnown(adc_well_known_path(env)?)) } /// The well-known path to ADC on Windows, as specified in [AIP-4113]. /// /// [AIP-4113]: https://google.aip.dev/auth/4113 #[cfg(target_os = "windows")] -fn adc_well_known_path() -> Option { - std::env::var("APPDATA") - .ok() +fn adc_well_known_path(env: &SharedEnvProvider) -> Option { + env.var("APPDATA") .map(|root| root + "/gcloud/application_default_credentials.json") } @@ -799,9 +856,8 @@ fn adc_well_known_path() -> Option { /// /// [AIP-4113]: https://google.aip.dev/auth/4113 #[cfg(not(target_os = "windows"))] -fn adc_well_known_path() -> Option { - std::env::var("HOME") - .ok() +fn adc_well_known_path(env: &SharedEnvProvider) -> Option { + env.var("HOME") .map(|root| root + "/.config/gcloud/application_default_credentials.json") } @@ -1050,12 +1106,13 @@ pub(crate) mod tests { fn adc_well_known_path_windows() { let _creds = ScopedEnv::remove("GOOGLE_APPLICATION_CREDENTIALS"); let _appdata = ScopedEnv::set("APPDATA", "C:/Users/foo"); + let env = SharedEnvProvider::default(); assert_eq!( - adc_well_known_path(), + adc_well_known_path(&env), Some("C:/Users/foo/gcloud/application_default_credentials.json".to_string()) ); assert_eq!( - adc_path(), + adc_path(&env), Some(AdcPath::WellKnown( "C:/Users/foo/gcloud/application_default_credentials.json".to_string() )) @@ -1068,8 +1125,9 @@ pub(crate) mod tests { fn adc_well_known_path_windows_no_appdata() { let _creds = ScopedEnv::remove("GOOGLE_APPLICATION_CREDENTIALS"); let _appdata = ScopedEnv::remove("APPDATA"); - assert_eq!(adc_well_known_path(), None); - assert_eq!(adc_path(), None); + let env = SharedEnvProvider::default(); + assert_eq!(adc_well_known_path(&env), None); + assert_eq!(adc_path(&env), None); } #[cfg(not(target_os = "windows"))] @@ -1078,12 +1136,13 @@ pub(crate) mod tests { fn adc_well_known_path_posix() { let _creds = ScopedEnv::remove("GOOGLE_APPLICATION_CREDENTIALS"); let _home = ScopedEnv::set("HOME", "/home/foo"); + let env = SharedEnvProvider::default(); assert_eq!( - adc_well_known_path(), + adc_well_known_path(&env), Some("/home/foo/.config/gcloud/application_default_credentials.json".to_string()) ); assert_eq!( - adc_path(), + adc_path(&env), Some(AdcPath::WellKnown( "/home/foo/.config/gcloud/application_default_credentials.json".to_string() )) @@ -1096,8 +1155,9 @@ pub(crate) mod tests { fn adc_well_known_path_posix_no_home() { let _creds = ScopedEnv::remove("GOOGLE_APPLICATION_CREDENTIALS"); let _appdata = ScopedEnv::remove("HOME"); - assert_eq!(adc_well_known_path(), None); - assert_eq!(adc_path(), None); + let env = SharedEnvProvider::default(); + assert_eq!(adc_well_known_path(&env), None); + assert_eq!(adc_path(&env), None); } #[test] @@ -1107,56 +1167,65 @@ pub(crate) mod tests { "GOOGLE_APPLICATION_CREDENTIALS", "/usr/bar/application_default_credentials.json", ); + let env = SharedEnvProvider::default(); assert_eq!( - adc_path(), + adc_path(&env), Some(AdcPath::FromEnv( "/usr/bar/application_default_credentials.json".to_string() )) ); } - #[test] + #[tokio::test] #[serial_test::serial] - fn load_adc_no_well_known_path_fallback_to_mds() { + async fn load_adc_no_well_known_path_fallback_to_mds() { let _e1 = ScopedEnv::remove("GOOGLE_APPLICATION_CREDENTIALS"); let _e2 = ScopedEnv::remove("HOME"); // For posix let _e3 = ScopedEnv::remove("APPDATA"); // For windows - assert_eq!(load_adc().unwrap(), AdcContents::FallbackToMds); + let env = SharedEnvProvider::default(); + let fs = SharedFsProvider::default(); + assert_eq!(load_adc(&env, &fs).unwrap(), AdcContents::FallbackToMds); } - #[test] + #[tokio::test] #[serial_test::serial] - fn load_adc_no_file_at_well_known_path_fallback_to_mds() { + async fn load_adc_no_file_at_well_known_path_fallback_to_mds() { // Create a new temp directory. There is not an ADC file in here. let dir = tempfile::TempDir::new().unwrap(); let path = dir.path().to_str().unwrap(); let _e1 = ScopedEnv::remove("GOOGLE_APPLICATION_CREDENTIALS"); let _e2 = ScopedEnv::set("HOME", path); // For posix let _e3 = ScopedEnv::set("APPDATA", path); // For windows - assert_eq!(load_adc().unwrap(), AdcContents::FallbackToMds); + let env = SharedEnvProvider::default(); + let fs = SharedFsProvider::default(); + assert_eq!(load_adc(&env, &fs).unwrap(), AdcContents::FallbackToMds); } - #[test] + #[tokio::test] #[serial_test::serial] - fn load_adc_no_file_at_env_is_error() { + async fn load_adc_no_file_at_env_is_error() { let _e = ScopedEnv::set("GOOGLE_APPLICATION_CREDENTIALS", "file-does-not-exist.json"); - let err = load_adc().unwrap_err(); + let env = SharedEnvProvider::default(); + let fs = SharedFsProvider::default(); + let err = load_adc(&env, &fs).unwrap_err(); assert!(err.is_loading(), "{err:?}"); let msg = format!("{err:?}"); assert!(msg.contains("file-does-not-exist.json"), "{err:?}"); assert!(msg.contains("GOOGLE_APPLICATION_CREDENTIALS"), "{err:?}"); } - #[test] + #[tokio::test] #[serial_test::serial] - fn load_adc_success() { + async fn load_adc_success() { let file = tempfile::NamedTempFile::new().unwrap(); let path = file.into_temp_path(); std::fs::write(&path, "contents").expect("Unable to write to temporary file."); let _e = ScopedEnv::set("GOOGLE_APPLICATION_CREDENTIALS", path.to_str().unwrap()); + let env = SharedEnvProvider::default(); + let fs = SharedFsProvider::default(); assert_eq!( - load_adc().unwrap(), + load_adc(&env, &fs).unwrap(), AdcContents::Contents("contents".to_string()) ); } diff --git a/src/auth/src/credentials/external_account.rs b/src/auth/src/credentials/external_account.rs index 9c657f8bd7..f7e427597c 100644 --- a/src/auth/src/credentials/external_account.rs +++ b/src/auth/src/credentials/external_account.rs @@ -122,6 +122,10 @@ use crate::credentials::subject_token::dynamic; use crate::credentials::{AccessToken, AccessTokenCredentials}; use crate::errors::non_retryable; use crate::headers_util::AuthHeadersBuilder; +use crate::io::{ + EnvProvider, FsProvider, HttpClientProvider, IoConfig, SharedEnvProvider, SharedFsProvider, + SharedHttpClientProvider, +}; use crate::retry::Builder as RetryTokenProviderBuilder; use crate::token::{CachedTokenProvider, Token, TokenProvider}; use crate::token_cache::TokenCache; @@ -219,50 +223,44 @@ impl From for ExternalAccountConfig { if scope.is_empty() { scope.push(DEFAULT_SCOPE.to_string()); } - Self { - audience: config.audience.clone(), - client_id: config.client_id, - client_secret: config.client_secret, - subject_token_type: config.subject_token_type, - token_url: config.token_url, - service_account_impersonation_url: config.service_account_impersonation_url, - credential_source: CredentialSource::from_file( - config.credential_source, - &config.audience, - ), - scopes: scope, - workforce_pool_user_project: config.workforce_pool_user_project, - } - } -} - -impl CredentialSource { - fn from_file(source: CredentialSourceFile, audience: &str) -> Self { - match source { + let credential_source = match config.credential_source { CredentialSourceFile::Url { url, headers, format, - } => Self::Url(UrlSourcedCredentials::new(url, headers, format)), + } => CredentialSource::Url { + url, + headers, + format, + }, CredentialSourceFile::Executable { executable } => { - Self::Executable(ExecutableSourcedCredentials::new(executable)) - } - CredentialSourceFile::File { file, format } => { - Self::File(FileSourcedCredentials::new(file, format)) + CredentialSource::Executable(executable) } + CredentialSourceFile::File { file, format } => CredentialSource::File { file, format }, CredentialSourceFile::Aws { region_url, url, regional_cred_verification_url, imdsv2_session_token_url, .. - } => Self::Aws(AwsSourcedCredentials::new( + } => CredentialSource::Aws { region_url, url, regional_cred_verification_url, imdsv2_session_token_url, - audience.to_string(), - )), + audience: config.audience.clone(), + }, + }; + Self { + audience: config.audience, + client_id: config.client_id, + client_secret: config.client_secret, + subject_token_type: config.subject_token_type, + token_url: config.token_url, + service_account_impersonation_url: config.service_account_impersonation_url, + credential_source, + scopes: scope, + workforce_pool_user_project: config.workforce_pool_user_project, } } } @@ -375,10 +373,23 @@ impl ExternalAccountConfigBuilder { #[derive(Debug, Clone)] enum CredentialSource { - Url(UrlSourcedCredentials), - Executable(ExecutableSourcedCredentials), - File(FileSourcedCredentials), - Aws(AwsSourcedCredentials), + Url { + url: String, + headers: Option>, + format: Option, + }, + Executable(ExecutableConfig), + File { + file: String, + format: Option, + }, + Aws { + region_url: Option, + url: Option, + regional_cred_verification_url: Option, + imdsv2_session_token_url: Option, + audience: String, + }, Programmatic(ProgrammaticSourcedCredentials), } @@ -387,24 +398,65 @@ impl ExternalAccountConfig { self, quota_project_id: Option, retry_builder: RetryTokenProviderBuilder, + http: SharedHttpClientProvider, + fs: SharedFsProvider, + env: SharedEnvProvider, ) -> ExternalAccountCredentials { let config = self.clone(); match self.credential_source { - CredentialSource::Url(source) => { - Self::make_credentials_from_source(source, config, quota_project_id, retry_builder) - } - CredentialSource::Executable(source) => { - Self::make_credentials_from_source(source, config, quota_project_id, retry_builder) - } - CredentialSource::Programmatic(source) => { - Self::make_credentials_from_source(source, config, quota_project_id, retry_builder) - } - CredentialSource::File(source) => { - Self::make_credentials_from_source(source, config, quota_project_id, retry_builder) - } - CredentialSource::Aws(source) => { - Self::make_credentials_from_source(source, config, quota_project_id, retry_builder) - } + CredentialSource::Url { + url, + headers, + format, + } => Self::make_credentials_from_source( + UrlSourcedCredentials::new(url, headers, format, http.clone()), + config, + quota_project_id, + retry_builder, + http, + ), + CredentialSource::Executable(executable) => Self::make_credentials_from_source( + ExecutableSourcedCredentials::new(executable, env, fs), + config, + quota_project_id, + retry_builder, + http, + ), + CredentialSource::Programmatic(source) => Self::make_credentials_from_source( + source, + config, + quota_project_id, + retry_builder, + http, + ), + CredentialSource::File { file, format } => Self::make_credentials_from_source( + FileSourcedCredentials::new(file, format, fs), + config, + quota_project_id, + retry_builder, + http, + ), + CredentialSource::Aws { + region_url, + url, + regional_cred_verification_url, + imdsv2_session_token_url, + audience, + } => Self::make_credentials_from_source( + AwsSourcedCredentials::new( + region_url, + url, + regional_cred_verification_url, + imdsv2_session_token_url, + audience, + env, + http.clone(), + ), + config, + quota_project_id, + retry_builder, + http, + ), } } @@ -413,6 +465,7 @@ impl ExternalAccountConfig { config: ExternalAccountConfig, quota_project_id: Option, retry_builder: RetryTokenProviderBuilder, + http: SharedHttpClientProvider, ) -> ExternalAccountCredentials where T: dynamic::SubjectTokenProvider + 'static, @@ -420,6 +473,7 @@ impl ExternalAccountConfig { let token_provider = ExternalAccountTokenProvider { subject_token_provider, config, + http, }; let token_provider_with_retry = retry_builder.build(token_provider); let cache = TokenCache::new(token_provider_with_retry); @@ -437,6 +491,7 @@ where { subject_token_provider: T, config: ExternalAccountConfig, + http: SharedHttpClientProvider, } #[async_trait::async_trait] @@ -489,7 +544,7 @@ where ..ExchangeTokenRequest::default() }; - let token_res = STSHandler::exchange_token(req).await?; + let token_res = STSHandler::exchange_token(req, &self.http).await?; if let Some(impersonation_url) = &self.config.service_account_impersonation_url { let mut headers = HeaderMap::new(); @@ -505,6 +560,7 @@ where user_scopes, impersonated::DEFAULT_LIFETIME, impersonation_url, + &self.http, ) .await; } @@ -577,6 +633,7 @@ pub struct Builder { scopes: Option>, retry_builder: RetryTokenProviderBuilder, iam_endpoint_override: Option, + providers: IoConfig, } impl Builder { @@ -590,6 +647,7 @@ impl Builder { scopes: None, retry_builder: RetryTokenProviderBuilder::default(), iam_endpoint_override: None, + providers: IoConfig::default(), } } @@ -709,6 +767,39 @@ impl Builder { self } + /// Sets a custom environment variable provider. + /// + /// When set, the auth crate will use this provider for all environment + /// variable lookups during external account credential construction + /// instead of reading from the process environment directly. + pub fn with_env_provider(mut self, provider: impl EnvProvider + 'static) -> Self { + self.providers.env = SharedEnvProvider::new(provider); + self + } + + /// Sets a custom filesystem provider. + /// + /// When set, the auth crate will use this provider for all file read + /// operations during external account credential construction instead + /// of reading from the real filesystem directly. + pub fn with_fs_provider(mut self, provider: impl FsProvider + 'static) -> Self { + self.providers.fs = SharedFsProvider::new(provider); + self + } + + /// Sets a custom HTTP client provider. + /// + /// When set, the auth crate will use this provider for all HTTP + /// requests during external account credential construction and token + /// retrieval instead of using `reqwest::Client` directly. + pub fn with_http_client_provider( + mut self, + provider: impl HttpClientProvider + 'static, + ) -> Self { + self.providers.http = SharedHttpClientProvider::new(provider); + self + } + /// Returns a [Credentials] instance with the configured settings. /// /// # Errors @@ -766,11 +857,21 @@ impl Builder { let access_boundary_url = external_account_lookup_url(&config.audience, self.iam_endpoint_override.as_deref()); - let creds = config.make_credentials(self.quota_project_id, self.retry_builder); + let http = SharedHttpClientProvider::clone(&self.providers.http); + let fs = SharedFsProvider::clone(&self.providers.fs); + let env = SharedEnvProvider::clone(&self.providers.env); + let creds = config.make_credentials( + self.quota_project_id, + self.retry_builder, + http.clone(), + fs, + env, + ); Ok(CredentialsWithAccessBoundary::new( creds, access_boundary_url, + http, )) } } @@ -826,6 +927,7 @@ pub struct ProgrammaticBuilder { quota_project_id: Option, config: ExternalAccountConfigBuilder, retry_builder: RetryTokenProviderBuilder, + providers: IoConfig, } impl ProgrammaticBuilder { @@ -871,6 +973,7 @@ impl ProgrammaticBuilder { quota_project_id: None, config, retry_builder: RetryTokenProviderBuilder::default(), + providers: IoConfig::default(), } } @@ -1369,6 +1472,39 @@ impl ProgrammaticBuilder { self } + /// Sets a custom environment variable provider. + /// + /// When set, the auth crate will use this provider for all environment + /// variable lookups during external account credential construction + /// instead of reading from the process environment directly. + pub fn with_env_provider(mut self, provider: impl EnvProvider + 'static) -> Self { + self.providers.env = SharedEnvProvider::new(provider); + self + } + + /// Sets a custom filesystem provider. + /// + /// When set, the auth crate will use this provider for all file read + /// operations during external account credential construction instead + /// of reading from the real filesystem directly. + pub fn with_fs_provider(mut self, provider: impl FsProvider + 'static) -> Self { + self.providers.fs = SharedFsProvider::new(provider); + self + } + + /// Sets a custom HTTP client provider. + /// + /// When set, the auth crate will use this provider for all HTTP + /// requests during external account credential construction and token + /// retrieval instead of using `reqwest::Client` directly. + pub fn with_http_client_provider( + mut self, + provider: impl HttpClientProvider + 'static, + ) -> Self { + self.providers.http = SharedHttpClientProvider::new(provider); + self + } + /// Returns a [Credentials] instance with the configured settings. /// /// # Errors @@ -1376,8 +1512,11 @@ impl ProgrammaticBuilder { /// Returns a [BuilderError] if any of the required fields (such as /// `audience` or `subject_token_type`) have not been set. pub fn build(self) -> BuildResult { + let http = SharedHttpClientProvider::clone(&self.providers.http); + let fs = SharedFsProvider::clone(&self.providers.fs); + let env = SharedEnvProvider::clone(&self.providers.env); let (config, quota_project_id, retry_builder) = self.build_components()?; - let creds = config.make_credentials(quota_project_id, retry_builder); + let creds = config.make_credentials(quota_project_id, retry_builder, http, fs, env); Ok(Credentials { inner: Arc::new(creds), }) @@ -1395,6 +1534,7 @@ impl ProgrammaticBuilder { quota_project_id, config, retry_builder, + providers: _, } = self; let mut config_builder = config; @@ -1536,14 +1676,25 @@ mod tests { let source = config.credential_source; match source { - CredentialSource::Url(source) => { - assert_eq!(source.url, "https://example.com/token"); + CredentialSource::Url { + url, + headers, + format, + } => { + assert_eq!(url, "https://example.com/token"); assert_eq!( - source.headers, - HashMap::from([("Metadata".to_string(), "True".to_string()),]), + headers, + Some(HashMap::from([( + "Metadata".to_string(), + "True".to_string() + ),])), + ); + let fmt = format.unwrap(); + assert_eq!(fmt.format_type, "json"); + assert_eq!( + fmt.subject_token_field_name, + Some("access_token".to_string()) ); - assert_eq!(source.format, "json"); - assert_eq!(source.subject_token_field_name, "access_token"); } _ => { unreachable!("expected Url Sourced credential") @@ -1573,11 +1724,10 @@ mod tests { let source = config.credential_source; match source { - CredentialSource::Executable(source) => { - assert_eq!(source.command, "cat"); - assert_eq!(source.args, vec!["/some/file"]); - assert_eq!(source.output_file.as_deref(), Some("/some/file")); - assert_eq!(source.timeout, Duration::from_secs(5)); + CredentialSource::Executable(executable) => { + assert_eq!(executable.command, "cat /some/file"); + assert_eq!(executable.output_file.as_deref(), Some("/some/file")); + assert_eq!(executable.timeout_millis, Some(5000)); } _ => { unreachable!("expected Executable Sourced credential") @@ -1607,10 +1757,11 @@ mod tests { let source = config.credential_source; match source { - CredentialSource::File(source) => { - assert_eq!(source.file, "/foo/bar"); - assert_eq!(source.format, "json"); - assert_eq!(source.subject_token_field_name, "token"); + CredentialSource::File { file, format } => { + assert_eq!(file, "/foo/bar"); + let fmt = format.unwrap(); + assert_eq!(fmt.format_type, "json"); + assert_eq!(fmt.subject_token_field_name, Some("token".to_string())); } _ => { unreachable!("expected File Sourced credential") @@ -2393,10 +2544,9 @@ mod tests { let config: ExternalAccountConfig = file.into(); match config.credential_source { - CredentialSource::File(source) => { - assert_eq!(source.file, "/var/run/service-account/token"); - assert_eq!(source.format, "text"); // Default format - assert_eq!(source.subject_token_field_name, ""); // Default empty + CredentialSource::File { file, format } => { + assert_eq!(file, "/var/run/service-account/token"); + assert!(format.is_none()); // No format specified } _ => { unreachable!("expected File sourced credential") @@ -2426,10 +2576,11 @@ mod tests { let config: ExternalAccountConfig = file.into(); match config.credential_source { - CredentialSource::File(source) => { - assert_eq!(source.file, "/var/run/service-account/token"); - assert_eq!(source.format, "text"); - assert_eq!(source.subject_token_field_name, ""); // Empty for text format + CredentialSource::File { file, format } => { + assert_eq!(file, "/var/run/service-account/token"); + let fmt = format.unwrap(); + assert_eq!(fmt.format_type, "text"); + assert_eq!(fmt.subject_token_field_name, None); // No field name for text format } _ => { unreachable!("expected File sourced credential") @@ -2463,23 +2614,28 @@ mod tests { let config: ExternalAccountConfig = file.into(); match config.credential_source { - CredentialSource::Aws(source) => { + CredentialSource::Aws { + region_url, + regional_cred_verification_url, + imdsv2_session_token_url, + .. + } => { assert_eq!( - source.region_url, + region_url, Some( "http://169.254.169.254/latest/meta-data/placement/availability-zone" .to_string() ) ); assert_eq!( - source.regional_cred_verification_url, + regional_cred_verification_url, Some( "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15" .to_string() ) ); assert_eq!( - source.imdsv2_session_token_url, + imdsv2_session_token_url, Some("http://169.254.169.254/latest/api/token".to_string()) ); } diff --git a/src/auth/src/credentials/external_account_sources/aws_sourced.rs b/src/auth/src/credentials/external_account_sources/aws_sourced.rs index 394bf35970..90ee066450 100644 --- a/src/auth/src/credentials/external_account_sources/aws_sourced.rs +++ b/src/auth/src/credentials/external_account_sources/aws_sourced.rs @@ -17,12 +17,11 @@ use crate::{ credentials::subject_token::{ Builder as SubjectTokenBuilder, SubjectToken, SubjectTokenProvider, }, - errors, + io::{HttpRequest, SharedEnvProvider, SharedHttpClientProvider}, }; use chrono::Utc; use google_cloud_gax::error::CredentialsError; use hmac::{Hmac, Mac}; -use reqwest::{Client, Response}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use std::collections::BTreeMap; @@ -68,6 +67,10 @@ pub(crate) struct AwsSourcedCredentials { pub imdsv2_session_token_url: Option, /// The audience for the x-goog-cloud-target-resource header. pub audience: String, + /// The environment variable provider. + pub env: SharedEnvProvider, + /// The HTTP client provider. + pub http: SharedHttpClientProvider, } impl AwsSourcedCredentials { @@ -77,6 +80,8 @@ impl AwsSourcedCredentials { regional_cred_verification_url: Option, imdsv2_session_token_url: Option, audience: String, + env: SharedEnvProvider, + http: SharedHttpClientProvider, ) -> Self { Self { region_url, @@ -84,6 +89,8 @@ impl AwsSourcedCredentials { regional_cred_verification_url, imdsv2_session_token_url, audience, + env, + http, } } } @@ -116,16 +123,10 @@ impl SubjectTokenProvider for AwsSourcedCredentials { type Error = CredentialsError; async fn subject_token(&self) -> Result { - let client = Client::new(); + let imdsv2_token = self.resolve_imdsv2_token().await?; - let imdsv2_token = self.resolve_imdsv2_token(&client).await?; - - let region = self - .resolve_region(&client, imdsv2_token.as_deref()) - .await?; - let creds = self - .resolve_credentials(&client, imdsv2_token.as_deref()) - .await?; + let region = self.resolve_region(imdsv2_token.as_deref()).await?; + let creds = self.resolve_credentials(imdsv2_token.as_deref()).await?; let now = Utc::now(); let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string(); @@ -287,25 +288,27 @@ fn parse_region_from_zone(zone: &str) -> Option<&str> { } impl AwsSourcedCredentials { - async fn resolve_imdsv2_token(&self, client: &Client) -> Result> { + async fn resolve_imdsv2_token(&self) -> Result> { if let Some(url) = &self.imdsv2_session_token_url { - let response = client - .put(url) - .header(IMDSV2_TOKEN_TTL_HEADER, IMDSV2_DEFAULT_TOKEN_TTL_SECONDS) - .send() + let request = HttpRequest::put(url) + .header(IMDSV2_TOKEN_TTL_HEADER, IMDSV2_DEFAULT_TOKEN_TTL_SECONDS); + + let response = self + .http + .execute(request) .await - .map_err(|e| errors::from_http_error(e, MSG))?; + .map_err(|e| crate::errors::from_http_error(e, MSG))?; - if !response.status().is_success() { - return Err( - errors::from_http_response(response, "failed to resolve IMDSv2 token").await, - ); + if !response.is_success() { + return Err(crate::errors::from_http_response( + &response, + "failed to resolve IMDSv2 token", + )); } let token = response .text() - .await - .map_err(|e| errors::from_http_error(e, "failed to read IMDSv2 token body"))?; + .map_err(|e| CredentialsError::from_source(false, e))?; return Ok(Some(token)); } @@ -314,44 +317,42 @@ impl AwsSourcedCredentials { async fn get_with_imdsv2_token( &self, - client: &Client, url: &str, imdsv2_token: Option<&str>, error_msg: &str, - ) -> Result { - let request = client.get(url); - let request = if let Some(token) = imdsv2_token { - request.header(IMDSV2_TOKEN_HEADER, token) - } else { - request - }; - let response = request - .send() + ) -> Result> { + let mut request = HttpRequest::get(url); + if let Some(token) = imdsv2_token { + request = request.header(IMDSV2_TOKEN_HEADER, token); + } + + let response = self + .http + .execute(request) .await - .map_err(|e| errors::from_http_error(e, MSG))?; - if !response.status().is_success() { - return Err(errors::from_http_response(response, error_msg).await); + .map_err(|e| crate::errors::from_http_error(e, MSG))?; + + if !response.is_success() { + return Err(crate::errors::from_http_response(&response, error_msg)); } - Ok(response) + Ok(response.body) } - async fn resolve_region(&self, client: &Client, imdsv2_token: Option<&str>) -> Result { - if let Ok(region) = std::env::var(AWS_REGION) { + async fn resolve_region(&self, imdsv2_token: Option<&str>) -> Result { + if let Some(region) = self.env.var(AWS_REGION) { return Ok(region); } - if let Ok(region) = std::env::var(AWS_DEFAULT_REGION) { + if let Some(region) = self.env.var(AWS_DEFAULT_REGION) { return Ok(region); } if let Some(url) = &self.region_url { - let response = self - .get_with_imdsv2_token(client, url, imdsv2_token, "could not resolve AWS region") + let body = self + .get_with_imdsv2_token(url, imdsv2_token, "could not resolve AWS region") .await?; - let zone = response - .text() - .await - .map_err(|e| errors::from_http_error(e, "failed to read AWS region body"))?; + let zone = + String::from_utf8(body).map_err(|e| CredentialsError::from_source(false, e))?; if let Some(region) = parse_region_from_zone(&zone) { return Ok(region.to_string()); } @@ -362,20 +363,14 @@ impl AwsSourcedCredentials { )) } - async fn resolve_role_name( - &self, - client: &Client, - imdsv2_token: Option<&str>, - ) -> Result { + async fn resolve_role_name(&self, imdsv2_token: Option<&str>) -> Result { if let Some(url) = &self.role_url { - let response = self - .get_with_imdsv2_token(client, url, imdsv2_token, "could not resolve AWS role name") + let body = self + .get_with_imdsv2_token(url, imdsv2_token, "could not resolve AWS role name") .await?; - let role_name = response - .text() - .await - .map_err(|e| errors::from_http_error(e, "failed to read AWS role name body"))?; + let role_name = + String::from_utf8(body).map_err(|e| CredentialsError::from_source(false, e))?; return Ok(role_name.trim().to_string()); } @@ -387,25 +382,17 @@ impl AwsSourcedCredentials { async fn resolve_role_credentials( &self, - client: &Client, role_name: &str, imdsv2_token: Option<&str>, ) -> Result { if let Some(url) = &self.role_url { let role_url = format!("{}/{}", url.trim_end_matches('/'), role_name.trim()); - let response = self - .get_with_imdsv2_token( - client, - &role_url, - imdsv2_token, - "could not resolve AWS credentials", - ) + let body = self + .get_with_imdsv2_token(&role_url, imdsv2_token, "could not resolve AWS credentials") .await?; - let creds = response - .json() - .await - .map_err(|e| errors::from_http_error(e, "failed to parse AWS credentials JSON"))?; + let creds: AwsSecurityCredentials = serde_json::from_slice(&body) + .map_err(|e| CredentialsError::from_source(e.is_io(), e))?; return Ok(creds); } Err(CredentialsError::from_msg( @@ -416,23 +403,22 @@ impl AwsSourcedCredentials { async fn resolve_credentials( &self, - client: &Client, imdsv2_token: Option<&str>, ) -> Result { - if let (Ok(ak), Ok(sk)) = ( - std::env::var(AWS_ACCESS_KEY_ID), - std::env::var(AWS_SECRET_ACCESS_KEY), + if let (Some(ak), Some(sk)) = ( + self.env.var(AWS_ACCESS_KEY_ID), + self.env.var(AWS_SECRET_ACCESS_KEY), ) { return Ok(AwsSecurityCredentials { access_key_id: ak, secret_access_key: sk, - token: std::env::var(AWS_SESSION_TOKEN).ok(), + token: self.env.var(AWS_SESSION_TOKEN), }); } - let role_name = self.resolve_role_name(client, imdsv2_token).await?; + let role_name = self.resolve_role_name(imdsv2_token).await?; let role_credentials = self - .resolve_role_credentials(client, &role_name, imdsv2_token) + .resolve_role_credentials(&role_name, imdsv2_token) .await?; Ok(role_credentials) @@ -488,13 +474,10 @@ mod tests { Some("sts.{region}.amazonaws.com".into()), None, "aud".into(), + SharedEnvProvider::default(), + SharedHttpClientProvider::default(), ); - let client = Client::new(); - assert_eq!( - creds.resolve_region(&client, None).await?, - "us-west-2", - "{creds:?}" - ); + assert_eq!(creds.resolve_region(None).await?, "us-west-2", "{creds:?}"); Ok(()) } @@ -513,13 +496,10 @@ mod tests { Some("sts.{region}.amazonaws.com".into()), None, "aud".into(), + SharedEnvProvider::default(), + SharedHttpClientProvider::default(), ); - let client = Client::new(); - assert_eq!( - creds.resolve_region(&client, None).await?, - "us-east-1", - "{creds:?}" - ); + assert_eq!(creds.resolve_region(None).await?, "us-east-1", "{creds:?}"); Ok(()) } @@ -535,9 +515,10 @@ mod tests { Some("sts.{region}.amazonaws.com".into()), None, "aud".into(), + SharedEnvProvider::default(), + SharedHttpClientProvider::default(), ); - let client = Client::new(); - let resolved = creds.resolve_credentials(&client, None).await?; + let resolved = creds.resolve_credentials(None).await?; assert_eq!(resolved.access_key_id, "ACCESS_KEY_ID", "{resolved:?}"); assert_eq!(resolved.secret_access_key, "SECRET", "{resolved:?}"); Ok(()) @@ -570,9 +551,10 @@ mod tests { Some("sts.{region}.amazonaws.com".into()), None, "aud".into(), + SharedEnvProvider::default(), + SharedHttpClientProvider::default(), ); - let client = Client::new(); - let resolved = creds.resolve_credentials(&client, None).await?; + let resolved = creds.resolve_credentials(None).await?; assert_eq!(resolved.access_key_id, "ACCESS_KEY_ID_IMDS", "{resolved:?}"); assert_eq!(resolved.secret_access_key, "SECRET_IMDS", "{resolved:?}"); assert_eq!( @@ -605,9 +587,10 @@ mod tests { Some("sts.{region}.amazonaws.com".into()), Some(server.url("/token").to_string()), "aud".into(), + SharedEnvProvider::default(), + SharedHttpClientProvider::default(), ); - let client = Client::new(); - let token = creds.resolve_imdsv2_token(&client).await?; + let token = creds.resolve_imdsv2_token().await?; assert_eq!(token, Some("test-token".to_string()), "{token:?}"); Ok(()) } @@ -666,6 +649,8 @@ mod tests { Some("sts.{region}.amazonaws.com".into()), Some(server.url("/token").to_string()), "another_audience".into(), + SharedEnvProvider::default(), + SharedHttpClientProvider::default(), ); let subject_token = creds.subject_token().await?; @@ -729,6 +714,8 @@ mod tests { Some("sts.{region}.amazonaws.com".into()), None, "some_audience".into(), + SharedEnvProvider::default(), + SharedHttpClientProvider::default(), ); let subject_token = creds.subject_token().await?; diff --git a/src/auth/src/credentials/external_account_sources/executable_sourced.rs b/src/auth/src/credentials/external_account_sources/executable_sourced.rs index 4aef4250c7..f117f136be 100644 --- a/src/auth/src/credentials/external_account_sources/executable_sourced.rs +++ b/src/auth/src/credentials/external_account_sources/executable_sourced.rs @@ -19,6 +19,7 @@ use crate::{ credentials::subject_token::{ Builder as SubjectTokenBuilder, SubjectToken, SubjectTokenProvider, }, + io::{SharedEnvProvider, SharedFsProvider}, }; use google_cloud_gax::error::CredentialsError; use serde::{Deserialize, Serialize}; @@ -34,6 +35,10 @@ pub(crate) struct ExecutableSourcedCredentials { pub args: Vec, pub timeout: Duration, pub output_file: Option, + #[serde(skip)] + pub env: SharedEnvProvider, + #[serde(skip)] + pub fs: SharedFsProvider, } /// Executable command should adere to this format. @@ -101,13 +106,14 @@ impl SubjectTokenProvider for ExecutableSourcedCredentials { async fn subject_token(&self) -> Result { if let Some(output_file) = self.output_file.clone() { - let token = Self::from_output_file(output_file).await; + let token = self.read_output_file(output_file).await; if let Ok(token) = token { return Ok(SubjectTokenBuilder::new(token).build()); } } - let token = - Self::from_command(self.command.clone(), self.args.clone(), self.timeout).await?; + let token = self + .run_command(self.command.clone(), self.args.clone(), self.timeout) + .await?; if token.is_empty() { let msg = format!("{MSG}, subject token is empty"); return Err(CredentialsError::from_msg(false, msg)); @@ -118,7 +124,11 @@ impl SubjectTokenProvider for ExecutableSourcedCredentials { } impl ExecutableSourcedCredentials { - pub(crate) fn new(executable: ExecutableConfig) -> Self { + pub(crate) fn new( + executable: ExecutableConfig, + env: SharedEnvProvider, + fs: SharedFsProvider, + ) -> Self { let (command, args) = Self::split_command(executable.command); let timeout = match executable.timeout_millis { Some(timeout) => Duration::from_millis(timeout.into()), @@ -130,11 +140,15 @@ impl ExecutableSourcedCredentials { args, timeout, output_file, + env, + fs, } } - async fn from_output_file(output_file: String) -> Result { - let content = std::fs::read_to_string(output_file) + async fn read_output_file(&self, output_file: String) -> Result { + let content = self + .fs + .read_to_string(&output_file) .map_err(|e| CredentialsError::from_source(false, e))?; Self::parse_token(content) @@ -142,10 +156,16 @@ impl ExecutableSourcedCredentials { /// See details on security reason on [executable sourced credentials]. /// [executable sourced credentials]: https://google.aip.dev/auth/4117#determining-the-subject-token-in-executable-sourced-credentials - async fn from_command(command: String, args: Vec, timeout: Duration) -> Result { + async fn run_command( + &self, + command: String, + args: Vec, + timeout: Duration, + ) -> Result { // For security reasons, we need our consumers to set this environment variable to allow executables to be run. - let allow_executable = std::env::var(ALLOW_EXECUTABLE_ENV) - .ok() + let allow_executable = self + .env + .var(ALLOW_EXECUTABLE_ENV) .unwrap_or("0".to_string()); if allow_executable != "1" { return Err(CredentialsError::from_msg( @@ -283,7 +303,11 @@ mod tests { command: format!("cat {}", path.to_str().unwrap()), ..ExecutableConfig::default() }; - let token_provider = ExecutableSourcedCredentials::new(config); + let token_provider = ExecutableSourcedCredentials::new( + config, + SharedEnvProvider::default(), + SharedFsProvider::default(), + ); let resp = token_provider.subject_token().await?; assert_eq!(resp.token, "an_example_token".to_string()); @@ -311,7 +335,11 @@ mod tests { command: format!("cat {path}"), ..ExecutableConfig::default() }; - let token_provider = ExecutableSourcedCredentials::new(config); + let token_provider = ExecutableSourcedCredentials::new( + config, + SharedEnvProvider::default(), + SharedFsProvider::default(), + ); let err = token_provider .subject_token() .await @@ -345,7 +373,11 @@ mod tests { command: "do nothing".to_string(), ..ExecutableConfig::default() }; - let token_provider = ExecutableSourcedCredentials::new(config); + let token_provider = ExecutableSourcedCredentials::new( + config, + SharedEnvProvider::default(), + SharedFsProvider::default(), + ); let resp = token_provider.subject_token().await?; assert_eq!(resp.token, "an_example_token"); @@ -391,7 +423,11 @@ mod tests { command: format!("cat {}", valid_path.to_str().unwrap()), ..ExecutableConfig::default() }; - let token_provider = ExecutableSourcedCredentials::new(config); + let token_provider = ExecutableSourcedCredentials::new( + config, + SharedEnvProvider::default(), + SharedFsProvider::default(), + ); let resp = token_provider.subject_token().await?; assert_eq!(resp.token, "a_valid_token"); @@ -488,7 +524,11 @@ done"; timeout_millis: Some(1000), ..ExecutableConfig::default() }; - let token_provider = ExecutableSourcedCredentials::new(config); + let token_provider = ExecutableSourcedCredentials::new( + config, + SharedEnvProvider::default(), + SharedFsProvider::default(), + ); let err = token_provider .subject_token() .await diff --git a/src/auth/src/credentials/external_account_sources/file_sourced.rs b/src/auth/src/credentials/external_account_sources/file_sourced.rs index 531a5677d4..a613dcf067 100644 --- a/src/auth/src/credentials/external_account_sources/file_sourced.rs +++ b/src/auth/src/credentials/external_account_sources/file_sourced.rs @@ -22,6 +22,7 @@ use crate::{ credentials::subject_token::{ Builder as SubjectTokenBuilder, SubjectToken, SubjectTokenProvider, }, + io::SharedFsProvider, }; #[derive(Serialize, Deserialize, Debug, Clone)] @@ -29,10 +30,16 @@ pub(crate) struct FileSourcedCredentials { pub file: String, pub format: String, pub subject_token_field_name: String, + #[serde(skip)] + pub fs: SharedFsProvider, } impl FileSourcedCredentials { - pub(crate) fn new(file: String, format_source: Option) -> Self { + pub(crate) fn new( + file: String, + format_source: Option, + fs: SharedFsProvider, + ) -> Self { let (format, subject_token_field_name) = format_source .map(|f| { ( @@ -45,6 +52,7 @@ impl FileSourcedCredentials { file, format, subject_token_field_name, + fs, } } } @@ -54,7 +62,9 @@ const JSON_FORMAT_TYPE: &str = "json"; impl SubjectTokenProvider for FileSourcedCredentials { type Error = CredentialsError; async fn subject_token(&self) -> Result { - let content = std::fs::read_to_string(&self.file) + let content = self + .fs + .read_to_string(&self.file) .map_err(|e| CredentialsError::from_source(false, e))?; match self.format.as_str() { @@ -102,6 +112,7 @@ mod tests { file: file.path().to_str().unwrap().to_string(), format: "text".into(), subject_token_field_name: "".into(), + fs: SharedFsProvider::default(), }; let resp = token_provider.subject_token().await?; assert_eq!(resp.token, "an_example_token".to_string()); @@ -119,6 +130,7 @@ mod tests { file: file.path().to_str().unwrap().to_string(), format: "json".into(), subject_token_field_name: "access_token".into(), + fs: SharedFsProvider::default(), }; let resp = token_provider.subject_token().await?; assert_eq!(resp.token, "an_example_token".to_string()); @@ -136,6 +148,7 @@ mod tests { file: file.path().to_str().unwrap().to_string(), format: "json".into(), subject_token_field_name: "access_token".into(), + fs: SharedFsProvider::default(), }; let err = token_provider .subject_token() @@ -158,6 +171,7 @@ mod tests { file: "/path/to/non/existent/file".to_string(), format: "text".into(), subject_token_field_name: "".into(), + fs: SharedFsProvider::default(), }; let err = token_provider .subject_token() @@ -175,6 +189,7 @@ mod tests { file: file.path().to_str().unwrap().to_string(), format: "text".into(), subject_token_field_name: "".into(), + fs: SharedFsProvider::default(), }; let resp = token_provider.subject_token().await?; assert_eq!(resp.token, "".to_string()); @@ -188,6 +203,7 @@ mod tests { file: file.path().to_str().unwrap().to_string(), format: "json".into(), subject_token_field_name: "access_token".into(), + fs: SharedFsProvider::default(), }; let err = token_provider .subject_token() diff --git a/src/auth/src/credentials/external_account_sources/url_sourced.rs b/src/auth/src/credentials/external_account_sources/url_sourced.rs index 3ad65f04e2..4ff9ffd44d 100644 --- a/src/auth/src/credentials/external_account_sources/url_sourced.rs +++ b/src/auth/src/credentials/external_account_sources/url_sourced.rs @@ -13,10 +13,9 @@ // limitations under the License. use google_cloud_gax::error::CredentialsError; -use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::{collections::HashMap, time::Duration}; +use std::collections::HashMap; use crate::{ Result, @@ -24,7 +23,7 @@ use crate::{ credentials::subject_token::{ Builder as SubjectTokenBuilder, SubjectToken, SubjectTokenProvider, }, - errors, + io::{HttpRequest, SharedHttpClientProvider}, }; #[derive(Serialize, Deserialize, Debug, Clone)] @@ -33,6 +32,8 @@ pub(crate) struct UrlSourcedCredentials { pub headers: HashMap, pub format: String, pub subject_token_field_name: String, + #[serde(skip)] + pub http: SharedHttpClientProvider, } impl UrlSourcedCredentials { @@ -40,6 +41,7 @@ impl UrlSourcedCredentials { url: String, headers: Option>, format_source: Option, + http: SharedHttpClientProvider, ) -> Self { let (format, subject_token_field_name) = format_source .map(|f| { @@ -54,6 +56,7 @@ impl UrlSourcedCredentials { headers: headers.unwrap_or_default(), format, subject_token_field_name, + http, } } } @@ -64,31 +67,24 @@ const JSON_FORMAT_TYPE: &str = "json"; impl SubjectTokenProvider for UrlSourcedCredentials { type Error = CredentialsError; async fn subject_token(&self) -> Result { - let client = Client::builder() - .timeout(Duration::from_secs(10)) - .build() - .unwrap(); - - let request = client.get(self.url.clone()); - let request = self - .headers - .iter() - .fold(request, |r, (k, v)| r.header(k.as_str(), v.as_str())); - - let response = request - .send() + let mut request = HttpRequest::get(&self.url); + for (k, v) in &self.headers { + request = request.header(k, v); + } + + let response = self + .http + .execute(request) .await - .map_err(|e| errors::from_http_error(e, MSG))?; + .map_err(|e| crate::errors::from_http_error(e, MSG))?; - if !response.status().is_success() { - let err = errors::from_http_response(response, MSG).await; - return Err(err); + if !response.is_success() { + return Err(crate::errors::from_http_response(&response, MSG)); } - let response_text = response.text().await.map_err(|e| { - let retryable = !e.is_body(); - CredentialsError::from_source(retryable, e) - })?; + let response_text = response + .text() + .map_err(|e| CredentialsError::from_source(false, e))?; match self.format.as_str() { JSON_FORMAT_TYPE => { @@ -144,6 +140,7 @@ mod tests { format: "json".into(), subject_token_field_name: "access_token".into(), headers: HashMap::from([("Metadata".to_string(), "True".to_string())]), + http: SharedHttpClientProvider::default(), }; let resp = token_provider.subject_token().await?; @@ -168,6 +165,7 @@ mod tests { format: "text".into(), subject_token_field_name: "".into(), headers: HashMap::new(), + http: SharedHttpClientProvider::default(), }; let resp = token_provider.subject_token().await?; @@ -198,6 +196,7 @@ mod tests { format: "json".into(), subject_token_field_name: "access_token".into(), headers: HashMap::from([("Metadata".to_string(), "True".to_string())]), + http: SharedHttpClientProvider::default(), }; let err = token_provider diff --git a/src/auth/src/credentials/idtoken.rs b/src/auth/src/credentials/idtoken.rs index 12cd271f99..3e0fed74d1 100644 --- a/src/auth/src/credentials/idtoken.rs +++ b/src/auth/src/credentials/idtoken.rs @@ -76,6 +76,10 @@ use crate::build_errors::Error as BuilderError; use crate::credentials::{AdcContents, CredentialsError, extract_credential_type, load_adc}; +use crate::io::{ + EnvProvider, FsProvider, HttpClientProvider, IoConfig, SharedEnvProvider, SharedFsProvider, + SharedHttpClientProvider, +}; use crate::token::Token; use crate::{BuildResult, Result}; use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; @@ -181,6 +185,7 @@ pub(crate) mod dynamic { pub struct Builder { target_audience: String, include_email: bool, + providers: IoConfig, } impl Builder { @@ -195,6 +200,7 @@ impl Builder { Self { target_audience: target_audience.into(), include_email: false, + providers: IoConfig::default(), } } @@ -208,6 +214,39 @@ impl Builder { self } + /// Sets a custom environment variable provider. + /// + /// When set, the auth crate will use this provider for all environment + /// variable lookups during ID token credential construction instead of + /// reading from the process environment directly. + pub fn with_env_provider(mut self, provider: impl EnvProvider + 'static) -> Self { + self.providers.env = SharedEnvProvider::new(provider); + self + } + + /// Sets a custom filesystem provider. + /// + /// When set, the auth crate will use this provider for all file read + /// operations during ID token credential construction instead of reading + /// from the real filesystem directly. + pub fn with_fs_provider(mut self, provider: impl FsProvider + 'static) -> Self { + self.providers.fs = SharedFsProvider::new(provider); + self + } + + /// Sets a custom HTTP client provider. + /// + /// When set, the auth crate will use this provider for all HTTP + /// requests during ID token credential construction and token retrieval + /// instead of using `reqwest::Client` directly. + pub fn with_http_client_provider( + mut self, + provider: impl HttpClientProvider + 'static, + ) -> Self { + self.providers.http = SharedHttpClientProvider::new(provider); + self + } + /// Returns a [IDTokenCredentials] instance with the configured settings. /// /// # Errors @@ -219,14 +258,20 @@ impl Builder { /// /// [application-default credentials]: https://cloud.google.com/docs/authentication/application-default-credentials pub fn build(self) -> BuildResult { - let json_data = match load_adc()? { + let providers = self.providers; + let json_data = match load_adc(&providers.env, &providers.fs)? { AdcContents::Contents(contents) => { Some(serde_json::from_str(&contents).map_err(BuilderError::parsing)?) } AdcContents::FallbackToMds => None, }; - build_id_token_credentials(self.target_audience, self.include_email, json_data) + build_id_token_credentials( + self.target_audience, + self.include_email, + json_data, + providers, + ) } } enum IDTokenBuilder { @@ -239,8 +284,9 @@ fn build_id_token_credentials( audience: String, include_email: bool, json: Option, + providers: IoConfig, ) -> BuildResult { - let builder = build_id_token_credentials_internal(audience, include_email, json)?; + let builder = build_id_token_credentials_internal(audience, include_email, json, &providers)?; match builder { IDTokenBuilder::Mds(builder) => builder.build(), IDTokenBuilder::ServiceAccount(builder) => builder.build(), @@ -252,6 +298,7 @@ fn build_id_token_credentials_internal( audience: String, include_email: bool, json: Option, + providers: &IoConfig, ) -> BuildResult { match json { None => { @@ -262,7 +309,11 @@ fn build_id_token_credentials_internal( mds::Format::Standard }; Ok(IDTokenBuilder::Mds( - mds::Builder::new(audience).with_format(format), + mds::Builder::new(audience) + .with_format(format) + .with_env_provider(SharedEnvProvider::clone(&providers.env)) + .with_fs_provider(SharedFsProvider::clone(&providers.fs)) + .with_http_client_provider(SharedHttpClientProvider::clone(&providers.http)), )) } Some(json) => { @@ -275,7 +326,10 @@ fn build_id_token_credentials_internal( service_account::Builder::new(audience, json), )), "impersonated_service_account" => { - let builder = impersonated::Builder::new(audience, json); + let builder = impersonated::Builder::new(audience, json) + .with_http_client_provider(SharedHttpClientProvider::clone( + &providers.http, + )); let builder = if include_email { builder.with_include_email() } else { @@ -465,7 +519,7 @@ pub(crate) mod tests { "refresh_token": "test_refresh_token", }); - let result = build_id_token_credentials(audience, false, Some(json)); + let result = build_id_token_credentials(audience, false, Some(json), IoConfig::default()); assert!(result.is_err(), "{result:?}"); let err = result.unwrap_err(); assert!(err.is_not_supported()); @@ -493,7 +547,7 @@ pub(crate) mod tests { } }); - let result = build_id_token_credentials(audience, false, Some(json)); + let result = build_id_token_credentials(audience, false, Some(json), IoConfig::default()); assert!(result.is_err(), "{result:?}"); let err = result.unwrap_err(); assert!(err.is_not_supported()); @@ -509,7 +563,7 @@ pub(crate) mod tests { "type": "unknown_credential_type", }); - let result = build_id_token_credentials(audience, false, Some(json)); + let result = build_id_token_credentials(audience, false, Some(json), IoConfig::default()); assert!(result.is_err(), "{result:?}"); let err = result.unwrap_err(); assert!(err.is_unknown_type()); @@ -523,14 +577,24 @@ pub(crate) mod tests { let audience = "test_audience".to_string(); // Test with include_email = true and no source credentials (MDS Fallback) - let creds = build_id_token_credentials_internal(audience.clone(), true, None)?; + let creds = build_id_token_credentials_internal( + audience.clone(), + true, + None, + &IoConfig::default(), + )?; assert!(matches!(creds, IDTokenBuilder::Mds(_))); if let IDTokenBuilder::Mds(builder) = creds { assert!(matches!(builder.format, Some(Format::Full))); } // Test with include_email = false and no source credentials (MDS Fallback) - let creds = build_id_token_credentials_internal(audience.clone(), false, None)?; + let creds = build_id_token_credentials_internal( + audience.clone(), + false, + None, + &IoConfig::default(), + )?; assert!(matches!(creds, IDTokenBuilder::Mds(_))); if let IDTokenBuilder::Mds(builder) = creds { assert!(matches!(builder.format, Some(Format::Standard))); @@ -561,15 +625,24 @@ pub(crate) mod tests { }); // Test with include_email = true and impersonated source credentials - let creds = - build_id_token_credentials_internal(audience.clone(), true, Some(json.clone()))?; + let creds = build_id_token_credentials_internal( + audience.clone(), + true, + Some(json.clone()), + &IoConfig::default(), + )?; assert!(matches!(creds, IDTokenBuilder::Impersonated(_))); if let IDTokenBuilder::Impersonated(builder) = creds { assert_eq!(builder.include_email, Some(true)); } // Test with include_email = false and impersonated source credentials - let creds = build_id_token_credentials_internal(audience.clone(), false, Some(json))?; + let creds = build_id_token_credentials_internal( + audience.clone(), + false, + Some(json), + &IoConfig::default(), + )?; assert!(matches!(creds, IDTokenBuilder::Impersonated(_))); if let IDTokenBuilder::Impersonated(builder) = creds { assert_eq!(builder.include_email, None); diff --git a/src/auth/src/credentials/idtoken/impersonated.rs b/src/auth/src/credentials/idtoken/impersonated.rs index 49e408111b..6b24b968d6 100644 --- a/src/auth/src/credentials/idtoken/impersonated.rs +++ b/src/auth/src/credentials/idtoken/impersonated.rs @@ -75,8 +75,8 @@ use crate::{ build_components_from_json, }, }, - errors, headers_util::{self, ID_TOKEN_REQUEST_TYPE, metrics_header_value}, + io::{HttpClientProvider, HttpRequest, IoConfig, SharedHttpClientProvider}, retry::Builder as RetryTokenProviderBuilder, token::{CachedTokenProvider, Token, TokenProvider}, token_cache::TokenCache, @@ -87,7 +87,6 @@ use google_cloud_gax::error::CredentialsError; use google_cloud_gax::retry_policy::RetryPolicyArg; use google_cloud_gax::retry_throttler::RetryThrottlerArg; use http::{Extensions, HeaderMap}; -use reqwest::Client; use serde_json::Value; use std::sync::Arc; @@ -119,6 +118,7 @@ pub struct Builder { target_audience: String, service_account_impersonation_url: Option, retry_builder: RetryTokenProviderBuilder, + http: SharedHttpClientProvider, } impl Builder { @@ -137,6 +137,7 @@ impl Builder { target_audience: target_audience.into(), service_account_impersonation_url: None, retry_builder: RetryTokenProviderBuilder::default(), + http: SharedHttpClientProvider::default(), } } @@ -173,6 +174,7 @@ impl Builder { target_principal.into() )), retry_builder: RetryTokenProviderBuilder::default(), + http: SharedHttpClientProvider::default(), } } @@ -306,6 +308,19 @@ impl Builder { self } + /// Sets a custom HTTP client provider. + /// + /// When set, the auth crate will use this provider for all HTTP + /// requests during ID token retrieval instead of using + /// `reqwest::Client` directly. + pub fn with_http_client_provider( + mut self, + provider: impl HttpClientProvider + 'static, + ) -> Self { + self.http = SharedHttpClientProvider::new(provider); + self + } + /// Returns a [Credentials] instance with the configured settings. /// /// # Errors @@ -322,7 +337,7 @@ impl Builder { pub fn build(self) -> BuildResult { let components = match self.source { BuilderSource::FromJson(json) => { - let mut components = build_components_from_json(json)?; + let mut components = build_components_from_json(json, IoConfig::default())?; components.service_account_impersonation_url = components .service_account_impersonation_url .replace("generateAccessToken", "generateIdToken"); @@ -341,6 +356,7 @@ impl Builder { delegates: self.delegates.or(components.delegates), include_email: self.include_email, target_audience: self.target_audience, + http: self.http, }; let token_provider = self.retry_builder.build(token_provider); Ok(IDTokenCredentials { @@ -382,6 +398,7 @@ pub(crate) struct ImpersonatedTokenProvider { pub(crate) delegates: Option>, pub(crate) target_audience: String, pub(crate) include_email: Option, + pub(crate) http: SharedHttpClientProvider, } #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)] @@ -399,40 +416,35 @@ async fn generate_id_token( audience: String, include_email: Option, service_account_impersonation_url: &str, + http: &SharedHttpClientProvider, ) -> Result { - let client = Client::new(); - let body = GenerateIdTokenRequest { audience, delegates, include_email, }; + let body = serde_json::to_vec(&body).map_err(|e| CredentialsError::from_source(false, e))?; - let response = client - .post(service_account_impersonation_url) - .header("Content-Type", "application/json") + let request = HttpRequest::post(service_account_impersonation_url) + .json(body) .header( headers_util::X_GOOG_API_CLIENT, metrics_header_value(ID_TOKEN_REQUEST_TYPE, IMPERSONATED_CREDENTIAL_TYPE), ) - .headers(source_headers) - .json(&body) - .send() + .headers_from_map(&source_headers); + + let response = http + .execute(request) .await - .map_err(|e| errors::from_http_error(e, MSG))?; + .map_err(|e| crate::errors::from_http_error(e, MSG))?; - if !response.status().is_success() { - let err = errors::from_http_response(response, MSG).await; - return Err(err); + if !response.is_success() { + return Err(crate::errors::from_http_response(&response, MSG)); } - let token_response = response - .json::() - .await - .map_err(|e| { - let retryable = !e.is_decode(); - CredentialsError::from_source(retryable, e) - })?; + let token_response: GenerateIdTokenResponse = response + .json() + .map_err(|e| CredentialsError::from_source(e.is_io(), e))?; parse_id_token_from_str(token_response.token) } @@ -454,6 +466,7 @@ impl TokenProvider for ImpersonatedTokenProvider { self.target_audience.clone(), self.include_email, &self.service_account_impersonation_url, + &self.http, ) .await } diff --git a/src/auth/src/credentials/idtoken/mds.rs b/src/auth/src/credentials/idtoken/mds.rs index b06bc14a22..e7722fc061 100644 --- a/src/auth/src/credentials/idtoken/mds.rs +++ b/src/auth/src/credentials/idtoken/mds.rs @@ -66,6 +66,10 @@ use crate::Result; use crate::credentials::CacheableResource; use crate::errors::CredentialsError; +use crate::io::{ + EnvProvider, FsProvider, HttpClientProvider, IoConfig, SharedEnvProvider, SharedFsProvider, + SharedHttpClientProvider, +}; use crate::mds::client::Client as MDSClient; use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry}; use crate::token::{CachedTokenProvider, Token, TokenProvider}; @@ -135,6 +139,7 @@ pub struct Builder { licenses: Option, target_audience: String, retry_builder: RetryTokenProviderBuilder, + providers: IoConfig, } impl Builder { @@ -150,6 +155,7 @@ impl Builder { licenses: None, target_audience: target_audience.into(), retry_builder: RetryTokenProviderBuilder::default(), + providers: IoConfig::default(), } } @@ -291,8 +297,41 @@ impl Builder { self } + /// Sets a custom environment variable provider. + /// + /// When set, the auth crate will use this provider for all environment + /// variable lookups during ID token credential construction instead of + /// reading from the process environment directly. + pub fn with_env_provider(mut self, provider: impl EnvProvider + 'static) -> Self { + self.providers.env = SharedEnvProvider::new(provider); + self + } + + /// Sets a custom filesystem provider. + /// + /// When set, the auth crate will use this provider for all file read + /// operations during ID token credential construction instead of reading + /// from the real filesystem directly. + pub fn with_fs_provider(mut self, provider: impl FsProvider + 'static) -> Self { + self.providers.fs = SharedFsProvider::new(provider); + self + } + + /// Sets a custom HTTP client provider. + /// + /// When set, the auth crate will use this provider for all HTTP + /// requests during ID token credential construction and token retrieval + /// instead of using `reqwest::Client` directly. + pub fn with_http_client_provider( + mut self, + provider: impl HttpClientProvider + 'static, + ) -> Self { + self.providers.http = SharedHttpClientProvider::new(provider); + self + } + fn build_token_provider(self) -> TokenProviderWithRetry { - let client = MDSClient::new(self.endpoint); + let client = MDSClient::new(self.endpoint, self.providers.env, self.providers.http); let tp = MDSTokenProvider { format: self.format, licenses: self.licenses, @@ -340,15 +379,13 @@ mod tests { use super::*; use crate::credentials::idtoken::tests::generate_test_id_token; use crate::credentials::tests::{ - find_source_error, get_mock_auth_retry_policy, get_mock_backoff_policy, - get_mock_retry_throttler, + get_mock_auth_retry_policy, get_mock_backoff_policy, get_mock_retry_throttler, }; use crate::mds::{GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI}; use httptest::cycle; use httptest::matchers::{all_of, contains, request, url_decoded}; use httptest::responders::status_code; use httptest::{Expectation, Server}; - use reqwest::StatusCode; use scoped_env::ScopedEnv; use serial_test::{parallel, serial}; use test_case::test_case; @@ -377,11 +414,8 @@ mod tests { .build()?; let err = creds.id_token().await.unwrap_err(); - let source = find_source_error::(&err); - assert!( - matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)), - "{err:?}" - ); + assert!(!err.is_transient(), "{err:?}"); + assert!(format!("{err:?}").contains("401"), "{err:?}"); Ok(()) } @@ -494,11 +528,8 @@ mod tests { .build()?; let err = creds.id_token().await.unwrap_err(); - let source = find_source_error::(&err); - assert!( - matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)), - "{err:?}" - ); + assert!(err.is_transient(), "{err:?}"); + assert!(format!("{err:?}").contains("503"), "{err:?}"); Ok(()) } diff --git a/src/auth/src/credentials/idtoken/service_account.rs b/src/auth/src/credentials/idtoken/service_account.rs index e08976c659..cd60046960 100644 --- a/src/auth/src/credentials/idtoken/service_account.rs +++ b/src/auth/src/credentials/idtoken/service_account.rs @@ -68,13 +68,14 @@ use crate::credentials::CacheableResource; use crate::credentials::idtoken::dynamic::IDTokenCredentialsProvider; use crate::credentials::idtoken::parse_id_token_from_str; use crate::credentials::service_account::{ServiceAccountKey, ServiceAccountTokenGenerator}; +use crate::errors; +use crate::io::{HttpClientProvider, HttpRequest, SharedHttpClientProvider}; use crate::token::{CachedTokenProvider, Token, TokenProvider}; use crate::token_cache::TokenCache; use crate::{BuildResult, credentials::idtoken::IDTokenCredentials}; use async_trait::async_trait; use google_cloud_gax::error::CredentialsError; use http::Extensions; -use reqwest::Client; use serde_json::Value; use std::sync::Arc; @@ -108,6 +109,7 @@ struct ServiceAccountTokenProvider { audience: String, target_audience: String, token_server_url: String, + http: SharedHttpClientProvider, } #[derive(serde::Deserialize)] @@ -115,6 +117,8 @@ struct IdTokenResponse { id_token: String, } +const MSG: &str = "failed to exchange id token"; + #[async_trait] impl TokenProvider for ServiceAccountTokenProvider { async fn token(&self) -> Result { @@ -128,26 +132,27 @@ impl TokenProvider for ServiceAccountTokenProvider { ); let assertion = tg.generate()?; - let client = Client::new(); - let request = client.post(&self.token_server_url).form(&[ - ("grant_type", JWT_BEARER_GRANT_TYPE.to_string()), - ("assertion", assertion), - ]); + let body = url::form_urlencoded::Serializer::new(String::new()) + .append_pair("grant_type", JWT_BEARER_GRANT_TYPE) + .append_pair("assertion", &assertion) + .finish() + .into_bytes(); + + let request = HttpRequest::post(&self.token_server_url).form(body); - let response = request - .send() + let response = self + .http + .execute(request) .await - .map_err(|e| crate::errors::from_http_error(e, "failed to exchange id token"))?; + .map_err(|e| errors::from_http_error(e, MSG))?; - if !response.status().is_success() { - let err = crate::errors::from_http_response(response, "failed to fetch id token").await; - return Err(err); + if !response.is_success() { + return Err(errors::from_http_response(&response, MSG)); } let token_res: IdTokenResponse = response .json() - .await - .map_err(|e| CredentialsError::from_source(!e.is_decode(), e))?; + .map_err(|e| CredentialsError::from_source(e.is_io(), e))?; parse_id_token_from_str(token_res.id_token) } @@ -159,6 +164,7 @@ pub struct Builder { service_account_key: Value, target_audience: String, token_server_url: String, + http: SharedHttpClientProvider, } impl Builder { @@ -170,6 +176,7 @@ impl Builder { service_account_key, target_audience: target_audience.into(), token_server_url: OAUTH2_TOKEN_SERVER_URL.to_string(), + http: SharedHttpClientProvider::default(), } } @@ -179,6 +186,15 @@ impl Builder { self } + /// Sets a custom HTTP client provider. + pub fn with_http_client_provider( + mut self, + provider: impl HttpClientProvider + 'static, + ) -> Self { + self.http = SharedHttpClientProvider::new(provider); + self + } + fn build_token_provider( self, target_audience: String, @@ -191,6 +207,7 @@ impl Builder { audience: OAUTH2_TOKEN_SERVER_URL.to_string(), target_audience, token_server_url: self.token_server_url, + http: self.http, }) } diff --git a/src/auth/src/credentials/idtoken/user_account.rs b/src/auth/src/credentials/idtoken/user_account.rs index a68d0e0056..dd789d38ac 100644 --- a/src/auth/src/credentials/idtoken/user_account.rs +++ b/src/auth/src/credentials/idtoken/user_account.rs @@ -58,6 +58,7 @@ use crate::build_errors::Error as BuilderError; use crate::credentials::CacheableResource; use crate::credentials::user_account::UserTokenProvider; +use crate::io::{HttpClientProvider, SharedHttpClientProvider}; use crate::retry::Builder as RetryTokenProviderBuilder; use crate::token::CachedTokenProvider; use crate::token_cache::TokenCache; @@ -121,6 +122,7 @@ pub struct Builder { authorized_user: Value, token_uri: Option, retry_builder: RetryTokenProviderBuilder, + http: SharedHttpClientProvider, } impl Builder { @@ -136,6 +138,7 @@ impl Builder { authorized_user, token_uri: None, retry_builder: RetryTokenProviderBuilder::default(), + http: SharedHttpClientProvider::default(), } } @@ -234,6 +237,19 @@ impl Builder { self } + /// Sets a custom HTTP client provider. + /// + /// When set, the auth crate will use this provider for all HTTP + /// requests during ID token retrieval instead of using + /// `reqwest::Client` directly. + pub fn with_http_client_provider( + mut self, + provider: impl HttpClientProvider + 'static, + ) -> Self { + self.http = SharedHttpClientProvider::new(provider); + self + } + fn build_token_provider(&self) -> BuildResult { let authorized_user = serde_json::from_value::(self.authorized_user.clone()) @@ -241,6 +257,7 @@ impl Builder { Ok(UserTokenProvider::new_id_token_provider( authorized_user, self.token_uri.clone(), + SharedHttpClientProvider::clone(&self.http), )) } @@ -283,7 +300,7 @@ mod tests { use crate::credentials::user_account::{ Oauth2RefreshRequest, Oauth2RefreshResponse, RefreshGrantType, }; - use http::StatusCode; + use crate::errors::CredentialsError; use httptest::cycle; use httptest::matchers::{all_of, json_decoded, request}; use httptest::responders::{json_encoded, status_code}; @@ -396,11 +413,9 @@ mod tests { let err = creds.id_token().await.unwrap_err(); assert!(!err.is_transient()); - let source = find_source_error::(&err); - assert!( - matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)), - "{err:?}" - ); + let source = find_source_error::(&err); + assert!(source.is_some(), "{err:?}"); + assert!(!source.unwrap().is_transient(), "{err:?}"); Ok(()) } diff --git a/src/auth/src/credentials/impersonated.rs b/src/auth/src/credentials/impersonated.rs index 8b4cfe22db..5f127fca8f 100644 --- a/src/auth/src/credentials/impersonated.rs +++ b/src/auth/src/credentials/impersonated.rs @@ -101,6 +101,10 @@ use crate::errors::{self, CredentialsError}; use crate::headers_util::{ self, ACCESS_TOKEN_REQUEST_TYPE, AuthHeadersBuilder, metrics_header_value, }; +use crate::io::{ + EnvProvider, FsProvider, HttpClientProvider, HttpRequest, IoConfig, SharedEnvProvider, + SharedFsProvider, SharedHttpClientProvider, +}; use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry}; use crate::token::{CachedTokenProvider, Token, TokenProvider}; use crate::token_cache::TokenCache; @@ -110,7 +114,6 @@ use google_cloud_gax::backoff_policy::BackoffPolicyArg; use google_cloud_gax::retry_policy::RetryPolicyArg; use google_cloud_gax::retry_throttler::RetryThrottlerArg; use http::{Extensions, HeaderMap}; -use reqwest::Client; use serde_json::Value; use std::fmt::Debug; use std::sync::Arc; @@ -157,6 +160,7 @@ pub struct Builder { retry_builder: RetryTokenProviderBuilder, iam_endpoint_override: Option, is_access_boundary_enabled: bool, + providers: IoConfig, } impl Builder { @@ -178,6 +182,7 @@ impl Builder { retry_builder: RetryTokenProviderBuilder::default(), iam_endpoint_override: None, is_access_boundary_enabled: true, + providers: IoConfig::default(), } } @@ -208,6 +213,7 @@ impl Builder { retry_builder: RetryTokenProviderBuilder::default(), iam_endpoint_override: None, is_access_boundary_enabled: true, + providers: IoConfig::default(), } } @@ -410,6 +416,39 @@ impl Builder { self } + /// Sets a custom environment variable provider. + /// + /// When set, the auth crate will use this provider for all environment + /// variable lookups during impersonated credential construction instead + /// of reading from the process environment directly. + pub fn with_env_provider(mut self, provider: impl EnvProvider + 'static) -> Self { + self.providers.env = SharedEnvProvider::new(provider); + self + } + + /// Sets a custom filesystem provider. + /// + /// When set, the auth crate will use this provider for all file read + /// operations during impersonated credential construction instead of + /// reading from the real filesystem directly. + pub fn with_fs_provider(mut self, provider: impl FsProvider + 'static) -> Self { + self.providers.fs = SharedFsProvider::new(provider); + self + } + + /// Sets a custom HTTP client provider. + /// + /// When set, the auth crate will use this provider for all HTTP + /// requests during impersonated credential construction and token + /// retrieval instead of using `reqwest::Client` directly. + pub fn with_http_client_provider( + mut self, + provider: impl HttpClientProvider + 'static, + ) -> Self { + self.providers.http = SharedHttpClientProvider::new(provider); + self + } + /// Returns a [Credentials] instance with the configured settings. /// /// # Errors @@ -493,6 +532,7 @@ impl Builder { let service_account_impersonation_url = self.resolve_impersonation_url()?; let client_email = extract_client_email(&service_account_impersonation_url)?; let iam_endpoint_override = self.iam_endpoint_override.clone(); + let http = SharedHttpClientProvider::clone(&self.providers.http); let (token_provider, quota_project_id) = self.build_components()?; let access_boundary_url = crate::access_boundary::service_account_lookup_url( &client_email, @@ -510,6 +550,7 @@ impl Builder { Ok(CredentialsWithAccessBoundary::new( creds, Some(access_boundary_url), + http, )) } @@ -547,6 +588,7 @@ impl Builder { /// [IAM signBlob API]: https://cloud.google.com/iam/docs/reference/credentials/rest/v1/projects.serviceAccounts/signBlob pub fn build_signer(self) -> BuildResult { let iam_endpoint = self.iam_endpoint_override.clone(); + let http = SharedHttpClientProvider::clone(&self.providers.http); let source = self.source.clone(); if let BuilderSource::FromJson(json) = source { // try to build service account signer from json @@ -558,7 +600,7 @@ impl Builder { let service_account_impersonation_url = self.resolve_impersonation_url()?; let client_email = extract_client_email(&service_account_impersonation_url)?; let creds = self.build()?; - let signer = crate::signer::iam::IamSigner::new(client_email, creds, iam_endpoint); + let signer = crate::signer::iam::IamSigner::new(client_email, creds, iam_endpoint, http); Ok(crate::signer::Signer { inner: Arc::new(signer), }) @@ -570,8 +612,11 @@ impl Builder { TokenProviderWithRetry, Option, )> { + let providers = IoConfig::clone(&self.providers); let components = match self.source { - BuilderSource::FromJson(json) => build_components_from_json(json)?, + BuilderSource::FromJson(json) => { + build_components_from_json(json, IoConfig::clone(&providers))? + } BuilderSource::FromCredentials(source_credentials) => { build_components_from_credentials( source_credentials, @@ -594,6 +639,7 @@ impl Builder { delegates, scopes, lifetime: self.lifetime.unwrap_or(DEFAULT_LIFETIME), + http: providers.http, }; let token_provider = self.retry_builder.build(token_provider); Ok((token_provider, quota_project_id)) @@ -630,6 +676,7 @@ fn config_from_json(json: Value) -> BuildResult { pub(crate) fn build_components_from_json( json: Value, + providers: IoConfig, ) -> BuildResult { let config = config_from_json(json)?; @@ -646,7 +693,8 @@ pub(crate) fn build_components_from_json( // the quota project and they typically need different scopes. // If user does want some specific scopes or quota, they can build using the // from_source_credentials method. - let source_credentials = build_credentials(Some(config.source_credentials), None, None)?.into(); + let source_credentials = + build_credentials(Some(config.source_credentials), None, None, providers)?.into(); Ok(ImpersonatedCredentialComponents { source_credentials, @@ -761,6 +809,7 @@ struct ImpersonatedTokenProvider { delegates: Option>, scopes: Vec, lifetime: Duration, + http: SharedHttpClientProvider, } impl Debug for ImpersonatedTokenProvider { @@ -774,6 +823,7 @@ impl Debug for ImpersonatedTokenProvider { .field("delegates", &self.delegates) .field("scopes", &self.scopes) .field("lifetime", &self.lifetime) + .field("http", &self.http) .finish() } } @@ -792,39 +842,35 @@ pub(crate) async fn generate_access_token( scopes: Vec, lifetime: Duration, service_account_impersonation_url: &str, + http: &SharedHttpClientProvider, ) -> Result { - let client = Client::new(); let body = GenerateAccessTokenRequest { delegates, scope: scopes, lifetime: format!("{}s", lifetime.as_secs_f64()), }; + let body = serde_json::to_vec(&body).map_err(|e| CredentialsError::from_source(false, e))?; - let response = client - .post(service_account_impersonation_url) - .header("Content-Type", "application/json") + let request = HttpRequest::post(service_account_impersonation_url) + .json(body) .header( headers_util::X_GOOG_API_CLIENT, metrics_header_value(ACCESS_TOKEN_REQUEST_TYPE, IMPERSONATED_CREDENTIAL_TYPE), ) - .headers(source_headers) - .json(&body) - .send() + .headers_from_map(&source_headers); + + let response = http + .execute(request) .await .map_err(|e| errors::from_http_error(e, MSG))?; - if !response.status().is_success() { - let err = errors::from_http_response(response, MSG).await; - return Err(err); + if !response.is_success() { + return Err(errors::from_http_response(&response, MSG)); } - let token_response = response - .json::() - .await - .map_err(|e| { - let retryable = !e.is_decode(); - CredentialsError::from_source(retryable, e) - })?; + let token_response: GenerateAccessTokenResponse = response + .json() + .map_err(|e| CredentialsError::from_source(e.is_io(), e))?; let parsed_dt = OffsetDateTime::parse( &token_response.expire_time, @@ -860,6 +906,7 @@ impl TokenProvider for ImpersonatedTokenProvider { self.scopes.clone(), self.lifetime, &self.service_account_impersonation_url, + &self.http, ) .await } @@ -915,6 +962,7 @@ mod tests { let mut headers = HeaderMap::new(); headers.insert("authorization", "Bearer test-token".parse().unwrap()); + let http = SharedHttpClientProvider::default(); let token = generate_access_token( headers, None, @@ -923,6 +971,7 @@ mod tests { &server .url("/v1/projects/-/serviceAccounts/test-principal:generateAccessToken") .to_string(), + &http, ) .await?; @@ -947,6 +996,7 @@ mod tests { let mut headers = HeaderMap::new(); headers.insert("authorization", "Bearer test-token".parse().unwrap()); + let http = SharedHttpClientProvider::default(); let err = generate_access_token( headers, None, @@ -955,6 +1005,7 @@ mod tests { &server .url("/v1/projects/-/serviceAccounts/test-principal:generateAccessToken") .to_string(), + &http, ) .await .unwrap_err(); @@ -975,6 +1026,7 @@ mod tests { .respond_with(status_code(401)), ); + let http = SharedHttpClientProvider::default(); let err = generate_access_token( HeaderMap::new(), None, @@ -983,6 +1035,7 @@ mod tests { &server .url("/v1/projects/-/serviceAccounts/test-principal:generateAccessToken") .to_string(), + &http, ) .await .unwrap_err(); @@ -1285,6 +1338,7 @@ mod tests { delegates: Some(vec!["delegate1".to_string()]), scopes: vec!["scope1".to_string()], lifetime: Duration::from_secs(3600), + http: SharedHttpClientProvider::default(), }; let fmt = format!("{expected:?}"); assert!(fmt.contains("UserCredentials"), "{fmt}"); @@ -1373,6 +1427,7 @@ mod tests { delegates: Some(vec!["delegate1".to_string()]), scopes: vec!["scope1".to_string()], lifetime: DEFAULT_LIFETIME, + http: SharedHttpClientProvider::default(), }; let err = token_provider.token().await.unwrap_err(); diff --git a/src/auth/src/credentials/internal/jwk_client.rs b/src/auth/src/credentials/internal/jwk_client.rs index bfb02673ea..b3e3b70691 100644 --- a/src/auth/src/credentials/internal/jwk_client.rs +++ b/src/auth/src/credentials/internal/jwk_client.rs @@ -14,6 +14,7 @@ use crate::Result; use crate::errors::CredentialsError; +use crate::io::{HttpClientProvider, HttpRequest, SharedHttpClientProvider}; use jsonwebtoken::{Algorithm, DecodingKey, jwk::JwkSet}; use std::{ collections::HashMap, @@ -36,6 +37,7 @@ struct CacheEntry { pub struct JwkClient { cache: Arc>>, // KeyID -> Certificate ttl: Duration, + http: SharedHttpClientProvider, } impl JwkClient { @@ -43,14 +45,21 @@ impl JwkClient { Self { cache: Arc::new(RwLock::new(HashMap::new())), ttl: CACHE_TTL, + http: SharedHttpClientProvider::default(), } } + pub fn with_http_client_provider(mut self, http: impl HttpClientProvider + 'static) -> Self { + self.http = SharedHttpClientProvider::new(http); + self + } + #[cfg(test)] fn with_ttl(ttl: Duration) -> Self { Self { cache: Arc::new(RwLock::new(HashMap::new())), ttl, + http: SharedHttpClientProvider::default(), } } @@ -105,23 +114,25 @@ impl JwkClient { } async fn fetch_certs(&self, jwks_url: String) -> Result { - let client = reqwest::Client::new(); + let request = HttpRequest::get(jwks_url); + // TODO(#3592): add retries - let response = client - .get(jwks_url) - .send() + let response = self + .http + .execute(request) .await .map_err(|e| crate::errors::from_http_error(e, "failed to fetch JWK set"))?; - if !response.status().is_success() { - let err = crate::errors::from_http_response(response, "failed to fetch JWK set").await; - return Err(err); + if !response.is_success() { + return Err(crate::errors::from_http_response( + &response, + "failed to fetch JWK set", + )); } let jwk_set: JwkSet = response .json() - .await - .map_err(|e| CredentialsError::new(!e.is_decode(), "failed to parse JWK set", e))?; + .map_err(|e| CredentialsError::new(e.is_io(), "failed to parse JWK set", e))?; Ok(jwk_set) } diff --git a/src/auth/src/credentials/internal/sts_exchange.rs b/src/auth/src/credentials/internal/sts_exchange.rs index b6390bcaf7..15e741aed8 100644 --- a/src/auth/src/credentials/internal/sts_exchange.rs +++ b/src/auth/src/credentials/internal/sts_exchange.rs @@ -14,7 +14,8 @@ use crate::{ constants::{ACCESS_TOKEN_TYPE, TOKEN_EXCHANGE_GRANT_TYPE}, - credentials::errors::{self, CredentialsError}, + credentials::errors::CredentialsError, + io::{HttpRequest, SharedHttpClientProvider}, }; use base64::Engine; use serde::Deserialize; @@ -28,7 +29,10 @@ pub struct STSHandler {} impl STSHandler { /// Performs an oauth2 token exchange with the provided [ExchangeTokenRequest] information. - pub(crate) async fn exchange_token(req: ExchangeTokenRequest) -> Result { + pub(crate) async fn exchange_token( + req: ExchangeTokenRequest, + http: &SharedHttpClientProvider, + ) -> Result { let mut params = HashMap::new(); params.insert("grant_type", TOKEN_EXCHANGE_GRANT_TYPE.to_string()); @@ -60,7 +64,7 @@ impl STSHandler { } } - Self::execute(req.url, req.authentication, req.headers, params).await + Self::execute(req.url, req.authentication, req.headers, params, http).await } /// Execute http request and token exchange @@ -69,29 +73,43 @@ impl STSHandler { client_auth: ClientAuthentication, headers: http::HeaderMap, params: HashMap<&str, String>, + http: &SharedHttpClientProvider, ) -> Result { - let client = reqwest::Client::new(); - + // Inject client authentication into headers let mut headers = headers.clone(); client_auth.inject_auth(&mut headers)?; - let res = client - .post(url) - .form(¶ms) - .headers(headers) - .send() - .await - .map_err(|e| errors::from_http_error(e, MSG))?; + // Build form-encoded body + let body = url::form_urlencoded::Serializer::new(String::new()) + .extend_pairs(params.iter()) + .finish() + .into_bytes(); + + // Start with form-encoded request (sets content-type automatically) + let mut request = HttpRequest::post(url).form(body); - let status = res.status(); - if !status.is_success() { - let err = errors::from_http_response(res, MSG).await; - return Err(err); + // Copy headers from the HeaderMap, but skip content-type since .form() already set it + for (name, value) in headers.iter() { + if name.as_str().eq_ignore_ascii_case("content-type") { + continue; + } + if let Ok(v) = value.to_str() { + request = request.header(name.as_str(), v); + } } - let token_res = res - .json::() + + let response = http + .execute(request) .await - .map_err(|err| CredentialsError::from_source(false, err))?; + .map_err(|e| crate::errors::from_http_error(e, MSG))?; + + if !response.is_success() { + return Err(crate::errors::from_http_response(&response, MSG)); + } + + let token_res: TokenResponse = response + .json() + .map_err(|err| CredentialsError::from_source(err.is_io(), err))?; Ok(token_res) } } @@ -157,10 +175,8 @@ pub struct ExchangeTokenRequest { mod tests { use super::*; use crate::constants::{DEFAULT_SCOPE, JWT_TOKEN_TYPE}; - use http::StatusCode; use httptest::{Expectation, Server, matchers::*, responders::*}; use serde_json::json; - use std::error::Error as _; type TestResult = std::result::Result<(), Box>; @@ -227,12 +243,14 @@ mod tests { headers, authentication, audience: Some("32555940559.apps.googleusercontent.com".to_string()), + scope: [DEFAULT_SCOPE.to_string()].to_vec(), subject_token: "an_example_token".to_string(), subject_token_type: JWT_TOKEN_TYPE.to_string(), ..ExchangeTokenRequest::default() }; - let resp = STSHandler::exchange_token(token_req).await?; + let http = SharedHttpClientProvider::default(); + let resp = STSHandler::exchange_token(token_req, &http).await?; assert_eq!( resp, @@ -302,20 +320,16 @@ mod tests { subject_token_type: JWT_TOKEN_TYPE.to_string(), ..ExchangeTokenRequest::default() }; - let err = STSHandler::exchange_token(token_req).await.unwrap_err(); + let http = SharedHttpClientProvider::default(); + let err = STSHandler::exchange_token(token_req, &http) + .await + .unwrap_err(); assert!(!err.is_transient(), "{err:?}"); assert!(err.to_string().contains(MSG), "{err}, debug={err:?}"); assert!( err.to_string().contains("bad request"), "{err}, debug={err:?}" ); - let source = err - .source() - .and_then(|e| e.downcast_ref::()); - assert!( - matches!(source, Some(e) if e.status() == Some(StatusCode::BAD_REQUEST)), - "{err:?}" - ); Ok(()) } diff --git a/src/auth/src/credentials/mds.rs b/src/auth/src/credentials/mds.rs index d2e183d6af..00f9932677 100644 --- a/src/auth/src/credentials/mds.rs +++ b/src/auth/src/credentials/mds.rs @@ -76,6 +76,10 @@ use crate::access_boundary::CredentialsWithAccessBoundary; use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider}; use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials}; use crate::headers_util::AuthHeadersBuilder; +use crate::io::{ + EnvProvider, FsProvider, HttpClientProvider, IoConfig, SharedEnvProvider, SharedFsProvider, + SharedHttpClientProvider, +}; use crate::mds::client::Client as MDSClient; use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry}; use crate::token::{CachedTokenProvider, Token, TokenProvider}; @@ -128,6 +132,7 @@ pub struct Builder { retry_builder: RetryTokenProviderBuilder, iam_endpoint_override: Option, is_access_boundary_enabled: bool, + providers: IoConfig, } impl Default for Builder { @@ -140,6 +145,7 @@ impl Default for Builder { retry_builder: RetryTokenProviderBuilder::default(), iam_endpoint_override: None, is_access_boundary_enabled: true, + providers: IoConfig::default(), } } } @@ -258,6 +264,39 @@ impl Builder { self } + /// Sets a custom environment variable provider. + /// + /// When set, the auth crate will use this provider for all environment + /// variable lookups during MDS credential construction instead of reading + /// from the process environment directly. + pub fn with_env_provider(mut self, provider: impl EnvProvider + 'static) -> Self { + self.providers.env = SharedEnvProvider::new(provider); + self + } + + /// Sets a custom filesystem provider. + /// + /// When set, the auth crate will use this provider for all file read + /// operations during MDS credential construction instead of reading from + /// the real filesystem directly. + pub fn with_fs_provider(mut self, provider: impl FsProvider + 'static) -> Self { + self.providers.fs = SharedFsProvider::new(provider); + self + } + + /// Sets a custom HTTP client provider. + /// + /// When set, the auth crate will use this provider for all HTTP + /// requests during MDS credential construction and token retrieval instead + /// of using `reqwest::Client` directly. + pub fn with_http_client_provider( + mut self, + provider: impl HttpClientProvider + 'static, + ) -> Self { + self.providers.http = SharedHttpClientProvider::new(provider); + self + } + #[cfg(test)] fn maybe_iam_endpoint_override(mut self, iam_endpoint_override: Option) -> Self { self.iam_endpoint_override = iam_endpoint_override; @@ -283,6 +322,8 @@ impl Builder { .endpoint(self.endpoint) .maybe_scopes(self.scopes) .created_by_adc(self.created_by_adc) + .env_provider(self.providers.env) + .http_provider(self.providers.http) .build(); self.retry_builder.build(tp) } @@ -316,7 +357,12 @@ impl Builder { ) -> BuildResult>> { let iam_endpoint = self.iam_endpoint_override.clone(); let is_access_boundary_enabled = self.is_access_boundary_enabled; - let mds_client = MDSClient::new(self.endpoint.clone()); + let http = SharedHttpClientProvider::clone(&self.providers.http); + let mds_client = MDSClient::new( + self.endpoint.clone(), + SharedEnvProvider::clone(&self.providers.env), + SharedHttpClientProvider::clone(&self.providers.http), + ); let mdsc = MDSCredentials { quota_project_id: self.quota_project_id.clone(), token_provider: TokenCache::new(self.build_token_provider()), @@ -328,6 +374,7 @@ impl Builder { mdsc, mds_client, iam_endpoint, + http, )) } @@ -348,10 +395,15 @@ impl Builder { /// /// [IAM signBlob API]: https://cloud.google.com/iam/docs/reference/credentials/rest/v1/projects.serviceAccounts/signBlob pub fn build_signer(self) -> BuildResult { - let client = MDSClient::new(self.endpoint.clone()); + let client = MDSClient::new( + self.endpoint.clone(), + SharedEnvProvider::clone(&self.providers.env), + SharedHttpClientProvider::clone(&self.providers.http), + ); let iam_endpoint = self.iam_endpoint_override.clone(); + let http = SharedHttpClientProvider::clone(&self.providers.http); let credentials = self.build()?; - let signing_provider = crate::signer::mds::MDSSigner::new(client, credentials); + let signing_provider = crate::signer::mds::MDSSigner::new(client, credentials, http); let signing_provider = iam_endpoint .iter() .fold(signing_provider, |signing_provider, endpoint| { @@ -393,12 +445,18 @@ struct MDSAccessTokenProviderBuilder { scopes: Option>, endpoint: Option, created_by_adc: bool, + env: Option, + http: Option, } impl MDSAccessTokenProviderBuilder { fn build(self) -> MDSAccessTokenProvider { MDSAccessTokenProvider { - client: MDSClient::new(self.endpoint), + client: MDSClient::new( + self.endpoint, + self.env.unwrap_or_default(), + self.http.unwrap_or_default(), + ), scopes: self.scopes, created_by_adc: self.created_by_adc, } @@ -421,6 +479,16 @@ impl MDSAccessTokenProviderBuilder { self.created_by_adc = v; self } + + fn env_provider(mut self, v: SharedEnvProvider) -> Self { + self.env = Some(v); + self + } + + fn http_provider(mut self, v: SharedHttpClientProvider) -> Self { + self.http = Some(v); + self + } } #[derive(Debug, Clone)] @@ -487,7 +555,6 @@ mod tests { use httptest::matchers::{all_of, contains, request, url_decoded}; use httptest::responders::{json_encoded, status_code}; use httptest::{Expectation, Server}; - use reqwest::StatusCode; use scoped_env::ScopedEnv; use serde_json::json; use serial_test::{parallel, serial}; @@ -1006,13 +1073,8 @@ mod tests { .without_access_boundary() .build()?; let err = mdsc.headers(Extensions::new()).await.unwrap_err(); - let original_err = find_source_error::(&err).unwrap(); - assert!(original_err.is_transient()); - let source = find_source_error::(&err); - assert!( - matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)), - "{err:?}" - ); + assert!(err.is_transient(), "{err:?}"); + assert!(format!("{err:?}").contains("503"), "{err:?}"); Ok(()) } @@ -1037,13 +1099,8 @@ mod tests { .build()?; let err = mdsc.headers(Extensions::new()).await.unwrap_err(); - let original_err = find_source_error::(&err).unwrap(); - assert!(!original_err.is_transient()); - let source = find_source_error::(&err); - assert!( - matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)), - "{err:?}" - ); + assert!(!err.is_transient(), "{err:?}"); + assert!(format!("{err:?}").contains("401"), "{err:?}"); Ok(()) } diff --git a/src/auth/src/credentials/service_account.rs b/src/auth/src/credentials/service_account.rs index 1d90520bbc..e591d13b00 100644 --- a/src/auth/src/credentials/service_account.rs +++ b/src/auth/src/credentials/service_account.rs @@ -78,6 +78,10 @@ use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsPro use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials}; use crate::errors::{self}; use crate::headers_util::AuthHeadersBuilder; +use crate::io::{ + EnvProvider, FsProvider, HttpClientProvider, IoConfig, SharedEnvProvider, SharedFsProvider, + SharedHttpClientProvider, +}; use crate::token::{CachedTokenProvider, Token, TokenProvider}; use crate::token_cache::TokenCache; use crate::{BuildResult, Result}; @@ -208,6 +212,7 @@ pub struct Builder { access_specifier: AccessSpecifier, quota_project_id: Option, iam_endpoint_override: Option, + providers: IoConfig, } impl Builder { @@ -223,6 +228,7 @@ impl Builder { access_specifier: AccessSpecifier::Scopes([DEFAULT_SCOPE].map(str::to_string).to_vec()), quota_project_id: None, iam_endpoint_override: None, + providers: IoConfig::default(), } } @@ -271,6 +277,43 @@ impl Builder { self } + /// Sets a custom environment variable provider. + /// + /// When set, the auth crate will use this provider for all environment + /// variable lookups during service account credential construction instead + /// of reading from the process environment directly. + pub fn with_env_provider(mut self, provider: impl EnvProvider + 'static) -> Self { + self.providers.env = SharedEnvProvider::new(provider); + self + } + + /// Sets a custom filesystem provider. + /// + /// When set, the auth crate will use this provider for all file read + /// operations during service account credential construction instead of + /// reading from the real filesystem directly. + pub fn with_fs_provider(mut self, provider: impl FsProvider + 'static) -> Self { + self.providers.fs = SharedFsProvider::new(provider); + self + } + + /// Sets a custom HTTP client provider. + /// + /// When set, the auth crate will use this provider for all HTTP + /// requests during service account credential construction and token + /// retrieval instead of using `reqwest::Client` directly. + /// + /// Service account tokens are generated locally via JWT signing, so the + /// HTTP provider is mainly used for access boundary lookups and IAM + /// signer operations. + pub fn with_http_client_provider( + mut self, + provider: impl HttpClientProvider + 'static, + ) -> Self { + self.providers.http = SharedHttpClientProvider::new(provider); + self + } + fn build_token_provider(self) -> BuildResult { let service_account_key = serde_json::from_value::(self.service_account_key) @@ -343,6 +386,7 @@ impl Builder { ) -> BuildResult>> { let iam_endpoint = self.iam_endpoint_override.clone(); let quota_project_id = self.quota_project_id.clone(); + let http = SharedHttpClientProvider::clone(&self.providers.http); let token_provider = self.build_token_provider()?; let client_email = token_provider.service_account_key.client_email.clone(); let access_boundary_url = crate::access_boundary::service_account_lookup_url( @@ -357,6 +401,7 @@ impl Builder { Ok(CredentialsWithAccessBoundary::new( creds, Some(access_boundary_url), + http, )) } diff --git a/src/auth/src/credentials/user_account.rs b/src/auth/src/credentials/user_account.rs index 631d709f47..c14f3fa894 100644 --- a/src/auth/src/credentials/user_account.rs +++ b/src/auth/src/credentials/user_account.rs @@ -98,6 +98,10 @@ use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsPro use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials}; use crate::errors::{self, CredentialsError}; use crate::headers_util::AuthHeadersBuilder; +use crate::io::{ + EnvProvider, FsProvider, HttpClientProvider, HttpRequest, IoConfig, SharedEnvProvider, + SharedFsProvider, SharedHttpClientProvider, +}; use crate::retry::Builder as RetryTokenProviderBuilder; use crate::token::{CachedTokenProvider, Token, TokenProvider}; use crate::token_cache::TokenCache; @@ -105,9 +109,7 @@ use crate::{BuildResult, Result}; use google_cloud_gax::backoff_policy::BackoffPolicyArg; use google_cloud_gax::retry_policy::RetryPolicyArg; use google_cloud_gax::retry_throttler::RetryThrottlerArg; -use http::header::CONTENT_TYPE; -use http::{Extensions, HeaderMap, HeaderValue}; -use reqwest::{Client, Method}; +use http::{Extensions, HeaderMap}; use serde_json::Value; use std::sync::Arc; use tokio::time::{Duration, Instant}; @@ -128,6 +130,7 @@ pub struct Builder { quota_project_id: Option, token_uri: Option, retry_builder: RetryTokenProviderBuilder, + providers: IoConfig, } impl Builder { @@ -144,6 +147,7 @@ impl Builder { quota_project_id: None, token_uri: None, retry_builder: RetryTokenProviderBuilder::default(), + providers: IoConfig::default(), } } @@ -300,6 +304,39 @@ impl Builder { self } + /// Sets a custom environment variable provider. + /// + /// When set, the auth crate will use this provider for all environment + /// variable lookups during user account credential construction instead + /// of reading from the process environment directly. + pub fn with_env_provider(mut self, provider: impl EnvProvider + 'static) -> Self { + self.providers.env = SharedEnvProvider::new(provider); + self + } + + /// Sets a custom filesystem provider. + /// + /// When set, the auth crate will use this provider for all file read + /// operations during user account credential construction instead of + /// reading from the real filesystem directly. + pub fn with_fs_provider(mut self, provider: impl FsProvider + 'static) -> Self { + self.providers.fs = SharedFsProvider::new(provider); + self + } + + /// Sets a custom HTTP client provider. + /// + /// When set, the auth crate will use this provider for all HTTP + /// requests during user account credential construction and token + /// retrieval instead of using `reqwest::Client` directly. + pub fn with_http_client_provider( + mut self, + provider: impl HttpClientProvider + 'static, + ) -> Self { + self.providers.http = SharedHttpClientProvider::new(provider); + self + } + /// Returns a [Credentials] instance with the configured settings. /// /// # Errors @@ -366,6 +403,7 @@ impl Builder { endpoint, scopes: self.scopes.map(|scopes| scopes.join(" ")), source: UserTokenSource::AccessToken, + http: self.providers.http, }; let token_provider = TokenCache::new(self.retry_builder.build(token_provider)); @@ -379,7 +417,6 @@ impl Builder { } } -#[derive(PartialEq)] pub(crate) struct UserTokenProvider { client_id: String, client_secret: String, @@ -387,6 +424,21 @@ pub(crate) struct UserTokenProvider { endpoint: String, scopes: Option, source: UserTokenSource, + pub(crate) http: SharedHttpClientProvider, +} + +// PartialEq is implemented manually because SharedHttpClientProvider +// (Arc) does not implement PartialEq. The http field is excluded +// from comparison — it does not affect logical equality of the token provider. +impl PartialEq for UserTokenProvider { + fn eq(&self, other: &Self) -> bool { + self.client_id == other.client_id + && self.client_secret == other.client_secret + && self.refresh_token == other.refresh_token + && self.endpoint == other.endpoint + && self.scopes == other.scopes + && self.source == other.source + } } #[derive(PartialEq)] @@ -404,6 +456,7 @@ impl std::fmt::Debug for UserTokenProvider { .field("refresh_token", &"[censored]") .field("endpoint", &self.endpoint) .field("scopes", &self.scopes) + .field("http", &self.http) .finish() } } @@ -413,6 +466,7 @@ impl UserTokenProvider { pub(crate) fn new_id_token_provider( authorized_user: AuthorizedUser, token_uri: Option, + http: SharedHttpClientProvider, ) -> UserTokenProvider { let endpoint = token_uri .or(authorized_user.token_uri) @@ -424,6 +478,7 @@ impl UserTokenProvider { endpoint, source: UserTokenSource::IdToken, scopes: None, + http, } } } @@ -431,9 +486,7 @@ impl UserTokenProvider { #[async_trait::async_trait] impl TokenProvider for UserTokenProvider { async fn token(&self) -> Result { - let client = Client::new(); - - // Make the request + // Build the request body let req = Oauth2RefreshRequest { grant_type: RefreshGrantType::RefreshToken, client_id: self.client_id.clone(), @@ -441,25 +494,25 @@ impl TokenProvider for UserTokenProvider { refresh_token: self.refresh_token.clone(), scopes: self.scopes.clone(), }; - let header = HeaderValue::from_static("application/json"); - let builder = client - .request(Method::POST, self.endpoint.as_str()) - .header(CONTENT_TYPE, header) - .json(&req); - let resp = builder - .send() + let body = serde_json::to_vec(&req).map_err(|e| CredentialsError::from_source(false, e))?; + + let request = HttpRequest::post(&self.endpoint).json(body); + + // Execute the request via the HTTP provider + let response = self + .http + .execute(request) .await .map_err(|e| errors::from_http_error(e, MSG))?; // Process the response - if !resp.status().is_success() { - let err = errors::from_http_response(resp, MSG).await; - return Err(err); + if !response.is_success() { + return Err(errors::from_http_response(&response, MSG)); } - let response = resp.json::().await.map_err(|e| { - let retryable = !e.is_decode(); - CredentialsError::from_source(retryable, e) - })?; + + let response: Oauth2RefreshResponse = response + .json() + .map_err(|e| CredentialsError::from_source(e.is_io(), e))?; let token = match self.source { UserTokenSource::AccessToken => Ok(response.access_token), @@ -571,9 +624,10 @@ mod tests { get_token_type_from_headers, }; use crate::credentials::{DEFAULT_UNIVERSE_DOMAIN, QUOTA_PROJECT_KEY}; + use crate::errors; use crate::errors::CredentialsError; use crate::token::tests::MockTokenProvider; - use http::StatusCode; + use http::HeaderValue; use http::header::AUTHORIZATION; use httptest::cycle; use httptest::matchers::{all_of, json_decoded, request}; @@ -680,6 +734,7 @@ mod tests { endpoint: OAUTH2_TOKEN_SERVER_URL.to_string(), scopes: Some("https://www.googleapis.com/auth/pubsub".to_string()), source: UserTokenSource::AccessToken, + http: SharedHttpClientProvider::default(), }; let fmt = format!("{expected:?}"); assert!(fmt.contains("test-client-id"), "{fmt}"); @@ -960,6 +1015,7 @@ mod tests { endpoint: server.url("/token").to_string(), scopes: Some("scope1 scope2".to_string()), source: UserTokenSource::AccessToken, + http: SharedHttpClientProvider::default(), }; let now = Instant::now(); let token = tp.token().await?; @@ -1266,12 +1322,6 @@ mod tests { let original_err = find_source_error::(&err).unwrap(); assert!(original_err.is_transient()); - let source = find_source_error::(&err); - assert!( - matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)), - "{err:?}" - ); - Ok(()) } @@ -1294,12 +1344,6 @@ mod tests { let original_err = find_source_error::(&err).unwrap(); assert!(!original_err.is_transient()); - let source = find_source_error::(&err); - assert!( - matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)), - "{err:?}" - ); - Ok(()) } diff --git a/src/auth/src/errors.rs b/src/auth/src/errors.rs index 42c2719184..24b91200c4 100644 --- a/src/auth/src/errors.rs +++ b/src/auth/src/errors.rs @@ -14,7 +14,7 @@ //! Common errors generated by the components in this crate. -use http::StatusCode; +use crate::io::HttpResponse; use std::error::Error; pub use google_cloud_gax::error::CredentialsError; @@ -73,23 +73,62 @@ impl SubjectTokenProviderError for CredentialsError { } } -pub(crate) fn from_http_error(err: reqwest::Error, msg: &str) -> CredentialsError { - let transient = self::is_retryable(&err); - CredentialsError::new(transient, msg, err) +/// Wraps a `Box` from an I/O provider (e.g. `HttpClientProvider`) +/// so it can be used as a source in `CredentialsError::from_source` (which +/// requires `Error + Sized`). +#[derive(Debug)] +struct TransportError { + message: String, + source: Box, } -pub(crate) async fn from_http_response(response: reqwest::Response, msg: &str) -> CredentialsError { - let err = response - .error_for_status_ref() - .expect_err("this function is only called on errors"); - let body = response.text().await; - let transient = crate::errors::is_retryable(&err); - match body { - Err(e) => CredentialsError::new(transient, msg, e), - Ok(b) => CredentialsError::new(transient, format!("{msg}, body=<{b}>"), err), +impl std::fmt::Display for TransportError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) } } +impl Error for TransportError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + Some(self.source.as_ref()) + } +} + +/// Creates a `CredentialsError` from a failed HTTP provider call. +/// +/// Use this when `HttpClientProvider::execute()` returns `Err(...)`, +/// meaning the request failed at the transport level (connection refused, +/// timeout, DNS resolution, TLS handshake, etc.) and no HTTP response was +/// received. These errors are marked as transient because transport-level +/// failures are typically recoverable. Non-transient cases (e.g., +/// misconfigured endpoints) will be bounded by the caller's retry policy. +/// +/// For errors from a received HTTP response (non-2xx status), use +/// [`from_http_response`] instead, which checks the status code to +/// determine retryability. +pub(crate) fn from_http_error(source: Box, msg: &str) -> CredentialsError { + CredentialsError::from_source( + true, + TransportError { + message: msg.to_string(), + source, + }, + ) +} + +/// Creates a `CredentialsError` from a non-success HTTP response. +/// +/// Use this when `HttpClientProvider::execute()` returns `Ok(response)` +/// but `response.is_success()` is false. +pub(crate) fn from_http_response(response: &HttpResponse, msg: &str) -> CredentialsError { + let status = response.status; + let body = String::from_utf8_lossy(&response.body); + CredentialsError::from_msg( + response.is_retryable(), + format!("{msg}, status={status}, body=<{body}>"), + ) +} + /// A helper to create a non-retryable error. pub(crate) fn non_retryable(source: T) -> CredentialsError { CredentialsError::from_source(false, source) @@ -99,54 +138,10 @@ pub(crate) fn non_retryable_from_str>(message: T) -> Credentials CredentialsError::from_msg(false, message) } -fn is_retryable(err: &reqwest::Error) -> bool { - // Connection errors are transient more often than not. A bad configuration - // can point to a non-existing service, and that will never recover. - // However: (1) we expect this to be rare, and (2) this is what limiting - // retry policies and backoff policies handle. - if err.is_connect() { - return true; - } - match err.status() { - Some(code) => is_retryable_code(code), - None => false, - } -} - -fn is_retryable_code(code: StatusCode) -> bool { - match code { - // Internal server errors do not indicate that there is anything wrong - // with our request, so we retry them. - StatusCode::INTERNAL_SERVER_ERROR - | StatusCode::SERVICE_UNAVAILABLE - | StatusCode::REQUEST_TIMEOUT - | StatusCode::TOO_MANY_REQUESTS => true, - _ => false, - } -} - #[cfg(test)] mod tests { use super::*; use std::num::ParseIntError; - use test_case::test_case; - - #[test_case(StatusCode::INTERNAL_SERVER_ERROR)] - #[test_case(StatusCode::SERVICE_UNAVAILABLE)] - #[test_case(StatusCode::REQUEST_TIMEOUT)] - #[test_case(StatusCode::TOO_MANY_REQUESTS)] - fn retryable(c: StatusCode) { - assert!(is_retryable_code(c)); - } - - #[test_case(StatusCode::NOT_FOUND)] - #[test_case(StatusCode::UNAUTHORIZED)] - #[test_case(StatusCode::BAD_REQUEST)] - #[test_case(StatusCode::BAD_GATEWAY)] - #[test_case(StatusCode::PRECONDITION_FAILED)] - fn non_retryable(c: StatusCode) { - assert!(!is_retryable_code(c)); - } #[test] fn helpers() { diff --git a/src/auth/src/io.rs b/src/auth/src/io.rs new file mode 100644 index 0000000000..868c8f9e3c --- /dev/null +++ b/src/auth/src/io.rs @@ -0,0 +1,658 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Trait-based abstractions for external I/O in the auth crate. +//! +//! This module defines provider traits for environment variable reads, +//! filesystem reads, and HTTP requests. Users can implement these traits +//! to control how the auth crate performs I/O operations. + +use std::future::Future; +use std::panic::{RefUnwindSafe, UnwindSafe}; +use std::sync::Arc; + +/// Abstracts environment variable reads. +/// +/// The default implementation delegates to `std::env::var`. +/// Implement this trait to control how the auth crate resolves +/// environment variables like `GOOGLE_APPLICATION_CREDENTIALS`. +pub trait EnvProvider: Send + Sync + UnwindSafe + RefUnwindSafe + std::fmt::Debug { + /// Reads an environment variable by name. + /// Returns `None` if the variable is not set. + fn var(&self, name: &str) -> Option; +} + +/// Default implementation using `std::env::var`. +#[derive(Debug, Clone)] +pub struct DefaultEnvProvider; + +impl EnvProvider for DefaultEnvProvider { + fn var(&self, name: &str) -> Option { + std::env::var(name).ok() + } +} + +/// Abstracts filesystem read operations. +/// +/// The default implementation delegates to `std::fs::read_to_string`. +/// Implement this trait to control how the auth crate loads files +/// such as ADC credential files. +pub trait FsProvider: Send + Sync + UnwindSafe + RefUnwindSafe + std::fmt::Debug { + /// Reads the entire contents of a file as a string. + fn read_to_string(&self, path: &str) -> std::io::Result; +} + +/// Default implementation using `std::fs::read_to_string`. +#[derive(Debug, Clone)] +pub struct DefaultFsProvider; + +impl FsProvider for DefaultFsProvider { + fn read_to_string(&self, path: &str) -> std::io::Result { + std::fs::read_to_string(path) + } +} + +/// HTTP method for auth requests. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum HttpMethod { + Get, + Post, + Put, +} + +/// A simple HTTP request used by the auth crate. +#[derive(Debug, Clone)] +pub struct HttpRequest { + pub method: HttpMethod, + pub url: String, + pub headers: Vec<(String, String)>, + pub body: Vec, + pub query_params: Vec<(String, String)>, +} + +impl HttpRequest { + /// Creates a GET request to the given URL. + pub fn get(url: impl Into) -> Self { + Self { + method: HttpMethod::Get, + url: url.into(), + headers: Vec::new(), + body: Vec::new(), + query_params: Vec::new(), + } + } + + /// Creates a POST request to the given URL. + pub fn post(url: impl Into) -> Self { + Self { + method: HttpMethod::Post, + url: url.into(), + headers: Vec::new(), + body: Vec::new(), + query_params: Vec::new(), + } + } + + /// Creates a PUT request to the given URL. + pub fn put(url: impl Into) -> Self { + Self { + method: HttpMethod::Put, + url: url.into(), + headers: Vec::new(), + body: Vec::new(), + query_params: Vec::new(), + } + } + + /// Adds a header to the request. + pub fn header(mut self, name: impl Into, value: impl Into) -> Self { + self.headers.push((name.into(), value.into())); + self + } + + /// Adds a query parameter to the request. + pub fn query(mut self, name: impl Into, value: impl Into) -> Self { + self.query_params.push((name.into(), value.into())); + self + } + + /// Sets the request body. + pub fn body(mut self, body: impl Into>) -> Self { + self.body = body.into(); + self + } + + /// Sets a JSON body and adds the `content-type: application/json` header. + pub fn json(self, body: impl Into>) -> Self { + self.header("content-type", "application/json").body(body) + } + + /// Sets a form-encoded body and adds the `content-type: application/x-www-form-urlencoded` header. + pub fn form(self, body: impl Into>) -> Self { + self.header("content-type", "application/x-www-form-urlencoded") + .body(body) + } + + /// Copies headers from an `http::HeaderMap` into this request. + pub fn headers_from_map(mut self, map: &http::HeaderMap) -> Self { + for (name, value) in map.iter() { + if let Ok(v) = value.to_str() { + self.headers.push((name.to_string(), v.to_string())); + } + } + self + } +} + +/// A simple HTTP response returned by the auth crate. +#[derive(Debug, Clone)] +pub struct HttpResponse { + pub status: http::StatusCode, + pub headers: http::HeaderMap, + pub body: Vec, +} + +impl HttpResponse { + /// Returns `true` if the status code indicates success (2xx). + pub fn is_success(&self) -> bool { + self.status.is_success() + } + + /// Returns `true` if the status code indicates a transient/retryable error. + pub fn is_retryable(&self) -> bool { + matches!( + self.status, + http::StatusCode::REQUEST_TIMEOUT + | http::StatusCode::TOO_MANY_REQUESTS + | http::StatusCode::INTERNAL_SERVER_ERROR + | http::StatusCode::SERVICE_UNAVAILABLE + ) + } + + /// Deserializes the response body as JSON into the given type. + pub fn json(&self) -> serde_json::Result { + serde_json::from_slice(&self.body) + } + + /// Returns the response body as a UTF-8 string. + pub fn text(&self) -> std::result::Result { + String::from_utf8(self.body.clone()) + } +} + +/// Abstracts HTTP request execution. +/// +/// The default implementation delegates to `reqwest::Client`. +/// Implement this trait to control how the auth crate makes HTTP +/// requests for token exchanges, metadata service calls, etc. +pub trait HttpClientProvider: Send + Sync + UnwindSafe + RefUnwindSafe + std::fmt::Debug { + /// Executes an HTTP request and returns the response. + fn execute( + &self, + request: HttpRequest, + ) -> impl Future>> + Send; +} + +/// Default implementation using `reqwest::Client`. +#[derive(Debug, Clone)] +pub struct DefaultHttpClientProvider { + client: reqwest::Client, +} + +// SAFETY: `reqwest::Client` is internally `Arc` — it is +// effectively immutable once constructed and safe to use across +// `catch_unwind` boundaries. +impl UnwindSafe for DefaultHttpClientProvider {} +impl RefUnwindSafe for DefaultHttpClientProvider {} + +impl Default for DefaultHttpClientProvider { + fn default() -> Self { + Self { + client: reqwest::Client::new(), + } + } +} + +impl HttpClientProvider for DefaultHttpClientProvider { + async fn execute( + &self, + request: HttpRequest, + ) -> Result> { + let method = match request.method { + HttpMethod::Get => reqwest::Method::GET, + HttpMethod::Post => reqwest::Method::POST, + HttpMethod::Put => reqwest::Method::PUT, + }; + + let mut builder = self.client.request(method, &request.url); + if !request.query_params.is_empty() { + builder = builder.query(&request.query_params); + } + for (name, value) in &request.headers { + builder = builder.header(name.as_str(), value.as_str()); + } + if !request.body.is_empty() { + builder = builder.body(request.body); + } + + let resp = builder.send().await?; + + let status = resp.status(); + let headers = resp.headers().clone(); + let body = resp.bytes().await?.to_vec(); + + Ok(HttpResponse { + status, + headers, + body, + }) + } +} + +/// A shared, cloneable environment variable provider. +/// +/// Wraps an `Arc`. Construct via [`SharedEnvProvider::new`]. +#[derive(Clone, Debug)] +pub struct SharedEnvProvider(Arc); + +impl SharedEnvProvider { + /// Creates a new shared provider from any type implementing [`EnvProvider`]. + pub fn new(provider: P) -> Self { + Self(Arc::new(provider)) + } + + /// Reads an environment variable by name, delegating to the inner provider. + pub(crate) fn var(&self, name: &str) -> Option { + self.0.var(name) + } +} + +impl EnvProvider for SharedEnvProvider { + fn var(&self, name: &str) -> Option { + self.0.var(name) + } +} + +impl Default for SharedEnvProvider { + fn default() -> Self { + Self::new(DefaultEnvProvider) + } +} + +/// A shared, cloneable filesystem provider. +/// +/// Wraps an `Arc`. Construct via [`SharedFsProvider::new`]. +#[derive(Clone, Debug)] +pub struct SharedFsProvider(Arc); + +impl SharedFsProvider { + /// Creates a new shared provider from any type implementing [`FsProvider`]. + pub fn new(provider: P) -> Self { + Self(Arc::new(provider)) + } + + /// Reads the entire contents of a file as a string, delegating to the inner provider. + pub(crate) fn read_to_string(&self, path: &str) -> std::io::Result { + self.0.read_to_string(path) + } +} + +impl FsProvider for SharedFsProvider { + fn read_to_string(&self, path: &str) -> std::io::Result { + self.0.read_to_string(path) + } +} + +impl Default for SharedFsProvider { + fn default() -> Self { + Self::new(DefaultFsProvider) + } +} + +/// A shared, cloneable HTTP client provider. +/// +/// Wraps an `Arc`. Construct via [`SharedHttpClientProvider::new`]. +#[derive(Clone, Debug)] +pub struct SharedHttpClientProvider(Arc); + +impl SharedHttpClientProvider { + /// Creates a new shared provider from any type implementing [`HttpClientProvider`]. + pub fn new(provider: P) -> Self { + Self(Arc::new(provider)) + } + + /// Executes an HTTP request, delegating to the inner provider. + pub(crate) async fn execute( + &self, + request: HttpRequest, + ) -> Result> { + self.0.execute(request).await + } +} + +impl HttpClientProvider for SharedHttpClientProvider { + async fn execute( + &self, + request: HttpRequest, + ) -> Result> { + self.0.execute(request).await + } +} + +impl Default for SharedHttpClientProvider { + fn default() -> Self { + Self::new(DefaultHttpClientProvider::default()) + } +} + +/// Holds the I/O provider configuration for credential construction. +/// +/// Passed through the credential construction chain so that all components +/// use the same set of providers. +#[derive(Clone, Debug, Default)] +pub(crate) struct IoConfig { + pub(crate) env: SharedEnvProvider, + pub(crate) fs: SharedFsProvider, + pub(crate) http: SharedHttpClientProvider, +} + +#[cfg(test)] +mod tests { + use super::*; + + // -- HttpRequest builder tests -- + + #[test] + fn http_request_get() { + let req = HttpRequest::get("https://example.com"); + assert_eq!(req.method, HttpMethod::Get); + assert_eq!(req.url, "https://example.com"); + assert!(req.headers.is_empty()); + assert!(req.body.is_empty()); + assert!(req.query_params.is_empty()); + } + + #[test] + fn http_request_post() { + let req = HttpRequest::post("https://example.com"); + assert_eq!(req.method, HttpMethod::Post); + } + + #[test] + fn http_request_put() { + let req = HttpRequest::put("https://example.com"); + assert_eq!(req.method, HttpMethod::Put); + } + + #[test] + fn http_request_header() { + let req = HttpRequest::get("https://example.com") + .header("authorization", "Bearer tok") + .header("x-custom", "val"); + assert_eq!(req.headers.len(), 2); + assert_eq!( + req.headers[0], + ("authorization".into(), "Bearer tok".into()) + ); + assert_eq!(req.headers[1], ("x-custom".into(), "val".into())); + } + + #[test] + fn http_request_query() { + let req = HttpRequest::get("https://example.com") + .query("key", "value") + .query("a", "b"); + assert_eq!(req.query_params.len(), 2); + assert_eq!(req.query_params[0], ("key".into(), "value".into())); + } + + #[test] + fn http_request_body() { + let req = HttpRequest::post("https://example.com").body(b"hello".to_vec()); + assert_eq!(req.body, b"hello"); + } + + #[test] + fn http_request_json() { + let req = HttpRequest::post("https://example.com").json(b"{}".to_vec()); + assert_eq!(req.body, b"{}"); + assert!( + req.headers + .iter() + .any(|(k, v)| k == "content-type" && v == "application/json") + ); + } + + #[test] + fn http_request_form() { + let req = HttpRequest::post("https://example.com").form(b"k=v".to_vec()); + assert_eq!(req.body, b"k=v"); + assert!( + req.headers + .iter() + .any(|(k, v)| k == "content-type" && v == "application/x-www-form-urlencoded") + ); + } + + #[test] + fn http_request_headers_from_map() { + let mut map = http::HeaderMap::new(); + map.insert("x-foo", http::HeaderValue::from_static("bar")); + let req = HttpRequest::get("https://example.com").headers_from_map(&map); + assert_eq!(req.headers, vec![("x-foo".into(), "bar".into())]); + } + + // -- HttpResponse tests -- + + fn make_response(status: u16, body: &[u8]) -> HttpResponse { + HttpResponse { + status: http::StatusCode::from_u16(status).unwrap(), + headers: http::HeaderMap::new(), + body: body.to_vec(), + } + } + + #[test] + fn http_response_is_success() { + assert!(make_response(200, b"").is_success()); + assert!(make_response(201, b"").is_success()); + assert!(!make_response(400, b"").is_success()); + assert!(!make_response(500, b"").is_success()); + } + + #[test] + fn http_response_is_retryable() { + assert!(make_response(408, b"").is_retryable()); + assert!(make_response(429, b"").is_retryable()); + assert!(make_response(500, b"").is_retryable()); + assert!(make_response(503, b"").is_retryable()); + assert!(!make_response(200, b"").is_retryable()); + assert!(!make_response(400, b"").is_retryable()); + assert!(!make_response(404, b"").is_retryable()); + assert!(!make_response(502, b"").is_retryable()); + } + + #[test] + fn http_response_json() { + let resp = make_response(200, br#"{"key":"value"}"#); + let parsed: serde_json::Value = resp.json().unwrap(); + assert_eq!(parsed["key"], "value"); + } + + #[test] + fn http_response_json_error() { + let resp = make_response(200, b"not json"); + let result: serde_json::Result = resp.json(); + assert!(result.is_err()); + } + + #[test] + fn http_response_text() { + let resp = make_response(200, b"hello world"); + assert_eq!(resp.text().unwrap(), "hello world"); + } + + #[test] + fn http_response_text_invalid_utf8() { + let resp = make_response(200, &[0xff, 0xfe]); + assert!(resp.text().is_err()); + } + + // -- Custom provider tests -- + + #[derive(Debug)] + struct FakeEnv; + impl EnvProvider for FakeEnv { + fn var(&self, name: &str) -> Option { + match name { + "TEST_KEY" => Some("test_value".into()), + _ => None, + } + } + } + + #[derive(Debug)] + struct FakeFs; + impl FsProvider for FakeFs { + fn read_to_string(&self, path: &str) -> std::io::Result { + if path == "/fake/file.txt" { + Ok("fake contents".into()) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::NotFound, + "not found", + )) + } + } + } + + #[test] + fn shared_env_provider_delegates() { + let shared = SharedEnvProvider::new(FakeEnv); + assert_eq!(shared.var("TEST_KEY"), Some("test_value".into())); + assert_eq!(shared.var("MISSING"), None); + } + + #[test] + fn shared_env_provider_implements_trait() { + let shared = SharedEnvProvider::new(FakeEnv); + let as_trait: &dyn EnvProvider = &shared; + assert_eq!(as_trait.var("TEST_KEY"), Some("test_value".into())); + } + + #[test] + fn shared_env_provider_clone() { + let shared = SharedEnvProvider::new(FakeEnv); + let cloned = shared.clone(); + assert_eq!(cloned.var("TEST_KEY"), Some("test_value".into())); + } + + #[test] + fn shared_fs_provider_delegates() { + let shared = SharedFsProvider::new(FakeFs); + assert_eq!( + shared.read_to_string("/fake/file.txt").unwrap(), + "fake contents" + ); + assert!(shared.read_to_string("/missing").is_err()); + } + + #[test] + fn shared_fs_provider_implements_trait() { + let shared = SharedFsProvider::new(FakeFs); + let as_trait: &dyn FsProvider = &shared; + assert_eq!( + as_trait.read_to_string("/fake/file.txt").unwrap(), + "fake contents" + ); + } + + #[derive(Debug)] + struct FakeHttp; + impl HttpClientProvider for FakeHttp { + async fn execute( + &self, + _request: HttpRequest, + ) -> Result> { + Ok(make_response(200, b"ok")) + } + } + + #[tokio::test] + async fn shared_http_provider_delegates() { + let shared = SharedHttpClientProvider::new(FakeHttp); + let resp = shared + .execute(HttpRequest::get("https://example.com")) + .await + .unwrap(); + assert_eq!(resp.status, http::StatusCode::OK); + assert_eq!(resp.body, b"ok"); + } + + #[tokio::test] + async fn shared_http_provider_implements_trait() { + let shared = SharedHttpClientProvider::new(FakeHttp); + let resp = HttpClientProvider::execute(&shared, HttpRequest::get("https://example.com")) + .await + .unwrap(); + assert_eq!(resp.status, http::StatusCode::OK); + } + + // -- IoConfig defaults -- + + #[test] + fn io_config_default() { + let config = IoConfig::default(); + // Default env provider reads real env vars — just verify it doesn't panic. + let _ = config.env.var("PATH"); + // Default fs provider reads real files — just verify a missing file errors. + assert!(config.fs.read_to_string("/nonexistent-path-12345").is_err()); + } +} + +pub(crate) mod dynamic { + use std::panic::{RefUnwindSafe, UnwindSafe}; + + /// A dyn-compatible, crate-private version of `HttpClientProvider`. + /// + /// The public `HttpClientProvider` uses RPITIT (`-> impl Future<...>`) + /// which is not dyn-compatible. This trait uses `#[async_trait]` to + /// produce a boxed future, enabling storage as `Arc`. + /// + /// `EnvProvider` and `FsProvider` are synchronous and already + /// dyn-compatible, so they don't need dynamic counterparts. + #[async_trait::async_trait] + pub trait HttpClientProvider: + Send + Sync + UnwindSafe + RefUnwindSafe + std::fmt::Debug + { + async fn execute( + &self, + request: super::HttpRequest, + ) -> Result>; + } + + /// The public HttpClientProvider implements the dyn-compatible HttpClientProvider. + #[async_trait::async_trait] + impl HttpClientProvider for T + where + T: super::HttpClientProvider + Send + Sync, + { + async fn execute( + &self, + request: super::HttpRequest, + ) -> Result> { + T::execute(self, request).await + } + } +} diff --git a/src/auth/src/lib.rs b/src/auth/src/lib.rs index 60e2b7bc69..ba04452698 100644 --- a/src/auth/src/lib.rs +++ b/src/auth/src/lib.rs @@ -58,6 +58,7 @@ pub(crate) mod constants; pub mod credentials; pub mod errors; pub(crate) mod headers_util; +pub mod io; pub(crate) mod mds; pub(crate) mod retry; pub mod signer; diff --git a/src/auth/src/mds/client.rs b/src/auth/src/mds/client.rs index 11edda97d5..99b74f0dfb 100644 --- a/src/auth/src/mds/client.rs +++ b/src/auth/src/mds/client.rs @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::errors::{self, CredentialsError}; +use crate::errors::CredentialsError; +use crate::io::{HttpRequest, HttpResponse, SharedEnvProvider, SharedHttpClientProvider}; use crate::token::Token; -use reqwest::{Client as ReqwestClient, RequestBuilder}; use std::time::Duration; use tokio::time::Instant; @@ -24,6 +24,7 @@ pub(crate) struct Client { endpoint: String, /// True if the endpoint was NOT overridden by env var or constructor arg. pub(crate) is_default_endpoint: bool, + http: SharedHttpClientProvider, } #[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)] @@ -36,19 +37,27 @@ pub(crate) struct MDSTokenResponse { impl Client { /// Creates a new client for the Metadata Service. - pub(crate) fn new(endpoint_override: Option) -> Self { - let (endpoint, is_default_endpoint) = Self::resolve_endpoint(endpoint_override); + pub(crate) fn new( + endpoint_override: Option, + env: SharedEnvProvider, + http: SharedHttpClientProvider, + ) -> Self { + let (endpoint, is_default_endpoint) = Self::resolve_endpoint(endpoint_override, &env); let endpoint = endpoint.trim_end_matches('/').to_string(); Self { endpoint, is_default_endpoint, + http, } } /// Determine the endpoint and whether it was overridden - fn resolve_endpoint(endpoint_override: Option) -> (String, bool) { - if let Ok(host) = std::env::var(super::GCE_METADATA_HOST_ENV_VAR) { + fn resolve_endpoint( + endpoint_override: Option, + env: &SharedEnvProvider, + ) -> (String, bool) { + if let Some(host) = env.var(super::GCE_METADATA_HOST_ENV_VAR) { // Check GCE_METADATA_HOST environment variable first (format!("http://{host}"), false) } else if let Some(e) = endpoint_override { @@ -61,11 +70,24 @@ impl Client { } /// Creates a GET request to the MDS service with the correct headers. - fn get(&self, path: &str) -> RequestBuilder { + fn get(&self, path: &str) -> HttpRequest { let url = format!("{}{}", self.endpoint, path); - ReqwestClient::new() - .get(url) - .header(super::METADATA_FLAVOR, super::METADATA_FLAVOR_VALUE) + HttpRequest::get(url).header(super::METADATA_FLAVOR, super::METADATA_FLAVOR_VALUE) + } + + /// Executes an HTTP request via the provider and maps transport errors. + async fn execute( + &self, + request: HttpRequest, + error_message: &str, + ) -> crate::Result { + let response = self + .http + .execute(request) + .await + .map_err(|e| crate::errors::from_http_error(e, error_message))?; + + Self::check_response_status(response, error_message) } /// Fetches an access token for the default service account. @@ -78,7 +100,7 @@ impl Client { let scopes = scopes.as_ref().map(|v| v.join(",")); let request = scopes .into_iter() - .fold(request, |r, s| r.query(&[("scopes", s)])); + .fold(request, |r, s| r.query("scopes", s)); let error_message = "failed to fetch access token"; @@ -86,24 +108,19 @@ impl Client { // running on MDS environments and not useful if there is no MDS. We will mark the error // as retryable and let the retry policy determine whether to retry or not. Whenever we // define a default retry policy, we can skip retrying this case. - let response = request - .send() - .await - .map_err(|e| errors::from_http_error(e, error_message))?; + let response = self.execute(request, error_message).await?; - let response = Self::check_response_status(response, error_message).await?; - - let response = response.json::().await.map_err(|e| { + let mds_response: MDSTokenResponse = response.json().map_err(|e| { // Decoding errors are not transient. Typically they indicate a badly // configured MDS endpoint, or DNS redirecting the request to a random // server, e.g., ISPs that redirect unknown services to HTTP. - CredentialsError::from_source(!e.is_decode(), e) + CredentialsError::from_source(e.is_io(), e) })?; Ok(Token { - token: response.access_token, - token_type: response.token_type, - expires_at: response + token: mds_response.access_token, + token_type: mds_response.token_type, + expires_at: mds_response .expires_in .map(|d| Instant::now() + Duration::from_secs(d)), metadata: None, @@ -120,27 +137,17 @@ impl Client { licenses: Option, ) -> crate::Result { let path = format!("{}/identity", super::MDS_DEFAULT_URI); - let request = self.get(&path).query(&[("audience", target_audience)]); - let request = format.iter().fold(request, |builder, format| { - builder.query(&[("format", format)]) - }); - let request = licenses.iter().fold(request, |builder, licenses| { - builder.query(&[("licenses", licenses)]) - }); + let request = self.get(&path).query("audience", target_audience); + let request = format.iter().fold(request, |r, f| r.query("format", f)); + let request = licenses.iter().fold(request, |r, l| r.query("licenses", l)); let error_message = "failed to fetch id token"; - let response = request - .send() - .await - .map_err(|e| errors::from_http_error(e, error_message))?; - - let response = Self::check_response_status(response, error_message).await?; + let response = self.execute(request, error_message).await?; let token = response .text() - .await - .map_err(|e| CredentialsError::from_source(!e.is_decode(), e))?; + .map_err(|e| CredentialsError::from_source(false, e))?; Ok(token) } @@ -151,28 +158,21 @@ impl Client { let request = self.get(&path); let error_message = "failed to fetch email"; - let response = request - .send() - .await - .map_err(|e| errors::from_http_error(e, error_message))?; - - let response = Self::check_response_status(response, error_message).await?; + let response = self.execute(request, error_message).await?; let email = response .text() - .await - .map_err(|e| CredentialsError::from_source(!e.is_decode(), e))?; + .map_err(|e| CredentialsError::from_source(false, e))?; Ok(email) } - async fn check_response_status( - response: reqwest::Response, + fn check_response_status( + response: HttpResponse, error_message: &str, - ) -> crate::Result { - if !response.status().is_success() { - let err = errors::from_http_response(response, error_message).await; - Err(err) + ) -> crate::Result { + if !response.is_success() { + Err(crate::errors::from_http_response(&response, error_message)) } else { Ok(response) } @@ -182,16 +182,25 @@ impl Client { #[cfg(test)] mod tests { use super::*; + use crate::io::{SharedEnvProvider, SharedHttpClientProvider}; use crate::mds::MDS_DEFAULT_URI; use httptest::{Expectation, Server, matchers::*, responders::*}; use scoped_env::ScopedEnv; use serial_test::{parallel, serial}; + fn new_test_client(endpoint: Option) -> Client { + Client::new( + endpoint, + SharedEnvProvider::default(), + SharedHttpClientProvider::default(), + ) + } + #[tokio::test] #[parallel] async fn test_access_token_success() { let server = Server::run(); - let client = Client::new(Some(format!("http://{}", server.addr()))); + let client = new_test_client(Some(format!("http://{}", server.addr()))); let response = MDSTokenResponse { access_token: "test-token".to_string(), expires_in: Some(3600), @@ -226,7 +235,7 @@ mod tests { #[parallel] async fn test_access_token_failure() { let server = Server::run(); - let client = Client::new(Some(format!("http://{}", server.addr()))); + let client = new_test_client(Some(format!("http://{}", server.addr()))); server.expect( Expectation::matching(all_of![ @@ -245,7 +254,7 @@ mod tests { #[cfg(feature = "idtoken")] async fn test_id_token_success() { let server = Server::run(); - let client = Client::new(Some(format!("http://{}", server.addr()))); + let client = new_test_client(Some(format!("http://{}", server.addr()))); server.expect( Expectation::matching(all_of![ @@ -274,7 +283,7 @@ mod tests { #[cfg(feature = "idtoken")] async fn test_id_token_failure() { let server = Server::run(); - let client = Client::new(Some(format!("http://{}", server.addr()))); + let client = new_test_client(Some(format!("http://{}", server.addr()))); server.expect( Expectation::matching(all_of![ @@ -292,7 +301,7 @@ mod tests { #[parallel] async fn test_email_success() { let server = Server::run(); - let client = Client::new(Some(format!("http://{}", server.addr()))); + let client = new_test_client(Some(format!("http://{}", server.addr()))); server.expect( Expectation::matching(all_of![ @@ -310,7 +319,7 @@ mod tests { #[parallel] async fn test_email_failure() { let server = Server::run(); - let client = Client::new(Some(format!("http://{}", server.addr()))); + let client = new_test_client(Some(format!("http://{}", server.addr()))); server.expect( Expectation::matching(all_of![ @@ -327,14 +336,14 @@ mod tests { #[test] #[parallel] fn test_resolve_endpoint_default() { - let client = Client::new(None); + let client = new_test_client(None); assert_eq!(client.endpoint, "http://metadata.google.internal"); } #[test] #[parallel] fn test_resolve_endpoint_override() { - let client = Client::new(Some("http://custom.endpoint".to_string())); + let client = new_test_client(Some("http://custom.endpoint".to_string())); assert_eq!(client.endpoint, "http://custom.endpoint"); } @@ -342,7 +351,7 @@ mod tests { #[serial] fn test_resolve_endpoint_env_var() { let _s = ScopedEnv::set(super::super::GCE_METADATA_HOST_ENV_VAR, "env.var.host"); - let client = Client::new(None); + let client = new_test_client(None); assert_eq!(client.endpoint, "http://env.var.host"); } @@ -351,7 +360,7 @@ mod tests { fn test_resolve_endpoint_priority() { let _s = ScopedEnv::set(super::super::GCE_METADATA_HOST_ENV_VAR, "env.priority.host"); // Env var should take precedence over constructor argument - let client = Client::new(Some("http://custom.endpoint".to_string())); + let client = new_test_client(Some("http://custom.endpoint".to_string())); assert_eq!(client.endpoint, "http://env.priority.host"); } } diff --git a/src/auth/src/signer/iam.rs b/src/auth/src/signer/iam.rs index cb53205990..7765f94f65 100644 --- a/src/auth/src/signer/iam.rs +++ b/src/auth/src/signer/iam.rs @@ -13,6 +13,7 @@ // limitations under the License. use crate::credentials::{CacheableResource, Credentials}; +use crate::io::{HttpRequest, SharedHttpClientProvider}; use crate::signer::{Result, SigningError, dynamic::SigningProvider}; use google_cloud_gax::backoff_policy::BackoffPolicy; use google_cloud_gax::exponential_backoff::ExponentialBackoff; @@ -22,7 +23,6 @@ use google_cloud_gax::retry_throttler::{ AdaptiveThrottler, RetryThrottlerArg, SharedRetryThrottler, }; use http::{Extensions, HeaderMap}; -use reqwest::Client; use std::sync::Arc; use std::time::Duration; @@ -35,7 +35,7 @@ pub(crate) struct IamSigner { client_email: String, inner: Credentials, endpoint: String, - client: Client, + http: SharedHttpClientProvider, retry_policy: Arc, backoff_policy: Arc, } @@ -52,14 +52,19 @@ struct SignBlobResponse { } impl IamSigner { - pub(crate) fn new(client_email: String, inner: Credentials, endpoint: Option) -> Self { + pub(crate) fn new( + client_email: String, + inner: Credentials, + endpoint: Option, + http: SharedHttpClientProvider, + ) -> Self { let retry_policy = Aip194Strict.with_time_limit(Duration::from_secs(60)); let backoff_policy = ExponentialBackoff::default(); Self { client_email, inner, endpoint: endpoint.unwrap_or("https://iamcredentials.googleapis.com".to_string()), - client: Client::new(), + http, retry_policy: Arc::new(retry_policy), backoff_policy: Arc::new(backoff_policy), } @@ -85,7 +90,7 @@ impl SigningProvider for IamSigner { ); let response = sign_blob_call_with_retry( self.inner.clone(), - self.client.clone(), + SharedHttpClientProvider::clone(&self.http), url, body, self.retry_policy.clone(), @@ -93,15 +98,14 @@ impl SigningProvider for IamSigner { ) .await?; - if !response.status().is_success() { - let err_text = response.text().await.map_err(SigningError::transport)?; - return Err(SigningError::transport(format!("err status: {err_text:?}"))); + if !response.is_success() { + let body_text = String::from_utf8_lossy(&response.body); + return Err(SigningError::transport(format!( + "err status: {body_text:?}" + ))); } - let res = response - .json::() - .await - .map_err(SigningError::transport)?; + let res: SignBlobResponse = response.json().map_err(SigningError::transport)?; let signature = BASE64_STANDARD .decode(res.signed_blob) @@ -113,12 +117,12 @@ impl SigningProvider for IamSigner { async fn sign_blob_call_with_retry( credentials: Credentials, - client: Client, + http: SharedHttpClientProvider, url: String, body: SignBlobRequest, retry_policy: Arc, backoff_policy: Arc, -) -> Result { +) -> Result { let sleep = async |d| tokio::time::sleep(d).await; let retry_throttler: RetryThrottlerArg = AdaptiveThrottler::default().into(); @@ -131,7 +135,7 @@ async fn sign_blob_call_with_retry( .await .map_err(google_cloud_gax::error::Error::authentication)?; - sign_blob_call(&client, &url, source_headers, body.clone()).await + sign_blob_call(&http, &url, source_headers, body.clone()).await }, sleep, true, // signBlob is idempotent @@ -144,11 +148,11 @@ async fn sign_blob_call_with_retry( } async fn sign_blob_call( - client: &Client, + http: &SharedHttpClientProvider, url: &str, source_headers: CacheableResource, body: SignBlobRequest, -) -> google_cloud_gax::Result { +) -> google_cloud_gax::Result { let source_headers = match source_headers { CacheableResource::New { data, .. } => data, CacheableResource::NotModified => { @@ -156,26 +160,22 @@ async fn sign_blob_call( } }; - let response = client - .post(url) - .header("Content-Type", "application/json") - .headers(source_headers.clone()) - .json(&body) - .send() + let json_body = serde_json::to_vec(&body).map_err(google_cloud_gax::error::Error::io)?; + + let request = HttpRequest::post(url) + .json(json_body) + .headers_from_map(&source_headers); + + let response = http + .execute(request) .await .map_err(google_cloud_gax::error::Error::io)?; - let status = response.status(); - if !status.is_success() { - let err_headers = response.headers().clone(); - let err_payload = response - .bytes() - .await - .map_err(|e| google_cloud_gax::error::Error::transport(err_headers.clone(), e))?; + if !response.is_success() { return Err(google_cloud_gax::error::Error::http( - status.as_u16(), - err_headers, - err_payload, + response.status.as_u16(), + response.headers, + response.body.into(), )); } @@ -244,7 +244,12 @@ mod tests { }); let creds = Credentials::from(mock); - let signer = IamSigner::new("test@example.com".to_string(), creds, Some(endpoint)); + let signer = IamSigner::new( + "test@example.com".to_string(), + creds, + Some(endpoint), + SharedHttpClientProvider::default(), + ); let signature = signer.sign(b"test").await.unwrap(); assert_eq!(signature.as_ref(), b"signed_blob"); @@ -257,7 +262,12 @@ mod tests { let mock = MockCredentials::new(); let creds = Credentials::from(mock); - let signer = IamSigner::new("test@example.com".to_string(), creds, None); + let signer = IamSigner::new( + "test@example.com".to_string(), + creds, + None, + SharedHttpClientProvider::default(), + ); let client_email = signer.client_email().await.unwrap(); assert_eq!(client_email, "test@example.com"); @@ -285,7 +295,12 @@ mod tests { }); let creds = Credentials::from(mock); - let signer = IamSigner::new("test@example.com".to_string(), creds, Some(endpoint)); + let signer = IamSigner::new( + "test@example.com".to_string(), + creds, + Some(endpoint), + SharedHttpClientProvider::default(), + ); let err = signer.sign(b"test").await.unwrap_err(); assert!(err.is_transport()); @@ -297,19 +312,13 @@ mod tests { async fn test_iam_sign_retry() -> TestResult { let server = Server::run(); let signed_blob = BASE64_STANDARD.encode("signed_blob"); - let invalid_res = http::Response::builder() - .version(http::Version::HTTP_3) // unsupported version - .status(204) - .body(Vec::new()) - .unwrap(); server.expect( Expectation::matching(all_of![request::method_path( "POST", "/v1/projects/-/serviceAccounts/test@example.com:signBlob" ),]) - .times(3) + .times(2) .respond_with(cycle![ - invalid_res, // forces i/o error status_code(503).body("try-again"), json_encoded(json!({ "signedBlob": signed_blob, @@ -327,7 +336,12 @@ mod tests { }); let creds = Credentials::from(mock); - let mut signer = IamSigner::new("test@example.com".to_string(), creds, Some(endpoint)); + let mut signer = IamSigner::new( + "test@example.com".to_string(), + creds, + Some(endpoint), + SharedHttpClientProvider::default(), + ); signer.backoff_policy = Arc::new(test_backoff_policy()); let signature = signer.sign(b"test").await.unwrap(); diff --git a/src/auth/src/signer/mds.rs b/src/auth/src/signer/mds.rs index 430dd0a42a..3de5828f13 100644 --- a/src/auth/src/signer/mds.rs +++ b/src/auth/src/signer/mds.rs @@ -13,6 +13,7 @@ // limitations under the License. use crate::credentials::Credentials; +use crate::io::SharedHttpClientProvider; use crate::mds::client::Client as MDSClient; use crate::signer::{Result, SigningError, dynamic::SigningProvider}; use std::sync::OnceLock; @@ -25,15 +26,21 @@ pub(crate) struct MDSSigner { iam_endpoint_override: Option, client_email: OnceLock, inner: Credentials, + http: SharedHttpClientProvider, } impl MDSSigner { - pub(crate) fn new(client: MDSClient, inner: Credentials) -> Self { + pub(crate) fn new( + client: MDSClient, + inner: Credentials, + http: SharedHttpClientProvider, + ) -> Self { Self { client, client_email: OnceLock::new(), inner, iam_endpoint_override: None, + http, } } @@ -65,6 +72,7 @@ impl SigningProvider for MDSSigner { client_email, self.inner.clone(), self.iam_endpoint_override.clone(), + SharedHttpClientProvider::clone(&self.http), ); signer.sign(content).await @@ -82,6 +90,7 @@ mod tests { use super::*; use crate::credentials::{CacheableResource, Credentials, CredentialsProvider, EntityTag}; use crate::errors::CredentialsError; + use crate::io::{SharedEnvProvider, SharedHttpClientProvider}; use crate::mds::MDS_DEFAULT_URI; use base64::{Engine, prelude::BASE64_STANDARD}; use http::header::{HeaderName, HeaderValue}; @@ -113,8 +122,12 @@ mod tests { ); let mock = MockCredentials::new(); let creds = Credentials::from(mock); - let client = MDSClient::new(Some(format!("http://{}", server.addr()))); - let signer = MDSSigner::new(client, creds); + let client = MDSClient::new( + Some(format!("http://{}", server.addr())), + SharedEnvProvider::default(), + SharedHttpClientProvider::default(), + ); + let signer = MDSSigner::new(client, creds, SharedHttpClientProvider::default()); let client_email = signer.client_email().await?; assert_eq!(client_email, "test-client-email"); @@ -159,8 +172,12 @@ mod tests { let creds = Credentials::from(mock); let endpoint = server.url("").to_string().trim_end_matches('/').to_string(); - let client = MDSClient::new(Some(endpoint.clone())); - let mut signer = MDSSigner::new(client, creds); + let client = MDSClient::new( + Some(endpoint.clone()), + SharedEnvProvider::default(), + SharedHttpClientProvider::default(), + ); + let mut signer = MDSSigner::new(client, creds, SharedHttpClientProvider::default()); signer.iam_endpoint_override = Some(endpoint); let client_email = signer.client_email().await?;