from huggingface_hub import login from fastapi import FastAPI, Depends, HTTPException import logging from pydantic import BaseModel from sentence_transformers import SentenceTransformer from services.qdrant_searcher import QdrantSearcher from services.openai_service import generate_rag_response from utils.auth import token_required from dotenv import load_dotenv import os load_dotenv() # Load environment variables from .env file app = FastAPI() os.environ["HF_HOME"] = "/tmp/huggingface_cache" # Ensure the cache directory exists cache_dir = os.environ["HF_HOME"] if not os.path.exists(cache_dir): os.makedirs(cache_dir) # Setup logging logging.basicConfig(level=logging.INFO) # Load Hugging Face token from environment variable huggingface_token = os.getenv('HUGGINGFACE_HUB_TOKEN') if huggingface_token: login(token=huggingface_token, add_to_git_credential=True) else: raise ValueError("Hugging Face token is not set. Please set the HUGGINGFACE_HUB_TOKEN environment variable.") # Initialize the Qdrant searcher qdrant_url = os.getenv('QDRANT_URL') access_token = os.getenv('QDRANT_ACCESS_TOKEN') encoder = SentenceTransformer('paraphrase-MiniLM-L6-v2', trust_remote_code=True) # Replace with your actual encoder searcher = QdrantSearcher(encoder, qdrant_url, access_token) # Request body models class SearchDocumentsRequest(BaseModel): query: str limit: int = 3 class GenerateRAGRequest(BaseModel): search_query: str @app.post("/api/search-documents") async def search_documents( body: SearchDocumentsRequest, credentials: tuple = Depends(token_required) ): customer_id, user_id = credentials # Check if customer_id or user_id is missing if not customer_id or not user_id: logging.error("Failed to extract customer_id or user_id from the JWT token.") raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id") logging.info("Received request to search documents") try: collection_name = "my_embeddings" hits, error = searcher.search_documents(collection_name, body.query, user_id, body.limit) if error: logging.error(f"Search documents error: {error}") raise HTTPException(status_code=500, detail=error) return hits except Exception as e: logging.error(f"Unexpected error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/generate-rag-response") async def generate_rag_response_api( body: GenerateRAGRequest, credentials: tuple = Depends(token_required) ): customer_id, user_id = credentials # Check if customer_id or user_id is missing if not customer_id or not user_id: logging.error("Failed to extract customer_id or user_id from the JWT token.") raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id") logging.info("Received request to generate RAG response") try: collection_name = "my_embeddings" hits, error = searcher.search_documents(collection_name, body.search_query, user_id) if error: logging.error(f"Search documents error: {error}") raise HTTPException(status_code=500, detail=error) response, error = generate_rag_response(hits, body.search_query) if error: logging.error(f"Generate RAG response error: {error}") raise HTTPException(status_code=500, detail=error) return {"response": response} except Exception as e: logging.error(f"Unexpected error: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == '__main__': import uvicorn uvicorn.run(app, host='0.0.0.0', port=8000)