Rehan3024's picture
Update app.py
674ccb7 verified
raw
history blame contribute delete
No virus
4.34 kB
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering, MarianMTModel, MarianTokenizer
import gradio as gr
import torch
import warnings
warnings.filterwarnings("ignore")
# Load BLIP models
captioning_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
captioning_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
# Dictionary to store translation models and tokenizers for different languages
translation_models = {
"Spanish": 'Helsinki-NLP/opus-mt-en-es',
"German": 'Helsinki-NLP/opus-mt-en-de',
"Chinese": 'Helsinki-NLP/opus-mt-en-zh',
"Japanese": 'Helsinki-NLP/opus-mt-en-ja',
"Russian": 'Helsinki-NLP/opus-mt-en-ru',
"Arabic": 'Helsinki-NLP/opus-mt-en-ar',
"Hindi": 'Helsinki-NLP/opus-mt-en-hi',
"Urdu": 'Helsinki-NLP/opus-mt-en-ur'
}
# Load translation models and tokenizers
loaded_translation_models = {}
loaded_translation_tokenizers = {}
for lang, model_name in translation_models.items():
try:
loaded_translation_models[lang] = MarianMTModel.from_pretrained(model_name)
loaded_translation_tokenizers[lang] = MarianTokenizer.from_pretrained(model_name)
print(f"Successfully loaded translation model for {lang}")
except Exception as e:
print(f"Error loading model for {lang}: {e}")
# Captioning function
def caption(image):
image = image.convert("RGB")
inputs = captioning_processor(image, return_tensors="pt")
out = captioning_model.generate(**inputs)
return captioning_processor.decode(out[0], skip_special_tokens=True)
# Visual Question Answering function
def qna(image, question):
image = image.convert("RGB")
inputs = processor(image, question, return_tensors="pt")
out = model.generate(**inputs)
return processor.decode(out[0], skip_special_tokens=True)
# Translation function
def translate_text(text, target_lang="Spanish"):
model = loaded_translation_models.get(target_lang)
tokenizer = loaded_translation_tokenizers.get(target_lang)
if model is None or tokenizer is None:
return f"Translation model for {target_lang} is not available."
inputs = tokenizer(text, return_tensors="pt")
translated = model.generate(**inputs)
return tokenizer.decode(translated[0], skip_special_tokens=True)
# Combined Captioning and Translation function
def caption_and_translate(image, target_lang="Spanish"):
caption_text = caption(image)
print(f"Generated caption: {caption_text}")
translated_caption = translate_text(caption_text, target_lang)
print(f"Translated caption: {translated_caption}")
return caption_text, translated_caption
# Create Gradio interfaces
interface1 = gr.Interface(fn=caption,
inputs=gr.components.Image(type="pil"),
outputs=gr.components.Textbox(label="Generated Caption by BLIP"),
description="BLIP Image Captioning")
interface2 = gr.Interface(fn=qna,
inputs=[gr.components.Image(type="pil"), gr.components.Textbox(label="Question")],
outputs=gr.components.Textbox(label="Answer generated by BLIP"),
description="BLIP Visual Question Answering of Images")
interface3 = gr.Interface(fn=caption_and_translate,
inputs=[gr.components.Image(type="pil"), gr.components.Dropdown(label="Target Language", choices=["Spanish", "German", "Chinese", "Japanese", "Russian", "Arabic", "Hindi", "Urdu"])],
outputs=[gr.components.Textbox(label="Generated Caption"),
gr.components.Textbox(label="Translated Caption")],
description="Image Captioning and Translation")
title = "Automated Image Captioning and Visual QnA Engine"
final_interface = gr.TabbedInterface([interface1, interface2, interface3],
["Captioning", "Visual QnA", "Captioning and Translation"],
title=title, theme=gr.themes.Soft())
final_interface.launch(inbrowser=True)