Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
from transformers.utils import logging | |
from googletrans import Translator | |
from transformers import ViltProcessor, ViltForQuestionAnswering, T5Tokenizer, T5ForConditionalGeneration | |
logging.set_verbosity_info() | |
logger = logging.get_logger("transformers") | |
# Translation | |
def google_translate(question, dest): | |
translator = Translator() | |
translation = translator.translate(question, dest=dest) | |
logger.info("Translation text: " + translation.text) | |
logger.info("Translation src: " + translation.src) | |
return (translation.text, translation.src) | |
# Load Vilt | |
vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
def vilt_vqa(image, question): | |
inputs = vilt_processor(image, question, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = vilt_model(**inputs) | |
logits = outputs.logits | |
idx = logits.argmax(-1).item() | |
answer = vilt_model.config.id2label[idx] | |
logger.info("ViLT: " + answer) | |
return answer | |
# Load FLAN-T5 | |
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large") | |
t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto") | |
def flan_t5_complete_sentence(question, answer): | |
input_text = f"A question: {question} An answer: {answer}. Based on these, answer the question with a complete sentence without extra information." | |
logger.info("T5 input: " + input_text) | |
inputs = t5_tokenizer(input_text, return_tensors="pt") | |
outputs = t5_model.generate(**inputs, max_length=50) | |
result_sentence = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
logger.info("T5 output: " + str(result_sentence)) | |
return result_sentence | |
# Main function | |
def vqa_main(image, question): | |
en_question, question_src_lang = google_translate(question, dest='en') | |
vqa_answer = vilt_vqa(image, en_question) | |
llm_answer = flan_t5_complete_sentence(en_question, vqa_answer)[0] | |
final_answer, answer_src_lang = google_translate(llm_answer, dest=question_src_lang) | |
logger.info("Final Answer: " + final_answer) | |
return final_answer | |
# Home page text | |
title = "Interactive demo: Multilingual VQA" | |
description = """ | |
Upload an image, type a question, click 'submit', or click one of the examples to load them. | |
Note: You may see "Error" displayed for output due to Gradio's incompatibility, if that is the case, please click "Logs" above to obtain the output for your request. | |
""" | |
article = """ | |
Supported 107 Languages: 'afrikaans', 'albanian', 'amharic', 'arabic', 'armenian', 'azerbaijani', 'basque', 'belarusian', 'bengali', 'bosnian', 'bulgarian', 'catalan', 'cebuano', 'chichewa', 'chinese (simplified)', 'chinese (traditional)', 'corsican', 'croatian', 'czech', 'danish', 'dutch', 'english', 'esperanto', 'estonian', 'filipino', 'finnish', 'french', 'frisian', 'galician', 'georgian', 'german', 'greek', 'gujarati', 'haitian creole', 'hausa', 'hawaiian', 'hebrew', 'hebrew', 'hindi', 'hmong', 'hungarian', 'icelandic', 'igbo', 'indonesian', 'irish', 'italian', 'japanese', 'javanese', 'kannada', 'kazakh', 'khmer', 'korean', 'kurdish (kurmanji)', 'kyrgyz', 'lao', 'latin', 'latvian', 'lithuanian', 'luxembourgish', 'macedonian', 'malagasy', 'malay', 'malayalam', 'maltese', 'maori', 'marathi', 'mongolian', 'myanmar (burmese)', 'nepali', 'norwegian', 'odia', 'pashto', 'persian', 'polish', 'portuguese', 'punjabi', 'romanian', 'russian', 'samoan', 'scots gaelic', 'serbian', 'sesotho', 'shona', 'sindhi', 'sinhala', 'slovak', 'slovenian', 'somali', 'spanish', 'sundanese', 'swahili', 'swedish', 'tajik', 'tamil', 'telugu', 'thai', 'turkish', 'ukrainian', 'urdu', 'uyghur', 'uzbek', 'vietnamese', 'welsh', 'xhosa', 'yiddish', 'yoruba', 'zulu' | |
""" | |
# Load example images | |
torch.hub.download_url_to_file('http://farm3.staticflickr.com/2710/4520550856_7a9f9ea59d_z.jpg', 'apple.jpg') | |
torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg') | |
# Define home page variables | |
image = gr.inputs.Image(type="pil") | |
question = gr.inputs.Textbox(label="Question") | |
answer = gr.outputs.Textbox(label="Predicted answer") | |
examples = [["apple.jpg", "Qu'est-ce que c'est dans ma main?"], ["cats.jpg", "What are the cats doing?"]] | |
demo = gr.Interface(fn=vqa_main, | |
inputs=[image, question], | |
outputs="text", | |
examples=examples, | |
title=title, | |
description=description, | |
article=article, | |
enable_queue = True) | |
demo.launch(debug=True, show_error = True) |