Spaces:
Running
Running
import gradio as gr | |
from transformers import ViltProcessor, ViltForQuestionAnswering, AutoModelForSeq2SeqLM, AutoTokenizer | |
import torch | |
import httpcore | |
setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy') | |
from googletrans import Translator | |
from googletrans import LANGCODES | |
import re | |
# List of acceptable languages | |
acceptable_languages = set(L.split()[0] for L in LANGCODES) | |
acceptable_languages.add("mandarin") | |
acceptable_languages.add("cantonese") | |
# Translation | |
def google_translate(question, dest): | |
translator = Translator() | |
translation = translator.translate(question, dest=dest) | |
print("Translation text: " + translation.text) | |
print("Translation src: " + translation.src) | |
return (translation.text, translation.src) | |
# Lang to lang_code mapping | |
def lang_code_match(accaptable_lang): | |
# Exception for chinese langs | |
if accaptable_lang == 'mandarin': | |
return 'zh-cn' | |
elif accaptable_lang == 'cantonese' or accaptable_lang == 'chinese': | |
return 'zh-tw' | |
# Default | |
else: | |
return LANGCODES[accaptable_lang] | |
# Find destination langauge | |
def find_dest_language(sentence, src_lang): | |
pattern = r'\b(' + '|'.join(acceptable_languages) + r')\b' | |
match = re.search(pattern, sentence, flags=re.IGNORECASE) | |
if match: | |
lang_code = lang_code_match(match.group(0).lower()) | |
print("Destination lang: " + lang_code) | |
return lang_code | |
else: | |
print("Destination lang:" + src_lang) | |
return src_lang | |
# Remove destination langauge context | |
def remove_language_phrase(sentence): | |
# Bremove "in [acceptable_languages]" or "[acceptable_languages]" and any non-closing punctuation around it | |
pattern = r'(\b(in\s)?(' + '|'.join(acceptable_languages) + r')\b)[\s,;:.!?]*' | |
cleaned_sentence = re.sub(pattern, '', sentence, flags=re.IGNORECASE).strip() | |
print("Language Phrase Removed: " + cleaned_sentence) | |
return cleaned_sentence | |
def vqa(image, text): | |
encoding = vqa_processor(image, text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = vqa_model(**encoding) | |
logits = outputs.logits | |
idx = logits.argmax(-1).item() | |
predicted_answer = vqa_model.config.id2label[idx] | |
return predicted_answer | |
def llm(cleaned_sentence, vqa_answer): | |
# Prepare the input prompt | |
prompt = ( | |
f"A question: {cleaned_sentence}\n" | |
f"An answer: {vqa_answer}.\n" | |
f"Based on these, answer the question with a complete sentence without extra information." | |
) | |
inputs = flan_tokenizer(prompt, return_tensors="pt") | |
outputs = flan_model.generate(**inputs, max_length=50) | |
response = flan_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
print("T5 prompt: " + prompt) | |
print("T5 response: " + response) | |
return response | |
torch.hub.download_url_to_file('https://media.istockphoto.com/id/1174602891/photo/two-monkeys-mom-and-cub-eat-bananas.jpg?s=612x612&w=0&k=20&c=r7VXi9d1wHhyq3iAk9D2Z3yTZiOJMlLNtjdVRBEjG7g=', 'monkeys.jpg') | |
torch.hub.download_url_to_file('https://img.freepik.com/premium-photo/man-holds-apple-his-hands_198067-740023.jpg', 'apple.jpg') | |
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
flan_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") | |
flan_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large") | |
def main(image, text): | |
en_question, question_src_lang = google_translate(text, dest='en') | |
dest_lang = find_dest_language(en_question, question_src_lang) | |
cleaned_sentence = remove_language_phrase(en_question) | |
vqa_answer = vqa(image, cleaned_sentence) | |
llm_answer = llm(cleaned_sentence, vqa_answer) | |
final_answer, _ = google_translate(llm_answer, dest=dest_lang) | |
print("Final answer: ", final_answer) | |
return vqa_answer, final_answer | |
image = gr.Image(type="pil") | |
question = gr.Textbox(label="Question") | |
answer = gr.Textbox(label="Predicted answer") | |
examples = [ | |
["monkeys.jpg", "What are the monkeys doing in French?"], | |
["apple.jpg", "Qu'est-ce que c'est dans ma main en anglais?"], | |
["monkeys.jpg", "In Uyghur, tell me how many monkeys are here?"], | |
["apple.jpg", "What color is this fruit? Chinese."] | |
] | |
title = "Cross-lingual VQA" | |
description = """ | |
Visual Question Answering (VQA) across Languages | |
Input an image and ask a question regarding the image. | |
Click on the examples below to see sample inputs and outputs. | |
""" | |
article = """ | |
Supports questions regarding the following languages: | |
['afrikaans', 'albanian', 'amharic', 'arabic', 'armenian', 'azerbaijani', 'basque', 'belarusian', 'bengali', 'bosnian', 'bulgarian', 'catalan', 'cebuano', 'chichewa', 'chinese (simplified)', 'chinese (traditional)', 'corsican', 'croatian', 'czech', 'danish', 'dutch', 'english', 'esperanto', 'estonian', 'filipino', 'finnish', 'french', 'frisian', 'galician', 'georgian', 'german', 'greek', 'gujarati', 'haitian creole', 'hausa', 'hawaiian', 'hebrew', 'hebrew', 'hindi', 'hmong', 'hungarian', 'icelandic', 'igbo', 'indonesian', 'irish', 'italian', 'japanese', 'javanese', 'kannada', 'kazakh', 'khmer', 'korean', 'kurdish (kurmanji)', 'kyrgyz', 'lao', 'latin', 'latvian', 'lithuanian', 'luxembourgish', 'macedonian', 'malagasy', 'malay', 'malayalam', 'maltese', 'maori', 'marathi', 'mongolian', 'myanmar (burmese)', 'nepali', 'norwegian', 'odia', 'pashto', 'persian', 'polish', 'portuguese', 'punjabi', 'romanian', 'russian', 'samoan', 'scots gaelic', 'serbian', 'sesotho', 'shona', 'sindhi', 'sinhala', 'slovak', 'slovenian', 'somali', 'spanish', 'sundanese', 'swahili', 'swedish', 'tajik', 'tamil', 'telugu', 'thai', 'turkish', 'ukrainian', 'urdu', 'uyghur', 'uzbek', 'vietnamese', 'welsh', 'xhosa', 'yiddish', 'yoruba', 'zulu'] | |
""" | |
interface = gr.Interface(fn=main, | |
inputs=[image, question], | |
outputs=answer, | |
examples=examples, | |
title=title, | |
description=description, | |
article=article) | |
interface.launch(debug=True) |