ixxan's picture
Update app.py
15c7e76
raw
history blame
No virus
4.73 kB
import gradio as gr
import torch
from transformers.utils import logging
from googletrans import Translator
from transformers import ViltProcessor, ViltForQuestionAnswering, T5Tokenizer, T5ForConditionalGeneration
logging.set_verbosity_info()
logger = logging.get_logger("transformers")
# Translation
def google_translate(question, dest):
translator = Translator()
translation = translator.translate(question, dest=dest)
logger.info("Translation text: " + translation.text)
logger.info("Translation src: " + translation.src)
return (translation.text, translation.src)
# Load Vilt
vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
def vilt_vqa(image, question):
inputs = vilt_processor(image, question, return_tensors="pt")
with torch.no_grad():
outputs = vilt_model(**inputs)
logits = outputs.logits
idx = logits.argmax(-1).item()
answer = vilt_model.config.id2label[idx]
logger.info("ViLT: " + answer)
return answer
# Load FLAN-T5
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")
def flan_t5_complete_sentence(question, answer):
input_text = f"A question: {question} An answer: {answer}. Based on these, answer the question with a complete sentence without extra information."
logger.info("T5 input: " + input_text)
inputs = t5_tokenizer(input_text, return_tensors="pt")
outputs = t5_model.generate(**inputs, max_length=50)
result_sentence = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
logger.info("T5 output: " + str(result_sentence))
return result_sentence
# Main function
def vqa_main(image, question):
en_question, question_src_lang = google_translate(question, dest='en')
vqa_answer = vilt_vqa(image, en_question)
llm_answer = flan_t5_complete_sentence(en_question, vqa_answer)[0]
final_answer, answer_src_lang = google_translate(llm_answer, dest=question_src_lang)
logger.info("Final Answer: " + final_answer)
return final_answer
# Home page text
title = "Interactive demo: Multilingual VQA"
description = """
Upload an image, type a question, click 'submit', or click one of the examples to load them.
Note: You may see "Error" displayed for output due to Gradio's incompatibility, if that is the case, please click "Logs" above to obtain the output for your request.
"""
article = """
Supported 107 Languages: 'afrikaans', 'albanian', 'amharic', 'arabic', 'armenian', 'azerbaijani', 'basque', 'belarusian', 'bengali', 'bosnian', 'bulgarian', 'catalan', 'cebuano', 'chichewa', 'chinese (simplified)', 'chinese (traditional)', 'corsican', 'croatian', 'czech', 'danish', 'dutch', 'english', 'esperanto', 'estonian', 'filipino', 'finnish', 'french', 'frisian', 'galician', 'georgian', 'german', 'greek', 'gujarati', 'haitian creole', 'hausa', 'hawaiian', 'hebrew', 'hebrew', 'hindi', 'hmong', 'hungarian', 'icelandic', 'igbo', 'indonesian', 'irish', 'italian', 'japanese', 'javanese', 'kannada', 'kazakh', 'khmer', 'korean', 'kurdish (kurmanji)', 'kyrgyz', 'lao', 'latin', 'latvian', 'lithuanian', 'luxembourgish', 'macedonian', 'malagasy', 'malay', 'malayalam', 'maltese', 'maori', 'marathi', 'mongolian', 'myanmar (burmese)', 'nepali', 'norwegian', 'odia', 'pashto', 'persian', 'polish', 'portuguese', 'punjabi', 'romanian', 'russian', 'samoan', 'scots gaelic', 'serbian', 'sesotho', 'shona', 'sindhi', 'sinhala', 'slovak', 'slovenian', 'somali', 'spanish', 'sundanese', 'swahili', 'swedish', 'tajik', 'tamil', 'telugu', 'thai', 'turkish', 'ukrainian', 'urdu', 'uyghur', 'uzbek', 'vietnamese', 'welsh', 'xhosa', 'yiddish', 'yoruba', 'zulu'
"""
# Load example images
torch.hub.download_url_to_file('http://farm3.staticflickr.com/2710/4520550856_7a9f9ea59d_z.jpg', 'apple.jpg')
torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
# Define home page variables
image = gr.inputs.Image(type="pil")
question = gr.inputs.Textbox(label="Question")
answer = gr.outputs.Textbox(label="Predicted answer")
examples = [["apple.jpg", "Qu'est-ce que c'est dans ma main?"], ["cats.jpg", "What are the cats doing?"]]
demo = gr.Interface(fn=vqa_main,
inputs=[image, question],
outputs="text",
examples=examples,
title=title,
description=description,
article=article,
enable_queue = True)
demo.launch(debug=True, show_error = True)