diff --git a/src/auth/src/access_boundary.rs b/src/auth/src/access_boundary.rs index 60ae1dae39..5471f72a97 100644 --- a/src/auth/src/access_boundary.rs +++ b/src/auth/src/access_boundary.rs @@ -17,7 +17,6 @@ use crate::credentials::EntityTag; use crate::credentials::{ AccessToken, AccessTokenCredentialsProvider, CacheableResource, CredentialsProvider, dynamic, }; -use crate::errors::CredentialsError; use crate::mds::client::Client as MDSClient; use crate::{Result, errors}; use google_cloud_gax::Result as GaxResult; @@ -30,7 +29,6 @@ use google_cloud_gax::retry_throttler::{AdaptiveThrottler, RetryThrottlerArg}; use http::{Extensions, HeaderMap, HeaderValue}; use reqwest::Client; use std::clone::Clone; -use std::error::Error; use std::fmt::Debug; use std::sync::{Arc, OnceLock}; use tokio::sync::{Mutex, watch}; @@ -493,13 +491,10 @@ where T: dynamic::AccessTokenCredentialsProvider + Send + Sync + 'static, { async fn fetch(self) -> Result> { - let resp = self.fetch_with_retry().await.map_err(|e| { - let is_transient = e - .source() - .and_then(|s| s.downcast_ref::()) - .is_some_and(|cred_error| cred_error.is_transient()); - CredentialsError::new(is_transient, "failed to fetch access boundary", e) - })?; + let resp = self + .fetch_with_retry() + .await + .map_err(|e| crate::errors::from_gax_error(e, "failed to fetch access boundary"))?; if !resp.encoded_locations.is_empty() { return Ok(Some(resp.encoded_locations)); @@ -651,6 +646,7 @@ pub(crate) mod tests { use super::*; use crate::credentials::tests::{get_access_boundary_from_headers, get_token_from_headers}; use crate::credentials::{AccessToken, EntityTag}; + use crate::errors::CredentialsError; use google_cloud_gax::exponential_backoff::ExponentialBackoffBuilder; use http::header::{AUTHORIZATION, HeaderValue}; use http::{Extensions, HeaderMap}; @@ -885,7 +881,7 @@ pub(crate) mod tests { let result = client.fetch().await; let err = result.unwrap_err(); - assert!(!err.is_transient(), "{err:?}"); + assert!(err.is_transient(), "{err:?}"); } #[tokio::test] diff --git a/src/auth/src/credentials.rs b/src/auth/src/credentials.rs index 022f002678..450b4f37a7 100644 --- a/src/auth/src/credentials.rs +++ b/src/auth/src/credentials.rs @@ -850,6 +850,7 @@ pub mod testing { pub(crate) mod tests { use super::*; use crate::constants::TRUST_BOUNDARY_HEADER; + use crate::errors::is_gax_error_retryable; use base64::Engine; use google_cloud_gax::backoff_policy::BackoffPolicy; use google_cloud_gax::retry_policy::RetryPolicy; @@ -871,14 +872,15 @@ pub(crate) mod tests { pub(crate) fn find_source_error<'a, T: Error + 'static>( error: &'a (dyn Error + 'static), ) -> Option<&'a T> { - let mut source = error.source(); + let mut last_err = None; + let mut source = Some(error); while let Some(err) = source { if let Some(target_err) = err.downcast_ref::() { - return Some(target_err); + last_err = Some(target_err); } source = err.source(); } - None + last_err } mock! { @@ -921,11 +923,8 @@ pub(crate) mod tests { if state.attempt_count >= attempts as u32 { return RetryResult::Exhausted(error); } - let is_transient = error - .source() - .and_then(|e| e.downcast_ref::()) - .is_some_and(|ce| ce.is_transient()); - if is_transient { + let is_retryable = is_gax_error_retryable(&error); + if is_retryable { RetryResult::Continue(error) } else { RetryResult::Permanent(error) diff --git a/src/auth/src/credentials/idtoken/mds.rs b/src/auth/src/credentials/idtoken/mds.rs index 005edcbf4d..3e87ccdb8e 100644 --- a/src/auth/src/credentials/idtoken/mds.rs +++ b/src/auth/src/credentials/idtoken/mds.rs @@ -377,9 +377,9 @@ mod tests { .build()?; let err = creds.id_token().await.unwrap_err(); - let source = find_source_error::(&err); + let source = find_source_error::(&err); assert!( - matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)), + matches!(source, Some(e) if e.http_status_code() == Some(StatusCode::UNAUTHORIZED.into())), "{err:?}" ); @@ -494,9 +494,9 @@ mod tests { .build()?; let err = creds.id_token().await.unwrap_err(); - let source = find_source_error::(&err); + let source = find_source_error::(&err); assert!( - matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)), + matches!(source, Some(e) if e.http_status_code() == Some(StatusCode::SERVICE_UNAVAILABLE.into())), "{err:?}" ); Ok(()) diff --git a/src/auth/src/credentials/mds.rs b/src/auth/src/credentials/mds.rs index f43de90c8a..366d906f2e 100644 --- a/src/auth/src/credentials/mds.rs +++ b/src/auth/src/credentials/mds.rs @@ -716,11 +716,8 @@ mod tests { return Ok(()); }; - let original_err = find_source_error::(&err).unwrap(); - assert!( - original_err.to_string().contains("application-default"), - "display={err}, debug={err:?}" - ); + let fmt = format!("{err:?}"); + assert!(fmt.contains("application-default"), "{fmt}"); Ok(()) } @@ -1009,9 +1006,9 @@ mod tests { let err = mdsc.headers(Extensions::new()).await.unwrap_err(); let original_err = find_source_error::(&err).unwrap(); assert!(original_err.is_transient()); - let source = find_source_error::(&err); + let source = find_source_error::(&err); assert!( - matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)), + matches!(source, Some(e) if e.http_status_code() == Some(StatusCode::SERVICE_UNAVAILABLE.into())), "{err:?}" ); @@ -1040,9 +1037,9 @@ mod tests { let err = mdsc.headers(Extensions::new()).await.unwrap_err(); let original_err = find_source_error::(&err).unwrap(); assert!(!original_err.is_transient()); - let source = find_source_error::(&err); + let source = find_source_error::(&err); assert!( - matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)), + matches!(source, Some(e) if e.http_status_code() == Some(StatusCode::UNAUTHORIZED.into())), "{err:?}" ); diff --git a/src/auth/src/errors.rs b/src/auth/src/errors.rs index 42c2719184..637613101a 100644 --- a/src/auth/src/errors.rs +++ b/src/auth/src/errors.rs @@ -73,6 +73,11 @@ impl SubjectTokenProviderError for CredentialsError { } } +pub(crate) fn from_gax_error(err: google_cloud_gax::error::Error, msg: &str) -> CredentialsError { + let transient = is_gax_error_retryable(&err); + CredentialsError::new(transient, msg, err) +} + pub(crate) fn from_http_error(err: reqwest::Error, msg: &str) -> CredentialsError { let transient = self::is_retryable(&err); CredentialsError::new(transient, msg, err) @@ -99,6 +104,29 @@ pub(crate) fn non_retryable_from_str>(message: T) -> Credentials CredentialsError::from_msg(false, message) } +pub(crate) fn is_gax_error_retryable(err: &google_cloud_gax::error::Error) -> bool { + if let Some(code) = err.http_status_code() { + if let Ok(status_code) = StatusCode::from_u16(code) { + if is_retryable_code(status_code) { + return true; + } + } + } + + match err.source() { + Some(s) => { + if let Some(cred_err) = s.downcast_ref::() { + cred_err.is_transient() + } else if let Some(req_err) = s.downcast_ref::() { + is_retryable(req_err) + } else { + false + } + } + None => false, + } +} + fn is_retryable(err: &reqwest::Error) -> bool { // Connection errors are transient more often than not. A bad configuration // can point to a non-existing service, and that will never recover. diff --git a/src/auth/src/mds/client.rs b/src/auth/src/mds/client.rs index 6143db6e09..d369e61ae2 100644 --- a/src/auth/src/mds/client.rs +++ b/src/auth/src/mds/client.rs @@ -12,9 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::errors::{self, CredentialsError}; +use crate::errors::CredentialsError; use crate::token::Token; +use google_cloud_gax::backoff_policy::BackoffPolicyArg; +use google_cloud_gax::exponential_backoff::ExponentialBackoff; +use google_cloud_gax::retry_loop_internal::retry_loop; +use google_cloud_gax::retry_policy::{NeverRetry, RetryPolicyArg}; +use google_cloud_gax::retry_throttler::{ + AdaptiveThrottler, RetryThrottlerArg, SharedRetryThrottler, +}; use reqwest::{Client as ReqwestClient, RequestBuilder}; +use std::sync::{Arc, Mutex}; use std::time::Duration; use tokio::time::Instant; @@ -105,6 +113,7 @@ impl Client { pub(crate) fn universe_domain(&self) -> UniverseDomainRequest { UniverseDomainRequest { client: self.clone(), + retry_config: RetryConfig::default(), } } @@ -112,30 +121,90 @@ impl Client { &self, request: reqwest::RequestBuilder, error_message: &'static str, + retry_config: RetryConfig, ) -> crate::Result { - let response = request - .send() - .await - .map_err(|e| errors::from_http_error(e, error_message))?; - - let response = Self::check_response_status(response, error_message).await?; - - Ok(response) + let sleep = async |d| tokio::time::sleep(d).await; + + let error_message_str = error_message.to_string().clone(); + retry_loop( + async move |_| { + let req = request + .try_clone() + .expect("client libraries only create builders where `try_clone()` succeeds"); + let response = req + .send() + .await + .map_err(google_cloud_gax::error::Error::io)?; + + let response = + Self::check_response_status(response, error_message_str.clone()).await?; + + Ok(response) + }, + sleep, + true, // GET requests are idempotent + retry_config.retry_throttler, + retry_config.retry_policy.into(), + retry_config.backoff_policy.into(), + ) + .await + .map_err(|e| crate::errors::from_gax_error(e, error_message)) } async fn check_response_status( response: reqwest::Response, - error_message: &str, - ) -> crate::Result { - if !response.status().is_success() { - let err = errors::from_http_response(response, error_message).await; - Err(err) - } else { - Ok(response) + error_message: String, + ) -> Result { + let status = response.status(); + if !status.is_success() { + let err_headers = response.headers().clone(); + let err_payload = response + .bytes() + .await + .map_err(|e| google_cloud_gax::error::Error::transport(err_headers.clone(), e))?; + return Err(google_cloud_gax::error::Error::http( + status.as_u16(), + err_headers, + format!("{error_message} :{err_payload:?}").into(), + )); + } + Ok(response) + } +} +#[derive(Clone)] +struct RetryConfig { + retry_policy: RetryPolicyArg, + backoff_policy: BackoffPolicyArg, + retry_throttler: SharedRetryThrottler, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + retry_policy: NeverRetry.into(), + backoff_policy: ExponentialBackoff::default().into(), + retry_throttler: Arc::new(Mutex::new(AdaptiveThrottler::default())), } } } +impl RetryConfig { + fn with_retry_policy(mut self, retry_policy: RetryPolicyArg) -> Self { + self.retry_policy = retry_policy; + self + } + + fn with_backoff_policy(mut self, backoff_policy: BackoffPolicyArg) -> Self { + self.backoff_policy = backoff_policy; + self + } + + fn with_retry_throttler(mut self, retry_throttler: RetryThrottlerArg) -> Self { + self.retry_throttler = retry_throttler.into(); + self + } +} + #[derive(Clone)] pub(crate) struct AccessTokenRequest { client: Client, @@ -160,7 +229,10 @@ impl AccessTokenRequest { // running on MDS environments and not useful if there is no MDS. We will mark the error // as retryable and let the retry policy determine whether to retry or not. Whenever we // define a default retry policy, we can skip retrying this case. - let response = self.client.send(request, error_message).await?; + let response = self + .client + .send(request, error_message, RetryConfig::default()) + .await?; let response = response.json::().await.map_err(|e| { // Decoding errors are not transient. Typically they indicate a badly @@ -205,7 +277,10 @@ impl IdTokenRequest { }); let error_message = "failed to fetch id token"; - let response = self.client.send(request, error_message).await?; + let response = self + .client + .send(request, error_message, RetryConfig::default()) + .await?; let token = response .text() @@ -227,7 +302,10 @@ impl EmailRequest { let request = self.client.get(&path); let error_message = "failed to fetch email"; - let response = self.client.send(request, error_message).await?; + let response = self + .client + .send(request, error_message, RetryConfig::default()) + .await?; let email = response .text() @@ -242,16 +320,38 @@ impl EmailRequest { #[allow(dead_code)] pub(crate) struct UniverseDomainRequest { client: Client, + retry_config: RetryConfig, } impl UniverseDomainRequest { + #[allow(dead_code)] + pub(crate) fn with_retry_policy(mut self, retry_policy: RetryPolicyArg) -> Self { + self.retry_config = self.retry_config.with_retry_policy(retry_policy); + self + } + + #[allow(dead_code)] + pub(crate) fn with_backoff_policy(mut self, backoff_policy: BackoffPolicyArg) -> Self { + self.retry_config = self.retry_config.with_backoff_policy(backoff_policy); + self + } + + #[allow(dead_code)] + pub(crate) fn with_retry_throttler(mut self, retry_throttler: RetryThrottlerArg) -> Self { + self.retry_config = self.retry_config.with_retry_throttler(retry_throttler); + self + } + #[allow(dead_code)] pub(crate) async fn send(self) -> crate::Result { let path = super::MDS_UNIVERSE_DOMAIN_URI; let request = self.client.get(path); let error_message = "failed to fetch universe domain"; - let response = self.client.send(request, error_message).await?; + let response = self + .client + .send(request, error_message, self.retry_config) + .await?; let universe_domain = response .text() @@ -266,6 +366,8 @@ impl UniverseDomainRequest { mod tests { use super::*; use crate::mds::{MDS_DEFAULT_URI, MDS_UNIVERSE_DOMAIN_URI}; + use google_cloud_gax::exponential_backoff::ExponentialBackoffBuilder; + use google_cloud_gax::retry_policy::{AlwaysRetry, RetryPolicyExt}; use httptest::{Expectation, Server, matchers::*, responders::*}; use scoped_env::ScopedEnv; use serial_test::{parallel, serial}; @@ -479,4 +581,76 @@ mod tests { let client = Client::new(Some("http://custom.endpoint".to_string())); assert_eq!(client.endpoint, "http://env.priority.host"); } + + #[tokio::test] + #[parallel] + async fn test_universe_domain_retry_success() { + let server = Server::run(); + let client = Client::new(Some(format!("http://{}", server.addr()))); + + // First request fails, second succeeds + let responses: Vec> = vec![ + Box::new(status_code(500)), + Box::new(status_code(200).body("my-universe-domain.com")), + ]; + server.expect( + Expectation::matching(all_of![ + request::method("GET"), + request::path(MDS_UNIVERSE_DOMAIN_URI), + ]) + .times(2) + .respond_with(cycle(responses)), + ); + + let retry_policy = AlwaysRetry.with_attempt_limit(2); + let backoff_policy = ExponentialBackoffBuilder::new() + .with_initial_delay(Duration::from_millis(1)) + .with_maximum_delay(Duration::from_millis(1)) + .build() + .unwrap(); + + let domain = client + .universe_domain() + .with_retry_policy(retry_policy.into()) + .with_backoff_policy(backoff_policy.into()) + .send() + .await + .unwrap(); + + assert_eq!(domain, "my-universe-domain.com"); + } + + #[tokio::test] + #[parallel] + async fn test_universe_domain_retry_failure() { + let server = Server::run(); + let client = Client::new(Some(format!("http://{}", server.addr()))); + + // All requests fail + server.expect( + Expectation::matching(all_of![ + request::method("GET"), + request::path(MDS_UNIVERSE_DOMAIN_URI), + ]) + .times(2) + .respond_with(status_code(500)), + ); + + let retry_policy = AlwaysRetry.with_attempt_limit(2); + let backoff_policy = ExponentialBackoffBuilder::new() + .with_initial_delay(Duration::from_millis(1)) + .with_maximum_delay(Duration::from_millis(1)) + .build() + .unwrap(); + + let err = client + .universe_domain() + .with_retry_policy(retry_policy.into()) + .with_backoff_policy(backoff_policy.into()) + .send() + .await + .unwrap_err(); + + assert!(err.to_string().contains("failed to fetch universe domain")); + } }