BLIPsinki2 / app.py
sophiaaez's picture
Update app.py
85ef56a
from PIL import Image
import requests
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import numpy as np
from transformers import pipeline
import gradio as gr
from models.blip import blip_decoder
image_size = 384
transform = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
model = blip_decoder(pretrained=model_url, image_size=384, vit='large')
model.eval()
model = model.to(device)
def getModelPath(language):
if language == 'English':
path = None
elif language == 'German':
path = "Helsinki-NLP/opus-mt-en-de"
elif language == 'French':
path = "Helsinki-NLP/opus-mt-en-fr"
elif language == 'Spanish':
path = "Helsinki-NLP/opus-mt-en-es"
elif language == 'Chinese':
path = "Helsinki-NLP/opus-mt-en-zh"
elif language == 'Ukranian':
path = "Helsinki-NLP/opus-mt-en-uk"
elif language == 'Swedish':
path = "Helsinki-NLP/opus-mt-en-sv"
elif language == 'Arabic':
path = "Helsinki-NLP/opus-mt-en-ar"
elif language == 'Italian':
path = "Helsinki-NLP/opus-mt-en-it"
elif language == 'Hindi':
path = "Helsinki-NLP/opus-mt-en-hi"
return(path)
def inference(input_img,strategy,language):
image = transform(input_img).unsqueeze(0).to(device)
with torch.no_grad():
if strategy == "Beam search":
cap = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
else:
cap = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
modelpath = getModelPath(language)
if modelpath:
translator = pipeline("translation", model=modelpath)
trans_cap = translator(cap[0])
tc = trans_cap[0]['translation_text']
return str(tc)
else:
return str(cap[0])
print("HI")
description = "A pipeline of BLIP image captioning and Helsinki translation in order to generate image captions in a language of your choice either with beam search (deterministic) or nucleus sampling (stochastic). Enjoy! Is the language you want to use missing? Let me know and I'll integrate it."
inputs_ = [gr.inputs.Image(type='pil', label="Input Image"),gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type="value", default="Nucleus sampling", label="Mode"), gr.inputs.Radio(choices=['English','German', 'French', 'Spanish', 'Chinese', 'Ukranian', 'Swedish', 'Arabic', 'Italian', 'Hindi'],type="value", default = 'German',label="Language")]
outputs_ = gr.outputs.Textbox(label="Output")
iface = gr.Interface(inference, inputs_, outputs_, description=description)
iface.launch(debug=True,show_error=True)