Spaces:
Runtime error
Runtime error
from happytransformer import HappyTextToText, TTSettings | |
from transformers import ViltProcessor | |
from transformers import ViltForQuestionAnswering | |
from transformers import AutoTokenizer | |
from transformers import AutoModelForSeq2SeqLM | |
from joblib import load | |
import os | |
import re | |
import string | |
import torch | |
import pandas as pd | |
''' | |
Visual Question Answering Model to generate answer statement for | |
question. | |
''' | |
class Predictor: | |
def __init__(self): | |
auth_token = os.environ.get('TOKEN') or True | |
self.vqa_processor = ViltProcessor.from_pretrained( | |
'dandelin/vilt-b32-finetuned-vqa') | |
self.vqa_model = ViltForQuestionAnswering.from_pretrained( | |
'dandelin/vilt-b32-finetuned-vqa') | |
self.qa_model = AutoModelForSeq2SeqLM.from_pretrained( | |
'Madhuri/t5_small_vqa_fs', use_auth_token=auth_token) | |
self.qa_tokenizer = AutoTokenizer.from_pretrained( | |
'Madhuri/t5_small_vqa_fs', use_auth_token=auth_token) | |
self.happy_tt = HappyTextToText( | |
"T5", "vennify/t5-base-grammar-correction") | |
self.tt_args = TTSettings(num_beams=5, min_length=1) | |
model_path= os.path.join( os.path.dirname(os.path.abspath(__file__)), 'qa_classifier.joblib') | |
self.qa_classifier = load(model_path) | |
def is_valid_question(self, question): | |
df=pd.DataFrame() | |
df['sentence']=[question] | |
return self.qa_classifier.predict(df['sentence'])[0] == 1 | |
def predict_answer_from_text(self, image, input): | |
if image is None: | |
return 'Please select an image and ask a question...' | |
chars = re.escape(string.punctuation) | |
question = re.sub(r'['+chars+']', '', input) | |
if not question or len(question.split()) < 3: | |
return 'I cannot understand, please ask a valid question...' | |
if not self.is_valid_question(question): | |
return 'I can understand only questions, can you please ask a valid question...' | |
# process question using image model | |
encoding = self.vqa_processor(image, question, return_tensors='pt') | |
with torch.no_grad(): | |
outputs = self.vqa_model(**encoding) | |
short_answer = self.vqa_model.config.id2label[outputs.logits.argmax( | |
-1).item()] | |
# generate statement using sentence generator model | |
prompt = question + '. ' + short_answer | |
input_ids = self.qa_tokenizer(prompt, return_tensors='pt').input_ids | |
with torch.no_grad(): | |
output_ids = self.qa_model.generate(input_ids) | |
answers = self.qa_tokenizer.batch_decode( | |
output_ids, skip_special_tokens=True) | |
# Correct the grammar of the answer | |
answer = self.happy_tt.generate_text( | |
'grammar: ' + answers[0], args=self.tt_args).text | |
print( | |
f'question - {question}, answer - {answer}, original_answer - {answers[0]}') | |
return answer | |