ki1207 commited on
Commit
32c980d
1 Parent(s): 615c29d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +322 -0
app.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import string
7
+
8
+ import gradio as gr
9
+ import PIL.Image
10
+ import spaces
11
+ import torch
12
+ from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
13
+
14
+ DESCRIPTION = "# [BLIP-2](https://github.com/salesforce/LAVIS/tree/main/projects/blip2)"
15
+
16
+ if not torch.cuda.is_available():
17
+ DESCRIPTION += "\n<p>Running on CPU.</p>"
18
+
19
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
+
21
+
22
+ MODEL_ID = "Salesforce/instructblip-flan-t5-xl"
23
+
24
+
25
+
26
+ processor = InstructBlipProcessor.from_pretrained(MODEL_ID)
27
+ model = InstructBlipForConditionalGeneration.from_pretrained(MODEL_ID, device_map="auto", load_in_8bit=True)
28
+
29
+
30
+
31
+ @spaces
32
+ def generate_caption(
33
+ image: PIL.Image.Image,
34
+ decoding_method: str = "Nucleus sampling",
35
+ temperature: float = 1.0,
36
+ length_penalty: float = 1.0,
37
+ repetition_penalty: float = 1.5,
38
+ max_length: int = 50,
39
+ min_length: int = 1,
40
+ num_beams: int = 5,
41
+ top_p: float = 0.9,
42
+ ) -> str:
43
+ inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
44
+ generated_ids = model.generate(
45
+ pixel_values=inputs.pixel_values,
46
+ do_sample=decoding_method == "Nucleus sampling",
47
+ temperature=temperature,
48
+ length_penalty=length_penalty,
49
+ repetition_penalty=repetition_penalty,
50
+ max_length=max_length,
51
+ min_length=min_length,
52
+ num_beams=num_beams,
53
+ top_p=top_p,
54
+ )
55
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
56
+ return result
57
+
58
+
59
+ @spaces
60
+ def answer_question(
61
+ image: PIL.Image.Image,
62
+ prompt: str,
63
+ decoding_method: str = "Nucleus sampling",
64
+ temperature: float = 1.0,
65
+ length_penalty: float = 1.0,
66
+ repetition_penalty: float = 1.5,
67
+ max_length: int = 50,
68
+ min_length: int = 1,
69
+ num_beams: int = 5,
70
+ top_p: float = 0.9,
71
+ ) -> str:
72
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
73
+ generated_ids = model.generate(
74
+ **inputs,
75
+ do_sample=decoding_method == "Nucleus sampling",
76
+ temperature=temperature,
77
+ length_penalty=length_penalty,
78
+ repetition_penalty=repetition_penalty,
79
+ max_length=max_length,
80
+ min_length=min_length,
81
+ num_beams=num_beams,
82
+ top_p=top_p,
83
+ )
84
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
85
+ return result
86
+
87
+
88
+ def postprocess_output(output: str) -> str:
89
+ if output and output[-1] not in string.punctuation:
90
+ output += "."
91
+ return output
92
+
93
+
94
+ def chat(
95
+ image: PIL.Image.Image,
96
+ text: str,
97
+ decoding_method: str = "Nucleus sampling",
98
+ temperature: float = 1.0,
99
+ length_penalty: float = 1.0,
100
+ repetition_penalty: float = 1.5,
101
+ max_length: int = 50,
102
+ min_length: int = 1,
103
+ num_beams: int = 5,
104
+ top_p: float = 0.9,
105
+ history_orig: list[str] = [],
106
+ history_qa: list[str] = [],
107
+ ) -> tuple[list[tuple[str, str]], list[str], list[str]]:
108
+ history_orig.append(text)
109
+ text_qa = f"Question: {text} Answer:"
110
+ history_qa.append(text_qa)
111
+ prompt = " ".join(history_qa)
112
+
113
+ output = answer_question(
114
+ image=image,
115
+ prompt=prompt,
116
+ decoding_method=decoding_method,
117
+ temperature=temperature,
118
+ length_penalty=length_penalty,
119
+ repetition_penalty=repetition_penalty,
120
+ max_length=max_length,
121
+ min_length=min_length,
122
+ num_beams=num_beams,
123
+ top_p=top_p,
124
+ )
125
+ output = postprocess_output(output)
126
+ history_orig.append(output)
127
+ history_qa.append(output)
128
+
129
+ chat_val = list(zip(history_orig[0::2], history_orig[1::2]))
130
+ return chat_val, history_orig, history_qa
131
+
132
+
133
+ examples = [
134
+ [
135
+ "images/house.png",
136
+ "How could someone get out of the house?",
137
+ ],
138
+ [
139
+ "images/flower.jpg",
140
+ "What is this flower and where is it's origin?",
141
+ ],
142
+ [
143
+ "images/pizza.jpg",
144
+ "What are steps to cook it?",
145
+ ],
146
+ [
147
+ "images/sunset.jpg",
148
+ "Here is a romantic message going along the photo:",
149
+ ],
150
+ [
151
+ "images/forbidden_city.webp",
152
+ "In what dynasties was this place built?",
153
+ ],
154
+ ]
155
+
156
+ with gr.Blocks as demo:
157
+ gr.Markdown(DESCRIPTION)
158
+
159
+ with gr.Group():
160
+ image = gr.Image(type="pil")
161
+ with gr.Tabs():
162
+ with gr.Tab(label="Image Captioning"):
163
+ caption_button = gr.Button("Caption it!")
164
+ caption_output = gr.Textbox(label="Caption Output", show_label=False, container=False)
165
+ with gr.Tab(label="Visual Question Answering"):
166
+ chatbot = gr.Chatbot(label="VQA Chat", show_label=False)
167
+ history_orig = gr.State(value=[])
168
+ history_qa = gr.State(value=[])
169
+ vqa_input = gr.Text(label="Chat Input", show_label=False, max_lines=1, container=False)
170
+ with gr.Row():
171
+ clear_chat_button = gr.Button("Clear")
172
+ chat_button = gr.Button("Submit", variant="primary")
173
+ with gr.Accordion(label="Advanced settings", open=False):
174
+ text_decoding_method = gr.Radio(
175
+ label="Text Decoding Method",
176
+ choices=["Beam search", "Nucleus sampling"],
177
+ value="Nucleus sampling",
178
+ )
179
+ temperature = gr.Slider(
180
+ label="Temperature",
181
+ info="Used with nucleus sampling.",
182
+ minimum=0.5,
183
+ maximum=1.0,
184
+ step=0.1,
185
+ value=1.0,
186
+ )
187
+ length_penalty = gr.Slider(
188
+ label="Length Penalty",
189
+ info="Set to larger for longer sequence, used with beam search.",
190
+ minimum=-1.0,
191
+ maximum=2.0,
192
+ step=0.2,
193
+ value=1.0,
194
+ )
195
+ repetition_penalty = gr.Slider(
196
+ label="Repetition Penalty",
197
+ info="Larger value prevents repetition.",
198
+ minimum=1.0,
199
+ maximum=5.0,
200
+ step=0.5,
201
+ value=1.5,
202
+ )
203
+ max_length = gr.Slider(
204
+ label="Max Length",
205
+ minimum=20,
206
+ maximum=512,
207
+ step=1,
208
+ value=50,
209
+ )
210
+ min_length = gr.Slider(
211
+ label="Minimum Length",
212
+ minimum=1,
213
+ maximum=100,
214
+ step=1,
215
+ value=1,
216
+ )
217
+ num_beams = gr.Slider(
218
+ label="Number of Beams",
219
+ minimum=1,
220
+ maximum=10,
221
+ step=1,
222
+ value=5,
223
+ )
224
+ top_p = gr.Slider(
225
+ label="Top P",
226
+ info="Used with nucleus sampling.",
227
+ minimum=0.5,
228
+ maximum=1.0,
229
+ step=0.1,
230
+ value=0.9,
231
+ )
232
+
233
+ gr.Examples(
234
+ examples=examples,
235
+ inputs=[image, vqa_input],
236
+ outputs=caption_output,
237
+ fn=generate_caption,
238
+ )
239
+
240
+ caption_button.click(
241
+ fn=generate_caption,
242
+ inputs=[
243
+ image,
244
+ text_decoding_method,
245
+ temperature,
246
+ length_penalty,
247
+ repetition_penalty,
248
+ max_length,
249
+ min_length,
250
+ num_beams,
251
+ top_p,
252
+ ],
253
+ outputs=caption_output,
254
+ api_name="caption",
255
+ )
256
+
257
+ chat_inputs = [
258
+ image,
259
+ vqa_input,
260
+ text_decoding_method,
261
+ temperature,
262
+ length_penalty,
263
+ repetition_penalty,
264
+ max_length,
265
+ min_length,
266
+ num_beams,
267
+ top_p,
268
+ history_orig,
269
+ history_qa,
270
+ ]
271
+ chat_outputs = [
272
+ chatbot,
273
+ history_orig,
274
+ history_qa,
275
+ ]
276
+ vqa_input.submit(
277
+ fn=chat,
278
+ inputs=chat_inputs,
279
+ outputs=chat_outputs,
280
+ ).success(
281
+ fn=lambda: "",
282
+ outputs=vqa_input,
283
+ queue=False,
284
+ api_name=False,
285
+ )
286
+ chat_button.click(
287
+ fn=chat,
288
+ inputs=chat_inputs,
289
+ outputs=chat_outputs,
290
+ api_name="chat",
291
+ ).success(
292
+ fn=lambda: "",
293
+ outputs=vqa_input,
294
+ queue=False,
295
+ api_name=False,
296
+ )
297
+ clear_chat_button.click(
298
+ fn=lambda: ("", [], [], []),
299
+ inputs=None,
300
+ outputs=[
301
+ vqa_input,
302
+ chatbot,
303
+ history_orig,
304
+ history_qa,
305
+ ],
306
+ queue=False,
307
+ api_name="clear",
308
+ )
309
+ image.change(
310
+ fn=lambda: ("", [], [], []),
311
+ inputs=None,
312
+ outputs=[
313
+ caption_output,
314
+ chatbot,
315
+ history_orig,
316
+ history_qa,
317
+ ],
318
+ queue=False,
319
+ )
320
+
321
+ if __name__ == "__main__":
322
+ demo.queue(max_size=10).launch()