GeorgiosIoannouCoder commited on
Commit
c7742ac
·
verified ·
1 Parent(s): 89ada58

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +325 -0
app.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #############################################################################################################################
2
+ # Filename : app.py
3
+ # Description: A Streamlit application to utilize five models back to back
4
+ # Models used:
5
+ # 1. Visual Question Answering (VQA).
6
+ # 2. Fill-Mask.
7
+ # 3. Text2text Generation.
8
+ # 4. Text Generation.
9
+ # 5. Topic.
10
+ # Author : Georgios Ioannou
11
+ #
12
+ # Copyright © 2024 by Georgios Ioannou
13
+ #############################################################################################################################
14
+
15
+ # Import libraries.
16
+
17
+ import streamlit as st # Build the GUI of the application.
18
+ import torch # Load Salesforce/blip model(s) on GPU.
19
+
20
+ from bertopic import BERTopic # Topic model inference.
21
+ from PIL import Image # Open and identify a given image file.
22
+ from transformers import (
23
+ pipeline,
24
+ BlipProcessor,
25
+ BlipForQuestionAnswering,
26
+ ) # VQA model inference.
27
+
28
+ #############################################################################################################################
29
+
30
+ # Function to apply local CSS.
31
+
32
+
33
+ def local_css(file_name):
34
+ with open(file_name) as f:
35
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
36
+
37
+
38
+ #############################################################################################################################
39
+
40
+ # Model 1.
41
+ # Model 1 gets input from the user.
42
+ # User -> Model 1
43
+
44
+ # Load the Visual Question Answering (VQA) model directly.
45
+ # Using transformers.
46
+
47
+
48
+ @st.cache_resource
49
+ def load_model_blip():
50
+ blip_processor_base = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
51
+ blip_model_base = BlipForQuestionAnswering.from_pretrained(
52
+ "Salesforce/blip-vqa-base"
53
+ )
54
+
55
+ # Backup model.
56
+ # blip_processor_large = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
57
+ # blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
58
+ # return blip_processor_large, blip_model_large
59
+
60
+ return blip_processor_base, blip_model_base
61
+
62
+
63
+ # General function for any Salesforce/blip model(s).
64
+ # VQA model.
65
+
66
+
67
+ def generate_answer_blip(processor, model, image, question):
68
+ # Prepare image + question.
69
+
70
+ inputs = processor(images=image, text=question, return_tensors="pt")
71
+
72
+ generated_ids = model.generate(**inputs, max_length=50)
73
+
74
+ generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
75
+
76
+ return generated_answer
77
+
78
+
79
+ # Generate answer from the Salesforce/blip model(s).
80
+ # VQA model.
81
+
82
+
83
+ @st.cache_resource
84
+ def generate_answer(image, question):
85
+ answer_blip_base = generate_answer_blip(
86
+ processor=blip_processor_base,
87
+ model=blip_model_base,
88
+ image=image,
89
+ question=question,
90
+ )
91
+
92
+ # answer_blip_large = generate_answer_blip(blip_processor_large, blip_model_large, image, question)
93
+ # return answer_blip_large
94
+
95
+ return answer_blip_base
96
+
97
+
98
+ #############################################################################################################################
99
+
100
+ # Model 2.
101
+ # Model 2 gets input from Model 1.
102
+ # User -> Model 1 -> Model 2
103
+
104
+
105
+ @st.cache_resource
106
+ def load_model_fill_mask():
107
+ return pipeline(task="fill-mask", model="bert-base-uncased")
108
+
109
+
110
+ #############################################################################################################################
111
+
112
+ # Model 3.
113
+ # Model 3 gets input from Model 2.
114
+ # User -> Model 1 -> Model 2 -> Model 3
115
+
116
+
117
+ @st.cache_resource
118
+ def load_model_text2text_generation():
119
+ return pipeline(
120
+ task="text2text-generation", model="facebook/blenderbot-400M-distill"
121
+ )
122
+
123
+
124
+ #############################################################################################################################
125
+
126
+ # Model 4.
127
+ # Model 4 gets input from Model 3.
128
+ # User -> Model 1 -> Model 2 -> Model 3 -> Model 4
129
+
130
+
131
+ @st.cache_resource
132
+ def load_model_fill_text_generation():
133
+ return pipeline(task="text-generation", model="gpt2")
134
+
135
+
136
+ #############################################################################################################################
137
+
138
+ # Model 5.
139
+ # Model 5 gets input from Model 4.
140
+ # User -> Model 1 -> Model 2 -> Model 3 -> Model 4 -> Model 5
141
+
142
+
143
+ @st.cache_resource
144
+ def load_model_bertopic1():
145
+ return BERTopic.load(path="davanstrien/chat_topics")
146
+
147
+
148
+ @st.cache_resource
149
+ def load_model_bertopic2():
150
+ return BERTopic.load(path="MaartenGr/BERTopic_ArXiv")
151
+
152
+
153
+ #############################################################################################################################
154
+ # Page title and favicon.
155
+
156
+ st.set_page_config(page_title="Visual Question Answering", page_icon="❓")
157
+
158
+ #############################################################################################################################
159
+
160
+ # Load the Salesforce/blip model directly.
161
+
162
+ if torch.cuda.is_available():
163
+ device = torch.device("cuda")
164
+ # elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
165
+ # device = torch.device("mps")
166
+ else:
167
+ device = torch.device("cpu")
168
+
169
+ blip_processor_base, blip_model_base = load_model_blip()
170
+ blip_model_base.to(device)
171
+
172
+ #############################################################################################################################
173
+ # Main function to create the Streamlit web application.
174
+ #
175
+ # 5 MODEL INFERENCES.
176
+ # User Input = Image + Question About The Image.
177
+ # User -> Model 1 -> Model 2 -> Model 3 -> Model 4 -> Model 5
178
+
179
+
180
+ def main():
181
+ try:
182
+ #####################################################################################################################
183
+
184
+ # Load CSS.
185
+
186
+ local_css("styles/style.css")
187
+
188
+ #####################################################################################################################
189
+
190
+ # Title.
191
+
192
+ title = f"""<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem">
193
+ Georgios Ioannou's Visual Question Answering</h1>"""
194
+ st.markdown(title, unsafe_allow_html=True)
195
+ # st.title("ChefBot - Automated Recipe Assistant")
196
+
197
+ #####################################################################################################################
198
+
199
+ # Subtitle.
200
+
201
+ subtitle = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem">
202
+ CUNY Tech Prep Tutorial 4</h2>"""
203
+ st.markdown(subtitle, unsafe_allow_html=True)
204
+
205
+ #####################################################################################################################
206
+
207
+ # Image.
208
+
209
+ image = "./ctp.png"
210
+ left_co, cent_co, last_co = st.columns(3)
211
+ with cent_co:
212
+ st.image(image=image)
213
+
214
+ #####################################################################################################################
215
+
216
+ # User input (Image).
217
+ image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
218
+
219
+ if image is not None:
220
+ bytes_data = image.getvalue()
221
+
222
+ with open(image.name, "wb") as file:
223
+
224
+ file.write(bytes_data)
225
+ st.image(image, caption="Uploaded Image.", use_column_width=True)
226
+ raw_image = Image.open(image.name).convert("RGB")
227
+
228
+ # User input (Question).
229
+ question = st.text_input("What's your question?")
230
+
231
+ #############################################################################################################
232
+
233
+ if question != "":
234
+ # Model 1.
235
+ with st.spinner(
236
+ text="VQA inference..."
237
+ ): # Spinner to keep the application interactive.
238
+ # Model inference.
239
+
240
+ answer = generate_answer(raw_image, question)[0]
241
+ st.success(f"VQA: {answer}")
242
+
243
+ bbu_pipeline = load_model_fill_mask()
244
+ text = (
245
+ "I love " + answer + " and I would like to know how to [MASK]."
246
+ )
247
+
248
+ #########################################################################################################
249
+
250
+ # Model 2.
251
+ with st.spinner(
252
+ text="Fill-Mask inference..."
253
+ ): # Spinner to keep the application interactive.
254
+ # Model inference.
255
+ bbu_pipeline_output = bbu_pipeline(text)
256
+ bbu_output = bbu_pipeline_output[0]["sequence"]
257
+ st.success(f"Fill-Mask: {bbu_output}")
258
+
259
+ facebook_pipeline = load_model_text2text_generation()
260
+ utterance = bbu_output
261
+
262
+ #########################################################################################################
263
+
264
+ # Model 3.
265
+ with st.spinner(
266
+ text="Text2text Generation inference..."
267
+ ): # Spinner to keep the application interactive.
268
+ # Model inference.
269
+ facebook_pipeline_output = facebook_pipeline(utterance)
270
+ facebook_output = facebook_pipeline_output[0]["generated_text"]
271
+ st.success(f"Text2text Generation: {facebook_output}")
272
+
273
+ gpt2_pipeline = load_model_fill_text_generation()
274
+
275
+ #########################################################################################################
276
+
277
+ # Model 4.
278
+ with st.spinner(
279
+ text="Fill Text Generation inference..."
280
+ ): # Spinner to keep the application interactive.
281
+ # Model inference.
282
+ gpt2_pipeline_output = gpt2_pipeline(facebook_output)
283
+ gpt2_output = gpt2_pipeline_output[0]["generated_text"]
284
+ st.success(f"Fill Text Generation: {gpt2_output}")
285
+
286
+ #########################################################################################################
287
+
288
+ # Model 5.
289
+ topic_model_1 = load_model_bertopic1()
290
+ topic, prob = topic_model_1.transform(gpt2_pipeline_output)
291
+ topic_model_1_output = topic_model_1.get_topic_info(topic[0])[
292
+ "Representation"
293
+ ][0]
294
+ st.success(
295
+ f"Topic(s) from davanstrien/chat_topics: {topic_model_1_output}"
296
+ )
297
+
298
+ topic_model_2 = load_model_bertopic2()
299
+ topic, prob = topic_model_2.transform(gpt2_pipeline_output)
300
+ topic_model_2_output = topic_model_2.get_topic_info(topic[0])[
301
+ "Representation"
302
+ ][0]
303
+ st.success(
304
+ f"Topic(s) from MaartenGr/BERTopic_ArXiv: {topic_model_1_output}"
305
+ )
306
+ except Exception as e:
307
+ # General exception/error handling.
308
+
309
+ st.error(e)
310
+
311
+ # GitHub repository of author.
312
+
313
+ st.markdown(
314
+ f"""
315
+ <p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our
316
+ <a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b>
317
+ </p>
318
+ """,
319
+ unsafe_allow_html=True,
320
+ )
321
+
322
+
323
+ #############################################################################################################################
324
+ if __name__ == "__main__":
325
+ main()