from fastapi import FastAPI, Request, Depends, HTTPException, Header, File, UploadFile from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List, Optional from helpmate_ai import get_system_msg, retreive_results, rerank_with_cross_encoder, generate_response, intro_message import google.generativeai as genai import os from dotenv import load_dotenv import re import speech_recognition as sr from io import BytesIO import wave import google.generativeai as genai # Load environment variables load_dotenv() gemini_api_key = os.getenv("GEMINI_API_KEY") genai.configure(api_key=gemini_api_key) # Define a secret API key (use environment variables in production) API_KEY = os.getenv("API_KEY") # Initialize FastAPI app app = FastAPI() # # Enable CORS # app.add_middleware( # CORSMiddleware, # allow_origins=["*"], # allow_credentials=True, # allow_methods=["*"], # allow_headers=["*"], # ) # Pydantic models for request/response validation class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): message: str class ChatResponse(BaseModel): response: str conversation: List[Message] class Report(BaseModel): response: str message: str timestamp: str # Initialize conversation and model conversation_bot = [] conversation = get_system_msg() model = genai.GenerativeModel("gemini-1.5-flash", system_instruction=conversation) # Initialize speech recognizer recognizer = sr.Recognizer() # Dependency to check the API key async def verify_api_key(x_api_key: str = Header(...)): if x_api_key != API_KEY: raise HTTPException(status_code=403, detail="Unauthorized") def get_gemini_completions(conversation: str) -> str: response = model.generate_content(conversation) return response.text # @app.get("/secure-endpoint", dependencies=[Depends(verify_api_key)]) # async def secure_endpoint(): # return {"message": "Access granted!"} # Initialize conversation endpoint @app.get("/init", response_model=ChatResponse, dependencies=[Depends(verify_api_key)]) async def initialize_chat(): global conversation_bot # conversation = "Hi" # introduction = get_gemini_completions(conversation) conversation_bot = [Message(role="bot", content=intro_message)] return ChatResponse( response=intro_message, conversation=conversation_bot ) # Chat endpoint @app.post("/chat", response_model=ChatResponse, dependencies=[Depends(verify_api_key)]) async def chat(request: ChatRequest): global conversation_bot # Add user message to conversation user_message = Message(role="user", content=request.message) conversation_bot.append(user_message) # Generate response results_df = retreive_results(request.message) top_docs = rerank_with_cross_encoder(request.message, results_df) messages = generate_response(request.message, top_docs) response_assistant = get_gemini_completions(messages) # formatted_response = format_rag_response(response_assistant) # Add bot response to conversation bot_message = Message(role="bot", content=response_assistant) conversation_bot.append(bot_message) return ChatResponse( response=response_assistant, conversation=conversation_bot ) # Voice processing endpoint @app.post("/process-voice") async def process_voice(audio_file: UploadFile = File(...), dependencies=[Depends(verify_api_key)]): # async def process_voice(name: str): try: # Read the audio file contents = await audio_file.read() audio_data = BytesIO(contents) # Convert audio to wav format for speech recognition with sr.AudioFile(audio_data) as source: audio = recognizer.record(source) # Perform speech recognition text = recognizer.recognize_google(audio) # print(text) # Process the text through the chat pipeline results_df = retreive_results(text) top_docs = rerank_with_cross_encoder(text, results_df) messages = generate_response(text, top_docs) response_assistant = get_gemini_completions(messages) return { "transcribed_text": text, "response": response_assistant } except Exception as e: return {"error": f"Error processing voice input: {str(e)}"} @app.post("/report") async def handle_feedback( request: Report, dependencies=[Depends(verify_api_key)] ): # if x_api_key != VALID_API_KEY: # raise HTTPException(status_code=403, detail="Invalid API key") # Here you can store the feedback in your database # For example: # await db.store_feedback(message, is_positive) return {"status": "success"} # Reset conversation endpoint @app.post("/reset", dependencies=[Depends(verify_api_key)]) async def reset_conversation(): global conversation_bot, conversation conversation_bot = [] # conversation = "Hi" # introduction = get_gemini_completions(conversation) conversation_bot.append(Message(role="bot", content=intro_message)) return {"status": "success", "message": "Conversation reset"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)