import gradio as gr from transformers import ViltProcessor, ViltForQuestionAnswering, AutoModelForSeq2SeqLM, AutoTokenizer import torch import httpcore setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy') from googletrans import Translator from googletrans import LANGCODES import re # List of acceptable languages acceptable_languages = set(L.split()[0] for L in LANGCODES) acceptable_languages.add("mandarin") acceptable_languages.add("cantonese") # Translation def google_translate(question, dest): translator = Translator() translation = translator.translate(question, dest=dest) print("Translation text: " + translation.text) print("Translation src: " + translation.src) return (translation.text, translation.src) # Lang to lang_code mapping def lang_code_match(accaptable_lang): # Exception for chinese langs if accaptable_lang == 'mandarin': return 'zh-cn' elif accaptable_lang == 'cantonese' or accaptable_lang == 'chinese': return 'zh-tw' # Default else: return LANGCODES[accaptable_lang] # Find destination langauge def find_dest_language(sentence, src_lang): pattern = r'\b(' + '|'.join(acceptable_languages) + r')\b' match = re.search(pattern, sentence, flags=re.IGNORECASE) if match: lang_code = lang_code_match(match.group(0).lower()) print("Destination lang: " + lang_code) return lang_code else: print("Destination lang:" + src_lang) return src_lang # Remove destination langauge context def remove_language_phrase(sentence): # Bremove "in [acceptable_languages]" or "[acceptable_languages]" and any non-closing punctuation around it pattern = r'(\b(in\s)?(' + '|'.join(acceptable_languages) + r')\b)[\s,;:.!?]*' cleaned_sentence = re.sub(pattern, '', sentence, flags=re.IGNORECASE).strip() print("Language Phrase Removed: " + cleaned_sentence) return cleaned_sentence def vqa(image, text): encoding = vqa_processor(image, text, return_tensors="pt") with torch.no_grad(): outputs = vqa_model(**encoding) logits = outputs.logits idx = logits.argmax(-1).item() predicted_answer = vqa_model.config.id2label[idx] return predicted_answer def llm(cleaned_sentence, vqa_answer): # Prepare the input prompt prompt = ( f"A question: {cleaned_sentence}\n" f"An answer: {vqa_answer}.\n" f"Based on these, answer the question with a complete sentence without extra information." ) inputs = flan_tokenizer(prompt, return_tensors="pt") outputs = flan_model.generate(**inputs, max_length=50) response = flan_tokenizer.decode(outputs[0], skip_special_tokens=True) print("T5 prompt: " + prompt) print("T5 response: " + response) return response torch.hub.download_url_to_file('https://media.istockphoto.com/id/1174602891/photo/two-monkeys-mom-and-cub-eat-bananas.jpg?s=612x612&w=0&k=20&c=r7VXi9d1wHhyq3iAk9D2Z3yTZiOJMlLNtjdVRBEjG7g=', 'monkeys.jpg') torch.hub.download_url_to_file('https://img.freepik.com/premium-photo/man-holds-apple-his-hands_198067-740023.jpg', 'apple.jpg') vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") flan_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") flan_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large") def main(image, text): en_question, question_src_lang = google_translate(text, dest='en') dest_lang = find_dest_language(en_question, question_src_lang) cleaned_sentence = remove_language_phrase(en_question) vqa_answer = vqa(image, cleaned_sentence) llm_answer = llm(cleaned_sentence, vqa_answer) final_answer, _ = google_translate(llm_answer, dest=dest_lang) print("Final answer: ", final_answer) return vqa_answer, final_answer image = gr.Image(type="pil") question = gr.Textbox(label="Question") answer = gr.Textbox(label="Predicted answer") examples = [ ["monkeys.jpg", "What are the monkeys doing in French?"], ["apple.jpg", "Qu'est-ce que c'est dans ma main en anglais?"], ["monkeys.jpg", "In Uyghur, tell me how many monkeys are here?"], ["apple.jpg", "What color is this fruit? Chinese."] ] title = "Cross-lingual VQA" description = """ Visual Question Answering (VQA) across Languages Input an image and ask a question regarding the image. Click on the examples below to see sample inputs and outputs. """ article = """ Supports questions regarding the following languages: ['afrikaans', 'albanian', 'amharic', 'arabic', 'armenian', 'azerbaijani', 'basque', 'belarusian', 'bengali', 'bosnian', 'bulgarian', 'catalan', 'cebuano', 'chichewa', 'chinese (simplified)', 'chinese (traditional)', 'corsican', 'croatian', 'czech', 'danish', 'dutch', 'english', 'esperanto', 'estonian', 'filipino', 'finnish', 'french', 'frisian', 'galician', 'georgian', 'german', 'greek', 'gujarati', 'haitian creole', 'hausa', 'hawaiian', 'hebrew', 'hebrew', 'hindi', 'hmong', 'hungarian', 'icelandic', 'igbo', 'indonesian', 'irish', 'italian', 'japanese', 'javanese', 'kannada', 'kazakh', 'khmer', 'korean', 'kurdish (kurmanji)', 'kyrgyz', 'lao', 'latin', 'latvian', 'lithuanian', 'luxembourgish', 'macedonian', 'malagasy', 'malay', 'malayalam', 'maltese', 'maori', 'marathi', 'mongolian', 'myanmar (burmese)', 'nepali', 'norwegian', 'odia', 'pashto', 'persian', 'polish', 'portuguese', 'punjabi', 'romanian', 'russian', 'samoan', 'scots gaelic', 'serbian', 'sesotho', 'shona', 'sindhi', 'sinhala', 'slovak', 'slovenian', 'somali', 'spanish', 'sundanese', 'swahili', 'swedish', 'tajik', 'tamil', 'telugu', 'thai', 'turkish', 'ukrainian', 'urdu', 'uyghur', 'uzbek', 'vietnamese', 'welsh', 'xhosa', 'yiddish', 'yoruba', 'zulu'] """ interface = gr.Interface(fn=main, inputs=[image, question], outputs=answer, examples=examples, title=title, description=description, article=article) interface.launch(debug=True)