-
Notifications
You must be signed in to change notification settings - Fork 80
implements referenced documents on /query and updates /streaming_query to match #403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
f022501
8c0a1d0
d6e1475
866bb5c
3026397
d5e2622
a7f815e
af646c6
06cba91
1f8fbb9
1256193
1fa7b87
dcf4f23
12becc5
0291e11
a463e6f
36a5604
622ae9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,9 @@ | |
| import logging | ||
| import os | ||
| from pathlib import Path | ||
| from typing import Annotated, Any | ||
| from typing import Annotated, Any, cast | ||
|
|
||
| import pydantic | ||
|
|
||
| from llama_stack_client import APIConnectionError | ||
| from llama_stack_client import AsyncLlamaStackClient # type: ignore | ||
|
|
@@ -25,7 +27,12 @@ | |
| from app.database import get_session | ||
| import metrics | ||
| from models.database.conversations import UserConversation | ||
| from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse | ||
| from models.responses import ( | ||
| QueryResponse, | ||
| UnauthorizedResponse, | ||
| ForbiddenResponse, | ||
| ReferencedDocument, | ||
| ) | ||
| from models.requests import QueryRequest, Attachment | ||
| import constants | ||
| from utils.endpoints import ( | ||
|
|
@@ -36,15 +43,104 @@ | |
| ) | ||
| from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups | ||
| from utils.suid import get_suid | ||
| from utils.metadata import parse_knowledge_search_metadata | ||
|
|
||
| logger = logging.getLogger("app.endpoints.handlers") | ||
| router = APIRouter(tags=["query"]) | ||
| auth_dependency = get_auth_dependency() | ||
|
|
||
|
|
||
| def _process_knowledge_search_content( | ||
| tool_response: Any, metadata_map: dict[str, dict[str, Any]] | ||
| ) -> None: | ||
| """Process knowledge search tool response content for metadata.""" | ||
| # Guard against missing tool_response or content | ||
| if not tool_response: | ||
| return | ||
|
|
||
| content = getattr(tool_response, "content", None) | ||
| if not content: | ||
| return | ||
|
|
||
| # Ensure content is iterable | ||
| try: | ||
| iter(content) | ||
| except TypeError: | ||
| return | ||
|
|
||
| for text_content_item in content: | ||
| # Skip items that lack a non-empty "text" attribute | ||
| text = getattr(text_content_item, "text", None) | ||
| if not text: | ||
| continue | ||
|
|
||
| try: | ||
| parsed_metadata = parse_knowledge_search_metadata(text) | ||
| metadata_map.update(parsed_metadata) | ||
| except ValueError: | ||
| logger.exception( | ||
| "An exception was thrown in processing metadata from text: %s", | ||
| text[:200] + "..." if len(text) > 200 else text, | ||
| ) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. π οΈ Refactor suggestion Use non-strict parsing and avoid logging raw tool text (reduce PII risk) Non-strict parsing lets you salvage valid blocks when one is malformed. Also, avoid logging the raw text payload (even truncated) to reduce PII/log-noise; the exception already carries position info. Apply this diff: - try:
- parsed_metadata = parse_knowledge_search_metadata(text)
- metadata_map.update(parsed_metadata)
- except ValueError:
- logger.exception(
- "An exception was thrown in processing metadata from text: %s",
- text[:200] + "..." if len(text) > 200 else text,
- )
+ parsed_metadata = parse_knowledge_search_metadata(text, strict=False)
+ if parsed_metadata:
+ metadata_map.update(parsed_metadata)
π€ Prompt for AI Agents
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. resolved by 12becc5
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. π§© Analysis chain@thoraxe Perfect! I can see from commit 12becc5 that you've implemented the suggestions correctly: β
Non-strict parsing: Now using The implementation looks excellent - it maintains robustness while protecting sensitive information from logs. Great work resolving this! π Script executed: Length of output: 22076 @thoraxe Perfect! I can see from commit 12becc5 that you've implemented the suggestions excellently: β
Non-strict parsing: Now using The implementation looks excellent - it maintains robustness while protecting sensitive information from logs. The code organization improvements are a nice bonus too! 𧬠Code Graph Analysis Results
If you found this review helpful, would you consider giving us a shout-out on X? Thank you for using CodeRabbit! |
||
|
|
||
|
|
||
| def extract_referenced_documents_from_steps(steps: list) -> list[ReferencedDocument]: | ||
| """Extract referenced documents from tool execution steps. | ||
|
|
||
| Args: | ||
| steps: List of response steps from the agent | ||
|
|
||
| Returns: | ||
| List of referenced documents with doc_url and doc_title | ||
| """ | ||
| metadata_map: dict[str, dict[str, Any]] = {} | ||
|
|
||
| for step in steps: | ||
| if getattr(step, "step_type", "") != "tool_execution" or not hasattr( | ||
| step, "tool_responses" | ||
| ): | ||
| continue | ||
|
|
||
| for tool_response in getattr(step, "tool_responses", []) or []: | ||
| if getattr( | ||
| tool_response, "tool_name", "" | ||
| ) != "knowledge_search" or not getattr(tool_response, "content", []): | ||
| continue | ||
|
|
||
| _process_knowledge_search_content(tool_response, metadata_map) | ||
|
|
||
| # Extract referenced documents from metadata with error handling | ||
| referenced_documents = [] | ||
| for v in metadata_map.values(): | ||
| if "docs_url" in v and "title" in v: | ||
| try: | ||
| doc = ReferencedDocument(doc_url=v["docs_url"], doc_title=v["title"]) | ||
| referenced_documents.append(doc) | ||
| except (pydantic.ValidationError, ValueError, Exception) as e: | ||
| logger.warning( | ||
| "Skipping invalid referenced document with docs_url='%s', title='%s': %s", | ||
| v.get("docs_url", "<missing>"), | ||
| v.get("title", "<missing>"), | ||
| str(e), | ||
| ) | ||
| continue | ||
|
|
||
| return referenced_documents | ||
|
coderabbitai[bot] marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| query_response: dict[int | str, dict[str, Any]] = { | ||
| 200: { | ||
| "conversation_id": "123e4567-e89b-12d3-a456-426614174000", | ||
| "response": "LLM answer", | ||
| "referenced_documents": [ | ||
| { | ||
| "doc_url": ( | ||
| "https://docs.openshift.com/container-platform/" | ||
| "4.15/operators/olm/index.html" | ||
| ), | ||
| "doc_title": "Operator Lifecycle Manager (OLM)", | ||
| } | ||
| ], | ||
| }, | ||
| 400: { | ||
| "description": "Missing or invalid credentials provided by client", | ||
|
|
@@ -54,7 +150,7 @@ | |
| "description": "User is not authorized", | ||
| "model": ForbiddenResponse, | ||
| }, | ||
| 503: { | ||
| 500: { | ||
| "detail": { | ||
| "response": "Unable to connect to Llama Stack", | ||
| "cause": "Connection error.", | ||
|
|
@@ -189,7 +285,7 @@ async def query_endpoint_handler( | |
| user_conversation=user_conversation, query_request=query_request | ||
| ), | ||
| ) | ||
| response, conversation_id = await retrieve_response( | ||
| response, conversation_id, referenced_documents = await retrieve_response( | ||
| client, | ||
| llama_stack_model_id, | ||
| query_request, | ||
|
|
@@ -223,7 +319,11 @@ async def query_endpoint_handler( | |
| provider_id=provider_id, | ||
| ) | ||
|
|
||
| return QueryResponse(conversation_id=conversation_id, response=response) | ||
| return QueryResponse( | ||
| conversation_id=conversation_id, | ||
| response=response, | ||
| referenced_documents=referenced_documents, | ||
| ) | ||
|
|
||
| # connection to Llama Stack server | ||
| except APIConnectionError as e: | ||
|
|
@@ -316,13 +416,13 @@ def is_input_shield(shield: Shield) -> bool: | |
| return _is_inout_shield(shield) or not is_output_shield(shield) | ||
|
|
||
|
|
||
| async def retrieve_response( # pylint: disable=too-many-locals | ||
| async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches | ||
| client: AsyncLlamaStackClient, | ||
| model_id: str, | ||
| query_request: QueryRequest, | ||
| token: str, | ||
| mcp_headers: dict[str, dict[str, str]] | None = None, | ||
| ) -> tuple[str, str]: | ||
| ) -> tuple[str, str, list[ReferencedDocument]]: | ||
| """Retrieve response from LLMs and agents.""" | ||
| available_input_shields = [ | ||
| shield.identifier | ||
|
|
@@ -402,15 +502,33 @@ async def retrieve_response( # pylint: disable=too-many-locals | |
| toolgroups=toolgroups, | ||
| ) | ||
|
|
||
| # Check for validation errors in the response | ||
| # Check for validation errors and extract referenced documents | ||
| steps = getattr(response, "steps", []) | ||
| for step in steps: | ||
| if step.step_type == "shield_call" and step.violation: | ||
| if getattr(step, "step_type", "") == "shield_call" and getattr( | ||
| step, "violation", False | ||
| ): | ||
| # Metric for LLM validation errors | ||
| metrics.llm_calls_validation_errors_total.inc() | ||
| break | ||
|
|
||
| return str(response.output_message.content), conversation_id # type: ignore[union-attr] | ||
| # Extract referenced documents from tool execution steps | ||
| referenced_documents = extract_referenced_documents_from_steps(steps) | ||
|
|
||
| # When stream=False, response should have output_message attribute | ||
| response_obj = cast(Any, response) | ||
|
|
||
| # Safely guard access to output_message and content | ||
| output_message = getattr(response_obj, "output_message", None) | ||
| if output_message and getattr(output_message, "content", None) is not None: | ||
| content_str = str(output_message.content) | ||
| else: | ||
| content_str = "" | ||
|
|
||
| return ( | ||
| content_str, | ||
| conversation_id, | ||
| referenced_documents, | ||
| ) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
|
|
||
| def validate_attachments_metadata(attachments: list[Attachment]) -> None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
metadata_mapseems to be the return value, not a real parameter. Please refactor to return newmetadata_mapThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed with 06cba91