Spaces:
Paused
Paused
File size: 13,299 Bytes
237f784 50f586c f1a8641 91631e7 237f784 91631e7 237f784 91631e7 50f586c 237f784 f1a8641 50f586c f1a8641 9331285 f1a8641 5c459a5 f1a8641 9331285 4eb2470 a16e641 bef9f1a a16e641 4eb2470 9331285 38ce839 e456caf 846cadf 4eb2470 9331285 f1a8641 ec3c542 5ff504f f1a8641 ec3c542 f1a8641 1f18eee f1a8641 ec3c542 f1a8641 1f18eee f1a8641 bf901c7 f1a8641 bf901c7 f1a8641 bf901c7 f1a8641 bf901c7 da11680 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 |
import spaces
import sys
from uuid import UUID
from fastapi import FastAPI, HTTPException
from transformers import pipeline
import torch
import gradio as gr
from Supabase import initSupabase, updateSupabaseChatHistory, updateSupabaseChatStatus
from supabase import Client
from config import MODEL_CONFIG
from typing import Dict, Any, List
from api_schemas import API_RESPONSES
from VectorDB import *
from pydantic import BaseModel
import uvicorn
# Use spaces.GPU decorator for GPU operations
@spaces.GPU
def initialize_gpu():
zero = torch.Tensor([0]).cuda()
print(f"GPU initialized: {zero.device}")
return zero.device
# Device selection with proper error handling
try:
# For Hugging Face Spaces, let spaces handle GPU initialization
if "spaces" in sys.modules:
print("Running in Hugging Face Spaces, using spaces.GPU for device management")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
# Local development path
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print(f"Using device: {device}")
except Exception as e:
print(f"Error detecting device, falling back to CPU: {str(e)}")
device = torch.device("cpu")
# Initialize GPU within application context to avoid blocking
if device.type == "cuda":
# Only initialize GPU if we've selected CUDA
initialize_gpu()
# Continue with the rest of your initialization
initRAG(device)
supabase: Client = initSupabase()
# print(search_docs("how much employment in manchester"))
# Initialize the LLM
try:
pipe = pipeline(
"text-generation",
model=MODEL_CONFIG["model_name"],
max_new_tokens=MODEL_CONFIG["max_new_tokens"],
temperature=0.3,
do_sample=True, # Allow sampling to generate diverse responses. More conversational and human-like
top_k=50, # Limit the top-k tokens to sample from
top_p=0.95, # Limit the cumulative probability distribution for sampling
trust_remote_code=True
)
except Exception as e:
print(f"Error loading model: {str(e)}")
raise RuntimeError("Failed to initialize the model")
# Define the system prompt that sets the behavior and role of the LLM
SYSTEM_PROMPT = """Your name is SophiaAI.
You are a friendly and empathetic assistant designed to empower refugee women and help with any questions.
You should always be friendly. On your first message introduce yourself with your name. Use emoji in all of your responses to be relatable. You may consider πππ€
Make sure to always reply to the user in the language they are speaking.
Ensure you have all the information you need before answering a question. Don't make anything up or guess.
Once you have answered a question, you should check if the user would like more detail on a specific area.
Please provide a structured lists with markdown-style links:
- **Charity Name**: Description [Visit their website](https://example.com)
Provide spacing between each item in the list or paragraph.
"""
# Serve the API docs as our landing page
app = FastAPI(docs_url="/", title="SophiaAi - 21312701", version="1", description="SophiaAi is a Chatbot created for a university final project.\nDesigned to empower refugee women, there is a RAG pipeline containing resources to support refuges connected to a finetuned LLM.")
print("App Startup Complete!")
class ChatRequest(BaseModel):
conversationHistory: List[Dict[str, str]]
chatID: UUID
model_config = {
"json_schema_extra": {
"example": {
"conversationHistory": [
{
"role": "user",
"content": "hi"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "whats the weather in MCR"
}
],
"chatID": "123e4567-e89b-12d3-a456-426614174000"
}
}
}
@app.post(
"/generateFromChatHistory",
responses={
200: {
"description": "Successful response",
"content": {
"application/json": {
"example": {
"status": "success",
"generated_text": {
"role": "assistant",
"content": "I don't have real-time weather data for Manchester. To get accurate information, please check a weather service like BBC Weather or the Met Office website."
}
}
}
}
},
400: API_RESPONSES[400],
500: API_RESPONSES[500]
}
)
async def generateFromChatHistory(input: ChatRequest):
# Use the GPU-enabled function inside this async function
return _gpu_generate_response(input)
# Create a GPU-decorated synchronous function
@spaces.GPU
def _gpu_generate_response(input: ChatRequest):
"""GPU-enabled function to handle the model generation"""
# Input validation
if not input.conversationHistory or len(input.conversationHistory) == 0:
raise HTTPException(status_code=400, detail="Conversation history cannot be empty")
if len(input.conversationHistory) > MODEL_CONFIG["max_conversation_history_size"]: # Arbitrary limit to avoid overloading LLM, adjust as needed
raise HTTPException(status_code=400, detail="Conversation history too long")
try:
# Map Conversation history
content = [
{
"role": "system",
"content": SYSTEM_PROMPT,
}
]
content.extend(
{"role": message["role"], "content": message["content"]}
for message in input.conversationHistory
)
updateSupabaseChatHistory(content[1:], input.chatID, supabase, True) # Update supabase
# Combine system prompt with user input
LastQuestion = input.conversationHistory[-1]["content"] # Users last question
RAG_Results = search_docs(LastQuestion, 3) # search Vector Database for user input.
# Retrieve RAG results
RAG_Results = search_docs(LastQuestion, 3)
RagPrompt = f"""_RAG_
Use the following information to assist in answering the users question most recent question. Do not make anything up or guess.
Relevant information retrieved: {RAG_Results}
If you don't know, simply let the user know, or ask for more detail. The user has not seen this message, it is for your reference only."""
# Append RAG results with a dedicated role
rag_message = {
"role": "user",
"content": RagPrompt
}
content.append(rag_message)
# print(content)
# Generate response
output = pipe(content, num_return_sequences=1, max_new_tokens=MODEL_CONFIG["max_new_tokens"])
generated_text = output[0]["generated_text"] # Get the entire conversation history including new generated item
generated_text.pop(0) # Remove the system prompt from the generated text
updateSupabaseChatHistory(generated_text, input.chatID, supabase)# Update supabase
return {
"status": "success",
"generated_text": generated_text # generated_text[-1], # return only the input prompt and the generated response
}
except Exception as e:
updateSupabaseChatStatus(False, input.chatID, supabase) # Notify database that an a chat isn't being processed
raise HTTPException(
status_code=500, detail=f"Error generating response: {str(e)}"
) from e
@app.get(
"/test-searchRAG",
responses={
200: {
"description": "Successful RAG search results",
"content": {
"application/json": {
"example": {
"status": "success",
"results": [
{"content": "Example content 1", "metadata": {"source": "doc1.pdf"}},
{"content": "Example content 2", "metadata": {"source": "doc2.pdf"}}
]
}
}
}
},
400: API_RESPONSES[400],
500: API_RESPONSES[500]
}
)
async def search_rag(query: str, limit: int = 3):
"""
Search the RAG system directly with a query
Args:
query (str): The search query
limit (int): Maximum number of results to return (default: 3
Returns:
Dict: Search results with relevant document
Raises:
HTTPException: If the query is invalid or search fails
"""
# Input validation
if not query or not query.strip():
raise HTTPException(status_code=400, detail="Search query cannot be empty")
if len(query) > 1000: # Arbitrary limit
raise HTTPException(status_code=400, detail="Query text too long")
try:
# Get results from vector database
results = search_docs(query, limit)
return {
"status": "success",
"results": results
}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error searching documents: {str(e)}"
) from e
@app.get(
"/test-generateSingleResponse",
responses={
200: {
"description": "Successful response",
"content": {
"application/json": {
"example": {
"status": "success",
"generated_text": [
{
"role": "user",
"content": "hey"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today? Is there something specific you'd like to talk about or learn more about?"
}
]
}
}
}
},
400: API_RESPONSES[400],
500: API_RESPONSES[500]
}
)
async def generateSingleResponse(input: str):
"""
Generate AI responses.
Args:
input (str): The user's question or prompt
Returns:
Dict[str, str]: Structured response containing the generated text
Raises:
HTTPException: If input is invalid or generation fails
"""
# Input validation
if not input or not input.strip():
raise HTTPException(status_code=400, detail="Input text cannot be empty")
if len(input) > 1000: # Arbitrary limit, adjust as needed
raise HTTPException(status_code=400, detail="Input text too long")
# search Vector Database for user input.
RAG_Results = search_docs(input, 3)
# print(RAG_Results)
combined_input = f"""
Here is the users questions: {input}.
Use the following information to assist in answering the users question. Do not make anything up or guess.
If you don't know, simply let the user know.
{RAG_Results}
"""
try:
# Combine system prompt with user input
content = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": combined_input},
]
# Generate response
output = pipe(content, num_return_sequences=1, max_new_tokens=MODEL_CONFIG["max_new_tokens"])
# Extract the conversation text from the output
generated_text = output[0]["generated_text"]
print(generated_text)
# Remove the system prompt from the generated text
# Structure the response
return {
"status": "success",
"generated_text": generated_text[-1], # return only the input prompt and the generated response
}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error generating response: {str(e)}"
) from e
@app.get(
"/status",
responses={
200: {
"description": "Successful response",
"content": {
"application/json": {
"example": {
"status": "success",
"message": "Service is running"
}
}
}
}
}
)
async def status():
"""
Check the service status
"""
return {"status": "success", "message": "Service is running"}
# Gradio Interface
def chatbot_interface(user_input, history):
return f"You said: {user_input}" # Replace with actual chatbot logic
demo = gr.ChatInterface(chatbot_interface)
# Mount Gradio app on FastAPI
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
# Run FastAPI on the port Spaces expects (7860)
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|