Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 67 additions & 32 deletions src/auth/src/access_boundary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::credentials::{
AccessToken, AccessTokenCredentialsProvider, CacheableResource, CredentialsProvider, dynamic,
};
use crate::errors::CredentialsError;
use crate::io::{HttpRequest, SharedHttpClientProvider};
use crate::mds::client::Client as MDSClient;
use crate::{Result, errors};
use google_cloud_gax::Result as GaxResult;
Expand All @@ -28,7 +29,6 @@ use google_cloud_gax::retry_loop_internal::retry_loop;
use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicy, RetryPolicyExt};
use google_cloud_gax::retry_throttler::{AdaptiveThrottler, RetryThrottlerArg};
use http::{Extensions, HeaderMap, HeaderValue};
use reqwest::Client;
use std::clone::Clone;
use std::error::Error;
use std::fmt::Debug;
Expand Down Expand Up @@ -235,11 +235,16 @@ impl<T> CredentialsWithAccessBoundary<T>
where
T: dynamic::AccessTokenCredentialsProvider + 'static,
{
pub(crate) fn new(credentials: T, access_boundary_url: Option<String>) -> Self {
pub(crate) fn new(
credentials: T,
access_boundary_url: Option<String>,
http: SharedHttpClientProvider,
) -> Self {
let credentials = Arc::new(credentials);
let provider = IAMAccessBoundaryProvider {
credentials: credentials.clone(),
url: access_boundary_url,
http,
};
let access_boundary = Arc::new(AccessBoundary::new(provider));
Self {
Expand All @@ -253,13 +258,15 @@ where
credentials: T,
mds_client: MDSClient,
iam_endpoint_override: Option<String>,
http: SharedHttpClientProvider,
) -> Self {
let credentials = Arc::new(credentials);
let provider = MDSAccessBoundaryProvider {
credentials: credentials.clone(),
mds_client,
iam_endpoint_override,
url: OnceLock::new(),
http,
};
let access_boundary = Arc::new(AccessBoundary::new(provider));
Self {
Expand Down Expand Up @@ -403,6 +410,7 @@ where
{
credentials: Arc<T>,
url: Option<String>,
http: SharedHttpClientProvider,
}

#[async_trait::async_trait]
Expand All @@ -413,7 +421,11 @@ where
async fn fetch_access_boundary(&self) -> Result<Option<String>> {
match self.url.as_ref() {
Some(url) => {
let client = AccessBoundaryClient::new(self.credentials.clone(), url.clone());
let client = AccessBoundaryClient::new(
self.credentials.clone(),
url.clone(),
SharedHttpClientProvider::clone(&self.http),
);
client.fetch().await
}
None => Ok(None), // No URL means no access boundary
Expand All @@ -431,6 +443,7 @@ where
mds_client: MDSClient,
iam_endpoint_override: Option<String>,
url: OnceLock<String>,
http: SharedHttpClientProvider,
}

#[async_trait::async_trait]
Expand All @@ -449,7 +462,11 @@ where
}

let url = self.url.get().unwrap().to_string();
let client = AccessBoundaryClient::new(self.credentials.clone(), url);
let client = AccessBoundaryClient::new(
self.credentials.clone(),
url,
SharedHttpClientProvider::clone(&self.http),
);
client.fetch().await
}
}
Expand All @@ -472,10 +489,11 @@ struct AccessBoundaryClient<T> {
url: String,
retry_policy: Arc<dyn RetryPolicy>,
backoff_policy: Arc<dyn BackoffPolicy>,
http: SharedHttpClientProvider,
}

impl<T> AccessBoundaryClient<T> {
fn new(credentials: Arc<T>, url: String) -> Self {
fn new(credentials: Arc<T>, url: String, http: SharedHttpClientProvider) -> Self {
let retry_policy = Aip194Strict.with_time_limit(Duration::from_secs(60));
let backoff_policy = ExponentialBackoff::default();

Expand All @@ -484,6 +502,7 @@ impl<T> AccessBoundaryClient<T> {
url,
retry_policy: Arc::new(retry_policy),
backoff_policy: Arc::new(backoff_policy),
http,
}
}
}
Expand All @@ -509,19 +528,19 @@ where
}

async fn fetch_with_retry(self) -> GaxResult<AllowedLocationsResponse> {
let client = Client::new();
let sleep = async |d| tokio::time::sleep(d).await;

let retry_throttler: RetryThrottlerArg = AdaptiveThrottler::default().into();
let creds = self.credentials;
let url = self.url;
let http = self.http;
let inner = async move |d| {
let headers = creds
.headers(Extensions::new())
.await
.map_err(GaxError::authentication)?;

let attempt = self::fetch_access_boundary_call(&client, &url, headers);
let attempt = self::fetch_access_boundary_call(&http, &url, headers);
match d {
Some(timeout) => match tokio::time::timeout(timeout, attempt).await {
Ok(r) => r,
Expand All @@ -544,7 +563,7 @@ where
}

async fn fetch_access_boundary_call(
client: &Client,
http: &SharedHttpClientProvider,
url: &str,
headers: CacheableResource<HeaderMap>,
) -> GaxResult<AllowedLocationsResponse> {
Expand All @@ -555,24 +574,19 @@ async fn fetch_access_boundary_call(
}
};

let resp = client
.get(url)
.headers(headers)
.send()
.await
.map_err(GaxError::io)?;

let status = resp.status();
if !status.is_success() {
let err_headers = resp.headers().clone();
let err_payload = resp
.bytes()
.await
.map_err(|e| GaxError::transport(err_headers.clone(), e))?;
return Err(GaxError::http(status.as_u16(), err_headers, err_payload));
let request = HttpRequest::get(url).headers_from_map(&headers);

let response = http.execute(request).await.map_err(GaxError::io)?;

if !response.is_success() {
return Err(GaxError::http(
response.status.as_u16(),
response.headers,
response.body.into(),
));
}

resp.json().await.map_err(GaxError::io)
response.json().map_err(GaxError::io)
}

async fn refresh_task<T>(provider: T, tx_header: watch::Sender<(Option<BoundaryValue>, EntityTag)>)
Expand Down Expand Up @@ -750,7 +764,11 @@ pub(crate) mod tests {
});
let url = server.url("/allowedLocations").to_string();

let creds = CredentialsWithAccessBoundary::new(mock, Some(url));
let creds = CredentialsWithAccessBoundary::new(
mock,
Some(url),
SharedHttpClientProvider::default(),
);

// wait for the background task to fetch the access boundary.
creds.wait_for_boundary().await;
Expand Down Expand Up @@ -804,9 +822,18 @@ pub(crate) mod tests {
})
});
let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
let mds_client = MDSClient::new(Some(endpoint.clone()));
let mds_client = MDSClient::new(
Some(endpoint.clone()),
SharedEnvProvider::default(),
SharedHttpClientProvider::default(),
);

let creds = CredentialsWithAccessBoundary::new_for_mds(mock, mds_client, Some(endpoint));
let creds = CredentialsWithAccessBoundary::new_for_mds(
mock,
mds_client,
Some(endpoint),
SharedHttpClientProvider::default(),
);

// wait for the background task to fetch the access boundary.
creds.wait_for_boundary().await;
Expand Down Expand Up @@ -849,7 +876,8 @@ pub(crate) mod tests {
})
});
let url = server.url("/allowedLocations").to_string();
let client = AccessBoundaryClient::new(Arc::new(mock), url);
let client =
AccessBoundaryClient::new(Arc::new(mock), url, SharedHttpClientProvider::default());
let val = client.fetch().await?;
assert!(val.is_none(), "{val:?}");

Expand Down Expand Up @@ -879,7 +907,8 @@ pub(crate) mod tests {
});

let url = server.url("/allowedLocations").to_string();
let mut client = AccessBoundaryClient::new(Arc::new(mock), url);
let mut client =
AccessBoundaryClient::new(Arc::new(mock), url, SharedHttpClientProvider::default());
client.retry_policy = Arc::new(Aip194Strict.with_attempt_limit(3));
client.backoff_policy = Arc::new(test_backoff_policy());

Expand All @@ -899,7 +928,11 @@ pub(crate) mod tests {
))
});

let client = AccessBoundaryClient::new(Arc::new(mock), "http://localhost".to_string());
let client = AccessBoundaryClient::new(
Arc::new(mock),
"http://localhost".to_string(),
SharedHttpClientProvider::default(),
);
let err = client.fetch().await.unwrap_err();
assert!(!err.is_transient(), "{err:?}");
}
Expand All @@ -918,7 +951,8 @@ pub(crate) mod tests {
data: headers,
})
});
let creds = CredentialsWithAccessBoundary::new(mock, None);
let creds =
CredentialsWithAccessBoundary::new(mock, None, SharedHttpClientProvider::default());

let cached_headers = creds.headers(Extensions::new()).await?;
let token = get_token_from_headers(cached_headers.clone());
Expand Down Expand Up @@ -1240,7 +1274,8 @@ pub(crate) mod tests {
});

let url = server.url("/allowedLocations").to_string();
let mut client = AccessBoundaryClient::new(Arc::new(mock), url);
let mut client =
AccessBoundaryClient::new(Arc::new(mock), url, SharedHttpClientProvider::default());
client.backoff_policy = Arc::new(test_backoff_policy());
let val = client.fetch().await?;
assert_eq!(val.as_deref(), Some("0x123"));
Expand Down
Loading
Loading