Spaces:
Running
Running
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse | |
from transformers import pipeline, ViltProcessor, ViltForQuestionAnswering, M2M100ForConditionalGeneration, M2M100Tokenizer | |
from typing import Optional, Dict, Any, List | |
import logging | |
import time | |
import os | |
import io | |
import json | |
import re | |
from PIL import Image | |
from docx import Document | |
import fitz # PyMuPDF | |
import pandas as pd | |
from functools import lru_cache | |
import torch | |
import numpy as np | |
from pydantic import BaseModel | |
import asyncio | |
import google.generativeai as genai | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger("cosmic_ai") | |
# Create app directory if it doesn't exist | |
upload_dir = os.getenv('UPLOAD_DIR', '/tmp/uploads') | |
os.makedirs(upload_dir, exist_ok=True) | |
app = FastAPI( | |
title="Cosmic AI Assistant", | |
description="An advanced AI assistant with space-themed interface and translation features", | |
version="2.0.0" | |
) | |
# Mount static files | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Gemini API Configuration | |
API_KEY = "AIzaSyCwmgD8KxzWiuivtySNtcZF_rfTvx9s9sY" # Replace with your actual API key | |
genai.configure(api_key=API_KEY) | |
# Model configurations | |
MODELS = { | |
"summarization": "sshleifer/distilbart-cnn-12-6", | |
"image-to-text": "Salesforce/blip-image-captioning-large", | |
"visual-qa": "dandelin/vilt-b32-finetuned-vqa", | |
"chatbot": "gemini-1.5-pro", # Handles both chat and text generation | |
"translation": "facebook/m2m100_418M" | |
} | |
# Supported languages for translation | |
SUPPORTED_LANGUAGES = { | |
"english": "en", | |
"french": "fr", | |
"german": "de", | |
"spanish": "es", | |
"italian": "it", | |
"russian": "ru", | |
"chinese": "zh", | |
"japanese": "ja", | |
"arabic": "ar", | |
"hindi": "hi", | |
"portuguese": "pt", | |
"korean": "ko" | |
} | |
# Global variables for pre-loaded translation model | |
translation_model = None | |
translation_tokenizer = None | |
# Cache for model loading (excluding translation) | |
def load_model(task: str, model_name: str = None): | |
"""Cached model loader with proper task names and error handling""" | |
try: | |
logger.info(f"Loading model for task: {task}, model: {model_name or MODELS.get(task)}") | |
start_time = time.time() | |
model_to_load = model_name or MODELS.get(task) | |
if task == "chatbot": # Gemini handles both chat and text generation | |
return genai.GenerativeModel(model_to_load) | |
if task == "visual-qa": | |
processor = ViltProcessor.from_pretrained(model_to_load) | |
model = ViltForQuestionAnswering.from_pretrained(model_to_load) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
def vqa_function(image, question, **generate_kwargs): | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
inputs = processor(image, question, return_tensors="pt").to(device) | |
logger.info(f"VQA inputs - question: {question}, image size: {image.size}") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
idx = logits.argmax(-1).item() | |
answer = model.config.id2label[idx] | |
logger.info(f"VQA raw output: {answer}") | |
return answer | |
return vqa_function | |
return pipeline(task, model=model_to_load) | |
except Exception as e: | |
logger.error(f"Model load failed: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}") | |
def get_gemini_response(user_input: str, is_generation: bool = False): | |
"""Function to generate response with Gemini for both chat and text generation""" | |
if not user_input: | |
return "Please provide some input." | |
try: | |
chatbot = load_model("chatbot") | |
if is_generation: | |
prompt = f"Generate creative text based on this prompt: {user_input}" | |
else: | |
prompt = user_input | |
response = chatbot.generate_content(prompt) | |
return response.text.strip() | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def translate_text(text: str, target_language: str): | |
"""Translate text to any target language using pre-loaded M2M100 model""" | |
if not text: | |
return "Please provide text to translate." | |
try: | |
global translation_model, translation_tokenizer | |
target_lang = target_language.lower() | |
if target_lang not in SUPPORTED_LANGUAGES: | |
similar = [lang for lang in SUPPORTED_LANGUAGES if target_lang in lang or lang in target_lang] | |
if similar: | |
target_lang = similar[0] | |
else: | |
return f"Language '{target_language}' not supported. Available languages: {', '.join(SUPPORTED_LANGUAGES.keys())}" | |
lang_code = SUPPORTED_LANGUAGES[target_lang] | |
if translation_model is None or translation_tokenizer is None: | |
raise Exception("Translation model not initialized") | |
match = re.search(r'how to say\s+(.+?)\s+in\s+(\w+)', text.lower()) | |
if match: | |
text_to_translate = match.group(1) | |
else: | |
content_match = re.search(r'(?:translate|convert).*to\s+[a-zA-Z]+\s*[:\s]*(.+)', text, re.IGNORECASE) | |
text_to_translate = content_match.group(1) if content_match else text | |
translation_tokenizer.src_lang = "en" | |
encoded = translation_tokenizer(text_to_translate, return_tensors="pt", padding=True, truncation=True).to(translation_model.device) | |
start_time = time.time() | |
generated_tokens = translation_model.generate( | |
**encoded, | |
forced_bos_token_id=translation_tokenizer.get_lang_id(lang_code), | |
max_length=512, | |
num_beams=1, | |
early_stopping=True | |
) | |
translated_text = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
logger.info(f"Translation took {time.time() - start_time:.2f} seconds") | |
return translated_text | |
except Exception as e: | |
logger.error(f"Translation error: {str(e)}", exc_info=True) | |
return f"Translation error: {str(e)}" | |
def detect_intent(text: str = None, file: UploadFile = None) -> tuple[str, str]: | |
"""Enhanced intent detection with dynamic translation support""" | |
target_language = "English" # Default | |
if file: | |
content_type = file.content_type.lower() if file.content_type else "" | |
filename = file.filename.lower() if file.filename else "" | |
# Catch "what’s this" and "does this fly" first for images | |
if content_type.startswith('image/') and text: | |
text_lower = text.lower() | |
if "what’s this" in text_lower: | |
return "visual-qa", target_language | |
if "does this fly" in text_lower: | |
return "visual-qa", target_language | |
# Broaden "fly" questions for VQA | |
if "fly" in text_lower and any(q in text_lower for q in ['does', 'can', 'will']): | |
return "visual-qa", target_language | |
if content_type.startswith('image/'): | |
if text and any(q in text.lower() for q in ['what is', 'what\'s', 'describe', 'tell me about', 'explain','how many', 'what color', 'is there', 'are they', 'does the']): | |
return "visual-qa", target_language | |
return "image-to-text", target_language | |
elif filename.endswith(('.xlsx', '.xls', '.csv')): | |
return "visualize", target_language | |
elif filename.endswith(('.pdf', '.docx', '.doc', '.txt', '.rtf')): | |
return "summarize", target_language | |
if not text: | |
return "chatbot", target_language | |
text_lower = text.lower() | |
if any(keyword in text_lower for keyword in ['chat', 'talk', 'converse', 'ask gemini']): | |
return "chatbot", target_language | |
translate_patterns = [ | |
r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)', | |
r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)', | |
r'how to say.*in\s+\[?([a-zA-Z]+)\]?:?\s*(.*)' | |
] | |
for pattern in translate_patterns: | |
translate_match = re.search(pattern, text_lower) | |
if translate_match: | |
potential_lang = translate_match.group(1).lower() | |
if potential_lang in SUPPORTED_LANGUAGES: | |
target_language = potential_lang.capitalize() | |
return "translate", target_language | |
else: | |
logger.warning(f"Invalid language detected: {potential_lang}") | |
return "chatbot", target_language | |
vqa_patterns = [ | |
r'how (many|much)', | |
r'what (color|size|position|shape)', | |
r'is (there|that|this) (a|an)', | |
r'are (they|there) (any|some)', | |
r'does (the|this) (image|picture) (show|contain)' | |
] | |
if any(re.search(pattern, text_lower) for pattern in vqa_patterns): | |
return "visual-qa", target_language | |
summarization_patterns = [ | |
r'\b(summar(y|ize|ise)|brief( overview)?)\b', | |
r'\b(long article|text|document)\b', | |
r'\bcan you (summar|brief|condense)\b', | |
r'\b(short summary|brief explanation)\b', | |
r'\b(overview|main points|key ideas)\b', | |
r'\b(tl;?dr|too long didn\'?t read)\b' | |
] | |
if any(re.search(pattern, text_lower) for pattern in summarization_patterns): | |
return "summarize", target_language | |
generation_patterns = [ | |
r'\b(write|generate|create|compose)\b', | |
r'\b(story|poem|essay|text|content)\b' | |
] | |
if any(re.search(pattern, text_lower) for pattern in generation_patterns): | |
return "text-generation", target_language | |
if len(text) > 100: | |
return "summarize", target_language | |
if file and file.content_type and file.content_type.startswith('image/'): | |
if text and "what’s this" in text_lower: | |
return "visual-qa", target_language | |
if text and any(q in text_lower for q in ['does this', 'is this', 'can this']): | |
return "visual-qa", target_language | |
return "chatbot", target_language | |
class ProcessResponse(BaseModel): | |
response: str | |
type: str | |
additional_data: Optional[Dict[str, Any]] = None | |
async def chatbot_interface(): | |
"""Redirect to the static index.html file for the chatbot interface""" | |
return RedirectResponse(url="/static/index.html") | |
async def chat_endpoint(data: dict): | |
message = data.get("message", "") | |
if not message: | |
raise HTTPException(status_code=400, detail="No message provided") | |
try: | |
response = get_gemini_response(message) | |
return {"response": response} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}") | |
async def process_input( | |
request: Request, | |
text: str = Form(None), | |
file: UploadFile = File(None) | |
): | |
"""Enhanced unified endpoint with dynamic translation""" | |
start_time = time.time() | |
client_ip = request.client.host | |
logger.info(f"Request from {client_ip}: text={text[:50] + '...' if text and len(text) > 50 else text}, file={file.filename if file else None}") | |
intent, target_language = detect_intent(text, file) | |
logger.info(f"Detected intent: {intent}, target_language: {target_language}") | |
try: | |
if intent == "chatbot": | |
response = get_gemini_response(text) | |
return {"response": response, "type": "chat"} | |
elif intent == "translate": | |
content = await extract_text_from_file(file) if file else text | |
if "all languages" in text.lower(): | |
translations = {} | |
phrase_to_translate = "I want to explore the stars" if "I want to explore the stars" in text else content | |
for lang, code in SUPPORTED_LANGUAGES.items(): | |
translation_tokenizer.src_lang = "en" | |
encoded = translation_tokenizer(phrase_to_translate, return_tensors="pt").to(translation_model.device) | |
generated_tokens = translation_model.generate( | |
**encoded, | |
forced_bos_token_id=translation_tokenizer.get_lang_id(code), | |
max_length=512, | |
num_beams=1 | |
) | |
translations[lang] = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
response = "\n".join(f"{lang.capitalize()}: {translations[lang]}" for lang in translations) | |
logger.info(f"Translated to all supported languages: {', '.join(translations.keys())}") | |
return {"response": response, "type": "translation"} | |
else: | |
translated_text = translate_text(content, target_language) | |
return {"response": translated_text, "type": "translation"} | |
elif intent == "summarize": | |
content = await extract_text_from_file(file) if file else text | |
summarizer = load_model("summarization") | |
content_length = len(content.split()) | |
max_len = max(30, min(150, content_length//2)) | |
min_len = max(15, min(30, max_len//2)) | |
if len(content) > 1024: | |
chunks = [content[i:i+1024] for i in range(0, len(content), 1024)] | |
summaries = [] | |
for chunk in chunks[:3]: | |
summary = summarizer( | |
chunk, | |
max_length=max_len, | |
min_length=min_len, | |
do_sample=False, | |
truncation=True | |
) | |
summaries.append(summary[0]['summary_text']) | |
final_summary = " ".join(summaries) | |
else: | |
summary = summarizer( | |
content, | |
max_length=max_len, | |
min_length=min_len, | |
do_sample=False, | |
truncation=True | |
) | |
final_summary = summary[0]['summary_text'] | |
final_summary = re.sub(r'\s+', ' ', final_summary).strip() | |
return {"response": final_summary, "type": "summary"} | |
elif intent == "image-to-text": | |
if not file or not file.content_type.startswith('image/'): | |
raise HTTPException(status_code=400, detail="An image file is required") | |
image = Image.open(io.BytesIO(await file.read())) | |
captioner = load_model("image-to-text") | |
caption = captioner(image, max_new_tokens=50) | |
return {"response": caption[0]['generated_text'], "type": "caption"} | |
elif intent == "visual-qa": | |
if not file or not file.content_type.startswith('image/'): | |
raise HTTPException(status_code=400, detail="An image file is required") | |
if not text: | |
raise HTTPException(status_code=400, detail="A question is required for VQA") | |
image = Image.open(io.BytesIO(await file.read())).convert("RGB") | |
vqa_pipeline = load_model("visual-qa") | |
question = text.strip() | |
if not question.endswith('?'): | |
question += '?' | |
answer = vqa_pipeline( | |
image=image, | |
question=question | |
) | |
answer = answer.strip() | |
if not answer or answer.lower() == question.lower(): | |
logger.warning(f"VQA failed to generate a meaningful answer: {answer}") | |
answer = "I couldn't determine the answer from the image." | |
else: | |
answer = answer.capitalize() | |
if not answer.endswith(('.', '!', '?')): | |
answer += '.' | |
chatbot = load_model("chatbot") | |
if "fly" in question.lower(): | |
answer = chatbot.generate_content(f"Make this fun and spacey: {answer}").text.strip() | |
else: | |
answer = chatbot.generate_content(f"Make this cosmic and poetic: {answer}").text.strip() | |
logger.info(f"Final VQA answer: {answer}") | |
return { | |
"response": answer, | |
"type": "visual_qa", | |
"additional_data": { | |
"question": text, | |
"image_size": f"{image.width}x{image.height}" | |
} | |
} | |
elif intent == "visualize": | |
if not file: | |
raise HTTPException(status_code=400, detail="An Excel file is required") | |
file_content = await file.read() | |
if file.filename.endswith('.csv'): | |
df = pd.read_csv(io.BytesIO(file_content)) | |
else: | |
df = pd.read_excel(io.BytesIO(file_content)) | |
code = generate_visualization_code(df, text) | |
stats = df.describe().to_string() | |
response = f"Stats:\n{stats}\n\nChart Code:\n{code}" | |
return {"response": response, "type": "visualization_code"} | |
elif intent == "text-generation": | |
response = get_gemini_response(text, is_generation=True) | |
lines = response.split(". ") | |
formatted_poem = "\n".join(line.strip() + ("." if not line.endswith(".") else "") for line in lines if line) | |
return {"response": formatted_poem, "type": "generated_text"} | |
else: | |
response = get_gemini_response(text or "Hello! How can I assist you?") | |
return {"response": response, "type": "chat"} | |
except Exception as e: | |
logger.error(f"Processing error: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=500, detail=str(e)) | |
finally: | |
process_time = time.time() - start_time | |
logger.info(f"Request processed in {process_time:.2f} seconds") | |
async def extract_text_from_file(file: UploadFile) -> str: | |
"""Enhanced text extraction with multiple fallbacks""" | |
if not file: | |
return "" | |
content = await file.read() | |
filename = file.filename.lower() | |
try: | |
if filename.endswith('.pdf'): | |
try: | |
doc = fitz.open(stream=content, filetype="pdf") | |
if doc.is_encrypted: | |
return "PDF is encrypted and cannot be read" | |
text = "" | |
for page in doc: | |
text += page.get_text() | |
return text | |
except Exception as pdf_error: | |
logger.warning(f"PyMuPDF failed: {str(pdf_error)}. Trying pdfminer.six...") | |
from pdfminer.high_level import extract_text | |
from io import BytesIO | |
return extract_text(BytesIO(content)) | |
elif filename.endswith(('.docx', '.doc')): | |
doc = Document(io.BytesIO(content)) | |
return "\n".join(para.text for para in doc.paragraphs) | |
elif filename.endswith('.txt'): | |
return content.decode('utf-8', errors='replace') | |
elif filename.endswith('.rtf'): | |
text = content.decode('utf-8', errors='replace') | |
text = re.sub(r'\\[a-z]+', ' ', text) | |
text = re.sub(r'\{|\}|\\', '', text) | |
return text | |
else: | |
raise HTTPException(status_code=400, detail=f"Unsupported file format: {filename}") | |
except Exception as e: | |
logger.error(f"File extraction error: {str(e)}", exc_info=True) | |
raise HTTPException( | |
status_code=500, | |
detail=f"Error extracting text: {str(e)}. Supported formats: PDF, DOCX, TXT, RTF" | |
) | |
def generate_visualization_code(df: pd.DataFrame, request: str = None) -> str: | |
"""Generate visualization code based on data analysis""" | |
num_rows, num_cols = df.shape | |
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() | |
categorical_cols = df.select_dtypes(include=['object']).columns.tolist() | |
date_cols = [col for col in df.columns if df[col].dtype == 'datetime64[ns]' or | |
(isinstance(df[col].dtype, object) and pd.to_datetime(df[col], errors='coerce').notna().all())] | |
if request: | |
request_lower = request.lower() | |
else: | |
request_lower = "" | |
if len(numeric_cols) >= 2 and ("scatter" in request_lower or "correlation" in request_lower): | |
x_col = numeric_cols[0] | |
y_col = numeric_cols[1] | |
return f"""import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
df = pd.read_excel('data.xlsx') | |
plt.figure(figsize=(10, 6)) | |
sns.regplot(x='{x_col}', y='{y_col}', data=df, scatter_kws={{'alpha': 0.6}}) | |
plt.title('Correlation between {x_col} and {y_col}') | |
plt.grid(True, alpha=0.3) | |
plt.tight_layout() | |
plt.savefig('correlation_plot.png') | |
plt.show() | |
correlation = df['{x_col}'].corr(df['{y_col}']) | |
print(f"Correlation coefficient: {{correlation:.4f}}")""" | |
elif len(numeric_cols) >= 1 and len(categorical_cols) >= 1 and ("bar" in request_lower or "comparison" in request_lower): | |
cat_col = categorical_cols[0] | |
num_col = numeric_cols[0] | |
return f"""import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
df = pd.read_excel('data.xlsx') | |
plt.figure(figsize=(12, 7)) | |
ax = sns.barplot(x='{cat_col}', y='{num_col}', data=df, palette='viridis') | |
for p in ax.patches: | |
ax.annotate(f'{{p.get_height():.1f}}', | |
(p.get_x() + p.get_width() / 2., p.get_height()), | |
ha='center', va='bottom', fontsize=10, color='black', xytext=(0, 5), | |
textcoords='offset points') | |
plt.title('Comparison of {num_col} by {cat_col}', fontsize=15) | |
plt.xlabel('{cat_col}', fontsize=12) | |
plt.ylabel('{num_col}', fontsize=12) | |
plt.xticks(rotation=45, ha='right') | |
plt.grid(axis='y', alpha=0.3) | |
plt.tight_layout() | |
plt.savefig('comparison_chart.png') | |
plt.show()""" | |
elif len(numeric_cols) >= 1 and ("distribution" in request_lower or "histogram" in request_lower): | |
num_col = numeric_cols[0] | |
return f"""import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
df = pd.read_excel('data.xlsx') | |
plt.figure(figsize=(10, 6)) | |
sns.histplot(df['{num_col}'], kde=True, bins=20, color='purple') | |
plt.title('Distribution of {num_col}', fontsize=15) | |
plt.xlabel('{num_col}', fontsize=12) | |
plt.ylabel('Frequency', fontsize=12) | |
plt.grid(True, alpha=0.3) | |
plt.tight_layout() | |
plt.savefig('distribution_plot.png') | |
plt.show() | |
print(df['{num_col}'].describe())""" | |
else: | |
return f"""import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
df = pd.read_excel('data.xlsx') | |
print("Descriptive statistics:") | |
print(df.describe()) | |
fig, axes = plt.subplots(2, 2, figsize=(15, 12)) | |
numeric_df = df.select_dtypes(include=[np.number]) | |
if not numeric_df.empty and numeric_df.shape[1] > 1: | |
sns.heatmap(numeric_df.corr(), annot=True, cmap='coolwarm', fmt='.2f', ax=axes[0, 0]) | |
axes[0, 0].set_title('Correlation Matrix') | |
if not numeric_df.empty: | |
for i, col in enumerate(numeric_df.columns[:1]): | |
sns.histplot(df[col], kde=True, ax=axes[0, 1], color='purple') | |
axes[0, 1].set_title(f'Distribution of {{col}}') | |
axes[0, 1].set_xlabel(col) | |
axes[0, 1].set_ylabel('Frequency') | |
categorical_cols = df.select_dtypes(include=['object']).columns | |
if len(categorical_cols) > 0 and not numeric_df.empty: | |
cat_col = categorical_cols[0] | |
num_col = numeric_df.columns[0] | |
sns.barplot(x=cat_col, y=num_col, data=df, ax=axes[1, 0], palette='viridis') | |
axes[1, 0].set_title(f'{{num_col}} by {{cat_col}}') | |
axes[1, 0].set_xticklabels(axes[1, 0].get_xticklabels(), rotation=45, ha='right') | |
if not numeric_df.empty and len(categorical_cols) > 0: | |
cat_col = categorical_cols[0] | |
num_col = numeric_df.columns[0] | |
sns.boxplot(x=cat_col, y=num_col, data=df, ax=axes[1, 1], palette='Set3') | |
axes[1, 1].set_title(f'Distribution of {{num_col}} by {{cat_col}}') | |
axes[1, 1].set_xticklabels(axes[1, 1].get_xticklabels(), rotation=45, ha='right') | |
plt.tight_layout() | |
plt.savefig('dashboard.png') | |
plt.show()""" | |
async def home(): | |
"""Redirect to the static index.html file""" | |
return RedirectResponse(url="/static/index.html") | |
async def health_check(): | |
"""Health check endpoint""" | |
return {"status": "healthy", "version": "2.0.0"} | |
async def list_models(): | |
"""List available models""" | |
return {"models": MODELS} | |
async def startup_event(): | |
"""Pre-load models at startup with timeout""" | |
global translation_model, translation_tokenizer | |
logger.info("Starting model pre-loading...") | |
async def load_model_with_timeout(task): | |
try: | |
await asyncio.wait_for(asyncio.to_thread(load_model, task), timeout=60.0) | |
logger.info(f"Successfully loaded {task} model") | |
except asyncio.TimeoutError: | |
logger.warning(f"Timeout loading {task} model - will load on demand") | |
except Exception as e: | |
logger.error(f"Error pre-loading {task}: {str(e)}") | |
try: | |
model_name = MODELS["translation"] | |
translation_model = M2M100ForConditionalGeneration.from_pretrained(model_name) | |
translation_tokenizer = M2M100Tokenizer.from_pretrained(model_name) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
translation_model.to(device) | |
logger.info("Translation model pre-loaded successfully") | |
except Exception as e: | |
logger.error(f"Error pre-loading translation model: {str(e)}") | |
await asyncio.gather( | |
load_model_with_timeout("summarization"), | |
load_model_with_timeout("image-to-text"), | |
load_model_with_timeout("visual-qa"), | |
load_model_with_timeout("chatbot") | |
) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |