Dongxu Li commited on
Commit
81cf2fa
1 Parent(s): f7f5be8

finish adding opt for captioning.

Browse files
Files changed (1) hide show
  1. app.py +77 -17
app.py CHANGED
@@ -14,7 +14,7 @@ def encode_image(image):
14
  return buffered
15
 
16
 
17
- def query_api(
18
  image, prompt, decoding_method, temperature, len_penalty, repetition_penalty
19
  ):
20
 
@@ -41,6 +41,34 @@ def query_api(
41
  return "Error: " + response.text
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def postprocess_output(output):
45
  # if last character is not a punctuation, add a full stop
46
  if not output[0][-1] in string.punctuation:
@@ -49,7 +77,7 @@ def postprocess_output(output):
49
  return output
50
 
51
 
52
- def inference(
53
  image,
54
  text_input,
55
  decoding_method,
@@ -64,7 +92,7 @@ def inference(
64
  prompt = " ".join(history)
65
  print(prompt)
66
 
67
- output = query_api(
68
  image, prompt, decoding_method, temperature, length_penalty, repetition_penalty
69
  )
70
  output = postprocess_output(output)
@@ -77,6 +105,20 @@ def inference(
77
  return {chatbot: chat, state: history}
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  title = """<h1 align="center">BLIP-2</h1>"""
81
  description = """Gradio demo for BLIP-2, a multimodal chatbot from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Please visit our <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'>project webpage</a>.</p>
82
  <p> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected. </p>"""
@@ -101,16 +143,15 @@ with gr.Blocks() as iface:
101
  with gr.Row():
102
  with gr.Column():
103
  image_input = gr.Image(type="pil")
104
- text_input = gr.Textbox(lines=2, label="Text input")
105
-
106
- sampling = gr.Radio(
107
- choices=["Beam search", "Nucleus sampling"],
108
- value="Beam search",
109
- label="Text Decoding Method",
110
- interactive=True,
111
- )
112
 
113
  with gr.Row():
 
 
 
 
 
 
 
114
  temperature = gr.Slider(
115
  minimum=0.5,
116
  maximum=1.0,
@@ -134,13 +175,32 @@ with gr.Blocks() as iface:
134
  value=10.0,
135
  step=0.5,
136
  interactive=True,
137
- label="Repetition Penalty",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  )
139
 
140
  with gr.Column():
 
 
141
  with gr.Row():
142
  chatbot = gr.Chatbot()
143
- image_input.change(lambda: (None, []), [], [chatbot, state])
144
 
145
  with gr.Row():
146
 
@@ -148,17 +208,17 @@ with gr.Blocks() as iface:
148
  clear_button.click(
149
  lambda: ("", None, [], []),
150
  [],
151
- [text_input, image_input, chatbot, state],
152
  )
153
 
154
  submit_button = gr.Button(
155
  value="Submit", interactive=True, variant="primary"
156
  )
157
  submit_button.click(
158
- inference,
159
  [
160
  image_input,
161
- text_input,
162
  sampling,
163
  temperature,
164
  len_penalty,
@@ -170,7 +230,7 @@ with gr.Blocks() as iface:
170
 
171
  examples = gr.Examples(
172
  examples=examples,
173
- inputs=[image_input, text_input],
174
  )
175
 
176
  iface.queue(concurrency_count=1, api_open=False, max_size=20)
 
14
  return buffered
15
 
16
 
17
+ def query_chat_api(
18
  image, prompt, decoding_method, temperature, len_penalty, repetition_penalty
19
  ):
20
 
 
41
  return "Error: " + response.text
42
 
43
 
44
+ def query_caption_api(
45
+ image, decoding_method, temperature, len_penalty, repetition_penalty
46
+ ):
47
+
48
+ url = endpoint.url
49
+ # replace /generate with /caption
50
+ url = url.replace("/generate", "/caption")
51
+
52
+ headers = {"User-Agent": "BLIP-2 HuggingFace Space"}
53
+
54
+ data = {
55
+ "use_nucleus_sampling": decoding_method == "Nucleus sampling",
56
+ "temperature": temperature,
57
+ "length_penalty": len_penalty,
58
+ "repetition_penalty": repetition_penalty,
59
+ }
60
+
61
+ image = encode_image(image)
62
+ files = {"image": image}
63
+
64
+ response = requests.post(url, data=data, files=files, headers=headers)
65
+
66
+ if response.status_code == 200:
67
+ return response.json()
68
+ else:
69
+ return "Error: " + response.text
70
+
71
+
72
  def postprocess_output(output):
73
  # if last character is not a punctuation, add a full stop
74
  if not output[0][-1] in string.punctuation:
 
77
  return output
78
 
79
 
80
+ def inference_chat(
81
  image,
82
  text_input,
83
  decoding_method,
 
92
  prompt = " ".join(history)
93
  print(prompt)
94
 
95
+ output = query_chat_api(
96
  image, prompt, decoding_method, temperature, length_penalty, repetition_penalty
97
  )
98
  output = postprocess_output(output)
 
105
  return {chatbot: chat, state: history}
106
 
107
 
108
+ def inference_caption(
109
+ image,
110
+ decoding_method,
111
+ temperature,
112
+ length_penalty,
113
+ repetition_penalty,
114
+ ):
115
+ output = query_caption_api(
116
+ image, decoding_method, temperature, length_penalty, repetition_penalty
117
+ )
118
+
119
+ return output[0]
120
+
121
+
122
  title = """<h1 align="center">BLIP-2</h1>"""
123
  description = """Gradio demo for BLIP-2, a multimodal chatbot from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Please visit our <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'>project webpage</a>.</p>
124
  <p> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected. </p>"""
 
143
  with gr.Row():
144
  with gr.Column():
145
  image_input = gr.Image(type="pil")
 
 
 
 
 
 
 
 
146
 
147
  with gr.Row():
148
+ sampling = gr.Radio(
149
+ choices=["Beam search", "Nucleus sampling"],
150
+ value="Beam search",
151
+ label="Text Decoding Method",
152
+ interactive=True,
153
+ )
154
+
155
  temperature = gr.Slider(
156
  minimum=0.5,
157
  maximum=1.0,
 
175
  value=10.0,
176
  step=0.5,
177
  interactive=True,
178
+ label="Repeat Penalty",
179
+ )
180
+
181
+ with gr.Row():
182
+ caption_output = gr.Textbox(lines=2, label="Caption Output")
183
+ caption_button = gr.Button(
184
+ value="Caption it!", interactive=True, variant="primary"
185
+ )
186
+ caption_button.click(
187
+ inference_caption,
188
+ [
189
+ image_input,
190
+ sampling,
191
+ temperature,
192
+ len_penalty,
193
+ rep_penalty,
194
+ ],
195
+ [caption_output],
196
  )
197
 
198
  with gr.Column():
199
+ chat_input = gr.Textbox(lines=2, label="Chat Input")
200
+
201
  with gr.Row():
202
  chatbot = gr.Chatbot()
203
+ image_input.change(lambda: (None, "", "", []), [], [chatbot, chat_input, caption_output, state])
204
 
205
  with gr.Row():
206
 
 
208
  clear_button.click(
209
  lambda: ("", None, [], []),
210
  [],
211
+ [chat_input, image_input, chatbot, state],
212
  )
213
 
214
  submit_button = gr.Button(
215
  value="Submit", interactive=True, variant="primary"
216
  )
217
  submit_button.click(
218
+ inference_chat,
219
  [
220
  image_input,
221
+ chat_input,
222
  sampling,
223
  temperature,
224
  len_penalty,
 
230
 
231
  examples = gr.Examples(
232
  examples=examples,
233
+ inputs=[image_input, chat_input],
234
  )
235
 
236
  iface.queue(concurrency_count=1, api_open=False, max_size=20)