File size: 4,729 Bytes
aca743e
 
02149a3
 
7e69a2f
 
 
02149a3
 
aca743e
7e69a2f
 
 
 
 
 
 
 
dc1676e
 
 
775f1ae
aca743e
087ba34
 
 
 
 
 
 
 
aca743e
775f1ae
d9f146b
 
775f1ae
 
1b3fc99
087ba34
 
 
 
682d1ce
087ba34
775f1ae
 
 
7e69a2f
 
 
 
087ba34
7e69a2f
775f1ae
 
 
15c7e76
 
 
 
043b7ee
15c7e76
 
 
043b7ee
775f1ae
 
7e69a2f
775f1ae
 
 
aca743e
 
 
7e69a2f
aca743e
b1438bb
aca743e
d68f037
aca743e
 
 
a1a9264
 
b1438bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import torch
from transformers.utils import logging

from googletrans import Translator
from transformers import ViltProcessor, ViltForQuestionAnswering, T5Tokenizer, T5ForConditionalGeneration

logging.set_verbosity_info()
logger = logging.get_logger("transformers")

# Translation 
def google_translate(question, dest):
    translator = Translator()
    translation = translator.translate(question, dest=dest)
    logger.info("Translation text: " + translation.text)
    logger.info("Translation src: " + translation.src)
    return (translation.text, translation.src)

# Load Vilt 
vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

def vilt_vqa(image, question):
    inputs = vilt_processor(image, question, return_tensors="pt")
    with torch.no_grad():
        outputs = vilt_model(**inputs)
    logits = outputs.logits
    idx = logits.argmax(-1).item()
    answer = vilt_model.config.id2label[idx]
    logger.info("ViLT: " + answer)
    return answer

# Load FLAN-T5
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")

def flan_t5_complete_sentence(question, answer):
    input_text = f"A question: {question} An answer: {answer}. Based on these, answer the question with a complete sentence without extra information."
    logger.info("T5 input: " + input_text)
    inputs = t5_tokenizer(input_text, return_tensors="pt")
    outputs = t5_model.generate(**inputs, max_length=50)
    result_sentence = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
    logger.info("T5 output: " + str(result_sentence))
    return result_sentence

# Main function
def vqa_main(image, question):
    en_question, question_src_lang = google_translate(question, dest='en')
    vqa_answer = vilt_vqa(image, en_question)
    llm_answer = flan_t5_complete_sentence(en_question, vqa_answer)[0]
    final_answer, answer_src_lang = google_translate(llm_answer, dest=question_src_lang)
    logger.info("Final Answer: " + final_answer)
    return final_answer
    
# Home page text
title = "Interactive demo: Multilingual VQA"
description = """
Upload an image, type a question, click 'submit', or click one of the examples to load them.

Note: You may see "Error" displayed for output due to Gradio's incompatibility, if that is the case, please click "Logs" above to obtain the output for your request.

"""
article = """
Supported 107 Languages: 'afrikaans', 'albanian', 'amharic', 'arabic', 'armenian', 'azerbaijani', 'basque', 'belarusian', 'bengali', 'bosnian', 'bulgarian', 'catalan', 'cebuano', 'chichewa', 'chinese (simplified)', 'chinese (traditional)', 'corsican', 'croatian', 'czech', 'danish', 'dutch', 'english', 'esperanto', 'estonian', 'filipino', 'finnish', 'french', 'frisian', 'galician', 'georgian', 'german', 'greek', 'gujarati', 'haitian creole', 'hausa', 'hawaiian', 'hebrew', 'hebrew', 'hindi', 'hmong', 'hungarian', 'icelandic', 'igbo', 'indonesian', 'irish', 'italian', 'japanese', 'javanese', 'kannada', 'kazakh', 'khmer', 'korean', 'kurdish (kurmanji)', 'kyrgyz', 'lao', 'latin', 'latvian', 'lithuanian', 'luxembourgish', 'macedonian', 'malagasy', 'malay', 'malayalam', 'maltese', 'maori', 'marathi', 'mongolian', 'myanmar (burmese)', 'nepali', 'norwegian', 'odia', 'pashto', 'persian', 'polish', 'portuguese', 'punjabi', 'romanian', 'russian', 'samoan', 'scots gaelic', 'serbian', 'sesotho', 'shona', 'sindhi', 'sinhala', 'slovak', 'slovenian', 'somali', 'spanish', 'sundanese', 'swahili', 'swedish', 'tajik', 'tamil', 'telugu', 'thai', 'turkish', 'ukrainian', 'urdu', 'uyghur', 'uzbek', 'vietnamese', 'welsh', 'xhosa', 'yiddish', 'yoruba', 'zulu'
"""

# Load example images 
torch.hub.download_url_to_file('http://farm3.staticflickr.com/2710/4520550856_7a9f9ea59d_z.jpg', 'apple.jpg')
torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')

# Define home page variables 
image = gr.inputs.Image(type="pil")
question = gr.inputs.Textbox(label="Question")
answer = gr.outputs.Textbox(label="Predicted answer")
examples = [["apple.jpg", "Qu'est-ce que c'est dans ma main?"], ["cats.jpg", "What are the cats doing?"]]

demo = gr.Interface(fn=vqa_main, 
                         inputs=[image, question], 
                         outputs="text",
                         examples=examples, 
                         title=title,
                         description=description,
                         article=article, 
                         enable_queue = True)
demo.launch(debug=True, show_error = True)