diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 8dc4b5c43778..5c6b41568c54 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -56,6 +56,20 @@ Iterable[PredictionResult]] +def _cuda_device_is_usable() -> bool: + """Returns True only when CUDA can actually allocate tensors.""" + if not torch.cuda.is_available(): + return False + try: + # Some environments report CUDA available but fail at first real use + # because a driver is missing or inaccessible. + torch.empty(1, device='cuda') + return True + except Exception: # pylint: disable=broad-except + logging.warning("CUDA probe failed", exc_info=True) + return False + + def _validate_constructor_args( state_dict_path, model_class, torch_script_model_path): message = ( @@ -86,7 +100,7 @@ def _load_model( model_params: Optional[dict[str, Any]], torch_script_model_path: Optional[str], load_model_args: Optional[dict[str, Any]]): - if device == torch.device('cuda') and not torch.cuda.is_available(): + if device == torch.device('cuda') and not _cuda_device_is_usable(): logging.warning( "Model handler specified a 'GPU' device, but GPUs are not available. " "Switching to CPU.") diff --git a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py index b6e5083ea728..8c5dcaa88b3f 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py @@ -204,8 +204,8 @@ def create_client(): self._test_client = retry_with_backoff( create_client, - max_retries=3, - retry_delay=1.0, + max_retries=5, + retry_delay=2.0, operation_name="Test Milvus client connection", exception_types=(MilvusException, )) diff --git a/sdks/python/apache_beam/ml/rag/test_utils.py b/sdks/python/apache_beam/ml/rag/test_utils.py index f4acb105892c..b200176d4445 100644 --- a/sdks/python/apache_beam/ml/rag/test_utils.py +++ b/sdks/python/apache_beam/ml/rag/test_utils.py @@ -204,8 +204,8 @@ def create_client(): client = retry_with_backoff( create_client, - max_retries=3, - retry_delay=1.0, + max_retries=5, + retry_delay=2.0, operation_name="Test Milvus client connection", exception_types=(MilvusException, ))