Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers.utils import logging | |
from googletrans import Translator | |
from transformers import ViltProcessor, ViltForQuestionAnswering, T5Tokenizer, T5ForConditionalGeneration | |
logging.set_verbosity_info() | |
logger = logging.get_logger("transformers") | |
# Translation | |
def google_translate(question, dest): | |
translator = Translator() | |
translation = translator.translate(question, dest=dest) | |
logger.info("Translation text: " + translation.text) | |
logger.info("Translation src: " + translation.src) | |
return (translation.text, translation.src) | |
# Load Vilt | |
vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
def vilt_vqa(image, question): | |
inputs = vilt_processor(image, question, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = vilt_model(**inputs) | |
logits = outputs.logits | |
idx = logits.argmax(-1).item() | |
answer = vilt_model.config.id2label[idx] | |
logger.info("ViLT: " + answer) | |
return answer | |
# Load FLAN-T5 | |
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large") | |
t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto") | |
def flan_t5_complete_sentence(question, answer): | |
input_text = f"A question: {question} An answer: {answer}. Based on these, answer the question with a complete sentence without extra information." | |
logger.info("T5 input: " + input_text) | |
inputs = t5_tokenizer(input_text, return_tensors="pt") | |
outputs = t5_model.generate(**inputs, max_length=50) | |
result_sentence = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
logger.info("T5 output: " + str(result_sentence)) | |
return result_sentence | |
# Main function | |
def vqa_main(image, question): | |
en_question, question_src_lang = google_translate(question, dest='en') | |
vqa_answer = vilt_vqa(image, en_question) | |
llm_answer = flan_t5_complete_sentence(en_question, vqa_answer)[0] | |
final_answer, answer_src_lang = google_translate(llm_answer, dest=question_src_lang) | |
logger.info("Final Answer: " + final_answer) | |
return final_answer | |
# Home page text | |
title = "Interactive demo: Multilingual VQA" | |
description = "Demo for Multilingual VQA. Upload an image, type a question, click 'submit', or click one of the examples to load them." | |
article = "article goes here" | |
# Load example images | |
torch.hub.download_url_to_file('http://farm3.staticflickr.com/2710/4520550856_7a9f9ea59d_z.jpg', 'apple.jpg') | |
torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg') | |
# Define home page variables | |
image = gr.inputs.Image(type="pil") | |
question = gr.inputs.Textbox(label="Question") | |
answer = gr.outputs.Textbox(label="Predicted answer") | |
examples = [["apple.jpg", "Qu'est-ce que c'est dans ma main?"], ["cats.jpg", "What are the cats doing?"]] | |
interface = gr.Interface(fn=vqa_main, | |
inputs=[image, question], | |
outputs="text", | |
examples=examples, | |
title=title, | |
description=description, | |
article=article, | |
enable_queue=True) | |
interface.launch(debug=True, show_error = True) |