Spaces:
Runtime error
Runtime error
Add grammar module to correct the generated answers.
Browse files- .DS_Store +0 -0
- .gitignore +80 -1
- app.py +7 -5
- audiobot.py +7 -7
- chatbot.py +7 -5
- images/.DS_Store +0 -0
- model/predictor.py +10 -2
- requirements.txt +17 -0
.DS_Store
DELETED
Binary file (6.15 kB)
|
|
.gitignore
CHANGED
@@ -2,9 +2,88 @@
|
|
2 |
__pycache__/
|
3 |
*.py[cod]
|
4 |
|
|
|
|
|
|
|
5 |
# Distribution / packaging
|
6 |
.Python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# Installer logs
|
9 |
pip-log.txt
|
10 |
-
pip-delete-this-directory.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
__pycache__/
|
3 |
*.py[cod]
|
4 |
|
5 |
+
# C extensions
|
6 |
+
*.so
|
7 |
+
|
8 |
# Distribution / packaging
|
9 |
.Python
|
10 |
+
env/
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
*.egg-info/
|
23 |
+
.installed.cfg
|
24 |
+
*.egg
|
25 |
+
|
26 |
+
# PyInstaller
|
27 |
+
# Usually these files are written by a python script from a template
|
28 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
29 |
+
*.manifest
|
30 |
+
*.spec
|
31 |
|
32 |
# Installer logs
|
33 |
pip-log.txt
|
34 |
+
pip-delete-this-directory.txt
|
35 |
+
|
36 |
+
# Unit test / coverage reports
|
37 |
+
htmlcov/
|
38 |
+
.tox/
|
39 |
+
.coverage
|
40 |
+
.coverage.*
|
41 |
+
.cache
|
42 |
+
nosetests.xml
|
43 |
+
coverage.xml
|
44 |
+
*.cover
|
45 |
+
|
46 |
+
# Translations
|
47 |
+
*.mo
|
48 |
+
*.pot
|
49 |
+
|
50 |
+
# Django stuff:
|
51 |
+
*.log
|
52 |
+
|
53 |
+
# Sphinx documentation
|
54 |
+
docs/_build/
|
55 |
+
|
56 |
+
# PyBuilder
|
57 |
+
target/
|
58 |
+
|
59 |
+
# DotEnv configuration
|
60 |
+
.env
|
61 |
+
|
62 |
+
# Database
|
63 |
+
*.db
|
64 |
+
*.rdb
|
65 |
+
|
66 |
+
# Pycharm
|
67 |
+
.idea
|
68 |
+
|
69 |
+
# VS Code
|
70 |
+
.vscode/
|
71 |
+
|
72 |
+
# Spyder
|
73 |
+
.spyproject/
|
74 |
+
|
75 |
+
# Jupyter NB Checkpoints
|
76 |
+
.ipynb_checkpoints/
|
77 |
+
|
78 |
+
# Mac OS-specific storage files
|
79 |
+
.DS_Store
|
80 |
+
|
81 |
+
# vim
|
82 |
+
*.swp
|
83 |
+
*.swo
|
84 |
+
|
85 |
+
# Mypy cache
|
86 |
+
.mypy_cache/
|
87 |
+
|
88 |
+
# exclude generated models from source control
|
89 |
+
models/intermediate/
|
app.py
CHANGED
@@ -7,9 +7,12 @@ import chatbot
|
|
7 |
import os
|
8 |
import threading
|
9 |
|
|
|
10 |
def runInThread():
|
11 |
print('Initialize model in thread')
|
12 |
st.session_state['predictor'] = predictor.Predictor()
|
|
|
|
|
13 |
|
14 |
def run():
|
15 |
st.set_page_config(
|
@@ -19,6 +22,10 @@ def run():
|
|
19 |
)
|
20 |
|
21 |
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
|
|
|
|
|
|
|
|
22 |
|
23 |
st.sidebar.title('VQA Bot')
|
24 |
st.sidebar.image('./images/logo.png')
|
@@ -37,10 +44,5 @@ def run():
|
|
37 |
|
38 |
st.caption("Created by Madhuri Sakhare - [Github](https://github.com/msak1612/vqa_chatbot) [Linkedin](https://www.linkedin.com/in/madhuri-sakhare/)")
|
39 |
|
40 |
-
if 'thread' not in st.session_state:
|
41 |
-
st.session_state.thread = threading.Thread(target=runInThread)
|
42 |
-
add_script_run_ctx(st.session_state.thread)
|
43 |
-
st.session_state.thread.start()
|
44 |
-
|
45 |
|
46 |
run()
|
|
|
7 |
import os
|
8 |
import threading
|
9 |
|
10 |
+
|
11 |
def runInThread():
|
12 |
print('Initialize model in thread')
|
13 |
st.session_state['predictor'] = predictor.Predictor()
|
14 |
+
print('Model is initialized')
|
15 |
+
|
16 |
|
17 |
def run():
|
18 |
st.set_page_config(
|
|
|
22 |
)
|
23 |
|
24 |
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
25 |
+
if 'thread' not in st.session_state:
|
26 |
+
st.session_state.thread = threading.Thread(target=runInThread)
|
27 |
+
add_script_run_ctx(st.session_state.thread)
|
28 |
+
st.session_state.thread.start()
|
29 |
|
30 |
st.sidebar.title('VQA Bot')
|
31 |
st.sidebar.image('./images/logo.png')
|
|
|
44 |
|
45 |
st.caption("Created by Madhuri Sakhare - [Github](https://github.com/msak1612/vqa_chatbot) [Linkedin](https://www.linkedin.com/in/madhuri-sakhare/)")
|
46 |
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
run()
|
audiobot.py
CHANGED
@@ -7,6 +7,7 @@ from streamlit_bokeh_events import streamlit_bokeh_events
|
|
7 |
from bokeh.models.widgets.buttons import Button
|
8 |
import time
|
9 |
|
|
|
10 |
def show():
|
11 |
st.session_state.audio_answer = ''
|
12 |
|
@@ -18,7 +19,7 @@ def show():
|
|
18 |
</i></h4>
|
19 |
''', unsafe_allow_html=True)
|
20 |
|
21 |
-
weights = [5,2]
|
22 |
image_col, audio_col = st.columns(weights)
|
23 |
with image_col:
|
24 |
upload_pic = st.file_uploader('Choose an image...', type=[
|
@@ -30,8 +31,8 @@ def show():
|
|
30 |
st.session_state.image = None
|
31 |
|
32 |
with audio_col:
|
33 |
-
welcome_text='Hello and Welcome. I have been trained as visual question answering model. You are welcome to look at any image and ask me any questions about it. I will do my best to provide the most accurate information possible based on my expertise. Select an image of interest by pressing the browse files button. Now use the Ask question button to ask a question. Please feel free to ask me any questions about this image. Now. to get my answer. press the Get answer button.'
|
34 |
-
welcome_button = Button(label='About Me')
|
35 |
welcome_button.js_on_event('button_click', CustomJS(code=f'''
|
36 |
var u = new SpeechSynthesisUtterance();
|
37 |
u.text = '{welcome_text}';
|
@@ -43,7 +44,7 @@ def show():
|
|
43 |
|
44 |
# Speech recognition based in streamlit based on
|
45 |
# https://discuss.streamlit.io/t/speech-to-text-on-client-side-using-html5-and-streamlit-bokeh-events/7888
|
46 |
-
stt_button = Button(label='Ask Question')
|
47 |
|
48 |
stt_button.js_on_event('button_click', CustomJS(code="""
|
49 |
var recognition = new webkitSpeechRecognition();
|
@@ -51,7 +52,7 @@ def show():
|
|
51 |
recognition.interimResults = false;
|
52 |
|
53 |
recognition.onresult = function (e) {
|
54 |
-
var value =
|
55 |
for (var i = e.resultIndex; i < e.results.length; ++i) {
|
56 |
if (e.results[i].isFinal) {
|
57 |
value += e.results[i][0].transcript;
|
@@ -80,8 +81,7 @@ def show():
|
|
80 |
st.session_state.audio_answer = st.session_state.predictor.predict_answer_from_text(
|
81 |
st.session_state.image, result.get('GET_TEXT'))
|
82 |
|
83 |
-
|
84 |
-
tts_button = Button(label='Get Answer')
|
85 |
tts_button.js_on_event('button_click', CustomJS(code=f"""
|
86 |
var u = new SpeechSynthesisUtterance();
|
87 |
u.text = '{st.session_state.audio_answer}';
|
|
|
7 |
from bokeh.models.widgets.buttons import Button
|
8 |
import time
|
9 |
|
10 |
+
|
11 |
def show():
|
12 |
st.session_state.audio_answer = ''
|
13 |
|
|
|
19 |
</i></h4>
|
20 |
''', unsafe_allow_html=True)
|
21 |
|
22 |
+
weights = [5, 2]
|
23 |
image_col, audio_col = st.columns(weights)
|
24 |
with image_col:
|
25 |
upload_pic = st.file_uploader('Choose an image...', type=[
|
|
|
31 |
st.session_state.image = None
|
32 |
|
33 |
with audio_col:
|
34 |
+
welcome_text = 'Hello and Welcome. I have been trained as visual question answering model. You are welcome to look at any image and ask me any questions about it. I will do my best to provide the most accurate information possible based on my expertise. Select an image of interest by pressing the browse files button. Now use the Ask question button to ask a question. Please feel free to ask me any questions about this image. Now. to get my answer. press the Get answer button.'
|
35 |
+
welcome_button = Button(label='About Me', width=100)
|
36 |
welcome_button.js_on_event('button_click', CustomJS(code=f'''
|
37 |
var u = new SpeechSynthesisUtterance();
|
38 |
u.text = '{welcome_text}';
|
|
|
44 |
|
45 |
# Speech recognition based in streamlit based on
|
46 |
# https://discuss.streamlit.io/t/speech-to-text-on-client-side-using-html5-and-streamlit-bokeh-events/7888
|
47 |
+
stt_button = Button(label='Ask Question', width=100)
|
48 |
|
49 |
stt_button.js_on_event('button_click', CustomJS(code="""
|
50 |
var recognition = new webkitSpeechRecognition();
|
|
|
52 |
recognition.interimResults = false;
|
53 |
|
54 |
recognition.onresult = function (e) {
|
55 |
+
var value = '';
|
56 |
for (var i = e.resultIndex; i < e.results.length; ++i) {
|
57 |
if (e.results[i].isFinal) {
|
58 |
value += e.results[i][0].transcript;
|
|
|
81 |
st.session_state.audio_answer = st.session_state.predictor.predict_answer_from_text(
|
82 |
st.session_state.image, result.get('GET_TEXT'))
|
83 |
|
84 |
+
tts_button = Button(label='Get Answer', width=100)
|
|
|
85 |
tts_button.js_on_event('button_click', CustomJS(code=f"""
|
86 |
var u = new SpeechSynthesisUtterance();
|
87 |
u.text = '{st.session_state.audio_answer}';
|
chatbot.py
CHANGED
@@ -3,6 +3,7 @@ from streamlit_chat import message
|
|
3 |
from PIL import Image
|
4 |
import time
|
5 |
|
|
|
6 |
def init_chat_history():
|
7 |
if 'question' not in st.session_state:
|
8 |
st.session_state['question'] = []
|
@@ -23,9 +24,10 @@ def predict(image, input):
|
|
23 |
if image is None or not input:
|
24 |
return
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
29 |
|
30 |
answer = st.session_state.predictor.predict_answer_from_text(image, input)
|
31 |
st.session_state.question.append(input)
|
@@ -51,8 +53,8 @@ def show():
|
|
51 |
image = Image.open(upload_pic)
|
52 |
st.image(upload_pic, use_column_width='auto')
|
53 |
else:
|
54 |
-
st.session_state.question=[]
|
55 |
-
st.session_state.answer=[]
|
56 |
st.session_state.input = ''
|
57 |
with text_col:
|
58 |
input = st.text_input('Enter question: ', '', key='input')
|
|
|
3 |
from PIL import Image
|
4 |
import time
|
5 |
|
6 |
+
|
7 |
def init_chat_history():
|
8 |
if 'question' not in st.session_state:
|
9 |
st.session_state['question'] = []
|
|
|
24 |
if image is None or not input:
|
25 |
return
|
26 |
|
27 |
+
if 'predictor' not in st.session_state:
|
28 |
+
with st.spinner('Preparing answer...'):
|
29 |
+
while 'predictor' not in st.session_state:
|
30 |
+
time.sleep(2)
|
31 |
|
32 |
answer = st.session_state.predictor.predict_answer_from_text(image, input)
|
33 |
st.session_state.question.append(input)
|
|
|
53 |
image = Image.open(upload_pic)
|
54 |
st.image(upload_pic, use_column_width='auto')
|
55 |
else:
|
56 |
+
st.session_state.question = []
|
57 |
+
st.session_state.answer = []
|
58 |
st.session_state.input = ''
|
59 |
with text_col:
|
60 |
input = st.text_input('Enter question: ', '', key='input')
|
images/.DS_Store
DELETED
Binary file (6.15 kB)
|
|
model/predictor.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import streamlit as st
|
|
|
2 |
from transformers import ViltProcessor
|
3 |
from transformers import ViltForQuestionAnswering
|
4 |
from transformers import AutoTokenizer
|
@@ -28,7 +29,9 @@ class Predictor:
|
|
28 |
'Madhuri/t5_small_vqa_fs', use_auth_token=auth_token)
|
29 |
self.qa_tokenizer = AutoTokenizer.from_pretrained(
|
30 |
'Madhuri/t5_small_vqa_fs', use_auth_token=auth_token)
|
31 |
-
|
|
|
|
|
32 |
|
33 |
def predict_answer_from_text(self, image, input):
|
34 |
if image is None:
|
@@ -54,4 +57,9 @@ class Predictor:
|
|
54 |
answers = self.qa_tokenizer.batch_decode(
|
55 |
output_ids, skip_special_tokens=True)
|
56 |
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from happytransformer import HappyTextToText, TTSettings
|
3 |
from transformers import ViltProcessor
|
4 |
from transformers import ViltForQuestionAnswering
|
5 |
from transformers import AutoTokenizer
|
|
|
29 |
'Madhuri/t5_small_vqa_fs', use_auth_token=auth_token)
|
30 |
self.qa_tokenizer = AutoTokenizer.from_pretrained(
|
31 |
'Madhuri/t5_small_vqa_fs', use_auth_token=auth_token)
|
32 |
+
self.happy_tt = HappyTextToText(
|
33 |
+
"T5", "vennify/t5-base-grammar-correction")
|
34 |
+
self.tt_args = TTSettings(num_beams=5, min_length=1)
|
35 |
|
36 |
def predict_answer_from_text(self, image, input):
|
37 |
if image is None:
|
|
|
57 |
answers = self.qa_tokenizer.batch_decode(
|
58 |
output_ids, skip_special_tokens=True)
|
59 |
|
60 |
+
# Correct the grammar of the answer
|
61 |
+
answer = self.happy_tt.generate_text(
|
62 |
+
'grammar: ' + answers[0], args=self.tt_args).text
|
63 |
+
print(
|
64 |
+
f'question - {question}, answer - {answer}, original_answer - {answers[0]}')
|
65 |
+
return answer
|
requirements.txt
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
altair==4.2.0
|
2 |
ansicolors==1.1.8
|
3 |
ansiwrap==0.8.4
|
@@ -5,6 +7,7 @@ appnope==0.1.3
|
|
5 |
argon2-cffi==21.3.0
|
6 |
argon2-cffi-bindings==21.2.0
|
7 |
asttokens==2.0.5
|
|
|
8 |
attrs==21.4.0
|
9 |
backcall==0.2.0
|
10 |
beautifulsoup4==4.11.1
|
@@ -19,15 +22,20 @@ charset-normalizer==2.0.12
|
|
19 |
click==8.1.3
|
20 |
combomethod==1.0.12
|
21 |
commonmark==0.9.1
|
|
|
22 |
debugpy==1.6.0
|
23 |
decorator==5.1.1
|
24 |
defusedxml==0.7.1
|
|
|
25 |
entrypoints==0.4
|
26 |
executing==0.8.3
|
27 |
fastjsonschema==2.15.3
|
28 |
filelock==3.7.1
|
|
|
|
|
29 |
gitdb==4.0.9
|
30 |
GitPython==3.1.27
|
|
|
31 |
huggingface-hub==0.7.0
|
32 |
idna==3.3
|
33 |
importlib-metadata==4.11.4
|
@@ -37,19 +45,24 @@ ipython-genutils==0.2.0
|
|
37 |
ipywidgets==7.7.0
|
38 |
jedi==0.18.1
|
39 |
Jinja2==3.1.2
|
|
|
40 |
jsonschema==4.6.0
|
41 |
jupyter-client==7.3.4
|
42 |
jupyter-core==4.10.0
|
43 |
jupyterlab-pygments==0.2.2
|
44 |
jupyterlab-widgets==1.1.0
|
|
|
45 |
MarkupSafe==2.1.1
|
46 |
matplotlib-inline==0.1.3
|
47 |
mementos==1.3.1
|
48 |
mistune==0.8.4
|
|
|
|
|
49 |
nbclient==0.6.4
|
50 |
nbconvert==6.5.0
|
51 |
nbformat==5.4.0
|
52 |
nest-asyncio==1.5.5
|
|
|
53 |
notebook==6.4.12
|
54 |
nulltype==2.3.1
|
55 |
numpy==1.22.4
|
@@ -81,10 +94,12 @@ PyYAML==6.0
|
|
81 |
pyzmq==23.1.0
|
82 |
regex==2022.6.2
|
83 |
requests==2.28.0
|
|
|
84 |
rich==12.4.4
|
85 |
say==1.6.6
|
86 |
semver==2.13.0
|
87 |
Send2Trash==1.8.0
|
|
|
88 |
simplere==1.2.13
|
89 |
six==1.12.0
|
90 |
smmap==5.0.0
|
@@ -112,4 +127,6 @@ validators==0.20.0
|
|
112 |
wcwidth==0.2.5
|
113 |
webencodings==0.5.1
|
114 |
widgetsnbextension==3.6.0
|
|
|
|
|
115 |
zipp==3.8.0
|
|
|
1 |
+
aiohttp==3.8.1
|
2 |
+
aiosignal==1.2.0
|
3 |
altair==4.2.0
|
4 |
ansicolors==1.1.8
|
5 |
ansiwrap==0.8.4
|
|
|
7 |
argon2-cffi==21.3.0
|
8 |
argon2-cffi-bindings==21.2.0
|
9 |
asttokens==2.0.5
|
10 |
+
async-timeout==4.0.2
|
11 |
attrs==21.4.0
|
12 |
backcall==0.2.0
|
13 |
beautifulsoup4==4.11.1
|
|
|
22 |
click==8.1.3
|
23 |
combomethod==1.0.12
|
24 |
commonmark==0.9.1
|
25 |
+
datasets==2.3.2
|
26 |
debugpy==1.6.0
|
27 |
decorator==5.1.1
|
28 |
defusedxml==0.7.1
|
29 |
+
dill==0.3.5.1
|
30 |
entrypoints==0.4
|
31 |
executing==0.8.3
|
32 |
fastjsonschema==2.15.3
|
33 |
filelock==3.7.1
|
34 |
+
frozenlist==1.3.0
|
35 |
+
fsspec==2022.5.0
|
36 |
gitdb==4.0.9
|
37 |
GitPython==3.1.27
|
38 |
+
happytransformer==2.4.1
|
39 |
huggingface-hub==0.7.0
|
40 |
idna==3.3
|
41 |
importlib-metadata==4.11.4
|
|
|
45 |
ipywidgets==7.7.0
|
46 |
jedi==0.18.1
|
47 |
Jinja2==3.1.2
|
48 |
+
joblib==1.1.0
|
49 |
jsonschema==4.6.0
|
50 |
jupyter-client==7.3.4
|
51 |
jupyter-core==4.10.0
|
52 |
jupyterlab-pygments==0.2.2
|
53 |
jupyterlab-widgets==1.1.0
|
54 |
+
language-tool-python==2.7.1
|
55 |
MarkupSafe==2.1.1
|
56 |
matplotlib-inline==0.1.3
|
57 |
mementos==1.3.1
|
58 |
mistune==0.8.4
|
59 |
+
multidict==6.0.2
|
60 |
+
multiprocess==0.70.13
|
61 |
nbclient==0.6.4
|
62 |
nbconvert==6.5.0
|
63 |
nbformat==5.4.0
|
64 |
nest-asyncio==1.5.5
|
65 |
+
nltk==3.7
|
66 |
notebook==6.4.12
|
67 |
nulltype==2.3.1
|
68 |
numpy==1.22.4
|
|
|
94 |
pyzmq==23.1.0
|
95 |
regex==2022.6.2
|
96 |
requests==2.28.0
|
97 |
+
responses==0.18.0
|
98 |
rich==12.4.4
|
99 |
say==1.6.6
|
100 |
semver==2.13.0
|
101 |
Send2Trash==1.8.0
|
102 |
+
sentencepiece==0.1.96
|
103 |
simplere==1.2.13
|
104 |
six==1.12.0
|
105 |
smmap==5.0.0
|
|
|
127 |
wcwidth==0.2.5
|
128 |
webencodings==0.5.1
|
129 |
widgetsnbextension==3.6.0
|
130 |
+
xxhash==3.0.0
|
131 |
+
yarl==1.7.2
|
132 |
zipp==3.8.0
|