Irpan commited on
Commit
13e41f6
·
1 Parent(s): 97db7c3
Files changed (1) hide show
  1. app.py +55 -3
app.py CHANGED
@@ -3,15 +3,59 @@ from transformers import ViltProcessor, ViltForQuestionAnswering
3
  import torch
4
  from googletrans import Translator
5
  from googletrans import LANGCODES
 
6
 
7
  torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
8
 
9
  processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
10
  model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
11
 
12
- def answer_question(image, text):
13
- encoding = processor(image, text, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # forward pass
16
  with torch.no_grad():
17
  outputs = model(**encoding)
@@ -21,6 +65,14 @@ def answer_question(image, text):
21
  predicted_answer = model.config.id2label[idx]
22
 
23
  return predicted_answer
 
 
 
 
 
 
 
 
24
 
25
  image = gr.inputs.Image(type="pil")
26
  question = gr.inputs.Textbox(label="Question")
@@ -30,7 +82,7 @@ examples = [["cats.jpg", "How many cats are there, in French?"]]
30
  title = "Cross-lingual VQA"
31
  description = "ViLT (Vision and Language Transformer), fine-tuned on VQAv2 "
32
 
33
- interface = gr.Interface(fn=answer_question,
34
  inputs=[image, question],
35
  outputs=answer,
36
  examples=examples,
 
3
  import torch
4
  from googletrans import Translator
5
  from googletrans import LANGCODES
6
+ import re
7
 
8
  torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
9
 
10
  processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
11
  model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
12
 
13
+ # List of acceptable languages
14
+ acceptable_languages = set(L.split()[0] for L in LANGCODES)
15
+ acceptable_languages.add("mandarin")
16
+ acceptable_languages.add("cantonese")
17
+
18
+ # Translation
19
+ def google_translate(question, dest):
20
+ translator = Translator()
21
+ translation = translator.translate(question, dest=dest)
22
+ print("Translation text: " + translation.text)
23
+ print("Translation src: " + translation.src)
24
+ return (translation.text, translation.src)
25
+
26
+ # Lang to lang_code mapping
27
+ def lang_code_match(accaptable_lang):
28
+ # Exception for chinese langs
29
+ if accaptable_lang == 'mandarin':
30
+ return 'zh-cn'
31
+ elif accaptable_lang == 'cantonese' or accaptable_lang == 'chinese':
32
+ return 'zh-tw'
33
+ # Default
34
+ else:
35
+ return LANGCODES[accaptable_lang]
36
 
37
+ # Find destination langauge
38
+ def find_dest_language(sentence, src_lang):
39
+ pattern = r'\b(' + '|'.join(acceptable_languages) + r')\b'
40
+ match = re.search(pattern, sentence, flags=re.IGNORECASE)
41
+ if match:
42
+ lang_code = lang_code_match(match.group(0).lower())
43
+ print("Destination lang: " + lang_code)
44
+ return lang_code
45
+ else:
46
+ print("Destination lang:" + src_lang)
47
+ return src_lang
48
+
49
+ # Remove destination langauge context
50
+ def remove_language_phrase(sentence):
51
+ # Bremove "in [acceptable_languages]" or "[acceptable_languages]" and any non-closing punctuation around it
52
+ pattern = r'(\b(in\s)?(' + '|'.join(acceptable_languages) + r')\b)[\s,;:.!?]*'
53
+ cleaned_sentence = re.sub(pattern, '', sentence, flags=re.IGNORECASE).strip()
54
+ print("Language Phrase Removed: " + cleaned_sentence)
55
+ return cleaned_sentence
56
+
57
+ def vqa(image, text):
58
+ encoding = processor(image, text, return_tensors="pt")
59
  # forward pass
60
  with torch.no_grad():
61
  outputs = model(**encoding)
 
65
  predicted_answer = model.config.id2label[idx]
66
 
67
  return predicted_answer
68
+
69
+ def main(image, text):
70
+ en_question, question_src_lang = google_translate(text, dest='en')
71
+ dest_lang = find_dest_language(en_question, question_src_lang)
72
+ cleaned_sentence = remove_language_phrase(en_question)
73
+ vqa_answer = vqa(image, cleaned_sentence)
74
+ return vqa_answer
75
+
76
 
77
  image = gr.inputs.Image(type="pil")
78
  question = gr.inputs.Textbox(label="Question")
 
82
  title = "Cross-lingual VQA"
83
  description = "ViLT (Vision and Language Transformer), fine-tuned on VQAv2 "
84
 
85
+ interface = gr.Interface(fn=main,
86
  inputs=[image, question],
87
  outputs=answer,
88
  examples=examples,