Spaces:
Running
Running
Irpan
commited on
Commit
·
13e41f6
1
Parent(s):
97db7c3
app.py
CHANGED
@@ -3,15 +3,59 @@ from transformers import ViltProcessor, ViltForQuestionAnswering
|
|
3 |
import torch
|
4 |
from googletrans import Translator
|
5 |
from googletrans import LANGCODES
|
|
|
6 |
|
7 |
torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
|
8 |
|
9 |
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
10 |
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
11 |
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
# forward pass
|
16 |
with torch.no_grad():
|
17 |
outputs = model(**encoding)
|
@@ -21,6 +65,14 @@ def answer_question(image, text):
|
|
21 |
predicted_answer = model.config.id2label[idx]
|
22 |
|
23 |
return predicted_answer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
image = gr.inputs.Image(type="pil")
|
26 |
question = gr.inputs.Textbox(label="Question")
|
@@ -30,7 +82,7 @@ examples = [["cats.jpg", "How many cats are there, in French?"]]
|
|
30 |
title = "Cross-lingual VQA"
|
31 |
description = "ViLT (Vision and Language Transformer), fine-tuned on VQAv2 "
|
32 |
|
33 |
-
interface = gr.Interface(fn=
|
34 |
inputs=[image, question],
|
35 |
outputs=answer,
|
36 |
examples=examples,
|
|
|
3 |
import torch
|
4 |
from googletrans import Translator
|
5 |
from googletrans import LANGCODES
|
6 |
+
import re
|
7 |
|
8 |
torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
|
9 |
|
10 |
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
11 |
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
12 |
|
13 |
+
# List of acceptable languages
|
14 |
+
acceptable_languages = set(L.split()[0] for L in LANGCODES)
|
15 |
+
acceptable_languages.add("mandarin")
|
16 |
+
acceptable_languages.add("cantonese")
|
17 |
+
|
18 |
+
# Translation
|
19 |
+
def google_translate(question, dest):
|
20 |
+
translator = Translator()
|
21 |
+
translation = translator.translate(question, dest=dest)
|
22 |
+
print("Translation text: " + translation.text)
|
23 |
+
print("Translation src: " + translation.src)
|
24 |
+
return (translation.text, translation.src)
|
25 |
+
|
26 |
+
# Lang to lang_code mapping
|
27 |
+
def lang_code_match(accaptable_lang):
|
28 |
+
# Exception for chinese langs
|
29 |
+
if accaptable_lang == 'mandarin':
|
30 |
+
return 'zh-cn'
|
31 |
+
elif accaptable_lang == 'cantonese' or accaptable_lang == 'chinese':
|
32 |
+
return 'zh-tw'
|
33 |
+
# Default
|
34 |
+
else:
|
35 |
+
return LANGCODES[accaptable_lang]
|
36 |
|
37 |
+
# Find destination langauge
|
38 |
+
def find_dest_language(sentence, src_lang):
|
39 |
+
pattern = r'\b(' + '|'.join(acceptable_languages) + r')\b'
|
40 |
+
match = re.search(pattern, sentence, flags=re.IGNORECASE)
|
41 |
+
if match:
|
42 |
+
lang_code = lang_code_match(match.group(0).lower())
|
43 |
+
print("Destination lang: " + lang_code)
|
44 |
+
return lang_code
|
45 |
+
else:
|
46 |
+
print("Destination lang:" + src_lang)
|
47 |
+
return src_lang
|
48 |
+
|
49 |
+
# Remove destination langauge context
|
50 |
+
def remove_language_phrase(sentence):
|
51 |
+
# Bremove "in [acceptable_languages]" or "[acceptable_languages]" and any non-closing punctuation around it
|
52 |
+
pattern = r'(\b(in\s)?(' + '|'.join(acceptable_languages) + r')\b)[\s,;:.!?]*'
|
53 |
+
cleaned_sentence = re.sub(pattern, '', sentence, flags=re.IGNORECASE).strip()
|
54 |
+
print("Language Phrase Removed: " + cleaned_sentence)
|
55 |
+
return cleaned_sentence
|
56 |
+
|
57 |
+
def vqa(image, text):
|
58 |
+
encoding = processor(image, text, return_tensors="pt")
|
59 |
# forward pass
|
60 |
with torch.no_grad():
|
61 |
outputs = model(**encoding)
|
|
|
65 |
predicted_answer = model.config.id2label[idx]
|
66 |
|
67 |
return predicted_answer
|
68 |
+
|
69 |
+
def main(image, text):
|
70 |
+
en_question, question_src_lang = google_translate(text, dest='en')
|
71 |
+
dest_lang = find_dest_language(en_question, question_src_lang)
|
72 |
+
cleaned_sentence = remove_language_phrase(en_question)
|
73 |
+
vqa_answer = vqa(image, cleaned_sentence)
|
74 |
+
return vqa_answer
|
75 |
+
|
76 |
|
77 |
image = gr.inputs.Image(type="pil")
|
78 |
question = gr.inputs.Textbox(label="Question")
|
|
|
82 |
title = "Cross-lingual VQA"
|
83 |
description = "ViLT (Vision and Language Transformer), fine-tuned on VQAv2 "
|
84 |
|
85 |
+
interface = gr.Interface(fn=main,
|
86 |
inputs=[image, question],
|
87 |
outputs=answer,
|
88 |
examples=examples,
|