From e5948f402d54b80b4a641d60f27baabdb8f3a0b3 Mon Sep 17 00:00:00 2001 From: aIbrahiim Date: Tue, 7 Apr 2026 12:22:02 +0200 Subject: [PATCH 1/2] Fix ML flakes --- .../ml/inference/pytorch_inference.py | 16 +++++++++++++++- .../ml/rag/ingestion/milvus_search_it_test.py | 4 ++-- sdks/python/apache_beam/ml/rag/test_utils.py | 4 ++-- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 8dc4b5c43778..be14c40b70fa 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -56,6 +56,20 @@ Iterable[PredictionResult]] +def _cuda_device_is_usable() -> bool: + """Returns True only when CUDA can actually allocate tensors.""" + if not torch.cuda.is_available(): + return False + try: + # Some environments report CUDA available but fail at first real use + # because a driver is missing or inaccessible. + torch.empty(1, device='cuda') + return True + except Exception: # pylint: disable=broad-except + logging.debug("CUDA probe failed", exc_info=True) + return False + + def _validate_constructor_args( state_dict_path, model_class, torch_script_model_path): message = ( @@ -86,7 +100,7 @@ def _load_model( model_params: Optional[dict[str, Any]], torch_script_model_path: Optional[str], load_model_args: Optional[dict[str, Any]]): - if device == torch.device('cuda') and not torch.cuda.is_available(): + if device == torch.device('cuda') and not _cuda_device_is_usable(): logging.warning( "Model handler specified a 'GPU' device, but GPUs are not available. " "Switching to CPU.") diff --git a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py index b6e5083ea728..8c5dcaa88b3f 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py @@ -204,8 +204,8 @@ def create_client(): self._test_client = retry_with_backoff( create_client, - max_retries=3, - retry_delay=1.0, + max_retries=5, + retry_delay=2.0, operation_name="Test Milvus client connection", exception_types=(MilvusException, )) diff --git a/sdks/python/apache_beam/ml/rag/test_utils.py b/sdks/python/apache_beam/ml/rag/test_utils.py index f4acb105892c..b200176d4445 100644 --- a/sdks/python/apache_beam/ml/rag/test_utils.py +++ b/sdks/python/apache_beam/ml/rag/test_utils.py @@ -204,8 +204,8 @@ def create_client(): client = retry_with_backoff( create_client, - max_retries=3, - retry_delay=1.0, + max_retries=5, + retry_delay=2.0, operation_name="Test Milvus client connection", exception_types=(MilvusException, )) From e1d69287c7849e3c31bf50cc76da2e994f755e17 Mon Sep 17 00:00:00 2001 From: aIbrahiim Date: Tue, 7 Apr 2026 22:09:00 +0200 Subject: [PATCH 2/2] changed Log CUDA probe failures to warning --- sdks/python/apache_beam/ml/inference/pytorch_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index be14c40b70fa..5c6b41568c54 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -66,7 +66,7 @@ def _cuda_device_is_usable() -> bool: torch.empty(1, device='cuda') return True except Exception: # pylint: disable=broad-except - logging.debug("CUDA probe failed", exc_info=True) + logging.warning("CUDA probe failed", exc_info=True) return False