Spaces:
Paused
Paused
import spaces | |
import sys | |
from uuid import UUID | |
from fastapi import FastAPI, HTTPException | |
from transformers import pipeline | |
import torch | |
import gradio as gr | |
from Supabase import initSupabase, updateSupabaseChatHistory, updateSupabaseChatStatus | |
from supabase import Client | |
from config import MODEL_CONFIG | |
from typing import Dict, Any, List | |
from api_schemas import API_RESPONSES | |
from VectorDB import * | |
from pydantic import BaseModel | |
import uvicorn | |
# Use spaces.GPU decorator for GPU operations | |
def initialize_gpu(): | |
zero = torch.Tensor([0]).cuda() | |
print(f"GPU initialized: {zero.device}") | |
return zero.device | |
# Device selection with proper error handling | |
try: | |
# For Hugging Face Spaces, let spaces handle GPU initialization | |
if "spaces" in sys.modules: | |
print("Running in Hugging Face Spaces, using spaces.GPU for device management") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
else: | |
# Local development path | |
if torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
elif torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
print(f"Using device: {device}") | |
except Exception as e: | |
print(f"Error detecting device, falling back to CPU: {str(e)}") | |
device = torch.device("cpu") | |
# Initialize GPU within application context to avoid blocking | |
if device.type == "cuda": | |
# Only initialize GPU if we've selected CUDA | |
initialize_gpu() | |
# Continue with the rest of your initialization | |
initRAG(device) | |
supabase: Client = initSupabase() | |
# print(search_docs("how much employment in manchester")) | |
# Initialize the LLM | |
try: | |
pipe = pipeline( | |
"text-generation", | |
model=MODEL_CONFIG["model_name"], | |
max_new_tokens=MODEL_CONFIG["max_new_tokens"], | |
temperature=0.3, | |
do_sample=True, # Allow sampling to generate diverse responses. More conversational and human-like | |
top_k=50, # Limit the top-k tokens to sample from | |
top_p=0.95, # Limit the cumulative probability distribution for sampling | |
trust_remote_code=True | |
) | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
raise RuntimeError("Failed to initialize the model") | |
# Define the system prompt that sets the behavior and role of the LLM | |
SYSTEM_PROMPT = """Your name is SophiaAI. | |
You are a friendly and empathetic assistant designed to empower refugee women and help with any questions. | |
You should always be friendly. On your first message introduce yourself with your name. Use emoji in all of your responses to be relatable. You may consider πππ€ | |
Make sure to always reply to the user in the language they are speaking. | |
Ensure you have all the information you need before answering a question. Don't make anything up or guess. | |
Once you have answered a question, you should check if the user would like more detail on a specific area. | |
Please provide a structured lists with markdown-style links: | |
- **Charity Name**: Description [Visit their website](https://example.com) | |
Provide spacing between each item in the list or paragraph. | |
""" | |
# Serve the API docs as our landing page | |
app = FastAPI(docs_url="/", title="SophiaAi - 21312701", version="1", description="SophiaAi is a Chatbot created for a university final project.\nDesigned to empower refugee women, there is a RAG pipeline containing resources to support refuges connected to a finetuned LLM.") | |
print("App Startup Complete!") | |
class ChatRequest(BaseModel): | |
conversationHistory: List[Dict[str, str]] | |
chatID: UUID | |
model_config = { | |
"json_schema_extra": { | |
"example": { | |
"conversationHistory": [ | |
{ | |
"role": "user", | |
"content": "hi" | |
}, | |
{ | |
"role": "assistant", | |
"content": "Hello! How can I assist you today?" | |
}, | |
{ | |
"role": "user", | |
"content": "whats the weather in MCR" | |
} | |
], | |
"chatID": "123e4567-e89b-12d3-a456-426614174000" | |
} | |
} | |
} | |
async def generateFromChatHistory(input: ChatRequest): | |
# Use the GPU-enabled function inside this async function | |
return _gpu_generate_response(input) | |
# Create a GPU-decorated synchronous function | |
def _gpu_generate_response(input: ChatRequest): | |
"""GPU-enabled function to handle the model generation""" | |
# Input validation | |
if not input.conversationHistory or len(input.conversationHistory) == 0: | |
raise HTTPException(status_code=400, detail="Conversation history cannot be empty") | |
if len(input.conversationHistory) > MODEL_CONFIG["max_conversation_history_size"]: # Arbitrary limit to avoid overloading LLM, adjust as needed | |
raise HTTPException(status_code=400, detail="Conversation history too long") | |
try: | |
# Map Conversation history | |
content = [ | |
{ | |
"role": "system", | |
"content": SYSTEM_PROMPT, | |
} | |
] | |
content.extend( | |
{"role": message["role"], "content": message["content"]} | |
for message in input.conversationHistory | |
) | |
updateSupabaseChatHistory(content[1:], input.chatID, supabase, True) # Update supabase | |
# Combine system prompt with user input | |
LastQuestion = input.conversationHistory[-1]["content"] # Users last question | |
RAG_Results = search_docs(LastQuestion, 3) # search Vector Database for user input. | |
# Retrieve RAG results | |
RAG_Results = search_docs(LastQuestion, 3) | |
RagPrompt = f"""_RAG_ | |
Use the following information to assist in answering the users question most recent question. Do not make anything up or guess. | |
Relevant information retrieved: {RAG_Results} | |
If you don't know, simply let the user know, or ask for more detail. The user has not seen this message, it is for your reference only.""" | |
# Append RAG results with a dedicated role | |
rag_message = { | |
"role": "user", | |
"content": RagPrompt | |
} | |
content.append(rag_message) | |
# print(content) | |
# Generate response | |
output = pipe(content, num_return_sequences=1, max_new_tokens=MODEL_CONFIG["max_new_tokens"]) | |
generated_text = output[0]["generated_text"] # Get the entire conversation history including new generated item | |
generated_text.pop(0) # Remove the system prompt from the generated text | |
updateSupabaseChatHistory(generated_text, input.chatID, supabase)# Update supabase | |
return { | |
"status": "success", | |
"generated_text": generated_text # generated_text[-1], # return only the input prompt and the generated response | |
} | |
except Exception as e: | |
updateSupabaseChatStatus(False, input.chatID, supabase) # Notify database that an a chat isn't being processed | |
raise HTTPException( | |
status_code=500, detail=f"Error generating response: {str(e)}" | |
) from e | |
async def search_rag(query: str, limit: int = 3): | |
""" | |
Search the RAG system directly with a query | |
Args: | |
query (str): The search query | |
limit (int): Maximum number of results to return (default: 3 | |
Returns: | |
Dict: Search results with relevant document | |
Raises: | |
HTTPException: If the query is invalid or search fails | |
""" | |
# Input validation | |
if not query or not query.strip(): | |
raise HTTPException(status_code=400, detail="Search query cannot be empty") | |
if len(query) > 1000: # Arbitrary limit | |
raise HTTPException(status_code=400, detail="Query text too long") | |
try: | |
# Get results from vector database | |
results = search_docs(query, limit) | |
return { | |
"status": "success", | |
"results": results | |
} | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, detail=f"Error searching documents: {str(e)}" | |
) from e | |
async def generateSingleResponse(input: str): | |
""" | |
Generate AI responses. | |
Args: | |
input (str): The user's question or prompt | |
Returns: | |
Dict[str, str]: Structured response containing the generated text | |
Raises: | |
HTTPException: If input is invalid or generation fails | |
""" | |
# Input validation | |
if not input or not input.strip(): | |
raise HTTPException(status_code=400, detail="Input text cannot be empty") | |
if len(input) > 1000: # Arbitrary limit, adjust as needed | |
raise HTTPException(status_code=400, detail="Input text too long") | |
# search Vector Database for user input. | |
RAG_Results = search_docs(input, 3) | |
# print(RAG_Results) | |
combined_input = f""" | |
Here is the users questions: {input}. | |
Use the following information to assist in answering the users question. Do not make anything up or guess. | |
If you don't know, simply let the user know. | |
{RAG_Results} | |
""" | |
try: | |
# Combine system prompt with user input | |
content = [ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
{"role": "user", "content": combined_input}, | |
] | |
# Generate response | |
output = pipe(content, num_return_sequences=1, max_new_tokens=MODEL_CONFIG["max_new_tokens"]) | |
# Extract the conversation text from the output | |
generated_text = output[0]["generated_text"] | |
print(generated_text) | |
# Remove the system prompt from the generated text | |
# Structure the response | |
return { | |
"status": "success", | |
"generated_text": generated_text[-1], # return only the input prompt and the generated response | |
} | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, detail=f"Error generating response: {str(e)}" | |
) from e | |
async def status(): | |
""" | |
Check the service status | |
""" | |
return {"status": "success", "message": "Service is running"} | |
# Gradio Interface | |
def chatbot_interface(user_input, history): | |
return f"You said: {user_input}" # Replace with actual chatbot logic | |
demo = gr.ChatInterface(chatbot_interface) | |
# Mount Gradio app on FastAPI | |
app = gr.mount_gradio_app(app, demo, path="/") | |
if __name__ == "__main__": | |
# Run FastAPI on the port Spaces expects (7860) | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |