import streamlit as st from transformers import ViltProcessor from transformers import ViltForQuestionAnswering from transformers import AutoTokenizer from transformers import AutoModelForSeq2SeqLM import os import torch ''' Visual Question Answering Model to generate answer statement for question. ''' @st.experimental_singleton class Predictor: def __init__(self): auth_token = os.environ.get('TOKEN') or True self.vqa_processor = ViltProcessor.from_pretrained( 'dandelin/vilt-b32-finetuned-vqa') self.vqa_model = ViltForQuestionAnswering.from_pretrained( 'dandelin/vilt-b32-finetuned-vqa') self.qa_model = AutoModelForSeq2SeqLM.from_pretrained( 'Madhuri/t5_small_vqa_fs', use_auth_token=auth_token) self.qa_tokenizer = AutoTokenizer.from_pretrained( 'Madhuri/t5_small_vqa_fs', use_auth_token=auth_token) def predict_answer_from_text(self, image, question): if not question or image is None: return '' # process question using image model encoding = self.vqa_processor(image, question, return_tensors='pt') with torch.no_grad(): outputs = self.vqa_model(**encoding) short_answer = self.vqa_model.config.id2label[outputs.logits.argmax( -1).item()] # generate statement using sentence generator model prompt = question + '. ' + short_answer input_ids = self.qa_tokenizer(prompt, return_tensors='pt').input_ids with torch.no_grad(): output_ids = self.qa_model.generate(input_ids) answers = self.qa_tokenizer.batch_decode( output_ids, skip_special_tokens=True) return answers[0]