File size: 7,840 Bytes
aca743e
8e8c9ba
aca743e
02149a3
4437cf3
02149a3
4437cf3
 
7e69a2f
9c9a73e
 
 
 
 
 
7e69a2f
02149a3
 
aca743e
7e69a2f
 
 
 
 
 
 
 
9c9a73e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc1676e
 
 
775f1ae
aca743e
087ba34
 
 
 
 
 
 
30a0050
 
 
 
 
087ba34
aca743e
775f1ae
d9f146b
 
775f1ae
 
9c9a73e
 
 
 
 
 
 
47a8d15
9c9a73e
 
 
 
 
 
 
 
47a8d15
 
 
17bc88c
9c9a73e
 
087ba34
775f1ae
 
 
7e69a2f
9c9a73e
647ec6f
 
 
9c9a73e
 
 
775f1ae
 
bf9d6b8
15c7e76
 
 
fb014ca
043b7ee
15c7e76
 
26eb6b9
043b7ee
775f1ae
 
7e69a2f
775f1ae
05c3988
775f1ae
 
ce05662
 
 
05c3988
fb014ca
 
 
05c3988
5106646
 
fb014ca
aca743e
b1438bb
aca743e
d68f037
aca743e
 
 
2c4a4e3
b1438bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gradio as gr
import re
import torch
from transformers.utils import logging
from transformers import ViltProcessor, ViltForQuestionAnswering, T5Tokenizer, T5ForConditionalGeneration

import httpcore
setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')  # set SyncHTTPTransport attribute for googletrans dependency
from googletrans import Translator
from googletrans import LANGCODES

# List of acceptable languages
acceptable_languages = set(L.split()[0] for L in LANGCODES)
acceptable_languages.add("mandarin")
acceptable_languages.add("cantonese")

logging.set_verbosity_info()
logger = logging.get_logger("transformers")

# Translation 
def google_translate(question, dest):
    translator = Translator()
    translation = translator.translate(question, dest=dest)
    logger.info("Translation text: " + translation.text)
    logger.info("Translation src: " + translation.src)
    return (translation.text, translation.src)

# Lang to lang_code mapping 
def lang_code_match(accaptable_lang):
    # Exception for chinese langs
    if accaptable_lang == 'mandarin':
        return 'zh-cn'
    elif accaptable_lang == 'cantonese' or accaptable_lang == 'chinese':
        return 'zh-tw'
    # Default 
    else:
        return LANGCODES[accaptable_lang]
    
# Find destination language
def find_dest_language(sentence, src_lang):
    pattern = r'\b(' + '|'.join(acceptable_languages) + r')\b'
    match = re.search(pattern, sentence, flags=re.IGNORECASE)
    if match:
        lang_code = lang_code_match(match.group(0).lower())
        logger.info("Destination lang: " + lang_code)
        return lang_code
    else:
        logger.info("Destination lang:" + src_lang)
        return src_lang

# Remove destination language context
def remove_language_phrase(sentence):
    # Bremove "in [acceptable_languages]" or "[acceptable_languages]" and any non-closing punctuation around it
    pattern = r'(\b(in\s)?(' + '|'.join(acceptable_languages) + r')\b)[\s,;:.!?]*'
    cleaned_sentence = re.sub(pattern, '', sentence, flags=re.IGNORECASE).strip()
    logger.info("Language Phrase Removed: " + cleaned_sentence)
    return cleaned_sentence


# Load Vilt 
vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

def vilt_vqa(image, question):
    inputs = vilt_processor(image, question, return_tensors="pt")
    with torch.no_grad():
        outputs = vilt_model(**inputs)
    logits = outputs.logits
    idx = logits.argmax(-1).item()
    answer = vilt_model.config.id2label[idx]
    logger.info("ViLT: " + answer)

    # Get the top 10 scores and their indices
    topk_values, topk_indices = torch.topk(logits, 10, dim=-1)
    topk_answers = [vilt_model.config.id2label[idx.item()] for idx in topk_indices[0]]
    logger.info("ViLT top 10 answers: " + str(topk_answers))
    return answer

# Load FLAN-T5
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")

def flan_t5_complete_sentence(question, answer):
    # #input_text = f"A question: {question} An answer: {answer}. Based on these, answer the question with a complete sentence without extra information."
    # input_text = f"What language is this question asking about: {question}"
    # logger.info("T5 input: " + input_text)
    # inputs = t5_tokenizer(input_text, return_tensors="pt")
    # outputs = t5_model.generate(**inputs, max_length=50)
    # result_sentence = ''.join(t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
    # logger.info("T5 output1: " + result_sentence)
    
    # input_text = f"Translate to {str(result_sentence)}:  {answer}"
    # logger.info("T5 input: " + input_text)
    # inputs = t5_tokenizer(input_text, return_tensors="pt")
    # outputs = t5_model.generate(**inputs, max_length=50)
    # result_sentence = ''.join(t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
    # logger.info("T5 output2: " + result_sentence)
    
    input_text = f"A question: {question} An answer: {answer}. Based on these, answer the question with a complete sentence without extra information."
    logger.info("T5 input: " + input_text)
    inputs = t5_tokenizer(input_text, return_tensors="pt")
    outputs = t5_model.generate(**inputs, max_length=50)
    result_sentence = ''.join(t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
    logger.info("T5 output: " + result_sentence)
    
    return result_sentence

# Main function
def vqa_main(image, question):
    en_question, question_src_lang = google_translate(question, dest='en')
    dest_lang = find_dest_language(en_question, question_src_lang)
    cleaned_question = remove_language_phrase(en_question)
    vqa_answer = vilt_vqa(image, cleaned_question)
    llm_answer = flan_t5_complete_sentence(cleaned_question, vqa_answer)
    final_answer, answer_src_lang = google_translate(llm_answer, dest=dest_lang)
    logger.info("Final Answer: " + final_answer)
    return final_answer
    
# Home page text
title = "Interactive demo: Cross-Lingual VQA"
description = """
Upload an image, type a question, click 'submit', or click one of the examples to load them.

Note: This web demo is running on a CPU thus, may take a few minutes for completing output at times. For better performance, please consider migrating to your own space and upgrading to a GPU runtime.

"""
article = """
Supported 107 Languages: Afrikaans, Albanian, Amharic, Arabic, Armenian, Azerbaijani, Basque, Belarusian, Bengali, Bosnian, Bulgarian, Catalan, Cebuano, Chichewa, Chinese (Simplified), Chinese (Traditional), Corsican, Croatian, Czech, Danish, Dutch, English, Esperanto, Estonian, Filipino, Finnish, French, Frisian, Galician, Georgian, German, Greek, Gujarati, Haitian Creole, Hausa, Hawaiian, Hebrew, Hindi, Hmong, Hungarian, Icelandic, Igbo, Indonesian, Irish, Italian, Japanese, Javanese, Kannada, Kazakh, Khmer, Korean, Kurdish (Kurmanji), Kyrgyz, Lao, Latin, Latvian, Lithuanian, Luxembourgish, Macedonian, Malagasy, Malay, Malayalam, Maltese, Maori, Marathi, Mongolian, Myanmar (Burmese), Nepali, Norwegian, Odia, Pashto, Persian, Polish, Portuguese, Punjabi, Romanian, Russian, Samoan, Scots Gaelic, Serbian, Sesotho, Shona, Sindhi, Sinhala, Slovak, Slovenian, Somali, Spanish, Sundanese, Swahili, Swedish, Tajik, Tamil, Telugu, Thai, Turkish, Ukrainian, Urdu, Uyghur, Uzbek, Vietnamese, Welsh, Xhosa, Yiddish, Yoruba, Zulu
"""

# Load example images 
torch.hub.download_url_to_file('http://farm3.staticflickr.com/2710/4520550856_7a9f9ea59d_z.jpg', 'apple.jpg')
torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://media.istockphoto.com/id/1174602891/photo/two-monkeys-mom-and-cub-eat-bananas.jpg?s=612x612&w=0&k=20&c=r7VXi9d1wHhyq3iAk9D2Z3yTZiOJMlLNtjdVRBEjG7g=', 'monkey.jpg')

# Define home page variables 
image = gr.Image(type="pil")
question = gr.Textbox(label="Question")
answer = gr.Textbox(label="Predicted answer")
examples = [
    ["apple.jpg", "Qu'est-ce que j'ai dans la main en anglais?"], 
    ["cats.jpg", "How many cats are here?"],
    ["monkey.jpg", "In Korean, what are these animals called?"],
    ["apple.jpg", "What color is this? Answer in Uyghur."], 
    ["cats.jpg", "What are the cats doing in German?"],
    ["monkey.jpg", "Maymunlar ne yiyor, Çince cevap ver."]
    ]

demo = gr.Interface(fn=vqa_main, 
                         inputs=[image, question], 
                         outputs="text",
                         examples=examples, 
                         title=title,
                         description=description,
                         article=article)
demo.launch(debug=True, show_error = True)