diff --git a/src/auth/src/credentials/mds.rs b/src/auth/src/credentials/mds.rs index 098fa51f4a..ae2da8e06d 100644 --- a/src/auth/src/credentials/mds.rs +++ b/src/auth/src/credentials/mds.rs @@ -88,7 +88,7 @@ use google_cloud_gax::retry_policy::RetryPolicyArg; use google_cloud_gax::retry_throttler::RetryThrottlerArg; use http::{Extensions, HeaderMap}; use std::default::Default; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; // TODO(#2235) - Improve this message by talking about retries when really running with MDS const MDS_NOT_FOUND_ERROR: &str = concat!( @@ -105,7 +105,10 @@ where T: CachedTokenProvider, { quota_project_id: Option, + universe_domain_override: Option, + universe_domain: OnceLock>, token_provider: T, + mds_client: MDSClient, } /// Creates [Credentials] instances backed by the [Metadata Service]. @@ -123,6 +126,7 @@ where pub struct Builder { endpoint: Option, quota_project_id: Option, + universe_domain: Option, scopes: Option>, created_by_adc: bool, retry_builder: RetryTokenProviderBuilder, @@ -135,6 +139,7 @@ impl Default for Builder { Self { endpoint: None, quota_project_id: None, + universe_domain: None, scopes: None, created_by_adc: false, retry_builder: RetryTokenProviderBuilder::default(), @@ -177,6 +182,15 @@ impl Builder { self } + /// Sets the Google Cloud universe domain for these credentials. + /// + /// Any value provided here overrides a `universe_domain` value from the input service account JSON. + #[allow(dead_code)] + pub(crate) fn with_universe_domain>(mut self, universe_domain: S) -> Self { + self.universe_domain = Some(universe_domain.into()); + self + } + /// Sets the [scopes] for this credentials. /// /// Metadata server issues tokens based on the requested scopes. @@ -319,7 +333,10 @@ impl Builder { let mds_client = MDSClient::new(self.endpoint.clone()); let mdsc = MDSCredentials { quota_project_id: self.quota_project_id.clone(), + universe_domain_override: self.universe_domain.clone(), + universe_domain: OnceLock::new(), token_provider: TokenCache::new(self.build_token_provider()), + mds_client: mds_client.clone(), }; if !is_access_boundary_enabled { return Ok(CredentialsWithAccessBoundary::new_no_op(mdsc)); @@ -375,6 +392,33 @@ where .maybe_quota_project_id(self.quota_project_id.as_deref()) .build() } + + async fn universe_domain(&self) -> Option { + if let Some(ud) = &self.universe_domain_override { + return Some(ud.clone()); + } + if let Some(ud) = self.universe_domain.get() { + return ud.clone(); + } + + // No overrides and no cache. Try to fetch from MDS. + let response = self.mds_client.universe_domain().send().await; + match response { + Ok(universe_domain) => { + let _ = self.universe_domain.set(Some(universe_domain.clone())); + Some(universe_domain) + } + Err(e) => { + if !e.is_transient() { + // Only cache None if the error is permanent (e.g., 404 on GDU) + let _ = self.universe_domain.set(None); + } + // Return None but do not cache it if it's transient, + // allowing subsequent calls to retry or try again. + None + } + } + } } #[async_trait::async_trait] @@ -469,7 +513,6 @@ impl TokenProvider for MDSAccessTokenProvider { #[cfg(test)] mod tests { use super::*; - use crate::constants::DEFAULT_UNIVERSE_DOMAIN; use crate::credentials::QUOTA_PROJECT_KEY; use crate::credentials::tests::{ find_source_error, get_headers_from_cache, get_mock_auth_retry_policy, @@ -479,8 +522,11 @@ mod tests { use crate::errors; use crate::errors::CredentialsError; use crate::mds::client::MDSTokenResponse; - use crate::mds::{GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, METADATA_ROOT}; + use crate::mds::{ + GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, MDS_UNIVERSE_DOMAIN_URI, METADATA_ROOT, + }; use crate::token::tests::MockTokenProvider; + use crate::token_cache::TokenCache; use base64::{Engine, prelude::BASE64_STANDARD}; use http::HeaderValue; use http::header::AUTHORIZATION; @@ -611,6 +657,9 @@ mod tests { let mdsc = MDSCredentials { quota_project_id: None, token_provider: TokenCache::new(mock), + universe_domain_override: None, + universe_domain: OnceLock::new(), + mds_client: MDSClient::new(None), }; let mut extensions = Extensions::new(); @@ -672,6 +721,9 @@ mod tests { let mdsc = MDSCredentials { quota_project_id: None, token_provider: TokenCache::new(mock), + universe_domain_override: None, + universe_domain: OnceLock::new(), + mds_client: MDSClient::new(None), }; let result = mdsc.headers(Extensions::new()).await; assert!(result.is_err(), "{result:?}"); @@ -846,7 +898,7 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[parallel] async fn token_caching() -> TestResult { - let mut server = Server::run(); + let server = Server::run(); let scopes = vec!["scope1".to_string()]; let response = MDSTokenResponse { access_token: "test-access-token".to_string(), @@ -878,9 +930,6 @@ mod tests { "test-access-token" ); - // validate that the inner token provider is called only once - server.verify_and_clear(); - Ok(()) } @@ -1077,9 +1126,132 @@ mod tests { #[tokio::test] #[parallel] - async fn get_default_universe_domain_success() -> TestResult { - let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap(); - assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN); + async fn get_default_universe_domain() -> TestResult { + let server = Server::run(); + server.expect( + Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),]) + .respond_with(status_code(404)), + ); + + let mut mock = MockTokenProvider::new(); + mock.expect_token() + .returning(|| Err(crate::errors::non_retryable_from_str("fail"))); + + let creds = MDSCredentials { + quota_project_id: None, + universe_domain_override: None, + universe_domain: OnceLock::new(), + token_provider: TokenCache::new(mock), + mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))), + }; + + let universe_domain = creds.universe_domain().await; + assert!(universe_domain.is_none()); + Ok(()) + } + + #[tokio::test] + #[parallel] + async fn get_universe_domain_override() -> TestResult { + let creds = Builder::default() + .with_universe_domain("my-universe-domain.com") + .without_access_boundary() + .build()?; + let universe_domain = creds.universe_domain().await; + assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com")); + Ok(()) + } + + #[tokio::test] + #[parallel] + async fn get_universe_domain_from_mds() -> TestResult { + let server = Server::run(); + server.expect( + Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),]) + .respond_with(status_code(200).body("my-universe-domain.com")), + ); + + let mut mock = MockTokenProvider::new(); + mock.expect_token() + .returning(|| Err(crate::errors::non_retryable_from_str("fail"))); + + let creds = MDSCredentials { + quota_project_id: None, + universe_domain_override: None, + universe_domain: OnceLock::new(), + token_provider: TokenCache::new(mock), + mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))), + }; + let universe_domain = creds.universe_domain().await; + assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com")); + Ok(()) + } + + #[tokio::test] + #[parallel] + async fn get_universe_domain_caching() -> TestResult { + let server = Server::run(); + server.expect( + Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),]) + .times(2) + .respond_with(cycle![ + status_code(503).body("transient error"), + status_code(200).body("my-universe-domain.com"), + ]), + ); + + let mut mock = MockTokenProvider::new(); + mock.expect_token() + .returning(|| Err(crate::errors::non_retryable_from_str("fail"))); + + let creds = MDSCredentials { + quota_project_id: None, + universe_domain_override: None, + universe_domain: OnceLock::new(), + token_provider: TokenCache::new(mock), + mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))), + }; + + let universe_domain = creds.universe_domain().await; + assert_eq!(universe_domain, None); + + let universe_domain = creds.universe_domain().await; + assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com")); + + let universe_domain = creds.universe_domain().await; + assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com")); + + Ok(()) + } + + #[tokio::test] + #[parallel] + async fn get_universe_domain_caching_permanent_error() -> TestResult { + let server = Server::run(); + server.expect( + Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),]) + .times(1) + .respond_with(status_code(404).body("permanent error")), + ); + + let mut mock = MockTokenProvider::new(); + mock.expect_token() + .returning(|| Err(crate::errors::non_retryable_from_str("fail"))); + + let creds = MDSCredentials { + quota_project_id: None, + universe_domain_override: None, + universe_domain: OnceLock::new(), + token_provider: TokenCache::new(mock), + mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))), + }; + + let universe_domain = creds.universe_domain().await; + assert_eq!(universe_domain, None); + + let universe_domain = creds.universe_domain().await; + assert_eq!(universe_domain, None); + Ok(()) } @@ -1134,6 +1306,7 @@ mod tests { #[cfg(google_cloud_unstable_trusted_boundaries)] async fn e2e_access_boundary() -> TestResult { use crate::credentials::tests::get_access_boundary_from_headers; + use crate::mds::MDS_UNIVERSE_DOMAIN_URI; let server = Server::run(); server.expect( @@ -1148,6 +1321,10 @@ mod tests { Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),]) .respond_with(status_code(200).body("test-client-email")), ); + server.expect( + Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),]) + .respond_with(status_code(404)), + ); server.expect( Expectation::matching(all_of![ request::method_path(