Madhuri commited on
Commit
78c15e2
1 Parent(s): 460b215

Add question validation model in front of VQA

Browse files
chatbot.py CHANGED
@@ -33,6 +33,9 @@ def predict(image, input):
33
  answer = st.session_state.predictor.predict_answer_from_text(image, input)
34
  st.session_state.question.append(input)
35
  st.session_state.answer.append(answer)
 
 
 
36
 
37
 
38
  def show():
33
  answer = st.session_state.predictor.predict_answer_from_text(image, input)
34
  st.session_state.question.append(input)
35
  st.session_state.answer.append(answer)
36
+ while len(st.session_state.question) >= 5:
37
+ st.session_state.answer.pop(0)
38
+ st.session_state.question.pop(0)
39
 
40
 
41
  def show():
model/predictor.py CHANGED
@@ -4,12 +4,13 @@ from transformers import ViltProcessor
4
  from transformers import ViltForQuestionAnswering
5
  from transformers import AutoTokenizer
6
  from transformers import AutoModelForSeq2SeqLM
 
7
 
8
  import os
9
  import re
10
  import string
11
  import torch
12
-
13
 
14
  '''
15
  Visual Question Answering Model to generate answer statement for
@@ -32,16 +33,26 @@ class Predictor:
32
  self.happy_tt = HappyTextToText(
33
  "T5", "vennify/t5-base-grammar-correction")
34
  self.tt_args = TTSettings(num_beams=5, min_length=1)
 
 
 
 
 
 
 
35
 
36
  def predict_answer_from_text(self, image, input):
37
  if image is None:
38
- return 'Please select an image...'
39
 
40
  chars = re.escape(string.punctuation)
41
  question = re.sub(r'['+chars+']', '', input)
42
  if not question or len(question.split()) < 3:
43
  return 'I cannot understand, please ask a valid question...'
44
 
 
 
 
45
  # process question using image model
46
  encoding = self.vqa_processor(image, question, return_tensors='pt')
47
  with torch.no_grad():
4
  from transformers import ViltForQuestionAnswering
5
  from transformers import AutoTokenizer
6
  from transformers import AutoModelForSeq2SeqLM
7
+ from joblib import load
8
 
9
  import os
10
  import re
11
  import string
12
  import torch
13
+ import pandas as pd
14
 
15
  '''
16
  Visual Question Answering Model to generate answer statement for
33
  self.happy_tt = HappyTextToText(
34
  "T5", "vennify/t5-base-grammar-correction")
35
  self.tt_args = TTSettings(num_beams=5, min_length=1)
36
+ model_path= os.path.join( os.path.dirname(os.path.abspath(__file__)), 'qa_classifier.joblib')
37
+ self.qa_classifier = load(model_path)
38
+
39
+ def is_valid_question(self, question):
40
+ df=pd.DataFrame()
41
+ df['sentence']=[question]
42
+ return self.qa_classifier.predict(df['sentence'])[0] == 1
43
 
44
  def predict_answer_from_text(self, image, input):
45
  if image is None:
46
+ return 'Please select an image and ask a question...'
47
 
48
  chars = re.escape(string.punctuation)
49
  question = re.sub(r'['+chars+']', '', input)
50
  if not question or len(question.split()) < 3:
51
  return 'I cannot understand, please ask a valid question...'
52
 
53
+ if not self.is_valid_question(question):
54
+ return 'I can understand only questions, can you please ask a valid question...'
55
+
56
  # process question using image model
57
  encoding = self.vqa_processor(image, question, return_tensors='pt')
58
  with torch.no_grad():
model/qa_classifier.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37ab3777a05c935a42303a7ad07e4dc8bfc6ea07ce8a3da4dbfab16104f9c7af
3
+ size 197373
requirements.txt CHANGED
@@ -97,11 +97,14 @@ requests==2.28.0
97
  responses==0.18.0
98
  rich==12.4.4
99
  say==1.6.6
 
 
100
  semver==2.13.0
101
  Send2Trash==1.8.0
102
  sentencepiece==0.1.96
103
  simplere==1.2.13
104
  six==1.12.0
 
105
  smmap==5.0.0
106
  soupsieve==2.3.2.post1
107
  stack-data==0.3.0
@@ -110,6 +113,7 @@ streamlit-bokeh-events==0.1.2
110
  streamlit-chat==0.0.2.1
111
  terminado==0.15.0
112
  textwrap3==0.9.2
 
113
  tinycss2==1.1.1
114
  tokenizers==0.12.1
115
  toml==0.10.2
97
  responses==0.18.0
98
  rich==12.4.4
99
  say==1.6.6
100
+ scikit-learn==1.1.1
101
+ scipy==1.8.1
102
  semver==2.13.0
103
  Send2Trash==1.8.0
104
  sentencepiece==0.1.96
105
  simplere==1.2.13
106
  six==1.12.0
107
+ sklearn==0.0
108
  smmap==5.0.0
109
  soupsieve==2.3.2.post1
110
  stack-data==0.3.0
113
  streamlit-chat==0.0.2.1
114
  terminado==0.15.0
115
  textwrap3==0.9.2
116
+ threadpoolctl==3.1.0
117
  tinycss2==1.1.1
118
  tokenizers==0.12.1
119
  toml==0.10.2