diff --git a/Cargo.lock b/Cargo.lock index ec1f0475c9..5207f37093 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5911,10 +5911,16 @@ dependencies = [ "google-cloud-lro", "google-cloud-spanner", "google-cloud-test-utils", + "google-cloud-wkt", "prost-types", + "rand 0.10.0", "reqwest 0.13.2", "serde_json", + "spanner-grpc-mock", + "time", "tokio", + "tokio-stream", + "tonic", "tracing", ] diff --git a/deny.toml b/deny.toml index fa61a89904..67b8359d88 100644 --- a/deny.toml +++ b/deny.toml @@ -115,6 +115,7 @@ wrappers = [ # Use in tests is fine. "grpc-server", "integration-tests-o11y", + "integration-tests-spanner", "pubsub-grpc-mock", "spanner-grpc-mock", "storage-grpc-mock", diff --git a/src/spanner/grpc-mock/src/lib.rs b/src/spanner/grpc-mock/src/lib.rs index 72b19b07bd..10a3fb74c6 100644 --- a/src/spanner/grpc-mock/src/lib.rs +++ b/src/spanner/grpc-mock/src/lib.rs @@ -64,6 +64,7 @@ pub mod google { include!("generated/protos/google.rpc.rs"); } pub mod spanner { + #[allow(rustdoc::broken_intra_doc_links, rustdoc::bare_urls)] pub mod v1 { include!("generated/protos/google.spanner.v1.rs"); } diff --git a/src/spanner/src/batch_read_only_transaction.rs b/src/spanner/src/batch_read_only_transaction.rs index 4617ef9e2d..313f291d22 100644 --- a/src/spanner/src/batch_read_only_transaction.rs +++ b/src/spanner/src/batch_read_only_transaction.rs @@ -16,7 +16,7 @@ use crate::database_client::DatabaseClient; use crate::model::PartitionOptions; use crate::precommit::PrecommitTokenTracker; use crate::read_only_transaction::{ - MultiUseReadOnlyTransaction, MultiUseReadOnlyTransactionBuilder, + MultiUseReadOnlyTransaction, MultiUseReadOnlyTransactionBuilder, ReadContextTransactionSelector, }; use crate::result_set::{ResultSet, StreamOperation}; use crate::statement::Statement; @@ -44,7 +44,8 @@ pub struct BatchReadOnlyTransactionBuilder { impl BatchReadOnlyTransactionBuilder { pub(crate) fn new(client: DatabaseClient) -> Self { Self { - inner: MultiUseReadOnlyTransactionBuilder::new(client), + inner: MultiUseReadOnlyTransactionBuilder::new(client) + .with_explicit_begin_transaction(true), } } @@ -143,12 +144,13 @@ impl BatchReadOnlyTransaction { statement: T, options: PartitionOptions, ) -> crate::Result> { + let selector = self.inner.context.transaction_selector.selector().await?; let statement = statement.into(); let request = statement .clone() .into_partition_query_request() .set_session(self.inner.context.client.session.name.clone()) - .set_transaction(self.inner.context.transaction_selector.clone()) + .set_transaction(selector.clone()) .set_partition_options(options); let response = self @@ -165,7 +167,7 @@ impl BatchReadOnlyTransaction { .map(|p| Partition { inner: PartitionedOperation::Query { partition_token: p.partition_token, - transaction_selector: self.inner.context.transaction_selector.clone(), + transaction_selector: selector.clone(), session_name: self.inner.context.client.session.name.clone(), statement: statement.clone(), }, @@ -198,12 +200,13 @@ impl BatchReadOnlyTransaction { read: T, options: PartitionOptions, ) -> crate::Result> { + let selector = self.inner.context.transaction_selector.selector().await?; let read = read.into(); let request = read .clone() .into_partition_read_request() .set_session(self.inner.context.client.session.name.clone()) - .set_transaction(self.inner.context.transaction_selector.clone()) + .set_transaction(selector.clone()) .set_partition_options(options); let response = self @@ -220,7 +223,7 @@ impl BatchReadOnlyTransaction { .map(|p| Partition { inner: PartitionedOperation::Read { partition_token: p.partition_token, - transaction_selector: self.inner.context.transaction_selector.clone(), + transaction_selector: selector.clone(), session_name: self.inner.context.client.session.name.clone(), read_request: read.clone(), }, @@ -344,6 +347,10 @@ impl Partition { Ok(ResultSet::new( stream, + Some(ReadContextTransactionSelector::Fixed( + transaction_selector.clone(), + None, + )), PrecommitTokenTracker::new_noop(), client.clone(), StreamOperation::Query(request), @@ -373,6 +380,10 @@ impl Partition { Ok(ResultSet::new( stream, + Some(ReadContextTransactionSelector::Fixed( + transaction_selector.clone(), + None, + )), PrecommitTokenTracker::new_noop(), client.clone(), StreamOperation::Read(request), diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index 3f0df51a5c..029785c569 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -19,6 +19,9 @@ use crate::precommit::PrecommitTokenTracker; use crate::result_set::{ResultSet, StreamOperation}; use crate::statement::Statement; use crate::timestamp_bound::TimestampBound; +use crate::transaction_retry_policy::is_aborted; +use std::sync::{Arc, Mutex}; +use tokio::sync::Notify; /// A builder for [SingleUseReadOnlyTransaction]. /// @@ -91,7 +94,10 @@ impl SingleUseReadOnlyTransactionBuilder { SingleUseReadOnlyTransaction { context: ReadContext { client: self.client, - transaction_selector, + transaction_selector: ReadContextTransactionSelector::Fixed( + transaction_selector, + None, + ), precommit_token_tracker: PrecommitTokenTracker::new_noop(), transaction_tag: None, }, @@ -204,6 +210,7 @@ impl SingleUseReadOnlyTransaction { pub struct MultiUseReadOnlyTransactionBuilder { client: DatabaseClient, timestamp_bound: Option, + explicit_begin: bool, } impl MultiUseReadOnlyTransactionBuilder { @@ -211,9 +218,44 @@ impl MultiUseReadOnlyTransactionBuilder { Self { client, timestamp_bound: None, + explicit_begin: false, } } + /// Sets whether the transaction should be explicitly started using a `BeginTransaction` RPC. + /// + /// # Example + /// ``` + /// # use google_cloud_spanner::client::Spanner; + /// # use google_cloud_spanner::client::Statement; + /// # async fn set_explicit_begin(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> { + /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?; + /// let transaction = db_client.read_only_transaction().with_explicit_begin_transaction(true).build().await?; + /// let statement = Statement::builder("SELECT * FROM users").build(); + /// let result_set = transaction.execute_query(statement).await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// By default, the Spanner client will inline the `BeginTransaction` call with the first query + /// in the transaction. This reduces the number of round-trips to Spanner that are needed for a + /// transaction. Setting this option to `true` can be beneficial for specific transaction shapes: + /// + /// 1. When the transaction executes multiple parallel queries at the start of the transaction. + /// Only one query can include a `BeginTransaction` option, and all other queries must wait for + /// the first query to return the first result before they can proceed to execute. A + /// `BeginTransaction` RPC will quickly return a transaction ID and allow all queries to start + /// execution in parallel once the transaction ID has been returned. + /// 2. When the first query in the transaction could fail. If the query fails, then it will also + /// not start a transaction and return a transaction ID. The transaction will then fall back to + /// executing a `BeginTransaction` RPC and retry the first query. + /// + /// Default is `false` (inline begin). + pub fn with_explicit_begin_transaction(mut self, explicit: bool) -> Self { + self.explicit_begin = explicit; + self + } + /// Sets the timestamp bound for the read-only transaction. /// /// # Example @@ -231,6 +273,20 @@ impl MultiUseReadOnlyTransactionBuilder { self } + async fn begin( + &self, + options: TransactionOptions, + ) -> crate::Result { + let response = execute_begin_transaction(&self.client, options).await?; + + let transaction_selector = crate::model::TransactionSelector::default().set_id(response.id); + + Ok(ReadContextTransactionSelector::Fixed( + transaction_selector, + response.read_timestamp, + )) + } + /// Builds the [MultiUseReadOnlyTransaction] and starts the transaction /// by calling the `BeginTransaction` RPC. /// @@ -245,30 +301,27 @@ impl MultiUseReadOnlyTransactionBuilder { /// ``` pub async fn build(self) -> crate::Result { let read_only = ReadOnly::default().set_return_read_timestamp(true); - let read_only = match self.timestamp_bound { - Some(b) => read_only.set_timestamp_bound(b.0), + let read_only = match self.timestamp_bound.as_ref() { + Some(b) => read_only.set_timestamp_bound(b.0.clone()), None => read_only.set_strong(true), }; - let request = crate::model::BeginTransactionRequest::default() - .set_session(self.client.session.name.clone()) - .set_options(TransactionOptions::default().set_read_only(read_only)); + let options = TransactionOptions::default().set_read_only(read_only); - // TODO(#4972): make request options configurable - let response = self - .client - .spanner - .begin_transaction(request, crate::RequestOptions::default()) - .await?; + let selector = if self.explicit_begin { + self.begin(options).await? + } else { + ReadContextTransactionSelector::Lazy(Arc::new(Mutex::new( + TransactionState::NotStarted(options), + ))) + }; - let transaction_selector = crate::model::TransactionSelector::default().set_id(response.id); Ok(MultiUseReadOnlyTransaction { context: ReadContext { client: self.client, - transaction_selector, + transaction_selector: selector, precommit_token_tracker: PrecommitTokenTracker::new_noop(), transaction_tag: None, }, - read_timestamp: response.read_timestamp, }) } } @@ -297,13 +350,12 @@ impl MultiUseReadOnlyTransactionBuilder { #[derive(Debug)] pub struct MultiUseReadOnlyTransaction { pub(crate) context: ReadContext, - pub(crate) read_timestamp: Option, } impl MultiUseReadOnlyTransaction { /// Returns the read timestamp chosen for the transaction. pub fn read_timestamp(&self) -> Option { - self.read_timestamp + self.context.transaction_selector.read_timestamp() } /// Executes a query using this transaction. @@ -370,10 +422,254 @@ impl MultiUseReadOnlyTransaction { } } +/// Executes an explicit `BeginTransaction` RPC on Spanner. +async fn execute_begin_transaction( + client: &crate::database_client::DatabaseClient, + options: crate::model::TransactionOptions, +) -> crate::Result { + let request = crate::model::BeginTransactionRequest::default() + .set_session(client.session.name.clone()) + .set_options(options); + + // TODO(#4972): make request options configurable + client + .spanner + .begin_transaction(request, crate::RequestOptions::default()) + .await +} + +#[derive(Clone, Debug)] +pub(crate) enum ReadContextTransactionSelector { + Fixed(crate::model::TransactionSelector, Option), + Lazy(Arc>), +} + +#[derive(Clone, Debug)] +pub(crate) enum TransactionState { + NotStarted(crate::model::TransactionOptions), + Starting(crate::model::TransactionOptions, Arc), + Started(crate::model::TransactionSelector, Option), + Failed(Arc), +} + +enum SelectorStatus { + Ready(crate::model::TransactionSelector), + Wait(std::sync::Arc), +} + +impl ReadContextTransactionSelector { + pub(crate) async fn selector(&self) -> crate::Result { + match self { + Self::Fixed(selector, _) => Ok(selector.clone()), + Self::Lazy(_) => loop { + match self.poll_selector_status()? { + SelectorStatus::Ready(selector) => return Ok(selector), + SelectorStatus::Wait(notify) => notify.notified().await, + } + }, + } + } + + /// Inspects the current lazy selector state returning whether it is ready, + /// failed, or needs to wait for the transaction to start. + fn poll_selector_status(&self) -> crate::Result { + let Self::Lazy(lazy) = self else { + unreachable!("poll_selector_status called on non-Lazy selector"); + }; + let mut guard = lazy.lock().expect("transaction state mutex poisoned"); + + // Fast path: Transaction is already started. + if let TransactionState::Started(selector, _) = &*guard { + return Ok(SelectorStatus::Ready(selector.clone())); + } + + // If the transaction has not started, extract options and proceed to transition. + let pending_options = if let TransactionState::NotStarted(options) = &*guard { + Some(options.clone()) + } else { + None + }; + if let Some(options) = pending_options { + let notify = Arc::new(Notify::new()); + *guard = TransactionState::Starting(options.clone(), Arc::clone(¬ify)); + return Ok(SelectorStatus::Ready( + crate::model::TransactionSelector::default().set_begin(options), + )); + } + + // Handle other states: yield error or wait. + match &*guard { + // Note: Failed will only be reached if the following happens: + // 1. The first query fails and the transaction falls back to an explicit BeginTransaction RPC. + // 2. The BeginTransaction RPC fails. This is the error that will be returned to all the waiting queries. + TransactionState::Failed(err) => { + let error = if let Some(status) = err.status() { + crate::Error::service(status.clone()) + } else { + crate::error::internal_error(format!("Transaction failed to start: {}", err)) + }; + Err(error) + } + // Transaction is starting. Wait until a transaction ID is returned. + TransactionState::Starting(_, notify) => Ok(SelectorStatus::Wait(Arc::clone(notify))), + TransactionState::Started(_, _) | TransactionState::NotStarted(_) => unreachable!(), + } + } + + /// Explicitly begins a transaction if the transaction selector is a `Lazy` + /// selector and the transaction has not yet been started. This is used by + /// the client to force the start of a transaction if the first statement + /// failed. + pub(crate) async fn begin_explicitly( + &self, + client: &crate::database_client::DatabaseClient, + ) -> crate::Result<()> { + let Self::Lazy(lazy) = self else { + return Ok(()); + }; + + let (options, notify_opt) = { + let guard = lazy.lock().expect("transaction state mutex poisoned"); + match &*guard { + // This should never happen in the current implementation. + TransactionState::NotStarted(_) => { + return Err(crate::error::internal_error( + "explicit begin with NotStarted state is currently unsupported", + )); + } + TransactionState::Starting(options, notify) => { + (options.clone(), Some(Arc::clone(notify))) + } + TransactionState::Started(_, _) | TransactionState::Failed(_) => return Ok(()), + } + }; + + let response = match execute_begin_transaction(client, options).await { + Ok(r) => r, + Err(e) => { + let mut guard = lazy.lock().expect("transaction state mutex poisoned"); + let error = Arc::new(e); + *guard = TransactionState::Failed(Arc::clone(&error)); + // Release the lock and notify all the waiting queries that + // the transaction has failed. + drop(guard); + if let Some(notify) = notify_opt { + notify.notify_waiters(); + } + + let return_error = if let Some(status) = error.status() { + crate::Error::service(status.clone()) + } else { + crate::error::internal_error(format!("Transaction failed to start: {}", error)) + }; + return Err(return_error); + } + }; + + self.update(response.id, response.read_timestamp)?; + + Ok(()) + } + + pub(crate) fn update( + &self, + id: bytes::Bytes, + timestamp: Option, + ) -> crate::Result<()> { + let Self::Lazy(lazy) = self else { + return Ok(()); + }; + let mut guard = lazy.lock().expect("transaction state mutex poisoned"); + + if matches!( + &*guard, + TransactionState::NotStarted(_) | TransactionState::Starting(_, _) + ) { + let previous_state = std::mem::replace( + &mut *guard, + TransactionState::Started( + crate::model::TransactionSelector::default().set_id(id), + timestamp, + ), + ); + drop(guard); + + // Notify all queries that are waiting for the transaction. + if let TransactionState::Starting(_, notify) = previous_state { + notify.notify_waiters(); + } + Ok(()) + } else { + Err(crate::error::internal_error( + "got a transaction id for an already Started or Failed transaction", + )) + } + } + + /// Returns the transaction ID if it is already available, without waiting. + /// + /// This method inspects the selector and returns the transaction ID if the + /// transaction has already started. It returns `None` if the transaction + /// has not yet started or is in a state without an ID. + pub(crate) fn get_id_no_wait(&self) -> Option { + use crate::generated::gapic_dataplane::model::transaction_selector::Selector; + match self { + Self::Fixed(selector, _) => { + if let Some(Selector::Id(id)) = &selector.selector { + return Some(id.clone()); + } + } + Self::Lazy(lazy) => { + let guard = lazy.lock().expect("transaction state mutex poisoned"); + if let TransactionState::Started(selector, _) = &*guard { + if let Some(Selector::Id(id)) = &selector.selector { + return Some(id.clone()); + } + } + } + } + None + } + + /// Resets the selector state from `Starting` back to `NotStarted`. + /// + /// This is used during stream resume fallbacks when the first query stream + /// fails before yielding a transaction ID. It unlocks any parked waiters + /// allowing them (or the retry attempt) to include the begin option again. + pub(crate) fn maybe_reset_starting(&self) { + let Self::Lazy(lazy) = self else { + return; + }; + + let mut guard = lazy.lock().expect("transaction state mutex poisoned"); + if let TransactionState::Starting(options, notify) = &*guard { + let options = options.clone(); + let notify = Arc::clone(notify); + *guard = TransactionState::NotStarted(options); + drop(guard); + notify.notify_waiters(); + } + } + + pub(crate) fn read_timestamp(&self) -> Option { + match self { + Self::Fixed(_, timestamp) => *timestamp, + Self::Lazy(lazy) => { + let guard = lazy.lock().expect("transaction state mutex poisoned"); + if let TransactionState::Started(_, timestamp) = &*guard { + *timestamp + } else { + None + } + } + } + } +} + #[derive(Clone, Debug)] pub(crate) struct ReadContext { pub(crate) client: DatabaseClient, - pub(crate) transaction_selector: crate::model::TransactionSelector, + pub(crate) transaction_selector: ReadContextTransactionSelector, pub(crate) precommit_token_tracker: PrecommitTokenTracker, pub(crate) transaction_tag: Option, } @@ -397,6 +693,67 @@ impl ReadContext { options } + /// Attempts to execute an explicit `begin_transaction` RPC if the current transaction + /// selector is still in the `Lazy(NotStarted)` state. This is used as a + /// fallback mechanism when an initial implicit begin attempt failed. + async fn begin_explicitly_if_not_started(&self) -> crate::Result { + let ReadContextTransactionSelector::Lazy(lazy) = &self.transaction_selector else { + return Ok(false); + }; + let is_started = matches!(&*lazy.lock().unwrap(), TransactionState::Started(_, _)); + if is_started { + return Ok(false); + } + + self.transaction_selector + .begin_explicitly(&self.client) + .await?; + Ok(true) + } +} + +/// Helper macro to execute a streaming SQL or streaming read RPC with retry logic. +macro_rules! execute_stream_with_retry { + ($self:expr, $request:ident, $rpc_method:ident, $operation_variant:path) => {{ + let stream = match $self + .client + .spanner + // TODO(#4972): make request options configurable + .$rpc_method($request.clone(), crate::RequestOptions::default()) + .send() + .await + { + Ok(s) => s, + Err(e) => { + if is_aborted(&e) { + return Err(e); + } + if $self.begin_explicitly_if_not_started().await? { + $request.transaction = Some($self.transaction_selector.selector().await?); + $self + .client + .spanner + // TODO(#4972): make request options configurable + .$rpc_method($request.clone(), crate::RequestOptions::default()) + .send() + .await? + } else { + return Err(e); + } + } + }; + + Ok(ResultSet::new( + stream, + Some($self.transaction_selector.clone()), + $self.precommit_token_tracker.clone(), + $self.client.clone(), + $operation_variant($request), + )) + }}; +} + +impl ReadContext { pub(crate) async fn execute_query>( &self, statement: T, @@ -405,23 +762,10 @@ impl ReadContext { .into() .into_request() .set_session(self.client.session.name.clone()) - .set_transaction(self.transaction_selector.clone()); + .set_transaction(self.transaction_selector.selector().await?); request.request_options = self.amend_request_options(request.request_options); - let stream = self - .client - .spanner - // TODO(#4972): make request options configurable - .execute_streaming_sql(request.clone(), crate::RequestOptions::default()) - .send() - .await?; - - Ok(ResultSet::new( - stream, - self.precommit_token_tracker.clone(), - self.client.clone(), - StreamOperation::Query(request), - )) + execute_stream_with_retry!(self, request, execute_streaming_sql, StreamOperation::Query) } pub(crate) async fn execute_read>( @@ -432,29 +776,26 @@ impl ReadContext { .into() .into_request() .set_session(self.client.session.name.clone()) - .set_transaction(self.transaction_selector.clone()); + .set_transaction(self.transaction_selector.selector().await?); request.request_options = self.amend_request_options(request.request_options); - let stream = self - .client - .spanner - // TODO(#4972): make request options configurable - .streaming_read(request.clone(), crate::RequestOptions::default()) - .send() - .await?; - - Ok(ResultSet::new( - stream, - self.precommit_token_tracker.clone(), - self.client.clone(), - StreamOperation::Read(request), - )) + execute_stream_with_retry!(self, request, streaming_read, StreamOperation::Read) } } #[cfg(test)] pub(crate) mod tests { use super::*; + use crate::client::Statement; + use crate::result_set::tests::string_val; + use crate::value::Value; + use gaxi::grpc::tonic::{self, Code, Response, Status}; + use mock_v1::transaction_selector::Selector; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use std::pin::Pin; + use std::sync::Arc; + use std::task::{Context, Poll}; + use tokio::sync::{Barrier, Mutex, Notify, mpsc}; #[test] fn auto_traits() { @@ -468,12 +809,10 @@ pub(crate) mod tests { pub(crate) fn create_session_mock() -> spanner_grpc_mock::MockSpanner { let mut mock = spanner_grpc_mock::MockSpanner::new(); mock.expect_create_session().once().returning(|_| { - Ok(gaxi::grpc::tonic::Response::new( - spanner_grpc_mock::google::spanner::v1::Session { - name: "projects/p/instances/i/databases/d/sessions/123".to_string(), - ..Default::default() - }, - )) + Ok(Response::new(mock_v1::Session { + name: "projects/p/instances/i/databases/d/sessions/123".to_string(), + ..Default::default() + })) }); mock } @@ -502,6 +841,7 @@ pub(crate) mod tests { let (address, server) = spanner_grpc_mock::start("0.0.0.0:0", mock) .await .expect("Failed to start mock server"); + let spanner = Spanner::builder() .with_endpoint(address) .with_credentials(Anonymous::new().build()) @@ -525,9 +865,13 @@ pub(crate) mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = db_client.single_use().build(); - let ro = tx + let selector = tx .context .transaction_selector + .selector() + .await + .expect("Failed to get selector"); + let ro = selector .single_use() .expect("Expected SingleUse selector") .read_only() @@ -543,9 +887,13 @@ pub(crate) mod tests { std::time::Duration::from_secs(10), )) .build(); - let ro2 = tx2 + let selector = tx2 .context .transaction_selector + .selector() + .await + .expect("Failed to get selector"); + let ro2 = selector .single_use() .expect("Expected SingleUse selector") .read_only() @@ -576,9 +924,9 @@ pub(crate) mod tests { ); assert_eq!(req.sql, "SELECT 1"); - Ok(gaxi::grpc::tonic::Response::new(Box::pin( - tokio_stream::iter(vec![Ok(setup_select1())]), - ))) + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(setup_select1()), + ])))) }); let (db_client, _server) = setup_db_client(mock).await; @@ -610,7 +958,7 @@ pub(crate) mod tests { req.session, "projects/p/instances/i/databases/d/sessions/123" ); - Ok(gaxi::grpc::tonic::Response::new(mock_v1::Transaction { + Ok(tonic::Response::new(mock_v1::Transaction { id: vec![1, 2, 3], // prost_types::Timestamp fields need to be explicitly set because default is 0 for both read_timestamp: Some(prost_types::Timestamp { @@ -637,15 +985,16 @@ pub(crate) mod tests { mock_v1::transaction_selector::Selector::Id(vec![1, 2, 3]) ); - Ok(gaxi::grpc::tonic::Response::new(Box::pin( - tokio_stream::iter(vec![Ok(setup_select1())]), - ))) + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(setup_select1()), + ])))) }); let (db_client, _server) = setup_db_client(mock).await; let tx = db_client .read_only_transaction() + .with_explicit_begin_transaction(true) .build() .await .expect("Failed to start tx"); @@ -670,6 +1019,102 @@ pub(crate) mod tests { } } + #[tokio::test] + async fn execute_multi_query_inline_begin() -> anyhow::Result<()> { + use super::super::result_set::tests::string_val; + use crate::client::Statement; + use crate::value::Value; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + + let mut mock = create_session_mock(); + + // No explicit begin_transaction should be called. + mock.expect_begin_transaction().never(); + + let mut seq = mockall::Sequence::new(); + + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + + // First call: Should have Selector::Begin + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin"), + } + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + read_timestamp: Some(prost_types::Timestamp { + seconds: 987654321, + nanos: 0, + }), + ..Default::default() + }); + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(rs), + ])))) + }); + + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + // Second call: Should have Selector::Id using the ID returned in the first call + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id"), + } + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(setup_select1()), + ])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // The read timestamp is not available until the first query is executed. + assert!(tx.read_timestamp().is_none()); + + for i in 0..2 { + let mut rs = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + + let row = rs.next().await.expect("Expected a row")?; + assert_eq!(row.raw_values(), [Value(string_val("1"))]); + + let result = rs.next().await; + assert!(result.is_none(), "Expected None, got {result:?}"); + + if i == 0 { + // Read timestamp becomes available. + assert_eq!( + tx.read_timestamp() + .expect("Expected read timestamp") + .seconds(), + 987654321 + ); + } + } + + Ok(()) + } + #[tokio::test] async fn execute_single_read() { use super::super::result_set::tests::string_val; @@ -687,9 +1132,9 @@ pub(crate) mod tests { assert_eq!(req.table, "Users"); assert_eq!(req.columns, vec!["Id".to_string(), "Name".to_string()]); - Ok(gaxi::grpc::tonic::Response::new(Box::pin( - tokio_stream::iter(vec![Ok(setup_select1())]), - ))) + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(setup_select1()), + ])))) }); let (db_client, _server) = setup_db_client(mock).await; @@ -705,4 +1150,1126 @@ pub(crate) mod tests { let result = rs.next().await; assert!(result.is_none(), "expected None, got {result:?}"); } + + #[tokio::test] + async fn execute_multi_read() -> anyhow::Result<()> { + use super::super::result_set::tests::string_val; + use crate::client::{KeySet, ReadRequest}; + use crate::value::Value; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + + let mut mock = create_session_mock(); + + // No explicit begin_transaction should be called. + mock.expect_begin_transaction().never(); + + let mut seq = mockall::Sequence::new(); + + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + + // First call: Should have Selector::Begin + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin"), + } + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + read_timestamp: Some(prost_types::Timestamp { + seconds: 987654321, + nanos: 0, + }), + ..Default::default() + }); + Ok(gaxi::grpc::tonic::Response::new(Box::pin( + tokio_stream::iter(vec![Ok(rs)]), + ))) + }); + + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + // Second call: Should have Selector::Id using the ID returned in the first call + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id"), + } + Ok(gaxi::grpc::tonic::Response::new(Box::pin( + tokio_stream::iter(vec![Ok(setup_select1())]), + ))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // The read timestamp is not available until the first query is executed. + assert!(tx.read_timestamp().is_none()); + + for i in 0..2 { + let read = ReadRequest::builder("Users", vec!["Id", "Name"]) + .with_keys(KeySet::all()) + .build(); + let mut rs = tx.execute_read(read).await?; + + let row = rs.next().await.expect("Expected a row")?; + assert_eq!(row.raw_values(), [Value(string_val("1"))]); + + let result = rs.next().await; + assert!(result.is_none(), "Expected None, got {result:?}"); + + if i == 0 { + // Read timestamp becomes available. + assert_eq!( + tx.read_timestamp() + .expect("Expected read timestamp") + .seconds(), + 987654321 + ); + } + } + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_failure_retry_success() -> anyhow::Result<()> { + use crate::value::Value; + use gaxi::grpc::tonic::Status; + use tonic::Response; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial query fails + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error"))); + + // 2. Explicit begin transaction succeeds + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + // Return a transaction with ID + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: Some(prost_types::Timestamp { + seconds: 123456789, + nanos: 0, + }), + ..Default::default() + })) + }); + + // 3. Retry of the query succeeds + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + // Ensure it uses the new transaction ID + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected Selector::Id"), + } + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let mut rs = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + + let row = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected a row but stream cleanly exhausted"))??; + assert_eq!( + row.raw_values(), + [Value(string_val("1"))], + "The parsed row value safely matched the underlying stream chunk" + ); + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_failure_retry_failure() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Status; + use tonic::Response; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial query fails + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error first"))); + + // 2. Explicit begin transaction succeeds + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: Some(prost_types::Timestamp { + seconds: 123456789, + nanos: 0, + }), + ..Default::default() + })) + }); + + // 3. Retry of the query fails again + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error second"))); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let rs_result = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await; + + assert!( + rs_result.is_err(), + "The failed execution bubbled upwards securely" + ); + let err_str = rs_result.unwrap_err().to_string(); + assert!( + err_str.contains("Internal error second"), + "Secondary error message accurately propagates: {}", + err_str + ); + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_failure_fallback_rpc_fails() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial query fails + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error query"))); + + // 2. Explicit begin transaction fails + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error begin tx"))); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let rs_result = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await; + + assert!( + rs_result.is_err(), + "The explicitly errored fallback boot securely propagated outwards" + ); + let err_str = rs_result.unwrap_err().to_string(); + assert!( + err_str.contains("Internal error begin tx"), + "Natively propagated specific BeginTx bounds: {}", + err_str + ); + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_read_failure_retry_success() -> anyhow::Result<()> { + use crate::client::{KeySet, ReadRequest}; + use crate::value::Value; + use gaxi::grpc::tonic::Status; + use tonic::Response; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial read fails + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error"))); + + // 2. Explicit begin transaction succeeds + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: None, + ..Default::default() + })) + }); + + // 3. Retry of the read succeeds + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + // Ensure it uses the new transaction ID + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected Selector::Id"), + } + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let read = ReadRequest::builder("Users", vec!["Id", "Name"]) + .with_keys(KeySet::all()) + .build(); + let mut rs = tx.execute_read(read).await?; + + let row = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected a row uniquely returned"))??; + assert_eq!( + row.raw_values(), + [Value(string_val("1"))], + "The macro correctly unpacked read arrays seamlessly" + ); + + Ok(()) + } + + #[tokio::test] + async fn single_use_query_send_error_returns_immediately() -> anyhow::Result<()> { + use crate::client::Statement; + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + + mock.expect_execute_streaming_sql() + .times(1) + .returning(|_| Err(Status::internal("Internal error single use query"))); + + mock.expect_begin_transaction().never(); + + let (db_client, _server) = setup_db_client(mock).await; + // single_use creates a Fixed selector + let tx = db_client.single_use().build(); + + let rs_result = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await; + + assert!(rs_result.is_err()); + let err_str = rs_result.unwrap_err().to_string(); + assert!(err_str.contains("Internal error single use query")); + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_already_started_query_send_error_returns_immediately() + -> anyhow::Result<()> { + use crate::client::Statement; + use gaxi::grpc::tonic::Status; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + mock.expect_begin_transaction().never(); + + // 1. First query executes successfully and implicitly starts the transaction. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |_req| { + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + read_timestamp: None, + ..Default::default() + }); + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(rs), + ])))) + }); + + // 2. Second query fails immediately upon send() + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error second query"))); + + let (db_client, _server) = setup_db_client(mock).await; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // Run first query (starts tx) + let mut rs = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + let _ = rs.next().await.expect("has row")?; + + // Run second query (fails) + let rs_result = tx + .execute_query(Statement::builder("SELECT 2").build()) + .await; + + assert!(rs_result.is_err()); + let err_str = rs_result.unwrap_err().to_string(); + assert!(err_str.contains("Internal error second query")); + + Ok(()) + } + + /// A wrapper that implements `tokio_stream::Stream` for a `mpsc::Receiver`. + /// Useful in mock setups to yield controlled streaming test responses. + struct ReceiverStream(mpsc::Receiver); + impl tokio_stream::Stream for ReceiverStream { + type Item = T; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.0.poll_recv(cx) + } + } + + #[tokio::test] + async fn execute_concurrent_queries_inline_begin() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + mock.expect_begin_transaction().never(); + + let mut seq = mockall::Sequence::new(); + let (tx_sender, rx_receiver) = mpsc::channel(1); + let rx_receiver = Arc::new(Mutex::new(Some(rx_receiver))); + + let task1_ready = Arc::new(Notify::new()); + let task1_ready_clone = Arc::clone(&task1_ready); + let tasks_started = Arc::new(Barrier::new(3)); + + // 1. First query: should include Selector::Begin + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + task1_ready_clone.notify_one(); + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin for first query"), + } + let rx = rx_receiver + .try_lock() + .expect("mutex poisoned") + .take() + .unwrap(); + Ok(Response::new(Box::pin(ReceiverStream(rx)))) + }); + + // 2. The other queries: should include populated Selector::Id + mock.expect_execute_streaming_sql() + .times(2) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id for other queries"), + } + + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let tx = Arc::new(tx); + + // Spawn 3 concurrent queries. + // Task 1 launches first and executes the first query. + let tx1 = Arc::clone(&tx); + let handle1 = tokio::spawn(async move { + let mut rs = tx1 + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + // Read the first result to get the transaction ID. + let _ = rs.next().await; + Ok::<_, crate::Error>(rs) + }); + + // Wait for Task 1 to reach the mock server. + task1_ready.notified().await; + + let tx2 = Arc::clone(&tx); + let tasks_started2 = Arc::clone(&tasks_started); + let handle2 = tokio::spawn(async move { + tasks_started2.wait().await; + tx2.execute_query(Statement::builder("SELECT 1").build()) + .await + }); + + let tx3 = Arc::clone(&tx); + let tasks_started3 = Arc::clone(&tasks_started); + let handle3 = tokio::spawn(async move { + tasks_started3.wait().await; + tx3.execute_query(Statement::builder("SELECT 1").build()) + .await + }); + + // Ensure both Tasks 2 and 3 have reached the barrier before proceeding. + tasks_started.wait().await; + + // Flush the scheduler on this single-threaded executor. + // This guarantees that Tasks 2 & 3 run until they both hit the internal + // selector Notify latch and become suspended. + tokio::task::yield_now().await; + + // Provide the first result (including the transaction ID) to Task 1. + // This transitions the selector to 'Started' and unblocks Tasks 2 and 3. + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + read_timestamp: Some(prost_types::Timestamp { + seconds: 987654321, + nanos: 0, + }), + ..Default::default() + }); + tx_sender.send(Ok(rs)).await.expect("channel broken"); + drop(tx_sender); + + // Collect all results + let mut rs1 = handle1.await??; + let mut rs2 = handle2.await??; + let mut rs3 = handle3.await??; + + // Verify the query results + assert!(rs1.next().await.is_none()); + + let row2 = rs2.next().await.expect("Expected a row")?; + assert_eq!(row2.raw_values(), [Value(string_val("1"))]); + assert!(rs2.next().await.is_none()); + + let row3 = rs3.next().await.expect("Expected a row")?; + assert_eq!(row3.raw_values(), [Value(string_val("1"))]); + assert!(rs3.next().await.is_none()); + + // Verify that the read timestamp was populated + assert_eq!(tx.read_timestamp().unwrap().seconds(), 987654321); + + Ok(()) + } + + #[tokio::test] + async fn execute_concurrent_queries_inline_begin_failed_cascade() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + let (tx_sender, rx_receiver) = mpsc::channel(1); + let rx_receiver = Arc::new(Mutex::new(Some(rx_receiver))); + + let task1_ready = Arc::new(Notify::new()); + let task1_ready_clone = Arc::clone(&task1_ready); + let tasks_started = Arc::new(Barrier::new(3)); + + // 1. Return a stream connected to tx_sender. + // We will use tx_sender later in the test to inject a failed first chunk. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |_req| { + task1_ready_clone.notify_one(); + let rx = rx_receiver + .try_lock() + .expect("mutex poisoned") + .take() + .unwrap(); + Ok(tonic::Response::new(Box::pin(ReceiverStream(rx)))) + }); + + // 2. Fallback BeginTransaction RPC fails + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Err(gaxi::grpc::tonic::Status::internal( + "Fallback BeginTransaction failed", + )) + }); + + // The other queries will never be executed. + mock.expect_execute_streaming_sql().times(0).returning(|_| { + panic!("Other queries should not launch after failure to start the transaction") + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let tx = Arc::new(tx); + + // Spawn 3 concurrent queries. + let tx1 = Arc::clone(&tx); + let handle1 = tokio::spawn(async move { + let mut rs = tx1 + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + rs.next().await.ok_or_else(|| { + crate::error::internal_error("stream exhausted (this should never happen)") + })??; + Ok::<_, crate::Error>(rs) + }); + + // Wait for Task 1 to reach the mock and transition the selector to Starting. + task1_ready.notified().await; + + let tx2 = Arc::clone(&tx); + let tasks_started2 = Arc::clone(&tasks_started); + let handle2 = tokio::spawn(async move { + tasks_started2.wait().await; + tx2.execute_query(Statement::builder("SELECT 1").build()) + .await + }); + + let tx3 = Arc::clone(&tx); + let tasks_started3 = Arc::clone(&tasks_started); + let handle3 = tokio::spawn(async move { + tasks_started3.wait().await; + tx3.execute_query(Statement::builder("SELECT 1").build()) + .await + }); + + // Ensure both Tasks 2 and 3 have reached the barrier before proceeding. + tasks_started.wait().await; + + // Flush the scheduler on this single-threaded executor. + // This guarantees that Tasks 2 & 3 run until they both hit the internal + // selector Notify latch and become suspended. + tokio::task::yield_now().await; + + // Push error to channel failing first query stream! + tx_sender + .send(Err(gaxi::grpc::tonic::Status::internal( + "Mocked boot failed", + ))) + .await + .expect("channel broken"); + drop(tx_sender); + + // Collect all results - all should fail with identical cached error! + let err1 = handle1.await?.unwrap_err().to_string(); + let err2 = handle2.await?.unwrap_err().to_string(); + let err3 = handle3.await?.unwrap_err().to_string(); + + assert!( + err1.contains("Fallback BeginTransaction failed"), + "err1: {}", + err1 + ); + assert!( + err2.contains("Fallback BeginTransaction failed"), + "err2: {}", + err2 + ); + assert!( + err3.contains("Fallback BeginTransaction failed"), + "err3: {}", + err3 + ); + + Ok(()) + } + + #[tokio::test] + async fn execute_concurrent_queries_inline_begin_stream_restart_deadlock_prevention() + -> crate::Result<()> { + let mut mock = create_session_mock(); + mock.expect_begin_transaction().never(); + + let mut seq = mockall::Sequence::new(); + + let (tx_sender, rx_receiver) = mpsc::channel(1); + let rx_receiver = Arc::new(Mutex::new(Some(rx_receiver))); + + let task1_ready = Arc::new(Notify::new()); + let task1_ready_clone = Arc::clone(&task1_ready); + let tasks_started = Arc::new(Barrier::new(3)); + + // 1. Task 1 initial query: Return a stream connected to tx_sender for error injection. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + // Return a stream connected to tx_sender. + // We will use tx_sender later in the test to inject a transient error. + task1_ready_clone.notify_one(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin for first query"), + } + let rx = rx_receiver + .try_lock() + .expect("mutex poisoned") + .take() + .unwrap(); + Ok(Response::new(Box::pin(ReceiverStream(rx)))) + }); + + // 2. Task 1 restart query: should include Selector::Begin, since + // it failed with a transient error. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Begin(_) => { + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + ..Default::default() + }); + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok(rs)])))) + } + _ => panic!("Expected Selector::Begin for stream restart query"), + } + }); + + // 3. Tasks 2 & 3: should include populated Selector::Id + mock.expect_execute_streaming_sql() + .times(2) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + } + _ => panic!("Expected Selector::Id for concurrent queries"), + } + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let tx = Arc::new(tx); + + let handle1_tx = Arc::clone(&tx); + let handle1 = tokio::spawn(async move { + let mut rs = handle1_tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + let _ = rs.next().await.ok_or_else(|| { + crate::error::internal_error("stream exhausted (this should never happen)") + })??; + Ok::<_, crate::Error>(rs) + }); + + // Wait for Task 1 to reach the mock and transition the selector to Starting. + task1_ready.notified().await; + + let handle2_tx = Arc::clone(&tx); + let tasks_started2 = Arc::clone(&tasks_started); + let handle2 = tokio::spawn(async move { + tasks_started2.wait().await; + let mut rs = handle2_tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + let _ = rs.next().await.ok_or_else(|| { + crate::error::internal_error("stream exhausted (this should never happen)") + })??; + Ok::<_, crate::Error>(rs) + }); + + let handle3_tx = Arc::clone(&tx); + let tasks_started3 = Arc::clone(&tasks_started); + let handle3 = tokio::spawn(async move { + tasks_started3.wait().await; + let mut rs = handle3_tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + let _ = rs.next().await.ok_or_else(|| { + crate::error::internal_error("stream exhausted (this should never happen)") + })??; + Ok::<_, crate::Error>(rs) + }); + + // Ensure both Tasks 2 and 3 have reached the barrier before proceeding. + tasks_started.wait().await; + + // Flush the scheduler on this single-threaded executor. + // This guarantees that Tasks 2 & 3 run until they both hit the internal + // selector Notify latch and become suspended. + tokio::task::yield_now().await; + + let grpc_status = Status::new(gaxi::grpc::tonic::Code::Unavailable, "transient error"); + tx_sender.send(Err(grpc_status)).await.expect("send failed"); + drop(tx_sender); + + // Collect and verify all results. + // handle.await returns Result, JoinError>. + // The first ? handles the potential JoinError (panic in the task), + // and the second ? handles the Spanner error. + let mut rs1 = handle1.await.expect("Task 1 panicked")?; + let mut rs2 = handle2.await.expect("Task 2 panicked")?; + let mut rs3 = handle3.await.expect("Task 3 panicked")?; + + // Verify that all results have been exhausted. + // (The tasks themselves already successfully read the first row). + assert!(rs1.next().await.is_none(), "Stream 1 should be exhausted"); + assert!(rs2.next().await.is_none(), "Stream 2 should be exhausted"); + assert!(rs3.next().await.is_none(), "Stream 3 should be exhausted"); + + Ok(()) + } + + #[tokio::test] + async fn execute_concurrent_queries_late_arrival_failure() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial query fails. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin for first query"), + } + Err(Status::internal("Initial inline-begin failed")) + }); + + // 2. Fallback BeginTransaction RPC also fails. + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Fallback BeginTransaction failed"))); + + // Any further attempts would panic because we haven't mocked them. + mock.expect_execute_streaming_sql().never(); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // First query: triggers the failure and transitions the state to Failed. + let err1 = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await + .expect_err("First query should fail"); + assert!( + err1.to_string() + .contains("Fallback BeginTransaction failed") + ); + + // Second query: starts AFTER the failure is already cached. + // It should immediately return the same error without invoking the mock server. + let err2 = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await + .expect_err("Late query should fail immediately"); + assert!( + err2.to_string() + .contains("Fallback BeginTransaction failed") + ); + + Ok(()) + } + + #[tokio::test] + async fn execute_concurrent_reads_inline_begin() -> anyhow::Result<()> { + use crate::client::{KeySet, ReadRequest}; + let mut mock = create_session_mock(); + mock.expect_begin_transaction().never(); + + let mut seq = mockall::Sequence::new(); + let (tx_sender, rx_receiver) = mpsc::channel(1); + let rx_receiver = Arc::new(Mutex::new(Some(rx_receiver))); + + let task1_ready = Arc::new(Notify::new()); + let task1_ready_clone = Arc::clone(&task1_ready); + let tasks_started = Arc::new(Barrier::new(3)); + + // 1. First read: should include Selector::Begin + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + task1_ready_clone.notify_one(); + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin for first read"), + } + + let rx = rx_receiver + .try_lock() + .expect("mutex poisoned") + .take() + .unwrap(); + Ok(Response::new(Box::pin(ReceiverStream(rx)))) + }); + + // 2. The other reads: should include populated Selector::Id + mock.expect_streaming_read() + .times(2) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id for other reads"), + } + + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let tx = Arc::new(tx); + + let read_req = ReadRequest::builder("Table", vec!["Col"]) + .with_keys(KeySet::all()) + .build(); + + // Spawn 3 concurrent reads. + let tx1 = Arc::clone(&tx); + let read1 = read_req.clone(); + let handle1 = tokio::spawn(async move { + let mut rs = tx1.execute_read(read1).await?; + let _ = rs.next().await; + Ok::<_, crate::Error>(rs) + }); + + task1_ready.notified().await; + + let tx2 = Arc::clone(&tx); + let read2 = read_req.clone(); + let tasks_started2 = Arc::clone(&tasks_started); + let handle2 = tokio::spawn(async move { + tasks_started2.wait().await; + let mut rs = tx2.execute_read(read2).await?; + let _ = rs.next().await; + Ok::<_, crate::Error>(rs) + }); + + let tx3 = Arc::clone(&tx); + let read3 = read_req.clone(); + let tasks_started3 = Arc::clone(&tasks_started); + let handle3 = tokio::spawn(async move { + tasks_started3.wait().await; + let mut rs = tx3.execute_read(read3).await?; + let _ = rs.next().await; + Ok::<_, crate::Error>(rs) + }); + + tasks_started.wait().await; + tokio::task::yield_now().await; + + // Provide the transaction ID. + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + ..Default::default() + }); + tx_sender.send(Ok(rs)).await.expect("send failed"); + drop(tx_sender); + + let mut rs1 = handle1.await.expect("Task 1 panicked")?; + let mut rs2 = handle2.await.expect("Task 2 panicked")?; + let mut rs3 = handle3.await.expect("Task 3 panicked")?; + + assert!(rs1.next().await.is_none()); + assert!(rs2.next().await.is_none()); + assert!(rs3.next().await.is_none()); + + Ok(()) + } + + #[tokio::test] + async fn execute_inline_begin_idempotent_update() -> anyhow::Result<()> { + let (db_client, _server) = setup_db_client(create_session_mock()).await; + // Access internal state for unit testing. + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let id1 = bytes::Bytes::from_static(b"tx1"); + let id2 = bytes::Bytes::from_static(b"tx2"); + + // 1. Initial update. + tx.context.transaction_selector.update(id1.clone(), None)?; + assert_eq!( + tx.context + .transaction_selector + .selector() + .await? + .id() + .unwrap(), + &id1 + ); + + // 2. Redundant update with same ID should result in an error. + // The implementation explicitly prevents redundant updates to ensure state consistency. + let err1 = tx + .context + .transaction_selector + .update(id1.clone(), None) + .expect_err("Redundant update should fail"); + assert!(err1.to_string().contains("already Started or Failed")); + + // 3. Update with DIFFERENT ID after already Started should also fail. + let err2 = tx + .context + .transaction_selector + .update(id2, None) + .expect_err("Update after Started should fail"); + assert!(err2.to_string().contains("already Started or Failed")); + + Ok(()) + } + + #[tokio::test] + async fn execute_inline_begin_with_transient_failure() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. First attempt fails transiently. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::new(Code::Unavailable, "Transient 1"))); + + // 2. Fallback BeginTransaction succeeds. + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + ..Default::default() + })) + }); + + // 3. The manual retry of the query (which happens after explicit begin fallback). + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let mut rs = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + assert!(rs.next().await.is_some()); + assert!(rs.next().await.is_none()); + + Ok(()) + } } diff --git a/src/spanner/src/read_write_transaction.rs b/src/spanner/src/read_write_transaction.rs index 9a84b1bb87..78ef3ab5d4 100644 --- a/src/spanner/src/read_write_transaction.rs +++ b/src/spanner/src/read_write_transaction.rs @@ -33,7 +33,9 @@ use crate::precommit::PrecommitTokenTracker; use crate::read_only_transaction::ReadContext; use crate::result_set::ResultSet; use crate::statement::Statement; +use crate::transaction_retry_policy::is_aborted; use std::sync::Arc; +use std::sync::Mutex; use std::sync::atomic::{AtomicI64, Ordering}; /// A builder for [ReadWriteTransaction]. @@ -42,6 +44,7 @@ pub(crate) struct ReadWriteTransactionBuilder { client: DatabaseClient, options: TransactionOptions, transaction_tag: Option, + explicit_begin: bool, } impl ReadWriteTransactionBuilder { @@ -50,6 +53,7 @@ impl ReadWriteTransactionBuilder { client, options: TransactionOptions::default().set_read_write(ReadWrite::default()), transaction_tag: None, + explicit_begin: false, } } @@ -83,24 +87,58 @@ impl ReadWriteTransactionBuilder { self } - pub(crate) async fn begin_transaction(&self) -> crate::Result { - let mut request = BeginTransactionRequest::default() - .set_session(self.client.session.name.clone()) - .set_options(self.options.clone()); - if let Some(tag) = &self.transaction_tag { - request = request.set_request_options( - crate::model::RequestOptions::default().set_transaction_tag(tag.clone()), - ); - } + /// Sets whether the transaction should be explicitly started using a `BeginTransaction` RPC. + /// + /// By default, the Spanner client will inline the `BeginTransaction` call with the first query + /// or DML statement in the transaction. This reduces the number of round-trips to Spanner that + /// are needed for a transaction. Setting this option to `true` can be beneficial for specific + /// transaction shapes: + /// + /// 1. When the transaction executes multiple parallel queries at the start of the transaction. + /// Only one query can include a `BeginTransaction` option, and all other queries must wait for + /// the first query to return the first result before they can proceed to execute. A + /// `BeginTransaction` RPC will quickly return a transaction ID and allow all queries to start + /// execution in parallel once the transaction ID has been returned. + /// 2. When the first statement in the transaction could fail. If the statement fails, then it + /// will also not start a transaction and return a transaction ID. The transaction will then + /// fall back to executing a `BeginTransaction` RPC and retry the first statement. + /// + /// Default is `false` (inline begin). + pub fn with_explicit_begin_transaction(mut self, explicit: bool) -> Self { + self.explicit_begin = explicit; + self + } - // TODO(#4972): make request options configurable - let response = self - .client - .spanner - .begin_transaction(request, RequestOptions::default()) - .await?; + pub(crate) async fn build(&self) -> crate::Result { + let transaction_selector = if self.explicit_begin { + let mut request = BeginTransactionRequest::default() + .set_session(self.client.session.name.clone()) + .set_options(self.options.clone()); + if let Some(tag) = &self.transaction_tag { + request = request.set_request_options( + crate::model::RequestOptions::default().set_transaction_tag(tag.clone()), + ); + } + + // TODO(#4972): make request options configurable + let response = self + .client + .spanner + .begin_transaction(request, RequestOptions::default()) + .await?; + + crate::read_only_transaction::ReadContextTransactionSelector::Fixed( + TransactionSelector::default().set_id(response.id), + None, + ) + } else { + crate::read_only_transaction::ReadContextTransactionSelector::Lazy(Arc::new( + Mutex::new(crate::read_only_transaction::TransactionState::NotStarted( + self.options.clone(), + )), + )) + }; - let transaction_selector = TransactionSelector::default().set_id(response.id); Ok(ReadWriteTransaction { context: ReadContext { client: self.client.clone(), @@ -120,6 +158,64 @@ pub struct ReadWriteTransaction { seqno: Arc, } +/// Helper macro to execute a DML or BatchDML RPC with retry logic if the +/// request included a BeginTransaction option. +macro_rules! execute_with_retry { + ($self:expr, $request:ident, $rpc_method:ident, $extract_id:expr) => {{ + let is_starting = matches!( + $request + .transaction + .as_ref() + .and_then(|t| t.selector.as_ref()), + Some(Selector::Begin(_)) + ); + + let response_result = $self + .context + .client + .spanner + .$rpc_method($request.clone(), RequestOptions::default()) + .await; + + let response = match response_result { + Ok(response) => { + if is_starting { + let id = $extract_id(&response).ok_or_else(|| { + crate::error::internal_error("Transaction ID was not returned by Spanner") + })?; + $self.context.transaction_selector.update(id, None)?; + } + response + } + Err(error) => { + if !is_starting { + return Err(error); + } + if is_aborted(&error) { + return Err(error); + } + + $self + .context + .transaction_selector + .begin_explicitly(&$self.context.client) + .await?; + + $request.transaction = Some($self.context.transaction_selector.selector().await?); + + $self + .context + .client + .spanner + .$rpc_method($request.clone(), RequestOptions::default()) + .await? + } + }; + + response + }}; +} + impl ReadWriteTransaction { /// Executes a query using this transaction. pub async fn execute_query>( @@ -144,16 +240,23 @@ impl ReadWriteTransaction { .into() .into_request() .set_session(self.context.client.session.name.clone()) - .set_transaction(self.context.transaction_selector.clone()) + .set_transaction(self.context.transaction_selector.selector().await?) .set_seqno(seqno); request.request_options = self.context.amend_request_options(request.request_options); - let response = self - .context - .client - .spanner - .execute_sql(request, RequestOptions::default()) - .await?; + let response = execute_with_retry!( + self, + request, + execute_sql, + |response: &crate::model::ResultSet| { + response + .metadata + .as_ref() + .and_then(|md| md.transaction.as_ref()) + .map(|t| t.id.clone()) + } + ); + self.context .precommit_token_tracker .update(response.precommit_token); @@ -237,41 +340,43 @@ impl ReadWriteTransaction { pub async fn execute_batch_update(&self, batch: BatchDml) -> crate::Result> { let seqno = self.seqno.fetch_add(1, Ordering::SeqCst); - let statements: Vec = batch - .statements + let BatchDml { + statements, + request_options, + } = batch; + let statements: Vec = statements .into_iter() .map(|stmt: crate::statement::Statement| stmt.into_batch_statement()) .collect(); - let request = ExecuteBatchDmlRequest::default() + let mut request = ExecuteBatchDmlRequest::default() .set_session(self.context.client.session.name.clone()) - .set_transaction(self.context.transaction_selector.clone()) + .set_transaction(self.context.transaction_selector.selector().await?) .set_seqno(seqno) .set_statements(statements) - .set_or_clear_request_options( - self.context.amend_request_options(batch.request_options), - ); - - let response_result = self - .context - .client - .spanner - .execute_batch_dml(request, RequestOptions::default()) - .await; + .set_or_clear_request_options(self.context.amend_request_options(request_options)); - match response_result { - Ok(response) => { - self.context - .precommit_token_tracker - .update(response.precommit_token.clone()); - crate::batch_dml::process_response(response) + let response = execute_with_retry!( + self, + request, + execute_batch_dml, + |response: &crate::model::ExecuteBatchDmlResponse| { + response + .result_sets + .first() + .and_then(|rs| rs.metadata.as_ref()) + .and_then(|md| md.transaction.as_ref()) + .map(|t| t.id.clone()) } - Err(e) => Err(e), - } + ); + self.context + .precommit_token_tracker + .update(response.precommit_token.clone()); + crate::batch_dml::process_response(response) } - pub(crate) fn transaction_id(&self) -> crate::Result { - match &self.context.transaction_selector.selector { + pub(crate) async fn transaction_id(&self) -> crate::Result { + match &self.context.transaction_selector.selector().await?.selector { Some(Selector::Id(id)) => Ok(id.clone()), _ => Err(internal_error("Transaction ID is missing")), } @@ -279,7 +384,7 @@ impl ReadWriteTransaction { /// Commits the transaction. pub(crate) async fn commit(self) -> crate::Result { - let transaction_id = self.transaction_id()?; + let transaction_id = self.transaction_id().await?; let precommit_token = self.context.precommit_token_tracker.get(); let request = CommitRequest::default() .set_session(self.context.client.session.name.clone()) @@ -319,7 +424,7 @@ impl ReadWriteTransaction { /// Rolls back the transaction. pub(crate) async fn rollback(self) -> crate::Result<()> { - let transaction_id = self.transaction_id()?; + let transaction_id = self.transaction_id().await?; let request = RollbackRequest::default() .set_session(self.context.client.session.name.clone()) @@ -343,6 +448,9 @@ mod tests { use gaxi::grpc::tonic; use spanner_grpc_mock::google::spanner::v1; use std::fmt::Debug; + use v1::result_set_stats::RowCount; + use v1::transaction_options::Mode; + use v1::transaction_selector::Selector; #[test] fn auto_traits() { @@ -351,28 +459,61 @@ mod tests { } #[tokio::test] - async fn read_write_transaction_commit_retry() { + async fn read_write_transaction_commit_retry_explicit() -> anyhow::Result<()> { + run_read_write_transaction_commit_retry(true).await + } + + #[tokio::test] + async fn read_write_transaction_commit_retry_inline() -> anyhow::Result<()> { + run_read_write_transaction_commit_retry(false).await + } + + async fn run_read_write_transaction_commit_retry(explicit_begin: bool) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.session, - "projects/p/instances/i/databases/d/sessions/123" - ); - Ok(tonic::Response::new(v1::Transaction { - id: vec![0, 0, 7], - ..Default::default() - })) - }); + if explicit_begin { + mock.expect_begin_transaction().once().returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + Ok(tonic::Response::new(v1::Transaction { + id: vec![0, 0, 7], + ..Default::default() + })) + }); + } // execute_update returns a precommit token. - mock.expect_execute_sql().once().returning(|req| { + mock.expect_execute_sql().once().returning(move |req| { let req = req.into_inner(); assert_eq!(req.sql, "UPDATE Users SET Name = 'Bob' WHERE Id = 1"); + + if !explicit_begin { + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!(selector, Selector::Begin(_))); + } + + let mut metadata = v1::ResultSetMetadata { + row_type: Some(v1::StructType { fields: vec![] }), + ..Default::default() + }; + if !explicit_begin { + metadata.transaction = Some(v1::Transaction { + id: vec![0, 0, 7], + ..Default::default() + }); + } + Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(metadata), stats: Some(v1::ResultSetStats { - row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)), + row_count: Some(RowCount::RowCountExact(1)), ..Default::default() }), precommit_token: Some(v1::MultiplexedSessionPrecommitToken { @@ -383,94 +524,135 @@ mod tests { })) }); + let mut seq = mockall::Sequence::new(); + // Simulate that commit returns a precommit token in the response. // This would normally not happen, but we test it here to verify // that the commit is retried. - mock.expect_commit().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.precommit_token, - Some(v1::MultiplexedSessionPrecommitToken { - precommit_token: vec![101], - seq_num: 1, - }) - ); - Ok(tonic::Response::new(v1::CommitResponse { - commit_timestamp: Some(prost_types::Timestamp { - seconds: 1000, - nanos: 0, - }), - multiplexed_session_retry: Some( - v1::commit_response::MultiplexedSessionRetry::PrecommitToken( - v1::MultiplexedSessionPrecommitToken { - precommit_token: vec![202], - seq_num: 2, - }, + mock.expect_commit() + .once() + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.precommit_token, + Some(v1::MultiplexedSessionPrecommitToken { + precommit_token: vec![101], + seq_num: 1, + }) + ); + Ok(tonic::Response::new(v1::CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 1000, + nanos: 0, + }), + multiplexed_session_retry: Some( + v1::commit_response::MultiplexedSessionRetry::PrecommitToken( + v1::MultiplexedSessionPrecommitToken { + precommit_token: vec![202], + seq_num: 2, + }, + ), ), - ), - ..Default::default() - })) - }); + ..Default::default() + })) + }); // Second commit retry is automatically issued with the new token - mock.expect_commit().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.precommit_token, - Some(v1::MultiplexedSessionPrecommitToken { - precommit_token: vec![202], - seq_num: 2, - }) - ); - Ok(tonic::Response::new(v1::CommitResponse { - commit_timestamp: Some(prost_types::Timestamp { - seconds: 1001, - nanos: 0, - }), - ..Default::default() - })) - }); + mock.expect_commit() + .once() + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.precommit_token, + Some(v1::MultiplexedSessionPrecommitToken { + precommit_token: vec![202], + seq_num: 2, + }) + ); + Ok(tonic::Response::new(v1::CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 1001, + nanos: 0, + }), + ..Default::default() + })) + }); let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() - .await - .expect("Failed to build transaction"); + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; let count = tx .execute_update("UPDATE Users SET Name = 'Bob' WHERE Id = 1") - .await - .unwrap(); + .await?; assert_eq!(count, 1); - let timestamp = tx.commit().await.unwrap(); + let timestamp = tx.commit().await?; assert_eq!(timestamp.seconds(), 1001); + Ok(()) } #[tokio::test] - async fn read_write_transaction_execute_update() { + async fn read_write_transaction_execute_update_explicit() { + run_read_write_transaction_execute_update(true).await; + } + + #[tokio::test] + async fn read_write_transaction_execute_update_inline() { + run_read_write_transaction_execute_update(false).await; + } + + async fn run_read_write_transaction_execute_update(explicit_begin: bool) { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.session, - "projects/p/instances/i/databases/d/sessions/123" - ); - Ok(tonic::Response::new(v1::Transaction { - id: vec![1, 2, 3], - ..Default::default() - })) - }); + if explicit_begin { + mock.expect_begin_transaction().once().returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + Ok(tonic::Response::new(v1::Transaction { + id: vec![1, 2, 3], + ..Default::default() + })) + }); + } - mock.expect_execute_sql().once().returning(|req| { + mock.expect_execute_sql().once().returning(move |req| { let req = req.into_inner(); assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1"); assert_eq!(req.seqno, 1); + + if !explicit_begin { + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!(selector, Selector::Begin(_))); + } + + let mut metadata = v1::ResultSetMetadata { + row_type: Some(v1::StructType { fields: vec![] }), + ..Default::default() + }; + if !explicit_begin { + metadata.transaction = Some(v1::Transaction { + id: vec![1, 2, 3], + ..Default::default() + }); + } + Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(metadata), stats: Some(v1::ResultSetStats { - row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)), + row_count: Some(RowCount::RowCountExact(1)), ..Default::default() }), ..Default::default() @@ -501,7 +683,8 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .with_explicit_begin_transaction(explicit_begin) + .build() .await .expect("Failed to build transaction"); let count = tx @@ -515,20 +698,55 @@ mod tests { } #[tokio::test] - async fn read_write_transaction_execute_update_invalid_stats() { + async fn read_write_transaction_execute_update_invalid_stats_explicit() -> anyhow::Result<()> { + run_read_write_transaction_execute_update_invalid_stats(true).await + } + + #[tokio::test] + async fn read_write_transaction_execute_update_invalid_stats_inline() -> anyhow::Result<()> { + run_read_write_transaction_execute_update_invalid_stats(false).await + } + + async fn run_read_write_transaction_execute_update_invalid_stats( + explicit_begin: bool, + ) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|_| { - Ok(tonic::Response::new(v1::Transaction { - id: vec![1, 2, 3], + if explicit_begin { + mock.expect_begin_transaction().once().returning(|_| { + Ok(tonic::Response::new(v1::Transaction { + id: vec![1, 2, 3], + ..Default::default() + })) + }); + } + + mock.expect_execute_sql().once().returning(move |req| { + let req = req.into_inner(); + if !explicit_begin { + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!(selector, Selector::Begin(_))); + } + + let mut metadata = v1::ResultSetMetadata { + row_type: Some(v1::StructType { fields: vec![] }), ..Default::default() - })) - }); + }; + if !explicit_begin { + metadata.transaction = Some(v1::Transaction { + id: vec![1, 2, 3], + ..Default::default() + }); + } - mock.expect_execute_sql().once().returning(|_| { Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(metadata), stats: Some(v1::ResultSetStats { - row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(1)), + row_count: Some(RowCount::RowCountLowerBound(1)), ..Default::default() }), ..Default::default() @@ -538,9 +756,9 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() - .await - .expect("Failed to build transaction"); + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; let result = tx .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1") @@ -552,98 +770,178 @@ mod tests { "Error did not contain expected message: {:?}", err ); + Ok(()) + } + + #[tokio::test] + async fn read_write_transaction_rollback_explicit() -> anyhow::Result<()> { + run_read_write_transaction_rollback(true).await } #[tokio::test] - async fn read_write_transaction_rollback() { + async fn read_write_transaction_rollback_inline() -> anyhow::Result<()> { + run_read_write_transaction_rollback(false).await + } + + async fn run_read_write_transaction_rollback(explicit_begin: bool) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.session, - "projects/p/instances/i/databases/d/sessions/123" - ); - Ok(tonic::Response::new(v1::Transaction { - id: vec![9, 9, 9], - ..Default::default() - })) - }); + let transaction_id = vec![9, 9, 9]; + + if explicit_begin { + let id = transaction_id.clone(); + mock.expect_begin_transaction().once().returning(move |_| { + Ok(tonic::Response::new(v1::Transaction { + id: id.clone(), + ..Default::default() + })) + }); + } else { + let id = transaction_id.clone(); + mock.expect_execute_sql().once().returning(move |req| { + let req = req.into_inner(); + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!(selector, Selector::Begin(_))); + + Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(v1::ResultSetMetadata { + transaction: Some(v1::Transaction { + id: id.clone(), + ..Default::default() + }), + ..Default::default() + }), + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) + }); + } - mock.expect_rollback().once().returning(|req| { + let id = transaction_id.clone(); + mock.expect_rollback().once().returning(move |req| { let req = req.into_inner(); assert_eq!( req.session, "projects/p/instances/i/databases/d/sessions/123" ); - assert_eq!(req.transaction_id, vec![9, 9, 9]); + assert_eq!(req.transaction_id, id); Ok(tonic::Response::new(())) }); let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() - .await - .expect("Failed to build transaction"); + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; + + if !explicit_begin { + tx.execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1") + .await + .expect("Failed to execute update"); + } + + tx.rollback().await?; + Ok(()) + } - tx.rollback().await.expect("Failed to rollback"); + #[tokio::test] + async fn read_write_transaction_execute_batch_update_explicit() -> anyhow::Result<()> { + run_read_write_transaction_execute_batch_update(true).await } #[tokio::test] - async fn read_write_transaction_execute_batch_update() -> anyhow::Result<()> { + async fn read_write_transaction_execute_batch_update_inline() -> anyhow::Result<()> { + run_read_write_transaction_execute_batch_update(false).await + } + + async fn run_read_write_transaction_execute_batch_update( + explicit_begin: bool, + ) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|_| { - Ok(tonic::Response::new(v1::Transaction { - id: vec![4, 5, 6], - ..Default::default() - })) - }); + if explicit_begin { + mock.expect_begin_transaction().once().returning(|_| { + Ok(tonic::Response::new(v1::Transaction { + id: vec![4, 5, 6], + ..Default::default() + })) + }); + } - mock.expect_execute_batch_dml().once().returning(|req| { - let req = req.into_inner(); - assert_eq!(req.statements.len(), 2); - assert_eq!( - req.statements[0].sql, - "UPDATE Users SET Name = 'Alice' WHERE Id = 1" - ); - assert_eq!( - req.statements[1].sql, - "UPDATE Users SET Name = 'Bob' WHERE Id = 2" - ); + mock.expect_execute_batch_dml() + .once() + .returning(move |req| { + let req = req.into_inner(); + assert_eq!(req.statements.len(), 2); + assert_eq!( + req.statements[0].sql, + "UPDATE Users SET Name = 'Alice' WHERE Id = 1" + ); + assert_eq!( + req.statements[1].sql, + "UPDATE Users SET Name = 'Bob' WHERE Id = 2" + ); + + if !explicit_begin { + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + assert!(matches!(selector, Selector::Begin(_))); + } - Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse { - result_sets: vec![ - v1::ResultSet { - stats: Some(v1::ResultSetStats { - row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)), - ..Default::default() - }), - ..Default::default() - }, - v1::ResultSet { - stats: Some(v1::ResultSetStats { - row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)), - ..Default::default() - }), + let mut metadata = v1::ResultSetMetadata { + ..Default::default() + }; + if !explicit_begin { + metadata.transaction = Some(v1::Transaction { + id: vec![4, 5, 6], ..Default::default() - }, - ], - status: Some(spanner_grpc_mock::google::rpc::Status { - code: 0, - message: "OK".into(), - details: vec![], - }), - ..Default::default() - })) - }); - - let (db_client, _server) = setup_db_client(mock).await; + }); + } - let tx = ReadWriteTransactionBuilder::new(db_client) - .begin_transaction() - .await?; + Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse { + result_sets: vec![ + v1::ResultSet { + metadata: Some(metadata), + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }, + v1::ResultSet { + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }, + ], + status: Some(spanner_grpc_mock::google::rpc::Status { + code: 0, + message: "OK".into(), + details: vec![], + }), + ..Default::default() + })) + }); + + let (db_client, _server) = setup_db_client(mock).await; + + let tx = ReadWriteTransactionBuilder::new(db_client) + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; let batch = BatchDml::builder() .add_statement("UPDATE Users SET Name = 'Alice' WHERE Id = 1") @@ -656,38 +954,77 @@ mod tests { } #[tokio::test] - async fn read_write_transaction_execute_batch_update_partial_failure() -> anyhow::Result<()> { + async fn read_write_transaction_execute_batch_update_partial_failure_explicit() + -> anyhow::Result<()> { + run_read_write_transaction_execute_batch_update_partial_failure(true).await + } + + #[tokio::test] + async fn read_write_transaction_execute_batch_update_partial_failure_inline() + -> anyhow::Result<()> { + run_read_write_transaction_execute_batch_update_partial_failure(false).await + } + + async fn run_read_write_transaction_execute_batch_update_partial_failure( + explicit_begin: bool, + ) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|_| { - Ok(tonic::Response::new(v1::Transaction { - id: vec![7, 8, 9], - ..Default::default() - })) - }); + if explicit_begin { + mock.expect_begin_transaction().once().returning(|_| { + Ok(tonic::Response::new(v1::Transaction { + id: vec![7, 8, 9], + ..Default::default() + })) + }); + } - mock.expect_execute_batch_dml().once().returning(|_| { - Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse { - result_sets: vec![v1::ResultSet { - stats: Some(v1::ResultSetStats { - row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)), + mock.expect_execute_batch_dml() + .once() + .returning(move |req| { + let req = req.into_inner(); + if !explicit_begin { + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + assert!(matches!(selector, Selector::Begin(_))); + } + + let mut metadata = v1::ResultSetMetadata { + ..Default::default() + }; + if !explicit_begin { + metadata.transaction = Some(v1::Transaction { + id: vec![7, 8, 9], ..Default::default() + }); + } + + Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse { + result_sets: vec![v1::ResultSet { + metadata: Some(metadata), + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }], + status: Some(spanner_grpc_mock::google::rpc::Status { + code: gaxi::grpc::tonic::Code::AlreadyExists as i32, + message: "row already exists".into(), + details: vec![], }), ..Default::default() - }], - status: Some(spanner_grpc_mock::google::rpc::Status { - code: gaxi::grpc::tonic::Code::AlreadyExists as i32, - message: "row already exists".into(), - details: vec![], - }), - ..Default::default() - })) - }); + })) + }); let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client) - .begin_transaction() + .with_explicit_begin_transaction(explicit_begin) + .build() .await?; let batch = BatchDml::builder() @@ -711,94 +1048,236 @@ mod tests { } #[tokio::test] - async fn read_write_transaction_execute_multiple_updates() { + async fn read_write_transaction_execute_multiple_updates_explicit() -> anyhow::Result<()> { + run_read_write_transaction_execute_multiple_updates(true).await + } + + #[tokio::test] + async fn read_write_transaction_execute_multiple_updates_inline() -> anyhow::Result<()> { + run_read_write_transaction_execute_multiple_updates(false).await + } + + async fn run_read_write_transaction_execute_multiple_updates( + explicit_begin: bool, + ) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.session, - "projects/p/instances/i/databases/d/sessions/123" - ); - Ok(tonic::Response::new(v1::Transaction { - id: vec![4, 5, 6], - ..Default::default() - })) - }); + if explicit_begin { + mock.expect_begin_transaction().once().returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + Ok(tonic::Response::new(v1::Transaction { + id: vec![4, 5, 6], + ..Default::default() + })) + }); + } - let counter = Arc::new(AtomicI64::new(1)); - mock.expect_execute_sql().times(3).returning(move |req| { - let req = req.into_inner(); - assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1"); - let c = counter.fetch_add(1, Ordering::SeqCst); - assert_eq!(req.seqno, c); + let mut seq = mockall::Sequence::new(); - Ok(tonic::Response::new(v1::ResultSet { - stats: Some(v1::ResultSetStats { - row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)), + // First update + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1"); + assert_eq!(req.seqno, 1); + + let mut metadata = v1::ResultSetMetadata { ..Default::default() - }), - ..Default::default() - })) - }); + }; + + if !explicit_begin { + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + assert!(matches!(selector, Selector::Begin(_))); + metadata.transaction = Some(v1::Transaction { + id: vec![4, 5, 6], + ..Default::default() + }); + } else { + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id"), + } + } + + Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(metadata), + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) + }); + + // Second update + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1"); + assert_eq!(req.seqno, 2); + + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id"), + } + + Ok(tonic::Response::new(v1::ResultSet { + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) + }); + + // Third update + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1"); + assert_eq!(req.seqno, 3); + + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id"), + } + + Ok(tonic::Response::new(v1::ResultSet { + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) + }); let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() - .await - .expect("Failed to build transaction"); + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; for i in 1..=3 { let count = tx .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1") .await - .unwrap_or_else(|_| panic!("Failed to execute update {}", i)); + .map_err(|e| anyhow::anyhow!("Failed to execute update {}: {:?}", i, e))?; assert_eq!(count, 1); } + Ok(()) + } + + #[tokio::test] + async fn read_write_transaction_execute_query_explicit() -> anyhow::Result<()> { + run_read_write_transaction_execute_query(true).await } #[tokio::test] - async fn read_write_transaction_execute_query() { + async fn read_write_transaction_execute_query_inline() -> anyhow::Result<()> { + run_read_write_transaction_execute_query(false).await + } + + async fn run_read_write_transaction_execute_query(explicit_begin: bool) -> anyhow::Result<()> { use crate::client::Statement; let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.session, - "projects/p/instances/i/databases/d/sessions/123" - ); - Ok(tonic::Response::new(v1::Transaction { - id: vec![7, 8, 9], - ..Default::default() - })) - }); + if explicit_begin { + mock.expect_begin_transaction().once().returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + Ok(tonic::Response::new(v1::Transaction { + id: vec![7, 8, 9], + ..Default::default() + })) + }); + } - mock.expect_execute_streaming_sql().once().returning(|req| { + mock.expect_execute_streaming_sql().once().returning(move |req| { let req = req.into_inner(); assert_eq!(req.sql, "SELECT 1"); // Queries do not need to include a sequence number. assert_eq!(req.seqno, 0); - assert_eq!( - req.transaction, - Some(v1::TransactionSelector { - selector: Some(v1::transaction_selector::Selector::Id(vec![7, 8, 9])) - }) - ); + if !explicit_begin { + let transaction = req.transaction.as_ref().expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!(selector, Selector::Begin(_))); + } else { + assert_eq!( + req.transaction, + Some(v1::TransactionSelector { + selector: Some(Selector::Id(vec![7, 8, 9])) + }) + ); + } type StreamType = ::ExecuteStreamingSqlStream; - let stream: tokio_stream::Empty> = tokio_stream::empty(); + + let mut metadata = v1::ResultSetMetadata { + row_type: Some(v1::StructType { fields: vec![] }), + ..Default::default() + }; + if !explicit_begin { + metadata.transaction = Some(v1::Transaction { + id: vec![7, 8, 9], + ..Default::default() + }); + } + + let first_response = v1::PartialResultSet { + metadata: Some(metadata), + ..Default::default() + }; + + let stream = tokio_stream::iter(vec![Ok(first_response)]); Ok(tonic::Response::new(Box::pin(stream) as StreamType)) }); let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() - .await - .expect("Failed to build transaction"); + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; let mut rs = tx .execute_query(Statement::builder("SELECT 1").build()) @@ -807,6 +1286,7 @@ mod tests { let result = rs.next().await; assert!(result.is_none(), "expected None, got empty stream"); + Ok(()) } #[tokio::test] @@ -823,7 +1303,7 @@ mod tests { let options = req.options.expect("missing transaction options"); let mode = options.mode.expect("missing mode"); match mode { - v1::transaction_options::Mode::ReadWrite(rw) => { + Mode::ReadWrite(rw) => { assert_eq!( rw.read_lock_mode, v1::transaction_options::read_write::ReadLockMode::Pessimistic as i32 @@ -848,44 +1328,148 @@ mod tests { let _tx = ReadWriteTransactionBuilder::new(db_client.clone()) .with_isolation_level(IsolationLevel::Serializable) .with_read_lock_mode(ReadLockMode::Pessimistic) - .begin_transaction() + .build() .await .expect("Failed to build transaction"); } #[tokio::test] - async fn read_write_transaction_tracks_highest_precommit_token() { + async fn read_write_transaction_tracks_highest_precommit_token_explicit() -> anyhow::Result<()> + { + run_read_write_transaction_tracks_highest_precommit_token(true).await + } + + #[tokio::test] + async fn read_write_transaction_tracks_highest_precommit_token_inline() -> anyhow::Result<()> { + run_read_write_transaction_tracks_highest_precommit_token(false).await + } + + async fn run_read_write_transaction_tracks_highest_precommit_token( + explicit_begin: bool, + ) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|_| { - Ok(tonic::Response::new(v1::Transaction { - id: vec![4, 2], - ..Default::default() - })) - }); + if explicit_begin { + mock.expect_begin_transaction().once().returning(|_| { + Ok(tonic::Response::new(v1::Transaction { + id: vec![4, 2], + ..Default::default() + })) + }); + } - // 3 sequential updates returning tokens [seq 2, seq 5, seq 3] - let tokens_iter = vec![2, 5, 3].into_iter(); - let counter_mutex = std::sync::Mutex::new(tokens_iter); + let mut seq = mockall::Sequence::new(); - mock.expect_execute_sql().times(3).returning(move |_req| { - let seq = counter_mutex - .lock() - .expect("Failed to lock mutex") - .next() - .expect("Failed to get next token"); - Ok(tonic::Response::new(v1::ResultSet { - stats: Some(v1::ResultSetStats { - row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)), + // First update + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + let mut metadata = v1::ResultSetMetadata { ..Default::default() - }), - precommit_token: Some(v1::MultiplexedSessionPrecommitToken { - precommit_token: vec![seq as u8], - seq_num: seq, - }), - ..Default::default() - })) - }); + }; + + if !explicit_begin { + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + assert!(matches!(selector, Selector::Begin(_))); + metadata.transaction = Some(v1::Transaction { + id: vec![4, 2], + ..Default::default() + }); + } else { + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Id(id) => { + assert_eq!(id, vec![4, 2]); + } + _ => panic!("Expected Selector::Id"), + } + } + + Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(metadata), + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + precommit_token: Some(v1::MultiplexedSessionPrecommitToken { + precommit_token: vec![2], + seq_num: 2, + }), + ..Default::default() + })) + }); + + // Second update + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Id(id) => { + assert_eq!(id, vec![4, 2]); + } + _ => panic!("Expected Selector::Id"), + } + + Ok(tonic::Response::new(v1::ResultSet { + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + precommit_token: Some(v1::MultiplexedSessionPrecommitToken { + precommit_token: vec![5], + seq_num: 5, + }), + ..Default::default() + })) + }); + + // Third update + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Id(id) => { + assert_eq!(id, vec![4, 2]); + } + _ => panic!("Expected Selector::Id"), + } + + Ok(tonic::Response::new(v1::ResultSet { + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + precommit_token: Some(v1::MultiplexedSessionPrecommitToken { + precommit_token: vec![3], + seq_num: 3, + }), + ..Default::default() + })) + }); // Commit should only use the highest token (seq 5) mock.expect_commit().once().returning(|req| { @@ -908,9 +1492,9 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() - .await - .expect("Failed to build transaction"); + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; for _ in 0..3 { tx.execute_update("UPDATE Y") @@ -919,74 +1503,424 @@ mod tests { } let ts = tx.commit().await.expect("Failed to commit transaction"); assert_eq!(ts.seconds(), 12345); + Ok(()) + } + + #[tokio::test] + async fn read_write_transaction_commit_retry_exactly_once_explicit() -> anyhow::Result<()> { + run_read_write_transaction_commit_retry_exactly_once(true).await } #[tokio::test] - async fn read_write_transaction_commit_retry_exactly_once() { + async fn read_write_transaction_commit_retry_exactly_once_inline() -> anyhow::Result<()> { + run_read_write_transaction_commit_retry_exactly_once(false).await + } + + async fn run_read_write_transaction_commit_retry_exactly_once( + explicit_begin: bool, + ) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|_| { - Ok(tonic::Response::new(v1::Transaction { - id: vec![7, 7], + let transaction_id = vec![7, 7]; + + if explicit_begin { + let id = transaction_id.clone(); + mock.expect_begin_transaction().once().returning(move |_| { + Ok(tonic::Response::new(v1::Transaction { + id: id.clone(), + ..Default::default() + })) + }); + } else { + let id = transaction_id.clone(); + mock.expect_execute_sql().once().returning(move |req| { + let req = req.into_inner(); + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!(selector, Selector::Begin(_))); + + Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(v1::ResultSetMetadata { + transaction: Some(v1::Transaction { + id: id.clone(), + ..Default::default() + }), + ..Default::default() + }), + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) + }); + } + + let mut seq = mockall::Sequence::new(); + + // Initial commit returns a retry token (seq 2) + mock.expect_commit() + .once() + .in_sequence(&mut seq) + .returning(|_| { + Ok(tonic::Response::new(v1::CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 1000, + nanos: 0, + }), + multiplexed_session_retry: Some( + v1::commit_response::MultiplexedSessionRetry::PrecommitToken( + v1::MultiplexedSessionPrecommitToken { + precommit_token: vec![2], + seq_num: 2, + }, + ), + ), + ..Default::default() + })) + }); + + // Retry commit returns another retry token (seq 3). + // The library should not retry multiple times. + mock.expect_commit() + .once() + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.precommit_token + .as_ref() + .expect("Missing precommit token in retry req") + .seq_num, + 2 + ); + + Ok(tonic::Response::new(v1::CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 9999, + nanos: 0, + }), + multiplexed_session_retry: Some( + v1::commit_response::MultiplexedSessionRetry::PrecommitToken( + v1::MultiplexedSessionPrecommitToken { + precommit_token: vec![3], + seq_num: 3, + }, + ), + ), + ..Default::default() + })) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = ReadWriteTransactionBuilder::new(db_client.clone()) + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; + + if !explicit_begin { + tx.execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1") + .await?; + } + + let ts = tx.commit().await.expect("Failed to commit transaction"); + assert_eq!(ts.seconds(), 9999); + Ok(()) + } + + #[tokio::test] + async fn read_write_transaction_execute_update_inline_begin() { + let mut mock = create_session_mock(); + + mock.expect_execute_sql().once().returning(|req| { + let req = req.into_inner(); + assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1"); + assert_eq!(req.seqno, 1); + + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Begin(options) => { + assert!(options.mode.is_some()); + } + _ => panic!("Expected Selector::Begin"), + } + + Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(v1::ResultSetMetadata { + transaction: Some(v1::Transaction { + id: vec![7, 8, 9], + ..Default::default() + }), + ..Default::default() + }), + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), ..Default::default() })) }); - // Initial commit returns a retry token (seq 2) - mock.expect_commit().once().returning(|_| { + mock.expect_commit().once().returning(|req| { + let req = req.into_inner(); + match req.transaction.expect("missing transaction") { + v1::commit_request::Transaction::TransactionId(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected TransactionId"), + } Ok(tonic::Response::new(v1::CommitResponse { commit_timestamp: Some(prost_types::Timestamp { - seconds: 1000, + seconds: 123456789, nanos: 0, }), - multiplexed_session_retry: Some( - v1::commit_response::MultiplexedSessionRetry::PrecommitToken( - v1::MultiplexedSessionPrecommitToken { - precommit_token: vec![2], - seq_num: 2, - }, - ), - ), ..Default::default() })) }); - // Retry commit returns another retry token (seq 3). - // The library should not retry multiple times. - mock.expect_commit().once().returning(|req| { + let (db_client, _server) = setup_db_client(mock).await; + + let tx = ReadWriteTransactionBuilder::new(db_client.clone()) + .build() + .await + .expect("Failed to build transaction"); + + let count = tx + .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1") + .await + .expect("Failed to execute update"); + assert_eq!(count, 1); + + let ts = tx.commit().await.expect("Failed to commit"); + assert_eq!(ts.seconds(), 123456789); + } + + #[tokio::test] + async fn read_write_transaction_execute_batch_update_inline_begin() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + + mock.expect_execute_batch_dml().once().returning(|req| { let req = req.into_inner(); - assert_eq!( - req.precommit_token - .as_ref() - .expect("Missing precommit token in retry req") - .seq_num, - 2 - ); + assert_eq!(req.statements.len(), 1); + + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Begin(options) => { + assert!(options.mode.is_some()); + } + _ => panic!("Expected Selector::Begin"), + } + + Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse { + result_sets: vec![v1::ResultSet { + metadata: Some(v1::ResultSetMetadata { + transaction: Some(v1::Transaction { + id: vec![4, 5, 6], + ..Default::default() + }), + ..Default::default() + }), + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }], + status: Some(spanner_grpc_mock::google::rpc::Status { + code: 0, + message: "OK".into(), + details: vec![], + }), + ..Default::default() + })) + }); + mock.expect_commit().once().returning(|req| { + let req = req.into_inner(); + match req.transaction.expect("missing transaction") { + v1::commit_request::Transaction::TransactionId(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected TransactionId"), + } Ok(tonic::Response::new(v1::CommitResponse { commit_timestamp: Some(prost_types::Timestamp { - seconds: 9999, + seconds: 123456789, nanos: 0, }), - multiplexed_session_retry: Some( - v1::commit_response::MultiplexedSessionRetry::PrecommitToken( - v1::MultiplexedSessionPrecommitToken { - precommit_token: vec![3], - seq_num: 3, - }, - ), - ), ..Default::default() })) }); let (db_client, _server) = setup_db_client(mock).await; + + let tx = ReadWriteTransactionBuilder::new(db_client).build().await?; + + let batch = + BatchDml::builder().add_statement("UPDATE Users SET Name = 'Alice' WHERE Id = 1"); + + let counts = tx.execute_batch_update(batch.build()).await?; + + assert_eq!(counts, vec![1]); + + let ts = tx.commit().await?; + assert_eq!(ts.seconds(), 123456789); + + Ok(()) + } + + #[tokio::test] + async fn read_write_transaction_execute_update_fallback() { + let mut mock = create_session_mock(); + + // 1. First DML attempt fails! + mock.expect_execute_sql().once().returning(|req| { + let req = req.into_inner(); + assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1"); + + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin"), + } + + Err(tonic::Status::new(tonic::Code::Internal, "internal error")) + }); + + // 2. Client falls back to explicit BeginTransaction! + mock.expect_begin_transaction().once().returning(|_| { + Ok(tonic::Response::new(v1::Transaction { + id: vec![7, 8, 9], + ..Default::default() + })) + }); + + // 3. Client retries DML with new ID! + mock.expect_execute_sql().once().returning(|req| { + let req = req.into_inner(); + assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1"); + + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Id(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected Selector::Id"), + } + + Ok(tonic::Response::new(v1::ResultSet { + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .build() .await .expect("Failed to build transaction"); - let ts = tx.commit().await.expect("Failed to commit transaction"); - assert_eq!(ts.seconds(), 9999); + let count = tx + .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1") + .await + .expect("Failed to execute update after fallback"); + assert_eq!(count, 1); + } + + #[tokio::test] + async fn read_write_transaction_execute_batch_update_fallback() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + + // 1. First Batch DML attempt fails! + mock.expect_execute_batch_dml().once().returning(|req| { + let req = req.into_inner(); + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin"), + } + + Err(tonic::Status::new(tonic::Code::Internal, "internal error")) + }); + + // 2. Client falls back to explicit BeginTransaction! + mock.expect_begin_transaction().once().returning(|_| { + Ok(tonic::Response::new(v1::Transaction { + id: vec![4, 5, 6], + ..Default::default() + })) + }); + + // 3. Client retries Batch DML with new ID! + mock.expect_execute_batch_dml().once().returning(|req| { + let req = req.into_inner(); + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id"), + } + + Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse { + result_sets: vec![v1::ResultSet { + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }], + status: Some(spanner_grpc_mock::google::rpc::Status { + code: 0, + message: "OK".into(), + details: vec![], + }), + ..Default::default() + })) + }); + + let (db_client, _server) = setup_db_client(mock).await; + + let tx = ReadWriteTransactionBuilder::new(db_client).build().await?; + + let batch = + BatchDml::builder().add_statement("UPDATE Users SET Name = 'Alice' WHERE Id = 1"); + + let counts = tx.execute_batch_update(batch.build()).await?; + + assert_eq!(counts, vec![1]); + + Ok(()) } } diff --git a/src/spanner/src/result_set.rs b/src/spanner/src/result_set.rs index cfe0397cae..02a2333ec4 100644 --- a/src/spanner/src/result_set.rs +++ b/src/spanner/src/result_set.rs @@ -16,6 +16,7 @@ use crate::database_client::DatabaseClient; use crate::error::internal_error; use crate::google::spanner::v1::PartialResultSet; use crate::precommit::PrecommitTokenTracker; +use crate::read_only_transaction::ReadContextTransactionSelector; use crate::result_set_metadata::ResultSetMetadata; use crate::row::Row; use crate::server_streaming::stream::PartialResultSetStream; @@ -24,6 +25,10 @@ use gaxi::prost::FromProto; use google_cloud_gax::error::rpc::Code; use std::collections::VecDeque; use std::mem::take; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use tokio::sync::mpsc; +use tokio::sync::watch; #[cfg(feature = "unstable-stream")] use futures::Stream; @@ -43,11 +48,20 @@ use futures::Stream; /// ``` #[derive(Debug)] pub struct ResultSet { + receiver: mpsc::Receiver>, + metadata: watch::Receiver>, + // This field is only modified in tests to set a small buffer size. + #[allow(dead_code)] + max_buffered_partial_result_sets: Arc, +} + +#[derive(Debug)] +struct ResultSetWorker { stream: PartialResultSetStream, buffered_values: Vec, chunked: bool, ready_rows: VecDeque, - metadata: Option, + metadata: watch::Sender>, precommit_token_tracker: PrecommitTokenTracker, // Fields for retries and buffering of a stream of PartialResultSets. @@ -56,8 +70,9 @@ pub struct ResultSet { last_resume_token: Bytes, partial_result_sets_buffer: VecDeque, safe_to_retry: bool, - max_buffered_partial_result_sets: usize, + max_buffered_partial_result_sets: Arc, retry_count: usize, + transaction_selector: Option, } /// Errors that can occur when interacting with a [`ResultSet`]. @@ -84,24 +99,38 @@ impl ResultSet { /// Creates a new result set. pub(crate) fn new( stream: PartialResultSetStream, + transaction_selector: Option, precommit_token_tracker: PrecommitTokenTracker, client: DatabaseClient, operation: StreamOperation, ) -> Self { - Self { + let (sender, receiver) = mpsc::channel(4); + let (metadata_sender, metadata_receiver) = watch::channel(None); + let max_buffered_partial_result_sets = + Arc::new(AtomicUsize::new(MAX_BUFFERED_PARTIAL_RESULT_SETS)); + + let mut worker = ResultSetWorker::new( stream, - buffered_values: Vec::new(), - chunked: false, - ready_rows: VecDeque::new(), - metadata: None, + transaction_selector, precommit_token_tracker, client, operation, - last_resume_token: Bytes::new(), - partial_result_sets_buffer: VecDeque::new(), - safe_to_retry: true, - max_buffered_partial_result_sets: MAX_BUFFERED_PARTIAL_RESULT_SETS, - retry_count: 0, + metadata_sender, + Arc::clone(&max_buffered_partial_result_sets), + ); + + tokio::spawn(async move { + while let Some(row) = worker.next().await { + if sender.send(row).await.is_err() { + break; // Receiver dropped + } + } + }); + + Self { + receiver, + metadata: metadata_receiver, + max_buffered_partial_result_sets, } } @@ -110,21 +139,29 @@ impl ResultSet { /// # Example /// ``` /// # use google_cloud_spanner::client::{ResultSet, Row}; - /// # async fn fetch_metadata(mut rs: ResultSet) -> Result<(), Box> { - /// if let Some(row) = rs.next().await.transpose()? { - /// let metadata = rs.metadata()?; - /// for column in metadata.column_names() { - /// println!("Column name: {}", column); - /// } + /// # async fn fetch_metadata(mut result_set: ResultSet) -> Result<(), Box> { + /// let metadata = result_set.metadata().await?; + /// for column in metadata.column_names() { + /// println!("Column name: {}", column); /// } /// # Ok(()) /// # } /// ``` /// - /// The metadata is only available after the first call to [`next`](Self::next). - /// If called before the first `next()` call, it returns a [`ResultSetError::MetadataNotAvailable`] error. - pub fn metadata(&self) -> Result { - self.metadata + /// This method blocks until the metadata is available, which is after the + /// first chunk is received from the server. If the stream ends or fails + /// before metadata is available, it returns [`ResultSetError::MetadataNotAvailable`]. + pub async fn metadata(&self) -> Result { + let mut receiver = self.metadata.clone(); + if let Some(metadata) = &*receiver.borrow() { + return Ok(metadata.clone()); + } + receiver + .changed() + .await + .map_err(|_| ResultSetError::MetadataNotAvailable)?; + receiver + .borrow() .clone() .ok_or(ResultSetError::MetadataNotAvailable) } @@ -144,6 +181,73 @@ impl ResultSet { /// /// Returns `None` when all rows have been retrieved. pub async fn next(&mut self) -> Option> { + self.receiver.recv().await + } + + /// Converts the [`ResultSet`] into a [`Stream`]. + /// + /// # Example + /// + /// ``` + /// # use google_cloud_spanner::client::ResultSet; + /// # use futures::TryStreamExt; + /// # use std::future::ready; + /// # async fn example(result_set: ResultSet) -> Result<(), google_cloud_spanner::Error> { + /// let rows: Vec<_> = result_set + /// .into_stream() + /// .try_filter(|row| { + /// let id = row.get::("Id"); + /// ready(id == "id1") + /// }) + /// .try_collect() + /// .await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// This consumes the [`ResultSet`] and returns a stream of rows. + #[cfg(feature = "unstable-stream")] + pub fn into_stream(self) -> impl Stream> + Unpin { + use futures::stream::unfold; + Box::pin(unfold(self, |mut result_set| async move { + result_set.next().await.map(|row| (row, result_set)) + })) + } +} + +impl ResultSetWorker { + /// Creates a new result set worker. + pub(crate) fn new( + stream: PartialResultSetStream, + transaction_selector: Option, + precommit_token_tracker: PrecommitTokenTracker, + client: DatabaseClient, + operation: StreamOperation, + metadata: watch::Sender>, + max_buffered_partial_result_sets: Arc, + ) -> Self { + Self { + stream, + buffered_values: Vec::new(), + chunked: false, + ready_rows: VecDeque::new(), + metadata, + precommit_token_tracker, + client, + operation, + last_resume_token: Bytes::new(), + partial_result_sets_buffer: VecDeque::new(), + safe_to_retry: true, + max_buffered_partial_result_sets, + retry_count: 0, + transaction_selector, + } + } + + /// Fetches the next row from the result set. + /// + /// Returns `None` when all rows have been retrieved. + pub(crate) async fn next(&mut self) -> Option> { if let Some(row) = self.ready_rows.pop_front() { return Some(Ok(row)); } @@ -206,7 +310,11 @@ impl ResultSet { // The PartialResultSet did not have a resume_token. Buffer the result // and continue with the next PartialResultSet, unless the buffer is full. - if self.partial_result_sets_buffer.len() >= self.max_buffered_partial_result_sets { + if self.partial_result_sets_buffer.len() + >= self + .max_buffered_partial_result_sets + .load(Ordering::Relaxed) + { // Mark this stream as 'unsafe to retry', meaning that any transient error // that we see will not be retried. We will instead propagate the error. self.safe_to_retry = false; @@ -229,7 +337,29 @@ impl ResultSet { return Ok(()); } - Err(e) + // Check if this stream included an inlined BeginTransaction option + // and has not yet returned a transaction ID. If so, we explicitly + // begin the transaction and restart the stream. + let Some(ReadContextTransactionSelector::Lazy(lazy)) = &self.transaction_selector else { + return Err(e); + }; + let is_started = matches!( + &*lazy.lock().unwrap(), + crate::read_only_transaction::TransactionState::Started(_, _) + ); + if is_started { + return Err(e); + } + + self.transaction_selector + .as_ref() + .unwrap() + .begin_explicitly(&self.client) + .await?; + + self.partial_result_sets_buffer.clear(); + self.restart_stream().await?; + Ok(()) } fn handle_stream_end(&mut self) -> crate::Result> { @@ -264,25 +394,57 @@ impl ResultSet { &mut self, partial_result_set: PartialResultSet, ) -> crate::Result<()> { - match (self.metadata.as_ref(), partial_result_set.metadata) { - (Some(_), None) => {} - (None, None) => { - return Err(internal_error( - "First PartialResultSet did not contain metadata", - )); - } - (Some(_), Some(_)) => { - return Err(internal_error("Additional metadata after first result set")); + let update_selector = { + let metadata_ref = self.metadata.borrow(); + match (&*metadata_ref, partial_result_set.metadata) { + (Some(_), None) => None, + (None, None) => { + return Err(internal_error( + "First PartialResultSet did not contain metadata", + )); + } + (Some(_), Some(_)) => { + return Err(internal_error("Additional metadata after first result set")); + } + (None, Some(mut m)) => { + let transaction = m.transaction.take(); + Some((ResultSetMetadata::new(Some(m)), transaction)) + } } - (None, Some(m)) => { - self.metadata = Some(ResultSetMetadata::new(Some(m))); + }; + + if let Some((metadata, transaction)) = update_selector { + self.metadata + .send(Some(metadata)) + .map_err(|_| internal_error("Failed to send metadata"))?; + + if let Some(selector) = &self.transaction_selector { + if let Some(transaction) = transaction { + selector.update( + transaction.id, + transaction + .read_timestamp + .and_then(|t| wkt::Timestamp::new(t.seconds, t.nanos).ok()), + )?; + } else if let ReadContextTransactionSelector::Lazy(lazy) = selector { + let is_started = matches!( + &*lazy.lock().expect("transaction state mutex poisoned"), + crate::read_only_transaction::TransactionState::Started(_, _) + ); + if !is_started { + return Err(internal_error( + "Spanner failed to return a transaction ID for a query that included a BeginTransaction option", + )); + } + } } } if partial_result_set.values.is_empty() { return Ok(()); } - let metadata = self.metadata.as_ref().unwrap(); + + let metadata = self.metadata.borrow().as_ref().unwrap().clone(); if metadata.column_types.is_empty() { return Err(internal_error( "PartialResultSet contained values but no column metadata was provided", @@ -321,9 +483,23 @@ impl ResultSet { } async fn restart_stream(&mut self) -> crate::Result<()> { + if let Some(s) = &self.transaction_selector { + s.maybe_reset_starting(); + } + + // Get the latest transaction selector for this transaction. + let transaction_selector = if let Some(s) = &self.transaction_selector { + Some(s.selector().await?) + } else { + None + }; + match &mut self.operation { StreamOperation::Query(req) => { req.resume_token = self.last_resume_token.clone(); + req.transaction = transaction_selector + .clone() + .or_else(|| req.transaction.take()); let stream = self .client .spanner @@ -334,6 +510,9 @@ impl ResultSet { } StreamOperation::Read(req) => { req.resume_token = self.last_resume_token.clone(); + req.transaction = transaction_selector + .clone() + .or_else(|| req.transaction.take()); let stream = self .client .spanner @@ -354,36 +533,6 @@ impl ResultSet { e.status() .is_some_and(|status| status.code == Code::Unavailable) } - - /// Converts the [`ResultSet`] into a [`Stream`]. - /// - /// # Example - /// - /// ``` - /// # use google_cloud_spanner::client::ResultSet; - /// # use futures::TryStreamExt; - /// # use std::future::ready; - /// # async fn example(result_set: ResultSet) -> Result<(), google_cloud_spanner::Error> { - /// let rows: Vec<_> = result_set - /// .into_stream() - /// .try_filter(|row| { - /// let id = row.get::("Id"); - /// ready(id == "id1") - /// }) - /// .try_collect() - /// .await?; - /// # Ok(()) - /// # } - /// ``` - /// - /// This consumes the [`ResultSet`] and returns a stream of rows. - #[cfg(feature = "unstable-stream")] - pub fn into_stream(self) -> impl Stream> + Unpin { - use futures::stream::unfold; - Box::pin(unfold(self, |mut result_set| async move { - result_set.next().await.map(|row| (row, result_set)) - })) - } } /// Merges two values from successive `PartialResultSet`s into a single value. @@ -441,7 +590,8 @@ fn merge_values(target: &mut prost_types::Value, source: prost_types::Value) -> #[cfg(test)] impl ResultSet { pub(crate) fn set_max_buffered_partial_result_sets(&mut self, limit: usize) { - self.max_buffered_partial_result_sets = limit; + self.max_buffered_partial_result_sets + .store(limit, Ordering::Relaxed); } } @@ -450,6 +600,7 @@ pub(crate) mod tests { use super::*; use crate::client::Spanner; use gaxi::grpc::tonic::Response; + use google_cloud_auth::credentials::anonymous::Builder as Anonymous; use prost_types::Value; use spanner_grpc_mock::MockSpanner; use spanner_grpc_mock::google::spanner::v1::spanner_server::Spanner as SpannerTrait; @@ -513,7 +664,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await .expect("Failed to build client"); @@ -549,6 +700,31 @@ pub(crate) mod tests { assert!(next.is_none()); } + #[tokio::test] + async fn test_result_set_metadata() -> anyhow::Result<()> { + let mut rs = run_mock_query(vec![PartialResultSet { + metadata: metadata(2), + values: vec![string_val("a"), string_val("b")], + last: true, + ..Default::default() + }]) + .await; + + // Called before next() -> blocks and returns metadata + let meta = rs.metadata().await; + assert!(meta.is_ok()); + let meta = meta.unwrap(); + assert_eq!( + meta.column_names(), + &["col0".to_string(), "col1".to_string()] + ); + + // Now consume the row + let _next = rs.next().await.expect("Expected a row")?; + + Ok(()) + } + #[tokio::test] async fn test_result_set_handle_partial_result_set_error() -> anyhow::Result<()> { let mut rs = run_mock_query(vec![PartialResultSet { @@ -571,6 +747,34 @@ pub(crate) mod tests { Ok(()) } + #[tokio::test] + async fn test_result_set_handle_partial_result_set_error_immediate() -> anyhow::Result<()> { + let mut rs = run_mock_query(vec![ + PartialResultSet { + values: vec![string_val("row1")], + ..Default::default() + }, + PartialResultSet { + resume_token: b"token".to_vec(), + ..Default::default() + }, + ]) + .await; + + let res = rs.next().await; + assert!(res.is_some(), "Expected an error but got None"); + let res = res.expect("Expected some response but got None"); + assert!(res.is_err(), "Expected an error but got Ok"); + let err_str = res.expect_err("Expected should be an error").to_string(); + assert!( + err_str.contains("First PartialResultSet did not contain metadata"), + "Expected error to contain 'First PartialResultSet did not contain metadata', but got '{}'", + err_str + ); + + Ok(()) + } + #[tokio::test] async fn test_result_set_stream_ended_with_chunked_value() -> anyhow::Result<()> { let mut rs = run_mock_query(vec![PartialResultSet { @@ -710,7 +914,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -981,21 +1185,62 @@ pub(crate) mod tests { } #[tokio::test] - async fn test_result_set_precommit_token_tracked() { - let mut rs = run_mock_query(vec![PartialResultSet { - metadata: metadata(1), - precommit_token: Some( - spanner_grpc_mock::google::spanner::v1::MultiplexedSessionPrecommitToken { - precommit_token: b"test_token".to_vec(), - seq_num: 99, - }, - ), - ..Default::default() - }]) - .await; + async fn test_result_set_precommit_token_tracked() -> anyhow::Result<()> { + let mut mock = MockSpanner::new(); + mock.expect_execute_streaming_sql() + .returning(move |_request| { + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: metadata(1), + precommit_token: Some( + spanner_grpc_mock::google::spanner::v1::MultiplexedSessionPrecommitToken { + precommit_token: b"test_token".to_vec(), + seq_num: 99, + }, + ), + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; - // Force tracking mode since run_mock_query uses a ReadOnly transaction (NoOp). - rs.precommit_token_tracker = PrecommitTokenTracker::new(); + let req = crate::model::ExecuteSqlRequest::default() + .set_session(db_client.session.name.clone()) + .set_sql("SELECT 1".to_string()); + + let stream = db_client + .spanner + .execute_streaming_sql(req.clone(), crate::RequestOptions::default()) + .send() + .await?; + + let tracker = PrecommitTokenTracker::new(); // Track mode! + + let mut rs = ResultSet::new( + stream, + None, + tracker.clone(), + db_client.clone(), + StreamOperation::Query(req), + ); // Read a row to trigger precommit token extraction assert!( @@ -1004,12 +1249,11 @@ pub(crate) mod tests { ); // Validate the tracker correctly intercepted and preserved the token - let token = rs - .precommit_token_tracker - .get() - .expect("token should be tracked"); + let token = tracker.get().expect("token should be tracked"); assert_eq!(token.seq_num, 99); assert_eq!(token.precommit_token, bytes::Bytes::from("test_token")); + + Ok(()) } #[tokio::test] @@ -1066,7 +1310,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1122,7 +1366,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1192,7 +1436,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1281,7 +1525,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1360,7 +1604,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1412,7 +1656,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1433,4 +1677,572 @@ pub(crate) mod tests { Ok(()) } + + #[tokio::test] + async fn result_set_inline_begin_stream_error_fallback() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // 1. Stream yields an error on the first chunk before returning transaction metadata. + // E.g., INVALID_ARGUMENT because the query is malformed. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let stream = + tokio_stream::iter(vec![Err(Status::invalid_argument("Invalid query"))]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + // 2. The explicit BeginTransaction fallback gets triggered. + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: Some(prost_types::Timestamp { + seconds: 123456789, + nanos: 0, + }), + ..Default::default() + })) + }); + + // 3. The ResultSet gracefully restarts the stream using the transaction ID returned by BeginTransaction. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + // Ensure the explicitly yielded ID is routed into the new stream transaction selector + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected Selector::Id"), + } + + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: metadata(1), + values: vec![string_val("1")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let mut rs = tx.execute_query("SELECT 1").await?; + + let row1 = rs.next().await.ok_or_else(|| { + anyhow::anyhow!("Expected row returned successfully despite stream breaking") + })??; + assert_eq!( + row1.raw_values()[0].0, + string_val("1"), + "Verify the returned stream successfully resumed with the correct payload" + ); + + Ok(()) + } + + #[tokio::test] + async fn result_set_retry_inline_begin_transient_error() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial stream throws UNAVAILABLE before metadata. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let stream = + tokio_stream::iter(vec![Err(Status::unavailable("Transient network issue"))]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + // 2. We retry the stream since it was a transient error. + // The retry should use the same transaction selector as the original request. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin on stream retry"), + } + + let mut meta = metadata(1).unwrap(); + meta.transaction = Some(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: None, + ..Default::default() + }); + + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: Some(meta), + values: vec![string_val("1")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let mut rs = tx.execute_query("SELECT 1").await?; + + let row1 = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected stream to recover safely"))??; + assert_eq!( + row1.raw_values()[0].0, + string_val("1"), + "Verify resumed stream returns data" + ); + + Ok(()) + } + + #[tokio::test] + async fn result_set_retry_inline_begin_id_recovered() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // 1. Stream successfully returns metadata chunk then throws UNAVAILABLE on chunk 2. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let mut meta = metadata(1).unwrap(); + meta.transaction = Some(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: None, + ..Default::default() + }); + let stream = tokio_stream::iter(vec![ + Ok(PartialResultSet { + metadata: Some(meta), + values: vec![string_val("1")], + resume_token: b"token1".to_vec(), + ..Default::default() + }), + Err(Status::unavailable("Transient mid-stream network issue")), + ]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + // 2. Stream resumes using Selector::Id. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected Selector::Id on stream retry"), + } + + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + values: vec![string_val("2")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let mut rs = tx.execute_query("SELECT 1").await?; + + let row1 = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected stream row1 extracted"))??; + assert_eq!( + row1.raw_values()[0].0, + string_val("1"), + "Verified chunk 1 payload" + ); + let row2 = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected stream row2 recovered"))??; + assert_eq!( + row2.raw_values()[0].0, + string_val("2"), + "Verified chunk 2 reboot dynamically intercepted ID bounds correctly" + ); + + Ok(()) + } + + #[tokio::test] + async fn result_set_inline_begin_metadata_missing_transaction_fails() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial stream successfully returns metadata chunk but completely lacks the `Transaction` entity. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: metadata(1), // Missing `.transaction` natively + values: vec![string_val("1")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + + // Use explicitly deferred Lazy begin transaction! + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let mut rs = tx.execute_query("SELECT 1").await?; + + let rs_result = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected explicit crash bound properly"))?; + assert!( + rs_result.is_err(), + "Securely aborted when metadata failed to package internal bounds properly" + ); + + let err_str = rs_result.unwrap_err().to_string(); + assert!( + err_str.contains("failed to return a transaction ID"), + "Caught implicit gap boundary: {}", + err_str + ); + + Ok(()) + } + + #[tokio::test] + async fn test_lazy_begin_deadlock_fixed() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // Setup mock to return metadata with transaction ID on first query. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let mut meta = metadata(1).expect("failed to create metadata"); + meta.transaction = Some(mock_v1::Transaction { + id: b"lazy_tx_id".to_vec(), + ..Default::default() + }); + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: Some(meta), + values: vec![string_val("1")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + // Mock call for second query which must carry the returned transaction ID + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + let selector = req + .transaction + .expect("missing transaction component") + .selector + .expect("missing selector component"); + + match selector { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, b"lazy_tx_id".to_vec()); + } + _ => panic!("Expected Selector::Id"), + } + + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: metadata(1), + values: vec![string_val("2")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + + // Use inline begin transaction + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // Execute query but DO NOT call rs.next() + let _rs = tx.execute_query("SELECT 1").await?; + + // Execute second query against same transaction + let mut rs2 = tx.execute_query("SELECT 2").await?; + + // Assert it does not hang and yielded elements properly + let row2 = rs2.next().await; + assert!( + row2.is_some(), + "Implicit deadlock encountered; query 2 stalled!" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_result_set_metadata_not_available() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + + // Setup mock to return a stream that fails immediately. + mock.expect_execute_streaming_sql().returning(|_request| { + let stream = tokio_stream::iter(vec![Err(Status::internal("Internal error"))]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + use spanner_grpc_mock::google::spanner::v1::Session; + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + let tx = db_client.single_use().build(); + + let rs = tx.execute_query("SELECT 1").await?; + + // Call metadata() immediately. It should fail because the stream ends without metadata. + let result = rs.metadata().await; + assert!(result.is_err(), "Expected error but got Ok"); + assert!( + matches!(result.unwrap_err(), ResultSetError::MetadataNotAvailable), + "Expected MetadataNotAvailable error" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_result_set_metadata_available_before_next() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + + // Setup mock to return metadata in first chunk. + mock.expect_execute_streaming_sql().returning(|_request| { + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: metadata(1), + values: vec![string_val("1")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + use spanner_grpc_mock::google::spanner::v1::Session; + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + let tx = db_client.single_use().build(); + + let mut rs = tx.execute_query("SELECT 1").await?; + + // Call metadata() BEFORE next(). It should succeed. + let metadata = rs.metadata().await?; + assert_eq!(metadata.column_names().len(), 1); + assert_eq!(metadata.column_names()[0], "col0"); + + // Now consume the row + let row = rs.next().await; + assert!(row.is_some()); + + Ok(()) + } } diff --git a/src/spanner/src/result_set_metadata.rs b/src/spanner/src/result_set_metadata.rs index 2a6cd2bf1e..3e48cf2b09 100644 --- a/src/spanner/src/result_set_metadata.rs +++ b/src/spanner/src/result_set_metadata.rs @@ -26,9 +26,7 @@ use std::sync::Arc; /// let tx = db.single_use().build(); /// let mut rs = tx.execute_query(Statement::builder("SELECT 1 AS Number").build()).await?; /// -/// // Metadata is available after the first `next` call -/// let _ = rs.next().await.transpose()?; -/// let metadata = rs.metadata()?; +/// let metadata = rs.metadata().await?; /// /// for (name, type_) in metadata.column_names().iter().zip(metadata.column_types().iter()) { /// println!("Column: {} has type: {:?}", name, type_.code()); diff --git a/src/spanner/src/transaction_runner.rs b/src/spanner/src/transaction_runner.rs index e6f5b4da6d..64b10d1e4c 100644 --- a/src/spanner/src/transaction_runner.rs +++ b/src/spanner/src/transaction_runner.rs @@ -19,6 +19,7 @@ use crate::read_write_transaction::{ReadWriteTransaction, ReadWriteTransactionBu use crate::transaction_retry_policy::{ BasicTransactionRetryPolicy, TransactionRetryPolicy, backoff_if_aborted, is_aborted, }; +use std::sync::Arc; /// A builder for a [TransactionRunner] for a read/write transaction. /// @@ -151,6 +152,42 @@ impl TransactionRunnerBuilder { self } + /// Sets whether the transaction should be explicitly started using a `BeginTransaction` RPC. + /// + /// # Example + /// ``` + /// # use google_cloud_spanner::client::Spanner; + /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> { + /// let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?; + /// let runner = db_client + /// .read_write_transaction() + /// .with_explicit_begin_transaction(true) + /// .build() + /// .await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// By default, the Spanner client will inline the `BeginTransaction` call with the first query + /// or DML statement in the transaction. This reduces the number of round-trips to Spanner that + /// are needed for a transaction. Setting this option to `true` can be beneficial for specific + /// transaction shapes: + /// + /// 1. When the transaction executes multiple parallel queries at the start of the transaction. + /// Only one query can include a `BeginTransaction` option, and all other queries must wait for + /// the first query to return the first result before they can proceed to execute. A + /// `BeginTransaction` RPC will quickly return a transaction ID and allow all queries to start + /// execution in parallel once the transaction ID has been returned. + /// 2. When the first statement in the transaction could fail. If the statement fails, then it + /// will also not start a transaction and return a transaction ID. The transaction will then + /// fall back to executing a `BeginTransaction` RPC and retry the first statement. + /// + /// Default is `false` (inline begin). + pub fn with_explicit_begin_transaction(mut self, explicit: bool) -> Self { + self.builder = self.builder.with_explicit_begin_transaction(explicit); + self + } + /// Builds a [TransactionRunner] for a read/write transaction. /// /// # Example @@ -220,36 +257,40 @@ impl TransactionRunner { loop { attempts += 1; - let mut current_tx_id = None; + let shared_tx_id = Arc::new(std::sync::Mutex::new(None)); let attempt_result = async { - let transaction = self.builder.begin_transaction().await?; - current_tx_id = transaction.transaction_id().ok(); + let transaction = self.builder.build().await?; let result = match work(transaction.clone()).await { Ok(res) => res, - Err(e) => { + Err(error) => { + let id = transaction.context.transaction_selector.get_id_no_wait(); // Rollback if the closure failed and it was not an Aborted error. - if !is_aborted(&e) { + if !is_aborted(&error) { let _ = transaction.rollback().await; } - return Err(e); + *shared_tx_id.lock().unwrap() = id; + return Err(error); } }; + let id = transaction.context.transaction_selector.get_id_no_wait(); + *shared_tx_id.lock().unwrap() = id; transaction.commit().await?; - Ok::(result) + Ok(result) } .await; match attempt_result { Ok(res) => return Ok(res), - Err(e) => { - if is_aborted(&e) { + Err(error) => { + if is_aborted(&error) { + let current_tx_id = shared_tx_id.lock().unwrap().clone(); self.builder = self.builder.with_previous_transaction_id(current_tx_id); } backoff_if_aborted( - e, + error, attempts, start_time.elapsed(), self.retry_policy.as_ref(), @@ -293,9 +334,11 @@ mod tests { async fn execute_test_runner( mock: spanner_grpc_mock::MockSpanner, + explicit_begin: bool, ) -> Result { let (db_client, _server) = setup_db_client(mock).await; let runner = TransactionRunnerBuilder::new(db_client) + .with_explicit_begin_transaction(explicit_begin) .build() .await .unwrap(); @@ -336,16 +379,57 @@ mod tests { } #[tokio::test] - async fn run_success() { + async fn execute_run_success_explicit() { + run_success(true).await; + } + + #[tokio::test] + async fn execute_run_success_inline() { + run_success(false).await; + } + + async fn run_success(explicit_begin: bool) { let mut mock = create_session_mock(); - expect_begin_transaction(&mut mock, 1, vec![1, 2, 3]); + if explicit_begin { + expect_begin_transaction(&mut mock, 1, vec![1, 2, 3]); + } - mock.expect_execute_sql().once().returning(|req| { + mock.expect_execute_sql().once().returning(move |req| { let req = req.into_inner(); assert_eq!(req.sql, "UPDATE Users SET active = true"); assert_eq!(req.seqno, 1); - row_count_exact_response(1) + + if !explicit_begin { + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!( + selector, + v1::transaction_selector::Selector::Begin(_) + )); + } + + let mut metadata = v1::ResultSetMetadata { + ..Default::default() + }; + if !explicit_begin { + metadata.transaction = Some(v1::Transaction { + id: vec![1, 2, 3], + ..Default::default() + }); + } + + Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(metadata), + stats: Some(v1::ResultSetStats { + row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) }); mock.expect_commit().once().returning(|req| { @@ -359,93 +443,563 @@ mod tests { commit_response() }); - let res = execute_test_runner(mock).await.unwrap(); + let res = execute_test_runner(mock, explicit_begin).await.unwrap(); assert_eq!(res, 1); } #[tokio::test] - async fn run_with_aborted_retry() -> anyhow::Result<()> { + async fn execute_run_with_aborted_retry_explicit() -> anyhow::Result<()> { + run_with_aborted_retry(true).await + } + + #[tokio::test] + async fn execute_run_with_aborted_retry_inline() -> anyhow::Result<()> { + run_with_aborted_retry(false).await + } + + async fn run_with_aborted_retry(explicit_begin: bool) -> anyhow::Result<()> { let mut mock = create_session_mock(); let mut seq = mockall::Sequence::new(); - mock.expect_begin_transaction() + if explicit_begin { + mock.expect_begin_transaction() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + Ok(tonic::Response::new(v1::Transaction { + id: vec![9, 9, 9], + ..Default::default() + })) + }); + } + + if !explicit_begin { + // Attempt 1: execute_sql fails with Aborted + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!( + selector, + v1::transaction_selector::Selector::Begin(_) + )); + + Err(create_aborted_status(std::time::Duration::from_nanos(1))) + }); + } else { + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |_req| { + Err(create_aborted_status(std::time::Duration::from_nanos(1))) + }); + } + + if explicit_begin { + mock.expect_begin_transaction() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + assert_eq!(req.session, "projects/p/instances/i/databases/d/sessions/123"); + + let options = req.options.as_ref().expect("options required on retry"); + let read_write = options.mode.as_ref().expect("mode required on retry"); + match read_write { + Mode::ReadWrite(rw) => { + assert_eq!(rw.multiplexed_session_previous_transaction_id, vec![9, 9, 9], "previous_transaction_id should be set to the ID of the aborted transaction"); + } + _ => panic!("Expected ReadWrite mode"), + } + + Ok(tonic::Response::new(v1::Transaction { + id: vec![8, 8, 8], + ..Default::default() + })) + }); + } + + // Attempt 2 (retry of closure) + mock.expect_execute_sql() .once() .in_sequence(&mut seq) .returning(move |req| { - let req = req.into_inner(); - assert_eq!( - req.session, - "projects/p/instances/i/databases/d/sessions/123" - ); - Ok(tonic::Response::new(v1::Transaction { - id: vec![9, 9, 9], - ..Default::default() - })) + if !explicit_begin { + let req = req.into_inner(); + let transaction = req.transaction.as_ref().expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!(selector, v1::transaction_selector::Selector::Begin(_))); + + let options = match selector { + v1::transaction_selector::Selector::Begin(o) => o, + _ => panic!("Expected Begin"), + }; + let read_write = options.mode.as_ref().expect("mode required"); + match read_write { + Mode::ReadWrite(rw) => { + assert!(rw.multiplexed_session_previous_transaction_id.is_empty(), "previous_transaction_id should NOT be set because the first attempt failed before getting an ID"); + } + _ => panic!("Expected ReadWrite"), + } + + let mut metadata = v1::ResultSetMetadata { ..Default::default() }; + metadata.transaction = Some(v1::Transaction { id: vec![8, 8, 8], ..Default::default() }); + return Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(metadata), + stats: Some(v1::ResultSetStats { + row_count: Some(v1::result_set_stats::RowCount::RowCountExact(5)), + ..Default::default() + }), + ..Default::default() + })); + } + row_count_exact_response(5) }); - mock.expect_execute_sql() + mock.expect_commit() + .once() + .returning(|_req| commit_response()); + + let res = execute_test_runner(mock, explicit_begin) + .await + .expect("runner should succeed"); + assert_eq!(res, 5); + Ok(()) + } + + #[tokio::test] + async fn execute_run_query_with_aborted_retry_explicit() -> anyhow::Result<()> { + run_query_with_aborted_retry(true).await + } + + #[tokio::test] + async fn execute_run_query_with_aborted_retry_inline() -> anyhow::Result<()> { + run_query_with_aborted_retry(false).await + } + + async fn run_query_with_aborted_retry(explicit_begin: bool) -> anyhow::Result<()> { + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + if explicit_begin { + mock.expect_begin_transaction() + .once() + .in_sequence(&mut seq) + .returning(move |_| { + Ok(tonic::Response::new(v1::Transaction { + id: vec![9, 9, 9], + ..Default::default() + })) + }); + } + + if !explicit_begin { + // Attempt 1: execute_streaming_sql fails with Aborted + mock.expect_execute_streaming_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!( + selector, + v1::transaction_selector::Selector::Begin(_) + )); + + Err(tonic::Status::new(tonic::Code::Aborted, "aborted")) + }); + } else { + mock.expect_execute_streaming_sql() + .once() + .in_sequence(&mut seq) + .returning(move |_req| Err(tonic::Status::new(tonic::Code::Aborted, "aborted"))); + } + + if explicit_begin { + mock.expect_begin_transaction() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + let options = req.options.as_ref().expect("options required on retry"); + let read_write = options.mode.as_ref().expect("mode required on retry"); + match read_write { + Mode::ReadWrite(rw) => { + assert_eq!( + rw.multiplexed_session_previous_transaction_id, + vec![9, 9, 9] + ); + } + _ => panic!("Expected ReadWrite mode"), + } + + Ok(tonic::Response::new(v1::Transaction { + id: vec![8, 8, 8], + ..Default::default() + })) + }); + } + + // Attempt 2 (retry of closure) + mock.expect_execute_streaming_sql() .once() .in_sequence(&mut seq) - .returning(move |_req| Err(create_aborted_status(std::time::Duration::from_nanos(1)))); + .returning(move |req| { + if !explicit_begin { + let req = req.into_inner(); + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!( + selector, + v1::transaction_selector::Selector::Begin(_) + )); + + let options = match selector { + v1::transaction_selector::Selector::Begin(o) => o, + _ => panic!("Expected Begin"), + }; + let read_write = options.mode.as_ref().expect("mode required"); + match read_write { + Mode::ReadWrite(rw) => { + assert!(rw.multiplexed_session_previous_transaction_id.is_empty()); + } + _ => panic!("Expected ReadWrite"), + } + } - mock.expect_begin_transaction() + let mut rs = v1::PartialResultSet { + metadata: Some(v1::ResultSetMetadata { + row_type: Some(v1::StructType { + fields: vec![Default::default()], + }), + ..Default::default() + }), + values: vec![prost_types::Value { + kind: Some(prost_types::value::Kind::StringValue("1".to_string())), + }], + last: true, + ..Default::default() + }; + + if !explicit_begin { + rs.metadata.as_mut().unwrap().transaction = Some(v1::Transaction { + id: vec![8, 8, 8], + ..Default::default() + }); + } + + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(rs), + ])))) + }); + + mock.expect_commit() + .once() + .returning(|_req| commit_response()); + + let (db_client, _server) = setup_db_client(mock).await; + let runner = TransactionRunnerBuilder::new(db_client) + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; + + let mut attempt_counter = 0; + let res = runner + .run(async |tx| { + attempt_counter += 1; + let mut rs = tx.execute_query("SELECT 1").await?; + let row = rs.next().await.expect("has row").expect("has valid row"); + Ok(row.raw_values()[0].as_string().to_string()) + }) + .await?; + + assert_eq!(res, "1"); + assert_eq!(attempt_counter, 2); + Ok(()) + } + + #[tokio::test] + async fn execute_run_query_stream_with_aborted_retry_explicit() -> anyhow::Result<()> { + run_query_stream_with_aborted_retry(true).await + } + + #[tokio::test] + async fn execute_run_query_stream_with_aborted_retry_inline() -> anyhow::Result<()> { + run_query_stream_with_aborted_retry(false).await + } + + async fn run_query_stream_with_aborted_retry(explicit_begin: bool) -> anyhow::Result<()> { + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + let tx_id_1 = vec![9, 9, 9]; + let tx_id_2 = vec![8, 8, 8]; + + let tx_id_1_c1 = tx_id_1.clone(); + if explicit_begin { + mock.expect_begin_transaction() + .once() + .in_sequence(&mut seq) + .returning(move |_| { + Ok(tonic::Response::new(v1::Transaction { + id: tx_id_1_c1.clone(), + ..Default::default() + })) + }); + } + + let tx_id_1_c2 = tx_id_1.clone(); + mock.expect_execute_streaming_sql() .once() .in_sequence(&mut seq) .returning(move |req| { let req = req.into_inner(); - assert_eq!(req.session, "projects/p/instances/i/databases/d/sessions/123"); - - let options = req.options.as_ref().expect("options required on retry"); - let read_write = options.mode.as_ref().expect("mode required on retry"); - match read_write { - Mode::ReadWrite(rw) => { - assert_eq!(rw.multiplexed_session_previous_transaction_id, vec![9, 9, 9], "previous_transaction_id should be set to the ID of the aborted transaction"); - } - _ => panic!("Expected ReadWrite mode"), + if !explicit_begin { + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!( + selector, + v1::transaction_selector::Selector::Begin(_) + )); } - Ok(tonic::Response::new(v1::Transaction { - id: vec![8, 8, 8], + let mut rs = v1::PartialResultSet { + metadata: Some(v1::ResultSetMetadata { + row_type: Some(v1::StructType { + fields: vec![Default::default()], + }), + ..Default::default() + }), + values: vec![prost_types::Value { + kind: Some(prost_types::value::Kind::StringValue("1".to_string())), + }], + resume_token: b"token1".to_vec(), ..Default::default() - })) + }; + + if !explicit_begin { + rs.metadata.as_mut().unwrap().transaction = Some(v1::Transaction { + id: tx_id_1_c2.clone(), + ..Default::default() + }); + } + + let stream = tokio_stream::iter(vec![ + Ok(rs), + Err(tonic::Status::new(tonic::Code::Aborted, "aborted")), + ]); + Ok(tonic::Response::new(Box::pin(stream))) }); - mock.expect_execute_sql() + let tx_id_1_c3 = tx_id_1.clone(); + let tx_id_2_c3 = tx_id_2.clone(); + if explicit_begin { + mock.expect_begin_transaction() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + let options = req.options.as_ref().expect("options required on retry"); + let read_write = options.mode.as_ref().expect("mode required on retry"); + match read_write { + Mode::ReadWrite(rw) => { + assert_eq!(rw.multiplexed_session_previous_transaction_id, tx_id_1_c3); + } + _ => panic!("Expected ReadWrite mode"), + } + + Ok(tonic::Response::new(v1::Transaction { + id: tx_id_2_c3.clone(), + ..Default::default() + })) + }); + } + + let tx_id_1_c4 = tx_id_1.clone(); + let tx_id_2_c4 = tx_id_2.clone(); + mock.expect_execute_streaming_sql() .once() .in_sequence(&mut seq) - .returning(move |_req| row_count_exact_response(5)); + .returning(move |req| { + let req = req.into_inner(); + if !explicit_begin { + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + let options = match selector { + v1::transaction_selector::Selector::Begin(o) => o, + _ => panic!("Expected Begin"), + }; + let read_write = options.mode.as_ref().expect("mode required"); + match read_write { + Mode::ReadWrite(rw) => { + assert_eq!(rw.multiplexed_session_previous_transaction_id, tx_id_1_c4); + } + _ => panic!("Expected ReadWrite"), + } + } + + let mut rs = v1::PartialResultSet { + metadata: Some(v1::ResultSetMetadata { + row_type: Some(v1::StructType { + fields: vec![Default::default()], + }), + ..Default::default() + }), + values: vec![prost_types::Value { + kind: Some(prost_types::value::Kind::StringValue("2".to_string())), + }], + last: true, + ..Default::default() + }; + + if !explicit_begin { + rs.metadata.as_mut().unwrap().transaction = Some(v1::Transaction { + id: tx_id_2_c4.clone(), + ..Default::default() + }); + } + + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(rs), + ])))) + }); mock.expect_commit() .once() .returning(|_req| commit_response()); - let res = execute_test_runner(mock) - .await - .expect("runner should succeed"); - assert_eq!(res, 5); + let (db_client, _server) = setup_db_client(mock).await; + let runner = TransactionRunnerBuilder::new(db_client) + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; + + let mut attempt_counter = 0; + let res = runner + .run(async |tx| { + attempt_counter += 1; + let mut rs = tx.execute_query("SELECT 1").await?; + let mut rows = vec![]; + while let Some(row_res) = rs.next().await { + rows.push(row_res?); + } + Ok(rows) + }) + .await?; + + assert_eq!(attempt_counter, 2); + assert_eq!(res.len(), 1); + assert_eq!(res[0].raw_values()[0].as_string(), "2"); Ok(()) } #[tokio::test] - async fn run_with_non_aborted_error() { + async fn execute_run_with_non_aborted_error_explicit() { + run_with_non_aborted_error(true).await; + } + + #[tokio::test] + async fn execute_run_with_non_aborted_error_inline() { + run_with_non_aborted_error(false).await; + } + + async fn run_with_non_aborted_error(explicit_begin: bool) { let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); - expect_begin_transaction(&mut mock, 1, vec![9, 9, 9]); + if explicit_begin { + expect_begin_transaction(&mut mock, 1, vec![9, 9, 9]); + } - // Let execute_sql return an error to trigger a rollback. - mock.expect_execute_sql().once().returning(move |_req| { - Err(tonic::Status::new( - tonic::Code::PermissionDenied, - "permission denied", - )) - }); + if !explicit_begin { + // First execute_sql fails + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!( + selector, + v1::transaction_selector::Selector::Begin(_) + )); + + Err(tonic::Status::new( + tonic::Code::PermissionDenied, + "permission denied", + )) + }); + + // Falls back to begin_transaction + mock.expect_begin_transaction() + .once() + .in_sequence(&mut seq) + .returning(move |_| { + Ok(tonic::Response::new(v1::Transaction { + id: vec![9, 9, 9], + ..Default::default() + })) + }); + + // Retries execute_sql and fails again + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + let transaction = req.transaction.as_ref().expect("transaction required"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!(selector, v1::transaction_selector::Selector::Id(id) if id == &vec![9_u8, 9, 9])); + + Err(tonic::Status::new( + tonic::Code::PermissionDenied, + "permission denied", + )) + }); + } else { + // Let execute_sql return an error to trigger a rollback. + mock.expect_execute_sql().once().returning(move |_req| { + Err(tonic::Status::new( + tonic::Code::PermissionDenied, + "permission denied", + )) + }); + } // Must explicitly trigger rollback mock.expect_rollback() .once() .returning(|_req| Ok(tonic::Response::new(()))); - let res = execute_test_runner(mock).await; + let res = execute_test_runner(mock, explicit_begin).await; assert!(res.is_err()); let err = res.unwrap_err(); @@ -460,25 +1014,88 @@ mod tests { } #[tokio::test] - async fn run_with_non_aborted_error_and_rollback_fails() { + async fn execute_run_with_non_aborted_error_and_rollback_fails_explicit() { + run_with_non_aborted_error_and_rollback_fails(true).await; + } + + #[tokio::test] + async fn execute_run_with_non_aborted_error_and_rollback_fails_inline() { + run_with_non_aborted_error_and_rollback_fails(false).await; + } + + async fn run_with_non_aborted_error_and_rollback_fails(explicit_begin: bool) { let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); - expect_begin_transaction(&mut mock, 1, vec![9, 9, 9]); + if explicit_begin { + expect_begin_transaction(&mut mock, 1, vec![9, 9, 9]); + } - // Let execute_sql return an error to trigger a rollback. - mock.expect_execute_sql().once().returning(move |_req| { - Err(tonic::Status::new( - tonic::Code::PermissionDenied, - "permission denied", - )) - }); + if !explicit_begin { + // First execute_sql fails + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!( + selector, + v1::transaction_selector::Selector::Begin(_) + )); + + Err(tonic::Status::new( + tonic::Code::PermissionDenied, + "permission denied", + )) + }); + + // Falls back to begin_transaction + mock.expect_begin_transaction() + .once() + .in_sequence(&mut seq) + .returning(move |_| { + Ok(tonic::Response::new(v1::Transaction { + id: vec![9, 9, 9], + ..Default::default() + })) + }); + + // Retries execute_sql and fails again + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + let transaction = req.transaction.as_ref().expect("transaction required"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!(selector, v1::transaction_selector::Selector::Id(id) if id == &vec![9_u8, 9, 9])); + + Err(tonic::Status::new( + tonic::Code::PermissionDenied, + "permission denied", + )) + }); + } else { + // Let execute_sql return an error to trigger a rollback. + mock.expect_execute_sql().once().returning(move |_req| { + Err(tonic::Status::new( + tonic::Code::PermissionDenied, + "permission denied", + )) + }); + } // Force the rollback itself to fail as well mock.expect_rollback() .once() .returning(|_req| Err(tonic::Status::new(tonic::Code::Internal, "rollback failed"))); - let res = execute_test_runner(mock).await; + let res = execute_test_runner(mock, explicit_begin).await; // Verify the user unequivocally receives the PRIMARY original error assert!(res.is_err()); @@ -494,38 +1111,132 @@ mod tests { } #[tokio::test] - async fn run_commit_aborted_retry() { - let mut mock = create_session_mock(); + async fn execute_run_commit_aborted_retry_explicit() { + run_commit_aborted_retry(true).await; + } - expect_begin_transaction(&mut mock, 2, vec![9, 9, 9]); + #[tokio::test] + async fn execute_run_commit_aborted_retry_inline() { + run_commit_aborted_retry(false).await; + } - mock.expect_execute_sql() - .times(2) - .returning(|_req| row_count_exact_response(5)); + async fn run_commit_aborted_retry(explicit_begin: bool) { + let mut mock = create_session_mock(); + + if explicit_begin { + expect_begin_transaction(&mut mock, 2, vec![9, 9, 9]); + } let mut attempt = 0; + mock.expect_execute_sql().times(2).returning(move |req| { + if !explicit_begin { + let req = req.into_inner(); + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!( + selector, + v1::transaction_selector::Selector::Begin(_) + )); + + attempt += 1; + if attempt == 2 { + let options = match selector { + v1::transaction_selector::Selector::Begin(o) => o, + _ => panic!("Expected Begin"), + }; + let read_write = options.mode.as_ref().expect("mode required"); + match read_write { + Mode::ReadWrite(rw) => { + assert_eq!( + rw.multiplexed_session_previous_transaction_id, + vec![9, 9, 9] + ); + } + _ => panic!("Expected ReadWrite"), + } + } + + let mut metadata = v1::ResultSetMetadata { + ..Default::default() + }; + metadata.transaction = Some(v1::Transaction { + id: vec![9, 9, 9], + ..Default::default() + }); + + return Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(metadata), + stats: Some(v1::ResultSetStats { + row_count: Some(v1::result_set_stats::RowCount::RowCountExact(5)), + ..Default::default() + }), + ..Default::default() + })); + } + row_count_exact_response(5) + }); + + let mut commit_attempt = 0; mock.expect_commit().times(2).returning(move |_req| { - attempt += 1; - if attempt == 1 { + commit_attempt += 1; + if commit_attempt == 1 { Err(create_aborted_status(std::time::Duration::from_nanos(1))) } else { commit_response() } }); - let res = execute_test_runner(mock).await.unwrap(); + let res = execute_test_runner(mock, explicit_begin).await.unwrap(); assert_eq!(res, 5); } #[tokio::test] - async fn run_begin_transaction_fails() { + async fn execute_run_begin_transaction_fails_explicit() { + run_begin_transaction_fails(true).await; + } + + #[tokio::test] + async fn execute_run_begin_transaction_fails_inline() { + run_begin_transaction_fails(false).await; + } + + async fn run_begin_transaction_fails(explicit_begin: bool) { let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); - mock.expect_begin_transaction() - .once() - .returning(|_req| Err(tonic::Status::new(tonic::Code::Internal, "internal error"))); + if explicit_begin { + mock.expect_begin_transaction() + .once() + .returning(|_req| Err(tonic::Status::new(tonic::Code::Internal, "internal error"))); + } else { + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!( + selector, + v1::transaction_selector::Selector::Begin(_) + )); + + Err(tonic::Status::new(tonic::Code::Internal, "internal error")) + }); + + mock.expect_begin_transaction() + .once() + .in_sequence(&mut seq) + .returning(|_req| Err(tonic::Status::new(tonic::Code::Internal, "internal error"))); + } - let res = execute_test_runner(mock).await; + let res = execute_test_runner(mock, explicit_begin).await; assert!(res.is_err()); let err = res.unwrap_err(); @@ -559,7 +1270,16 @@ mod tests { } #[tokio::test] - async fn run_batch_dml_aborted_retry() { + async fn execute_run_batch_dml_aborted_retry_explicit() { + run_batch_dml_aborted_retry(true).await; + } + + #[tokio::test] + async fn execute_run_batch_dml_aborted_retry_inline() { + run_batch_dml_aborted_retry(false).await; + } + + async fn run_batch_dml_aborted_retry(explicit_begin: bool) { use crate::batch_dml::BatchDml; use crate::statement::Statement; use gaxi::grpc::tonic::Code; @@ -568,13 +1288,28 @@ mod tests { let mut mock = create_session_mock(); - expect_begin_transaction(&mut mock, 2, vec![9, 9, 9]); + if explicit_begin { + expect_begin_transaction(&mut mock, 2, vec![9, 9, 9]); + } let mut seq = mockall::Sequence::new(); mock.expect_execute_batch_dml() .once() .in_sequence(&mut seq) - .returning(move |_req| { + .returning(move |req| { + if !explicit_begin { + let req = req.into_inner(); + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + assert!(matches!( + selector, + v1::transaction_selector::Selector::Begin(_) + )); + } + // Return a successful response but with an embedded aborted status. let status = Status { code: Code::Aborted as i32, @@ -582,8 +1317,19 @@ mod tests { ..Default::default() }; + let mut metadata = v1::ResultSetMetadata { + ..Default::default() + }; + if !explicit_begin { + metadata.transaction = Some(v1::Transaction { + id: vec![9, 9, 9], + ..Default::default() + }); + } + Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse { result_sets: vec![v1::ResultSet { + metadata: Some(metadata), stats: Some(v1::ResultSetStats { row_count: Some(RowCount::RowCountExact(1)), ..Default::default() @@ -597,10 +1343,34 @@ mod tests { mock.expect_execute_batch_dml() .once() .in_sequence(&mut seq) - .returning(move |_req| { + .returning(move |req| { + if !explicit_begin { + let req = req.into_inner(); + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + assert!(matches!( + selector, + v1::transaction_selector::Selector::Begin(_) + )); + } + + let mut metadata = v1::ResultSetMetadata { + ..Default::default() + }; + if !explicit_begin { + metadata.transaction = Some(v1::Transaction { + id: vec![9, 9, 9], + ..Default::default() + }); + } + // Return success after the retry. Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse { result_sets: vec![v1::ResultSet { + metadata: Some(metadata), stats: Some(v1::ResultSetStats { row_count: Some(RowCount::RowCountExact(5)), ..Default::default() @@ -617,6 +1387,7 @@ mod tests { let (db_client, _) = setup_db_client(mock).await; let runner = TransactionRunnerBuilder::new(db_client) + .with_explicit_begin_transaction(explicit_begin) .build() .await .expect("failed to build TransactionRunner"); @@ -640,12 +1411,38 @@ mod tests { } #[tokio::test] - async fn run_with_transaction_tag() -> anyhow::Result<()> { + async fn execute_run_with_transaction_tag_explicit() -> anyhow::Result<()> { + run_with_transaction_tag(true).await + } + + #[tokio::test] + async fn execute_run_with_transaction_tag_inline() -> anyhow::Result<()> { + run_with_transaction_tag(false).await + } + + async fn run_with_transaction_tag(explicit_begin: bool) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|req| { + if explicit_begin { + mock.expect_begin_transaction().once().returning(|req| { + let req = req.into_inner(); + // Check if the transaction tag is correctly propagated. + assert_eq!( + req.request_options + .expect("Missing request_options") + .transaction_tag, + "my-test-tag" + ); + + Ok(tonic::Response::new(v1::Transaction { + id: vec![9, 9, 9], + ..Default::default() + })) + }); + } + + mock.expect_execute_sql().once().returning(move |req| { let req = req.into_inner(); - // Check if the transaction tag is correctly propagated. assert_eq!( req.request_options .expect("Missing request_options") @@ -653,21 +1450,36 @@ mod tests { "my-test-tag" ); - Ok(tonic::Response::new(v1::Transaction { - id: vec![9, 9, 9], + if !explicit_begin { + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!( + selector, + v1::transaction_selector::Selector::Begin(_) + )); + } + + let mut metadata = v1::ResultSetMetadata { ..Default::default() - })) - }); + }; + if !explicit_begin { + metadata.transaction = Some(v1::Transaction { + id: vec![9, 9, 9], + ..Default::default() + }); + } - mock.expect_execute_sql().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.request_options - .expect("Missing request_options") - .transaction_tag, - "my-test-tag" - ); - row_count_exact_response(5) + Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(metadata), + stats: Some(v1::ResultSetStats { + row_count: Some(v1::result_set_stats::RowCount::RowCountExact(5)), + ..Default::default() + }), + ..Default::default() + })) }); mock.expect_commit().once().returning(|req| { @@ -684,6 +1496,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let runner = TransactionRunnerBuilder::new(db_client) + .with_explicit_begin_transaction(explicit_begin) .with_transaction_tag("my-test-tag") .build() .await?; diff --git a/tests/spanner/Cargo.toml b/tests/spanner/Cargo.toml index 461c18d110..2869708160 100644 --- a/tests/spanner/Cargo.toml +++ b/tests/spanner/Cargo.toml @@ -33,10 +33,16 @@ google-cloud-gax = { workspace = true } google-cloud-lro = { workspace = true } google-cloud-spanner = { workspace = true, features = ["unstable-stream"] } google-cloud-test-utils = { workspace = true } +google-cloud-wkt = { workspace = true } prost-types.workspace = true +rand = { workspace = true } reqwest = { workspace = true, features = ["json"] } serde_json = { workspace = true } +spanner-grpc-mock = { path = "../../src/spanner/grpc-mock" } +time = { workspace = true } tokio = { workspace = true, features = ["sync"] } +tokio-stream = { workspace = true } +tonic = { workspace = true } tracing.workspace = true [lints] diff --git a/tests/spanner/src/client.rs b/tests/spanner/src/client.rs index 7b05ecf3c5..3516bc011c 100644 --- a/tests/spanner/src/client.rs +++ b/tests/spanner/src/client.rs @@ -14,6 +14,8 @@ use google_cloud_spanner::client::{KeySet, Mutation, Spanner}; use google_cloud_test_utils::resource_names::LowercaseAlphanumeric; +use std::time::Duration; +use tokio::time::sleep; const PROJECT_ID: &str = "test-project"; const INSTANCE_ID: &str = "test-instance"; @@ -40,7 +42,7 @@ pub async fn wait_for_emulator(endpoint: &str) { static PROVISION_EMULATOR: tokio::sync::OnceCell<()> = tokio::sync::OnceCell::const_new(); static DATABASE_ID: tokio::sync::OnceCell = tokio::sync::OnceCell::const_new(); -async fn get_database_id() -> &'static str { +pub async fn get_database_id() -> &'static str { DATABASE_ID .get_or_init(|| async { std::env::var("SPANNER_EMULATOR_TEST_DB") @@ -59,16 +61,19 @@ pub async fn provision_emulator(endpoint: &str) { .await; } +pub fn get_emulator_rest_endpoint(grpc_endpoint: &str) -> String { + let rest_endpoint = std::env::var("SPANNER_EMULATOR_REST_HOST") + .unwrap_or_else(|_| grpc_endpoint.replace("9010", "9020")); + if rest_endpoint.starts_with("http://") || rest_endpoint.starts_with("https://") { + rest_endpoint + } else { + format!("http://{}", rest_endpoint) + } +} + async fn do_provision_emulator(endpoint: &str) { // TODO(#4973): Re-write this to use the admin clients once those also support the Emulator. - let rest_endpoint = std::env::var("SPANNER_EMULATOR_REST_HOST") - .unwrap_or_else(|_| endpoint.replace("9010", "9020")); - let rest_endpoint = - if rest_endpoint.starts_with("http://") || rest_endpoint.starts_with("https://") { - rest_endpoint - } else { - format!("http://{}", rest_endpoint) - }; + let rest_endpoint = get_emulator_rest_endpoint(endpoint); let client = reqwest::Client::new(); // Create a test instance and ignore any ALREADY_EXISTS errors. @@ -196,3 +201,50 @@ pub async fn create_database_client() -> Option anyhow::Result<()> { + let emulator_host = get_emulator_host().expect("SPANNER_EMULATOR_HOST must be set"); + let rest_endpoint = get_emulator_rest_endpoint(&emulator_host); + let db_path = format!( + "projects/{}/instances/{}/databases/{}", + PROJECT_ID, + INSTANCE_ID, + get_database_id().await + ); + let url = format!("{}/v1/{}/ddl", rest_endpoint, db_path); + let client = reqwest::Client::new(); + let payload = serde_json::json!({ + "statements": [statement] + }); + + let mut attempts = 0; + const MAX_ATTEMPTS: u32 = 25; + + loop { + attempts += 1; + let res = client.patch(&url).json(&payload).send().await?; + + let status = res.status(); + let text = res.text().await?; + + if status.is_success() { + return Ok(()); + } + + // Check if the error is the specific one we want to retry. + // Code 9 is FailedPrecondition. + if text.contains("\"code\":9") && text.contains("Schema change operation rejected") { + if attempts >= MAX_ATTEMPTS { + anyhow::bail!( + "Failed to update DDL after {} attempts. Last error: {}", + attempts, + text + ); + } + sleep(Duration::from_millis(100)).await; + continue; + } + + anyhow::bail!("Failed to update DDL: status={}, body={}", status, text); + } +} diff --git a/tests/spanner/src/concurrent_inline_begin.rs b/tests/spanner/src/concurrent_inline_begin.rs new file mode 100644 index 0000000000..c191a4bd73 --- /dev/null +++ b/tests/spanner/src/concurrent_inline_begin.rs @@ -0,0 +1,264 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::client::{get_database_id, get_emulator_host, provision_emulator, update_database_ddl}; +use crate::test_proxy::{InterceptedSpanner, SpannerInterceptor}; +use futures::stream::{self, StreamExt}; +use google_cloud_spanner::client::{ResultSet, Row, Spanner, TimestampBound}; +use google_cloud_test_utils::resource_names::LowercaseAlphanumeric; +use spanner_grpc_mock::google::spanner::v1 as spanner_v1; +use spanner_grpc_mock::google::spanner::v1::spanner_client::SpannerClient; +use std::collections::HashMap; +use std::sync::Arc; +use time::OffsetDateTime; +use tokio::net::TcpListener; +use tokio::sync::{Barrier, Mutex}; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::transport::{Channel, Server}; + +/// An interceptor that injects transient (Unavailable) and permanent (Internal) failures +/// into streaming SQL responses for specific query patterns. +pub struct ConcurrentFaultInterceptor { + emulator_client: SpannerClient, + /// Tracks failure counts to allow transient recovery. + failure_counts: Arc>>, +} + +impl ConcurrentFaultInterceptor { + pub fn new(emulator_client: SpannerClient) -> Self { + Self { + emulator_client, + failure_counts: Arc::new(Mutex::new(HashMap::new())), + } + } +} + +#[tonic::async_trait] +impl SpannerInterceptor for ConcurrentFaultInterceptor { + fn emulator_client(&self) -> SpannerClient { + self.emulator_client.clone() + } + + async fn execute_streaming_sql( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + let sql = request.get_ref().sql.clone(); + + // Emulates a transient stream failure. + if sql.starts_with("SELECT 'Transient-") { + let mut counts = self.failure_counts.lock().await; + let count = counts.entry(sql.clone()).or_insert(0); + if *count == 0 { + *count += 1; + // Return a stream that fails immediately with Unavailable. + let stream = stream::once(async { + Err(tonic::Status::unavailable("Transient stream failure")) + }); + return Ok(tonic::Response::new(stream.boxed())); + } + // Second attempt succeeds (fall through to emulator). + } + + // Emulates a permanent stream failure. + if sql == "SELECT 'Permanent'" { + // Returns a stream that always fails with an Internal error. + let stream = + stream::once(async { Err(tonic::Status::internal("Permanent stream failure")) }); + return Ok(tonic::Response::new(stream.boxed())); + } + + // Forward other queries to the emulator. + let res = self + .emulator_client() + .execute_streaming_sql(request) + .await?; + let (metadata, stream, extensions) = res.into_parts(); + Ok(tonic::Response::from_parts( + metadata, + stream.boxed(), + extensions, + )) + } +} + +pub async fn test_concurrent_inline_begin_with_snapshot_consistency() -> anyhow::Result<()> { + let emulator_host = match get_emulator_host() { + Some(host) => host, + None => return Ok(()), + }; + provision_emulator(&emulator_host).await; + let db_id = get_database_id().await; + let db_path = format!( + "projects/test-project/instances/test-instance/databases/{}", + db_id + ); + + // 1. Setup Table 1 (Exists at snapshot time) + let suffix = LowercaseAlphanumeric.random_string(6); + let table_success = format!("TableSuccess_{}", suffix); + let table_not_found = format!("TableNotFound_{}", suffix); + + let statement = format!("CREATE TABLE {} (Id INT64) PRIMARY KEY (Id)", table_success); + update_database_ddl(statement).await?; + + // 2. Capture snapshot time. + let spanner = Spanner::builder() + .with_endpoint(format!("http://{}", emulator_host)) + .build() + .await?; + let db_client = spanner.database_client(&db_path).build().await?; + + let mut rs: ResultSet = db_client + .single_use() + .build() + .execute_query("SELECT CURRENT_TIMESTAMP") + .await?; + let row: Row = rs.next().await.unwrap().unwrap(); + let snapshot_time: OffsetDateTime = row.try_get(0)?; + + // 3. Setup Table 2 (Does NOT exist at snapshot time) + let statement = format!( + "CREATE TABLE {} (Id INT64) PRIMARY KEY (Id)", + table_not_found + ); + update_database_ddl(statement).await?; + + // 4. Start the Intercepted Server + let listener = TcpListener::bind("127.0.0.1:0").await?; + let local_addr = listener.local_addr()?; + let emulator_channel = Channel::from_shared(format!("http://{}", emulator_host))? + .connect() + .await?; + let interceptor = ConcurrentFaultInterceptor::new(SpannerClient::new(emulator_channel)); + let service = InterceptedSpanner(interceptor); + + tokio::spawn(async move { + Server::builder() + .add_service(spanner_v1::spanner_server::SpannerServer::new(service)) + .serve_with_incoming(TcpListenerStream::new(listener)) + .await + .expect("Server failed"); + }); + + // 5. Build Client pointing to Interceptor + let intercepted_spanner = Spanner::builder() + .with_endpoint(format!("http://{}", local_addr)) + .build() + .await?; + let intercepted_db = intercepted_spanner + .database_client(&db_path) + .build() + .await?; + + // 6. Spawn 20 tasks with random workloads + let tx = intercepted_db + .read_only_transaction() + .with_timestamp_bound(TimestampBound::read_timestamp(snapshot_time)) + .with_explicit_begin_transaction(false) + .build() + .await?; + let tx = Arc::new(tx); + let barrier = Arc::new(Barrier::new(20)); + let mut handles = Vec::new(); + + for i in 0..20 { + let role = rand::random_range(0..4); + let tx = Arc::clone(&tx); + let barrier = Arc::clone(&barrier); + let table_success = table_success.clone(); + let table_not_found = table_not_found.clone(); + + handles.push(tokio::spawn(async move { + barrier.wait().await; + match role { + 0 => { + // Success + let mut result_set: ResultSet = tx + .execute_query(format!("SELECT * FROM {}", table_success)) + .await?; + while let Some(row) = result_set.next().await { + row?; + } + Ok::<_, anyhow::Error>(format!("Task {} Success: OK", i)) + } + 1 => { + // Table not found + let res: Result = tx + .execute_query(format!("SELECT * FROM {}", table_not_found)) + .await; + match res { + Err(e) + if e.to_string().contains("not found") + || e.to_string().contains("NotFound") => + { + Ok(format!("Task {} NotFound: OK", i)) + } + Ok(_) => anyhow::bail!("Task {} expected NotFound but got Success", i), + Err(e) => anyhow::bail!("Task {} expected NotFound but got: {:?}", i, e), + } + } + 2 => { + // Transient stream error. This will trigger a retry of the stream. + let sql = format!("SELECT 'Transient-{}'", i); + let mut result_set: ResultSet = tx.execute_query(sql).await?; + while let Some(row) = result_set.next().await { + row?; + } + Ok(format!("Task {} Transient: OK", i)) + } + 3 => { + // Permanent stream error. + let result_set_res: Result = + tx.execute_query("SELECT 'Permanent'").await; + let mut result_set = match result_set_res { + Ok(rs) => rs, + Err(e) => anyhow::bail!( + "Task {} expected successful RPC initiation but got: {:?}", + i, + e + ), + }; + + let next = result_set.next().await; + match next { + Some(Err(e)) + if e.to_string().contains("Permanent") + || e.to_string().contains("Internal") => + { + Ok(format!("Task {} Permanent: OK", i)) + } + Some(Ok(_)) => { + anyhow::bail!("Task {} expected Permanent error but got a valid row", i) + } + _ => anyhow::bail!( + "Task {} expected Permanent error but succeeded or got empty results", + i + ), + } + } + _ => unreachable!(), + } + })); + } + + for handle in handles { + handle.await??; + } + + Ok(()) +} diff --git a/tests/spanner/src/lib.rs b/tests/spanner/src/lib.rs index 2be88b7a18..ad0413d66a 100644 --- a/tests/spanner/src/lib.rs +++ b/tests/spanner/src/lib.rs @@ -14,8 +14,10 @@ pub mod batch_read_only_transaction; pub mod client; +pub mod concurrent_inline_begin; pub mod partitioned_dml; pub mod query; pub mod read; pub mod read_write_transaction; +pub mod test_proxy; pub mod write; diff --git a/tests/spanner/src/query.rs b/tests/spanner/src/query.rs index e38b51b989..badb161f37 100644 --- a/tests/spanner/src/query.rs +++ b/tests/spanner/src/query.rs @@ -12,7 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use google_cloud_spanner::client::{DatabaseClient, Kind, Statement}; +use crate::client::{get_database_id, get_emulator_host}; +use crate::test_proxy::{InterceptedSpanner, SpannerInterceptor}; +use google_cloud_spanner::client::{DatabaseClient, Kind, Spanner, Statement}; +use google_cloud_test_utils::resource_names::LowercaseAlphanumeric; +use spanner_grpc_mock::google::spanner::v1 as spanner_v1; +use spanner_grpc_mock::google::spanner::v1::spanner_client::SpannerClient; +use spanner_grpc_mock::google::spanner::v1::spanner_server::SpannerServer; +use std::sync::Arc; +use tokio::net::TcpListener; +use tokio::sync::Notify; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::transport::{Channel, Server}; pub async fn simple_query(db_client: &DatabaseClient) -> anyhow::Result<()> { let rot = db_client.single_use().build(); @@ -147,14 +158,14 @@ pub async fn result_set_metadata(db_client: &DatabaseClient) -> anyhow::Result<( // 1. Simple normal query let sql = "SELECT 1 as num, 'Alice' as name"; - let mut rs = rot.execute_query(Statement::builder(sql).build()).await?; + let mut result_set = rot.execute_query(Statement::builder(sql).build()).await?; - assert!(rs.next().await.transpose()?.is_some()); - let metadata = rs.metadata()?; + let metadata = result_set.metadata().await?; assert_eq!( metadata.column_names(), &["num".to_string(), "name".to_string()] ); + assert!(result_set.next().await.transpose()?.is_some()); // 2. Query that returns zero rows let sql_zero_rows = r#" @@ -163,25 +174,25 @@ pub async fn result_set_metadata(db_client: &DatabaseClient) -> anyhow::Result<( ) SELECT num, name FROM Data WHERE 1=0 "#; - let mut rs_zero_rows = rot + let mut result_set_zero_rows = rot .execute_query(Statement::builder(sql_zero_rows).build()) .await?; - assert!(rs_zero_rows.next().await.transpose()?.is_none()); - let metadata_zero_rows = rs_zero_rows.metadata()?; + let metadata_zero_rows = result_set_zero_rows.metadata().await?; assert_eq!( metadata_zero_rows.column_names(), &["num".to_string(), "name".to_string()] ); + assert!(result_set_zero_rows.next().await.transpose()?.is_none()); // 3. Query with duplicate aliases let sql_dup = "SELECT 1 as dup, 2 as dup"; - let mut rs_dup = rot + let mut result_set_dup = rot .execute_query(Statement::builder(sql_dup).build()) .await?; - let row_dup = rs_dup.next().await.transpose()?.unwrap(); - let metadata_dup = rs_dup.metadata()?; + let row_dup = result_set_dup.next().await.transpose()?.unwrap(); + let metadata_dup = result_set_dup.metadata().await?; assert_eq!( metadata_dup.column_names(), &["dup".to_string(), "dup".to_string()] @@ -194,17 +205,40 @@ pub async fn result_set_metadata(db_client: &DatabaseClient) -> anyhow::Result<( } pub async fn multi_use_read_only_transaction(db_client: &DatabaseClient) -> anyhow::Result<()> { + for explicit_begin in [false, true] { + test_multi_use_read_only_transaction(db_client, explicit_begin).await?; + } + Ok(()) +} + +async fn test_multi_use_read_only_transaction( + db_client: &DatabaseClient, + explicit_begin: bool, +) -> anyhow::Result<()> { // Start a multi-use read-only transaction. - let tx = db_client.read_only_transaction().build().await?; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; - // Expect a read timestamp to have been chosen. - assert!(tx.read_timestamp().is_some()); + if explicit_begin { + // Expect a read timestamp to have been chosen immediately. + assert!(tx.read_timestamp().is_some()); + } else { + // Expect a read timestamp to NOT have been chosen yet. + assert!(tx.read_timestamp().is_none()); + } // Execute the first query. let mut rs1 = tx .execute_query(Statement::builder("SELECT 1 AS col_int").build()) .await?; let row1 = rs1.next().await.transpose()?.expect("should yield a row"); + + // The read timestamp is now always available. + assert!(tx.read_timestamp().is_some()); + let val1 = row1.raw_values()[0].as_string(); assert_eq!(val1, "1"); let next1 = rs1.next().await.transpose()?; @@ -223,6 +257,45 @@ pub async fn multi_use_read_only_transaction(db_client: &DatabaseClient) -> anyh Ok(()) } +pub async fn multi_use_read_only_transaction_invalid_query_fallback( + db_client: &DatabaseClient, +) -> anyhow::Result<()> { + // Start a multi-use read-only transaction with implicit begin. + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // Expect a read timestamp to NOT have been chosen yet. + assert!(tx.read_timestamp().is_none()); + + // Execute the first query with invalid syntax. + let rs_result = tx + .execute_query(Statement::builder("SELECT * FROM NonExistentTable").build()) + .await; + + assert!( + rs_result.is_err(), + "Expected an error from an invalid query" + ); + + // The read timestamp should now be available because the transaction + // fell back to an explicit BeginTransaction. + assert!(tx.read_timestamp().is_some()); + + // It should be possible to use the transaction. + let mut rs2 = tx + .execute_query(Statement::builder("SELECT 2 AS col_int").build()) + .await?; + + let row2 = rs2.next().await.transpose()?.expect("should yield a row"); + let val2 = row2.raw_values()[0].as_string(); + assert_eq!(val2, "2"); + + Ok(()) +} + fn verify_null_row(row: &google_cloud_spanner::client::Row) { let raw_values = row.raw_values(); assert_eq!(raw_values.len(), 20, "Row should have exactly 20 columns"); @@ -346,3 +419,138 @@ fn verify_row_2(row: &google_cloud_spanner::client::Row) { "2026-03-11T16:20:00Z" ); } + +struct DelayedBeginProxy { + emulator_client: SpannerClient, + latch: Arc, + begin_transaction_entered_latch: Arc, +} + +#[tonic::async_trait] +impl SpannerInterceptor for DelayedBeginProxy { + fn emulator_client(&self) -> SpannerClient { + self.emulator_client.clone() + } + + async fn begin_transaction( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.begin_transaction_entered_latch.notify_one(); + self.latch.notified().await; + self.emulator_client().begin_transaction(request).await + } +} + +// This test verifies that the client correctly falls back to `BeginTransaction` when the +// first statement in a transaction fails. It also shows that the statement is retried and +// could (theoretically) succeed during this retry. It achieves this by doing the following: +// 1. It uses a proxy that allows it to intercept the RPCs that are being sent to Spanner. +// 2. It creates a read-only transaction that uses inline-begin-transaction. +// 3. It executes a query that tries to read from a table that does not exist. +// 4. As the first statement in the transaction fails, the client falls back to using +// an explicit BeginTransaction RPC. +// 5. The proxy blocks this BeginTransaction RPC, and in the meantime the test creates +// the missing table. +// 6. The proxy unblocks the BeginTransaction RPC. +// 7. The statement is retried and succeeds. The test never sees the error. +// +// This test might seem like an extreme corner case for a read-only transaction like this. +// However, for read/write transactions, similar types of failures are more likely to occur, +// for example if a transaction tries to insert a row that violates the primary key. Another +// transaction could delete the row in the time between the first attempt failed, and the +// BeginTransaction RPC has been executed. +pub async fn inline_begin_fallback(_db_client: &DatabaseClient) -> anyhow::Result<()> { + let emulator_host = get_emulator_host().expect("SPANNER_EMULATOR_HOST must be set"); + let latch = Arc::new(Notify::new()); + let begin_transaction_entered_latch = Arc::new(Notify::new()); + + // Create a raw gRPC client that connects to the Spanner Emulator. + // This will be used by the proxy server to forward requests to the Emulator. + let endpoint = Channel::from_shared(format!("http://{}", emulator_host))? + .connect() + .await?; + let raw_client = SpannerClient::new(endpoint); + + // Create a local TCP listener to bind our proxy server to. + let listener = TcpListener::bind("127.0.0.1:0").await?; + let local_addr = listener.local_addr()?; + let proxy_address = format!("{}:{}", local_addr.ip(), local_addr.port()); + + let proxy = DelayedBeginProxy { + emulator_client: raw_client, + latch: Arc::clone(&latch), + begin_transaction_entered_latch: Arc::clone(&begin_transaction_entered_latch), + }; + + let _server_handle = tokio::spawn(async move { + let stream = TcpListenerStream::new(listener); + Server::builder() + .add_service(SpannerServer::new(InterceptedSpanner(proxy))) + .serve_with_incoming(stream) + .await + .expect("Proxy server failed"); + }); + + // We build the Spanner DatabaseClient pointing directly to our proxy address over HTTP. + let proxy_db_client = Spanner::builder() + .with_endpoint(format!("http://{}", proxy_address)) + .build() + .await? + .database_client(format!( + "projects/test-project/instances/test-instance/databases/{}", + get_database_id().await + )) + .build() + .await?; + + let tx = proxy_db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let table_name = LowercaseAlphanumeric.random_string(10); + let table_name = format!("LateLoadedTable_{}", table_name); + + // Create a task that tries to query the table before it exists. + // This will initially fail, and the client will fall back to using + // an explicit BeginTransaction RPC. The table will then be created + // BEFORE the BeginTransaction RPC is executed, which will cause the + // query to succeed when it is retried using the transaction ID that + // was returned by BeginTransaction. This task will never see the + // initial error, and instead it will seem like the query simply + // succeeded. + let query_task = tokio::spawn({ + let table_name = table_name.clone(); + async move { + let stmt = Statement::builder(format!("SELECT * FROM {}", table_name)).build(); + let mut rs = tx.execute_query(stmt).await?; + let _ = rs.next().await; + Ok::<_, anyhow::Error>(tx) + } + }); + + // Wait until the query task above has been executed and has triggered an + // explicit BeginTransaction RPC. The BeginTransaction RPC is blocked until + // `latch` is notified. + begin_transaction_entered_latch.notified().await; + + // Create the table on the emulator while the BeginTransaction RPC is blocked. + let statement = format!("CREATE TABLE {} (Id INT64) PRIMARY KEY (Id)", table_name); + crate::client::update_database_ddl(statement).await?; + + // Unblock the BeginTransaction RPC. + latch.notify_one(); + + // Wait for the query task to complete. It should succeed and never see + // the initial error. + let tx = query_task.await??; + + assert!( + tx.read_timestamp().is_some(), + "The transaction should have a read timestamp" + ); + + Ok(()) +} diff --git a/tests/spanner/src/test_proxy.rs b/tests/spanner/src/test_proxy.rs new file mode 100644 index 0000000000..f7a07d881f --- /dev/null +++ b/tests/spanner/src/test_proxy.rs @@ -0,0 +1,303 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use futures::stream::{BoxStream, StreamExt}; +use spanner_grpc_mock::google::spanner::v1 as spanner_v1; +use spanner_grpc_mock::google::spanner::v1::spanner_client::SpannerClient; + +pub type ExecuteStreamingSqlStream = + BoxStream<'static, std::result::Result>; + +#[tonic::async_trait] +pub trait SpannerInterceptor: Send + Sync + 'static { + fn emulator_client(&self) -> SpannerClient; + + async fn create_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().create_session(request).await + } + + async fn batch_create_sessions( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> + { + self.emulator_client().batch_create_sessions(request).await + } + + async fn get_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().get_session(request).await + } + + async fn list_sessions( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().list_sessions(request).await + } + + async fn delete_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().delete_session(request).await + } + + async fn execute_sql( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().execute_sql(request).await + } + + async fn execute_streaming_sql( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response< + BoxStream<'static, std::result::Result>, + >, + tonic::Status, + > { + let res = self + .emulator_client() + .execute_streaming_sql(request) + .await?; + let (metadata, stream, extensions) = res.into_parts(); + Ok(tonic::Response::from_parts( + metadata, + stream.boxed(), + extensions, + )) + } + + async fn execute_batch_dml( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> + { + self.emulator_client().execute_batch_dml(request).await + } + + async fn read( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().read(request).await + } + + async fn streaming_read( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response< + BoxStream<'static, std::result::Result>, + >, + tonic::Status, + > { + let res = self.emulator_client().streaming_read(request).await?; + let (metadata, stream, extensions) = res.into_parts(); + Ok(tonic::Response::from_parts( + metadata, + stream.boxed(), + extensions, + )) + } + + async fn begin_transaction( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().begin_transaction(request).await + } + + async fn commit( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().commit(request).await + } + + async fn rollback( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().rollback(request).await + } + + async fn partition_query( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().partition_query(request).await + } + + async fn partition_read( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().partition_read(request).await + } + + async fn batch_write( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response< + BoxStream<'static, std::result::Result>, + >, + tonic::Status, + > { + let res = self.emulator_client().batch_write(request).await?; + let (metadata, stream, extensions) = res.into_parts(); + Ok(tonic::Response::from_parts( + metadata, + stream.boxed(), + extensions, + )) + } +} + +pub struct InterceptedSpanner(pub T); + +#[tonic::async_trait] +impl spanner_v1::spanner_server::Spanner for InterceptedSpanner { + async fn create_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.create_session(request).await + } + + async fn batch_create_sessions( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> + { + self.0.batch_create_sessions(request).await + } + + async fn get_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.get_session(request).await + } + + async fn list_sessions( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.list_sessions(request).await + } + + async fn delete_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.delete_session(request).await + } + + async fn execute_sql( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.execute_sql(request).await + } + + type ExecuteStreamingSqlStream = + BoxStream<'static, std::result::Result>; + + async fn execute_streaming_sql( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.execute_streaming_sql(request).await + } + + async fn execute_batch_dml( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> + { + self.0.execute_batch_dml(request).await + } + + async fn read( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.read(request).await + } + + type StreamingReadStream = + BoxStream<'static, std::result::Result>; + + async fn streaming_read( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.streaming_read(request).await + } + + async fn begin_transaction( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.begin_transaction(request).await + } + + async fn commit( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.commit(request).await + } + + async fn rollback( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.rollback(request).await + } + + async fn partition_query( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.partition_query(request).await + } + + async fn partition_read( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.partition_read(request).await + } + + type BatchWriteStream = + BoxStream<'static, std::result::Result>; + + async fn batch_write( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.batch_write(request).await + } +} diff --git a/tests/spanner/src/write.rs b/tests/spanner/src/write.rs index cd2f9ae485..f0007df60b 100644 --- a/tests/spanner/src/write.rs +++ b/tests/spanner/src/write.rs @@ -526,6 +526,7 @@ async fn write_internal( let metadata = rs .metadata() + .await .expect("result set metadata is unexpectedly missing"); let column_count = metadata.column_names().len(); assert_eq!(row2.raw_values().len(), column_count); diff --git a/tests/spanner/tests/driver.rs b/tests/spanner/tests/driver.rs index d27cc10dc5..ff90e2dd92 100644 --- a/tests/spanner/tests/driver.rs +++ b/tests/spanner/tests/driver.rs @@ -26,6 +26,11 @@ mod spanner { integration_tests_spanner::query::query_with_parameters(&db_client).await?; integration_tests_spanner::query::result_set_metadata(&db_client).await?; integration_tests_spanner::query::multi_use_read_only_transaction(&db_client).await?; + integration_tests_spanner::query::multi_use_read_only_transaction_invalid_query_fallback( + &db_client, + ) + .await?; + integration_tests_spanner::query::inline_begin_fallback(&db_client).await?; Ok(()) } @@ -110,4 +115,9 @@ mod spanner { Ok(()) } + + #[tokio::test] + async fn run_concurrent_inline_begin_tests() -> anyhow::Result<()> { + integration_tests_spanner::concurrent_inline_begin::test_concurrent_inline_begin_with_snapshot_consistency().await + } }