from happytransformer import HappyTextToText, TTSettings from transformers import ViltProcessor from transformers import ViltForQuestionAnswering from transformers import AutoTokenizer from transformers import AutoModelForSeq2SeqLM from joblib import load import os import re import string import torch import pandas as pd ''' Visual Question Answering Model to generate answer statement for question. ''' class Predictor: def __init__(self): auth_token = os.environ.get('TOKEN') or True self.vqa_processor = ViltProcessor.from_pretrained( 'dandelin/vilt-b32-finetuned-vqa') self.vqa_model = ViltForQuestionAnswering.from_pretrained( 'dandelin/vilt-b32-finetuned-vqa') self.qa_model = AutoModelForSeq2SeqLM.from_pretrained( 'Madhuri/t5_small_vqa_fs', use_auth_token=auth_token) self.qa_tokenizer = AutoTokenizer.from_pretrained( 'Madhuri/t5_small_vqa_fs', use_auth_token=auth_token) self.happy_tt = HappyTextToText( "T5", "vennify/t5-base-grammar-correction") self.tt_args = TTSettings(num_beams=5, min_length=1) model_path= os.path.join( os.path.dirname(os.path.abspath(__file__)), 'qa_classifier.joblib') self.qa_classifier = load(model_path) def is_valid_question(self, question): df=pd.DataFrame() df['sentence']=[question] return self.qa_classifier.predict(df['sentence'])[0] == 1 def predict_answer_from_text(self, image, input): if image is None: return 'Please select an image and ask a question...' chars = re.escape(string.punctuation) question = re.sub(r'['+chars+']', '', input) if not question or len(question.split()) < 3: return 'I cannot understand, please ask a valid question...' if not self.is_valid_question(question): return 'I can understand only questions, can you please ask a valid question...' # process question using image model encoding = self.vqa_processor(image, question, return_tensors='pt') with torch.no_grad(): outputs = self.vqa_model(**encoding) short_answer = self.vqa_model.config.id2label[outputs.logits.argmax( -1).item()] # generate statement using sentence generator model prompt = question + '. ' + short_answer input_ids = self.qa_tokenizer(prompt, return_tensors='pt').input_ids with torch.no_grad(): output_ids = self.qa_model.generate(input_ids) answers = self.qa_tokenizer.batch_decode( output_ids, skip_special_tokens=True) # Correct the grammar of the answer answer = self.happy_tt.generate_text( 'grammar: ' + answers[0], args=self.tt_args).text print( f'question - {question}, answer - {answer}, original_answer - {answers[0]}') return answer