T0X1N's picture
refactor: resolve pylint errors and warnings across api and src routers
7caf4dc
"""
MediGuard AI — Ask Router
Free-form medical Q&A powered by the agentic RAG pipeline.
Supports both synchronous and SSE streaming responses.
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
import uuid
from collections.abc import AsyncGenerator
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from src.schemas.schemas import AskRequest, AskResponse, FeedbackRequest, FeedbackResponse
logger = logging.getLogger(__name__)
router = APIRouter(tags=["ask"])
@router.post("/ask", response_model=AskResponse)
async def ask_medical_question(body: AskRequest, request: Request):
"""Answer a free-form medical question via agentic RAG."""
rag_service = getattr(request.app.state, "rag_service", None)
if rag_service is None:
raise HTTPException(status_code=503, detail="RAG service unavailable")
request_id = f"req_{uuid.uuid4().hex[:12]}"
t0 = time.time()
try:
result = rag_service.ask(
query=body.question,
biomarkers=body.biomarkers,
patient_context=body.patient_context or "",
)
except Exception as exc:
logger.exception("Agentic RAG failed: %s", exc)
raise HTTPException(status_code=500, detail=f"RAG pipeline error: {exc}") from exc
elapsed = (time.time() - t0) * 1000
return AskResponse(
status="success",
request_id=request_id,
question=body.question,
answer=result.get("final_answer", ""),
guardrail_score=result.get("guardrail_score"),
documents_retrieved=len(result.get("retrieved_documents", [])),
documents_relevant=len(result.get("relevant_documents", [])),
processing_time_ms=round(elapsed, 1),
)
# ---------------------------------------------------------------------------
# SSE Streaming Endpoint
# ---------------------------------------------------------------------------
async def _stream_rag_response(
rag_service,
question: str,
biomarkers: dict | None,
patient_context: str,
request_id: str,
) -> AsyncGenerator[str, None]:
"""
Generate Server-Sent Events for streaming RAG responses.
Event types:
- status: Pipeline stage updates
- token: Individual response tokens
- metadata: Retrieval/grading info
- done: Final completion signal
- error: Error information
"""
t0 = time.time()
try:
# Send initial status
yield f"event: status\ndata: {json.dumps({'stage': 'guardrail', 'message': 'Validating query...'})}\n\n"
await asyncio.sleep(0) # Allow event loop to flush
# Run the RAG pipeline (synchronous, but we yield progress)
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
None,
lambda: rag_service.ask(
query=question,
biomarkers=biomarkers,
patient_context=patient_context,
),
)
# Send retrieval metadata
yield f"event: metadata\ndata: {json.dumps({'documents_retrieved': len(result.get('retrieved_documents', [])), 'documents_relevant': len(result.get('relevant_documents', [])), 'guardrail_score': result.get('guardrail_score')})}\n\n"
await asyncio.sleep(0)
# Stream the answer token by token for smooth UI
answer = result.get("final_answer", "")
if answer:
yield f"event: status\ndata: {json.dumps({'stage': 'generating', 'message': 'Generating response...'})}\n\n"
# Simulate streaming by chunking the response
words = answer.split()
chunk_size = 3 # Send 3 words at a time
for i in range(0, len(words), chunk_size):
chunk = " ".join(words[i : i + chunk_size])
if i + chunk_size < len(words):
chunk += " "
yield f"event: token\ndata: {json.dumps({'text': chunk})}\n\n"
await asyncio.sleep(0.02) # Small delay for visual streaming effect
# Send completion
elapsed = (time.time() - t0) * 1000
yield f"event: done\ndata: {json.dumps({'request_id': request_id, 'processing_time_ms': round(elapsed, 1), 'status': 'success'})}\n\n"
except Exception as exc:
logger.exception("Streaming RAG failed: %s", exc)
yield f"event: error\ndata: {json.dumps({'error': str(exc), 'request_id': request_id})}\n\n"
@router.post("/ask/stream")
async def ask_medical_question_stream(body: AskRequest, request: Request):
"""
Stream a medical Q&A response via Server-Sent Events (SSE).
Events:
- `status`: Pipeline stage updates (guardrail, retrieve, grade, generate)
- `token`: Individual response tokens for real-time display
- `metadata`: Retrieval statistics (documents found, relevance scores)
- `done`: Completion signal with timing info
- `error`: Error details if something fails
Example client code (JavaScript):
```javascript
const eventSource = new EventSource('/ask/stream', {
method: 'POST',
body: JSON.stringify({ question: 'What causes high glucose?' })
});
eventSource.addEventListener('token', (e) => {
const data = JSON.parse(e.data);
document.getElementById('response').innerHTML += data.text;
});
```
"""
rag_service = getattr(request.app.state, "rag_service", None)
if rag_service is None:
raise HTTPException(status_code=503, detail="RAG service unavailable")
request_id = f"req_{uuid.uuid4().hex[:12]}"
return StreamingResponse(
_stream_rag_response(
rag_service,
body.question,
body.biomarkers,
body.patient_context or "",
request_id,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Request-ID": request_id,
},
)
@router.post("/feedback", response_model=FeedbackResponse)
async def submit_feedback(body: FeedbackRequest, request: Request):
"""Submit user feedback for an analysis or RAG response."""
tracer = getattr(request.app.state, "tracer", None)
if tracer:
tracer.score(trace_id=body.request_id, name="user-feedback", value=body.score, comment=body.comment)
return FeedbackResponse(request_id=body.request_id)