Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,6 +18,7 @@ import torch
|
|
| 18 |
import numpy as np
|
| 19 |
from pydantic import BaseModel
|
| 20 |
import asyncio
|
|
|
|
| 21 |
from spellchecker import SpellChecker
|
| 22 |
import nltk
|
| 23 |
from nltk.tokenize import sent_tokenize
|
|
@@ -42,13 +43,13 @@ except Exception as e:
|
|
| 42 |
logger.error(f"Error verifying NLTK punkt_tab: {str(e)}")
|
| 43 |
raise Exception(f"Failed to verify NLTK punkt_tab: {str(e)}")
|
| 44 |
|
| 45 |
-
# Create
|
| 46 |
upload_dir = os.getenv('UPLOAD_DIR', '/tmp/uploads')
|
| 47 |
os.makedirs(upload_dir, exist_ok=True)
|
| 48 |
|
| 49 |
app = FastAPI(
|
| 50 |
title="Cosmic AI Assistant",
|
| 51 |
-
description="An advanced AI assistant with space-themed interface, translation,
|
| 52 |
version="2.0.0"
|
| 53 |
)
|
| 54 |
|
|
@@ -58,11 +59,16 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
| 58 |
# Mount images directory
|
| 59 |
app.mount("/images", StaticFiles(directory="images"), name="images")
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
# Model configurations
|
| 62 |
MODELS = {
|
| 63 |
"summarization": "sshleifer/distilbart-cnn-12-6",
|
| 64 |
"image-to-text": "Salesforce/blip-image-captioning-large",
|
| 65 |
"visual-qa": "dandelin/vilt-b32-finetuned-vqa",
|
|
|
|
| 66 |
"translation": "facebook/m2m100_418M",
|
| 67 |
"file-qa": "distilbert-base-cased-distilled-squad"
|
| 68 |
}
|
|
@@ -90,7 +96,7 @@ translation_tokenizer = None
|
|
| 90 |
# Initialize spell checker
|
| 91 |
spell = SpellChecker()
|
| 92 |
|
| 93 |
-
# Cache for model loading
|
| 94 |
@lru_cache(maxsize=8)
|
| 95 |
def load_model(task: str, model_name: str = None):
|
| 96 |
"""Cached model loader with proper task names and error handling"""
|
|
@@ -100,6 +106,9 @@ def load_model(task: str, model_name: str = None):
|
|
| 100 |
|
| 101 |
model_to_load = model_name or MODELS.get(task)
|
| 102 |
|
|
|
|
|
|
|
|
|
|
| 103 |
if task == "visual-qa":
|
| 104 |
processor = ViltProcessor.from_pretrained(model_to_load)
|
| 105 |
model = ViltForQuestionAnswering.from_pretrained(model_to_load)
|
|
@@ -128,6 +137,21 @@ def load_model(task: str, model_name: str = None):
|
|
| 128 |
logger.error(f"Model load failed: {str(e)}")
|
| 129 |
raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}")
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
def translate_text(text: str, target_language: str):
|
| 132 |
"""Translate text to any target language using pre-loaded M2M100 model"""
|
| 133 |
if not text:
|
|
@@ -217,10 +241,13 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
|
|
| 217 |
return "summarize", target_language
|
| 218 |
|
| 219 |
if not text:
|
| 220 |
-
return "
|
| 221 |
|
| 222 |
text_lower = text.lower()
|
| 223 |
|
|
|
|
|
|
|
|
|
|
| 224 |
# Text translation intent
|
| 225 |
translate_patterns = [
|
| 226 |
r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
|
|
@@ -237,7 +264,7 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
|
|
| 237 |
return "translate", target_language
|
| 238 |
else:
|
| 239 |
logger.warning(f"Invalid language detected: {potential_lang}")
|
| 240 |
-
return "
|
| 241 |
|
| 242 |
vqa_patterns = [
|
| 243 |
r'how (many|much)',
|
|
@@ -273,7 +300,7 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
|
|
| 273 |
if len(text) > 100:
|
| 274 |
return "summarize", target_language
|
| 275 |
|
| 276 |
-
return "
|
| 277 |
|
| 278 |
def preprocess_text(text: str) -> str:
|
| 279 |
"""Correct spelling errors and improve text readability."""
|
|
@@ -288,13 +315,29 @@ class ProcessResponse(BaseModel):
|
|
| 288 |
type: str
|
| 289 |
additional_data: Optional[Dict[str, Any]] = None
|
| 290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
@app.post("/process", response_model=ProcessResponse)
|
| 292 |
async def process_input(
|
| 293 |
request: Request,
|
| 294 |
text: str = Form(None),
|
| 295 |
file: UploadFile = File(None)
|
| 296 |
):
|
| 297 |
-
"""Enhanced unified endpoint
|
| 298 |
start_time = time.time()
|
| 299 |
client_ip = request.client.host
|
| 300 |
logger.info(f"Request from {client_ip}: text={text[:50] + '...' if text and len(text) > 50 else text}, file={file.filename if file else None}")
|
|
@@ -303,7 +346,11 @@ async def process_input(
|
|
| 303 |
logger.info(f"Detected intent: {intent}, target_language: {target_language}")
|
| 304 |
|
| 305 |
try:
|
| 306 |
-
if intent == "
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
content = await extract_text_from_file(file) if file else text
|
| 308 |
if "all languages" in text.lower():
|
| 309 |
translations = {}
|
|
@@ -401,6 +448,12 @@ async def process_input(
|
|
| 401 |
final_summary = summary[0]['summary_text']
|
| 402 |
|
| 403 |
final_summary = re.sub(r'\s+', ' ', final_summary).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
if not final_summary.endswith(('.', '!', '?')):
|
| 405 |
final_summary += '.'
|
| 406 |
|
|
@@ -409,7 +462,10 @@ async def process_input(
|
|
| 409 |
|
| 410 |
except Exception as e:
|
| 411 |
logger.error(f"Summarization error: {str(e)}")
|
| 412 |
-
|
|
|
|
|
|
|
|
|
|
| 413 |
|
| 414 |
elif intent == "image-to-text":
|
| 415 |
if not file or not file.content_type.startswith('image/'):
|
|
@@ -441,7 +497,10 @@ async def process_input(
|
|
| 441 |
if not question.endswith('?'):
|
| 442 |
question += '?'
|
| 443 |
|
| 444 |
-
answer = vqa_pipeline(
|
|
|
|
|
|
|
|
|
|
| 445 |
|
| 446 |
answer = answer.strip()
|
| 447 |
if not answer or answer.lower() == question.lower():
|
|
@@ -452,10 +511,25 @@ async def process_input(
|
|
| 452 |
if not answer.endswith(('.', '!', '?')):
|
| 453 |
answer += '.'
|
| 454 |
|
| 455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
|
| 457 |
return {
|
| 458 |
-
"response":
|
| 459 |
"type": "visual_qa",
|
| 460 |
"additional_data": {
|
| 461 |
"question": text,
|
|
@@ -481,11 +555,10 @@ async def process_input(
|
|
| 481 |
return {"response": response, "type": "visualization_code"}
|
| 482 |
|
| 483 |
elif intent == "text-generation":
|
| 484 |
-
|
| 485 |
-
response = f"Generated text based on '{text}': This is a simulated creative text."
|
| 486 |
lines = response.split(". ")
|
| 487 |
-
|
| 488 |
-
return {"response":
|
| 489 |
|
| 490 |
elif intent == "file-qa":
|
| 491 |
if not file or not file.filename.lower().endswith(('.pdf', '.docx', '.doc', '.txt', '.rtf')):
|
|
@@ -522,10 +595,17 @@ async def process_input(
|
|
| 522 |
if not best_answer.endswith(('.', '!', '?')):
|
| 523 |
best_answer += '.'
|
| 524 |
|
| 525 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
|
| 527 |
return {
|
| 528 |
-
"response":
|
| 529 |
"type": "file_qa",
|
| 530 |
"additional_data": {
|
| 531 |
"question": text,
|
|
@@ -534,7 +614,8 @@ async def process_input(
|
|
| 534 |
}
|
| 535 |
|
| 536 |
else:
|
| 537 |
-
|
|
|
|
| 538 |
|
| 539 |
except Exception as e:
|
| 540 |
logger.error(f"Processing error: {str(e)}", exc_info=True)
|
|
@@ -740,6 +821,7 @@ async def startup_event():
|
|
| 740 |
load_model_with_timeout("summarization"),
|
| 741 |
load_model_with_timeout("image-to-text"),
|
| 742 |
load_model_with_timeout("visual-qa"),
|
|
|
|
| 743 |
load_model_with_timeout("file-qa")
|
| 744 |
)
|
| 745 |
|
|
|
|
| 18 |
import numpy as np
|
| 19 |
from pydantic import BaseModel
|
| 20 |
import asyncio
|
| 21 |
+
import google.generativeai as genai
|
| 22 |
from spellchecker import SpellChecker
|
| 23 |
import nltk
|
| 24 |
from nltk.tokenize import sent_tokenize
|
|
|
|
| 43 |
logger.error(f"Error verifying NLTK punkt_tab: {str(e)}")
|
| 44 |
raise Exception(f"Failed to verify NLTK punkt_tab: {str(e)}")
|
| 45 |
|
| 46 |
+
# Create app directory if it doesn't exist
|
| 47 |
upload_dir = os.getenv('UPLOAD_DIR', '/tmp/uploads')
|
| 48 |
os.makedirs(upload_dir, exist_ok=True)
|
| 49 |
|
| 50 |
app = FastAPI(
|
| 51 |
title="Cosmic AI Assistant",
|
| 52 |
+
description="An advanced AI assistant with space-themed interface, translation, and file question-answering features",
|
| 53 |
version="2.0.0"
|
| 54 |
)
|
| 55 |
|
|
|
|
| 59 |
# Mount images directory
|
| 60 |
app.mount("/images", StaticFiles(directory="images"), name="images")
|
| 61 |
|
| 62 |
+
# Gemini API Configuration
|
| 63 |
+
API_KEY = "AIzaSyDtLhhmXpy8ubSGb84ImaxM_ywlL0l_8bo" # Replace with your actual API key
|
| 64 |
+
genai.configure(api_key=API_KEY)
|
| 65 |
+
|
| 66 |
# Model configurations
|
| 67 |
MODELS = {
|
| 68 |
"summarization": "sshleifer/distilbart-cnn-12-6",
|
| 69 |
"image-to-text": "Salesforce/blip-image-captioning-large",
|
| 70 |
"visual-qa": "dandelin/vilt-b32-finetuned-vqa",
|
| 71 |
+
"chatbot": "gemini-1.5-pro",
|
| 72 |
"translation": "facebook/m2m100_418M",
|
| 73 |
"file-qa": "distilbert-base-cased-distilled-squad"
|
| 74 |
}
|
|
|
|
| 96 |
# Initialize spell checker
|
| 97 |
spell = SpellChecker()
|
| 98 |
|
| 99 |
+
# Cache for model loading (excluding translation)
|
| 100 |
@lru_cache(maxsize=8)
|
| 101 |
def load_model(task: str, model_name: str = None):
|
| 102 |
"""Cached model loader with proper task names and error handling"""
|
|
|
|
| 106 |
|
| 107 |
model_to_load = model_name or MODELS.get(task)
|
| 108 |
|
| 109 |
+
if task == "chatbot":
|
| 110 |
+
return genai.GenerativeModel(model_to_load)
|
| 111 |
+
|
| 112 |
if task == "visual-qa":
|
| 113 |
processor = ViltProcessor.from_pretrained(model_to_load)
|
| 114 |
model = ViltForQuestionAnswering.from_pretrained(model_to_load)
|
|
|
|
| 137 |
logger.error(f"Model load failed: {str(e)}")
|
| 138 |
raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}")
|
| 139 |
|
| 140 |
+
def get_gemini_response(user_input: str, is_generation: bool = False):
|
| 141 |
+
"""Function to generate response with Gemini for both chat and text generation"""
|
| 142 |
+
if not user_input:
|
| 143 |
+
return "Please provide some input."
|
| 144 |
+
try:
|
| 145 |
+
chatbot = load_model("chatbot")
|
| 146 |
+
if is_generation:
|
| 147 |
+
prompt = f"Generate creative text based on this prompt: {user_input}"
|
| 148 |
+
else:
|
| 149 |
+
prompt = user_input
|
| 150 |
+
response = chatbot.generate_content(prompt)
|
| 151 |
+
return response.text.strip()
|
| 152 |
+
except Exception as e:
|
| 153 |
+
return f"Error: {str(e)}"
|
| 154 |
+
|
| 155 |
def translate_text(text: str, target_language: str):
|
| 156 |
"""Translate text to any target language using pre-loaded M2M100 model"""
|
| 157 |
if not text:
|
|
|
|
| 241 |
return "summarize", target_language
|
| 242 |
|
| 243 |
if not text:
|
| 244 |
+
return "chatbot", target_language
|
| 245 |
|
| 246 |
text_lower = text.lower()
|
| 247 |
|
| 248 |
+
if any(keyword in text_lower for keyword in ['chat', 'talk', 'converse', 'ask gemini']):
|
| 249 |
+
return "chatbot", target_language
|
| 250 |
+
|
| 251 |
# Text translation intent
|
| 252 |
translate_patterns = [
|
| 253 |
r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
|
|
|
|
| 264 |
return "translate", target_language
|
| 265 |
else:
|
| 266 |
logger.warning(f"Invalid language detected: {potential_lang}")
|
| 267 |
+
return "chatbot", target_language
|
| 268 |
|
| 269 |
vqa_patterns = [
|
| 270 |
r'how (many|much)',
|
|
|
|
| 300 |
if len(text) > 100:
|
| 301 |
return "summarize", target_language
|
| 302 |
|
| 303 |
+
return "chatbot", target_language
|
| 304 |
|
| 305 |
def preprocess_text(text: str) -> str:
|
| 306 |
"""Correct spelling errors and improve text readability."""
|
|
|
|
| 315 |
type: str
|
| 316 |
additional_data: Optional[Dict[str, Any]] = None
|
| 317 |
|
| 318 |
+
@app.get("/chatbot")
|
| 319 |
+
async def chatbot_interface():
|
| 320 |
+
"""Redirect to the static index.html file for the chatbot interface"""
|
| 321 |
+
return RedirectResponse(url="/static/index.html")
|
| 322 |
+
|
| 323 |
+
@app.post("/chat")
|
| 324 |
+
async def chat_endpoint(data: dict):
|
| 325 |
+
message = data.get("message", "")
|
| 326 |
+
if not message:
|
| 327 |
+
raise HTTPException(status_code=400, detail="No message provided")
|
| 328 |
+
try:
|
| 329 |
+
response = get_gemini_response(message)
|
| 330 |
+
return {"response": response}
|
| 331 |
+
except Exception as e:
|
| 332 |
+
raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}")
|
| 333 |
+
|
| 334 |
@app.post("/process", response_model=ProcessResponse)
|
| 335 |
async def process_input(
|
| 336 |
request: Request,
|
| 337 |
text: str = Form(None),
|
| 338 |
file: UploadFile = File(None)
|
| 339 |
):
|
| 340 |
+
"""Enhanced unified endpoint with dynamic translation and file translation"""
|
| 341 |
start_time = time.time()
|
| 342 |
client_ip = request.client.host
|
| 343 |
logger.info(f"Request from {client_ip}: text={text[:50] + '...' if text and len(text) > 50 else text}, file={file.filename if file else None}")
|
|
|
|
| 346 |
logger.info(f"Detected intent: {intent}, target_language: {target_language}")
|
| 347 |
|
| 348 |
try:
|
| 349 |
+
if intent == "chatbot":
|
| 350 |
+
response = get_gemini_response(text)
|
| 351 |
+
return {"response": response, "type": "chat"}
|
| 352 |
+
|
| 353 |
+
elif intent == "translate":
|
| 354 |
content = await extract_text_from_file(file) if file else text
|
| 355 |
if "all languages" in text.lower():
|
| 356 |
translations = {}
|
|
|
|
| 448 |
final_summary = summary[0]['summary_text']
|
| 449 |
|
| 450 |
final_summary = re.sub(r'\s+', ' ', final_summary).strip()
|
| 451 |
+
if not final_summary or final_summary.lower().startswith(content.lower()[:30]):
|
| 452 |
+
logger.warning("Summarizer produced inadequate output, falling back to Gemini")
|
| 453 |
+
final_summary = get_gemini_response(
|
| 454 |
+
f"Summarize this text in a concise and meaningful way: {content}"
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
if not final_summary.endswith(('.', '!', '?')):
|
| 458 |
final_summary += '.'
|
| 459 |
|
|
|
|
| 462 |
|
| 463 |
except Exception as e:
|
| 464 |
logger.error(f"Summarization error: {str(e)}")
|
| 465 |
+
final_summary = get_gemini_response(
|
| 466 |
+
f"Summarize this text in a concise and meaningful way: {content}"
|
| 467 |
+
)
|
| 468 |
+
return {"response": final_summary, "type": "summary", "message": "Text was preprocessed to correct spelling errors"}
|
| 469 |
|
| 470 |
elif intent == "image-to-text":
|
| 471 |
if not file or not file.content_type.startswith('image/'):
|
|
|
|
| 497 |
if not question.endswith('?'):
|
| 498 |
question += '?'
|
| 499 |
|
| 500 |
+
answer = vqa_pipeline(
|
| 501 |
+
image=image,
|
| 502 |
+
question=question
|
| 503 |
+
)
|
| 504 |
|
| 505 |
answer = answer.strip()
|
| 506 |
if not answer or answer.lower() == question.lower():
|
|
|
|
| 511 |
if not answer.endswith(('.', '!', '?')):
|
| 512 |
answer += '.'
|
| 513 |
|
| 514 |
+
# Check if the question asks for a specific, factual detail like color
|
| 515 |
+
factual_questions = ['color', 'size', 'number', 'how many', 'what is the']
|
| 516 |
+
is_factual = any(keyword in question.lower() for keyword in factual_questions)
|
| 517 |
+
|
| 518 |
+
if is_factual:
|
| 519 |
+
# Return the raw VQA answer for factual questions
|
| 520 |
+
final_answer = answer
|
| 521 |
+
else:
|
| 522 |
+
# Apply cosmic tone for non-factual, open-ended questions
|
| 523 |
+
chatbot = load_model("chatbot")
|
| 524 |
+
if "fly" in question.lower():
|
| 525 |
+
final_answer = chatbot.generate_content(f"Make this fun and spacey: {answer}").text.strip()
|
| 526 |
+
else:
|
| 527 |
+
final_answer = chatbot.generate_content(f"Make this cosmic and poetic: {answer}").text.strip()
|
| 528 |
+
|
| 529 |
+
logger.info(f"Final VQA answer: {final_answer}")
|
| 530 |
|
| 531 |
return {
|
| 532 |
+
"response": final_answer,
|
| 533 |
"type": "visual_qa",
|
| 534 |
"additional_data": {
|
| 535 |
"question": text,
|
|
|
|
| 555 |
return {"response": response, "type": "visualization_code"}
|
| 556 |
|
| 557 |
elif intent == "text-generation":
|
| 558 |
+
response = get_gemini_response(text, is_generation=True)
|
|
|
|
| 559 |
lines = response.split(". ")
|
| 560 |
+
formatted_poem = "\n".join(line.strip() + ("." if not line.endswith(".") else "") for line in lines if line)
|
| 561 |
+
return {"response": formatted_poem, "type": "generated_text"}
|
| 562 |
|
| 563 |
elif intent == "file-qa":
|
| 564 |
if not file or not file.filename.lower().endswith(('.pdf', '.docx', '.doc', '.txt', '.rtf')):
|
|
|
|
| 595 |
if not best_answer.endswith(('.', '!', '?')):
|
| 596 |
best_answer += '.'
|
| 597 |
|
| 598 |
+
try:
|
| 599 |
+
chatbot = load_model("chatbot")
|
| 600 |
+
final_answer = chatbot.generate_content(f"Make this cosmic and poetic: {best_answer}").text.strip()
|
| 601 |
+
except Exception as e:
|
| 602 |
+
logger.warning(f"Failed to add cosmic tone: {str(e)}. Using raw answer.")
|
| 603 |
+
final_answer = best_answer
|
| 604 |
+
|
| 605 |
+
logger.info(f"File QA answer: {final_answer}")
|
| 606 |
|
| 607 |
return {
|
| 608 |
+
"response": final_answer,
|
| 609 |
"type": "file_qa",
|
| 610 |
"additional_data": {
|
| 611 |
"question": text,
|
|
|
|
| 614 |
}
|
| 615 |
|
| 616 |
else:
|
| 617 |
+
response = get_gemini_response(text or "Hello! How can I assist you?")
|
| 618 |
+
return {"response": response, "type": "chat"}
|
| 619 |
|
| 620 |
except Exception as e:
|
| 621 |
logger.error(f"Processing error: {str(e)}", exc_info=True)
|
|
|
|
| 821 |
load_model_with_timeout("summarization"),
|
| 822 |
load_model_with_timeout("image-to-text"),
|
| 823 |
load_model_with_timeout("visual-qa"),
|
| 824 |
+
load_model_with_timeout("chatbot"),
|
| 825 |
load_model_with_timeout("file-qa")
|
| 826 |
)
|
| 827 |
|