Hadiil commited on
Commit
e047fcf
·
verified ·
1 Parent(s): 4f55411

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -19
app.py CHANGED
@@ -18,6 +18,7 @@ import torch
18
  import numpy as np
19
  from pydantic import BaseModel
20
  import asyncio
 
21
  from spellchecker import SpellChecker
22
  import nltk
23
  from nltk.tokenize import sent_tokenize
@@ -42,13 +43,13 @@ except Exception as e:
42
  logger.error(f"Error verifying NLTK punkt_tab: {str(e)}")
43
  raise Exception(f"Failed to verify NLTK punkt_tab: {str(e)}")
44
 
45
- # Create upload directory if it doesn't exist
46
  upload_dir = os.getenv('UPLOAD_DIR', '/tmp/uploads')
47
  os.makedirs(upload_dir, exist_ok=True)
48
 
49
  app = FastAPI(
50
  title="Cosmic AI Assistant",
51
- description="An advanced AI assistant with space-themed interface, translation, summarization, image analysis, and file question-answering features",
52
  version="2.0.0"
53
  )
54
 
@@ -58,11 +59,16 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
58
  # Mount images directory
59
  app.mount("/images", StaticFiles(directory="images"), name="images")
60
 
 
 
 
 
61
  # Model configurations
62
  MODELS = {
63
  "summarization": "sshleifer/distilbart-cnn-12-6",
64
  "image-to-text": "Salesforce/blip-image-captioning-large",
65
  "visual-qa": "dandelin/vilt-b32-finetuned-vqa",
 
66
  "translation": "facebook/m2m100_418M",
67
  "file-qa": "distilbert-base-cased-distilled-squad"
68
  }
@@ -90,7 +96,7 @@ translation_tokenizer = None
90
  # Initialize spell checker
91
  spell = SpellChecker()
92
 
93
- # Cache for model loading
94
  @lru_cache(maxsize=8)
95
  def load_model(task: str, model_name: str = None):
96
  """Cached model loader with proper task names and error handling"""
@@ -100,6 +106,9 @@ def load_model(task: str, model_name: str = None):
100
 
101
  model_to_load = model_name or MODELS.get(task)
102
 
 
 
 
103
  if task == "visual-qa":
104
  processor = ViltProcessor.from_pretrained(model_to_load)
105
  model = ViltForQuestionAnswering.from_pretrained(model_to_load)
@@ -128,6 +137,21 @@ def load_model(task: str, model_name: str = None):
128
  logger.error(f"Model load failed: {str(e)}")
129
  raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}")
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  def translate_text(text: str, target_language: str):
132
  """Translate text to any target language using pre-loaded M2M100 model"""
133
  if not text:
@@ -217,10 +241,13 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
217
  return "summarize", target_language
218
 
219
  if not text:
220
- return "summarize", target_language
221
 
222
  text_lower = text.lower()
223
 
 
 
 
224
  # Text translation intent
225
  translate_patterns = [
226
  r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
@@ -237,7 +264,7 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
237
  return "translate", target_language
238
  else:
239
  logger.warning(f"Invalid language detected: {potential_lang}")
240
- return "summarize", target_language
241
 
242
  vqa_patterns = [
243
  r'how (many|much)',
@@ -273,7 +300,7 @@ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
273
  if len(text) > 100:
274
  return "summarize", target_language
275
 
276
- return "summarize", target_language
277
 
278
  def preprocess_text(text: str) -> str:
279
  """Correct spelling errors and improve text readability."""
@@ -288,13 +315,29 @@ class ProcessResponse(BaseModel):
288
  type: str
289
  additional_data: Optional[Dict[str, Any]] = None
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  @app.post("/process", response_model=ProcessResponse)
292
  async def process_input(
293
  request: Request,
294
  text: str = Form(None),
295
  file: UploadFile = File(None)
296
  ):
297
- """Enhanced unified endpoint for summarization, translation, image analysis, and file QA"""
298
  start_time = time.time()
299
  client_ip = request.client.host
300
  logger.info(f"Request from {client_ip}: text={text[:50] + '...' if text and len(text) > 50 else text}, file={file.filename if file else None}")
@@ -303,7 +346,11 @@ async def process_input(
303
  logger.info(f"Detected intent: {intent}, target_language: {target_language}")
304
 
305
  try:
306
- if intent == "translate":
 
 
 
 
307
  content = await extract_text_from_file(file) if file else text
308
  if "all languages" in text.lower():
309
  translations = {}
@@ -401,6 +448,12 @@ async def process_input(
401
  final_summary = summary[0]['summary_text']
402
 
403
  final_summary = re.sub(r'\s+', ' ', final_summary).strip()
 
 
 
 
 
 
404
  if not final_summary.endswith(('.', '!', '?')):
405
  final_summary += '.'
406
 
@@ -409,7 +462,10 @@ async def process_input(
409
 
410
  except Exception as e:
411
  logger.error(f"Summarization error: {str(e)}")
412
- raise HTTPException(status_code=500, detail=f"Summarization error: {str(e)}")
 
 
 
413
 
414
  elif intent == "image-to-text":
415
  if not file or not file.content_type.startswith('image/'):
@@ -441,7 +497,10 @@ async def process_input(
441
  if not question.endswith('?'):
442
  question += '?'
443
 
444
- answer = vqa_pipeline(image=image, question=question)
 
 
 
445
 
446
  answer = answer.strip()
447
  if not answer or answer.lower() == question.lower():
@@ -452,10 +511,25 @@ async def process_input(
452
  if not answer.endswith(('.', '!', '?')):
453
  answer += '.'
454
 
455
- logger.info(f"Final VQA answer: {answer}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
457
  return {
458
- "response": answer,
459
  "type": "visual_qa",
460
  "additional_data": {
461
  "question": text,
@@ -481,11 +555,10 @@ async def process_input(
481
  return {"response": response, "type": "visualization_code"}
482
 
483
  elif intent == "text-generation":
484
- # Simulate text generation without Gemini
485
- response = f"Generated text based on '{text}': This is a simulated creative text."
486
  lines = response.split(". ")
487
- formatted_text = "\n".join(line.strip() + ("." if not line.endswith(".") else "") for line in lines if line)
488
- return {"response": formatted_text, "type": "generated_text"}
489
 
490
  elif intent == "file-qa":
491
  if not file or not file.filename.lower().endswith(('.pdf', '.docx', '.doc', '.txt', '.rtf')):
@@ -522,10 +595,17 @@ async def process_input(
522
  if not best_answer.endswith(('.', '!', '?')):
523
  best_answer += '.'
524
 
525
- logger.info(f"File QA answer: {best_answer}")
 
 
 
 
 
 
 
526
 
527
  return {
528
- "response": best_answer,
529
  "type": "file_qa",
530
  "additional_data": {
531
  "question": text,
@@ -534,7 +614,8 @@ async def process_input(
534
  }
535
 
536
  else:
537
- raise HTTPException(status_code=400, detail="Invalid intent detected")
 
538
 
539
  except Exception as e:
540
  logger.error(f"Processing error: {str(e)}", exc_info=True)
@@ -740,6 +821,7 @@ async def startup_event():
740
  load_model_with_timeout("summarization"),
741
  load_model_with_timeout("image-to-text"),
742
  load_model_with_timeout("visual-qa"),
 
743
  load_model_with_timeout("file-qa")
744
  )
745
 
 
18
  import numpy as np
19
  from pydantic import BaseModel
20
  import asyncio
21
+ import google.generativeai as genai
22
  from spellchecker import SpellChecker
23
  import nltk
24
  from nltk.tokenize import sent_tokenize
 
43
  logger.error(f"Error verifying NLTK punkt_tab: {str(e)}")
44
  raise Exception(f"Failed to verify NLTK punkt_tab: {str(e)}")
45
 
46
+ # Create app directory if it doesn't exist
47
  upload_dir = os.getenv('UPLOAD_DIR', '/tmp/uploads')
48
  os.makedirs(upload_dir, exist_ok=True)
49
 
50
  app = FastAPI(
51
  title="Cosmic AI Assistant",
52
+ description="An advanced AI assistant with space-themed interface, translation, and file question-answering features",
53
  version="2.0.0"
54
  )
55
 
 
59
  # Mount images directory
60
  app.mount("/images", StaticFiles(directory="images"), name="images")
61
 
62
+ # Gemini API Configuration
63
+ API_KEY = "AIzaSyDtLhhmXpy8ubSGb84ImaxM_ywlL0l_8bo" # Replace with your actual API key
64
+ genai.configure(api_key=API_KEY)
65
+
66
  # Model configurations
67
  MODELS = {
68
  "summarization": "sshleifer/distilbart-cnn-12-6",
69
  "image-to-text": "Salesforce/blip-image-captioning-large",
70
  "visual-qa": "dandelin/vilt-b32-finetuned-vqa",
71
+ "chatbot": "gemini-1.5-pro",
72
  "translation": "facebook/m2m100_418M",
73
  "file-qa": "distilbert-base-cased-distilled-squad"
74
  }
 
96
  # Initialize spell checker
97
  spell = SpellChecker()
98
 
99
+ # Cache for model loading (excluding translation)
100
  @lru_cache(maxsize=8)
101
  def load_model(task: str, model_name: str = None):
102
  """Cached model loader with proper task names and error handling"""
 
106
 
107
  model_to_load = model_name or MODELS.get(task)
108
 
109
+ if task == "chatbot":
110
+ return genai.GenerativeModel(model_to_load)
111
+
112
  if task == "visual-qa":
113
  processor = ViltProcessor.from_pretrained(model_to_load)
114
  model = ViltForQuestionAnswering.from_pretrained(model_to_load)
 
137
  logger.error(f"Model load failed: {str(e)}")
138
  raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}")
139
 
140
+ def get_gemini_response(user_input: str, is_generation: bool = False):
141
+ """Function to generate response with Gemini for both chat and text generation"""
142
+ if not user_input:
143
+ return "Please provide some input."
144
+ try:
145
+ chatbot = load_model("chatbot")
146
+ if is_generation:
147
+ prompt = f"Generate creative text based on this prompt: {user_input}"
148
+ else:
149
+ prompt = user_input
150
+ response = chatbot.generate_content(prompt)
151
+ return response.text.strip()
152
+ except Exception as e:
153
+ return f"Error: {str(e)}"
154
+
155
  def translate_text(text: str, target_language: str):
156
  """Translate text to any target language using pre-loaded M2M100 model"""
157
  if not text:
 
241
  return "summarize", target_language
242
 
243
  if not text:
244
+ return "chatbot", target_language
245
 
246
  text_lower = text.lower()
247
 
248
+ if any(keyword in text_lower for keyword in ['chat', 'talk', 'converse', 'ask gemini']):
249
+ return "chatbot", target_language
250
+
251
  # Text translation intent
252
  translate_patterns = [
253
  r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
 
264
  return "translate", target_language
265
  else:
266
  logger.warning(f"Invalid language detected: {potential_lang}")
267
+ return "chatbot", target_language
268
 
269
  vqa_patterns = [
270
  r'how (many|much)',
 
300
  if len(text) > 100:
301
  return "summarize", target_language
302
 
303
+ return "chatbot", target_language
304
 
305
  def preprocess_text(text: str) -> str:
306
  """Correct spelling errors and improve text readability."""
 
315
  type: str
316
  additional_data: Optional[Dict[str, Any]] = None
317
 
318
+ @app.get("/chatbot")
319
+ async def chatbot_interface():
320
+ """Redirect to the static index.html file for the chatbot interface"""
321
+ return RedirectResponse(url="/static/index.html")
322
+
323
+ @app.post("/chat")
324
+ async def chat_endpoint(data: dict):
325
+ message = data.get("message", "")
326
+ if not message:
327
+ raise HTTPException(status_code=400, detail="No message provided")
328
+ try:
329
+ response = get_gemini_response(message)
330
+ return {"response": response}
331
+ except Exception as e:
332
+ raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}")
333
+
334
  @app.post("/process", response_model=ProcessResponse)
335
  async def process_input(
336
  request: Request,
337
  text: str = Form(None),
338
  file: UploadFile = File(None)
339
  ):
340
+ """Enhanced unified endpoint with dynamic translation and file translation"""
341
  start_time = time.time()
342
  client_ip = request.client.host
343
  logger.info(f"Request from {client_ip}: text={text[:50] + '...' if text and len(text) > 50 else text}, file={file.filename if file else None}")
 
346
  logger.info(f"Detected intent: {intent}, target_language: {target_language}")
347
 
348
  try:
349
+ if intent == "chatbot":
350
+ response = get_gemini_response(text)
351
+ return {"response": response, "type": "chat"}
352
+
353
+ elif intent == "translate":
354
  content = await extract_text_from_file(file) if file else text
355
  if "all languages" in text.lower():
356
  translations = {}
 
448
  final_summary = summary[0]['summary_text']
449
 
450
  final_summary = re.sub(r'\s+', ' ', final_summary).strip()
451
+ if not final_summary or final_summary.lower().startswith(content.lower()[:30]):
452
+ logger.warning("Summarizer produced inadequate output, falling back to Gemini")
453
+ final_summary = get_gemini_response(
454
+ f"Summarize this text in a concise and meaningful way: {content}"
455
+ )
456
+
457
  if not final_summary.endswith(('.', '!', '?')):
458
  final_summary += '.'
459
 
 
462
 
463
  except Exception as e:
464
  logger.error(f"Summarization error: {str(e)}")
465
+ final_summary = get_gemini_response(
466
+ f"Summarize this text in a concise and meaningful way: {content}"
467
+ )
468
+ return {"response": final_summary, "type": "summary", "message": "Text was preprocessed to correct spelling errors"}
469
 
470
  elif intent == "image-to-text":
471
  if not file or not file.content_type.startswith('image/'):
 
497
  if not question.endswith('?'):
498
  question += '?'
499
 
500
+ answer = vqa_pipeline(
501
+ image=image,
502
+ question=question
503
+ )
504
 
505
  answer = answer.strip()
506
  if not answer or answer.lower() == question.lower():
 
511
  if not answer.endswith(('.', '!', '?')):
512
  answer += '.'
513
 
514
+ # Check if the question asks for a specific, factual detail like color
515
+ factual_questions = ['color', 'size', 'number', 'how many', 'what is the']
516
+ is_factual = any(keyword in question.lower() for keyword in factual_questions)
517
+
518
+ if is_factual:
519
+ # Return the raw VQA answer for factual questions
520
+ final_answer = answer
521
+ else:
522
+ # Apply cosmic tone for non-factual, open-ended questions
523
+ chatbot = load_model("chatbot")
524
+ if "fly" in question.lower():
525
+ final_answer = chatbot.generate_content(f"Make this fun and spacey: {answer}").text.strip()
526
+ else:
527
+ final_answer = chatbot.generate_content(f"Make this cosmic and poetic: {answer}").text.strip()
528
+
529
+ logger.info(f"Final VQA answer: {final_answer}")
530
 
531
  return {
532
+ "response": final_answer,
533
  "type": "visual_qa",
534
  "additional_data": {
535
  "question": text,
 
555
  return {"response": response, "type": "visualization_code"}
556
 
557
  elif intent == "text-generation":
558
+ response = get_gemini_response(text, is_generation=True)
 
559
  lines = response.split(". ")
560
+ formatted_poem = "\n".join(line.strip() + ("." if not line.endswith(".") else "") for line in lines if line)
561
+ return {"response": formatted_poem, "type": "generated_text"}
562
 
563
  elif intent == "file-qa":
564
  if not file or not file.filename.lower().endswith(('.pdf', '.docx', '.doc', '.txt', '.rtf')):
 
595
  if not best_answer.endswith(('.', '!', '?')):
596
  best_answer += '.'
597
 
598
+ try:
599
+ chatbot = load_model("chatbot")
600
+ final_answer = chatbot.generate_content(f"Make this cosmic and poetic: {best_answer}").text.strip()
601
+ except Exception as e:
602
+ logger.warning(f"Failed to add cosmic tone: {str(e)}. Using raw answer.")
603
+ final_answer = best_answer
604
+
605
+ logger.info(f"File QA answer: {final_answer}")
606
 
607
  return {
608
+ "response": final_answer,
609
  "type": "file_qa",
610
  "additional_data": {
611
  "question": text,
 
614
  }
615
 
616
  else:
617
+ response = get_gemini_response(text or "Hello! How can I assist you?")
618
+ return {"response": response, "type": "chat"}
619
 
620
  except Exception as e:
621
  logger.error(f"Processing error: {str(e)}", exc_info=True)
 
821
  load_model_with_timeout("summarization"),
822
  load_model_with_timeout("image-to-text"),
823
  load_model_with_timeout("visual-qa"),
824
+ load_model_with_timeout("chatbot"),
825
  load_model_with_timeout("file-qa")
826
  )
827