File size: 5,347 Bytes
527779b
d47a213
f1fb734
 
140b902
27bbfe3
6c6d21d
 
f1fb734
27bbfe3
527779b
 
 
140b902
527779b
 
d47a213
f1fb734
6c6d21d
 
27bbfe3
 
527779b
 
 
27bbfe3
 
 
f1fb734
27bbfe3
 
f1fb734
27bbfe3
 
 
 
 
f1fb734
 
 
 
 
 
 
 
 
 
 
 
1e44c85
 
1304d8f
1e44c85
1304d8f
f1fb734
 
68acb0b
f1fb734
 
527779b
 
d47a213
 
 
 
 
527779b
 
 
 
d47a213
 
 
 
 
f1fb734
d47a213
f1fb734
27bbfe3
68acb0b
140b902
 
 
f1fb734
140b902
f1fb734
 
 
 
d47a213
f1fb734
 
 
 
 
 
 
27bbfe3
f1fb734
 
 
27bbfe3
527779b
f1fb734
 
527779b
f1fb734
 
 
527779b
f1fb734
 
140b902
527779b
 
2f92a5d
527779b
 
 
 
 
 
f1fb734
527779b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1304d8f
1e44c85
1304d8f
1e44c85
1304d8f
 
 
 
 
 
 
 
 
 
527779b
f1fb734
d47a213
f1fb734
 
 
140b902
 
 
f1fb734
27bbfe3
68acb0b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from fastapi import FastAPI, Request, Depends, HTTPException, Header, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
from helpmate_ai import get_system_msg, retreive_results, rerank_with_cross_encoder, generate_response, intro_message
import google.generativeai as genai
import os
from dotenv import load_dotenv
import re

import speech_recognition as sr
from io import BytesIO
import wave

import google.generativeai as genai


# Load environment variables
load_dotenv()
gemini_api_key = os.getenv("GEMINI_API_KEY")
genai.configure(api_key=gemini_api_key)

# Define a secret API key (use environment variables in production)
API_KEY = os.getenv("API_KEY")

# Initialize FastAPI app
app = FastAPI()

# # Enable CORS
# app.add_middleware(
#     CORSMiddleware,
#     allow_origins=["*"],
#     allow_credentials=True,
#     allow_methods=["*"],
#     allow_headers=["*"],
# )

# Pydantic models for request/response validation
class Message(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    message: str

class ChatResponse(BaseModel):
    response: str
    conversation: List[Message]

class Report(BaseModel):
    response: str
    message: str
    timestamp: str

# Initialize conversation and model
conversation_bot = []
conversation = get_system_msg()
model = genai.GenerativeModel("gemini-1.5-flash", system_instruction=conversation)

# Initialize speech recognizer
recognizer = sr.Recognizer()

# Dependency to check the API key
async def verify_api_key(x_api_key: str = Header(...)):
    if x_api_key != API_KEY:
        raise HTTPException(status_code=403, detail="Unauthorized")
    
def get_gemini_completions(conversation: str) -> str:
    response = model.generate_content(conversation)
    return response.text

# @app.get("/secure-endpoint", dependencies=[Depends(verify_api_key)])
# async def secure_endpoint():
#     return {"message": "Access granted!"}

# Initialize conversation endpoint
@app.get("/init", response_model=ChatResponse, dependencies=[Depends(verify_api_key)])
async def initialize_chat():
    global conversation_bot

    # conversation = "Hi" 
    # introduction = get_gemini_completions(conversation)
    conversation_bot = [Message(role="bot", content=intro_message)]
    return ChatResponse(
        response=intro_message,
        conversation=conversation_bot
    )

# Chat endpoint
@app.post("/chat", response_model=ChatResponse, dependencies=[Depends(verify_api_key)])
async def chat(request: ChatRequest):
    global conversation_bot
    
    # Add user message to conversation
    user_message = Message(role="user", content=request.message)
    conversation_bot.append(user_message)
    
    # Generate response
    results_df = retreive_results(request.message)
    top_docs = rerank_with_cross_encoder(request.message, results_df)
    messages = generate_response(request.message, top_docs)
    response_assistant = get_gemini_completions(messages)
    # formatted_response = format_rag_response(response_assistant)
    
    # Add bot response to conversation
    bot_message = Message(role="bot", content=response_assistant)
    conversation_bot.append(bot_message)
    
    return ChatResponse(
        response=response_assistant,
        conversation=conversation_bot
    )

# Voice processing endpoint
@app.post("/process-voice")
async def process_voice(audio_file: UploadFile = File(...), dependencies=[Depends(verify_api_key)]):
# async def process_voice(name: str):

    try:
        # Read the audio file
        contents = await audio_file.read()
        audio_data = BytesIO(contents)

    
        # Convert audio to wav format for speech recognition
        with sr.AudioFile(audio_data) as source:
            audio = recognizer.record(source)
       
        # Perform speech recognition
        text = recognizer.recognize_google(audio)
        # print(text)
        
        # Process the text through the chat pipeline
        results_df = retreive_results(text)
        top_docs = rerank_with_cross_encoder(text, results_df)
        messages = generate_response(text, top_docs)
        response_assistant = get_gemini_completions(messages)
        
        return {
            "transcribed_text": text,
            "response": response_assistant
        }
    
    except Exception as e:
        return {"error": f"Error processing voice input: {str(e)}"}

@app.post("/report")
async def handle_feedback(
    request: Report,
    dependencies=[Depends(verify_api_key)]
):
    # if x_api_key != VALID_API_KEY:
    #     raise HTTPException(status_code=403, detail="Invalid API key")
        
    # Here you can store the feedback in your database
    # For example:
    # await db.store_feedback(message, is_positive)
    
    return {"status": "success"}
    
# Reset conversation endpoint
@app.post("/reset", dependencies=[Depends(verify_api_key)])
async def reset_conversation():
    global conversation_bot, conversation
    conversation_bot = []
    # conversation = "Hi"
    # introduction = get_gemini_completions(conversation)
    conversation_bot.append(Message(role="bot", content=intro_message))
    return {"status": "success", "message": "Conversation reset"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)