Hadiil commited on
Commit
252f82c
·
verified ·
1 Parent(s): ebb75cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +616 -574
app.py CHANGED
@@ -2,7 +2,7 @@ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse
4
  from transformers import pipeline, ViltProcessor, ViltForQuestionAnswering, M2M100ForConditionalGeneration, M2M100Tokenizer
5
- from typing import Optional, Dict, Any, List, Union
6
  import logging
7
  import time
8
  import os
@@ -19,12 +19,6 @@ import numpy as np
19
  from pydantic import BaseModel
20
  import asyncio
21
  import google.generativeai as genai
22
- import magic # For MIME type detection
23
- import datetime
24
- import matplotlib
25
- matplotlib.use('Agg') # Set non-interactive backend
26
- import matplotlib.pyplot as plt
27
- import seaborn as sns
28
 
29
  # Configure logging
30
  logging.basicConfig(
@@ -33,42 +27,28 @@ logging.basicConfig(
33
  )
34
  logger = logging.getLogger("cosmic_ai")
35
 
36
- # Initialize FastAPI app
37
- app = FastAPI(title="Cosmic AI Assistant", version="2.1.0")
38
- app.mount("/static", StaticFiles(directory="static"), name="static")
39
- app.mount("/images", StaticFiles(directory="images"), name="images")
40
 
41
- # Ensure directories exist
42
- UPLOAD_DIR = os.getenv("UPLOAD_DIR", "/app/uploads")
43
- CACHE_DIR = os.getenv("TRANSFORMERS_CACHE", "/app/cache")
44
- IMAGES_DIR = "/app/images"
 
45
 
46
- os.makedirs(CACHE_DIR, exist_ok=True)
47
- os.makedirs(UPLOAD_DIR, exist_ok=True)
48
- os.makedirs(IMAGES_DIR, exist_ok=True)
49
 
50
- # Configure Gemini
51
- API_KEY = os.getenv("GOOGLE_API_KEY", "AIzaSyCwmgD8KxzWiuivtySNtcZF_rfTvx9s9sY")
52
- genai.configure(api_key=API_KEY)
53
 
54
- # Language mapping for translation
55
- LANGUAGE_MAPPING = {
56
- "english": "en",
57
- "french": "fr",
58
- "spanish": "es",
59
- "german": "de",
60
- "italian": "it",
61
- "portuguese": "pt",
62
- "russian": "ru",
63
- "chinese": "zh",
64
- "japanese": "ja",
65
- "korean": "ko",
66
- "arabic": "ar",
67
- "hindi": "hi"
68
- }
69
 
70
- # Inverse language mapping for reference
71
- LANGUAGE_CODE_TO_NAME = {v: k.title() for k, v in LANGUAGE_MAPPING.items()}
 
72
 
73
  # Model configurations
74
  MODELS = {
@@ -77,148 +57,42 @@ MODELS = {
77
  "visual-qa": "dandelin/vilt-b32-finetuned-vqa",
78
  "chatbot": "gemini-1.5-pro",
79
  "translation": "facebook/m2m100_418M",
80
- "question-answering": "distilbert-base-cased-distilled-squad",
81
- "generate": "gemini-1.5-pro"
82
  }
83
 
84
- # Response model
85
- class ProcessResponse(BaseModel):
86
- response: str
87
- type: str
88
- additional_data: Optional[Dict[str, Any]] = None
 
 
 
 
 
 
 
 
 
 
89
 
90
- # Intent detection with improved pattern matching and language detection
91
- def detect_intent(text: Optional[str], file: Optional[UploadFile]) -> tuple:
92
- """
93
- Detect user intent and target language from input
94
- Returns a tuple of (intent, target_language)
95
- """
96
- if not text and not file:
97
- return "unknown", "en"
98
-
99
- text_lower = text.lower() if text else ""
100
-
101
- # File-based intent detection
102
- if file:
103
- mime_type = file.content_type.lower() if hasattr(file, 'content_type') else ""
104
- filename_lower = file.filename.lower() if hasattr(file, 'filename') else ""
105
-
106
- # Image processing
107
- if mime_type.startswith('image/'):
108
- # Check if there's a specific question about the image
109
- if text and any(phrase in text_lower for phrase in [
110
- "what is", "how many", "does this", "is there", "can you see",
111
- "what color", "identify", "explain"
112
- ]):
113
- return "visual-qa", "en"
114
- else:
115
- # Just caption the image if no specific question
116
- return "image-to-text", "en"
117
-
118
- # Data visualization for spreadsheets
119
- elif any(mime_type.startswith(mt) for mt in ['text/csv', 'application/vnd.ms-excel', 'application/vnd.openxmlformats-officedocument.spreadsheetml']) or \
120
- any(filename_lower.endswith(ext) for ext in ['.csv', '.xls', '.xlsx']):
121
- return "visualize", "en"
122
-
123
- # Document processing
124
- elif any(mime_type.startswith(mt) for mt in [
125
- 'application/pdf',
126
- 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
127
- 'application/msword',
128
- 'text/plain',
129
- 'application/rtf',
130
- 'text/rtf'
131
- ]) or any(filename_lower.endswith(ext) for ext in ['.pdf', '.docx', '.doc', '.txt', '.rtf']):
132
- # If there's a specific question about the document
133
- if text and ("?" in text or any(word in text_lower for word in ["what", "who", "how", "when", "where", "why", "which", "find", "search"])):
134
- return "file-qa", "en"
135
- # If translation is requested
136
- elif text and any(keyword in text_lower for keyword in ["translate", "translation", "convert to"]):
137
- # Extract target language
138
- target_lang = "en" # Default to English
139
-
140
- # Check for language specification patterns
141
- lang_pattern = r"to\s+(\w+)"
142
- lang_match = re.search(lang_pattern, text_lower)
143
-
144
- if lang_match:
145
- lang_name = lang_match.group(1).lower()
146
- if lang_name in LANGUAGE_MAPPING:
147
- target_lang = LANGUAGE_MAPPING[lang_name]
148
- # Check if it's a direct language code
149
- elif lang_name in LANGUAGE_MAPPING.values():
150
- target_lang = lang_name
151
-
152
- return "translate", target_lang
153
- # Default to summarization for documents without specific instructions
154
- else:
155
- return "summarize", "en"
156
-
157
- # Text-based intent detection (no file)
158
-
159
- # Translation intent
160
- if any(keyword in text_lower for keyword in ["translate", "translation", "say in", "how to say"]):
161
- # Try to extract target language
162
- target_lang = "en" # Default
163
-
164
- # Check for language specification patterns
165
- lang_patterns = [
166
- r"to\s+(\w+)",
167
- r"in\s+(\w+)",
168
- r"into\s+(\w+)"
169
- ]
170
-
171
- for pattern in lang_patterns:
172
- lang_match = re.search(pattern, text_lower)
173
- if lang_match:
174
- lang_name = lang_match.group(1).lower()
175
- if lang_name in LANGUAGE_MAPPING:
176
- target_lang = LANGUAGE_MAPPING[lang_name]
177
- break
178
- # Check if it's a direct language code
179
- elif lang_name in LANGUAGE_MAPPING.values():
180
- target_lang = lang_name
181
- break
182
-
183
- # Check for "all languages" request
184
- if "all languages" in text_lower or "all supported languages" in text_lower:
185
- target_lang = "all"
186
-
187
- return "translate", target_lang
188
-
189
- # Summarization intent
190
- elif any(keyword in text_lower for keyword in [
191
- "summarize", "summary", "overview", "brief", "condense", "shorten", "tldr"
192
- ]) or (len(text) > 500 and not any(keyword in text_lower for keyword in ["write", "generate", "create"])):
193
- return "summarize", "en"
194
-
195
- # Text generation intent (creative writing)
196
- elif any(keyword in text_lower for keyword in [
197
- "write", "generate", "create", "compose", "draft", "story", "poem", "essay",
198
- "script", "letter", "email", "article", "blog"
199
- ]):
200
- return "generate", "en"
201
-
202
- # Default to chat
203
- return "chat", "en"
204
 
205
- # Model loading with caching
206
  @lru_cache(maxsize=8)
207
  def load_model(task: str, model_name: str = None):
208
  """Cached model loader with proper task names and error handling"""
209
  try:
210
- logger.info(f"Loading model for task: {task}, model: {model_name or MODELS.get(task, 'unknown')}")
211
  start_time = time.time()
212
- model_to_load = model_name or MODELS.get(task)
213
 
214
- if not model_to_load:
215
- raise ValueError(f"No model configured for task: {task}")
216
 
217
- # Gemini models
218
- if task == "chatbot" or task == "generate":
219
  return genai.GenerativeModel(model_to_load)
220
-
221
- # Visual Question Answering
222
  if task == "visual-qa":
223
  processor = ViltProcessor.from_pretrained(model_to_load)
224
  model = ViltForQuestionAnswering.from_pretrained(model_to_load)
@@ -228,453 +102,621 @@ def load_model(task: str, model_name: str = None):
228
  def vqa_function(image, question, **generate_kwargs):
229
  if image.mode != "RGB":
230
  image = image.convert("RGB")
231
-
232
  inputs = processor(image, question, return_tensors="pt").to(device)
233
  logger.info(f"VQA inputs - question: {question}, image size: {image.size}")
234
-
235
  with torch.no_grad():
236
  outputs = model(**inputs)
237
- logits = outputs.logits
238
- idx = logits.argmax(-1).item()
239
- answer = model.config.id2label[idx]
240
-
241
  logger.info(f"VQA raw output: {answer}")
242
  return answer
243
 
244
  return vqa_function
245
 
246
- # For most transformer models, use the standard pipeline
247
- try:
248
- if task == "translation":
249
- # For translation, return both tokenizer and model
250
- tokenizer = M2M100Tokenizer.from_pretrained(model_to_load)
251
- model = M2M100ForConditionalGeneration.from_pretrained(model_to_load)
252
- return tokenizer, model
253
- else:
254
- # Map task names to transformers pipeline tasks
255
- task_mapping = {
256
- "summarization": "summarization",
257
- "question-answering": "question-answering",
258
- "image-to-text": "image-to-text"
259
- }
260
-
261
- pipeline_task = task_mapping.get(task, task)
262
- return pipeline(pipeline_task, model=model_to_load)
263
- except Exception as e:
264
- logger.error(f"Pipeline creation failed for {task}: {str(e)}")
265
- raise
266
-
267
- logger.info(f"Model loaded in {time.time() - start_time:.2f}s")
268
 
269
  except Exception as e:
270
  logger.error(f"Model load failed: {str(e)}")
271
  raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}")
272
 
273
- # File text extraction with improved error handling and multiple format support
274
- async def extract_text_from_file(file: UploadFile) -> str:
275
- """Extract text from uploaded file (PDF, DOCX, TXT, RTF)"""
276
- filename = file.filename.lower() if hasattr(file, 'filename') else "unknown"
277
- content = await file.read()
278
-
279
- # Use Python-magic to detect MIME type
280
  try:
281
- mime = magic.Magic(mime=True)
282
- mime_type = mime.from_buffer(content)
 
 
 
 
 
283
  except Exception as e:
284
- logger.warning(f"MIME detection failed: {str(e)}, using content_type")
285
- mime_type = file.content_type if hasattr(file, 'content_type') else "application/octet-stream"
286
-
287
- logger.info(f"Processing file: {filename}, size: {len(content)} bytes, MIME type: {mime_type}")
 
 
288
 
289
  try:
290
- # PDF processing with fallback mechanisms
291
- if mime_type == 'application/pdf' or filename.endswith('.pdf'):
292
- try:
293
- doc = fitz.open(stream=content, filetype="pdf")
294
- text = ""
295
- for page in doc:
296
- text += page.get_text()
297
- doc.close()
298
-
299
- if not text.strip():
300
- logger.warning(f"No text extracted from PDF: {filename}, attempting OCR fallback")
301
- raise ValueError("No text could be extracted from the PDF")
302
-
303
- return text
304
- except Exception as e:
305
- logger.error(f"PyMuPDF failed for {filename}: {str(e)}")
306
- # Could implement PDF OCR fallback here if needed
307
- raise HTTPException(status_code=400, detail=f"Could not extract text from PDF: {str(e)}")
308
 
309
- # Word document processing
310
- elif any(mime_type.startswith(mt) for mt in [
311
- 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
312
- 'application/msword'
313
- ]) or filename.endswith(('.docx', '.doc')):
314
- doc = Document(io.BytesIO(content))
315
- text = "\n".join([para.text for para in doc.paragraphs])
316
-
317
- if not text.strip():
318
- logger.warning(f"No text extracted from DOCX: {filename}")
319
- raise ValueError("No text could be extracted from the document")
320
-
321
- return text
322
 
323
- # Plain text and RTF processing
324
- elif mime_type in ['text/plain', 'text/rtf', 'application/rtf'] or filename.endswith(('.txt', '.rtf')):
325
- try:
326
- # Try UTF-8 first
327
- text = content.decode('utf-8', errors='ignore')
328
-
329
- # For RTF, do basic cleanup of markup
330
- if mime_type in ['text/rtf', 'application/rtf'] or filename.endswith('.rtf'):
331
- # Very basic RTF cleaning (would need a proper RTF parser for better results)
332
- text = re.sub(r'\\[a-zA-Z]+', ' ', text) # Remove RTF commands
333
- text = re.sub(r'[{}]', '', text) # Remove braces
334
- text = re.sub(r'\\[0-9]+', '', text) # Remove numeric commands
335
-
336
- if not text.strip():
337
- logger.warning(f"No text extracted from text file: {filename}")
338
- raise ValueError("No text could be extracted from the text file")
339
-
340
- return text
341
- except UnicodeDecodeError:
342
- # Fallback to latin-1 if UTF-8 fails
343
- text = content.decode('latin-1', errors='ignore')
344
- return text
345
 
 
 
 
346
  else:
347
- logger.error(f"Unsupported file type: {mime_type} for {filename}")
348
- raise HTTPException(
349
- status_code=400,
350
- detail=f"Unsupported file type: {mime_type}. Please upload a PDF, DOCX, TXT, or RTF file"
351
- )
352
-
353
- except HTTPException:
354
- # Re-raise HTTP exceptions
355
- raise
356
- except Exception as e:
357
- logger.error(f"Text extraction failed for {filename}: {str(e)}")
358
- raise HTTPException(status_code=400, detail=f"Text extraction failed: {str(e)}")
359
-
360
- # Data visualization with enhanced options and error handling
361
- def generate_visualization_code(df: pd.DataFrame, visualization_type: str = None) -> tuple:
362
- """
363
- Generate visualization based on data analysis and save to static file
364
- Returns tuple of (image_path, description)
365
- """
366
- try:
367
- # Basic data analysis
368
- num_rows, num_cols = df.shape
369
- numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
370
- categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
371
 
372
- # Check for datetime columns
373
- date_cols = []
374
- for col in df.columns:
375
- if pd.api.types.is_datetime64_any_dtype(df[col]):
376
- date_cols.append(col)
377
- elif df[col].dtype == 'object':
378
- # Try to convert to datetime
379
- try:
380
- pd.to_datetime(df[col], errors='raise')
381
- date_cols.append(col)
382
- except (ValueError, TypeError):
383
- pass
384
 
385
- # Generate stats summary
386
- stats_summary = df.describe().to_string()
 
 
 
 
 
 
 
 
387
 
388
- # File path for saving
389
- timestamp = int(time.time())
390
- img_filename = f"viz_{timestamp}.png"
391
- img_path = os.path.join(IMAGES_DIR, img_filename)
392
 
393
- # Apply visualization based on type
394
- if visualization_type and visualization_type.lower() in ['scatter', 'correlation']:
395
- if len(numeric_cols) < 2:
396
- raise ValueError("Need at least 2 numeric columns for a scatter plot")
397
-
398
- plt.figure(figsize=(10, 6))
399
- x_col, y_col = numeric_cols[0], numeric_cols[1]
400
-
401
- # Create enhanced scatter plot
402
- sns.scatterplot(data=df, x=x_col, y=y_col, hue=categorical_cols[0] if categorical_cols else None)
403
- plt.title(f'Correlation between {x_col} and {y_col}')
404
- plt.xlabel(x_col)
405
- plt.ylabel(y_col)
406
- plt.grid(True, alpha=0.3)
407
-
408
- # Add regression line
409
- sns.regplot(x=x_col, y=y_col, data=df, scatter=False, line_kws={"color": "red"})
410
-
411
- # Add correlation coefficient as text
412
- corr = df[x_col].corr(df[y_col])
413
- plt.annotate(f"Correlation: {corr:.2f}",
414
- xy=(0.05, 0.95),
415
- xycoords='axes fraction',
416
- fontsize=12,
417
- bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8))
418
-
419
- plt.tight_layout()
420
- plt.savefig(img_path)
421
- plt.close()
422
-
423
- description = f"Scatter plot showing correlation between {x_col} and {y_col}. Correlation coefficient: {corr:.4f}"
424
-
425
- elif visualization_type and visualization_type.lower() in ['bar', 'barplot', 'barchart']:
426
- if len(categorical_cols) < 1 or len(numeric_cols) < 1:
427
- raise ValueError("Need at least 1 categorical and 1 numeric column for a bar chart")
428
-
429
- plt.figure(figsize=(12, 7))
430
- cat_col = categorical_cols[0]
431
- num_col = numeric_cols[0]
432
-
433
- # Get top categories if too many
434
- if df[cat_col].nunique() > 10:
435
- top_cats = df.groupby(cat_col)[num_col].sum().nlargest(10).index
436
- df_plot = df[df[cat_col].isin(top_cats)]
437
- title_suffix = " (top 10 categories)"
438
- else:
439
- df_plot = df
440
- title_suffix = ""
441
-
442
- # Create bar chart
443
- ax = sns.barplot(x=cat_col, y=num_col, data=df_plot, palette='viridis')
444
-
445
- # Add value labels on top of bars
446
- for p in ax.patches:
447
- ax.annotate(f'{p.get_height():.1f}',
448
- (p.get_x() + p.get_width() / 2., p.get_height()),
449
- ha='center', va='bottom',
450
- fontsize=9, color='black',
451
- xytext=(0, 5), textcoords='offset points')
452
-
453
- plt.title(f'Comparison of {num_col} by {cat_col}{title_suffix}', fontsize=14)
454
- plt.xlabel(cat_col, fontsize=12)
455
- plt.ylabel(num_col, fontsize=12)
456
- plt.xticks(rotation=45, ha='right')
457
- plt.grid(axis='y', alpha=0.3)
458
- plt.tight_layout()
459
- plt.savefig(img_path)
460
- plt.close()
461
-
462
- description = f"Bar chart comparing {num_col} across different {cat_col} categories"
463
-
464
- elif visualization_type and visualization_type.lower() in ['histogram', 'distribution']:
465
- if len(numeric_cols) < 1:
466
- raise ValueError("Need at least 1 numeric column for a histogram")
467
-
468
- plt.figure(figsize=(10, 6))
469
- num_col = numeric_cols[0]
470
-
471
- # Create histogram with KDE
472
- sns.histplot(df[num_col], kde=True, bins=20, color='purple')
473
-
474
- # Add mean and median lines
475
- mean_val = df[num_col].mean()
476
- median_val = df[num_col].median()
477
-
478
- plt.axvline(mean_val, color='red', linestyle='--', linewidth=1.5, label=f'Mean: {mean_val:.2f}')
479
- plt.axvline(median_val, color='green', linestyle='-.', linewidth=1.5, label=f'Median: {median_val:.2f}')
480
-
481
- plt.title(f'Distribution of {num_col}', fontsize=14)
482
- plt.xlabel(num_col, fontsize=12)
483
- plt.ylabel('Frequency', fontsize=12)
484
- plt.legend()
485
- plt.grid(True, alpha=0.3)
486
- plt.tight_layout()
487
- plt.savefig(img_path)
488
- plt.close()
489
-
490
- # Get descriptive stats for the column
491
- desc_stats = df[num_col].describe()
492
-
493
- description = (f"Histogram showing distribution of {num_col}\n"
494
- f"Mean: {desc_stats['mean']:.2f}, Median: {median_val:.2f}\n"
495
- f"Min: {desc_stats['min']:.2f}, Max: {desc_stats['max']:.2f}\n"
496
- f"Std Dev: {desc_stats['std']:.2f}")
497
-
498
- else: # Default dashboard with multiple plots
499
- # Create dashboard with multiple plots
500
- fig, axes = plt.subplots(2, 2, figsize=(16, 12))
501
- fig.suptitle('Data Dashboard', fontsize=16)
502
-
503
- # Plot 1: Correlation matrix (top-left)
504
- if len(numeric_cols) > 1:
505
- corr_matrix = df[numeric_cols].corr()
506
- sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt='.2f', ax=axes[0, 0])
507
- axes[0, 0].set_title('Correlation Matrix')
508
  else:
509
- axes[0, 0].text(0.5, 0.5, 'Not enough numeric columns for correlation matrix',
510
- ha='center', va='center', fontsize=12)
511
- axes[0, 0].axis('off')
512
-
513
- # Plot 2: Distribution (top-right)
514
- if numeric_cols:
515
- num_col = numeric_cols[0]
516
- sns.histplot(df[num_col], kde=True, ax=axes[0, 1], color='purple')
517
- axes[0, 1].set_title(f'Distribution of {num_col}')
518
- axes[0, 1].set_xlabel(num_col)
519
- axes[0, 1].set_ylabel('Frequency')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  else:
521
- axes[0, 1].text(0.5, 0.5, 'No numeric columns for histogram',
522
- ha='center', va='center', fontsize=12)
523
- axes[0, 1].axis('off')
524
-
525
- # Plot 3: Bar chart (bottom-left)
526
- if categorical_cols and numeric_cols:
527
- cat_col = categorical_cols[0]
528
- num_col = numeric_cols[0]
 
 
 
 
 
 
529
 
530
- # Limit to top categories if too many
531
- if df[cat_col].nunique() > 8:
532
- top_cats = df.groupby(cat_col)[num_col].sum().nlargest(8).index
533
- df_plot = df[df[cat_col].isin(top_cats)]
534
- title_suffix = " (top 8)"
535
- else:
536
- df_plot = df
537
- title_suffix = ""
 
538
 
539
- sns.barplot(x=cat_col, y=num_col, data=df_plot, ax=axes[1, 0], palette='viridis')
540
- axes[1, 0].set_title(f'{num_col} by {cat_col}{title_suffix}')
541
- axes[1, 0].set_xticklabels(axes[1, 0].get_xticklabels(), rotation=45, ha='right')
542
  else:
543
- axes[1, 0].text(0.5, 0.5, 'Need both categorical and numeric columns for bar chart',
544
- ha='center', va='center', fontsize=12)
545
- axes[1, 0].axis('off')
546
-
547
- # Plot 4: Box plot (bottom-right)
548
- if categorical_cols and numeric_cols:
549
- cat_col = categorical_cols[0]
550
- num_col = numeric_cols[0]
551
-
552
- # Limit to top categories if too many
553
- if df[cat_col].nunique() > 8:
554
- top_cats = df.groupby(cat_col)[num_col].sum().nlargest(8).index
555
- df_plot = df[df[cat_col].isin(top_cats)]
556
- title_suffix = " (top 8)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  else:
558
- df_plot = df
559
- title_suffix = ""
560
-
561
- sns.boxplot(x=cat_col, y=num_col, data=df_plot, ax=axes[1, 1], palette='Set3')
562
- axes[1, 1].set_title(f'Distribution of {num_col} by {cat_col}{title_suffix}')
563
- axes[1, 1].set_xticklabels(axes[1, 1].get_xticklabels(), rotation=45, ha='right')
564
  else:
565
- axes[1, 1].text(0.5, 0.5, 'Need both categorical and numeric columns for box plot',
566
- ha='center', va='center', fontsize=12)
567
- axes[1, 1].axis('off')
568
 
569
- plt.tight_layout(rect=[0, 0, 1, 0.97]) # Adjust layout to make room for suptitle
570
- plt.savefig(img_path)
571
- plt.close()
572
 
573
- # Generate description with data summary
574
- description = (f"Data Dashboard Summary:\n"
575
- f"Dataset dimensions: {num_rows} rows × {num_cols} columns\n"
576
- f"Numeric columns: {', '.join(numeric_cols[:5])}{'...' if len(numeric_cols) > 5 else ''}\n"
577
- f"Categorical columns: {', '.join(categorical_cols[:5])}{'...' if len(categorical_cols) > 5 else ''}")
578
 
579
- return f"/images/{img_filename}", description
580
-
581
- except Exception as e:
582
- logger.error(f"Visualization generation failed: {str(e)}")
583
- raise ValueError(f"Could not generate visualization: {str(e)}")
584
-
585
- # Enhanced translation with multiple language support
586
- async def translate_text(text: str, target_lang: str = "en") -> Union[str, Dict[str, str]]:
587
- """
588
- Translate text to target language or multiple languages
589
- If target_lang is "all", returns dict of language:translation
590
- """
591
- try:
592
- tokenizer, model = load_model("translation")
593
-
594
- # If requesting translation to all supported languages
595
- if target_lang == "all":
596
- results = {}
597
- for lang_code in LANGUAGE_MAPPING.values():
598
- try:
599
- tokenizer.src_lang = "en" # Assuming source is English
600
- tokenizer.tgt_lang = lang_code
601
- encoded = tokenizer(text, return_tensors="pt")
602
- generated_tokens = model.generate(
603
- **encoded,
604
- forced_bos_token_id=tokenizer.get_lang_id(lang_code),
605
- max_length=512
606
- )
607
- translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
608
- results[LANGUAGE_CODE_TO_NAME.get(lang_code, lang_code)] = translation
609
- except Exception as lang_error:
610
- logger.error(f"Translation to {lang_code} failed: {str(lang_error)}")
611
- results[LANGUAGE_CODE_TO_NAME.get(lang_code, lang_code)] = f"Translation failed: {str(lang_error)}"
612
 
613
- return results
614
  else:
615
- # Single language translation
616
- tokenizer.src_lang = "en" # Assuming source is English
617
-
618
- # Check if target_lang is valid
619
- if target_lang not in LANGUAGE_MAPPING.values():
620
- # Try to find it in the values
621
- for lang_name, lang_code in LANGUAGE_MAPPING.items():
622
- if target_lang.lower() == lang_name:
623
- target_lang = lang_code
624
- break
625
- else:
626
- logger.warning(f"Unsupported target language: {target_lang}, defaulting to English")
627
- return text # Return original text if language not supported
628
-
629
- tokenizer.tgt_lang = target_lang
630
- encoded = tokenizer(text, return_tensors="pt")
631
- generated_tokens = model.generate(
632
- **encoded,
633
- forced_bos_token_id=tokenizer.get_lang_id(target_lang),
634
- max_length=512
635
- )
636
- translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
637
- return translation
638
 
639
  except Exception as e:
640
- logger.error(f"Translation failed: {str(e)}")
641
- raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
642
 
643
- # Creative text generation with enhanced Gemini capabilities
644
- async def generate_creative_text(prompt: str) -> str:
645
- """Generate creative text content using Gemini model"""
646
  try:
647
- chatbot = load_model("generate")
648
-
649
- # Identify content type from prompt
650
- content_type = "story" # Default
651
- if "poem" in prompt.lower() or "poetry" in prompt.lower():
652
- content_type = "poem"
653
- elif "essay" in prompt.lower():
654
- content_type = "essay"
655
- elif "article" in prompt.lower() or "blog" in prompt.lower():
656
- content_type = "article"
657
- elif "letter" in prompt.lower() or "email" in prompt.lower():
658
- content_type = "letter"
659
-
660
- # Create an enhanced prompt with formatting instructions
661
- enhanced_prompt = f"Generate a creative {content_type} based on this prompt: '{prompt}'. Please follow these guidelines: Create engaging, original content with proper structure. Use vivid language and appropriate tone. Format the output with proper paragraphs and line breaks. If generating a poem, use appropriate stanza structure. Include a cosmic or space theme if appropriate."
662
-
663
- # Generate the content
664
- generation_config = {
665
- "temperature": 0.8,
666
- "top_p": 0.95,
667
- "top_k": 40,
668
- "max_output_tokens": 1024
669
- }
670
-
671
- # Call the Gemini model
672
- response = chatbot.generate_content(
673
- enhanced_prompt,
674
- generation_config=generation_config
 
 
 
 
 
 
 
 
675
  )
676
- return response.text
677
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678
  except Exception as e:
679
- logger.error(f"Text generation failed: {str(e)}")
680
- raise HTTPException(status_code=500, detail=f"Text generation failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse
4
  from transformers import pipeline, ViltProcessor, ViltForQuestionAnswering, M2M100ForConditionalGeneration, M2M100Tokenizer
5
+ from typing import Optional, Dict, Any, List
6
  import logging
7
  import time
8
  import os
 
19
  from pydantic import BaseModel
20
  import asyncio
21
  import google.generativeai as genai
 
 
 
 
 
 
22
 
23
  # Configure logging
24
  logging.basicConfig(
 
27
  )
28
  logger = logging.getLogger("cosmic_ai")
29
 
30
+ # Create app directory if it doesn't exist
31
+ upload_dir = os.getenv('UPLOAD_DIR', '/tmp/uploads')
32
+ os.makedirs(upload_dir, exist_ok=True)
 
33
 
34
+ app = FastAPI(
35
+ title="Cosmic AI Assistant",
36
+ description="An advanced AI assistant with space-themed interface, translation, and file question-answering features",
37
+ version="2.0.0"
38
+ )
39
 
40
+ # Mount static files
41
+ app.mount("/static", StaticFiles(directory="static"), name="static")
 
42
 
43
+ # Mount videos directory
44
+ app.mount("/videos", StaticFiles(directory="videos"), name="videos")
 
45
 
46
+ # Mount images directory
47
+ app.mount("/images", StaticFiles(directory="images"), name="images")
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ # Gemini API Configuration
50
+ API_KEY = "AIzaSyCwmgD8KxzWiuivtySNtcZF_rfTvx9s9sY" # Replace with your actual API key
51
+ genai.configure(api_key=API_KEY)
52
 
53
  # Model configurations
54
  MODELS = {
 
57
  "visual-qa": "dandelin/vilt-b32-finetuned-vqa",
58
  "chatbot": "gemini-1.5-pro",
59
  "translation": "facebook/m2m100_418M",
60
+ "file-qa": "distilbert-base-cased-distilled-squad" # New model for file QA
 
61
  }
62
 
63
+ # Supported languages for translation
64
+ SUPPORTED_LANGUAGES = {
65
+ "english": "en",
66
+ "french": "fr",
67
+ "german": "de",
68
+ "spanish": "es",
69
+ "italian": "it",
70
+ "russian": "ru",
71
+ "chinese": "zh",
72
+ "japanese": "ja",
73
+ "arabic": "ar",
74
+ "hindi": "hi",
75
+ "portuguese": "pt",
76
+ "korean": "ko"
77
+ }
78
 
79
+ # Global variables for pre-loaded translation model
80
+ translation_model = None
81
+ translation_tokenizer = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # Cache for model loading (excluding translation)
84
  @lru_cache(maxsize=8)
85
  def load_model(task: str, model_name: str = None):
86
  """Cached model loader with proper task names and error handling"""
87
  try:
88
+ logger.info(f"Loading model for task: {task}, model: {model_name or MODELS.get(task)}")
89
  start_time = time.time()
 
90
 
91
+ model_to_load = model_name or MODELS.get(task)
 
92
 
93
+ if task == "chatbot":
 
94
  return genai.GenerativeModel(model_to_load)
95
+
 
96
  if task == "visual-qa":
97
  processor = ViltProcessor.from_pretrained(model_to_load)
98
  model = ViltForQuestionAnswering.from_pretrained(model_to_load)
 
102
  def vqa_function(image, question, **generate_kwargs):
103
  if image.mode != "RGB":
104
  image = image.convert("RGB")
 
105
  inputs = processor(image, question, return_tensors="pt").to(device)
106
  logger.info(f"VQA inputs - question: {question}, image size: {image.size}")
 
107
  with torch.no_grad():
108
  outputs = model(**inputs)
109
+ logits = outputs.logits
110
+ idx = logits.argmax(-1).item()
111
+ answer = model.config.id2label[idx]
 
112
  logger.info(f"VQA raw output: {answer}")
113
  return answer
114
 
115
  return vqa_function
116
 
117
+ return pipeline(task, model=model_to_load)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  except Exception as e:
120
  logger.error(f"Model load failed: {str(e)}")
121
  raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}")
122
 
123
+ def get_gemini_response(user_input: str, is_generation: bool = False):
124
+ """Function to generate response with Gemini for both chat and text generation"""
125
+ if not user_input:
126
+ return "Please provide some input."
 
 
 
127
  try:
128
+ chatbot = load_model("chatbot")
129
+ if is_generation:
130
+ prompt = f"Generate creative text based on this prompt: {user_input}"
131
+ else:
132
+ prompt = user_input
133
+ response = chatbot.generate_content(prompt)
134
+ return response.text.strip()
135
  except Exception as e:
136
+ return f"Error: {str(e)}"
137
+
138
+ def translate_text(text: str, target_language: str):
139
+ """Translate text to any target language using pre-loaded M2M100 model"""
140
+ if not text:
141
+ return "Please provide text to translate."
142
 
143
  try:
144
+ global translation_model, translation_tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ target_lang = target_language.lower()
147
+ if target_lang not in SUPPORTED_LANGUAGES:
148
+ similar = [lang for lang in SUPPORTED_LANGUAGES if target_lang in lang or lang in target_lang]
149
+ if similar:
150
+ target_lang = similar[0]
151
+ else:
152
+ return f"Language '{target_language}' not supported. Available languages: {', '.join(SUPPORTED_LANGUAGES.keys())}"
 
 
 
 
 
 
153
 
154
+ lang_code = SUPPORTED_LANGUAGES[target_lang]
155
+
156
+ if translation_model is None or translation_tokenizer is None:
157
+ raise Exception("Translation model not initialized")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
+ match = re.search(r'how to say\s+(.+?)\s+in\s+(\w+)', text.lower())
160
+ if match:
161
+ text_to_translate = match.group(1)
162
  else:
163
+ content_match = re.search(r'(?:translate|convert).*to\s+[a-zA-Z]+\s*[:\s]*(.+)', text, re.IGNORECASE)
164
+ text_to_translate = content_match.group(1) if content_match else text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ translation_tokenizer.src_lang = "en"
167
+ encoded = translation_tokenizer(text_to_translate, return_tensors="pt", padding=True, truncation=True).to(translation_model.device)
 
 
 
 
 
 
 
 
 
 
168
 
169
+ start_time = time.time()
170
+ generated_tokens = translation_model.generate(
171
+ **encoded,
172
+ forced_bos_token_id=translation_tokenizer.get_lang_id(lang_code),
173
+ max_length=512,
174
+ num_beams=1,
175
+ early_stopping=True
176
+ )
177
+ translated_text = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
178
+ logger.info(f"Translation took {time.time() - start_time:.2f} seconds")
179
 
180
+ return translated_text
 
 
 
181
 
182
+ except Exception as e:
183
+ logger.error(f"Translation error: {str(e)}", exc_info=True)
184
+ return f"Translation error: {str(e)}"
185
+
186
+ def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]:
187
+ """Enhanced intent detection with dynamic translation and file QA support"""
188
+ target_language = "English" # Default
189
+
190
+ if file:
191
+ content_type = file.content_type.lower() if file.content_type else ""
192
+ filename = file.filename.lower() if file.filename else ""
193
+
194
+ if content_type.startswith('image/') and text:
195
+ text_lower = text.lower()
196
+ if "what’s this" in text_lower:
197
+ return "visual-qa", target_language
198
+ if "does this fly" in text_lower:
199
+ return "visual-qa", target_language
200
+ if "fly" in text_lower and any(q in text_lower for q in ['does', 'can', 'will']):
201
+ return "visual-qa", target_language
202
+
203
+ if content_type.startswith('image/'):
204
+ if text and any(q in text.lower() for q in ['what is', 'what\'s', 'describe', 'tell me about', 'explain','how many', 'what color', 'is there', 'are they', 'does the']):
205
+ return "visual-qa", target_language
206
+ return "image-to-text", target_language
207
+ elif filename.endswith(('.xlsx', '.xls', '.csv')):
208
+ return "visualize", target_language
209
+ elif filename.endswith(('.pdf', '.docx', '.doc', '.txt', '.rtf')):
210
+ if text and any(q in text.lower() for q in ['what is', 'who is', 'where', 'when', 'why', 'how', 'what are', 'who are']):
211
+ return "file-qa", target_language # New intent for file QA
212
+ return "summarize", target_language
213
+
214
+ if not text:
215
+ return "chatbot", target_language
216
+
217
+ text_lower = text.lower()
218
+
219
+ if any(keyword in text_lower for keyword in ['chat', 'talk', 'converse', 'ask gemini']):
220
+ return "chatbot", target_language
221
+
222
+ translate_patterns = [
223
+ r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
224
+ r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)',
225
+ r'how to say.*in\s+\[?([a-zA-Z]+)\]?:?\s*(.*)'
226
+ ]
227
+
228
+ for pattern in translate_patterns:
229
+ translate_match = re.search(pattern, text_lower)
230
+ if translate_match:
231
+ potential_lang = translate_match.group(1).lower()
232
+ if potential_lang in SUPPORTED_LANGUAGES:
233
+ target_language = potential_lang.capitalize()
234
+ return "translate", target_language
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  else:
236
+ logger.warning(f"Invalid language detected: {potential_lang}")
237
+ return "chatbot", target_language
238
+
239
+ vqa_patterns = [
240
+ r'how (many|much)',
241
+ r'what (color|size|position|shape)',
242
+ r'is (there|that|this) (a|an)',
243
+ r'are (they|there) (any|some)',
244
+ r'does (the|this) (image|picture) (show|contain)'
245
+ ]
246
+
247
+ if any(re.search(pattern, text_lower) for pattern in vqa_patterns):
248
+ return "visual-qa", target_language
249
+
250
+ summarization_patterns = [
251
+ r'\b(summar(y|ize|ise)|brief( overview)?)\b',
252
+ r'\b(long article|text|document)\b',
253
+ r'\bcan you (summar|brief|condense)\b',
254
+ r'\b(short summary|brief explanation)\b',
255
+ r'\b(overview|main points|key ideas)\b',
256
+ r'\b(tl;?dr|too long didn\'?t read)\b'
257
+ ]
258
+
259
+ if any(re.search(pattern, text_lower) for pattern in summarization_patterns):
260
+ return "summarize", target_language
261
+
262
+ generation_patterns = [
263
+ r'\b(write|generate|create|compose)\b',
264
+ r'\b(story|poem|essay|text|content)\b'
265
+ ]
266
+
267
+ if any(re.search(pattern, text_lower) for pattern in generation_patterns):
268
+ return "text-generation", target_language
269
+
270
+ if len(text) > 100:
271
+ return "summarize", target_language
272
+
273
+ if file and file.content_type and file.content_type.startswith('image/'):
274
+ if text and "what’s this" in text_lower:
275
+ return "visual-qa", target_language
276
+ if text and any(q in text_lower for q in ['does this', 'is this', 'can this']):
277
+ return "visual-qa", target_language
278
+
279
+ return "chatbot", target_language
280
+
281
+ class ProcessResponse(BaseModel):
282
+ response: str
283
+ type: str
284
+ additional_data: Optional[Dict[str, Any]] = None
285
+
286
+ @app.get("/chatbot")
287
+ async def chatbot_interface():
288
+ """Redirect to the static index.html file for the chatbot interface"""
289
+ return RedirectResponse(url="/static/index.html")
290
+
291
+ @app.post("/chat")
292
+ async def chat_endpoint(data: dict):
293
+ message = data.get("message", "")
294
+ if not message:
295
+ raise HTTPException(status_code=400, detail="No message provided")
296
+ try:
297
+ response = get_gemini_response(message)
298
+ return {"response": response}
299
+ except Exception as e:
300
+ raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}")
301
+
302
+ @app.post("/process", response_model=ProcessResponse)
303
+ async def process_input(
304
+ request: Request,
305
+ text: str = Form(None),
306
+ file: UploadFile = File(None)
307
+ ):
308
+ """Enhanced unified endpoint with dynamic translation and file QA"""
309
+ start_time = time.time()
310
+ client_ip = request.client.host
311
+ logger.info(f"Request from {client_ip}: text={text[:50] + '...' if text and len(text) > 50 else text}, file={file.filename if file else None}")
312
+
313
+ intent, target_language = detect_intent(text, file)
314
+ logger.info(f"Detected intent: {intent}, target_language: {target_language}")
315
+
316
+ try:
317
+ if intent == "chatbot":
318
+ response = get_gemini_response(text)
319
+ return {"response": response, "type": "chat"}
320
+
321
+ elif intent == "translate":
322
+ content = await extract_text_from_file(file) if file else text
323
+ if "all languages" in text.lower():
324
+ translations = {}
325
+ phrase_to_translate = "I want to explore the stars" if "I want to explore the stars" in text else content
326
+ for lang, code in SUPPORTED_LANGUAGES.items():
327
+ translation_tokenizer.src_lang = "en"
328
+ encoded = translation_tokenizer(phrase_to_translate, return_tensors="pt").to(translation_model.device)
329
+ generated_tokens = translation_model.generate(
330
+ **encoded,
331
+ forced_bos_token_id=translation_tokenizer.get_lang_id(code),
332
+ max_length=512,
333
+ num_beams=1
334
+ )
335
+ translations[lang] = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
336
+ response = "\n".join(f"{lang.capitalize()}: {translations[lang]}" for lang in translations)
337
+ logger.info(f"Translated to all supported languages: {', '.join(translations.keys())}")
338
+ return {"response": response, "type": "translation"}
339
  else:
340
+ translated_text = translate_text(content, target_language)
341
+ return {"response": translated_text, "type": "translation"}
342
+
343
+ elif intent == "summarize":
344
+ content = await extract_text_from_file(file) if file else text
345
+ summarizer = load_model("summarization")
346
+
347
+ content_length = len(content.split())
348
+ max_len = max(30, min(150, content_length//2))
349
+ min_len = max(15, min(30, max_len//2))
350
+
351
+ if len(content) > 1024:
352
+ chunks = [content[i:i+1024] for i in range(0, len(content), 1024)]
353
+ summaries = []
354
 
355
+ for chunk in chunks[:3]:
356
+ summary = summarizer(
357
+ chunk,
358
+ max_length=max_len,
359
+ min_length=min_len,
360
+ do_sample=False,
361
+ truncation=True
362
+ )
363
+ summaries.append(summary[0]['summary_text'])
364
 
365
+ final_summary = " ".join(summaries)
 
 
366
  else:
367
+ summary = summarizer(
368
+ content,
369
+ max_length=max_len,
370
+ min_length=min_len,
371
+ do_sample=False,
372
+ truncation=True
373
+ )
374
+ final_summary = summary[0]['summary_text']
375
+
376
+ final_summary = re.sub(r'\s+', ' ', final_summary).strip()
377
+ return {"response": final_summary, "type": "summary"}
378
+
379
+ elif intent == "image-to-text":
380
+ if not file or not file.content_type.startswith('image/'):
381
+ raise HTTPException(status_code=400, detail="An image file is required")
382
+
383
+ image = Image.open(io.BytesIO(await file.read()))
384
+ captioner = load_model("image-to-text")
385
+
386
+ caption = captioner(image, max_new_tokens=50)
387
+
388
+ return {"response": caption[0]['generated_text'], "type": "caption"}
389
+
390
+ elif intent == "visual-qa":
391
+ if not file or not file.content_type.startswith('image/'):
392
+ raise HTTPException(status_code=400, detail="An image file is required")
393
+ if not text:
394
+ raise HTTPException(status_code=400, detail="A question is required for VQA")
395
+
396
+ image = Image.open(io.BytesIO(await file.read())).convert("RGB")
397
+ vqa_pipeline = load_model("visual-qa")
398
+
399
+ question = text.strip()
400
+ if not question.endswith('?'):
401
+ question += '?'
402
+
403
+ answer = vqa_pipeline(
404
+ image=image,
405
+ question=question
406
+ )
407
+
408
+ answer = answer.strip()
409
+ if not answer or answer.lower() == question.lower():
410
+ logger.warning(f"VQA failed to generate a meaningful answer: {answer}")
411
+ answer = "I couldn't determine the answer from the image."
412
+ else:
413
+ answer = answer.capitalize()
414
+ if not answer.endswith(('.', '!', '?')):
415
+ answer += '.'
416
+ chatbot = load_model("chatbot")
417
+ if "fly" in question.lower():
418
+ answer = chatbot.generate_content(f"Make this fun and spacey: {answer}").text.strip()
419
+ else:
420
+ answer = chatbot.generate_content(f"Make this cosmic and poetic: {answer}").text.strip()
421
+
422
+ logger.info(f"Final VQA answer: {answer}")
423
+
424
+ return {
425
+ "response": answer,
426
+ "type": "visual_qa",
427
+ "additional_data": {
428
+ "question": text,
429
+ "image_size": f"{image.width}x{image.height}"
430
+ }
431
+ }
432
+
433
+ elif intent == "visualize":
434
+ if not file:
435
+ raise HTTPException(status_code=400, detail="An Excel file is required")
436
+
437
+ file_content = await file.read()
438
+
439
+ if file.filename.endswith('.csv'):
440
+ df = pd.read_csv(io.BytesIO(file_content))
441
+ else:
442
+ df = pd.read_excel(io.BytesIO(file_content))
443
+
444
+ code = generate_visualization_code(df, text)
445
+ stats = df.describe().to_string()
446
+ response = f"Stats:\n{stats}\n\nChart Code:\n{code}"
447
+
448
+ return {"response": response, "type": "visualization_code"}
449
+
450
+ elif intent == "text-generation":
451
+ response = get_gemini_response(text, is_generation=True)
452
+ lines = response.split(". ")
453
+ formatted_poem = "\n".join(line.strip() + ("." if not line.endswith(".") else "") for line in lines if line)
454
+ return {"response": formatted_poem, "type": "generated_text"}
455
+
456
+ elif intent == "file-qa":
457
+ if not file or not file.filename.lower().endswith(('.pdf', '.docx', '.doc', '.txt', '.rtf')):
458
+ raise HTTPException(status_code=400, detail="A text-based file (PDF, DOCX, TXT, RTF) is required")
459
+ if not text:
460
+ raise HTTPException(status_code=400, detail="A question about the file is required")
461
+
462
+ content = await extract_text_from_file(file)
463
+ if not content.strip():
464
+ raise HTTPException(status_code=400, detail="No text could be extracted from the file")
465
+
466
+ qa_pipeline = load_model("file-qa")
467
+
468
+ question = text.strip()
469
+ if not question.endswith('?'):
470
+ question += '?'
471
+
472
+ # Chunk content if too long (model context limit ~512 tokens)
473
+ if len(content) > 1024:
474
+ chunks = [content[i:i+1024] for i in range(0, len(content), 1024)]
475
+ answers = []
476
+ for chunk in chunks[:3]: # Limit to 3 chunks to avoid excessive processing
477
+ result = qa_pipeline(question=question, context=chunk)
478
+ if result['score'] > 0.1: # Only include high-confidence answers
479
+ answers.append((result['answer'], result['score']))
480
+ if answers:
481
+ # Select the answer with the highest confidence score
482
+ best_answer = max(answers, key=lambda x: x[1])[0]
483
  else:
484
+ best_answer = "I couldn't find a clear answer in the document."
 
 
 
 
 
485
  else:
486
+ result = qa_pipeline(question=question, context=content)
487
+ best_answer = result['answer'] if result['score'] > 0.1 else "I couldn't find a clear answer in the document."
 
488
 
489
+ best_answer = best_answer.strip().capitalize()
490
+ if not best_answer.endswith(('.', '!', '?')):
491
+ best_answer += '.'
492
 
493
+ # Add cosmic tone
494
+ chatbot = load_model("chatbot")
495
+ final_answer = chatbot.generate_content(f"Make this cosmic and poetic: {best_answer}").text.strip()
 
 
496
 
497
+ logger.info(f"File QA answer: {final_answer}")
498
+
499
+ return {
500
+ "response": final_answer,
501
+ "type": "file_qa",
502
+ "additional_data": {
503
+ "question": text,
504
+ "file_name": file.filename
505
+ }
506
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
 
 
508
  else:
509
+ response = get_gemini_response(text or "Hello! How can I assist you?")
510
+ return {"response": response, "type": "chat"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
 
512
  except Exception as e:
513
+ logger.error(f"Processing error: {str(e)}", exc_info=True)
514
+ raise HTTPException(status_code=500, detail=str(e))
515
+ finally:
516
+ process_time = time.time() - start_time
517
+ logger.info(f"Request processed in {process_time:.2f} seconds")
518
+
519
+ async def extract_text_from_file(file: UploadFile) -> str:
520
+ """Enhanced text extraction with multiple fallbacks"""
521
+ if not file:
522
+ return ""
523
+
524
+ content = await file.read()
525
+ filename = file.filename.lower()
526
 
 
 
 
527
  try:
528
+ if filename.endswith('.pdf'):
529
+ try:
530
+ doc = fitz.open(stream=content, filetype="pdf")
531
+ if doc.is_encrypted:
532
+ return "PDF is encrypted and cannot be read"
533
+ text = ""
534
+ for page in doc:
535
+ text += page.get_text()
536
+ return text
537
+ except Exception as pdf_error:
538
+ logger.warning(f"PyMuPDF failed: {str(pdf_error)}. Trying pdfminer.six...")
539
+ from pdfminer.high_level import extract_text
540
+ from io import BytesIO
541
+ return extract_text(BytesIO(content))
542
+
543
+ elif filename.endswith(('.docx', '.doc')):
544
+ doc = Document(io.BytesIO(content))
545
+ return "\n".join(para.text for para in doc.paragraphs)
546
+
547
+ elif filename.endswith('.txt'):
548
+ return content.decode('utf-8', errors='replace')
549
+
550
+ elif filename.endswith('.rtf'):
551
+ text = content.decode('utf-8', errors='replace')
552
+ text = re.sub(r'\\[a-z]+', ' ', text)
553
+ text = re.sub(r'\{|\}|\\', '', text)
554
+ return text
555
+
556
+ else:
557
+ raise HTTPException(status_code=400, detail=f"Unsupported file format: {filename}")
558
+
559
+ except Exception as e:
560
+ logger.error(f"File extraction error: {str(e)}", exc_info=True)
561
+ raise HTTPException(
562
+ status_code=500,
563
+ detail=f"Error extracting text: {str(e)}. Supported formats: PDF, DOCX, TXT, RTF"
564
  )
565
+
566
+ def generate_visualization_code(df: pd.DataFrame, request: str = None) -> str:
567
+ """Generate visualization code based on data analysis"""
568
+ num_rows, num_cols = df.shape
569
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
570
+ categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
571
+ date_cols = [col for col in df.columns if df[col].dtype == 'datetime64[ns]' or
572
+ (isinstance(df[col].dtype, object) and pd.to_datetime(df[col], errors='coerce').notna().all())]
573
+
574
+ if request:
575
+ request_lower = request.lower()
576
+ else:
577
+ request_lower = ""
578
+
579
+ if len(numeric_cols) >= 2 and ("scatter" in request_lower or "correlation" in request_lower):
580
+ x_col = numeric_cols[0]
581
+ y_col = numeric_cols[1]
582
+ return f"""import pandas as pd
583
+ import matplotlib.pyplot as plt
584
+ import seaborn as sns
585
+ df = pd.read_excel('data.xlsx')
586
+ plt.figure(figsize=(10, 6))
587
+ sns.regplot(x='{x_col}', y='{y_col}', data=df, scatter_kws={{'alpha': 0.6}})
588
+ plt.title('Correlation between {x_col} and {y_col}')
589
+ plt.grid(True, alpha=0.3)
590
+ plt.tight_layout()
591
+ plt.savefig('correlation_plot.png')
592
+ plt.show()
593
+ correlation = df['{x_col}'].corr(df['{y_col}'])
594
+ print(f"Correlation coefficient: {{correlation:.4f}}")"""
595
+
596
+ elif len(numeric_cols) >= 1 and len(categorical_cols) >= 1 and ("bar" in request_lower or "comparison" in request_lower):
597
+ cat_col = categorical_cols[0]
598
+ num_col = numeric_cols[0]
599
+ return f"""import pandas as pd
600
+ import matplotlib.pyplot as plt
601
+ import seaborn as sns
602
+ df = pd.read_excel('data.xlsx')
603
+ plt.figure(figsize=(12, 7))
604
+ ax = sns.barplot(x='{cat_col}', y='{num_col}', data=df, palette='viridis')
605
+ for p in ax.patches:
606
+ ax.annotate(f'{{p.get_height():.1f}}',
607
+ (p.get_x() + p.get_width() / 2., p.get_height()),
608
+ ha='center', va='bottom', fontsize=10, color='black', xytext=(0, 5),
609
+ textcoords='offset points')
610
+ plt.title('Comparison of {num_col} by {cat_col}', fontsize=15)
611
+ plt.xlabel('{cat_col}', fontsize=12)
612
+ plt.ylabel('{num_col}', fontsize=12)
613
+ plt.xticks(rotation=45, ha='right')
614
+ plt.grid(axis='y', alpha=0.3)
615
+ plt.tight_layout()
616
+ plt.savefig('comparison_chart.png')
617
+ plt.show()"""
618
+
619
+ elif len(numeric_cols) >= 1 and ("distribution" in request_lower or "histogram" in request_lower):
620
+ num_col = numeric_cols[0]
621
+ return f"""import pandas as pd
622
+ import matplotlib.pyplot as plt
623
+ import seaborn as sns
624
+ df = pd.read_excel('data.xlsx')
625
+ plt.figure(figsize=(10, 6))
626
+ sns.histplot(df['{num_col}'], kde=True, bins=20, color='purple')
627
+ plt.title('Distribution of {num_col}', fontsize=15)
628
+ plt.xlabel('{num_col}', fontsize=12)
629
+ plt.ylabel('Frequency', fontsize=12)
630
+ plt.grid(True, alpha=0.3)
631
+ plt.tight_layout()
632
+ plt.savefig('distribution_plot.png')
633
+ plt.show()
634
+ print(df['{num_col}'].describe())"""
635
+
636
+ else:
637
+ return f"""import pandas as pd
638
+ import matplotlib.pyplot as plt
639
+ import seaborn as sns
640
+ import numpy as np
641
+ df = pd.read_excel('data.xlsx')
642
+ print("Descriptive statistics:")
643
+ print(df.describe())
644
+ fig, axes = plt.subplots(2, 2, figsize=(15, 12))
645
+ numeric_df = df.select_dtypes(include=[np.number])
646
+ if not numeric_df.empty and numeric_df.shape[1] > 1:
647
+ sns.heatmap(numeric_df.corr(), annot=True, cmap='coolwarm', fmt='.2f', ax=axes[0, 0])
648
+ axes[0, 0].set_title('Correlation Matrix')
649
+ if not numeric_df.empty:
650
+ for i, col in enumerate(numeric_df.columns[:1]):
651
+ sns.histplot(df[col], kde=True, ax=axes[0, 1], color='purple')
652
+ axes[0, 1].set_title(f'Distribution of {{col}}')
653
+ axes[0, 1].set_xlabel(col)
654
+ axes[0, 1].set_ylabel('Frequency')
655
+ categorical_cols = df.select_dtypes(include=['object']).columns
656
+ if len(categorical_cols) > 0 and not numeric_df.empty:
657
+ cat_col = categorical_cols[0]
658
+ num_col = numeric_df.columns[0]
659
+ sns.barplot(x=cat_col, y=num_col, data=df, ax=axes[1, 0], palette='viridis')
660
+ axes[1, 0].set_title(f'{{num_col}} by {{cat_col}}')
661
+ axes[1, 0].set_xticklabels(axes[1, 0].get_xticklabels(), rotation=45, ha='right')
662
+ if not numeric_df.empty and len(categorical_cols) > 0:
663
+ cat_col = categorical_cols[0]
664
+ num_col = numeric_df.columns[0]
665
+ sns.boxplot(x=cat_col, y=num_col, data=df, ax=axes[1, 1], palette='Set3')
666
+ axes[1, 1].set_title(f'Distribution of {{num_col}} by {{cat_col}}')
667
+ axes[1, 1].set_xticklabels(axes[1, 1].get_xticklabels(), rotation=45, ha='right')
668
+ plt.tight_layout()
669
+ plt.savefig('dashboard.png')
670
+ plt.show()"""
671
+
672
+ @app.get("/", include_in_schema=False)
673
+ async def home():
674
+ """Redirect to the static index.html file"""
675
+ return RedirectResponse(url="/static/index.html")
676
+
677
+ @app.get("/health", include_in_schema=True)
678
+ async def health_check():
679
+ """Health check endpoint"""
680
+ return {"status": "healthy", "version": "2.0.0"}
681
+
682
+ @app.get("/models", include_in_schema=True)
683
+ async def list_models():
684
+ """List available models"""
685
+ return {"models": MODELS}
686
+
687
+ @app.on_event("startup")
688
+ async def startup_event():
689
+ """Pre-load models at startup with timeout"""
690
+ global translation_model, translation_tokenizer
691
+ logger.info("Starting model pre-loading...")
692
+
693
+ async def load_model_with_timeout(task):
694
+ try:
695
+ await asyncio.wait_for(asyncio.to_thread(load_model, task), timeout=60.0)
696
+ logger.info(f"Successfully loaded {task} model")
697
+ except asyncio.TimeoutError:
698
+ logger.warning(f"Timeout loading {task} model - will load on demand")
699
+ except Exception as e:
700
+ logger.error(f"Error pre-loading {task}: {str(e)}")
701
+
702
+ try:
703
+ model_name = MODELS["translation"]
704
+ translation_model = M2M100ForConditionalGeneration.from_pretrained(model_name)
705
+ translation_tokenizer = M2M100Tokenizer.from_pretrained(model_name)
706
+ device = "cuda" if torch.cuda.is_available() else "cpu"
707
+ translation_model.to(device)
708
+ logger.info("Translation model pre-loaded successfully")
709
  except Exception as e:
710
+ logger.error(f"Error pre-loading translation model: {str(e)}")
711
+
712
+ await asyncio.gather(
713
+ load_model_with_timeout("summarization"),
714
+ load_model_with_timeout("image-to-text"),
715
+ load_model_with_timeout("visual-qa"),
716
+ load_model_with_timeout("chatbot"),
717
+ load_model_with_timeout("file-qa") # Pre-load file QA model
718
+ )
719
+
720
+ if __name__ == "__main__":
721
+ import uvicorn
722
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)