Madhuri commited on
Commit
4c71f4e
1 Parent(s): 25f8b3c

Use client server approach for model to UI communication.

Browse files
Files changed (8) hide show
  1. .gitignore +4 -1
  2. app.py +33 -18
  3. audiobot.py +9 -24
  4. chatbot.py +13 -32
  5. helper.py +43 -0
  6. model/predictor.py +0 -2
  7. requirements.txt +10 -0
  8. server.py +68 -0
.gitignore CHANGED
@@ -86,4 +86,7 @@ target/
86
  .mypy_cache/
87
 
88
  # exclude generated models from source control
89
- models/intermediate/
 
 
 
86
  .mypy_cache/
87
 
88
  # exclude generated models from source control
89
+ models/intermediate/
90
+
91
+ # exclude uploaded images from source control
92
+ images/upload_*
app.py CHANGED
@@ -1,32 +1,21 @@
 
1
  import streamlit as st
 
2
 
3
- from model import predictor
4
- from streamlit.scriptrunner import add_script_run_ctx
5
  import audiobot
6
  import chatbot
7
  import os
8
- import threading
9
 
10
 
11
- def runInThread():
12
- print('Initialize model in thread')
13
- st.session_state['predictor'] = predictor.Predictor()
14
- print('Model is initialized')
15
-
16
-
17
- def run():
18
  st.set_page_config(
19
  page_title='Welcome to Visual Question Answering - Bot',
20
  page_icon=':robot:',
21
  layout='wide'
22
  )
23
 
24
- os.environ['TOKENIZERS_PARALLELISM'] = 'false'
25
- if 'thread' not in st.session_state:
26
- st.session_state.thread = threading.Thread(target=runInThread)
27
- add_script_run_ctx(st.session_state.thread)
28
- st.session_state.thread.start()
29
-
30
  st.sidebar.title('VQA Bot')
31
  st.sidebar.write('''
32
  VQA Bot addresses the challenge of visual question answering with the chat and voice assistance.
@@ -41,7 +30,33 @@ def run():
41
  elif selected_page == 'VQA Audiobot':
42
  audiobot.show()
43
 
44
- st.caption("Created by Madhuri Sakhare - [Github](https://github.com/msak1612/vqa_chatbot) [Linkedin](https://www.linkedin.com/in/madhuri-sakhare/)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
 
 
 
 
 
46
 
47
- run()
1
+ import uvicorn
2
  import streamlit as st
3
+ from multiprocessing import Process
4
 
5
+ import socket
6
+ import time
7
  import audiobot
8
  import chatbot
9
  import os
 
10
 
11
 
12
+ def run_st_app():
 
 
 
 
 
 
13
  st.set_page_config(
14
  page_title='Welcome to Visual Question Answering - Bot',
15
  page_icon=':robot:',
16
  layout='wide'
17
  )
18
 
 
 
 
 
 
 
19
  st.sidebar.title('VQA Bot')
20
  st.sidebar.write('''
21
  VQA Bot addresses the challenge of visual question answering with the chat and voice assistance.
30
  elif selected_page == 'VQA Audiobot':
31
  audiobot.show()
32
 
33
+ st.caption(
34
+ 'Created by Madhuri Sakhare - [Github](https://github.com/msak1612/vqa_chatbot) [Linkedin](https://www.linkedin.com/in/madhuri-sakhare/)')
35
+
36
+
37
+ def run_uvicorn():
38
+ os.system('uvicorn server:app --port 8080 --host 0.0.0.0 --workers 1')
39
+
40
+
41
+ def start_server():
42
+ if not is_port_in_use(8080):
43
+ with st.spinner(text='Loading models...'):
44
+ proc = Process(target=run_uvicorn, args=(), daemon=True)
45
+ proc.start()
46
+ while not is_port_in_use(8080):
47
+ time.sleep(1)
48
+ st.success('Models are loaded.')
49
+
50
+
51
+ def is_port_in_use(port):
52
+ # Find whether port is available using https://stackoverflow.com/questions/2470971/fast-way-to-test-if-a-port-is-in-use-using-python
53
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
54
+ return s.connect_ex(('0.0.0.0', port)) == 0
55
+
56
 
57
+ if __name__ == '__main__':
58
+ run_st_app()
59
+ if 'server' not in st.session_state:
60
+ st.session_state['server'] = True
61
+ start_server()
62
 
 
audiobot.py CHANGED
@@ -1,4 +1,3 @@
1
- from turtle import width
2
  import streamlit as st
3
  from PIL import Image
4
  from bokeh.models.widgets import Button
@@ -7,28 +6,14 @@ from st_clickable_images import clickable_images
7
  from streamlit_bokeh_events import streamlit_bokeh_events
8
  from bokeh.models.widgets.buttons import Button
9
  import time
10
- from os.path import *
11
- from os import listdir
12
- import base64
13
-
14
- def update_gallery_images():
15
- if 'gallery' not in st.session_state:
16
- st.session_state.gallery = []
17
- st.session_state.gallery_images = []
18
- image_path = join(dirname(abspath(__file__)), 'images')
19
- for f in listdir(image_path):
20
- if f.startswith('image'):
21
- with open(join(image_path, f), "rb") as image:
22
- encoded = base64.b64encode(image.read()).decode()
23
- st.session_state.gallery.append(
24
- f"data:image/jpeg;base64,{encoded}")
25
- st.session_state.gallery_images.append(join(image_path, f))
26
 
27
 
28
  def upload_image_callback():
29
- st.session_state.uploaded_image = st.session_state.uploader
30
  st.session_state.input = ''
31
 
 
32
  def show():
33
  st.session_state.audio_answer = ''
34
 
@@ -63,8 +48,10 @@ def show():
63
  on_change=upload_image_callback, key='uploader')
64
 
65
  if st.session_state.uploaded_image is not None:
66
- st.session_state.image = Image.open(st.session_state.uploaded_image)
67
- st.image(st.session_state.uploaded_image, use_column_width='always')
 
 
68
  else:
69
  st.session_state.image = None
70
  st.session_state.input = ''
@@ -118,10 +105,8 @@ def show():
118
  if 'question' not in st.session_state or st.session_state.question != result.get('GET_TEXT'):
119
  st.session_state['question'] = result.get('GET_TEXT')
120
  with st.spinner('Preparing answer...'):
121
- while 'predictor' not in st.session_state:
122
- time.sleep(2)
123
- st.session_state.audio_answer = st.session_state.predictor.predict_answer_from_text(
124
- st.session_state.image, result.get('GET_TEXT'))
125
 
126
  tts_button = Button(label='Get Answer', width=100)
127
  tts_button.js_on_event('button_click', CustomJS(code=f"""
 
1
  import streamlit as st
2
  from PIL import Image
3
  from bokeh.models.widgets import Button
6
  from streamlit_bokeh_events import streamlit_bokeh_events
7
  from bokeh.models.widgets.buttons import Button
8
  import time
9
+ from helper import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  def upload_image_callback():
13
+ st.session_state.uploaded_image = upload_image_to_server()
14
  st.session_state.input = ''
15
 
16
+
17
  def show():
18
  st.session_state.audio_answer = ''
19
 
48
  on_change=upload_image_callback, key='uploader')
49
 
50
  if st.session_state.uploaded_image is not None:
51
+ st.session_state.image = Image.open(
52
+ st.session_state.uploaded_image)
53
+ st.image(st.session_state.uploaded_image,
54
+ use_column_width='always')
55
  else:
56
  st.session_state.image = None
57
  st.session_state.input = ''
105
  if 'question' not in st.session_state or st.session_state.question != result.get('GET_TEXT'):
106
  st.session_state['question'] = result.get('GET_TEXT')
107
  with st.spinner('Preparing answer...'):
108
+ st.session_state.audio_answer = request_answer(
109
+ st.session_state.server_image_file, result.get('GET_TEXT'))
 
 
110
 
111
  tts_button = Button(label='Get Answer', width=100)
112
  tts_button.js_on_event('button_click', CustomJS(code=f"""
chatbot.py CHANGED
@@ -1,12 +1,9 @@
1
- import streamlit as st
2
  from streamlit_chat import message
3
  from st_clickable_images import clickable_images
4
  from PIL import Image
5
- import time
6
- from os.path import *
7
- from os import listdir
8
- import base64
9
 
 
10
 
11
  def init_chat_history():
12
  if 'question' not in st.session_state:
@@ -29,39 +26,22 @@ def predict(image, input):
29
  if image is None or not input:
30
  return
31
 
32
- if 'predictor' not in st.session_state:
33
- with st.spinner('Preparing answer...'):
34
- while 'predictor' not in st.session_state:
35
- time.sleep(2)
36
-
37
- answer = st.session_state.predictor.predict_answer_from_text(image, input)
38
- st.session_state.question.append(input)
39
- st.session_state.answer.append(answer)
40
- while len(st.session_state.question) >= 5:
41
- st.session_state.answer.pop(0)
42
- st.session_state.question.pop(0)
43
-
44
-
45
- def update_gallery_images():
46
- if 'gallery' not in st.session_state:
47
- st.session_state.gallery = []
48
- st.session_state.gallery_images = []
49
- image_path = join(dirname(abspath(__file__)), 'images')
50
- for f in listdir(image_path):
51
- if f.startswith('image'):
52
- with open(join(image_path, f), "rb") as image:
53
- encoded = base64.b64encode(image.read()).decode()
54
- st.session_state.gallery.append(
55
- f"data:image/jpeg;base64,{encoded}")
56
- st.session_state.gallery_images.append(join(image_path, f))
57
 
58
 
59
  def upload_image_callback():
60
- st.session_state.uploaded_image = st.session_state.uploader
61
  st.session_state.question = []
62
  st.session_state.answer = []
63
  st.session_state.input = ''
64
 
 
65
  def show():
66
  init_chat_history()
67
 
@@ -98,7 +78,8 @@ def show():
98
 
99
  if st.session_state.uploaded_image is not None:
100
  image = Image.open(st.session_state.uploaded_image)
101
- st.image(st.session_state.uploaded_image, use_column_width='always')
 
102
  else:
103
  st.session_state.question = []
104
  st.session_state.answer = []
 
1
  from streamlit_chat import message
2
  from st_clickable_images import clickable_images
3
  from PIL import Image
4
+ from helper import *
 
 
 
5
 
6
+ import streamlit as st
7
 
8
  def init_chat_history():
9
  if 'question' not in st.session_state:
26
  if image is None or not input:
27
  return
28
 
29
+ with st.spinner('Preparing answer...'):
30
+ answer = request_answer(st.session_state.uploaded_image, input)
31
+ st.session_state.question.append(input)
32
+ st.session_state.answer.append(answer)
33
+ while len(st.session_state.question) >= 5:
34
+ st.session_state.answer.pop(0)
35
+ st.session_state.question.pop(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  def upload_image_callback():
39
+ st.session_state.uploaded_image = upload_image_to_server()
40
  st.session_state.question = []
41
  st.session_state.answer = []
42
  st.session_state.input = ''
43
 
44
+
45
  def show():
46
  init_chat_history()
47
 
78
 
79
  if st.session_state.uploaded_image is not None:
80
  image = Image.open(st.session_state.uploaded_image)
81
+ st.image(st.session_state.uploaded_image,
82
+ use_column_width='always')
83
  else:
84
  st.session_state.question = []
85
  st.session_state.answer = []
helper.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import listdir
2
+ from os.path import *
3
+ from PIL import Image
4
+ from io import BytesIO
5
+
6
+ import streamlit as st
7
+ import base64
8
+ import requests
9
+
10
+
11
+ def update_gallery_images():
12
+ if 'gallery' not in st.session_state:
13
+ st.session_state['gallery'] = []
14
+ st.session_state['gallery_images'] = []
15
+ image_path = join(dirname(abspath(__file__)), 'images')
16
+ for f in listdir(image_path):
17
+ if f.startswith('image'):
18
+ with open(join(image_path, f), "rb") as image:
19
+ encoded = base64.b64encode(image.read()).decode()
20
+ st.session_state.gallery.append(
21
+ f"data:image/jpeg;base64,{encoded}")
22
+ st.session_state.gallery_images.append(join(image_path, f))
23
+
24
+
25
+ def upload_image_to_server():
26
+ if st.session_state.uploader is not None:
27
+ image = Image.open(st.session_state.uploader)
28
+ byte_io = BytesIO()
29
+ image.save(byte_io, 'png')
30
+ byte_io.seek(0)
31
+ file = {'file': byte_io}
32
+ response = requests.post('http://0.0.0.0:8080/uploadfile/', files=file)
33
+ if response.status_code == 200:
34
+ return response.json()['filename']
35
+ return None
36
+
37
+
38
+ def request_answer(image, question):
39
+ response = requests.get(
40
+ f'http://0.0.0.0:8080/vqa?image={image}&question={question}')
41
+ if response.status_code == 200:
42
+ return response.json()['answer']
43
+ return 'I do not understand. Please ask again.'
model/predictor.py CHANGED
@@ -1,4 +1,3 @@
1
- import streamlit as st
2
  from happytransformer import HappyTextToText, TTSettings
3
  from transformers import ViltProcessor
4
  from transformers import ViltForQuestionAnswering
@@ -18,7 +17,6 @@ question.
18
  '''
19
 
20
 
21
- @st.experimental_singleton
22
  class Predictor:
23
  def __init__(self):
24
  auth_token = os.environ.get('TOKEN') or True
 
1
  from happytransformer import HappyTextToText, TTSettings
2
  from transformers import ViltProcessor
3
  from transformers import ViltForQuestionAnswering
17
  '''
18
 
19
 
 
20
  class Predictor:
21
  def __init__(self):
22
  auth_token = os.environ.get('TOKEN') or True
requirements.txt CHANGED
@@ -3,6 +3,7 @@ aiosignal==1.2.0
3
  altair==4.2.0
4
  ansicolors==1.1.8
5
  ansiwrap==0.8.4
 
6
  appnope==0.1.3
7
  argon2-cffi==21.3.0
8
  argon2-cffi-bindings==21.2.0
@@ -29,12 +30,14 @@ defusedxml==0.7.1
29
  dill==0.3.5.1
30
  entrypoints==0.4
31
  executing==0.8.3
 
32
  fastjsonschema==2.15.3
33
  filelock==3.7.1
34
  frozenlist==1.3.0
35
  fsspec==2022.5.0
36
  gitdb==4.0.9
37
  GitPython==3.1.27
 
38
  happytransformer==2.4.1
39
  huggingface-hub==0.7.0
40
  idna==3.3
@@ -74,6 +77,8 @@ parso==0.8.3
74
  pexpect==4.8.0
75
  pickleshare==0.7.5
76
  Pillow==9.1.1
 
 
77
  prometheus-client==0.14.1
78
  prompt-toolkit==3.0.29
79
  protobuf==3.20.1
@@ -82,12 +87,14 @@ ptyprocess==0.7.0
82
  pure-eval==0.2.2
83
  pyarrow==8.0.0
84
  pycparser==2.21
 
85
  pydeck==0.7.1
86
  Pygments==2.12.0
87
  Pympler==1.0.1
88
  pyparsing==3.0.9
89
  pyrsistent==0.18.1
90
  python-dateutil==2.8.2
 
91
  pytz==2022.1
92
  pytz-deprecation-shim==0.1.0.post0
93
  PyYAML==6.0
@@ -106,9 +113,11 @@ simplere==1.2.13
106
  six==1.12.0
107
  sklearn==0.0
108
  smmap==5.0.0
 
109
  soupsieve==2.3.2.post1
110
  st-clickable-images==0.0.3
111
  stack-data==0.3.0
 
112
  streamlit==1.10.0
113
  streamlit-bokeh-events==0.1.2
114
  streamlit-chat==0.0.2.1
@@ -128,6 +137,7 @@ typing_extensions==4.2.0
128
  tzdata==2022.1
129
  tzlocal==4.2
130
  urllib3==1.26.9
 
131
  validators==0.20.0
132
  wcwidth==0.2.5
133
  webencodings==0.5.1
3
  altair==4.2.0
4
  ansicolors==1.1.8
5
  ansiwrap==0.8.4
6
+ anyio==3.6.1
7
  appnope==0.1.3
8
  argon2-cffi==21.3.0
9
  argon2-cffi-bindings==21.2.0
30
  dill==0.3.5.1
31
  entrypoints==0.4
32
  executing==0.8.3
33
+ fastapi==0.78.0
34
  fastjsonschema==2.15.3
35
  filelock==3.7.1
36
  frozenlist==1.3.0
37
  fsspec==2022.5.0
38
  gitdb==4.0.9
39
  GitPython==3.1.27
40
+ h11==0.13.0
41
  happytransformer==2.4.1
42
  huggingface-hub==0.7.0
43
  idna==3.3
77
  pexpect==4.8.0
78
  pickleshare==0.7.5
79
  Pillow==9.1.1
80
+ pox==0.3.1
81
+ ppft==1.7.6.5
82
  prometheus-client==0.14.1
83
  prompt-toolkit==3.0.29
84
  protobuf==3.20.1
87
  pure-eval==0.2.2
88
  pyarrow==8.0.0
89
  pycparser==2.21
90
+ pydantic==1.9.1
91
  pydeck==0.7.1
92
  Pygments==2.12.0
93
  Pympler==1.0.1
94
  pyparsing==3.0.9
95
  pyrsistent==0.18.1
96
  python-dateutil==2.8.2
97
+ python-multipart==0.0.5
98
  pytz==2022.1
99
  pytz-deprecation-shim==0.1.0.post0
100
  PyYAML==6.0
113
  six==1.12.0
114
  sklearn==0.0
115
  smmap==5.0.0
116
+ sniffio==1.2.0
117
  soupsieve==2.3.2.post1
118
  st-clickable-images==0.0.3
119
  stack-data==0.3.0
120
+ starlette==0.19.1
121
  streamlit==1.10.0
122
  streamlit-bokeh-events==0.1.2
123
  streamlit-chat==0.0.2.1
137
  tzdata==2022.1
138
  tzlocal==4.2
139
  urllib3==1.26.9
140
+ uvicorn==0.18.1
141
  validators==0.20.0
142
  wcwidth==0.2.5
143
  webencodings==0.5.1
server.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from model import predictor
3
+ from os import listdir
4
+ from os.path import *
5
+ from PIL import Image
6
+
7
+ import os
8
+ import hashlib
9
+ import threading
10
+ import time
11
+
12
+ gpredictor = None
13
+ app = FastAPI()
14
+
15
+ @app.get('/')
16
+ def root():
17
+ return {'app': 'Thanks for visiting!!'}
18
+
19
+
20
+ @app.get('/favicon.ico', include_in_schema=False)
21
+ @app.post('/uploadfile/')
22
+ async def create_upload_file(file: UploadFile = File(...)):
23
+ contents = await file.read()
24
+ hash = hashlib.sha256(contents).hexdigest()
25
+ file.filename = f'images/upload_{hash}.jpg'
26
+ if not os.path.isfile(file.filename):
27
+ with open(file.filename, 'wb') as f:
28
+ f.write(contents)
29
+ images[file.filename] = Image.open(file.filename)
30
+ return {'filename': file.filename}
31
+
32
+
33
+ @app.get('/vqa')
34
+ async def answer(
35
+ image: str,
36
+ question: str
37
+ ):
38
+ if image not in images:
39
+ print('not in image')
40
+ pil_image = Image.open(image)
41
+ images[image] = pil_image
42
+ else:
43
+ pil_image = images[image]
44
+ while gpredictor is None:
45
+ time.sleep(1)
46
+ answer = gpredictor.predict_answer_from_text( pil_image, question )
47
+ return {'answer': answer }
48
+
49
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
50
+ images={}
51
+
52
+ def runInThread():
53
+ collect_images()
54
+ print('Initialize model in thread')
55
+ global gpredictor
56
+ gpredictor = predictor.Predictor()
57
+ print('Model is initialized')
58
+
59
+
60
+ def collect_images():
61
+ image_path = join(dirname(abspath(__file__)), 'images')
62
+ for f in listdir(image_path):
63
+ if f.startswith('image'):
64
+ full_image_path = join(image_path, f)
65
+ images[full_image_path] = Image.open(full_image_path)
66
+
67
+ thread = threading.Thread(target=runInThread)
68
+ thread.start()