Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -22,6 +22,11 @@ import google.generativeai as genai
|
|
| 22 |
from spellchecker import SpellChecker
|
| 23 |
import nltk
|
| 24 |
from nltk.tokenize import sent_tokenize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# Configure logging
|
| 27 |
logging.basicConfig(
|
|
@@ -31,7 +36,7 @@ logging.basicConfig(
|
|
| 31 |
logger = logging.getLogger("cosmic_ai")
|
| 32 |
|
| 33 |
# Set a custom NLTK data directory
|
| 34 |
-
nltk_data_dir = os.getenv('NLTK_DATA_DIR', '/
|
| 35 |
os.makedirs(nltk_data_dir, exist_ok=True)
|
| 36 |
nltk.data.path.append(nltk_data_dir)
|
| 37 |
|
|
@@ -60,7 +65,9 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
| 60 |
app.mount("/images", StaticFiles(directory="images"), name="images")
|
| 61 |
|
| 62 |
# Gemini API Configuration
|
| 63 |
-
API_KEY =
|
|
|
|
|
|
|
| 64 |
genai.configure(api_key=API_KEY)
|
| 65 |
|
| 66 |
# Model configurations
|
|
@@ -101,6 +108,13 @@ spell = SpellChecker()
|
|
| 101 |
def load_model(task: str, model_name: str = None):
|
| 102 |
"""Cached model loader with proper task names and error handling"""
|
| 103 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
logger.info(f"Loading model for task: {task}, model: {model_name or MODELS.get(task)}")
|
| 105 |
start_time = time.time()
|
| 106 |
|
|
@@ -110,8 +124,8 @@ def load_model(task: str, model_name: str = None):
|
|
| 110 |
return genai.GenerativeModel(model_to_load)
|
| 111 |
|
| 112 |
if task == "visual-qa":
|
| 113 |
-
processor = ViltProcessor.from_pretrained(model_to_load)
|
| 114 |
-
model = ViltForQuestionAnswering.from_pretrained(model_to_load)
|
| 115 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 116 |
model.to(device)
|
| 117 |
|
|
@@ -130,8 +144,11 @@ def load_model(task: str, model_name: str = None):
|
|
| 130 |
|
| 131 |
return vqa_function
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
except Exception as e:
|
| 137 |
logger.error(f"Model load failed: {str(e)}")
|
|
@@ -171,6 +188,7 @@ def translate_text(text: str, target_language: str):
|
|
| 171 |
lang_code = SUPPORTED_LANGUAGES[target_lang]
|
| 172 |
|
| 173 |
if translation_model is None or translation_tokenizer is None:
|
|
|
|
| 174 |
raise Exception("Translation model not initialized")
|
| 175 |
|
| 176 |
match = re.search(r'how to say\s+(.+?)\s+in\s+(\w+)', text.lower())
|
|
@@ -191,7 +209,11 @@ def translate_text(text: str, target_language: str):
|
|
| 191 |
num_beams=1,
|
| 192 |
early_stopping=True
|
| 193 |
)
|
| 194 |
-
translated_text = translation_tokenizer.batch_decode(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
logger.info(f"Translation took {time.time() - start_time:.2f} seconds")
|
| 196 |
|
| 197 |
return translated_text
|
|
@@ -208,7 +230,6 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
|
|
| 208 |
text_lower = text.lower()
|
| 209 |
filename = file.filename.lower() if file.filename else ""
|
| 210 |
|
| 211 |
-
# Check for file translation intent
|
| 212 |
translate_patterns = [
|
| 213 |
r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
|
| 214 |
r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
|
|
@@ -222,7 +243,6 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
|
|
| 222 |
target_language = potential_lang.capitalize()
|
| 223 |
return "file-translate", target_language
|
| 224 |
|
| 225 |
-
# Image-related intents
|
| 226 |
content_type = file.content_type.lower() if file.content_type else ""
|
| 227 |
if content_type.startswith('image/') and text:
|
| 228 |
if "what’s this" in text_lower or "does this fly" in text_lower or ("fly" in text_lower and any(q in text_lower for q in ['does', 'can', 'will'])):
|
|
@@ -232,7 +252,6 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
|
|
| 232 |
if "generate a caption" in text_lower or "caption" in text_lower:
|
| 233 |
return "image-to-text", target_language
|
| 234 |
|
| 235 |
-
# File-related intents
|
| 236 |
if filename.endswith(('.xlsx', '.xls', '.csv')):
|
| 237 |
return "visualize", target_language
|
| 238 |
elif filename.endswith(('.pdf', '.docx', '.doc', '.txt', '.rtf')):
|
|
@@ -248,7 +267,6 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
|
|
| 248 |
if any(keyword in text_lower for keyword in ['chat', 'talk', 'converse', 'ask gemini']):
|
| 249 |
return "chatbot", target_language
|
| 250 |
|
| 251 |
-
# Text translation intent
|
| 252 |
translate_patterns = [
|
| 253 |
r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
|
| 254 |
r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
|
|
@@ -364,7 +382,11 @@ async def process_input(
|
|
| 364 |
max_length=512,
|
| 365 |
num_beams=1
|
| 366 |
)
|
| 367 |
-
translations[lang] = translation_tokenizer.batch_decode(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
response = "\n".join(f"{lang.capitalize()}: {translations[lang]}" for lang in translations)
|
| 369 |
logger.info(f"Translated to all supported languages: {', '.join(translations.keys())}")
|
| 370 |
return {"response": response, "type": "translation"}
|
|
@@ -382,7 +404,6 @@ async def process_input(
|
|
| 382 |
if not content.strip():
|
| 383 |
raise HTTPException(status_code=400, detail="No text could be extracted from the file")
|
| 384 |
|
| 385 |
-
# Split content into chunks to handle large files
|
| 386 |
max_chunk_size = 512
|
| 387 |
chunks = [content[i:i+max_chunk_size] for i in range(0, len(content), max_chunk_size)]
|
| 388 |
translated_chunks = []
|
|
@@ -511,15 +532,12 @@ async def process_input(
|
|
| 511 |
if not answer.endswith(('.', '!', '?')):
|
| 512 |
answer += '.'
|
| 513 |
|
| 514 |
-
# Check if the question asks for a specific, factual detail like color
|
| 515 |
factual_questions = ['color', 'size', 'number', 'how many', 'what is the']
|
| 516 |
is_factual = any(keyword in question.lower() for keyword in factual_questions)
|
| 517 |
|
| 518 |
if is_factual:
|
| 519 |
-
# Return the raw VQA answer for factual questions
|
| 520 |
final_answer = answer
|
| 521 |
else:
|
| 522 |
-
# Apply cosmic tone for non-factual, open-ended questions
|
| 523 |
chatbot = load_model("chatbot")
|
| 524 |
if "fly" in question.lower():
|
| 525 |
final_answer = chatbot.generate_content(f"Make this fun and spacey: {answer}").text.strip()
|
|
@@ -570,7 +588,22 @@ async def process_input(
|
|
| 570 |
if not content.strip():
|
| 571 |
raise HTTPException(status_code=400, detail="No text could be extracted from the file")
|
| 572 |
|
| 573 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
|
| 575 |
question = text.strip()
|
| 576 |
if not question.endswith('?'):
|
|
@@ -800,7 +833,7 @@ async def startup_event():
|
|
| 800 |
|
| 801 |
async def load_model_with_timeout(task):
|
| 802 |
try:
|
| 803 |
-
await asyncio.wait_for(asyncio.to_thread(load_model, task), timeout=
|
| 804 |
logger.info(f"Successfully loaded {task} model")
|
| 805 |
except asyncio.TimeoutError:
|
| 806 |
logger.warning(f"Timeout loading {task} model - will load on demand")
|
|
@@ -809,8 +842,8 @@ async def startup_event():
|
|
| 809 |
|
| 810 |
try:
|
| 811 |
model_name = MODELS["translation"]
|
| 812 |
-
translation_model = M2M100ForConditionalGeneration.from_pretrained(model_name)
|
| 813 |
-
translation_tokenizer = M2M100Tokenizer.from_pretrained(model_name)
|
| 814 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 815 |
translation_model.to(device)
|
| 816 |
logger.info("Translation model pre-loaded successfully")
|
|
|
|
| 22 |
from spellchecker import SpellChecker
|
| 23 |
import nltk
|
| 24 |
from nltk.tokenize import sent_tokenize
|
| 25 |
+
from dotenv import load_dotenv
|
| 26 |
+
import shutil
|
| 27 |
+
|
| 28 |
+
# Load environment variables
|
| 29 |
+
load_dotenv()
|
| 30 |
|
| 31 |
# Configure logging
|
| 32 |
logging.basicConfig(
|
|
|
|
| 36 |
logger = logging.getLogger("cosmic_ai")
|
| 37 |
|
| 38 |
# Set a custom NLTK data directory
|
| 39 |
+
nltk_data_dir = os.getenv('NLTK_DATA_DIR', '/cache/nltk_data')
|
| 40 |
os.makedirs(nltk_data_dir, exist_ok=True)
|
| 41 |
nltk.data.path.append(nltk_data_dir)
|
| 42 |
|
|
|
|
| 65 |
app.mount("/images", StaticFiles(directory="images"), name="images")
|
| 66 |
|
| 67 |
# Gemini API Configuration
|
| 68 |
+
API_KEY = os.getenv(AIzaSyDtLhhmXpy8ubSGb84ImaxM_ywlL0l_8bo')
|
| 69 |
+
if not API_KEY:
|
| 70 |
+
raise ValueError("GEMINI_API_KEY environment variable is not set")
|
| 71 |
genai.configure(api_key=API_KEY)
|
| 72 |
|
| 73 |
# Model configurations
|
|
|
|
| 108 |
def load_model(task: str, model_name: str = None):
|
| 109 |
"""Cached model loader with proper task names and error handling"""
|
| 110 |
try:
|
| 111 |
+
cache_dir = os.getenv('HF_HOME', '/cache/huggingface')
|
| 112 |
+
if not os.path.exists(cache_dir):
|
| 113 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 114 |
+
elif not os.access(cache_dir, os.W_OK):
|
| 115 |
+
logger.warning(f"Cache directory {cache_dir} is not writable. Attempting to clear cache.")
|
| 116 |
+
shutil.rmtree(cache_dir, ignore_errors=True)
|
| 117 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 118 |
logger.info(f"Loading model for task: {task}, model: {model_name or MODELS.get(task)}")
|
| 119 |
start_time = time.time()
|
| 120 |
|
|
|
|
| 124 |
return genai.GenerativeModel(model_to_load)
|
| 125 |
|
| 126 |
if task == "visual-qa":
|
| 127 |
+
processor = ViltProcessor.from_pretrained(model_to_load, cache_dir=cache_dir)
|
| 128 |
+
model = ViltForQuestionAnswering.from_pretrained(model_to_load, cache_dir=cache_dir)
|
| 129 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 130 |
model.to(device)
|
| 131 |
|
|
|
|
| 144 |
|
| 145 |
return vqa_function
|
| 146 |
|
| 147 |
+
return pipeline(
|
| 148 |
+
task if task != "file-qa" else "question-answering",
|
| 149 |
+
model=model_to_load,
|
| 150 |
+
cache_dir=cache_dir
|
| 151 |
+
)
|
| 152 |
|
| 153 |
except Exception as e:
|
| 154 |
logger.error(f"Model load failed: {str(e)}")
|
|
|
|
| 188 |
lang_code = SUPPORTED_LANGUAGES[target_lang]
|
| 189 |
|
| 190 |
if translation_model is None or translation_tokenizer is None:
|
| 191 |
+
Debugger cannot access local variable 'lang_code' before it was used
|
| 192 |
raise Exception("Translation model not initialized")
|
| 193 |
|
| 194 |
match = re.search(r'how to say\s+(.+?)\s+in\s+(\w+)', text.lower())
|
|
|
|
| 209 |
num_beams=1,
|
| 210 |
early_stopping=True
|
| 211 |
)
|
| 212 |
+
translated_text = translation_tokenizer.batch_decode(
|
| 213 |
+
generated_tokens,
|
| 214 |
+
skip_special_tokens=True,
|
| 215 |
+
clean_up_tokenization_spaces=False
|
| 216 |
+
)[0]
|
| 217 |
logger.info(f"Translation took {time.time() - start_time:.2f} seconds")
|
| 218 |
|
| 219 |
return translated_text
|
|
|
|
| 230 |
text_lower = text.lower()
|
| 231 |
filename = file.filename.lower() if file.filename else ""
|
| 232 |
|
|
|
|
| 233 |
translate_patterns = [
|
| 234 |
r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
|
| 235 |
r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
|
|
|
|
| 243 |
target_language = potential_lang.capitalize()
|
| 244 |
return "file-translate", target_language
|
| 245 |
|
|
|
|
| 246 |
content_type = file.content_type.lower() if file.content_type else ""
|
| 247 |
if content_type.startswith('image/') and text:
|
| 248 |
if "what’s this" in text_lower or "does this fly" in text_lower or ("fly" in text_lower and any(q in text_lower for q in ['does', 'can', 'will'])):
|
|
|
|
| 252 |
if "generate a caption" in text_lower or "caption" in text_lower:
|
| 253 |
return "image-to-text", target_language
|
| 254 |
|
|
|
|
| 255 |
if filename.endswith(('.xlsx', '.xls', '.csv')):
|
| 256 |
return "visualize", target_language
|
| 257 |
elif filename.endswith(('.pdf', '.docx', '.doc', '.txt', '.rtf')):
|
|
|
|
| 267 |
if any(keyword in text_lower for keyword in ['chat', 'talk', 'converse', 'ask gemini']):
|
| 268 |
return "chatbot", target_language
|
| 269 |
|
|
|
|
| 270 |
translate_patterns = [
|
| 271 |
r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
|
| 272 |
r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
|
|
|
|
| 382 |
max_length=512,
|
| 383 |
num_beams=1
|
| 384 |
)
|
| 385 |
+
translations[lang] = translation_tokenizer.batch_decode(
|
| 386 |
+
generated_tokens,
|
| 387 |
+
skip_special_tokens=True,
|
| 388 |
+
clean_up_tokenization_spaces=False
|
| 389 |
+
)[0]
|
| 390 |
response = "\n".join(f"{lang.capitalize()}: {translations[lang]}" for lang in translations)
|
| 391 |
logger.info(f"Translated to all supported languages: {', '.join(translations.keys())}")
|
| 392 |
return {"response": response, "type": "translation"}
|
|
|
|
| 404 |
if not content.strip():
|
| 405 |
raise HTTPException(status_code=400, detail="No text could be extracted from the file")
|
| 406 |
|
|
|
|
| 407 |
max_chunk_size = 512
|
| 408 |
chunks = [content[i:i+max_chunk_size] for i in range(0, len(content), max_chunk_size)]
|
| 409 |
translated_chunks = []
|
|
|
|
| 532 |
if not answer.endswith(('.', '!', '?')):
|
| 533 |
answer += '.'
|
| 534 |
|
|
|
|
| 535 |
factual_questions = ['color', 'size', 'number', 'how many', 'what is the']
|
| 536 |
is_factual = any(keyword in question.lower() for keyword in factual_questions)
|
| 537 |
|
| 538 |
if is_factual:
|
|
|
|
| 539 |
final_answer = answer
|
| 540 |
else:
|
|
|
|
| 541 |
chatbot = load_model("chatbot")
|
| 542 |
if "fly" in question.lower():
|
| 543 |
final_answer = chatbot.generate_content(f"Make this fun and spacey: {answer}").text.strip()
|
|
|
|
| 588 |
if not content.strip():
|
| 589 |
raise HTTPException(status_code=400, detail="No text could be extracted from the file")
|
| 590 |
|
| 591 |
+
try:
|
| 592 |
+
qa_pipeline = load_model("file-qa")
|
| 593 |
+
except Exception as e:
|
| 594 |
+
logger.warning(f"File-QA model failed: {str(e)}. Falling back to Gemini.")
|
| 595 |
+
question = text.strip()
|
| 596 |
+
if not question.endswith('?'):
|
| 597 |
+
question += '?'
|
| 598 |
+
response = get_gemini_response(f"Answer this question based on the following text: {content}\nQuestion: {question}")
|
| 599 |
+
return {
|
| 600 |
+
"response": response,
|
| 601 |
+
"type": "file_qa",
|
| 602 |
+
"additional_data": {
|
| 603 |
+
"question": text,
|
| 604 |
+
"file_name": file.filename
|
| 605 |
+
}
|
| 606 |
+
}
|
| 607 |
|
| 608 |
question = text.strip()
|
| 609 |
if not question.endswith('?'):
|
|
|
|
| 833 |
|
| 834 |
async def load_model_with_timeout(task):
|
| 835 |
try:
|
| 836 |
+
await asyncio.wait_for(asyncio.to_thread(load_model, task), timeout=120.0)
|
| 837 |
logger.info(f"Successfully loaded {task} model")
|
| 838 |
except asyncio.TimeoutError:
|
| 839 |
logger.warning(f"Timeout loading {task} model - will load on demand")
|
|
|
|
| 842 |
|
| 843 |
try:
|
| 844 |
model_name = MODELS["translation"]
|
| 845 |
+
translation_model = M2M100ForConditionalGeneration.from_pretrained(model_name, cache_dir=os.getenv('HF_HOME'))
|
| 846 |
+
translation_tokenizer = M2M100Tokenizer.from_pretrained(model_name, cache_dir=os.getenv('HF_HOME'))
|
| 847 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 848 |
translation_model.to(device)
|
| 849 |
logger.info("Translation model pre-loaded successfully")
|