Madhuri commited on
Commit
5560825
1 Parent(s): 7aa61b0

Add grammar module to correct the generated answers.

Browse files
Files changed (8) hide show
  1. .DS_Store +0 -0
  2. .gitignore +80 -1
  3. app.py +7 -5
  4. audiobot.py +7 -7
  5. chatbot.py +7 -5
  6. images/.DS_Store +0 -0
  7. model/predictor.py +10 -2
  8. requirements.txt +17 -0
.DS_Store DELETED
Binary file (6.15 kB)
 
.gitignore CHANGED
@@ -2,9 +2,88 @@
2
  __pycache__/
3
  *.py[cod]
4
 
 
 
 
5
  # Distribution / packaging
6
  .Python
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Installer logs
9
  pip-log.txt
10
- pip-delete-this-directory.txt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  __pycache__/
3
  *.py[cod]
4
 
5
+ # C extensions
6
+ *.so
7
+
8
  # Distribution / packaging
9
  .Python
10
+ env/
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+
26
+ # PyInstaller
27
+ # Usually these files are written by a python script from a template
28
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
29
+ *.manifest
30
+ *.spec
31
 
32
  # Installer logs
33
  pip-log.txt
34
+ pip-delete-this-directory.txt
35
+
36
+ # Unit test / coverage reports
37
+ htmlcov/
38
+ .tox/
39
+ .coverage
40
+ .coverage.*
41
+ .cache
42
+ nosetests.xml
43
+ coverage.xml
44
+ *.cover
45
+
46
+ # Translations
47
+ *.mo
48
+ *.pot
49
+
50
+ # Django stuff:
51
+ *.log
52
+
53
+ # Sphinx documentation
54
+ docs/_build/
55
+
56
+ # PyBuilder
57
+ target/
58
+
59
+ # DotEnv configuration
60
+ .env
61
+
62
+ # Database
63
+ *.db
64
+ *.rdb
65
+
66
+ # Pycharm
67
+ .idea
68
+
69
+ # VS Code
70
+ .vscode/
71
+
72
+ # Spyder
73
+ .spyproject/
74
+
75
+ # Jupyter NB Checkpoints
76
+ .ipynb_checkpoints/
77
+
78
+ # Mac OS-specific storage files
79
+ .DS_Store
80
+
81
+ # vim
82
+ *.swp
83
+ *.swo
84
+
85
+ # Mypy cache
86
+ .mypy_cache/
87
+
88
+ # exclude generated models from source control
89
+ models/intermediate/
app.py CHANGED
@@ -7,9 +7,12 @@ import chatbot
7
  import os
8
  import threading
9
 
 
10
  def runInThread():
11
  print('Initialize model in thread')
12
  st.session_state['predictor'] = predictor.Predictor()
 
 
13
 
14
  def run():
15
  st.set_page_config(
@@ -19,6 +22,10 @@ def run():
19
  )
20
 
21
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
 
 
 
 
22
 
23
  st.sidebar.title('VQA Bot')
24
  st.sidebar.image('./images/logo.png')
@@ -37,10 +44,5 @@ def run():
37
 
38
  st.caption("Created by Madhuri Sakhare - [Github](https://github.com/msak1612/vqa_chatbot) [Linkedin](https://www.linkedin.com/in/madhuri-sakhare/)")
39
 
40
- if 'thread' not in st.session_state:
41
- st.session_state.thread = threading.Thread(target=runInThread)
42
- add_script_run_ctx(st.session_state.thread)
43
- st.session_state.thread.start()
44
-
45
 
46
  run()
 
7
  import os
8
  import threading
9
 
10
+
11
  def runInThread():
12
  print('Initialize model in thread')
13
  st.session_state['predictor'] = predictor.Predictor()
14
+ print('Model is initialized')
15
+
16
 
17
  def run():
18
  st.set_page_config(
 
22
  )
23
 
24
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
25
+ if 'thread' not in st.session_state:
26
+ st.session_state.thread = threading.Thread(target=runInThread)
27
+ add_script_run_ctx(st.session_state.thread)
28
+ st.session_state.thread.start()
29
 
30
  st.sidebar.title('VQA Bot')
31
  st.sidebar.image('./images/logo.png')
 
44
 
45
  st.caption("Created by Madhuri Sakhare - [Github](https://github.com/msak1612/vqa_chatbot) [Linkedin](https://www.linkedin.com/in/madhuri-sakhare/)")
46
 
 
 
 
 
 
47
 
48
  run()
audiobot.py CHANGED
@@ -7,6 +7,7 @@ from streamlit_bokeh_events import streamlit_bokeh_events
7
  from bokeh.models.widgets.buttons import Button
8
  import time
9
 
 
10
  def show():
11
  st.session_state.audio_answer = ''
12
 
@@ -18,7 +19,7 @@ def show():
18
  </i></h4>
19
  ''', unsafe_allow_html=True)
20
 
21
- weights = [5,2]
22
  image_col, audio_col = st.columns(weights)
23
  with image_col:
24
  upload_pic = st.file_uploader('Choose an image...', type=[
@@ -30,8 +31,8 @@ def show():
30
  st.session_state.image = None
31
 
32
  with audio_col:
33
- welcome_text='Hello and Welcome. I have been trained as visual question answering model. You are welcome to look at any image and ask me any questions about it. I will do my best to provide the most accurate information possible based on my expertise. Select an image of interest by pressing the browse files button. Now use the Ask question button to ask a question. Please feel free to ask me any questions about this image. Now. to get my answer. press the Get answer button.'
34
- welcome_button = Button(label='About Me')
35
  welcome_button.js_on_event('button_click', CustomJS(code=f'''
36
  var u = new SpeechSynthesisUtterance();
37
  u.text = '{welcome_text}';
@@ -43,7 +44,7 @@ def show():
43
 
44
  # Speech recognition based in streamlit based on
45
  # https://discuss.streamlit.io/t/speech-to-text-on-client-side-using-html5-and-streamlit-bokeh-events/7888
46
- stt_button = Button(label='Ask Question')
47
 
48
  stt_button.js_on_event('button_click', CustomJS(code="""
49
  var recognition = new webkitSpeechRecognition();
@@ -51,7 +52,7 @@ def show():
51
  recognition.interimResults = false;
52
 
53
  recognition.onresult = function (e) {
54
- var value = "";
55
  for (var i = e.resultIndex; i < e.results.length; ++i) {
56
  if (e.results[i].isFinal) {
57
  value += e.results[i][0].transcript;
@@ -80,8 +81,7 @@ def show():
80
  st.session_state.audio_answer = st.session_state.predictor.predict_answer_from_text(
81
  st.session_state.image, result.get('GET_TEXT'))
82
 
83
-
84
- tts_button = Button(label='Get Answer')
85
  tts_button.js_on_event('button_click', CustomJS(code=f"""
86
  var u = new SpeechSynthesisUtterance();
87
  u.text = '{st.session_state.audio_answer}';
 
7
  from bokeh.models.widgets.buttons import Button
8
  import time
9
 
10
+
11
  def show():
12
  st.session_state.audio_answer = ''
13
 
 
19
  </i></h4>
20
  ''', unsafe_allow_html=True)
21
 
22
+ weights = [5, 2]
23
  image_col, audio_col = st.columns(weights)
24
  with image_col:
25
  upload_pic = st.file_uploader('Choose an image...', type=[
 
31
  st.session_state.image = None
32
 
33
  with audio_col:
34
+ welcome_text = 'Hello and Welcome. I have been trained as visual question answering model. You are welcome to look at any image and ask me any questions about it. I will do my best to provide the most accurate information possible based on my expertise. Select an image of interest by pressing the browse files button. Now use the Ask question button to ask a question. Please feel free to ask me any questions about this image. Now. to get my answer. press the Get answer button.'
35
+ welcome_button = Button(label='About Me', width=100)
36
  welcome_button.js_on_event('button_click', CustomJS(code=f'''
37
  var u = new SpeechSynthesisUtterance();
38
  u.text = '{welcome_text}';
 
44
 
45
  # Speech recognition based in streamlit based on
46
  # https://discuss.streamlit.io/t/speech-to-text-on-client-side-using-html5-and-streamlit-bokeh-events/7888
47
+ stt_button = Button(label='Ask Question', width=100)
48
 
49
  stt_button.js_on_event('button_click', CustomJS(code="""
50
  var recognition = new webkitSpeechRecognition();
 
52
  recognition.interimResults = false;
53
 
54
  recognition.onresult = function (e) {
55
+ var value = '';
56
  for (var i = e.resultIndex; i < e.results.length; ++i) {
57
  if (e.results[i].isFinal) {
58
  value += e.results[i][0].transcript;
 
81
  st.session_state.audio_answer = st.session_state.predictor.predict_answer_from_text(
82
  st.session_state.image, result.get('GET_TEXT'))
83
 
84
+ tts_button = Button(label='Get Answer', width=100)
 
85
  tts_button.js_on_event('button_click', CustomJS(code=f"""
86
  var u = new SpeechSynthesisUtterance();
87
  u.text = '{st.session_state.audio_answer}';
chatbot.py CHANGED
@@ -3,6 +3,7 @@ from streamlit_chat import message
3
  from PIL import Image
4
  import time
5
 
 
6
  def init_chat_history():
7
  if 'question' not in st.session_state:
8
  st.session_state['question'] = []
@@ -23,9 +24,10 @@ def predict(image, input):
23
  if image is None or not input:
24
  return
25
 
26
- with st.spinner('Preparing answer...'):
27
- while 'predictor' not in st.session_state:
28
- time.sleep(2)
 
29
 
30
  answer = st.session_state.predictor.predict_answer_from_text(image, input)
31
  st.session_state.question.append(input)
@@ -51,8 +53,8 @@ def show():
51
  image = Image.open(upload_pic)
52
  st.image(upload_pic, use_column_width='auto')
53
  else:
54
- st.session_state.question=[]
55
- st.session_state.answer=[]
56
  st.session_state.input = ''
57
  with text_col:
58
  input = st.text_input('Enter question: ', '', key='input')
 
3
  from PIL import Image
4
  import time
5
 
6
+
7
  def init_chat_history():
8
  if 'question' not in st.session_state:
9
  st.session_state['question'] = []
 
24
  if image is None or not input:
25
  return
26
 
27
+ if 'predictor' not in st.session_state:
28
+ with st.spinner('Preparing answer...'):
29
+ while 'predictor' not in st.session_state:
30
+ time.sleep(2)
31
 
32
  answer = st.session_state.predictor.predict_answer_from_text(image, input)
33
  st.session_state.question.append(input)
 
53
  image = Image.open(upload_pic)
54
  st.image(upload_pic, use_column_width='auto')
55
  else:
56
+ st.session_state.question = []
57
+ st.session_state.answer = []
58
  st.session_state.input = ''
59
  with text_col:
60
  input = st.text_input('Enter question: ', '', key='input')
images/.DS_Store DELETED
Binary file (6.15 kB)
 
model/predictor.py CHANGED
@@ -1,4 +1,5 @@
1
  import streamlit as st
 
2
  from transformers import ViltProcessor
3
  from transformers import ViltForQuestionAnswering
4
  from transformers import AutoTokenizer
@@ -28,7 +29,9 @@ class Predictor:
28
  'Madhuri/t5_small_vqa_fs', use_auth_token=auth_token)
29
  self.qa_tokenizer = AutoTokenizer.from_pretrained(
30
  'Madhuri/t5_small_vqa_fs', use_auth_token=auth_token)
31
-
 
 
32
 
33
  def predict_answer_from_text(self, image, input):
34
  if image is None:
@@ -54,4 +57,9 @@ class Predictor:
54
  answers = self.qa_tokenizer.batch_decode(
55
  output_ids, skip_special_tokens=True)
56
 
57
- return answers[0]
 
 
 
 
 
 
1
  import streamlit as st
2
+ from happytransformer import HappyTextToText, TTSettings
3
  from transformers import ViltProcessor
4
  from transformers import ViltForQuestionAnswering
5
  from transformers import AutoTokenizer
 
29
  'Madhuri/t5_small_vqa_fs', use_auth_token=auth_token)
30
  self.qa_tokenizer = AutoTokenizer.from_pretrained(
31
  'Madhuri/t5_small_vqa_fs', use_auth_token=auth_token)
32
+ self.happy_tt = HappyTextToText(
33
+ "T5", "vennify/t5-base-grammar-correction")
34
+ self.tt_args = TTSettings(num_beams=5, min_length=1)
35
 
36
  def predict_answer_from_text(self, image, input):
37
  if image is None:
 
57
  answers = self.qa_tokenizer.batch_decode(
58
  output_ids, skip_special_tokens=True)
59
 
60
+ # Correct the grammar of the answer
61
+ answer = self.happy_tt.generate_text(
62
+ 'grammar: ' + answers[0], args=self.tt_args).text
63
+ print(
64
+ f'question - {question}, answer - {answer}, original_answer - {answers[0]}')
65
+ return answer
requirements.txt CHANGED
@@ -1,3 +1,5 @@
 
 
1
  altair==4.2.0
2
  ansicolors==1.1.8
3
  ansiwrap==0.8.4
@@ -5,6 +7,7 @@ appnope==0.1.3
5
  argon2-cffi==21.3.0
6
  argon2-cffi-bindings==21.2.0
7
  asttokens==2.0.5
 
8
  attrs==21.4.0
9
  backcall==0.2.0
10
  beautifulsoup4==4.11.1
@@ -19,15 +22,20 @@ charset-normalizer==2.0.12
19
  click==8.1.3
20
  combomethod==1.0.12
21
  commonmark==0.9.1
 
22
  debugpy==1.6.0
23
  decorator==5.1.1
24
  defusedxml==0.7.1
 
25
  entrypoints==0.4
26
  executing==0.8.3
27
  fastjsonschema==2.15.3
28
  filelock==3.7.1
 
 
29
  gitdb==4.0.9
30
  GitPython==3.1.27
 
31
  huggingface-hub==0.7.0
32
  idna==3.3
33
  importlib-metadata==4.11.4
@@ -37,19 +45,24 @@ ipython-genutils==0.2.0
37
  ipywidgets==7.7.0
38
  jedi==0.18.1
39
  Jinja2==3.1.2
 
40
  jsonschema==4.6.0
41
  jupyter-client==7.3.4
42
  jupyter-core==4.10.0
43
  jupyterlab-pygments==0.2.2
44
  jupyterlab-widgets==1.1.0
 
45
  MarkupSafe==2.1.1
46
  matplotlib-inline==0.1.3
47
  mementos==1.3.1
48
  mistune==0.8.4
 
 
49
  nbclient==0.6.4
50
  nbconvert==6.5.0
51
  nbformat==5.4.0
52
  nest-asyncio==1.5.5
 
53
  notebook==6.4.12
54
  nulltype==2.3.1
55
  numpy==1.22.4
@@ -81,10 +94,12 @@ PyYAML==6.0
81
  pyzmq==23.1.0
82
  regex==2022.6.2
83
  requests==2.28.0
 
84
  rich==12.4.4
85
  say==1.6.6
86
  semver==2.13.0
87
  Send2Trash==1.8.0
 
88
  simplere==1.2.13
89
  six==1.12.0
90
  smmap==5.0.0
@@ -112,4 +127,6 @@ validators==0.20.0
112
  wcwidth==0.2.5
113
  webencodings==0.5.1
114
  widgetsnbextension==3.6.0
 
 
115
  zipp==3.8.0
 
1
+ aiohttp==3.8.1
2
+ aiosignal==1.2.0
3
  altair==4.2.0
4
  ansicolors==1.1.8
5
  ansiwrap==0.8.4
 
7
  argon2-cffi==21.3.0
8
  argon2-cffi-bindings==21.2.0
9
  asttokens==2.0.5
10
+ async-timeout==4.0.2
11
  attrs==21.4.0
12
  backcall==0.2.0
13
  beautifulsoup4==4.11.1
 
22
  click==8.1.3
23
  combomethod==1.0.12
24
  commonmark==0.9.1
25
+ datasets==2.3.2
26
  debugpy==1.6.0
27
  decorator==5.1.1
28
  defusedxml==0.7.1
29
+ dill==0.3.5.1
30
  entrypoints==0.4
31
  executing==0.8.3
32
  fastjsonschema==2.15.3
33
  filelock==3.7.1
34
+ frozenlist==1.3.0
35
+ fsspec==2022.5.0
36
  gitdb==4.0.9
37
  GitPython==3.1.27
38
+ happytransformer==2.4.1
39
  huggingface-hub==0.7.0
40
  idna==3.3
41
  importlib-metadata==4.11.4
 
45
  ipywidgets==7.7.0
46
  jedi==0.18.1
47
  Jinja2==3.1.2
48
+ joblib==1.1.0
49
  jsonschema==4.6.0
50
  jupyter-client==7.3.4
51
  jupyter-core==4.10.0
52
  jupyterlab-pygments==0.2.2
53
  jupyterlab-widgets==1.1.0
54
+ language-tool-python==2.7.1
55
  MarkupSafe==2.1.1
56
  matplotlib-inline==0.1.3
57
  mementos==1.3.1
58
  mistune==0.8.4
59
+ multidict==6.0.2
60
+ multiprocess==0.70.13
61
  nbclient==0.6.4
62
  nbconvert==6.5.0
63
  nbformat==5.4.0
64
  nest-asyncio==1.5.5
65
+ nltk==3.7
66
  notebook==6.4.12
67
  nulltype==2.3.1
68
  numpy==1.22.4
 
94
  pyzmq==23.1.0
95
  regex==2022.6.2
96
  requests==2.28.0
97
+ responses==0.18.0
98
  rich==12.4.4
99
  say==1.6.6
100
  semver==2.13.0
101
  Send2Trash==1.8.0
102
+ sentencepiece==0.1.96
103
  simplere==1.2.13
104
  six==1.12.0
105
  smmap==5.0.0
 
127
  wcwidth==0.2.5
128
  webencodings==0.5.1
129
  widgetsnbextension==3.6.0
130
+ xxhash==3.0.0
131
+ yarl==1.7.2
132
  zipp==3.8.0