import gradio as gr import torch from transformers.utils import logging from googletrans import Translator from transformers import ViltProcessor, ViltForQuestionAnswering, T5Tokenizer, T5ForConditionalGeneration logging.set_verbosity_info() logger = logging.get_logger("transformers") # Translation def google_translate(question, dest): translator = Translator() translation = translator.translate(question, dest=dest) logger.info("Translation text: " + translation.text) logger.info("Translation src: " + translation.src) return (translation.text, translation.src) # Load Vilt vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") def vilt_vqa(image, question): inputs = vilt_processor(image, question, return_tensors="pt") with torch.no_grad(): outputs = vilt_model(**inputs) logits = outputs.logits idx = logits.argmax(-1).item() answer = vilt_model.config.id2label[idx] logger.info("ViLT: " + answer) return answer # Load FLAN-T5 t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large") t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto") def flan_t5_complete_sentence(question, answer): input_text = f"A question: {question} An answer: {answer}. Based on these, answer the question with a complete sentence without extra information." logger.info("T5 input: " + input_text) inputs = t5_tokenizer(input_text, return_tensors="pt") outputs = t5_model.generate(**inputs, max_length=50) result_sentence = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True) logger.info("T5 output: " + str(result_sentence)) return result_sentence # Main function def vqa_main(image, question): en_question, question_src_lang = google_translate(question, dest='en') vqa_answer = vilt_vqa(image, en_question) llm_answer = flan_t5_complete_sentence(en_question, vqa_answer)[0] final_answer, answer_src_lang = google_translate(llm_answer, dest=question_src_lang) logger.info("Final Answer: " + final_answer) return final_answer # Home page text title = "Interactive demo: Multilingual VQA" description = """ Upload an image, type a question, click 'submit', or click one of the examples to load them. Note: You may see "Error" displayed for output due to Gradio's incompatibility, if that is the case, please click "Logs" above to obtain the output for your request. """ article = """ Supported 107 Languages: 'afrikaans', 'albanian', 'amharic', 'arabic', 'armenian', 'azerbaijani', 'basque', 'belarusian', 'bengali', 'bosnian', 'bulgarian', 'catalan', 'cebuano', 'chichewa', 'chinese (simplified)', 'chinese (traditional)', 'corsican', 'croatian', 'czech', 'danish', 'dutch', 'english', 'esperanto', 'estonian', 'filipino', 'finnish', 'french', 'frisian', 'galician', 'georgian', 'german', 'greek', 'gujarati', 'haitian creole', 'hausa', 'hawaiian', 'hebrew', 'hebrew', 'hindi', 'hmong', 'hungarian', 'icelandic', 'igbo', 'indonesian', 'irish', 'italian', 'japanese', 'javanese', 'kannada', 'kazakh', 'khmer', 'korean', 'kurdish (kurmanji)', 'kyrgyz', 'lao', 'latin', 'latvian', 'lithuanian', 'luxembourgish', 'macedonian', 'malagasy', 'malay', 'malayalam', 'maltese', 'maori', 'marathi', 'mongolian', 'myanmar (burmese)', 'nepali', 'norwegian', 'odia', 'pashto', 'persian', 'polish', 'portuguese', 'punjabi', 'romanian', 'russian', 'samoan', 'scots gaelic', 'serbian', 'sesotho', 'shona', 'sindhi', 'sinhala', 'slovak', 'slovenian', 'somali', 'spanish', 'sundanese', 'swahili', 'swedish', 'tajik', 'tamil', 'telugu', 'thai', 'turkish', 'ukrainian', 'urdu', 'uyghur', 'uzbek', 'vietnamese', 'welsh', 'xhosa', 'yiddish', 'yoruba', 'zulu' """ # Load example images torch.hub.download_url_to_file('http://farm3.staticflickr.com/2710/4520550856_7a9f9ea59d_z.jpg', 'apple.jpg') torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg') # Define home page variables image = gr.inputs.Image(type="pil") question = gr.inputs.Textbox(label="Question") answer = gr.outputs.Textbox(label="Predicted answer") examples = [["apple.jpg", "Qu'est-ce que c'est dans ma main?"], ["cats.jpg", "What are the cats doing?"]] demo = gr.Interface(fn=vqa_main, inputs=[image, question], outputs="text", examples=examples, title=title, description=description, article=article, enable_queue = True) demo.launch(debug=True, show_error = True)