Irpan
edits
ea5126d
raw
history blame
6.18 kB
import gradio as gr
from transformers import ViltProcessor, ViltForQuestionAnswering, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import httpcore
setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
from googletrans import Translator
from googletrans import LANGCODES
import re
# List of acceptable languages
acceptable_languages = set(L.split()[0] for L in LANGCODES)
acceptable_languages.add("mandarin")
acceptable_languages.add("cantonese")
# Translation
def google_translate(question, dest):
translator = Translator()
translation = translator.translate(question, dest=dest)
print("Translation text: " + translation.text)
print("Translation src: " + translation.src)
return (translation.text, translation.src)
# Lang to lang_code mapping
def lang_code_match(accaptable_lang):
# Exception for chinese langs
if accaptable_lang == 'mandarin':
return 'zh-cn'
elif accaptable_lang == 'cantonese' or accaptable_lang == 'chinese':
return 'zh-tw'
# Default
else:
return LANGCODES[accaptable_lang]
# Find destination langauge
def find_dest_language(sentence, src_lang):
pattern = r'\b(' + '|'.join(acceptable_languages) + r')\b'
match = re.search(pattern, sentence, flags=re.IGNORECASE)
if match:
lang_code = lang_code_match(match.group(0).lower())
print("Destination lang: " + lang_code)
return lang_code
else:
print("Destination lang:" + src_lang)
return src_lang
# Remove destination langauge context
def remove_language_phrase(sentence):
# Bremove "in [acceptable_languages]" or "[acceptable_languages]" and any non-closing punctuation around it
pattern = r'(\b(in\s)?(' + '|'.join(acceptable_languages) + r')\b)[\s,;:.!?]*'
cleaned_sentence = re.sub(pattern, '', sentence, flags=re.IGNORECASE).strip()
print("Language Phrase Removed: " + cleaned_sentence)
return cleaned_sentence
def vqa(image, text):
encoding = vqa_processor(image, text, return_tensors="pt")
with torch.no_grad():
outputs = vqa_model(**encoding)
logits = outputs.logits
idx = logits.argmax(-1).item()
predicted_answer = vqa_model.config.id2label[idx]
return predicted_answer
def llm(cleaned_sentence, vqa_answer):
# Prepare the input prompt
prompt = (
f"A question: {cleaned_sentence}\n"
f"An answer: {vqa_answer}.\n"
f"Based on these, answer the question with a complete sentence without extra information."
)
inputs = flan_tokenizer(prompt, return_tensors="pt")
outputs = flan_model.generate(**inputs, max_length=50)
response = flan_tokenizer.decode(outputs[0], skip_special_tokens=True)
print("T5 prompt: " + prompt)
print("T5 response: " + response)
return response
torch.hub.download_url_to_file('https://media.istockphoto.com/id/1174602891/photo/two-monkeys-mom-and-cub-eat-bananas.jpg?s=612x612&w=0&k=20&c=r7VXi9d1wHhyq3iAk9D2Z3yTZiOJMlLNtjdVRBEjG7g=', 'monkeys.jpg')
torch.hub.download_url_to_file('https://img.freepik.com/premium-photo/man-holds-apple-his-hands_198067-740023.jpg', 'apple.jpg')
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
flan_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
flan_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
def main(image, text):
en_question, question_src_lang = google_translate(text, dest='en')
dest_lang = find_dest_language(en_question, question_src_lang)
cleaned_sentence = remove_language_phrase(en_question)
vqa_answer = vqa(image, cleaned_sentence)
llm_answer = llm(cleaned_sentence, vqa_answer)
final_answer, _ = google_translate(llm_answer, dest=dest_lang)
print("Final answer: ", final_answer)
return vqa_answer, final_answer
image = gr.Image(type="pil")
question = gr.Textbox(label="Question")
answer = gr.Textbox(label="Predicted answer")
examples = [
["monkeys.jpg", "What are the monkeys doing in French?"],
["apple.jpg", "Qu'est-ce que c'est dans ma main en anglais?"],
["monkeys.jpg", "In Uyghur, tell me how many monkeys are here?"],
["apple.jpg", "What color is this fruit? Chinese."]
]
title = "Cross-lingual VQA"
description = """
Visual Question Answering (VQA) across Languages
Input an image and ask a question regarding the image.
Click on the examples below to see sample inputs and outputs.
"""
article = """
Supports questions regarding the following languages:
['afrikaans', 'albanian', 'amharic', 'arabic', 'armenian', 'azerbaijani', 'basque', 'belarusian', 'bengali', 'bosnian', 'bulgarian', 'catalan', 'cebuano', 'chichewa', 'chinese (simplified)', 'chinese (traditional)', 'corsican', 'croatian', 'czech', 'danish', 'dutch', 'english', 'esperanto', 'estonian', 'filipino', 'finnish', 'french', 'frisian', 'galician', 'georgian', 'german', 'greek', 'gujarati', 'haitian creole', 'hausa', 'hawaiian', 'hebrew', 'hebrew', 'hindi', 'hmong', 'hungarian', 'icelandic', 'igbo', 'indonesian', 'irish', 'italian', 'japanese', 'javanese', 'kannada', 'kazakh', 'khmer', 'korean', 'kurdish (kurmanji)', 'kyrgyz', 'lao', 'latin', 'latvian', 'lithuanian', 'luxembourgish', 'macedonian', 'malagasy', 'malay', 'malayalam', 'maltese', 'maori', 'marathi', 'mongolian', 'myanmar (burmese)', 'nepali', 'norwegian', 'odia', 'pashto', 'persian', 'polish', 'portuguese', 'punjabi', 'romanian', 'russian', 'samoan', 'scots gaelic', 'serbian', 'sesotho', 'shona', 'sindhi', 'sinhala', 'slovak', 'slovenian', 'somali', 'spanish', 'sundanese', 'swahili', 'swedish', 'tajik', 'tamil', 'telugu', 'thai', 'turkish', 'ukrainian', 'urdu', 'uyghur', 'uzbek', 'vietnamese', 'welsh', 'xhosa', 'yiddish', 'yoruba', 'zulu']
"""
interface = gr.Interface(fn=main,
inputs=[image, question],
outputs=answer,
examples=examples,
title=title,
description=description,
article=article)
interface.launch(debug=True)