Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import RedirectResponse, JSONResponse | |
| from transformers import pipeline, ViltProcessor, ViltForQuestionAnswering, M2M100ForConditionalGeneration, M2M100Tokenizer | |
| from typing import Optional, Dict, Any | |
| import logging | |
| import time | |
| import os | |
| import io | |
| import re | |
| from PIL import Image | |
| from docx import Document | |
| import fitz # PyMuPDF | |
| import pandas as pd | |
| from functools import lru_cache | |
| import torch | |
| import numpy as np | |
| from pydantic import BaseModel | |
| import asyncio | |
| import google.generativeai as genai | |
| from spellchecker import SpellChecker | |
| import nltk | |
| from nltk.tokenize import sent_tokenize | |
| from dotenv import load_dotenv | |
| import shutil | |
| # Load environment variables | |
| load_dotenv() | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger("cosmic_ai") | |
| # Set a custom NLTK data directory | |
| nltk_data_dir = os.getenv('NLTK_DATA_DIR', '/cache/nltk_data') | |
| os.makedirs(nltk_data_dir, exist_ok=True) | |
| nltk.data.path.append(nltk_data_dir) | |
| # Download punkt_tab data if not already present | |
| try: | |
| nltk.download('punkt_tab', download_dir=nltk_data_dir, quiet=True, raise_on_error=True) | |
| logger.info(f"NLTK punkt_tab verified in {nltk_data_dir}") | |
| except Exception as e: | |
| logger.error(f"Error verifying NLTK punkt_tab: {str(e)}") | |
| raise Exception(f"Failed to verify NLTK punkt_tab: {str(e)}") | |
| # Create app directory if it doesn't exist | |
| upload_dir = os.getenv('UPLOAD_DIR', '/tmp/uploads') | |
| os.makedirs(upload_dir, exist_ok=True) | |
| app = FastAPI( | |
| title="VION IA Assistant", | |
| description="An advanced AI assistant with space-themed interface, translation, and file question-answering features", | |
| version="2.0.0" | |
| ) | |
| # Mount static files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| app.mount("/images", StaticFiles(directory="images"), name="images") | |
| # Gemini API Configuration | |
| API_KEY = os.getenv('AIzaSyDtLhhmXpy8ubSGb84ImaxM_ywlL0l_8bo') | |
| GEMINI_AVAILABLE = True | |
| if not API_KEY: | |
| logger.warning("GEMINI_API_KEY not set. Gemini-dependent features (chatbot, file QA fallback, summarization fallback, text generation) will be disabled.") | |
| GEMINI_AVAILABLE = False | |
| else: | |
| try: | |
| genai.configure(api_key=API_KEY) | |
| logger.info("Gemini API configured successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to configure Gemini API: {str(e)}. Disabling Gemini features.") | |
| GEMINI_AVAILABLE = False | |
| # Model configurations | |
| MODELS = { | |
| "summarization": "sshleifer/distilbart-cnn-12-6", | |
| "image-to-text": "Salesforce/blip-image-captioning-large", | |
| "visual-qa": "dandelin/vilt-b32-finetuned-vqa", | |
| "chatbot": "gemini-1.5-pro", | |
| "translation": "facebook/m2m100_418M", | |
| "file-qa": "distilbert-base-cased-distilled-squad" | |
| } | |
| # Supported languages for translation | |
| SUPPORTED_LANGUAGES = { | |
| "english": "en", | |
| "french": "fr", | |
| "german": "de", | |
| "spanish": "es", | |
| "italian": "it", | |
| "russian": "ru", | |
| "chinese": "zh", | |
| "japanese": "ja", | |
| "arabic": "ar", | |
| "hindi": "hi", | |
| "portuguese": "pt", | |
| "korean": "ko" | |
| } | |
| # Global variables for pre-loaded translation model | |
| translation_model = None | |
| translation_tokenizer = None | |
| # Initialize spell checker | |
| spell = SpellChecker() | |
| # Cache for model loading (excluding translation) | |
| def load_model(task: str, model_name: str = None): | |
| """Cached model loader with proper task names and error handling""" | |
| try: | |
| cache_dir = os.getenv('HF_HOME', '/cache/huggingface') | |
| if not os.path.exists(cache_dir): | |
| os.makedirs(cache_dir, exist_ok=True) | |
| elif not os.access(cache_dir, os.W_OK): | |
| logger.warning(f"Cache directory {cache_dir} is not writable. Attempting to clear cache.") | |
| shutil.rmtree(cache_dir, ignore_errors=True) | |
| os.makedirs(cache_dir, exist_ok=True) | |
| logger.info(f"Loading model for task: {task}, model: {model_name or MODELS.get(task)}") | |
| start_time = time.time() | |
| model_to_load = model_name or MODELS.get(task) | |
| if task == "chatbot": | |
| if not GEMINI_AVAILABLE: | |
| logger.warning("Gemini not available. Returning None for chatbot task.") | |
| return None | |
| return genai.GenerativeModel(model_to_load) | |
| if task == "visual-qa": | |
| processor = ViltProcessor.from_pretrained(model_to_load, cache_dir=cache_dir) | |
| model = ViltForQuestionAnswering.from_pretrained(model_to_load, cache_dir=cache_dir) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| def vqa_function(image, question, **generate_kwargs): | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| inputs = processor(image, question, return_tensors="pt").to(device) | |
| logger.info(f"VQA inputs - question: {question}, image size: {image.size}") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| idx = logits.argmax(-1).item() | |
| answer = model.config.id2label[idx] | |
| logger.info(f"VQA raw output: {answer}") | |
| return answer | |
| return vqa_function | |
| return pipeline( | |
| task if task != "file-qa" else "question-answering", | |
| model=model_to_load, | |
| cache_dir=cache_dir | |
| ) | |
| except Exception as e: | |
| logger.error(f"Model load failed: {task} - {str(e)}") | |
| if task == "file-qa": | |
| logger.warning("Falling back to Gemini for file-qa due to model load failure") | |
| return None | |
| raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}") | |
| def get_gemini_response(user_input: str, is_generation: bool = False): | |
| """Function to generate response with Gemini for both chat and text generation""" | |
| if not GEMINI_AVAILABLE: | |
| return "Error: Gemini API is not available. Please contact the administrator." | |
| if not user_input: | |
| return "Please provide some input." | |
| try: | |
| chatbot = load_model("chatbot") | |
| if not chatbot: | |
| return "Error: Gemini API is not available." | |
| if is_generation: | |
| prompt = f"Generate creative text based on this prompt: {user_input}" | |
| else: | |
| prompt = user_input | |
| response = chatbot.generate_content(prompt) | |
| return response.text.strip() | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def translate_text(text: str, target_language: str): | |
| """Translate text to any target language using pre-loaded M2M100 model""" | |
| if not text: | |
| return "Please provide text to translate." | |
| try: | |
| global translation_model, translation_tokenizer | |
| target_lang = target_language.lower() | |
| if target_lang not in SUPPORTED_LANGUAGES: | |
| similar = [lang for lang in SUPPORTED_LANGUAGES if target_lang in lang or lang in target_lang] | |
| if similar: | |
| target_lang = similar[0] | |
| else: | |
| return f"Language '{target_language}' not supported. Available languages: {', '.join(SUPPORTED_LANGUAGES.keys())}" | |
| lang_code = SUPPORTED_LANGUAGES[target_lang] | |
| if translation_model is None or translation_tokenizer is None: | |
| raise Exception("Translation model not initialized") | |
| match = re.search(r'how to say\s+(.+?)\s+in\s+(\w+)', text.lower()) | |
| if match: | |
| text_to_translate = match.group(1) | |
| else: | |
| content_match = re.search(r'(?:translate|convert).*to\s+[a-zA-Z]+\s*[:\s]*(.+)', text, re.IGNORECASE) | |
| text_to_translate = content_match.group(1) if content_match else text | |
| translation_tokenizer.src_lang = "en" | |
| encoded = translation_tokenizer(text_to_translate, return_tensors="pt", padding=True, truncation=True).to(translation_model.device) | |
| start_time = time.time() | |
| generated_tokens = translation_model.generate( | |
| **encoded, | |
| forced_bos_token_id=translation_tokenizer.get_lang_id(lang_code), | |
| max_length=512, | |
| num_beams=1, | |
| early_stopping=True | |
| ) | |
| translated_text = translation_tokenizer.batch_decode( | |
| generated_tokens, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| )[0] | |
| logger.info(f"Translation took {time.time() - start_time:.2f} seconds") | |
| return translated_text | |
| except Exception as e: | |
| logger.error(f"Translation error: {str(e)}") | |
| return f"Translation error: {str(e)}" | |
| def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]: | |
| """Enhanced intent detection with dynamic translation and file translation support""" | |
| target_language = "English" # Default | |
| if file and text: | |
| text_lower = text.lower() | |
| filename = file.filename.lower() if file.filename else "" | |
| translate_patterns = [ | |
| r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)', | |
| r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)', | |
| r'how to say.*in\s+\[?([a-zA-Z]+)\]?:?\s*(.*)' | |
| ] | |
| for pattern in translate_patterns: | |
| translate_match = re.search(pattern, text_lower) | |
| if translate_match and filename.endswith(('.pdf', '.docx', '.txt', '.rtf')): | |
| potential_lang = translate_match.group(1).lower() | |
| if potential_lang in SUPPORTED_LANGUAGES: | |
| target_language = potential_lang.capitalize() | |
| return "file-translate", target_language | |
| content_type = file.content_type.lower() if file.content_type else "" | |
| if content_type.startswith('image/') and text: | |
| if "what’s this" in text_lower or "does this fly" in text_lower or ("fly" in text_lower and any(q in text_lower for q in ['does', 'can', 'will'])): | |
| return "visual-qa", target_language | |
| if any(q in text_lower for q in ['what is', 'what\'s', 'describe', 'tell me about', 'explain', 'how many', 'what color', 'is there', 'are they', 'does the']): | |
| return "visual-qa", target_language | |
| if "generate a caption" in text_lower or "caption" in text_lower: | |
| return "image-to-text", target_language | |
| if filename.endswith(('.xlsx', '.xls', '.csv')): | |
| return "visualize", target_language | |
| elif filename.endswith(('.pdf', '.docx', '.doc', '.txt', '.rtf')): | |
| if any(q in text_lower for q in ['what is', 'who is', 'where', 'when', 'why', 'how', 'what are', 'who are']): | |
| return "file-qa", target_language | |
| return "summarize", target_language | |
| if not text: | |
| return "chatbot", target_language | |
| text_lower = text.lower() | |
| if any(keyword in text_lower for keyword in ['chat', 'talk', 'converse', 'ask gemini']): | |
| return "chatbot", target_language | |
| translate_patterns = [ | |
| r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)', | |
| r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)', | |
| r'how to say.*in\s+\[?([a-zA-Z]+)\]?:?\s*(.*)' | |
| ] | |
| for pattern in translate_patterns: | |
| translate_match = re.search(pattern, text_lower) | |
| if translate_match: | |
| potential_lang = translate_match.group(1).lower() | |
| if potential_lang in SUPPORTED_LANGUAGES: | |
| target_language = potential_lang.capitalize() | |
| return "translate", target_language | |
| else: | |
| logger.warning(f"Invalid language detected: {potential_lang}") | |
| return "chatbot", target_language | |
| vqa_patterns = [ | |
| r'how (many|much)', | |
| r'what (color|size|position|shape)', | |
| r'is (there|that|this) (a|an)', | |
| r'are (they|there) (any|some)', | |
| r'does (the|this) (image|picture) (show|contain)' | |
| ] | |
| if any(re.search(pattern, text_lower) for pattern in vqa_patterns): | |
| return "visual-qa", target_language | |
| summarization_patterns = [ | |
| r'\b(summar(y|ize|ise)|brief( overview)?)\b', | |
| r'\b(long article|text|document)\b', | |
| r'\bcan you (summar|brief|condense)\b', | |
| r'\b(short summary|brief explanation)\b', | |
| r'\b(overview|main points|key ideas)\b', | |
| r'\b(tl;?dr|too long didn\'?t read)\b' | |
| ] | |
| if any(re.search(pattern, text_lower) for pattern in summarization_patterns): | |
| return "summarize", target_language | |
| generation_patterns = [ | |
| r'\b(write|generate|create|compose)\b', | |
| r'\b(story|poem|essay|text|content)\b' | |
| ] | |
| if any(re.search(pattern, text_lower) for pattern in generation_patterns): | |
| return "text-generation", target_language | |
| if len(text) > 100: | |
| return "summarize", target_language | |
| return "chatbot", target_language | |
| def preprocess_text(text: str) -> str: | |
| """Correct spelling errors and improve text readability.""" | |
| words = text.split() | |
| corrected_words = [spell.correction(word) if spell.correction(word) else word for word in words] | |
| corrected_text = " ".join(corrected_words) | |
| sentences = sent_tokenize(corrected_text) | |
| return ". ".join(sentence.capitalize() for sentence in sentences) + (". " if sentences else "") | |
| class ProcessResponse(BaseModel): | |
| response: str | |
| type: str | |
| additional_data: Optional[Dict[str, Any]] = None | |
| async def chatbot_interface(): | |
| """Redirect to the static index.html file for the chatbot interface""" | |
| return RedirectResponse(url="/static/index.html") | |
| async def chat_endpoint(data: dict): | |
| message = data.get("message", "") | |
| if not message: | |
| raise HTTPException(status_code=400, detail="No message provided") | |
| try: | |
| response = get_gemini_response(message) | |
| return {"response": response} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}") | |
| async def process_input( | |
| request: Request, | |
| text: str = Form(None), | |
| file: UploadFile = File(None) | |
| ): | |
| """Enhanced unified endpoint with dynamic translation and file translation""" | |
| start_time = time.time() | |
| client_ip = request.client.host | |
| logger.info(f"Request from {client_ip}: text={text[:50] + '...' if text and len(text) > 50 else text}, file={file.filename if file else None}") | |
| intent, target_language = detect_intent(text, file) | |
| logger.info(f"Detected intent: {intent}, target_language: {target_language}") | |
| try: | |
| if intent == "chatbot": | |
| response = get_gemini_response(text) | |
| return {"response": response, "type": "chat"} | |
| elif intent == "translate": | |
| content = await extract_text_from_file(file) if file else text | |
| if "all languages" in text.lower(): | |
| translations = {} | |
| phrase_to_translate = "I want to explore the stars" if "I want to explore the stars" in text else content | |
| for lang, code in SUPPORTED_LANGUAGES.items(): | |
| translation_tokenizer.src_lang = "en" | |
| encoded = translation_tokenizer(phrase_to_translate, return_tensors="pt").to(translation_model.device) | |
| generated_tokens = translation_model.generate( | |
| **encoded, | |
| forced_bos_token_id=translation_tokenizer.get_lang_id(code), | |
| max_length=512, | |
| num_beams=1 | |
| ) | |
| translations[lang] = translation_tokenizer.batch_decode( | |
| generated_tokens, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| )[0] | |
| response = "\n".join(f"{lang.capitalize()}: {translations[lang]}" for lang in translations) | |
| logger.info(f"Translated to all supported languages: {', '.join(translations.keys())}") | |
| return {"response": response, "type": "translation"} | |
| else: | |
| translated_text = translate_text(content, target_language) | |
| return {"response": translated_text, "type": "translation"} | |
| elif intent == "file-translate": | |
| if not file or not file.filename.lower().endswith(('.pdf', '.docx', '.txt', '.rtf')): | |
| raise HTTPException(status_code=400, detail="A text-based file (PDF, DOCX, TXT, RTF) is required") | |
| if not text: | |
| raise HTTPException(status_code=400, detail="Please specify a target language for translation") | |
| content = await extract_text_from_file(file) | |
| if not content.strip(): | |
| raise HTTPException(status_code=400, detail="No text could be extracted from the file") | |
| max_chunk_size = 512 | |
| chunks = [content[i:i+max_chunk_size] for i in range(0, len(content), max_chunk_size)] | |
| translated_chunks = [] | |
| for chunk in chunks: | |
| translated_chunk = translate_text(chunk, target_language) | |
| translated_chunks.append(translated_chunk) | |
| translated_text = " ".join(translated_chunks) | |
| translated_text = translated_text.strip().capitalize() | |
| if not translated_text.endswith(('.', '!', '?')): | |
| translated_text += '.' | |
| logger.info(f"File translated to {target_language}: {translated_text[:100]}...") | |
| return { | |
| "response": translated_text, | |
| "type": "file_translation", | |
| "additional_data": { | |
| "file_name": file.filename, | |
| "target_language": target_language | |
| } | |
| } | |
| elif intent == "summarize": | |
| content = await extract_text_from_file(file) if file else text | |
| if not content.strip(): | |
| raise HTTPException(status_code=400, detail="No content to summarize") | |
| content = preprocess_text(content) | |
| logger.info(f"Preprocessed content: {content[:100]}...") | |
| summarizer = load_model("summarization") | |
| content_length = len(content.split()) | |
| max_len = max(50, min(200, content_length)) | |
| min_len = max(20, min(50, content_length // 3)) | |
| try: | |
| if len(content) > 1024: | |
| chunks = [content[i:i+1024] for i in range(0, len(content), 1024)] | |
| summaries = [] | |
| for chunk in chunks[:3]: | |
| summary = summarizer( | |
| chunk, | |
| max_length=max_len, | |
| min_length=min_len, | |
| do_sample=False, | |
| truncation=True | |
| ) | |
| summaries.append(summary[0]['summary_text']) | |
| final_summary = " ".join(summaries) | |
| else: | |
| summary = summarizer( | |
| content, | |
| max_length=max_len, | |
| min_length=min_len, | |
| do_sample=False, | |
| truncation=True | |
| ) | |
| final_summary = summary[0]['summary_text'] | |
| final_summary = re.sub(r'\s+', ' ', final_summary).strip() | |
| if not final_summary or final_summary.lower().startswith(content.lower()[:30]): | |
| logger.warning("Summarizer produced inadequate output, falling back to Gemini") | |
| if GEMINI_AVAILABLE: | |
| final_summary = get_gemini_response( | |
| f"Summarize this text in a concise and meaningful way: {content}" | |
| ) | |
| else: | |
| final_summary = "Summarization fallback unavailable without Gemini API." | |
| if not final_summary.endswith(('.', '!', '?')): | |
| final_summary += '.' | |
| logger.info(f"Generated summary: {final_summary}") | |
| return {"response": final_summary, "type": "summary", "message": "Text was preprocessed to correct spelling errors"} | |
| except Exception as e: | |
| logger.error(f"Summarization error: {str(e)}") | |
| if GEMINI_AVAILABLE: | |
| final_summary = get_gemini_response( | |
| f"Summarize this text in a concise and meaningful way: {content}" | |
| ) | |
| else: | |
| final_summary = "Summarization failed and Gemini fallback is unavailable." | |
| return {"response": final_summary, "type": "summary", "message": "Text was preprocessed to correct spelling errors"} | |
| elif intent == "image-to-text": | |
| if not file or not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="An image file is required") | |
| image = Image.open(io.BytesIO(await file.read())) | |
| captioner = load_model("image-to-text") | |
| caption = captioner(image, max_new_tokens=50) | |
| return { | |
| "response": caption[0]['generated_text'], | |
| "type": "caption", | |
| "additional_data": { | |
| "image_size": f"{image.width}x{image.height}" | |
| } | |
| } | |
| elif intent == "visual-qa": | |
| if not file or not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="An image file is required") | |
| if not text: | |
| raise HTTPException(status_code=400, detail="A question is required for VQA") | |
| image = Image.open(io.BytesIO(await file.read())).convert("RGB") | |
| vqa_pipeline = load_model("visual-qa") | |
| question = text.strip() | |
| if not question.endswith('?'): | |
| question += '?' | |
| answer = vqa_pipeline( | |
| image=image, | |
| question=question | |
| ) | |
| answer = answer.strip() | |
| if not answer or answer.lower() == question.lower(): | |
| logger.warning(f"VQA failed to generate a meaningful answer: {answer}") | |
| answer = "I couldn't determine the answer from the image." | |
| else: | |
| answer = answer.capitalize() | |
| if not answer.endswith(('.', '!', '?')): | |
| answer += '.' | |
| factual_questions = ['color', 'size', 'number', 'how many', 'what is the'] | |
| is_factual = any(keyword in question.lower() for keyword in factual_questions) | |
| if is_factual: | |
| final_answer = answer | |
| else: | |
| if GEMINI_AVAILABLE: | |
| chatbot = load_model("chatbot") | |
| if "fly" in question.lower(): | |
| final_answer = chatbot.generate_content(f"Make this fun and spacey: {answer}").text.strip() | |
| else: | |
| final_answer = chatbot.generate_content(f"Make this cosmic and poetic: {answer}").text.strip() | |
| else: | |
| final_answer = answer | |
| logger.warning("Gemini unavailable for enhancing VQA answer") | |
| logger.info(f"Final VQA answer: {final_answer}") | |
| return { | |
| "response": final_answer, | |
| "type": "visual_qa", | |
| "additional_data": { | |
| "question": text, | |
| "image_size": f"{image.width}x{image.height}" | |
| } | |
| } | |
| elif intent == "visualize": | |
| if not file: | |
| raise HTTPException(status_code=400, detail="An Excel file is required") | |
| file_content = await file.read() | |
| if file.filename.endswith('.csv'): | |
| df = pd.read_csv(io.BytesIO(file_content)) | |
| else: | |
| df = pd.read_excel(io.BytesIO(file_content)) | |
| code = generate_visualization_code(df, text) | |
| stats = df.describe().to_string() | |
| response = f"Stats:\n{stats}\n\nChart Code:\n{code}" | |
| return {"response": response, "type": "visualization_code"} | |
| elif intent == "text-generation": | |
| if GEMINI_AVAILABLE: | |
| response = get_gemini_response(text, is_generation=True) | |
| lines = response.split(". ") | |
| formatted_poem = "\n".join(line.strip() + ("." if not line.endswith(".") else "") for line in lines if line) | |
| else: | |
| response = "Text generation is unavailable without Gemini API." | |
| formatted_poem = response | |
| return {"response": formatted_poem, "type": "generated_text"} | |
| elif intent == "file-qa": | |
| if not file or not file.filename.lower().endswith(('.pdf', '.docx', '.doc', '.txt', '.rtf')): | |
| raise HTTPException(status_code=400, detail="A text-based file (PDF, DOCX, TXT, RTF) is required") | |
| if not text: | |
| raise HTTPException(status_code=400, detail="A question about the file is required") | |
| content = await extract_text_from_file(file) | |
| if not content.strip(): | |
| raise HTTPException(status_code=400, detail="No text could be extracted from the file") | |
| qa_pipeline = load_model("file-qa") | |
| if qa_pipeline is None: | |
| logger.info("Using Gemini fallback for file-qa") | |
| if not GEMINI_AVAILABLE: | |
| return { | |
| "response": "File QA is unavailable without Gemini API or a working QA model.", | |
| "type": "file_qa", | |
| "additional_data": { | |
| "question": text, | |
| "file_name": file.filename | |
| } | |
| } | |
| question = text.strip() | |
| if not question.endswith('?'): | |
| question += '?' | |
| response = get_gemini_response(f"Answer this question based on the following text: {content}\nQuestion: {question}") | |
| return { | |
| "response": response, | |
| "type": "file_qa", | |
| "additional_data": { | |
| "question": text, | |
| "file_name": file.filename | |
| } | |
| } | |
| question = text.strip() | |
| if not question.endswith('?'): | |
| question += '?' | |
| if len(content) > 512: | |
| chunks = [content[i:i+512] for i in range(0, len(content), 512)] | |
| answers = [] | |
| for chunk in chunks[:3]: | |
| result = qa_pipeline(question=question, context=chunk) | |
| if result['score'] > 0.1: | |
| answers.append((result['answer'], result['score'])) | |
| if answers: | |
| best_answer = max(answers, key=lambda x: x[1])[0] | |
| else: | |
| best_answer = "I couldn't find a clear answer in the document." | |
| else: | |
| result = qa_pipeline(question=question, context=content) | |
| best_answer = result['answer'] if result['score'] > 0.1 else "I couldn't find a clear answer in the document." | |
| best_answer = best_answer.strip().capitalize() | |
| if not best_answer.endswith(('.', '!', '?')): | |
| best_answer += '.' | |
| try: | |
| if GEMINI_AVAILABLE: | |
| chatbot = load_model("chatbot") | |
| final_answer = chatbot.generate_content(f"Make this cosmic and poetic: {best_answer}").text.strip() | |
| else: | |
| final_answer = best_answer | |
| logger.warning("Gemini unavailable for enhancing file QA answer") | |
| except Exception as e: | |
| logger.warning(f"Failed to add cosmic tone: {str(e)}. Using raw answer.") | |
| final_answer = best_answer | |
| logger.info(f"File QA answer: {final_answer}") | |
| return { | |
| "response": final_answer, | |
| "type": "file_qa", | |
| "additional_data": { | |
| "question": text, | |
| "file_name": file.filename | |
| } | |
| } | |
| else: | |
| response = get_gemini_response(text or "Hello! How can I assist you?") | |
| return {"response": response, "type": "chat"} | |
| except Exception as e: | |
| logger.error(f"Processing error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| process_time = time.time() - start_time | |
| logger.info(f"Request processed in {process_time:.2f} seconds") | |
| async def extract_text_from_file(file: UploadFile) -> str: | |
| """Enhanced text extraction with multiple fallbacks""" | |
| if not file: | |
| return "" | |
| content = await file.read() | |
| filename = file.filename.lower() | |
| try: | |
| if filename.endswith('.pdf'): | |
| try: | |
| doc = fitz.open(stream=content, filetype="pdf") | |
| if doc.is_encrypted: | |
| return "PDF is encrypted and cannot be read" | |
| text = "" | |
| for page in doc: | |
| text += page.get_text() | |
| return text | |
| except Exception as pdf_error: | |
| logger.warning(f"PyMuPDF failed: {str(pdf_error)}. Trying pdfminer.six...") | |
| from pdfminer.high_level import extract_text | |
| from io import BytesIO | |
| return extract_text(BytesIO(content)) | |
| elif filename.endswith(('.docx', '.doc')): | |
| doc = Document(io.BytesIO(content)) | |
| return "\n".join(para.text for para in doc.paragraphs) | |
| elif filename.endswith('.txt'): | |
| return content.decode('utf-8', errors='replace') | |
| elif filename.endswith('.rtf'): | |
| text = content.decode('utf-8', errors='replace') | |
| text = re.sub(r'\\[a-z]+', ' ', text) | |
| text = re.sub(r'\{|\}|\\', '', text) | |
| return text | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Unsupported file format: {filename}") | |
| except Exception as e: | |
| logger.error(f"File extraction error: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error extracting text: {str(e)}. Supported formats: PDF, DOCX, TXT, RTF" | |
| ) | |
| def generate_visualization_code(df: pd.DataFrame, request: str = None) -> str: | |
| """Generate visualization code based on data analysis""" | |
| num_rows, num_cols = df.shape | |
| numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() | |
| categorical_cols = df.select_dtypes(include=['object']).columns.tolist() | |
| date_cols = [col for col in df.columns if df[col].dtype == 'datetime64[ns]' or | |
| (isinstance(df[col].dtype, np.dtype) and pd.to_datetime(df[col], errors='coerce').notna().all())] | |
| if request: | |
| request_lower = request.lower() | |
| else: | |
| request_lower = "" | |
| if len(numeric_cols) >= 2 and ("scatter" in request_lower or "correlation" in request_lower): | |
| x_col = numeric_cols[0] | |
| y_col = numeric_cols[1] | |
| return f"""import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| df = pd.read_excel('data.xlsx') | |
| plt.figure(figsize=(10, 6)) | |
| sns.regplot(x='{x_col}', y='{y_col}', data=df, scatter_kws={{'alpha': 0.6}}) | |
| plt.title('Correlation between {x_col} and {y_col}') | |
| plt.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig('correlation_plot.png') | |
| plt.show() | |
| correlation = df['{x_col}'].corr(df['{y_col}']) | |
| print(f"Correlation coefficient: {{correlation:.4f}}")""" | |
| elif len(numeric_cols) >= 1 and len(categorical_cols) >= 1 and ("bar" in request_lower or "comparison" in request_lower): | |
| cat_col = categorical_cols[0] | |
| num_col = numeric_cols[0] | |
| return f"""import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| df = pd.read_excel('data.xlsx') | |
| plt.figure(figsize=(12, 7)) | |
| ax = sns.barplot(x='{cat_col}', y='{num_col}', data=df, palette='viridis') | |
| for p in ax.patches: | |
| ax.annotate(f'{{p.get_height():.1f}}', | |
| (p.get_x() + p.get_width() / 2., p.get_height()), | |
| ha='center', va='bottom', fontsize=10, color='black', xytext=(0, 5), | |
| textcoords='offset points') | |
| plt.title('Comparison of {num_col} by {cat_col}', fontsize=15) | |
| plt.xlabel('{cat_col}', fontsize=12) | |
| plt.ylabel('{num_col}', fontsize=12) | |
| plt.xticks(rotation=45, ha='right') | |
| plt.grid(axis='y', alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig('comparison_chart.png') | |
| plt.show()""" | |
| elif len(numeric_cols) >= 1 and ("distribution" in request_lower or "histogram" in request_lower): | |
| num_col = numeric_cols[0] | |
| return f"""import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| df = pd.read_excel('data.xlsx') | |
| plt.figure(figsize=(10, 6)) | |
| sns.histplot(df['{num_col}'], kde=True, bins=20, color='purple') | |
| plt.title('Distribution of {num_col}', fontsize=15) | |
| plt.xlabel('{num_col}', fontsize=12) | |
| plt.ylabel('Frequency', fontsize=12) | |
| plt.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig('distribution_plot.png') | |
| plt.show() | |
| print(df['{num_col}'].describe())""" | |
| else: | |
| return f"""import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import numpy as np | |
| df = pd.read_excel('data.xlsx') | |
| print("Descriptive statistics:") | |
| print(df.describe()) | |
| fig, axes = plt.subplots(2, 2, figsize=(15, 12)) | |
| numeric_df = df.select_dtypes(include=[np.number]) | |
| if not numeric_df.empty and numeric_df.shape[1] > 1: | |
| sns.heatmap(numeric_df.corr(), annot=True, cmap='coolwarm', fmt='.2f', ax=axes[0, 0]) | |
| axes[0, 0].set_title('Correlation Matrix') | |
| if not numeric_df.empty: | |
| for i, col in enumerate(numeric_df.columns[:1]): | |
| sns.histplot(df[col], kde=True, ax=axes[0, 1], color='purple') | |
| axes[0, 1].set_title(f'Distribution of {{col}}') | |
| axes[0, 1].set_xlabel(col) | |
| axes[0, 1].set_ylabel('Frequency') | |
| categorical_cols = df.select_dtypes(include=['object']).columns | |
| if len(categorical_cols) > 0 and not numeric_df.empty: | |
| cat_col = categorical_cols[0] | |
| num_col = numeric_df.columns[0] | |
| sns.barplot(x=cat_col, y=num_col, data=df, ax=axes[1, 0], palette='viridis') | |
| axes[1, 0].set_title(f'{{num_col}} by {{cat_col}}') | |
| axes[1, 0].set_xticklabels(axes[1, 0].get_xticklabels(), rotation=45, ha='right') | |
| if not numeric_df.empty and len(categorical_cols) > 0: | |
| cat_col = categorical_cols[0] | |
| num_col = numeric_df.columns[0] | |
| sns.boxplot(x=cat_col, y=num_col, data=df, ax=axes[1, 1], palette='Set3') | |
| axes[1, 1].set_title(f'Distribution of {{num_col}} by {{cat_col}}') | |
| axes[1, 1].set_xticklabels(axes[1, 1].get_xticklabels(), rotation=45, ha='right') | |
| plt.tight_layout() | |
| plt.savefig('dashboard.png') | |
| plt.show()""" | |
| async def home(): | |
| """Redirect to the static index.html file""" | |
| return RedirectResponse(url="/static/index.html") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "version": "2.0.0", | |
| "gemini_available": GEMINI_AVAILABLE | |
| } | |
| async def list_models(): | |
| """List available models""" | |
| available_models = MODELS.copy() | |
| if not GEMINI_AVAILABLE: | |
| available_models["chatbot"] = "disabled (Gemini API unavailable)" | |
| return {"models": available_models} | |
| async def startup_event(): | |
| """Pre-load models at startup with timeout""" | |
| global translation_model, translation_tokenizer | |
| logger.info("Starting model pre-loading...") | |
| async def load_model_with_timeout(task): | |
| try: | |
| await asyncio.wait_for(asyncio.to_thread(load_model, task), timeout=120.0) | |
| logger.info(f"Successfully loaded {task} model") | |
| except asyncio.TimeoutError: | |
| logger.warning(f"Timeout loading {task} model - will load on demand") | |
| except Exception as e: | |
| logger.error(f"Error pre-loading {task}: {str(e)}") | |
| try: | |
| model_name = MODELS["translation"] | |
| translation_model = M2M100ForConditionalGeneration.from_pretrained(model_name, cache_dir=os.getenv('HF_HOME')) | |
| translation_tokenizer = M2M100Tokenizer.from_pretrained(model_name, cache_dir=os.getenv('HF_HOME')) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| translation_model.to(device) | |
| logger.info("Translation model pre-loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Error pre-loading translation model: {str(e)}") | |
| tasks = ["summarization", "image-to-text", "visual-qa", "file-qa"] | |
| if GEMINI_AVAILABLE: | |
| tasks.append("chatbot") | |
| await asyncio.gather(*(load_model_with_timeout(task) for task in tasks)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |