Spaces:
Runtime error
Runtime error
Add question validation model in front of VQA
Browse files- chatbot.py +3 -0
- model/predictor.py +13 -2
- model/qa_classifier.joblib +3 -0
- requirements.txt +4 -0
chatbot.py
CHANGED
@@ -33,6 +33,9 @@ def predict(image, input):
|
|
33 |
answer = st.session_state.predictor.predict_answer_from_text(image, input)
|
34 |
st.session_state.question.append(input)
|
35 |
st.session_state.answer.append(answer)
|
|
|
|
|
|
|
36 |
|
37 |
|
38 |
def show():
|
33 |
answer = st.session_state.predictor.predict_answer_from_text(image, input)
|
34 |
st.session_state.question.append(input)
|
35 |
st.session_state.answer.append(answer)
|
36 |
+
while len(st.session_state.question) >= 5:
|
37 |
+
st.session_state.answer.pop(0)
|
38 |
+
st.session_state.question.pop(0)
|
39 |
|
40 |
|
41 |
def show():
|
model/predictor.py
CHANGED
@@ -4,12 +4,13 @@ from transformers import ViltProcessor
|
|
4 |
from transformers import ViltForQuestionAnswering
|
5 |
from transformers import AutoTokenizer
|
6 |
from transformers import AutoModelForSeq2SeqLM
|
|
|
7 |
|
8 |
import os
|
9 |
import re
|
10 |
import string
|
11 |
import torch
|
12 |
-
|
13 |
|
14 |
'''
|
15 |
Visual Question Answering Model to generate answer statement for
|
@@ -32,16 +33,26 @@ class Predictor:
|
|
32 |
self.happy_tt = HappyTextToText(
|
33 |
"T5", "vennify/t5-base-grammar-correction")
|
34 |
self.tt_args = TTSettings(num_beams=5, min_length=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
def predict_answer_from_text(self, image, input):
|
37 |
if image is None:
|
38 |
-
return 'Please select an image...'
|
39 |
|
40 |
chars = re.escape(string.punctuation)
|
41 |
question = re.sub(r'['+chars+']', '', input)
|
42 |
if not question or len(question.split()) < 3:
|
43 |
return 'I cannot understand, please ask a valid question...'
|
44 |
|
|
|
|
|
|
|
45 |
# process question using image model
|
46 |
encoding = self.vqa_processor(image, question, return_tensors='pt')
|
47 |
with torch.no_grad():
|
4 |
from transformers import ViltForQuestionAnswering
|
5 |
from transformers import AutoTokenizer
|
6 |
from transformers import AutoModelForSeq2SeqLM
|
7 |
+
from joblib import load
|
8 |
|
9 |
import os
|
10 |
import re
|
11 |
import string
|
12 |
import torch
|
13 |
+
import pandas as pd
|
14 |
|
15 |
'''
|
16 |
Visual Question Answering Model to generate answer statement for
|
33 |
self.happy_tt = HappyTextToText(
|
34 |
"T5", "vennify/t5-base-grammar-correction")
|
35 |
self.tt_args = TTSettings(num_beams=5, min_length=1)
|
36 |
+
model_path= os.path.join( os.path.dirname(os.path.abspath(__file__)), 'qa_classifier.joblib')
|
37 |
+
self.qa_classifier = load(model_path)
|
38 |
+
|
39 |
+
def is_valid_question(self, question):
|
40 |
+
df=pd.DataFrame()
|
41 |
+
df['sentence']=[question]
|
42 |
+
return self.qa_classifier.predict(df['sentence'])[0] == 1
|
43 |
|
44 |
def predict_answer_from_text(self, image, input):
|
45 |
if image is None:
|
46 |
+
return 'Please select an image and ask a question...'
|
47 |
|
48 |
chars = re.escape(string.punctuation)
|
49 |
question = re.sub(r'['+chars+']', '', input)
|
50 |
if not question or len(question.split()) < 3:
|
51 |
return 'I cannot understand, please ask a valid question...'
|
52 |
|
53 |
+
if not self.is_valid_question(question):
|
54 |
+
return 'I can understand only questions, can you please ask a valid question...'
|
55 |
+
|
56 |
# process question using image model
|
57 |
encoding = self.vqa_processor(image, question, return_tensors='pt')
|
58 |
with torch.no_grad():
|
model/qa_classifier.joblib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:37ab3777a05c935a42303a7ad07e4dc8bfc6ea07ce8a3da4dbfab16104f9c7af
|
3 |
+
size 197373
|
requirements.txt
CHANGED
@@ -97,11 +97,14 @@ requests==2.28.0
|
|
97 |
responses==0.18.0
|
98 |
rich==12.4.4
|
99 |
say==1.6.6
|
|
|
|
|
100 |
semver==2.13.0
|
101 |
Send2Trash==1.8.0
|
102 |
sentencepiece==0.1.96
|
103 |
simplere==1.2.13
|
104 |
six==1.12.0
|
|
|
105 |
smmap==5.0.0
|
106 |
soupsieve==2.3.2.post1
|
107 |
stack-data==0.3.0
|
@@ -110,6 +113,7 @@ streamlit-bokeh-events==0.1.2
|
|
110 |
streamlit-chat==0.0.2.1
|
111 |
terminado==0.15.0
|
112 |
textwrap3==0.9.2
|
|
|
113 |
tinycss2==1.1.1
|
114 |
tokenizers==0.12.1
|
115 |
toml==0.10.2
|
97 |
responses==0.18.0
|
98 |
rich==12.4.4
|
99 |
say==1.6.6
|
100 |
+
scikit-learn==1.1.1
|
101 |
+
scipy==1.8.1
|
102 |
semver==2.13.0
|
103 |
Send2Trash==1.8.0
|
104 |
sentencepiece==0.1.96
|
105 |
simplere==1.2.13
|
106 |
six==1.12.0
|
107 |
+
sklearn==0.0
|
108 |
smmap==5.0.0
|
109 |
soupsieve==2.3.2.post1
|
110 |
stack-data==0.3.0
|
113 |
streamlit-chat==0.0.2.1
|
114 |
terminado==0.15.0
|
115 |
textwrap3==0.9.2
|
116 |
+
threadpoolctl==3.1.0
|
117 |
tinycss2==1.1.1
|
118 |
tokenizers==0.12.1
|
119 |
toml==0.10.2
|