File size: 6,179 Bytes
64f507f
190650c
64f507f
b063473
 
 
 
97db7c3
 
13e41f6
64f507f
13e41f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64f507f
13e41f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290c136
64f507f
290c136
64f507f
 
 
290c136
64f507f
 
13e41f6
190650c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c308d7
190650c
 
 
 
72a284b
 
190650c
 
 
13e41f6
 
 
 
 
190650c
 
db445c5
72a284b
13e41f6
64f507f
290c136
 
 
7c308d7
72a284b
db445c5
ea5126d
 
7c308d7
64f507f
 
9d3623a
 
 
ea5126d
 
 
0b48030
9d3623a
7c308d7
6a4a82d
7c308d7
 
64f507f
13e41f6
64f507f
 
 
 
7c308d7
 
64f507f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
from transformers import ViltProcessor, ViltForQuestionAnswering, AutoModelForSeq2SeqLM, AutoTokenizer
import torch

import httpcore
setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy') 

from googletrans import Translator
from googletrans import LANGCODES
import re

# List of acceptable languages
acceptable_languages = set(L.split()[0] for L in LANGCODES)
acceptable_languages.add("mandarin")
acceptable_languages.add("cantonese")

# Translation 
def google_translate(question, dest):
    translator = Translator()
    translation = translator.translate(question, dest=dest)
    print("Translation text: " + translation.text)
    print("Translation src: " + translation.src)
    return (translation.text, translation.src)

# Lang to lang_code mapping 
def lang_code_match(accaptable_lang):
    # Exception for chinese langs
    if accaptable_lang == 'mandarin':
        return 'zh-cn'
    elif accaptable_lang == 'cantonese' or accaptable_lang == 'chinese':
        return 'zh-tw'
    # Default 
    else:
        return LANGCODES[accaptable_lang]
    
# Find destination langauge
def find_dest_language(sentence, src_lang):
    pattern = r'\b(' + '|'.join(acceptable_languages) + r')\b'
    match = re.search(pattern, sentence, flags=re.IGNORECASE)
    if match:
        lang_code = lang_code_match(match.group(0).lower())
        print("Destination lang: " + lang_code)
        return lang_code
    else:
        print("Destination lang:" + src_lang)
        return src_lang

# Remove destination langauge context
def remove_language_phrase(sentence):
    # Bremove "in [acceptable_languages]" or "[acceptable_languages]" and any non-closing punctuation around it
    pattern = r'(\b(in\s)?(' + '|'.join(acceptable_languages) + r')\b)[\s,;:.!?]*'
    cleaned_sentence = re.sub(pattern, '', sentence, flags=re.IGNORECASE).strip()
    print("Language Phrase Removed: " + cleaned_sentence)
    return cleaned_sentence

def vqa(image, text):
    encoding = vqa_processor(image, text, return_tensors="pt")
    with torch.no_grad():
     outputs = vqa_model(**encoding)
     
    logits = outputs.logits
    idx = logits.argmax(-1).item()
    predicted_answer = vqa_model.config.id2label[idx]
   
    return predicted_answer


def llm(cleaned_sentence, vqa_answer):
    # Prepare the input prompt
    prompt = (
        f"A question: {cleaned_sentence}\n"
        f"An answer: {vqa_answer}.\n"
        f"Based on these, answer the question with a complete sentence without extra information."
    )
    inputs = flan_tokenizer(prompt, return_tensors="pt")
    outputs = flan_model.generate(**inputs, max_length=50)
    response = flan_tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("T5 prompt: " + prompt)
    print("T5 response: " + response)
    return response 


torch.hub.download_url_to_file('https://media.istockphoto.com/id/1174602891/photo/two-monkeys-mom-and-cub-eat-bananas.jpg?s=612x612&w=0&k=20&c=r7VXi9d1wHhyq3iAk9D2Z3yTZiOJMlLNtjdVRBEjG7g=', 'monkeys.jpg')
torch.hub.download_url_to_file('https://img.freepik.com/premium-photo/man-holds-apple-his-hands_198067-740023.jpg', 'apple.jpg')

vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

flan_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
flan_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")



def main(image, text):
    en_question, question_src_lang = google_translate(text, dest='en')
    dest_lang = find_dest_language(en_question, question_src_lang)
    cleaned_sentence = remove_language_phrase(en_question)
    vqa_answer = vqa(image, cleaned_sentence)
    llm_answer = llm(cleaned_sentence, vqa_answer)
    final_answer, _ = google_translate(llm_answer, dest=dest_lang)
    print("Final answer: ", final_answer)
    return vqa_answer, final_answer
    
   
image = gr.Image(type="pil")
question = gr.Textbox(label="Question")
answer = gr.Textbox(label="Predicted answer")
examples = [
    ["monkeys.jpg", "What are the monkeys doing in French?"],
    ["apple.jpg", "Qu'est-ce que c'est dans ma main en anglais?"],
    ["monkeys.jpg", "In Uyghur, tell me how many monkeys are here?"],
    ["apple.jpg", "What color is this fruit? Chinese."]
    ]

title = "Cross-lingual VQA"
description = """
Visual Question Answering (VQA) across Languages

Input an image and ask a question regarding the image.

Click on the examples below to see sample inputs and outputs. 

"""
article = """
Supports questions regarding the following languages: 
['afrikaans', 'albanian', 'amharic', 'arabic', 'armenian', 'azerbaijani', 'basque', 'belarusian', 'bengali', 'bosnian', 'bulgarian', 'catalan', 'cebuano', 'chichewa', 'chinese (simplified)', 'chinese (traditional)', 'corsican', 'croatian', 'czech', 'danish', 'dutch', 'english', 'esperanto', 'estonian', 'filipino', 'finnish', 'french', 'frisian', 'galician', 'georgian', 'german', 'greek', 'gujarati', 'haitian creole', 'hausa', 'hawaiian', 'hebrew', 'hebrew', 'hindi', 'hmong', 'hungarian', 'icelandic', 'igbo', 'indonesian', 'irish', 'italian', 'japanese', 'javanese', 'kannada', 'kazakh', 'khmer', 'korean', 'kurdish (kurmanji)', 'kyrgyz', 'lao', 'latin', 'latvian', 'lithuanian', 'luxembourgish', 'macedonian', 'malagasy', 'malay', 'malayalam', 'maltese', 'maori', 'marathi', 'mongolian', 'myanmar (burmese)', 'nepali', 'norwegian', 'odia', 'pashto', 'persian', 'polish', 'portuguese', 'punjabi', 'romanian', 'russian', 'samoan', 'scots gaelic', 'serbian', 'sesotho', 'shona', 'sindhi', 'sinhala', 'slovak', 'slovenian', 'somali', 'spanish', 'sundanese', 'swahili', 'swedish', 'tajik', 'tamil', 'telugu', 'thai', 'turkish', 'ukrainian', 'urdu', 'uyghur', 'uzbek', 'vietnamese', 'welsh', 'xhosa', 'yiddish', 'yoruba', 'zulu']
"""

interface = gr.Interface(fn=main, 
                         inputs=[image, question], 
                         outputs=answer, 
                         examples=examples, 
                         title=title,
                         description=description,
                         article=article)
interface.launch(debug=True)