Hadiil commited on
Commit
ead791f
·
verified ·
1 Parent(s): 53b0ba4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -20
app.py CHANGED
@@ -22,6 +22,11 @@ import google.generativeai as genai
22
  from spellchecker import SpellChecker
23
  import nltk
24
  from nltk.tokenize import sent_tokenize
 
 
 
 
 
25
 
26
  # Configure logging
27
  logging.basicConfig(
@@ -31,7 +36,7 @@ logging.basicConfig(
31
  logger = logging.getLogger("cosmic_ai")
32
 
33
  # Set a custom NLTK data directory
34
- nltk_data_dir = os.getenv('NLTK_DATA_DIR', '/tmp/nltk_data')
35
  os.makedirs(nltk_data_dir, exist_ok=True)
36
  nltk.data.path.append(nltk_data_dir)
37
 
@@ -60,7 +65,9 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
60
  app.mount("/images", StaticFiles(directory="images"), name="images")
61
 
62
  # Gemini API Configuration
63
- API_KEY = "AIzaSyDtLhhmXpy8ubSGb84ImaxM_ywlL0l_8bo" # Replace with your actual API key
 
 
64
  genai.configure(api_key=API_KEY)
65
 
66
  # Model configurations
@@ -101,6 +108,13 @@ spell = SpellChecker()
101
  def load_model(task: str, model_name: str = None):
102
  """Cached model loader with proper task names and error handling"""
103
  try:
 
 
 
 
 
 
 
104
  logger.info(f"Loading model for task: {task}, model: {model_name or MODELS.get(task)}")
105
  start_time = time.time()
106
 
@@ -110,8 +124,8 @@ def load_model(task: str, model_name: str = None):
110
  return genai.GenerativeModel(model_to_load)
111
 
112
  if task == "visual-qa":
113
- processor = ViltProcessor.from_pretrained(model_to_load)
114
- model = ViltForQuestionAnswering.from_pretrained(model_to_load)
115
  device = "cuda" if torch.cuda.is_available() else "cpu"
116
  model.to(device)
117
 
@@ -130,8 +144,11 @@ def load_model(task: str, model_name: str = None):
130
 
131
  return vqa_function
132
 
133
- # Use pipeline for summarization, image-to-text, and file-qa
134
- return pipeline(task if task != "file-qa" else "question-answering", model=model_to_load)
 
 
 
135
 
136
  except Exception as e:
137
  logger.error(f"Model load failed: {str(e)}")
@@ -171,6 +188,7 @@ def translate_text(text: str, target_language: str):
171
  lang_code = SUPPORTED_LANGUAGES[target_lang]
172
 
173
  if translation_model is None or translation_tokenizer is None:
 
174
  raise Exception("Translation model not initialized")
175
 
176
  match = re.search(r'how to say\s+(.+?)\s+in\s+(\w+)', text.lower())
@@ -191,7 +209,11 @@ def translate_text(text: str, target_language: str):
191
  num_beams=1,
192
  early_stopping=True
193
  )
194
- translated_text = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
 
 
 
 
195
  logger.info(f"Translation took {time.time() - start_time:.2f} seconds")
196
 
197
  return translated_text
@@ -208,7 +230,6 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
208
  text_lower = text.lower()
209
  filename = file.filename.lower() if file.filename else ""
210
 
211
- # Check for file translation intent
212
  translate_patterns = [
213
  r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
214
  r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
@@ -222,7 +243,6 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
222
  target_language = potential_lang.capitalize()
223
  return "file-translate", target_language
224
 
225
- # Image-related intents
226
  content_type = file.content_type.lower() if file.content_type else ""
227
  if content_type.startswith('image/') and text:
228
  if "what’s this" in text_lower or "does this fly" in text_lower or ("fly" in text_lower and any(q in text_lower for q in ['does', 'can', 'will'])):
@@ -232,7 +252,6 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
232
  if "generate a caption" in text_lower or "caption" in text_lower:
233
  return "image-to-text", target_language
234
 
235
- # File-related intents
236
  if filename.endswith(('.xlsx', '.xls', '.csv')):
237
  return "visualize", target_language
238
  elif filename.endswith(('.pdf', '.docx', '.doc', '.txt', '.rtf')):
@@ -248,7 +267,6 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
248
  if any(keyword in text_lower for keyword in ['chat', 'talk', 'converse', 'ask gemini']):
249
  return "chatbot", target_language
250
 
251
- # Text translation intent
252
  translate_patterns = [
253
  r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
254
  r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
@@ -364,7 +382,11 @@ async def process_input(
364
  max_length=512,
365
  num_beams=1
366
  )
367
- translations[lang] = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
 
 
 
 
368
  response = "\n".join(f"{lang.capitalize()}: {translations[lang]}" for lang in translations)
369
  logger.info(f"Translated to all supported languages: {', '.join(translations.keys())}")
370
  return {"response": response, "type": "translation"}
@@ -382,7 +404,6 @@ async def process_input(
382
  if not content.strip():
383
  raise HTTPException(status_code=400, detail="No text could be extracted from the file")
384
 
385
- # Split content into chunks to handle large files
386
  max_chunk_size = 512
387
  chunks = [content[i:i+max_chunk_size] for i in range(0, len(content), max_chunk_size)]
388
  translated_chunks = []
@@ -511,15 +532,12 @@ async def process_input(
511
  if not answer.endswith(('.', '!', '?')):
512
  answer += '.'
513
 
514
- # Check if the question asks for a specific, factual detail like color
515
  factual_questions = ['color', 'size', 'number', 'how many', 'what is the']
516
  is_factual = any(keyword in question.lower() for keyword in factual_questions)
517
 
518
  if is_factual:
519
- # Return the raw VQA answer for factual questions
520
  final_answer = answer
521
  else:
522
- # Apply cosmic tone for non-factual, open-ended questions
523
  chatbot = load_model("chatbot")
524
  if "fly" in question.lower():
525
  final_answer = chatbot.generate_content(f"Make this fun and spacey: {answer}").text.strip()
@@ -570,7 +588,22 @@ async def process_input(
570
  if not content.strip():
571
  raise HTTPException(status_code=400, detail="No text could be extracted from the file")
572
 
573
- qa_pipeline = load_model("file-qa")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
  question = text.strip()
576
  if not question.endswith('?'):
@@ -800,7 +833,7 @@ async def startup_event():
800
 
801
  async def load_model_with_timeout(task):
802
  try:
803
- await asyncio.wait_for(asyncio.to_thread(load_model, task), timeout=60.0)
804
  logger.info(f"Successfully loaded {task} model")
805
  except asyncio.TimeoutError:
806
  logger.warning(f"Timeout loading {task} model - will load on demand")
@@ -809,8 +842,8 @@ async def startup_event():
809
 
810
  try:
811
  model_name = MODELS["translation"]
812
- translation_model = M2M100ForConditionalGeneration.from_pretrained(model_name)
813
- translation_tokenizer = M2M100Tokenizer.from_pretrained(model_name)
814
  device = "cuda" if torch.cuda.is_available() else "cpu"
815
  translation_model.to(device)
816
  logger.info("Translation model pre-loaded successfully")
 
22
  from spellchecker import SpellChecker
23
  import nltk
24
  from nltk.tokenize import sent_tokenize
25
+ from dotenv import load_dotenv
26
+ import shutil
27
+
28
+ # Load environment variables
29
+ load_dotenv()
30
 
31
  # Configure logging
32
  logging.basicConfig(
 
36
  logger = logging.getLogger("cosmic_ai")
37
 
38
  # Set a custom NLTK data directory
39
+ nltk_data_dir = os.getenv('NLTK_DATA_DIR', '/cache/nltk_data')
40
  os.makedirs(nltk_data_dir, exist_ok=True)
41
  nltk.data.path.append(nltk_data_dir)
42
 
 
65
  app.mount("/images", StaticFiles(directory="images"), name="images")
66
 
67
  # Gemini API Configuration
68
+ API_KEY = os.getenv(AIzaSyDtLhhmXpy8ubSGb84ImaxM_ywlL0l_8bo')
69
+ if not API_KEY:
70
+ raise ValueError("GEMINI_API_KEY environment variable is not set")
71
  genai.configure(api_key=API_KEY)
72
 
73
  # Model configurations
 
108
  def load_model(task: str, model_name: str = None):
109
  """Cached model loader with proper task names and error handling"""
110
  try:
111
+ cache_dir = os.getenv('HF_HOME', '/cache/huggingface')
112
+ if not os.path.exists(cache_dir):
113
+ os.makedirs(cache_dir, exist_ok=True)
114
+ elif not os.access(cache_dir, os.W_OK):
115
+ logger.warning(f"Cache directory {cache_dir} is not writable. Attempting to clear cache.")
116
+ shutil.rmtree(cache_dir, ignore_errors=True)
117
+ os.makedirs(cache_dir, exist_ok=True)
118
  logger.info(f"Loading model for task: {task}, model: {model_name or MODELS.get(task)}")
119
  start_time = time.time()
120
 
 
124
  return genai.GenerativeModel(model_to_load)
125
 
126
  if task == "visual-qa":
127
+ processor = ViltProcessor.from_pretrained(model_to_load, cache_dir=cache_dir)
128
+ model = ViltForQuestionAnswering.from_pretrained(model_to_load, cache_dir=cache_dir)
129
  device = "cuda" if torch.cuda.is_available() else "cpu"
130
  model.to(device)
131
 
 
144
 
145
  return vqa_function
146
 
147
+ return pipeline(
148
+ task if task != "file-qa" else "question-answering",
149
+ model=model_to_load,
150
+ cache_dir=cache_dir
151
+ )
152
 
153
  except Exception as e:
154
  logger.error(f"Model load failed: {str(e)}")
 
188
  lang_code = SUPPORTED_LANGUAGES[target_lang]
189
 
190
  if translation_model is None or translation_tokenizer is None:
191
+ Debugger cannot access local variable 'lang_code' before it was used
192
  raise Exception("Translation model not initialized")
193
 
194
  match = re.search(r'how to say\s+(.+?)\s+in\s+(\w+)', text.lower())
 
209
  num_beams=1,
210
  early_stopping=True
211
  )
212
+ translated_text = translation_tokenizer.batch_decode(
213
+ generated_tokens,
214
+ skip_special_tokens=True,
215
+ clean_up_tokenization_spaces=False
216
+ )[0]
217
  logger.info(f"Translation took {time.time() - start_time:.2f} seconds")
218
 
219
  return translated_text
 
230
  text_lower = text.lower()
231
  filename = file.filename.lower() if file.filename else ""
232
 
 
233
  translate_patterns = [
234
  r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
235
  r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
 
243
  target_language = potential_lang.capitalize()
244
  return "file-translate", target_language
245
 
 
246
  content_type = file.content_type.lower() if file.content_type else ""
247
  if content_type.startswith('image/') and text:
248
  if "what’s this" in text_lower or "does this fly" in text_lower or ("fly" in text_lower and any(q in text_lower for q in ['does', 'can', 'will'])):
 
252
  if "generate a caption" in text_lower or "caption" in text_lower:
253
  return "image-to-text", target_language
254
 
 
255
  if filename.endswith(('.xlsx', '.xls', '.csv')):
256
  return "visualize", target_language
257
  elif filename.endswith(('.pdf', '.docx', '.doc', '.txt', '.rtf')):
 
267
  if any(keyword in text_lower for keyword in ['chat', 'talk', 'converse', 'ask gemini']):
268
  return "chatbot", target_language
269
 
 
270
  translate_patterns = [
271
  r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
272
  r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
 
382
  max_length=512,
383
  num_beams=1
384
  )
385
+ translations[lang] = translation_tokenizer.batch_decode(
386
+ generated_tokens,
387
+ skip_special_tokens=True,
388
+ clean_up_tokenization_spaces=False
389
+ )[0]
390
  response = "\n".join(f"{lang.capitalize()}: {translations[lang]}" for lang in translations)
391
  logger.info(f"Translated to all supported languages: {', '.join(translations.keys())}")
392
  return {"response": response, "type": "translation"}
 
404
  if not content.strip():
405
  raise HTTPException(status_code=400, detail="No text could be extracted from the file")
406
 
 
407
  max_chunk_size = 512
408
  chunks = [content[i:i+max_chunk_size] for i in range(0, len(content), max_chunk_size)]
409
  translated_chunks = []
 
532
  if not answer.endswith(('.', '!', '?')):
533
  answer += '.'
534
 
 
535
  factual_questions = ['color', 'size', 'number', 'how many', 'what is the']
536
  is_factual = any(keyword in question.lower() for keyword in factual_questions)
537
 
538
  if is_factual:
 
539
  final_answer = answer
540
  else:
 
541
  chatbot = load_model("chatbot")
542
  if "fly" in question.lower():
543
  final_answer = chatbot.generate_content(f"Make this fun and spacey: {answer}").text.strip()
 
588
  if not content.strip():
589
  raise HTTPException(status_code=400, detail="No text could be extracted from the file")
590
 
591
+ try:
592
+ qa_pipeline = load_model("file-qa")
593
+ except Exception as e:
594
+ logger.warning(f"File-QA model failed: {str(e)}. Falling back to Gemini.")
595
+ question = text.strip()
596
+ if not question.endswith('?'):
597
+ question += '?'
598
+ response = get_gemini_response(f"Answer this question based on the following text: {content}\nQuestion: {question}")
599
+ return {
600
+ "response": response,
601
+ "type": "file_qa",
602
+ "additional_data": {
603
+ "question": text,
604
+ "file_name": file.filename
605
+ }
606
+ }
607
 
608
  question = text.strip()
609
  if not question.endswith('?'):
 
833
 
834
  async def load_model_with_timeout(task):
835
  try:
836
+ await asyncio.wait_for(asyncio.to_thread(load_model, task), timeout=120.0)
837
  logger.info(f"Successfully loaded {task} model")
838
  except asyncio.TimeoutError:
839
  logger.warning(f"Timeout loading {task} model - will load on demand")
 
842
 
843
  try:
844
  model_name = MODELS["translation"]
845
+ translation_model = M2M100ForConditionalGeneration.from_pretrained(model_name, cache_dir=os.getenv('HF_HOME'))
846
+ translation_tokenizer = M2M100Tokenizer.from_pretrained(model_name, cache_dir=os.getenv('HF_HOME'))
847
  device = "cuda" if torch.cuda.is_available() else "cpu"
848
  translation_model.to(device)
849
  logger.info("Translation model pre-loaded successfully")