File size: 2,929 Bytes
5560825
7a69915
 
 
 
78c15e2
7a69915
 
3089ae4
 
7a69915
78c15e2
7a69915
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5560825
 
 
78c15e2
 
 
 
 
 
 
3089ae4
 
 
78c15e2
3089ae4
 
 
 
 
7a69915
78c15e2
 
 
7a69915
 
 
 
 
 
 
 
 
 
 
 
 
 
3089ae4
5560825
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from happytransformer import HappyTextToText, TTSettings
from transformers import ViltProcessor
from transformers import ViltForQuestionAnswering
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from joblib import load

import os
import re
import string
import torch
import pandas as pd

'''
Visual Question Answering Model to generate answer statement for
question.
'''


class Predictor:
    def __init__(self):
        auth_token = os.environ.get('TOKEN') or True
        self.vqa_processor = ViltProcessor.from_pretrained(
            'dandelin/vilt-b32-finetuned-vqa')
        self.vqa_model = ViltForQuestionAnswering.from_pretrained(
            'dandelin/vilt-b32-finetuned-vqa')
        self.qa_model = AutoModelForSeq2SeqLM.from_pretrained(
            'Madhuri/t5_small_vqa_fs',  use_auth_token=auth_token)
        self.qa_tokenizer = AutoTokenizer.from_pretrained(
            'Madhuri/t5_small_vqa_fs', use_auth_token=auth_token)
        self.happy_tt = HappyTextToText(
            "T5", "vennify/t5-base-grammar-correction")
        self.tt_args = TTSettings(num_beams=5, min_length=1)
        model_path= os.path.join( os.path.dirname(os.path.abspath(__file__)), 'qa_classifier.joblib')
        self.qa_classifier = load(model_path)

    def is_valid_question(self, question):
        df=pd.DataFrame()
        df['sentence']=[question]
        return self.qa_classifier.predict(df['sentence'])[0] == 1

    def predict_answer_from_text(self, image, input):
        if image is None:
            return 'Please select an image and ask a question...'

        chars = re.escape(string.punctuation)
        question = re.sub(r'['+chars+']', '', input)
        if not question or len(question.split()) < 3:
            return 'I cannot understand, please ask a valid question...'

        if not self.is_valid_question(question):
            return 'I can understand only questions, can you please ask a valid question...'

        # process question using image model
        encoding = self.vqa_processor(image, question, return_tensors='pt')
        with torch.no_grad():
            outputs = self.vqa_model(**encoding)
        short_answer = self.vqa_model.config.id2label[outputs.logits.argmax(
            -1).item()]

        # generate statement using sentence generator model
        prompt = question + '. ' + short_answer
        input_ids = self.qa_tokenizer(prompt, return_tensors='pt').input_ids
        with torch.no_grad():
            output_ids = self.qa_model.generate(input_ids)
        answers = self.qa_tokenizer.batch_decode(
            output_ids, skip_special_tokens=True)

        # Correct the grammar of the answer
        answer = self.happy_tt.generate_text(
            'grammar: ' + answers[0], args=self.tt_args).text
        print(
            f'question - {question}, answer - {answer}, original_answer - {answers[0]}')
        return answer