diff --git a/src/auth/src/signer/iam.rs b/src/auth/src/signer/iam.rs index cb53205990..b04f1fbd30 100644 --- a/src/auth/src/signer/iam.rs +++ b/src/auth/src/signer/iam.rs @@ -34,7 +34,7 @@ use std::time::Duration; pub(crate) struct IamSigner { client_email: String, inner: Credentials, - endpoint: String, + iam_endpoint_override: Option, client: Client, retry_policy: Arc, backoff_policy: Arc, @@ -52,18 +52,36 @@ struct SignBlobResponse { } impl IamSigner { - pub(crate) fn new(client_email: String, inner: Credentials, endpoint: Option) -> Self { + pub(crate) fn new( + client_email: String, + inner: Credentials, + iam_endpoint_override: Option, + ) -> Self { let retry_policy = Aip194Strict.with_time_limit(Duration::from_secs(60)); let backoff_policy = ExponentialBackoff::default(); Self { client_email, inner, - endpoint: endpoint.unwrap_or("https://iamcredentials.googleapis.com".to_string()), + iam_endpoint_override, client: Client::new(), retry_policy: Arc::new(retry_policy), backoff_policy: Arc::new(backoff_policy), } } + + async fn sign_blob_url(&self) -> String { + let endpoint = match self.iam_endpoint_override.as_ref() { + Some(endpoint) => endpoint.clone(), + None => { + let universe_domain = crate::universe_domain::resolve(&self.inner).await; + format!("https://iamcredentials.{universe_domain}") + } + }; + format!( + "{}/v1/projects/-/serviceAccounts/{}:signBlob", + endpoint, self.client_email + ) + } } #[async_trait::async_trait] @@ -78,11 +96,7 @@ impl SigningProvider for IamSigner { let payload = BASE64_STANDARD.encode(content); let body = SignBlobRequest { payload }; - let client_email = self.client_email.clone(); - let url = format!( - "{}/v1/projects/-/serviceAccounts/{client_email}:signBlob", - self.endpoint - ); + let url = self.sign_blob_url().await; let response = sign_blob_call_with_retry( self.inner.clone(), self.client.clone(), @@ -195,6 +209,7 @@ mod tests { use httptest::responders::{json_encoded, status_code}; use httptest::{Expectation, Server}; use serde_json::json; + use test_case::test_case; use tokio::time::Duration; type TestResult = anyhow::Result<()>; @@ -336,6 +351,56 @@ mod tests { Ok(()) } + #[test_case(None ; "no custom universe domain")] + #[test_case(Some("my-custom-universe.com".to_string()) ; "with custom universe domain")] + #[tokio::test] + async fn test_sign_blob_url_with_override(universe_domain: Option) -> TestResult { + let mut mock = MockCredentials::new(); + mock.expect_universe_domain() + .returning(move || universe_domain.clone()); + let creds = Credentials::from(mock); + let signer = IamSigner::new( + "test@example.com".to_string(), + creds, + Some("http://example.com".to_string()), + ); + let url = signer.sign_blob_url().await; + assert_eq!( + url, + "http://example.com/v1/projects/-/serviceAccounts/test@example.com:signBlob" + ); + Ok(()) + } + + #[tokio::test] + async fn test_sign_blob_url_default_universe() -> TestResult { + let mut mock = MockCredentials::new(); + mock.expect_universe_domain().returning(|| None); + let creds = Credentials::from(mock); + let signer = IamSigner::new("test@example.com".to_string(), creds, None); + let url = signer.sign_blob_url().await; + assert_eq!( + url, + "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/test@example.com:signBlob" + ); + Ok(()) + } + + #[tokio::test] + async fn test_sign_blob_url_custom_universe() -> TestResult { + let mut mock = MockCredentials::new(); + mock.expect_universe_domain() + .returning(|| Some("my-custom-universe.com".to_string())); + let creds = Credentials::from(mock); + let signer = IamSigner::new("test@example.com".to_string(), creds, None); + let url = signer.sign_blob_url().await; + assert_eq!( + url, + "https://iamcredentials.my-custom-universe.com/v1/projects/-/serviceAccounts/test@example.com:signBlob" + ); + Ok(()) + } + fn test_backoff_policy() -> ExponentialBackoff { use google_cloud_gax::exponential_backoff::ExponentialBackoffBuilder; ExponentialBackoffBuilder::new() diff --git a/src/auth/src/signer/mds.rs b/src/auth/src/signer/mds.rs index f9df86fef4..ce278f9dac 100644 --- a/src/auth/src/signer/mds.rs +++ b/src/auth/src/signer/mds.rs @@ -94,6 +94,7 @@ mod tests { use httptest::responders::{json_encoded, status_code}; use httptest::{Expectation, Server}; use serde_json::json; + use serial_test::serial; type TestResult = anyhow::Result<()>; @@ -131,6 +132,7 @@ mod tests { } #[tokio::test] + #[serial] async fn test_sign() -> TestResult { let server = Server::run(); server.expect(