ixxan's picture
Update app.py
1b3fc99
raw
history blame
3.35 kB
import gradio as gr
import torch
from transformers.utils import logging
from googletrans import Translator
from transformers import ViltProcessor, ViltForQuestionAnswering, T5Tokenizer, T5ForConditionalGeneration
logging.set_verbosity_info()
logger = logging.get_logger("transformers")
# Translation
def google_translate(question, dest):
translator = Translator()
translation = translator.translate(question, dest=dest)
logger.info("Translation text: " + translation.text)
logger.info("Translation src: " + translation.src)
return (translation.text, translation.src)
# Load Vilt
vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
def vilt_vqa(image, question):
inputs = vilt_processor(image, question, return_tensors="pt")
with torch.no_grad():
outputs = vilt_model(**inputs)
logits = outputs.logits
idx = logits.argmax(-1).item()
answer = vilt_model.config.id2label[idx]
logger.info("ViLT: " + answer)
return answer
# Load FLAN-T5
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")
def flan_t5_complete_sentence(question, answer):
input_text = f"A question: {question} An answer: {answer}. Based on these, answer the question with a complete sentence without extra information."
logger.info("T5 input: " + input_text)
inputs = t5_tokenizer(input_text, return_tensors="pt")
outputs = t5_model.generate(**inputs, max_length=50)
result_sentence = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
logger.info("T5 output: " + str(result_sentence))
return result_sentence
# Main function
def vqa_main(image, question):
en_question, question_src_lang = google_translate(question, dest='en')
vqa_answer = vilt_vqa(image, en_question)
llm_answer = flan_t5_complete_sentence(en_question, vqa_answer)[0]
final_answer, answer_src_lang = google_translate(llm_answer, dest=question_src_lang)
logger.info("Final Answer: " + final_answer)
return final_answer
# Home page text
title = "Interactive demo: Multilingual VQA"
description = "Demo for Multilingual VQA. Upload an image, type a question, click 'submit', or click one of the examples to load them."
article = "article goes here"
# Load example images
torch.hub.download_url_to_file('http://farm3.staticflickr.com/2710/4520550856_7a9f9ea59d_z.jpg', 'apple.jpg')
torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
# Define home page variables
image = gr.inputs.Image(type="pil")
question = gr.inputs.Textbox(label="Question")
answer = gr.outputs.Textbox(label="Predicted answer")
examples = [["apple.jpg", "Qu'est-ce que c'est dans ma main?"], ["cats.jpg", "What are the cats doing?"]]
interface = gr.Interface(fn=vqa_main,
inputs=[image, question],
outputs="text",
examples=examples,
title=title,
description=description,
article=article,
enable_queue=True)
interface.launch(debug=True, show_error = True)