Spaces:
Running
Running
File size: 7,840 Bytes
aca743e 8e8c9ba aca743e 02149a3 4437cf3 02149a3 4437cf3 7e69a2f 9c9a73e 7e69a2f 02149a3 aca743e 7e69a2f 9c9a73e dc1676e 775f1ae aca743e 087ba34 30a0050 087ba34 aca743e 775f1ae d9f146b 775f1ae 9c9a73e 47a8d15 9c9a73e 47a8d15 17bc88c 9c9a73e 087ba34 775f1ae 7e69a2f 9c9a73e 647ec6f 9c9a73e 775f1ae bf9d6b8 15c7e76 fb014ca 043b7ee 15c7e76 26eb6b9 043b7ee 775f1ae 7e69a2f 775f1ae 05c3988 775f1ae ce05662 05c3988 fb014ca 05c3988 5106646 fb014ca aca743e b1438bb aca743e d68f037 aca743e 2c4a4e3 b1438bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import gradio as gr
import re
import torch
from transformers.utils import logging
from transformers import ViltProcessor, ViltForQuestionAnswering, T5Tokenizer, T5ForConditionalGeneration
import httpcore
setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy') # set SyncHTTPTransport attribute for googletrans dependency
from googletrans import Translator
from googletrans import LANGCODES
# List of acceptable languages
acceptable_languages = set(L.split()[0] for L in LANGCODES)
acceptable_languages.add("mandarin")
acceptable_languages.add("cantonese")
logging.set_verbosity_info()
logger = logging.get_logger("transformers")
# Translation
def google_translate(question, dest):
translator = Translator()
translation = translator.translate(question, dest=dest)
logger.info("Translation text: " + translation.text)
logger.info("Translation src: " + translation.src)
return (translation.text, translation.src)
# Lang to lang_code mapping
def lang_code_match(accaptable_lang):
# Exception for chinese langs
if accaptable_lang == 'mandarin':
return 'zh-cn'
elif accaptable_lang == 'cantonese' or accaptable_lang == 'chinese':
return 'zh-tw'
# Default
else:
return LANGCODES[accaptable_lang]
# Find destination language
def find_dest_language(sentence, src_lang):
pattern = r'\b(' + '|'.join(acceptable_languages) + r')\b'
match = re.search(pattern, sentence, flags=re.IGNORECASE)
if match:
lang_code = lang_code_match(match.group(0).lower())
logger.info("Destination lang: " + lang_code)
return lang_code
else:
logger.info("Destination lang:" + src_lang)
return src_lang
# Remove destination language context
def remove_language_phrase(sentence):
# Bremove "in [acceptable_languages]" or "[acceptable_languages]" and any non-closing punctuation around it
pattern = r'(\b(in\s)?(' + '|'.join(acceptable_languages) + r')\b)[\s,;:.!?]*'
cleaned_sentence = re.sub(pattern, '', sentence, flags=re.IGNORECASE).strip()
logger.info("Language Phrase Removed: " + cleaned_sentence)
return cleaned_sentence
# Load Vilt
vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
def vilt_vqa(image, question):
inputs = vilt_processor(image, question, return_tensors="pt")
with torch.no_grad():
outputs = vilt_model(**inputs)
logits = outputs.logits
idx = logits.argmax(-1).item()
answer = vilt_model.config.id2label[idx]
logger.info("ViLT: " + answer)
# Get the top 10 scores and their indices
topk_values, topk_indices = torch.topk(logits, 10, dim=-1)
topk_answers = [vilt_model.config.id2label[idx.item()] for idx in topk_indices[0]]
logger.info("ViLT top 10 answers: " + str(topk_answers))
return answer
# Load FLAN-T5
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")
def flan_t5_complete_sentence(question, answer):
# #input_text = f"A question: {question} An answer: {answer}. Based on these, answer the question with a complete sentence without extra information."
# input_text = f"What language is this question asking about: {question}"
# logger.info("T5 input: " + input_text)
# inputs = t5_tokenizer(input_text, return_tensors="pt")
# outputs = t5_model.generate(**inputs, max_length=50)
# result_sentence = ''.join(t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
# logger.info("T5 output1: " + result_sentence)
# input_text = f"Translate to {str(result_sentence)}: {answer}"
# logger.info("T5 input: " + input_text)
# inputs = t5_tokenizer(input_text, return_tensors="pt")
# outputs = t5_model.generate(**inputs, max_length=50)
# result_sentence = ''.join(t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
# logger.info("T5 output2: " + result_sentence)
input_text = f"A question: {question} An answer: {answer}. Based on these, answer the question with a complete sentence without extra information."
logger.info("T5 input: " + input_text)
inputs = t5_tokenizer(input_text, return_tensors="pt")
outputs = t5_model.generate(**inputs, max_length=50)
result_sentence = ''.join(t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
logger.info("T5 output: " + result_sentence)
return result_sentence
# Main function
def vqa_main(image, question):
en_question, question_src_lang = google_translate(question, dest='en')
dest_lang = find_dest_language(en_question, question_src_lang)
cleaned_question = remove_language_phrase(en_question)
vqa_answer = vilt_vqa(image, cleaned_question)
llm_answer = flan_t5_complete_sentence(cleaned_question, vqa_answer)
final_answer, answer_src_lang = google_translate(llm_answer, dest=dest_lang)
logger.info("Final Answer: " + final_answer)
return final_answer
# Home page text
title = "Interactive demo: Cross-Lingual VQA"
description = """
Upload an image, type a question, click 'submit', or click one of the examples to load them.
Note: This web demo is running on a CPU thus, may take a few minutes for completing output at times. For better performance, please consider migrating to your own space and upgrading to a GPU runtime.
"""
article = """
Supported 107 Languages: Afrikaans, Albanian, Amharic, Arabic, Armenian, Azerbaijani, Basque, Belarusian, Bengali, Bosnian, Bulgarian, Catalan, Cebuano, Chichewa, Chinese (Simplified), Chinese (Traditional), Corsican, Croatian, Czech, Danish, Dutch, English, Esperanto, Estonian, Filipino, Finnish, French, Frisian, Galician, Georgian, German, Greek, Gujarati, Haitian Creole, Hausa, Hawaiian, Hebrew, Hindi, Hmong, Hungarian, Icelandic, Igbo, Indonesian, Irish, Italian, Japanese, Javanese, Kannada, Kazakh, Khmer, Korean, Kurdish (Kurmanji), Kyrgyz, Lao, Latin, Latvian, Lithuanian, Luxembourgish, Macedonian, Malagasy, Malay, Malayalam, Maltese, Maori, Marathi, Mongolian, Myanmar (Burmese), Nepali, Norwegian, Odia, Pashto, Persian, Polish, Portuguese, Punjabi, Romanian, Russian, Samoan, Scots Gaelic, Serbian, Sesotho, Shona, Sindhi, Sinhala, Slovak, Slovenian, Somali, Spanish, Sundanese, Swahili, Swedish, Tajik, Tamil, Telugu, Thai, Turkish, Ukrainian, Urdu, Uyghur, Uzbek, Vietnamese, Welsh, Xhosa, Yiddish, Yoruba, Zulu
"""
# Load example images
torch.hub.download_url_to_file('http://farm3.staticflickr.com/2710/4520550856_7a9f9ea59d_z.jpg', 'apple.jpg')
torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://media.istockphoto.com/id/1174602891/photo/two-monkeys-mom-and-cub-eat-bananas.jpg?s=612x612&w=0&k=20&c=r7VXi9d1wHhyq3iAk9D2Z3yTZiOJMlLNtjdVRBEjG7g=', 'monkey.jpg')
# Define home page variables
image = gr.Image(type="pil")
question = gr.Textbox(label="Question")
answer = gr.Textbox(label="Predicted answer")
examples = [
["apple.jpg", "Qu'est-ce que j'ai dans la main en anglais?"],
["cats.jpg", "How many cats are here?"],
["monkey.jpg", "In Korean, what are these animals called?"],
["apple.jpg", "What color is this? Answer in Uyghur."],
["cats.jpg", "What are the cats doing in German?"],
["monkey.jpg", "Maymunlar ne yiyor, Çince cevap ver."]
]
demo = gr.Interface(fn=vqa_main,
inputs=[image, question],
outputs="text",
examples=examples,
title=title,
description=description,
article=article)
demo.launch(debug=True, show_error = True) |