diff --git a/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/core/glue_data_catalog/data_catalog_handler.py b/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/core/glue_data_catalog/data_catalog_handler.py index 4330ac1077..460dadf35a 100644 --- a/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/core/glue_data_catalog/data_catalog_handler.py +++ b/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/core/glue_data_catalog/data_catalog_handler.py @@ -21,21 +21,31 @@ import json from awslabs.aws_dataprocessing_mcp_server.models.data_catalog_models import ( + BatchDeleteConnectionData, ConnectionSummary, + ConnectionTypeBrief, CreateCatalogData, CreateConnectionData, CreatePartitionData, DeleteCatalogData, DeleteConnectionData, DeletePartitionData, + DescribeConnectionTypeData, + DescribeEntityData, + EntitySummary, + FieldSummary, GetCatalogData, GetConnectionData, + GetEntityRecordsData, GetPartitionData, ImportCatalogData, ListCatalogsData, + ListConnectionTypesData, ListConnectionsData, + ListEntitiesData, ListPartitionsData, PartitionSummary, + TestConnectionData, UpdateConnectionData, UpdatePartitionData, ) @@ -515,6 +525,536 @@ async def update_connection( content=[TextContent(type='text', text=error_message)], ) + async def test_connection( + self, + ctx: Context, + connection_name: Optional[str] = None, + catalog_id: Optional[str] = None, + test_connection_input: Optional[Dict[str, Any]] = None, + ) -> CallToolResult: + """Test a connection to validate service credentials. + + Tests a connection to a service to validate the service credentials. + You can either provide an existing connection name or a TestConnectionInput + for testing a non-existing connection input. + + Args: + ctx: MCP context containing request information + connection_name: Optional name of an existing connection to test + catalog_id: Optional catalog ID (defaults to AWS account ID) + test_connection_input: Optional TestConnectionInput for testing without an existing connection + + Returns: + TestConnectionResponse with the result of the operation + """ + try: + kwargs: Dict[str, Any] = {} + if connection_name: + kwargs['ConnectionName'] = connection_name + if catalog_id: + kwargs['CatalogId'] = catalog_id + if test_connection_input: + kwargs['TestConnectionInput'] = test_connection_input + + self.glue_client.test_connection(**kwargs) + + log_with_request_id( + ctx, + LogLevel.INFO, + f'Successfully tested connection{" for: " + connection_name if connection_name else ""}', + ) + + success_msg = f'Successfully tested connection{" for: " + connection_name if connection_name else ""}. The connection credentials are valid.' + data = TestConnectionData( + connection_name=connection_name, + catalog_id=catalog_id or '', + ) + + return CallToolResult( + isError=False, + content=[ + TextContent(type='text', text=success_msg), + TextContent(type='text', text=json.dumps(data.model_dump())), + ], + ) + + except ClientError as e: + error_code = e.response['Error']['Code'] + error_message = f'Failed to test connection{" " + connection_name if connection_name else ""}: {error_code} - {e.response["Error"]["Message"]}' + log_with_request_id(ctx, LogLevel.ERROR, error_message) + + return CallToolResult( + isError=True, + content=[TextContent(type='text', text=error_message)], + ) + + async def batch_delete_connection( + self, + ctx: Context, + connection_name_list: List[str], + catalog_id: Optional[str] = None, + ) -> CallToolResult: + """Delete a list of connections from the Data Catalog. + + Deletes multiple connections in a single call. Only connections managed + by the MCP server (with appropriate tags) will be deleted. + + Args: + ctx: MCP context containing request information + connection_name_list: List of connection names to delete + catalog_id: Optional catalog ID (defaults to AWS account ID) + + Returns: + BatchDeleteConnectionResponse with succeeded and failed deletions + """ + try: + # Verify each connection is MCP-managed before batch delete + region = AwsHelper.get_aws_region() + account_id = catalog_id or AwsHelper.get_aws_account_id() + partition = AwsHelper.get_aws_partition() + + non_managed = [] + for name in connection_name_list: + try: + connection_arn = f'arn:{partition}:glue:{region}:{account_id}:connection/{name}' + if not AwsHelper.is_resource_mcp_managed(self.glue_client, connection_arn): + non_managed.append(name) + except ClientError: + non_managed.append(name) + + if non_managed: + error_message = f'Cannot batch delete - the following connections are not managed by the MCP server: {", ".join(non_managed)}' + log_with_request_id(ctx, LogLevel.ERROR, error_message) + return CallToolResult( + isError=True, + content=[TextContent(type='text', text=error_message)], + ) + + kwargs: Dict[str, Any] = {'ConnectionNameList': connection_name_list} + if catalog_id: + kwargs['CatalogId'] = catalog_id + + response = self.glue_client.batch_delete_connection(**kwargs) + succeeded = response.get('Succeeded', []) + errors = response.get('Errors', {}) + + log_with_request_id( + ctx, + LogLevel.INFO, + f'Batch delete connections: {len(succeeded)} succeeded, {len(errors)} failed', + ) + + success_msg = f'Batch delete connections: {len(succeeded)} succeeded, {len(errors)} failed' + data = BatchDeleteConnectionData( + succeeded=succeeded, + errors=errors, + catalog_id=catalog_id or '', + ) + + return CallToolResult( + isError=False, + content=[ + TextContent(type='text', text=success_msg), + TextContent(type='text', text=json.dumps(data.model_dump())), + ], + ) + + except ClientError as e: + error_code = e.response['Error']['Code'] + error_message = f'Failed to batch delete connections: {error_code} - {e.response["Error"]["Message"]}' + log_with_request_id(ctx, LogLevel.ERROR, error_message) + + return CallToolResult( + isError=True, + content=[TextContent(type='text', text=error_message)], + ) + + async def describe_connection_type( + self, + ctx: Context, + connection_type: str, + ) -> CallToolResult: + """Describe a connection type with full details of supported options. + + Provides full details of the supported options for a given connection type + in AWS Glue, including properties, authentication, and compute environments. + + Args: + ctx: MCP context containing request information + connection_type: The name of the connection type to describe + + Returns: + DescribeConnectionTypeResponse with the connection type details + """ + try: + response = self.glue_client.describe_connection_type( + ConnectionType=connection_type + ) + + log_with_request_id( + ctx, + LogLevel.INFO, + f'Successfully described connection type: {connection_type}', + ) + + success_msg = f'Successfully described connection type: {connection_type}' + data = DescribeConnectionTypeData( + connection_type=response.get('ConnectionType', connection_type), + description=response.get('Description'), + capabilities=response.get('Capabilities'), + connection_properties=response.get('ConnectionProperties'), + connection_options=response.get('ConnectionOptions'), + authentication_configuration=response.get('AuthenticationConfiguration'), + compute_environment_configurations=response.get('ComputeEnvironmentConfigurations'), + physical_connection_requirements=response.get('PhysicalConnectionRequirements'), + athena_connection_properties=response.get('AthenaConnectionProperties'), + python_connection_properties=response.get('PythonConnectionProperties'), + spark_connection_properties=response.get('SparkConnectionProperties'), + ) + + return CallToolResult( + isError=False, + content=[ + TextContent(type='text', text=success_msg), + TextContent(type='text', text=json.dumps(data.model_dump())), + ], + ) + + except ClientError as e: + error_code = e.response['Error']['Code'] + error_message = f'Failed to describe connection type {connection_type}: {error_code} - {e.response["Error"]["Message"]}' + log_with_request_id(ctx, LogLevel.ERROR, error_message) + + return CallToolResult( + isError=True, + content=[TextContent(type='text', text=error_message)], + ) + + async def list_connection_types( + self, + ctx: Context, + max_results: Optional[int] = None, + next_token: Optional[str] = None, + ) -> CallToolResult: + """List available connection types in AWS Glue. + + Provides a discovery mechanism to learn available connection types. + The response contains a list of connection types with high-level details. + + Args: + ctx: MCP context containing request information + max_results: Optional maximum number of results to return + next_token: Optional pagination token + + Returns: + ListConnectionTypesResponse with the list of connection types + """ + try: + kwargs: Dict[str, Any] = {} + if max_results: + kwargs['MaxResults'] = max_results + if next_token: + kwargs['NextToken'] = next_token + + response = self.glue_client.list_connection_types(**kwargs) + connection_types = response.get('ConnectionTypes', []) + next_token_response = response.get('NextToken') + + log_with_request_id( + ctx, + LogLevel.INFO, + f'Successfully listed {len(connection_types)} connection types', + ) + + success_msg = f'Successfully listed {len(connection_types)} connection types' + data = ListConnectionTypesData( + connection_types=[ + ConnectionTypeBrief( + connection_type=ct.get('ConnectionType'), + display_name=ct.get('DisplayName'), + vendor=ct.get('Vendor'), + description=ct.get('Description'), + ) + for ct in connection_types + ], + count=len(connection_types), + next_token=next_token_response, + ) + + return CallToolResult( + isError=False, + content=[ + TextContent(type='text', text=success_msg), + TextContent(type='text', text=json.dumps(data.model_dump())), + ], + ) + + except ClientError as e: + error_code = e.response['Error']['Code'] + error_message = f'Failed to list connection types: {error_code} - {e.response["Error"]["Message"]}' + log_with_request_id(ctx, LogLevel.ERROR, error_message) + + return CallToolResult( + isError=True, + content=[TextContent(type='text', text=error_message)], + ) + + async def list_entities( + self, + ctx: Context, + connection_name: str, + catalog_id: Optional[str] = None, + parent_entity_name: Optional[str] = None, + next_token: Optional[str] = None, + data_store_api_version: Optional[str] = None, + ) -> CallToolResult: + """List entities available for a connection. + + Returns the available entities supported by the connection type. + For example, databases, schemas, or tables for Amazon Redshift, + or SObjects for Salesforce. + + Args: + ctx: MCP context containing request information + connection_name: Name of the connection to list entities for + catalog_id: Optional catalog ID + parent_entity_name: Optional parent entity name for listing children + next_token: Optional pagination token + data_store_api_version: Optional API version of the SaaS connector + + Returns: + ListEntitiesResponse with the list of entities + """ + try: + kwargs: Dict[str, Any] = {'ConnectionName': connection_name} + if catalog_id: + kwargs['CatalogId'] = catalog_id + if parent_entity_name: + kwargs['ParentEntityName'] = parent_entity_name + if next_token: + kwargs['NextToken'] = next_token + if data_store_api_version: + kwargs['DataStoreApiVersion'] = data_store_api_version + + response = self.glue_client.list_entities(**kwargs) + entities = response.get('Entities', []) + next_token_response = response.get('NextToken') + + log_with_request_id( + ctx, + LogLevel.INFO, + f'Successfully listed {len(entities)} entities for connection: {connection_name}', + ) + + success_msg = f'Successfully listed {len(entities)} entities for connection: {connection_name}' + data = ListEntitiesData( + entities=[ + EntitySummary( + entity_name=e.get('EntityName'), + label=e.get('Label'), + is_parent_entity=e.get('IsParentEntity'), + description=e.get('Description'), + category=e.get('Category'), + ) + for e in entities + ], + count=len(entities), + next_token=next_token_response, + ) + + return CallToolResult( + isError=False, + content=[ + TextContent(type='text', text=success_msg), + TextContent(type='text', text=json.dumps(data.model_dump())), + ], + ) + + except ClientError as e: + error_code = e.response['Error']['Code'] + error_message = f'Failed to list entities for connection {connection_name}: {error_code} - {e.response["Error"]["Message"]}' + log_with_request_id(ctx, LogLevel.ERROR, error_message) + + return CallToolResult( + isError=True, + content=[TextContent(type='text', text=error_message)], + ) + + async def describe_entity( + self, + ctx: Context, + connection_name: str, + entity_name: str, + catalog_id: Optional[str] = None, + next_token: Optional[str] = None, + data_store_api_version: Optional[str] = None, + ) -> CallToolResult: + """Describe an entity's fields for a connection. + + Provides details regarding the entity used with the connection type, + with a description of the data model for each field in the selected entity. + + Args: + ctx: MCP context containing request information + connection_name: Name of the connection + entity_name: Name of the entity to describe + catalog_id: Optional catalog ID + next_token: Optional pagination token + data_store_api_version: Optional API version of the data store + + Returns: + DescribeEntityResponse with the entity field details + """ + try: + kwargs: Dict[str, Any] = { + 'ConnectionName': connection_name, + 'EntityName': entity_name, + } + if catalog_id: + kwargs['CatalogId'] = catalog_id + if next_token: + kwargs['NextToken'] = next_token + if data_store_api_version: + kwargs['DataStoreApiVersion'] = data_store_api_version + + response = self.glue_client.describe_entity(**kwargs) + fields = response.get('Fields', []) + next_token_response = response.get('NextToken') + + log_with_request_id( + ctx, + LogLevel.INFO, + f'Successfully described entity {entity_name} for connection: {connection_name}', + ) + + success_msg = f'Successfully described entity {entity_name} with {len(fields)} fields' + data = DescribeEntityData( + fields=[ + FieldSummary( + field_name=f.get('FieldName'), + label=f.get('Label'), + description=f.get('Description'), + field_type=f.get('FieldType'), + is_primary_key=f.get('IsPrimaryKey'), + is_nullable=f.get('IsNullable'), + is_filterable=f.get('IsFilterable'), + is_partitionable=f.get('IsPartitionable'), + is_retrievable=f.get('IsRetrievable'), + is_createable=f.get('IsCreateable'), + is_updateable=f.get('IsUpdateable'), + is_upsertable=f.get('IsUpsertable'), + ) + for f in fields + ], + count=len(fields), + next_token=next_token_response, + ) + + return CallToolResult( + isError=False, + content=[ + TextContent(type='text', text=success_msg), + TextContent(type='text', text=json.dumps(data.model_dump())), + ], + ) + + except ClientError as e: + error_code = e.response['Error']['Code'] + error_message = f'Failed to describe entity {entity_name} for connection {connection_name}: {error_code} - {e.response["Error"]["Message"]}' + log_with_request_id(ctx, LogLevel.ERROR, error_message) + + return CallToolResult( + isError=True, + content=[TextContent(type='text', text=error_message)], + ) + + async def get_entity_records( + self, + ctx: Context, + connection_name: str, + entity_name: str, + limit: int, + catalog_id: Optional[str] = None, + next_token: Optional[str] = None, + data_store_api_version: Optional[str] = None, + connection_options: Optional[Dict[str, str]] = None, + filter_predicate: Optional[str] = None, + selected_fields: Optional[List[str]] = None, + ) -> CallToolResult: + """Get entity records (preview data) from a connection. + + Queries preview data from a given connection type or from a native + Amazon S3 based AWS Glue Data Catalog. Returns records as JSON blobs. + + Args: + ctx: MCP context containing request information + connection_name: Name of the connection + entity_name: Name of the entity to query + limit: Maximum number of records to fetch (1-1000) + catalog_id: Optional catalog ID + next_token: Optional pagination token + data_store_api_version: Optional API version of the SaaS connector + connection_options: Optional connector options for querying data + filter_predicate: Optional filter predicate for the query + selected_fields: Optional list of fields to fetch + + Returns: + GetEntityRecordsResponse with the entity records + """ + try: + kwargs: Dict[str, Any] = { + 'ConnectionName': connection_name, + 'EntityName': entity_name, + 'Limit': limit, + } + if catalog_id: + kwargs['CatalogId'] = catalog_id + if next_token: + kwargs['NextToken'] = next_token + if data_store_api_version: + kwargs['DataStoreApiVersion'] = data_store_api_version + if connection_options: + kwargs['ConnectionOptions'] = connection_options + if filter_predicate: + kwargs['FilterPredicate'] = filter_predicate + if selected_fields: + kwargs['SelectedFields'] = selected_fields + + response = self.glue_client.get_entity_records(**kwargs) + records = response.get('Records', []) + next_token_response = response.get('NextToken') + + log_with_request_id( + ctx, + LogLevel.INFO, + f'Successfully retrieved {len(records)} records for entity {entity_name}', + ) + + success_msg = f'Successfully retrieved {len(records)} records for entity {entity_name}' + data = GetEntityRecordsData( + records=records, + count=len(records), + next_token=next_token_response, + ) + + return CallToolResult( + isError=False, + content=[ + TextContent(type='text', text=success_msg), + TextContent(type='text', text=json.dumps(data.model_dump())), + ], + ) + + except ClientError as e: + error_code = e.response['Error']['Code'] + error_message = f'Failed to get entity records for {entity_name}: {error_code} - {e.response["Error"]["Message"]}' + log_with_request_id(ctx, LogLevel.ERROR, error_message) + + return CallToolResult( + isError=True, + content=[TextContent(type='text', text=error_message)], + ) + async def create_partition( self, ctx: Context, diff --git a/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/handlers/glue/data_catalog_handler.py b/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/handlers/glue/data_catalog_handler.py index d6f5d75d8a..d3df731343 100644 --- a/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/handlers/glue/data_catalog_handler.py +++ b/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/handlers/glue/data_catalog_handler.py @@ -65,6 +65,12 @@ def __init__(self, mcp, allow_write: bool = False, allow_sensitive_data_access: self.mcp.tool(name='manage_aws_glue_connections')( self.manage_aws_glue_data_catalog_connections ) + self.mcp.tool(name='manage_aws_glue_connection_types')( + self.manage_aws_glue_connection_types + ) + self.mcp.tool(name='manage_aws_glue_connection_metadata')( + self.manage_aws_glue_connection_metadata + ) self.mcp.tool(name='manage_aws_glue_partitions')( self.manage_aws_glue_data_catalog_partitions ) @@ -441,13 +447,13 @@ async def manage_aws_glue_data_catalog_connections( operation: Annotated[ str, Field( - description='Operation to perform: create-connection, delete-connection, get-connection, list-connections, or update-connection. Choose "get-connection" or "list-connections" for read-only operations.', + description='Operation to perform: create-connection, delete-connection, get-connection, list-connections, update-connection, test-connection, or batch-delete-connection. Choose "get-connection" or "list-connections" for read-only operations.', ), ], connection_name: Annotated[ Optional[str], Field( - description='Name of the connection (required for create-connection, delete-connection, get-connection, and update-connection operations).', + description='Name of the connection (required for create-connection, delete-connection, get-connection, update-connection, and test-connection operations).', ), ] = None, connection_input: Annotated[ @@ -476,6 +482,18 @@ async def manage_aws_glue_data_catalog_connections( description='Flag to retrieve the connection metadata without returning the password(for get-connection and list-connections operation).', ), ] = True, + test_connection_input: Annotated[ + Optional[Dict[str, Any]], + Field( + description='TestConnectionInput for testing a non-existing connection (for test-connection operation). Provide either connection_name or test_connection_input.', + ), + ] = None, + connection_name_list: Annotated[ + Optional[List[str]], + Field( + description='List of connection names to delete (required for batch-delete-connection operation, max 25).', + ), + ] = None, ) -> CallToolResult: """Manage AWS Glue Data Catalog connections with both read and write operations. @@ -485,7 +503,7 @@ async def manage_aws_glue_data_catalog_connections( to connect to external data sources. ## Requirements - - The server must be run with the `--allow-write` flag for create, update, and delete operations + - The server must be run with the `--allow-write` flag for create, update, delete, test, and batch-delete operations - Appropriate AWS permissions for Glue Data Catalog operations - Connection properties must be valid for the connection type @@ -495,11 +513,14 @@ async def manage_aws_glue_data_catalog_connections( - **get-connection**: Retrieve detailed information about a specific connection - **list-connections**: List all connections - **update-connection**: Update an existing connection's properties + - **test-connection**: Test a connection to validate service credentials + - **batch-delete-connection**: Delete multiple connections in a single call ## Usage Tips - Connection names must be unique within your catalog - Connection input should include ConnectionType and ConnectionProperties - Use get or list operations to check existing connections before creating + - For test-connection, provide either connection_name (existing) or test_connection_input (new) Args: ctx: MCP context @@ -510,6 +531,8 @@ async def manage_aws_glue_data_catalog_connections( max_results: Maximum results to return next_token: A continuation string token, if this is a continuation call hide_password: The boolean flag to control connection password in return value for get-connection and list-connections operation + test_connection_input: TestConnectionInput for test-connection operation + connection_name_list: List of connection names for batch-delete-connection operation Returns: Union of response types specific to the operation performed @@ -520,8 +543,10 @@ async def manage_aws_glue_data_catalog_connections( 'get-connection', 'list-connections', 'update-connection', + 'test-connection', + 'batch-delete-connection', ]: - error_message = f'Invalid operation: {operation}. Must be one of: create-connection, delete-connection, get-connection, list-connections, update-connection' + error_message = f'Invalid operation: {operation}. Must be one of: create-connection, delete-connection, get-connection, list-connections, update-connection, test-connection, batch-delete-connection' log_with_request_id(ctx, LogLevel.ERROR, error_message) return CallToolResult( isError=True, @@ -590,8 +615,32 @@ async def manage_aws_glue_data_catalog_connections( connection_input=connection_input, catalog_id=catalog_id, ) + + elif operation == 'test-connection': + if connection_name is None and test_connection_input is None: + raise ValueError( + 'Either connection_name or test_connection_input is required for test-connection operation' + ) + return await self.data_catalog_manager.test_connection( + ctx=ctx, + connection_name=connection_name, + catalog_id=catalog_id, + test_connection_input=test_connection_input, + ) + + elif operation == 'batch-delete-connection': + if connection_name_list is None or len(connection_name_list) == 0: + raise ValueError( + 'connection_name_list is required for batch-delete-connection operation' + ) + return await self.data_catalog_manager.batch_delete_connection( + ctx=ctx, + connection_name_list=connection_name_list, + catalog_id=catalog_id, + ) + else: - error_message = f'Invalid operation: {operation}. Must be one of: create-connection, delete-connection, get-connection, list-connections, update-connection' + error_message = f'Invalid operation: {operation}. Must be one of: create-connection, delete-connection, get-connection, list-connections, update-connection, test-connection, batch-delete-connection' log_with_request_id(ctx, LogLevel.ERROR, error_message) return CallToolResult( isError=True, @@ -609,6 +658,311 @@ async def manage_aws_glue_data_catalog_connections( content=[TextContent(type='text', text=error_message)], ) + async def manage_aws_glue_connection_types( + self, + ctx: Context, + operation: Annotated[ + str, + Field( + description='Operation to perform: describe-connection-type, list-connection-types. Both are read-only operations.', + ), + ], + connection_type: Annotated[ + Optional[str], + Field( + description='The name of the connection type to describe (required for describe-connection-type operation, e.g. JDBC, KAFKA, SALESFORCE).', + ), + ] = None, + max_results: Annotated[ + Optional[int], + Field(description='Maximum number of results to return for list-connection-types operation.'), + ] = None, + next_token: Annotated[ + Optional[str], + Field(description='A continuation token, if this is a continuation call.'), + ] = None, + ) -> CallToolResult: + """Discover and describe AWS Glue connection types. + + This tool provides operations for discovering available connection types in AWS Glue + and getting detailed information about specific connection types, including their + supported properties, authentication methods, and compute environments. + + ## Operations + - **describe-connection-type**: Get full details of a specific connection type including properties, auth config, and compute environments + - **list-connection-types**: List all available connection types with brief descriptions + + ## Example + ```python + # List all available connection types + manage_aws_glue_connection_types(operation='list-connection-types') + + # Describe a specific connection type + manage_aws_glue_connection_types(operation='describe-connection-type', connection_type='JDBC') + ``` + + Args: + ctx: MCP context + operation: Operation to perform + connection_type: Name of the connection type (for describe-connection-type) + max_results: Maximum results to return + next_token: Pagination token + + Returns: + Union of response types specific to the operation performed + """ + if operation not in ['describe-connection-type', 'list-connection-types']: + error_message = f'Invalid operation: {operation}. Must be one of: describe-connection-type, list-connection-types' + log_with_request_id(ctx, LogLevel.ERROR, error_message) + return CallToolResult( + isError=True, + content=[TextContent(type='text', text=error_message)], + ) + + try: + if operation == 'describe-connection-type': + if connection_type is None: + raise ValueError( + 'connection_type is required for describe-connection-type operation' + ) + return await self.data_catalog_manager.describe_connection_type( + ctx=ctx, + connection_type=connection_type, + ) + + elif operation == 'list-connection-types': + return await self.data_catalog_manager.list_connection_types( + ctx=ctx, + max_results=max_results, + next_token=next_token, + ) + + else: + error_message = f'Invalid operation: {operation}. Must be one of: describe-connection-type, list-connection-types' + log_with_request_id(ctx, LogLevel.ERROR, error_message) + return CallToolResult( + isError=True, + content=[TextContent(type='text', text=error_message)], + ) + + except ValueError as e: + log_with_request_id(ctx, LogLevel.ERROR, f'Parameter validation error: {str(e)}') + raise + except Exception as e: + error_message = f'Error in manage_aws_glue_connection_types: {str(e)}' + log_with_request_id(ctx, LogLevel.ERROR, error_message) + return CallToolResult( + isError=True, + content=[TextContent(type='text', text=error_message)], + ) + + async def manage_aws_glue_connection_metadata( + self, + ctx: Context, + operation: Annotated[ + str, + Field( + description='Operation to perform: list-entities, describe-entity, get-entity-records. Choose "list-entities" or "describe-entity" for metadata-only operations. "get-entity-records" requires --allow-sensitive-data-access flag.', + ), + ], + connection_name: Annotated[ + str, + Field( + description='Name of the connection that has required credentials to query the connection type (required for all operations).', + ), + ], + entity_name: Annotated[ + Optional[str], + Field( + description='Name of the entity (required for describe-entity and get-entity-records operations).', + ), + ] = None, + catalog_id: Annotated[ + Optional[str], + Field( + description='Catalog ID (optional, defaults to AWS account ID).', + ), + ] = None, + parent_entity_name: Annotated[ + Optional[str], + Field( + description='Name of the parent entity for listing child entities (for list-entities operation).', + ), + ] = None, + next_token: Annotated[ + Optional[str], + Field(description='A continuation token, if this is a continuation call.'), + ] = None, + data_store_api_version: Annotated[ + Optional[str], + Field( + description='The API version of the SaaS connector.', + ), + ] = None, + limit: Annotated[ + Optional[int], + Field( + description='Maximum number of records to fetch (1-1000, required for get-entity-records operation).', + ), + ] = None, + connection_options: Annotated[ + Optional[Dict[str, str]], + Field( + description='Connector options required to query the data (for get-entity-records operation).', + ), + ] = None, + filter_predicate: Annotated[ + Optional[str], + Field( + description='A filter predicate to apply in the query request (for get-entity-records operation).', + ), + ] = None, + selected_fields: Annotated[ + Optional[List[str]], + Field( + description='List of fields to fetch as part of preview data (for get-entity-records operation).', + ), + ] = None, + ) -> CallToolResult: + """Access connection metadata and preview entity data from AWS Glue connections. + + This tool provides operations for discovering entities available through a connection, + describing entity schemas, and previewing entity data. Useful for exploring data sources + connected via AWS Glue connections such as SaaS applications, databases, and other data stores. + + ## Requirements + - The server must be run with the `--allow-sensitive-data-access` flag for get-entity-records operation + - Appropriate AWS permissions for Glue connection metadata operations + - A valid connection with credentials must exist + + ## Operations + - **list-entities**: List available entities (e.g., tables, SObjects) for a connection + - **describe-entity**: Get the schema/field details for a specific entity + - **get-entity-records**: Preview data records from an entity (requires sensitive data access) + + ## Example + ```python + # List entities for a Salesforce connection + manage_aws_glue_connection_metadata( + operation='list-entities', + connection_name='my-salesforce-connection', + ) + + # Describe the Account entity + manage_aws_glue_connection_metadata( + operation='describe-entity', + connection_name='my-salesforce-connection', + entity_name='Account', + ) + + # Preview records from the Account entity + manage_aws_glue_connection_metadata( + operation='get-entity-records', + connection_name='my-salesforce-connection', + entity_name='Account', + limit=10, + ) + ``` + + Args: + ctx: MCP context + operation: Operation to perform + connection_name: Name of the connection + entity_name: Name of the entity + catalog_id: Catalog ID + parent_entity_name: Parent entity name for listing children + next_token: Pagination token + data_store_api_version: API version of the SaaS connector + limit: Maximum number of records to fetch + connection_options: Connector options for querying data + filter_predicate: Filter predicate for the query + selected_fields: List of fields to fetch + + Returns: + Union of response types specific to the operation performed + """ + if operation not in ['list-entities', 'describe-entity', 'get-entity-records']: + error_message = f'Invalid operation: {operation}. Must be one of: list-entities, describe-entity, get-entity-records' + log_with_request_id(ctx, LogLevel.ERROR, error_message) + return CallToolResult( + isError=True, + content=[TextContent(type='text', text=error_message)], + ) + + try: + if operation == 'get-entity-records' and not self.allow_sensitive_data_access: + error_message = 'Operation get-entity-records requires --allow-sensitive-data-access flag to be enabled' + log_with_request_id(ctx, LogLevel.ERROR, error_message) + return CallToolResult( + isError=True, + content=[TextContent(type='text', text=error_message)], + ) + + if operation == 'list-entities': + return await self.data_catalog_manager.list_entities( + ctx=ctx, + connection_name=connection_name, + catalog_id=catalog_id, + parent_entity_name=parent_entity_name, + next_token=next_token, + data_store_api_version=data_store_api_version, + ) + + elif operation == 'describe-entity': + if entity_name is None: + raise ValueError( + 'entity_name is required for describe-entity operation' + ) + return await self.data_catalog_manager.describe_entity( + ctx=ctx, + connection_name=connection_name, + entity_name=entity_name, + catalog_id=catalog_id, + next_token=next_token, + data_store_api_version=data_store_api_version, + ) + + elif operation == 'get-entity-records': + if entity_name is None: + raise ValueError( + 'entity_name is required for get-entity-records operation' + ) + if limit is None: + raise ValueError( + 'limit is required for get-entity-records operation' + ) + return await self.data_catalog_manager.get_entity_records( + ctx=ctx, + connection_name=connection_name, + entity_name=entity_name, + limit=limit, + catalog_id=catalog_id, + next_token=next_token, + data_store_api_version=data_store_api_version, + connection_options=connection_options, + filter_predicate=filter_predicate, + selected_fields=selected_fields, + ) + + else: + error_message = f'Invalid operation: {operation}. Must be one of: list-entities, describe-entity, get-entity-records' + log_with_request_id(ctx, LogLevel.ERROR, error_message) + return CallToolResult( + isError=True, + content=[TextContent(type='text', text=error_message)], + ) + + except ValueError as e: + log_with_request_id(ctx, LogLevel.ERROR, f'Parameter validation error: {str(e)}') + raise + except Exception as e: + error_message = f'Error in manage_aws_glue_connection_metadata: {str(e)}' + log_with_request_id(ctx, LogLevel.ERROR, error_message) + return CallToolResult( + isError=True, + content=[TextContent(type='text', text=error_message)], + ) + async def manage_aws_glue_data_catalog_partitions( self, ctx: Context, diff --git a/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/models/data_catalog_models.py b/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/models/data_catalog_models.py index f10c2d64bf..c4bd18fb67 100644 --- a/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/models/data_catalog_models.py +++ b/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/models/data_catalog_models.py @@ -273,6 +273,112 @@ class UpdateConnectionData(BaseModel): operation: str = Field(default='update', description='Operation performed') +class TestConnectionData(BaseModel): + """Data model for test connection operation.""" + + connection_name: Optional[str] = Field(None, description='Name of the tested connection') + catalog_id: Optional[str] = Field(None, description='Catalog ID containing the connection') + operation: str = Field(default='test-connection', description='Operation performed') + + +class BatchDeleteConnectionData(BaseModel): + """Data model for batch delete connection operation.""" + + succeeded: List[str] = Field(default_factory=list, description='Connections successfully deleted') + errors: Dict[str, Any] = Field(default_factory=dict, description='Connections that failed to delete with error details') + catalog_id: Optional[str] = Field(None, description='Catalog ID containing the connections') + operation: str = Field(default='batch-delete-connection', description='Operation performed') + + +class DescribeConnectionTypeData(BaseModel): + """Data model for describe connection type operation.""" + + connection_type: str = Field(..., description='Name of the connection type') + description: Optional[str] = Field(None, description='Description of the connection type') + capabilities: Optional[Dict[str, Any]] = Field(None, description='Supported capabilities') + connection_properties: Optional[Dict[str, Any]] = Field(None, description='Common connection properties') + connection_options: Optional[Dict[str, Any]] = Field(None, description='Connection options for Spark ETL') + authentication_configuration: Optional[Dict[str, Any]] = Field(None, description='Authentication configuration') + compute_environment_configurations: Optional[Dict[str, Any]] = Field(None, description='Supported compute environments') + physical_connection_requirements: Optional[Dict[str, Any]] = Field(None, description='Physical connection requirements') + athena_connection_properties: Optional[Dict[str, Any]] = Field(None, description='Athena-specific properties') + python_connection_properties: Optional[Dict[str, Any]] = Field(None, description='Python-specific properties') + spark_connection_properties: Optional[Dict[str, Any]] = Field(None, description='Spark-specific properties') + operation: str = Field(default='describe-connection-type', description='Operation performed') + + +class ConnectionTypeBrief(BaseModel): + """Summary model for a connection type.""" + + connection_type: Optional[str] = Field(None, description='Name of the connection type') + display_name: Optional[str] = Field(None, description='Human-readable display name') + vendor: Optional[str] = Field(None, description='Vendor name') + description: Optional[str] = Field(None, description='Description of the connection type') + + +class ListConnectionTypesData(BaseModel): + """Data model for list connection types operation.""" + + connection_types: List[ConnectionTypeBrief] = Field(default_factory=list, description='List of connection types') + count: int = Field(0, description='Number of connection types returned') + next_token: Optional[str] = Field(None, description='Token for pagination') + operation: str = Field(default='list-connection-types', description='Operation performed') + + +class EntitySummary(BaseModel): + """Summary model for a connection entity.""" + + entity_name: Optional[str] = Field(None, description='Name of the entity') + label: Optional[str] = Field(None, description='Label for the entity') + is_parent_entity: Optional[bool] = Field(None, description='Whether entity has sub-objects') + description: Optional[str] = Field(None, description='Description of the entity') + category: Optional[str] = Field(None, description='Category of the entity') + + +class ListEntitiesData(BaseModel): + """Data model for list entities operation.""" + + entities: List[EntitySummary] = Field(default_factory=list, description='List of entities') + count: int = Field(0, description='Number of entities returned') + next_token: Optional[str] = Field(None, description='Token for pagination') + operation: str = Field(default='list-entities', description='Operation performed') + + +class FieldSummary(BaseModel): + """Summary model for an entity field.""" + + field_name: Optional[str] = Field(None, description='Unique identifier for the field') + label: Optional[str] = Field(None, description='Readable label for the field') + description: Optional[str] = Field(None, description='Description of the field') + field_type: Optional[str] = Field(None, description='Data type of the field') + is_primary_key: Optional[bool] = Field(None, description='Whether field is a primary key') + is_nullable: Optional[bool] = Field(None, description='Whether field is nullable') + is_filterable: Optional[bool] = Field(None, description='Whether field can be used in filters') + is_partitionable: Optional[bool] = Field(None, description='Whether field can be used for partitioning') + is_retrievable: Optional[bool] = Field(None, description='Whether field can be retrieved') + is_createable: Optional[bool] = Field(None, description='Whether field can be created') + is_updateable: Optional[bool] = Field(None, description='Whether field can be updated') + is_upsertable: Optional[bool] = Field(None, description='Whether field can be upserted') + + +class DescribeEntityData(BaseModel): + """Data model for describe entity operation.""" + + fields: List[FieldSummary] = Field(default_factory=list, description='List of entity fields') + count: int = Field(0, description='Number of fields returned') + next_token: Optional[str] = Field(None, description='Token for pagination') + operation: str = Field(default='describe-entity', description='Operation performed') + + +class GetEntityRecordsData(BaseModel): + """Data model for get entity records operation.""" + + records: List[Dict[str, Any]] = Field(default_factory=list, description='List of entity records') + count: int = Field(0, description='Number of records returned') + next_token: Optional[str] = Field(None, description='Token for pagination') + operation: str = Field(default='get-entity-records', description='Operation performed') + + # Partition Data Models class CreatePartitionData(BaseModel): """Data model for create partition operation.""" diff --git a/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/server.py b/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/server.py index 3c73a29e32..dae542d024 100644 --- a/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/server.py +++ b/src/aws-dataprocessing-mcp-server/awslabs/aws_dataprocessing_mcp_server/server.py @@ -126,6 +126,21 @@ 2. Delete a table: `manage_aws_glue_tables(operation='delete-table', database_name='my-database', table_name='my-table')` 3. Delete a connection: `manage_aws_glue_connections(operation='delete-connection', connection_name='my-connection')` 4. Delete a database: `manage_aws_glue_databases(operation='delete-database', database_name='my-database')` +5. Batch delete connections: `manage_aws_glue_connections(operation='batch-delete-connection', connection_name_list=['conn-1', 'conn-2'])` + +### Testing Connections +1. Test an existing connection: `manage_aws_glue_connections(operation='test-connection', connection_name='my-connection')` +2. Test a new connection input: `manage_aws_glue_connections(operation='test-connection', test_connection_input={'ConnectionType': 'JDBC', 'ConnectionProperties': {...}})` + +### Discovering Connection Types +1. List all available connection types: `manage_aws_glue_connection_types(operation='list-connection-types')` +2. Describe a specific connection type: `manage_aws_glue_connection_types(operation='describe-connection-type', connection_type='JDBC')` + +### Exploring Connection Metadata and Entity Data +1. List entities for a connection: `manage_aws_glue_connection_metadata(operation='list-entities', connection_name='my-connection')` +2. List child entities: `manage_aws_glue_connection_metadata(operation='list-entities', connection_name='my-connection', parent_entity_name='my-database')` +3. Describe an entity's schema: `manage_aws_glue_connection_metadata(operation='describe-entity', connection_name='my-connection', entity_name='my-table')` +4. Preview entity records: `manage_aws_glue_connection_metadata(operation='get-entity-records', connection_name='my-connection', entity_name='my-table', limit=10)` ### Setup EMR EC2 Cluster diff --git a/src/aws-dataprocessing-mcp-server/tests/core/glue_data_catalog/test_data_catalog_handler.py b/src/aws-dataprocessing-mcp-server/tests/core/glue_data_catalog/test_data_catalog_handler.py index 3bb8c43fca..3e22feafe8 100644 --- a/src/aws-dataprocessing-mcp-server/tests/core/glue_data_catalog/test_data_catalog_handler.py +++ b/src/aws-dataprocessing-mcp-server/tests/core/glue_data_catalog/test_data_catalog_handler.py @@ -2017,3 +2017,526 @@ async def test_update_partition_with_new_parameters(self, manager, mock_ctx, moc assert result.isError is False assert len(result.content) >= 1 assert hasattr(result.content[0], 'text') + + # ==================== TestConnection Tests ==================== + + @pytest.mark.asyncio + async def test_test_connection_success(self, manager, mock_ctx, mock_glue_client): + """Test that test_connection returns a successful response.""" + mock_glue_client.test_connection.return_value = {} + + result = await manager.test_connection( + mock_ctx, + connection_name='test-connection', + catalog_id='123456789012', + ) + + mock_glue_client.test_connection.assert_called_once_with( + ConnectionName='test-connection', + CatalogId='123456789012', + ) + assert isinstance(result, CallToolResult) + assert result.isError is False + assert 'Successfully tested connection' in result.content[0].text + + @pytest.mark.asyncio + async def test_test_connection_with_test_input(self, manager, mock_ctx, mock_glue_client): + """Test that test_connection works with TestConnectionInput.""" + mock_glue_client.test_connection.return_value = {} + test_input = { + 'ConnectionType': 'JDBC', + 'ConnectionProperties': { + 'JDBC_CONNECTION_URL': 'jdbc:mysql://localhost:3306/test', + }, + } + + result = await manager.test_connection( + mock_ctx, + test_connection_input=test_input, + ) + + mock_glue_client.test_connection.assert_called_once_with( + TestConnectionInput=test_input, + ) + assert isinstance(result, CallToolResult) + assert result.isError is False + + @pytest.mark.asyncio + async def test_test_connection_error(self, manager, mock_ctx, mock_glue_client): + """Test that test_connection handles errors correctly.""" + mock_glue_client.test_connection.side_effect = ClientError( + {'Error': {'Code': 'InvalidInputException', 'Message': 'Invalid connection'}}, + 'TestConnection', + ) + + result = await manager.test_connection( + mock_ctx, + connection_name='bad-connection', + ) + + assert isinstance(result, CallToolResult) + assert result.isError is True + assert 'Failed to test connection' in result.content[0].text + + # ==================== BatchDeleteConnection Tests ==================== + + @pytest.mark.asyncio + async def test_batch_delete_connection_success(self, manager, mock_ctx, mock_glue_client): + """Test that batch_delete_connection returns a successful response.""" + mock_glue_client.batch_delete_connection.return_value = { + 'Succeeded': ['conn-1', 'conn-2'], + 'Errors': {}, + } + + with ( + patch( + 'awslabs.aws_dataprocessing_mcp_server.utils.aws_helper.AwsHelper.is_resource_mcp_managed', + return_value=True, + ), + patch( + 'awslabs.aws_dataprocessing_mcp_server.utils.aws_helper.AwsHelper.get_aws_region', + return_value='us-east-1', + ), + patch( + 'awslabs.aws_dataprocessing_mcp_server.utils.aws_helper.AwsHelper.get_aws_account_id', + return_value='123456789012', + ), + patch( + 'awslabs.aws_dataprocessing_mcp_server.utils.aws_helper.AwsHelper.get_aws_partition', + return_value='aws', + ), + ): + result = await manager.batch_delete_connection( + mock_ctx, + connection_name_list=['conn-1', 'conn-2'], + ) + + mock_glue_client.batch_delete_connection.assert_called_once_with( + ConnectionNameList=['conn-1', 'conn-2'], + ) + assert isinstance(result, CallToolResult) + assert result.isError is False + assert '2 succeeded' in result.content[0].text + + @pytest.mark.asyncio + async def test_batch_delete_connection_not_mcp_managed(self, manager, mock_ctx, mock_glue_client): + """Test that batch_delete_connection rejects non-MCP-managed connections.""" + with ( + patch( + 'awslabs.aws_dataprocessing_mcp_server.utils.aws_helper.AwsHelper.is_resource_mcp_managed', + return_value=False, + ), + patch( + 'awslabs.aws_dataprocessing_mcp_server.utils.aws_helper.AwsHelper.get_aws_region', + return_value='us-east-1', + ), + patch( + 'awslabs.aws_dataprocessing_mcp_server.utils.aws_helper.AwsHelper.get_aws_account_id', + return_value='123456789012', + ), + patch( + 'awslabs.aws_dataprocessing_mcp_server.utils.aws_helper.AwsHelper.get_aws_partition', + return_value='aws', + ), + ): + result = await manager.batch_delete_connection( + mock_ctx, + connection_name_list=['conn-1'], + ) + + assert isinstance(result, CallToolResult) + assert result.isError is True + assert 'not managed by the MCP server' in result.content[0].text + + @pytest.mark.asyncio + async def test_batch_delete_connection_error(self, manager, mock_ctx, mock_glue_client): + """Test that batch_delete_connection handles API errors.""" + with ( + patch( + 'awslabs.aws_dataprocessing_mcp_server.utils.aws_helper.AwsHelper.is_resource_mcp_managed', + return_value=True, + ), + patch( + 'awslabs.aws_dataprocessing_mcp_server.utils.aws_helper.AwsHelper.get_aws_region', + return_value='us-east-1', + ), + patch( + 'awslabs.aws_dataprocessing_mcp_server.utils.aws_helper.AwsHelper.get_aws_account_id', + return_value='123456789012', + ), + patch( + 'awslabs.aws_dataprocessing_mcp_server.utils.aws_helper.AwsHelper.get_aws_partition', + return_value='aws', + ), + ): + mock_glue_client.batch_delete_connection.side_effect = ClientError( + {'Error': {'Code': 'InternalServiceException', 'Message': 'Internal error'}}, + 'BatchDeleteConnection', + ) + result = await manager.batch_delete_connection( + mock_ctx, + connection_name_list=['conn-1'], + ) + + assert isinstance(result, CallToolResult) + assert result.isError is True + assert 'Failed to batch delete connections' in result.content[0].text + + # ==================== DescribeConnectionType Tests ==================== + + @pytest.mark.asyncio + async def test_describe_connection_type_success(self, manager, mock_ctx, mock_glue_client): + """Test that describe_connection_type returns a successful response.""" + mock_glue_client.describe_connection_type.return_value = { + 'ConnectionType': 'JDBC', + 'Description': 'JDBC connection type', + 'Capabilities': {'SupportedAuthenticationTypes': ['BASIC']}, + 'ConnectionProperties': {'HOST': {'Description': 'Host name'}}, + 'AuthenticationConfiguration': {'AuthenticationType': 'BASIC'}, + } + + result = await manager.describe_connection_type( + mock_ctx, + connection_type='JDBC', + ) + + mock_glue_client.describe_connection_type.assert_called_once_with( + ConnectionType='JDBC', + ) + assert isinstance(result, CallToolResult) + assert result.isError is False + assert 'Successfully described connection type: JDBC' in result.content[0].text + + @pytest.mark.asyncio + async def test_describe_connection_type_error(self, manager, mock_ctx, mock_glue_client): + """Test that describe_connection_type handles errors correctly.""" + mock_glue_client.describe_connection_type.side_effect = ClientError( + {'Error': {'Code': 'InvalidInputException', 'Message': 'Unknown type'}}, + 'DescribeConnectionType', + ) + + result = await manager.describe_connection_type( + mock_ctx, + connection_type='UNKNOWN', + ) + + assert isinstance(result, CallToolResult) + assert result.isError is True + assert 'Failed to describe connection type' in result.content[0].text + + # ==================== ListConnectionTypes Tests ==================== + + @pytest.mark.asyncio + async def test_list_connection_types_success(self, manager, mock_ctx, mock_glue_client): + """Test that list_connection_types returns a successful response.""" + mock_glue_client.list_connection_types.return_value = { + 'ConnectionTypes': [ + { + 'ConnectionType': 'JDBC', + 'DisplayName': 'JDBC', + 'Vendor': 'AWS', + 'Description': 'JDBC connections', + }, + { + 'ConnectionType': 'KAFKA', + 'DisplayName': 'Apache Kafka', + 'Vendor': 'Apache', + 'Description': 'Kafka connections', + }, + ], + 'NextToken': None, + } + + result = await manager.list_connection_types(mock_ctx) + + mock_glue_client.list_connection_types.assert_called_once_with() + assert isinstance(result, CallToolResult) + assert result.isError is False + assert 'Successfully listed 2 connection types' in result.content[0].text + + @pytest.mark.asyncio + async def test_list_connection_types_with_pagination(self, manager, mock_ctx, mock_glue_client): + """Test that list_connection_types handles pagination parameters.""" + mock_glue_client.list_connection_types.return_value = { + 'ConnectionTypes': [{'ConnectionType': 'JDBC'}], + 'NextToken': 'next-page-token', + } + + result = await manager.list_connection_types( + mock_ctx, + max_results=1, + next_token='prev-token', + ) + + mock_glue_client.list_connection_types.assert_called_once_with( + MaxResults=1, + NextToken='prev-token', + ) + assert isinstance(result, CallToolResult) + assert result.isError is False + + @pytest.mark.asyncio + async def test_list_connection_types_error(self, manager, mock_ctx, mock_glue_client): + """Test that list_connection_types handles errors correctly.""" + mock_glue_client.list_connection_types.side_effect = ClientError( + {'Error': {'Code': 'InternalServiceException', 'Message': 'Internal error'}}, + 'ListConnectionTypes', + ) + + result = await manager.list_connection_types(mock_ctx) + + assert isinstance(result, CallToolResult) + assert result.isError is True + assert 'Failed to list connection types' in result.content[0].text + + # ==================== ListEntities Tests ==================== + + @pytest.mark.asyncio + async def test_list_entities_success(self, manager, mock_ctx, mock_glue_client): + """Test that list_entities returns a successful response.""" + mock_glue_client.list_entities.return_value = { + 'Entities': [ + { + 'EntityName': 'Account', + 'Label': 'Account', + 'IsParentEntity': False, + 'Description': 'Salesforce Account object', + 'Category': 'SObject', + }, + { + 'EntityName': 'Contact', + 'Label': 'Contact', + 'IsParentEntity': False, + 'Description': 'Salesforce Contact object', + 'Category': 'SObject', + }, + ], + 'NextToken': None, + } + + result = await manager.list_entities( + mock_ctx, + connection_name='my-salesforce-conn', + ) + + mock_glue_client.list_entities.assert_called_once_with( + ConnectionName='my-salesforce-conn', + ) + assert isinstance(result, CallToolResult) + assert result.isError is False + assert 'Successfully listed 2 entities' in result.content[0].text + + @pytest.mark.asyncio + async def test_list_entities_with_all_parameters(self, manager, mock_ctx, mock_glue_client): + """Test that list_entities handles all optional parameters.""" + mock_glue_client.list_entities.return_value = { + 'Entities': [], + 'NextToken': None, + } + + result = await manager.list_entities( + mock_ctx, + connection_name='my-conn', + catalog_id='123456789012', + parent_entity_name='my-database', + next_token='token-123', + data_store_api_version='v1', + ) + + mock_glue_client.list_entities.assert_called_once_with( + ConnectionName='my-conn', + CatalogId='123456789012', + ParentEntityName='my-database', + NextToken='token-123', + DataStoreApiVersion='v1', + ) + assert isinstance(result, CallToolResult) + assert result.isError is False + + @pytest.mark.asyncio + async def test_list_entities_error(self, manager, mock_ctx, mock_glue_client): + """Test that list_entities handles errors correctly.""" + mock_glue_client.list_entities.side_effect = ClientError( + {'Error': {'Code': 'EntityNotFoundException', 'Message': 'Connection not found'}}, + 'ListEntities', + ) + + result = await manager.list_entities( + mock_ctx, + connection_name='bad-conn', + ) + + assert isinstance(result, CallToolResult) + assert result.isError is True + assert 'Failed to list entities' in result.content[0].text + + # ==================== DescribeEntity Tests ==================== + + @pytest.mark.asyncio + async def test_describe_entity_success(self, manager, mock_ctx, mock_glue_client): + """Test that describe_entity returns a successful response.""" + mock_glue_client.describe_entity.return_value = { + 'Fields': [ + { + 'FieldName': 'Id', + 'Label': 'ID', + 'FieldType': 'STRING', + 'IsPrimaryKey': True, + 'IsNullable': False, + 'IsFilterable': True, + 'IsRetrievable': True, + }, + { + 'FieldName': 'Name', + 'Label': 'Name', + 'FieldType': 'STRING', + 'IsPrimaryKey': False, + 'IsNullable': True, + 'IsFilterable': True, + 'IsRetrievable': True, + }, + ], + 'NextToken': None, + } + + result = await manager.describe_entity( + mock_ctx, + connection_name='my-conn', + entity_name='Account', + ) + + mock_glue_client.describe_entity.assert_called_once_with( + ConnectionName='my-conn', + EntityName='Account', + ) + assert isinstance(result, CallToolResult) + assert result.isError is False + assert '2 fields' in result.content[0].text + + @pytest.mark.asyncio + async def test_describe_entity_with_all_parameters(self, manager, mock_ctx, mock_glue_client): + """Test that describe_entity handles all optional parameters.""" + mock_glue_client.describe_entity.return_value = { + 'Fields': [], + 'NextToken': None, + } + + result = await manager.describe_entity( + mock_ctx, + connection_name='my-conn', + entity_name='Account', + catalog_id='123456789012', + next_token='token-123', + data_store_api_version='v1', + ) + + mock_glue_client.describe_entity.assert_called_once_with( + ConnectionName='my-conn', + EntityName='Account', + CatalogId='123456789012', + NextToken='token-123', + DataStoreApiVersion='v1', + ) + assert isinstance(result, CallToolResult) + assert result.isError is False + + @pytest.mark.asyncio + async def test_describe_entity_error(self, manager, mock_ctx, mock_glue_client): + """Test that describe_entity handles errors correctly.""" + mock_glue_client.describe_entity.side_effect = ClientError( + {'Error': {'Code': 'EntityNotFoundException', 'Message': 'Entity not found'}}, + 'DescribeEntity', + ) + + result = await manager.describe_entity( + mock_ctx, + connection_name='my-conn', + entity_name='BadEntity', + ) + + assert isinstance(result, CallToolResult) + assert result.isError is True + assert 'Failed to describe entity' in result.content[0].text + + # ==================== GetEntityRecords Tests ==================== + + @pytest.mark.asyncio + async def test_get_entity_records_success(self, manager, mock_ctx, mock_glue_client): + """Test that get_entity_records returns a successful response.""" + mock_glue_client.get_entity_records.return_value = { + 'Records': [ + {'Id': '001', 'Name': 'Acme Corp'}, + {'Id': '002', 'Name': 'Globex Inc'}, + ], + 'NextToken': None, + } + + result = await manager.get_entity_records( + mock_ctx, + connection_name='my-conn', + entity_name='Account', + limit=10, + ) + + mock_glue_client.get_entity_records.assert_called_once_with( + ConnectionName='my-conn', + EntityName='Account', + Limit=10, + ) + assert isinstance(result, CallToolResult) + assert result.isError is False + assert 'Successfully retrieved 2 records' in result.content[0].text + + @pytest.mark.asyncio + async def test_get_entity_records_with_all_parameters(self, manager, mock_ctx, mock_glue_client): + """Test that get_entity_records handles all optional parameters.""" + mock_glue_client.get_entity_records.return_value = { + 'Records': [], + 'NextToken': None, + } + + result = await manager.get_entity_records( + mock_ctx, + connection_name='my-conn', + entity_name='Account', + limit=5, + catalog_id='123456789012', + next_token='token-123', + data_store_api_version='v1', + connection_options={'key': 'value'}, + filter_predicate="Name = 'Acme'", + selected_fields=['Id', 'Name'], + ) + + mock_glue_client.get_entity_records.assert_called_once_with( + ConnectionName='my-conn', + EntityName='Account', + Limit=5, + CatalogId='123456789012', + NextToken='token-123', + DataStoreApiVersion='v1', + ConnectionOptions={'key': 'value'}, + FilterPredicate="Name = 'Acme'", + SelectedFields=['Id', 'Name'], + ) + assert isinstance(result, CallToolResult) + assert result.isError is False + + @pytest.mark.asyncio + async def test_get_entity_records_error(self, manager, mock_ctx, mock_glue_client): + """Test that get_entity_records handles errors correctly.""" + mock_glue_client.get_entity_records.side_effect = ClientError( + {'Error': {'Code': 'ValidationException', 'Message': 'Invalid request'}}, + 'GetEntityRecords', + ) + + result = await manager.get_entity_records( + mock_ctx, + connection_name='my-conn', + entity_name='Account', + limit=10, + ) + + assert isinstance(result, CallToolResult) + assert result.isError is True + assert 'Failed to get entity records' in result.content[0].text diff --git a/src/aws-dataprocessing-mcp-server/tests/handlers/glue/test_data_catalog_handler.py b/src/aws-dataprocessing-mcp-server/tests/handlers/glue/test_data_catalog_handler.py index 5aae24a135..d02196077b 100644 --- a/src/aws-dataprocessing-mcp-server/tests/handlers/glue/test_data_catalog_handler.py +++ b/src/aws-dataprocessing-mcp-server/tests/handlers/glue/test_data_catalog_handler.py @@ -118,7 +118,7 @@ def test_initialization(self, mock_mcp): assert handler.allow_sensitive_data_access is False # Verify that the tools were registered - assert mock_mcp.tool.call_count == 5 + assert mock_mcp.tool.call_count == 7 # Get all call args call_args_list = mock_mcp.tool.call_args_list @@ -130,6 +130,8 @@ def test_initialization(self, mock_mcp): assert 'manage_aws_glue_databases' in tool_names assert 'manage_aws_glue_tables' in tool_names assert 'manage_aws_glue_connections' in tool_names + assert 'manage_aws_glue_connection_types' in tool_names + assert 'manage_aws_glue_connection_metadata' in tool_names assert 'manage_aws_glue_partitions' in tool_names assert 'manage_aws_glue_catalog' in tool_names @@ -3736,3 +3738,420 @@ async def test_manage_aws_glue_data_catalog_connections_get_with_all_parameters( # Verify that the result is the expected response assert result == expected_response + + # ==================== New Fixtures ==================== + + @pytest.fixture + def handler_with_sensitive_data_access( + self, mock_mcp, mock_database_manager, mock_table_manager, mock_catalog_manager + ): + """Create a GlueDataCatalogHandler instance with sensitive data access enabled.""" + with ( + patch( + 'awslabs.aws_dataprocessing_mcp_server.handlers.glue.data_catalog_handler.DataCatalogDatabaseManager', + return_value=mock_database_manager, + ), + patch( + 'awslabs.aws_dataprocessing_mcp_server.handlers.glue.data_catalog_handler.DataCatalogTableManager', + return_value=mock_table_manager, + ), + patch( + 'awslabs.aws_dataprocessing_mcp_server.handlers.glue.data_catalog_handler.DataCatalogManager', + return_value=mock_catalog_manager, + ), + ): + handler = GlueDataCatalogHandler( + mock_mcp, allow_write=True, allow_sensitive_data_access=True + ) + handler.data_catalog_database_manager = mock_database_manager + handler.data_catalog_table_manager = mock_table_manager + handler.data_catalog_manager = mock_catalog_manager + return handler + + # ==================== TestConnection Handler Tests ==================== + + @pytest.mark.asyncio + async def test_manage_connections_test_connection_no_write_access(self, handler, mock_ctx): + """Test that test-connection is not allowed without write access.""" + result = await handler.manage_aws_glue_data_catalog_connections( + mock_ctx, + operation='test-connection', + connection_name='test-conn', + ) + assert result.isError is True + assert 'not allowed without write access' in result.content[0].text + + @pytest.mark.asyncio + async def test_manage_connections_test_connection_with_write_access( + self, handler_with_write_access, mock_ctx, mock_catalog_manager + ): + """Test that test-connection works with write access.""" + expected_response = MagicMock() + expected_response.isError = False + expected_response.content = [] + mock_catalog_manager.test_connection.return_value = expected_response + + result = await handler_with_write_access.manage_aws_glue_data_catalog_connections( + mock_ctx, + operation='test-connection', + connection_name='test-conn', + ) + + mock_catalog_manager.test_connection.assert_called_once() + assert result == expected_response + + @pytest.mark.asyncio + async def test_manage_connections_test_connection_missing_params( + self, handler_with_write_access, mock_ctx + ): + """Test that test-connection requires connection_name or test_connection_input.""" + with pytest.raises(ValueError, match='Either connection_name or test_connection_input'): + await handler_with_write_access.manage_aws_glue_data_catalog_connections( + mock_ctx, + operation='test-connection', + ) + + # ==================== BatchDeleteConnection Handler Tests ==================== + + @pytest.mark.asyncio + async def test_manage_connections_batch_delete_no_write_access(self, handler, mock_ctx): + """Test that batch-delete-connection is not allowed without write access.""" + result = await handler.manage_aws_glue_data_catalog_connections( + mock_ctx, + operation='batch-delete-connection', + connection_name_list=['conn-1', 'conn-2'], + ) + assert result.isError is True + assert 'not allowed without write access' in result.content[0].text + + @pytest.mark.asyncio + async def test_manage_connections_batch_delete_with_write_access( + self, handler_with_write_access, mock_ctx, mock_catalog_manager + ): + """Test that batch-delete-connection works with write access.""" + expected_response = MagicMock() + expected_response.isError = False + expected_response.content = [] + mock_catalog_manager.batch_delete_connection.return_value = expected_response + + result = await handler_with_write_access.manage_aws_glue_data_catalog_connections( + mock_ctx, + operation='batch-delete-connection', + connection_name_list=['conn-1', 'conn-2'], + ) + + mock_catalog_manager.batch_delete_connection.assert_called_once() + assert result == expected_response + + @pytest.mark.asyncio + async def test_manage_connections_batch_delete_missing_list( + self, handler_with_write_access, mock_ctx + ): + """Test that batch-delete-connection requires connection_name_list.""" + with pytest.raises(ValueError, match='connection_name_list is required'): + await handler_with_write_access.manage_aws_glue_data_catalog_connections( + mock_ctx, + operation='batch-delete-connection', + ) + + @pytest.mark.asyncio + async def test_manage_connections_batch_delete_empty_list( + self, handler_with_write_access, mock_ctx + ): + """Test that batch-delete-connection rejects empty list.""" + with pytest.raises(ValueError, match='connection_name_list is required'): + await handler_with_write_access.manage_aws_glue_data_catalog_connections( + mock_ctx, + operation='batch-delete-connection', + connection_name_list=[], + ) + + @pytest.mark.asyncio + async def test_manage_connections_invalid_operation(self, handler, mock_ctx): + """Test that invalid connection operations are rejected.""" + result = await handler.manage_aws_glue_data_catalog_connections( + mock_ctx, + operation='invalid-op', + ) + assert result.isError is True + assert 'Invalid operation' in result.content[0].text + + # ==================== Connection Types Handler Tests ==================== + + @pytest.mark.asyncio + async def test_manage_connection_types_describe_success( + self, handler, mock_ctx, mock_catalog_manager + ): + """Test describe-connection-type operation.""" + expected_response = MagicMock() + expected_response.isError = False + expected_response.content = [] + mock_catalog_manager.describe_connection_type.return_value = expected_response + + result = await handler.manage_aws_glue_connection_types( + mock_ctx, + operation='describe-connection-type', + connection_type='JDBC', + ) + + mock_catalog_manager.describe_connection_type.assert_called_once_with( + ctx=mock_ctx, + connection_type='JDBC', + ) + assert result == expected_response + + @pytest.mark.asyncio + async def test_manage_connection_types_describe_missing_type(self, handler, mock_ctx): + """Test describe-connection-type requires connection_type.""" + with pytest.raises(ValueError, match='connection_type is required'): + await handler.manage_aws_glue_connection_types( + mock_ctx, + operation='describe-connection-type', + ) + + @pytest.mark.asyncio + async def test_manage_connection_types_list_success( + self, handler, mock_ctx, mock_catalog_manager + ): + """Test list-connection-types operation.""" + expected_response = MagicMock() + expected_response.isError = False + expected_response.content = [] + mock_catalog_manager.list_connection_types.return_value = expected_response + + result = await handler.manage_aws_glue_connection_types( + mock_ctx, + operation='list-connection-types', + max_results=10, + ) + + mock_catalog_manager.list_connection_types.assert_called_once_with( + ctx=mock_ctx, + max_results=10, + next_token=ANY, + ) + assert result == expected_response + + @pytest.mark.asyncio + async def test_manage_connection_types_invalid_operation(self, handler, mock_ctx): + """Test that invalid connection type operations are rejected.""" + result = await handler.manage_aws_glue_connection_types( + mock_ctx, + operation='invalid-op', + ) + assert result.isError is True + assert 'Invalid operation' in result.content[0].text + + @pytest.mark.asyncio + async def test_manage_connection_types_exception_handling( + self, handler, mock_ctx, mock_catalog_manager + ): + """Test that connection types handler catches general exceptions.""" + mock_catalog_manager.list_connection_types.side_effect = Exception('Unexpected error') + + result = await handler.manage_aws_glue_connection_types( + mock_ctx, + operation='list-connection-types', + ) + + assert result.isError is True + assert 'Error in manage_aws_glue_connection_types' in result.content[0].text + + # ==================== Connection Metadata Handler Tests ==================== + + @pytest.mark.asyncio + async def test_manage_connection_metadata_list_entities_success( + self, handler, mock_ctx, mock_catalog_manager + ): + """Test list-entities operation.""" + expected_response = MagicMock() + expected_response.isError = False + expected_response.content = [] + mock_catalog_manager.list_entities.return_value = expected_response + + result = await handler.manage_aws_glue_connection_metadata( + mock_ctx, + operation='list-entities', + connection_name='my-conn', + ) + + mock_catalog_manager.list_entities.assert_called_once() + assert result == expected_response + + @pytest.mark.asyncio + async def test_manage_connection_metadata_list_entities_with_parent( + self, handler, mock_ctx, mock_catalog_manager + ): + """Test list-entities with parent_entity_name.""" + expected_response = MagicMock() + expected_response.isError = False + expected_response.content = [] + mock_catalog_manager.list_entities.return_value = expected_response + + result = await handler.manage_aws_glue_connection_metadata( + mock_ctx, + operation='list-entities', + connection_name='my-conn', + parent_entity_name='my-database', + catalog_id='123456789012', + ) + + mock_catalog_manager.list_entities.assert_called_once() + call_kwargs = mock_catalog_manager.list_entities.call_args[1] + assert call_kwargs['connection_name'] == 'my-conn' + assert call_kwargs['parent_entity_name'] == 'my-database' + assert call_kwargs['catalog_id'] == '123456789012' + + @pytest.mark.asyncio + async def test_manage_connection_metadata_describe_entity_success( + self, handler, mock_ctx, mock_catalog_manager + ): + """Test describe-entity operation.""" + expected_response = MagicMock() + expected_response.isError = False + expected_response.content = [] + mock_catalog_manager.describe_entity.return_value = expected_response + + result = await handler.manage_aws_glue_connection_metadata( + mock_ctx, + operation='describe-entity', + connection_name='my-conn', + entity_name='Account', + ) + + mock_catalog_manager.describe_entity.assert_called_once() + assert result == expected_response + + @pytest.mark.asyncio + async def test_manage_connection_metadata_describe_entity_missing_name( + self, handler, mock_ctx + ): + """Test describe-entity requires entity_name.""" + with pytest.raises(ValueError, match='entity_name is required'): + await handler.manage_aws_glue_connection_metadata( + mock_ctx, + operation='describe-entity', + connection_name='my-conn', + ) + + @pytest.mark.asyncio + async def test_manage_connection_metadata_get_records_no_sensitive_access( + self, handler, mock_ctx + ): + """Test get-entity-records requires sensitive data access flag.""" + result = await handler.manage_aws_glue_connection_metadata( + mock_ctx, + operation='get-entity-records', + connection_name='my-conn', + entity_name='Account', + limit=10, + ) + assert result.isError is True + assert 'allow-sensitive-data-access' in result.content[0].text + + @pytest.mark.asyncio + async def test_manage_connection_metadata_get_records_with_sensitive_access( + self, handler_with_sensitive_data_access, mock_ctx, mock_catalog_manager + ): + """Test get-entity-records works with sensitive data access.""" + expected_response = MagicMock() + expected_response.isError = False + expected_response.content = [] + mock_catalog_manager.get_entity_records.return_value = expected_response + + result = await handler_with_sensitive_data_access.manage_aws_glue_connection_metadata( + mock_ctx, + operation='get-entity-records', + connection_name='my-conn', + entity_name='Account', + limit=10, + ) + + mock_catalog_manager.get_entity_records.assert_called_once() + assert result == expected_response + + @pytest.mark.asyncio + async def test_manage_connection_metadata_get_records_missing_entity_name( + self, handler_with_sensitive_data_access, mock_ctx + ): + """Test get-entity-records requires entity_name.""" + with pytest.raises(ValueError, match='entity_name is required'): + await handler_with_sensitive_data_access.manage_aws_glue_connection_metadata( + mock_ctx, + operation='get-entity-records', + connection_name='my-conn', + limit=10, + ) + + @pytest.mark.asyncio + async def test_manage_connection_metadata_get_records_missing_limit( + self, handler_with_sensitive_data_access, mock_ctx + ): + """Test get-entity-records requires limit.""" + with pytest.raises(ValueError, match='limit is required'): + await handler_with_sensitive_data_access.manage_aws_glue_connection_metadata( + mock_ctx, + operation='get-entity-records', + connection_name='my-conn', + entity_name='Account', + ) + + @pytest.mark.asyncio + async def test_manage_connection_metadata_get_records_with_all_params( + self, handler_with_sensitive_data_access, mock_ctx, mock_catalog_manager + ): + """Test get-entity-records with all optional parameters.""" + expected_response = MagicMock() + expected_response.isError = False + expected_response.content = [] + mock_catalog_manager.get_entity_records.return_value = expected_response + + result = await handler_with_sensitive_data_access.manage_aws_glue_connection_metadata( + mock_ctx, + operation='get-entity-records', + connection_name='my-conn', + entity_name='Account', + limit=5, + catalog_id='123456789012', + connection_options={'key': 'value'}, + filter_predicate="Name = 'Acme'", + selected_fields=['Id', 'Name'], + data_store_api_version='v1', + ) + + call_kwargs = mock_catalog_manager.get_entity_records.call_args[1] + assert call_kwargs['connection_name'] == 'my-conn' + assert call_kwargs['entity_name'] == 'Account' + assert call_kwargs['limit'] == 5 + assert call_kwargs['catalog_id'] == '123456789012' + assert call_kwargs['connection_options'] == {'key': 'value'} + assert call_kwargs['filter_predicate'] == "Name = 'Acme'" + assert call_kwargs['selected_fields'] == ['Id', 'Name'] + assert result == expected_response + + @pytest.mark.asyncio + async def test_manage_connection_metadata_invalid_operation(self, handler, mock_ctx): + """Test that invalid metadata operations are rejected.""" + result = await handler.manage_aws_glue_connection_metadata( + mock_ctx, + operation='invalid-op', + connection_name='my-conn', + ) + assert result.isError is True + assert 'Invalid operation' in result.content[0].text + + @pytest.mark.asyncio + async def test_manage_connection_metadata_exception_handling( + self, handler, mock_ctx, mock_catalog_manager + ): + """Test that metadata handler catches general exceptions.""" + mock_catalog_manager.list_entities.side_effect = Exception('Unexpected error') + + result = await handler.manage_aws_glue_connection_metadata( + mock_ctx, + operation='list-entities', + connection_name='my-conn', + ) + + assert result.isError is True + assert 'Error in manage_aws_glue_connection_metadata' in result.content[0].text