Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 210 additions & 1 deletion src/auth/src/mds/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,17 @@

use crate::errors::{self, CredentialsError};
use crate::token::Token;
use google_cloud_gax::backoff_policy::BackoffPolicy;
use google_cloud_gax::backoff_policy::BackoffPolicyArg;
use google_cloud_gax::exponential_backoff::ExponentialBackoff;
use google_cloud_gax::retry_loop_internal::retry_loop;
use google_cloud_gax::retry_policy::RetryPolicyArg;
use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicy, RetryPolicyExt};
use google_cloud_gax::retry_throttler::{
AdaptiveThrottler, RetryThrottlerArg, SharedRetryThrottler,
};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::time::Instant;

Expand Down Expand Up @@ -105,6 +115,7 @@ impl Client {
pub(crate) fn universe_domain(&self) -> UniverseDomainRequest {
UniverseDomainRequest {
client: self.clone(),
retry_config: RetryConfig::default(),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about creating the policies once, and storing them on the client?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we want to have different policies per call, because the MDS Client is reused on the MDS Credentials provider. For calls to fetch tokens we don't want retries, because the TokenCache already handles retries. We only want retries here for universe_domain calls

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I want is to allocate the policies once.

If we only make one attempt to get the universe_domain per credentials, then it's fine to do it on that call.

}
}

Expand All @@ -123,6 +134,55 @@ impl Client {
Ok(response)
}

async fn send_with_retry(
&self,
request: reqwest::RequestBuilder,
error_message: &'static str,
retry_config: RetryConfig,
) -> crate::Result<reqwest::Response> {
let sleep = async |d| tokio::time::sleep(d).await;

if !retry_config.has_retry_config() {
return self.send(request, error_message).await;
}
Comment on lines +145 to +147
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This doesn't seem like a special case. I think we should always use a retry loop, and the default settings are policies that don't retry. That is less complexity IMO.


let (retry_policy, backoff_policy, retry_throttler) = retry_config.build();

retry_loop(
async move |_| {
let req = request
.try_clone()
.expect("client libraries only create builders where `try_clone()` succeeds");
let response = req
.send()
.await
.map_err(google_cloud_gax::error::Error::io)?;

let status = response.status();
if !status.is_success() {
let err_headers = response.headers().clone();
let err_payload = response.bytes().await.map_err(|e| {
google_cloud_gax::error::Error::transport(err_headers.clone(), e)
})?;
return Err(google_cloud_gax::error::Error::http(
status.as_u16(),
err_headers,
err_payload,
));
}
Comment on lines +153 to +172
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we duping the send(...) code? can we reuse it?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

send(...) maps errors to Credentials Errors and is simpler (we don't need to clone requests). On the retry loop we need to map to gax errors and clone requests. I tried to reuse both, but code was not looking nicer. We could make send(...) already return gax errors and map later to credentials errors, but we can loose some information as gax errors are more specific


Ok(response)
},
sleep,
true, // GET requests are idempotent
retry_throttler,
retry_policy,
backoff_policy,
)
.await
.map_err(|e| errors::CredentialsError::new(false, error_message, e))
}

async fn check_response_status(
response: reqwest::Response,
error_message: &str,
Expand All @@ -135,6 +195,59 @@ impl Client {
}
}
}
#[derive(Clone, Default)]
struct RetryConfig {
retry_policy: Option<RetryPolicyArg>,
backoff_policy: Option<BackoffPolicyArg>,
retry_throttler: Option<RetryThrottlerArg>,
}

impl RetryConfig {
fn with_retry_policy(mut self, retry_policy: RetryPolicyArg) -> Self {
self.retry_policy = Some(retry_policy);
self
}

fn with_backoff_policy(mut self, backoff_policy: BackoffPolicyArg) -> Self {
self.backoff_policy = Some(backoff_policy);
self
}

fn with_retry_throttler(mut self, retry_throttler: RetryThrottlerArg) -> Self {
self.retry_throttler = Some(retry_throttler);
self
}

fn has_retry_config(&self) -> bool {
self.retry_policy.is_some()
|| self.backoff_policy.is_some()
|| self.retry_throttler.is_some()
}

fn build(
self,
) -> (
Arc<dyn RetryPolicy>,
Arc<dyn BackoffPolicy>,
SharedRetryThrottler,
) {
let backoff_policy: Arc<dyn BackoffPolicy> = match self.backoff_policy {
Some(p) => p.into(),
None => Arc::new(ExponentialBackoff::default()),
};
let retry_throttler: SharedRetryThrottler = match self.retry_throttler {
Some(p) => p.into(),
None => Arc::new(Mutex::new(AdaptiveThrottler::default())),
};

let retry_policy = self
.retry_policy
.unwrap_or_else(|| Aip194Strict.with_time_limit(Duration::from_secs(60)).into())
.into();

(retry_policy, backoff_policy, retry_throttler)
}
}

#[derive(Clone)]
pub(crate) struct AccessTokenRequest {
Expand Down Expand Up @@ -242,16 +355,38 @@ impl EmailRequest {
#[allow(dead_code)]
pub(crate) struct UniverseDomainRequest {
client: Client,
retry_config: RetryConfig,
}

impl UniverseDomainRequest {
#[allow(dead_code)]
pub(crate) fn with_retry_policy(mut self, retry_policy: RetryPolicyArg) -> Self {
self.retry_config = self.retry_config.with_retry_policy(retry_policy);
self
}

#[allow(dead_code)]
pub(crate) fn with_backoff_policy(mut self, backoff_policy: BackoffPolicyArg) -> Self {
self.retry_config = self.retry_config.with_backoff_policy(backoff_policy);
self
}

#[allow(dead_code)]
pub(crate) fn with_retry_throttler(mut self, retry_throttler: RetryThrottlerArg) -> Self {
self.retry_config = self.retry_config.with_retry_throttler(retry_throttler);
self
}

#[allow(dead_code)]
pub(crate) async fn send(self) -> crate::Result<String> {
let path = super::MDS_UNIVERSE_DOMAIN_URI;
let request = self.client.get(path);
let error_message = "failed to fetch universe domain";

let response = self.client.send(request, error_message).await?;
let response = self
.client
.send_with_retry(request, error_message, self.retry_config)
.await?;

let universe_domain = response
.text()
Expand All @@ -266,6 +401,8 @@ impl UniverseDomainRequest {
mod tests {
use super::*;
use crate::mds::{MDS_DEFAULT_URI, MDS_UNIVERSE_DOMAIN_URI};
use google_cloud_gax::exponential_backoff::ExponentialBackoffBuilder;
use google_cloud_gax::retry_policy::AlwaysRetry;
use httptest::{Expectation, Server, matchers::*, responders::*};
use scoped_env::ScopedEnv;
use serial_test::{parallel, serial};
Expand Down Expand Up @@ -479,4 +616,76 @@ mod tests {
let client = Client::new(Some("http://custom.endpoint".to_string()));
assert_eq!(client.endpoint, "http://env.priority.host");
}

#[tokio::test]
#[parallel]
async fn test_universe_domain_retry_success() {
let server = Server::run();
let client = Client::new(Some(format!("http://{}", server.addr())));

// First request fails, second succeeds
let responses: Vec<Box<dyn Responder>> = vec![
Box::new(status_code(500)),
Box::new(status_code(200).body("my-universe-domain.com")),
];
server.expect(
Expectation::matching(all_of![
request::method("GET"),
request::path(MDS_UNIVERSE_DOMAIN_URI),
])
.times(2)
.respond_with(cycle(responses)),
);

let retry_policy = AlwaysRetry.with_attempt_limit(2);
let backoff_policy = ExponentialBackoffBuilder::new()
.with_initial_delay(Duration::from_millis(1))
.with_maximum_delay(Duration::from_millis(1))
.build()
.unwrap();

let domain = client
.universe_domain()
.with_retry_policy(retry_policy.into())
.with_backoff_policy(backoff_policy.into())
.send()
.await
.unwrap();

assert_eq!(domain, "my-universe-domain.com");
}

#[tokio::test]
#[parallel]
async fn test_universe_domain_retry_failure() {
let server = Server::run();
let client = Client::new(Some(format!("http://{}", server.addr())));

// All requests fail
server.expect(
Expectation::matching(all_of![
request::method("GET"),
request::path(MDS_UNIVERSE_DOMAIN_URI),
])
.times(2)
.respond_with(status_code(500)),
);

let retry_policy = AlwaysRetry.with_attempt_limit(2);
let backoff_policy = ExponentialBackoffBuilder::new()
.with_initial_delay(Duration::from_millis(1))
.with_maximum_delay(Duration::from_millis(1))
.build()
.unwrap();

let err = client
.universe_domain()
.with_retry_policy(retry_policy.into())
.with_backoff_policy(backoff_policy.into())
.send()
.await
.unwrap_err();

assert!(err.to_string().contains("failed to fetch universe domain"));
}
}
Loading