File size: 3,452 Bytes
19aaa42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import sys
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import uvicorn
import logging
from typing import Optional, List, Dict

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Add the project root to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

from src.groq_medical_rag import GroqMedicalRAG, MedicalResponse

# --- Globals ---
rag_system = None

# --- Lifespan Management ---
@asynccontextmanager
async def lifespan(app: FastAPI):
    global rag_system
    logger.info("🚀 Initializing RAG system...")
    try:
        rag_system = GroqMedicalRAG()
        logger.info("✅ RAG system initialized successfully.")
    except Exception as e:
        logger.error(f"❌ CRITICAL: Failed to initialize RAG system: {str(e)}")
        logger.info("✅ API server is running, but RAG functionality will be disabled.")
        rag_system = None
    yield
    logger.info("👋 Shutting down...")

# --- FastAPI App ---
app = FastAPI(title="Clinical Assistant API", lifespan=lifespan)

# --- CORS Middleware ---
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins
    allow_credentials=True,
    allow_methods=["*"],  # Allows all methods
    allow_headers=["*"],  # Allows all headers
)

# --- Pydantic Models ---
class QueryRequest(BaseModel):
    query: str
    history: Optional[List[Dict[str, str]]] = None

class QueryResponse(BaseModel):
    response: str

class HealthResponse(BaseModel):
    status: str
    rag_system_status: str
    version: str = "1.0.0"

# --- API Endpoints ---
@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Check the health status of the API and its components."""
    return HealthResponse(
        status="healthy",
        rag_system_status="initialized" if rag_system else "offline"
    )

@app.post("/query", response_model=QueryResponse)
async def process_query(query_request: QueryRequest):
    """Processes a clinical query and returns an evidence-based answer."""
    logger.info(f"Received query: {query_request.query[:50]}...")
    
    if not rag_system:
        logger.warning("Query received but RAG system is offline")
        return JSONResponse(
            status_code=503,
            content={
                "response": "Sorry, the clinical assistant is currently offline due to a connection issue. Please try again later."
            }
        )
    
    try:
        query_text = query_request.query
        history = query_request.history
        medical_response = rag_system.query(query=query_text, history=history)
        logger.info("Query processed successfully")
        return QueryResponse(response=medical_response.answer)
    except Exception as e:
        logger.error(f"Error processing query: {str(e)}")
        return JSONResponse(
            status_code=500,
            content={
                "response": "An error occurred while processing your query. Please try again.",
                "error": str(e)
            }
        )

# --- Main ---
if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)