| import asyncio |
| import logging |
| from typing import Dict, Any |
|
|
| from fastapi import HTTPException, UploadFile, status, Depends |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from config import Config |
|
|
| from .rag_pipeline import route_and_process_query, add_document_to_rag, check_system_health |
| from .document_handler import extract_text_from_file |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| security = HTTPBearer() |
|
|
| |
| SUPPORTED_CONTENT_TYPES = Config.RAG_SUPPORTED_CONTENT_TYPES |
|
|
| MAX_FILE_SIZE = Config.RAG_MAX_FILE_SIZE |
| MAX_QUERY_LENGTH = Config.RAG_MAX_QUERY_LENGTH |
|
|
| async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): |
| """Verify Bearer token from Authorization header.""" |
| token = credentials.credentials |
| expected_token = Config.SECRET_TOKEN |
| |
| if not expected_token: |
| logger.error("MY_SECRET_TOKEN not configured") |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="Server configuration error" |
| ) |
| |
| if token != expected_token: |
| logger.warning(f"Invalid token attempt: {token[:10]}...") |
| raise HTTPException( |
| status_code=status.HTTP_403_FORBIDDEN, |
| detail="Invalid or expired token" |
| ) |
| return token |
|
|
| async def handle_rag_query(query: str) -> Dict[str, Any]: |
| """Handle an incoming query by routing it and getting the appropriate answer.""" |
| |
| |
| if not query or not query.strip(): |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Query cannot be empty" |
| ) |
| |
| if len(query) > MAX_QUERY_LENGTH: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=f"Query too long. Please limit to {MAX_QUERY_LENGTH} characters." |
| ) |
| |
| try: |
| logger.info(f"Processing query: {query[:50]}...") |
| |
| |
| response = await asyncio.to_thread(route_and_process_query, query) |
| |
| logger.info(f"Query processed successfully. Route: {response.get('route', 'Unknown')}") |
| return response |
| |
| except Exception as e: |
| logger.error(f"Error processing query: {e}") |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="Error processing your query. Please try again." |
| ) |
|
|
| async def handle_document_upload(file: UploadFile) -> Dict[str, str]: |
| """Handle uploading a document to the RAG's vector store.""" |
| |
| |
| if not file.filename: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="No file provided" |
| ) |
| |
| if file.content_type not in SUPPORTED_CONTENT_TYPES: |
| raise HTTPException( |
| status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, |
| detail=f"Unsupported file type: {file.content_type}. " |
| f"Supported types: {', '.join(SUPPORTED_CONTENT_TYPES)}" |
| ) |
| |
| |
| contents = await file.read() |
| if len(contents) > MAX_FILE_SIZE: |
| raise HTTPException( |
| status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, |
| detail=f"File too large. Maximum size: {MAX_FILE_SIZE / (1024*1024):.1f}MB" |
| ) |
| |
| |
| await file.seek(0) |
| |
| try: |
| logger.info(f"Processing file upload: {file.filename}") |
| |
| |
| text = await extract_text_from_file(file) |
| |
| if not text or not text.strip(): |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="The file appears to be empty or could not be read." |
| ) |
| |
| if len(text) < 50: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="The extracted text is too short to be meaningful." |
| ) |
| |
| |
| success = await asyncio.to_thread( |
| add_document_to_rag, |
| text, |
| { |
| "source": file.filename, |
| "content_type": file.content_type, |
| "size": len(contents) |
| } |
| ) |
| |
| if not success: |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="Failed to add document to the knowledge base" |
| ) |
| |
| logger.info(f"Successfully processed file: {file.filename}") |
| |
| return { |
| "message": f"Successfully uploaded and processed '{file.filename}'. " |
| f"It is now available for querying.", |
| "filename": file.filename, |
| "text_length": len(text), |
| "content_type": file.content_type |
| } |
|
|
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"Error processing file {file.filename}: {e}") |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="Error processing the file. Please try again." |
| ) |
|
|
| async def handle_health_check() -> Dict[str, Any]: |
| """Handle health check requests.""" |
| try: |
| health_status = await asyncio.to_thread(check_system_health) |
| |
| if health_status["status"] == "unhealthy": |
| raise HTTPException( |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
| detail="Service is currently unhealthy" |
| ) |
| |
| return health_status |
| |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"Health check failed: {e}") |
| raise HTTPException( |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
| detail="Health check failed" |
| ) |