Spaces:
Running
Running
File size: 6,179 Bytes
64f507f 190650c 64f507f b063473 97db7c3 13e41f6 64f507f 13e41f6 64f507f 13e41f6 290c136 64f507f 290c136 64f507f 290c136 64f507f 13e41f6 190650c 7c308d7 190650c 72a284b 190650c 13e41f6 190650c db445c5 72a284b 13e41f6 64f507f 290c136 7c308d7 72a284b db445c5 ea5126d 7c308d7 64f507f 9d3623a ea5126d 0b48030 9d3623a 7c308d7 6a4a82d 7c308d7 64f507f 13e41f6 64f507f 7c308d7 64f507f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import gradio as gr
from transformers import ViltProcessor, ViltForQuestionAnswering, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import httpcore
setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
from googletrans import Translator
from googletrans import LANGCODES
import re
# List of acceptable languages
acceptable_languages = set(L.split()[0] for L in LANGCODES)
acceptable_languages.add("mandarin")
acceptable_languages.add("cantonese")
# Translation
def google_translate(question, dest):
translator = Translator()
translation = translator.translate(question, dest=dest)
print("Translation text: " + translation.text)
print("Translation src: " + translation.src)
return (translation.text, translation.src)
# Lang to lang_code mapping
def lang_code_match(accaptable_lang):
# Exception for chinese langs
if accaptable_lang == 'mandarin':
return 'zh-cn'
elif accaptable_lang == 'cantonese' or accaptable_lang == 'chinese':
return 'zh-tw'
# Default
else:
return LANGCODES[accaptable_lang]
# Find destination langauge
def find_dest_language(sentence, src_lang):
pattern = r'\b(' + '|'.join(acceptable_languages) + r')\b'
match = re.search(pattern, sentence, flags=re.IGNORECASE)
if match:
lang_code = lang_code_match(match.group(0).lower())
print("Destination lang: " + lang_code)
return lang_code
else:
print("Destination lang:" + src_lang)
return src_lang
# Remove destination langauge context
def remove_language_phrase(sentence):
# Bremove "in [acceptable_languages]" or "[acceptable_languages]" and any non-closing punctuation around it
pattern = r'(\b(in\s)?(' + '|'.join(acceptable_languages) + r')\b)[\s,;:.!?]*'
cleaned_sentence = re.sub(pattern, '', sentence, flags=re.IGNORECASE).strip()
print("Language Phrase Removed: " + cleaned_sentence)
return cleaned_sentence
def vqa(image, text):
encoding = vqa_processor(image, text, return_tensors="pt")
with torch.no_grad():
outputs = vqa_model(**encoding)
logits = outputs.logits
idx = logits.argmax(-1).item()
predicted_answer = vqa_model.config.id2label[idx]
return predicted_answer
def llm(cleaned_sentence, vqa_answer):
# Prepare the input prompt
prompt = (
f"A question: {cleaned_sentence}\n"
f"An answer: {vqa_answer}.\n"
f"Based on these, answer the question with a complete sentence without extra information."
)
inputs = flan_tokenizer(prompt, return_tensors="pt")
outputs = flan_model.generate(**inputs, max_length=50)
response = flan_tokenizer.decode(outputs[0], skip_special_tokens=True)
print("T5 prompt: " + prompt)
print("T5 response: " + response)
return response
torch.hub.download_url_to_file('https://media.istockphoto.com/id/1174602891/photo/two-monkeys-mom-and-cub-eat-bananas.jpg?s=612x612&w=0&k=20&c=r7VXi9d1wHhyq3iAk9D2Z3yTZiOJMlLNtjdVRBEjG7g=', 'monkeys.jpg')
torch.hub.download_url_to_file('https://img.freepik.com/premium-photo/man-holds-apple-his-hands_198067-740023.jpg', 'apple.jpg')
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
flan_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
flan_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
def main(image, text):
en_question, question_src_lang = google_translate(text, dest='en')
dest_lang = find_dest_language(en_question, question_src_lang)
cleaned_sentence = remove_language_phrase(en_question)
vqa_answer = vqa(image, cleaned_sentence)
llm_answer = llm(cleaned_sentence, vqa_answer)
final_answer, _ = google_translate(llm_answer, dest=dest_lang)
print("Final answer: ", final_answer)
return vqa_answer, final_answer
image = gr.Image(type="pil")
question = gr.Textbox(label="Question")
answer = gr.Textbox(label="Predicted answer")
examples = [
["monkeys.jpg", "What are the monkeys doing in French?"],
["apple.jpg", "Qu'est-ce que c'est dans ma main en anglais?"],
["monkeys.jpg", "In Uyghur, tell me how many monkeys are here?"],
["apple.jpg", "What color is this fruit? Chinese."]
]
title = "Cross-lingual VQA"
description = """
Visual Question Answering (VQA) across Languages
Input an image and ask a question regarding the image.
Click on the examples below to see sample inputs and outputs.
"""
article = """
Supports questions regarding the following languages:
['afrikaans', 'albanian', 'amharic', 'arabic', 'armenian', 'azerbaijani', 'basque', 'belarusian', 'bengali', 'bosnian', 'bulgarian', 'catalan', 'cebuano', 'chichewa', 'chinese (simplified)', 'chinese (traditional)', 'corsican', 'croatian', 'czech', 'danish', 'dutch', 'english', 'esperanto', 'estonian', 'filipino', 'finnish', 'french', 'frisian', 'galician', 'georgian', 'german', 'greek', 'gujarati', 'haitian creole', 'hausa', 'hawaiian', 'hebrew', 'hebrew', 'hindi', 'hmong', 'hungarian', 'icelandic', 'igbo', 'indonesian', 'irish', 'italian', 'japanese', 'javanese', 'kannada', 'kazakh', 'khmer', 'korean', 'kurdish (kurmanji)', 'kyrgyz', 'lao', 'latin', 'latvian', 'lithuanian', 'luxembourgish', 'macedonian', 'malagasy', 'malay', 'malayalam', 'maltese', 'maori', 'marathi', 'mongolian', 'myanmar (burmese)', 'nepali', 'norwegian', 'odia', 'pashto', 'persian', 'polish', 'portuguese', 'punjabi', 'romanian', 'russian', 'samoan', 'scots gaelic', 'serbian', 'sesotho', 'shona', 'sindhi', 'sinhala', 'slovak', 'slovenian', 'somali', 'spanish', 'sundanese', 'swahili', 'swedish', 'tajik', 'tamil', 'telugu', 'thai', 'turkish', 'ukrainian', 'urdu', 'uyghur', 'uzbek', 'vietnamese', 'welsh', 'xhosa', 'yiddish', 'yoruba', 'zulu']
"""
interface = gr.Interface(fn=main,
inputs=[image, question],
outputs=answer,
examples=examples,
title=title,
description=description,
article=article)
interface.launch(debug=True) |