Expose function

#5
by merve HF staff - opened
Files changed (1) hide show
  1. app.py +137 -143
app.py CHANGED
@@ -163,152 +163,146 @@ def draw_entity_boxes_on_image(image, entities, show=False, save_path=None, enti
163
  return pil_image
164
 
165
 
166
- def main():
167
 
168
- ckpt = "microsoft/kosmos-2-patch14-224"
169
 
170
- model = AutoModelForVision2Seq.from_pretrained(ckpt).to("cuda")
171
- processor = AutoProcessor.from_pretrained(ckpt)
172
 
173
- def generate_predictions(image_input, text_input):
174
 
175
- # Save the image and load it again to match the original Kosmos-2 demo.
176
- # (https://github.com/microsoft/unilm/blob/f4695ed0244a275201fff00bee495f76670fbe70/kosmos-2/demo/gradio_app.py#L345-L346)
177
- user_image_path = "/tmp/user_input_test_image.jpg"
178
- image_input.save(user_image_path)
179
- # This might give different results from the original argument `image_input`
180
- image_input = Image.open(user_image_path)
181
 
182
- if text_input == "Brief":
183
- text_input = "<grounding>An image of"
184
- elif text_input == "Detailed":
185
- text_input = "<grounding>Describe this image in detail:"
186
- else:
187
- text_input = f"<grounding>{text_input}"
188
-
189
- inputs = processor(text=text_input, images=image_input, return_tensors="pt").to("cuda")
190
-
191
- generated_ids = model.generate(
192
- pixel_values=inputs["pixel_values"],
193
- input_ids=inputs["input_ids"],
194
- attention_mask=inputs["attention_mask"],
195
- image_embeds=None,
196
- image_embeds_position_mask=inputs["image_embeds_position_mask"],
197
- use_cache=True,
198
- max_new_tokens=128,
199
- )
200
-
201
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
202
-
203
- # By default, the generated text is cleanup and the entities are extracted.
204
- processed_text, entities = processor.post_process_generation(generated_text)
205
-
206
- annotated_image = draw_entity_boxes_on_image(image_input, entities, show=False)
207
-
208
- color_id = -1
209
- entity_info = []
210
- filtered_entities = []
211
- for entity in entities:
212
- entity_name, (start, end), bboxes = entity
213
- if start == end:
214
- # skip bounding bbox without a `phrase` associated
215
- continue
216
- color_id += 1
217
- # for bbox_id, _ in enumerate(bboxes):
218
- # if start is None and bbox_id > 0:
219
- # color_id += 1
220
- entity_info.append(((start, end), color_id))
221
- filtered_entities.append(entity)
222
-
223
- colored_text = []
224
- prev_start = 0
225
- end = 0
226
- for idx, ((start, end), color_id) in enumerate(entity_info):
227
- if start > prev_start:
228
- colored_text.append((processed_text[prev_start:start], None))
229
- colored_text.append((processed_text[start:end], f"{color_id}"))
230
- prev_start = end
231
-
232
- if end < len(processed_text):
233
- colored_text.append((processed_text[end:len(processed_text)], None))
234
-
235
- return annotated_image, colored_text, str(filtered_entities)
236
-
237
- term_of_use = """
238
- ### Terms of use
239
- By using this model, users are required to agree to the following terms:
240
- The model is intended for academic and research purposes.
241
- The utilization of the model to create unsuitable material is strictly forbidden and not endorsed by this work.
242
- The accountability for any improper or unacceptable application of the model rests exclusively with the individuals who generated such content.
243
-
244
- ### License
245
- This project is licensed under the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct).
246
- """
247
 
248
- with gr.Blocks(title="Kosmos-2", theme=gr.themes.Base()).queue() as demo:
249
- gr.Markdown(("""
250
- # Kosmos-2: Grounding Multimodal Large Language Models to the World
251
- [[Paper]](https://arxiv.org/abs/2306.14824) [[Code]](https://github.com/microsoft/unilm/blob/master/kosmos-2)
252
- """))
253
- with gr.Row():
254
- with gr.Column():
255
- image_input = gr.Image(type="pil", label="Test Image")
256
- text_input = gr.Radio(["Brief", "Detailed"], label="Description Type", value="Brief")
257
-
258
- run_button = gr.Button(label="Run", visible=True)
259
-
260
- with gr.Column():
261
- image_output = gr.Image(type="pil")
262
- text_output1 = gr.HighlightedText(
263
- label="Generated Description",
264
- combine_adjacent=False,
265
- show_legend=True,
266
- ).style(color_map=color_map)
267
-
268
- with gr.Row():
269
- with gr.Column():
270
- gr.Examples(examples=[
271
- ["images/two_dogs.jpg", "Detailed"],
272
- ["images/snowman.png", "Brief"],
273
- ["images/man_ball.png", "Detailed"],
274
- ], inputs=[image_input, text_input])
275
- with gr.Column():
276
- gr.Examples(examples=[
277
- ["images/six_planes.png", "Brief"],
278
- ["images/quadrocopter.jpg", "Brief"],
279
- ["images/carnaby_street.jpg", "Brief"],
280
- ], inputs=[image_input, text_input])
281
- gr.Markdown(term_of_use)
282
-
283
- # record which text span (label) is selected
284
- selected = gr.Number(-1, show_label=False, placeholder="Selected", visible=False)
285
-
286
- # record the current `entities`
287
- entity_output = gr.Textbox(visible=False)
288
-
289
- # get the current selected span label
290
- def get_text_span_label(evt: gr.SelectData):
291
- if evt.value[-1] is None:
292
- return -1
293
- return int(evt.value[-1])
294
- # and set this information to `selected`
295
- text_output1.select(get_text_span_label, None, selected)
296
-
297
- # update output image when we change the span (enity) selection
298
- def update_output_image(img_input, image_output, entities, idx):
299
- entities = ast.literal_eval(entities)
300
- updated_image = draw_entity_boxes_on_image(img_input, entities, entity_index=idx)
301
- return updated_image
302
- selected.change(update_output_image, [image_input, image_output, entity_output, selected], [image_output])
303
-
304
- run_button.click(fn=generate_predictions,
305
- inputs=[image_input, text_input],
306
- outputs=[image_output, text_output1, entity_output],
307
- show_progress=True, queue=True)
308
-
309
- demo.launch(share=False)
310
-
311
-
312
- if __name__ == "__main__":
313
- main()
314
- # trigger
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  return pil_image
164
 
165
 
 
166
 
167
+ ckpt = "microsoft/kosmos-2-patch14-224"
168
 
169
+ model = AutoModelForVision2Seq.from_pretrained(ckpt).to("cuda")
170
+ processor = AutoProcessor.from_pretrained(ckpt)
171
 
172
+ def generate_predictions(image_input, text_input):
173
 
174
+ # Save the image and load it again to match the original Kosmos-2 demo.
175
+ # (https://github.com/microsoft/unilm/blob/f4695ed0244a275201fff00bee495f76670fbe70/kosmos-2/demo/gradio_app.py#L345-L346)
176
+ user_image_path = "/tmp/user_input_test_image.jpg"
177
+ image_input.save(user_image_path)
178
+ # This might give different results from the original argument `image_input`
179
+ image_input = Image.open(user_image_path)
180
 
181
+ if text_input == "Brief":
182
+ text_input = "<grounding>An image of"
183
+ elif text_input == "Detailed":
184
+ text_input = "<grounding>Describe this image in detail:"
185
+ else:
186
+ text_input = f"<grounding>{text_input}"
187
+
188
+ inputs = processor(text=text_input, images=image_input, return_tensors="pt").to("cuda")
189
+
190
+ generated_ids = model.generate(
191
+ pixel_values=inputs["pixel_values"],
192
+ input_ids=inputs["input_ids"],
193
+ attention_mask=inputs["attention_mask"],
194
+ image_embeds=None,
195
+ image_embeds_position_mask=inputs["image_embeds_position_mask"],
196
+ use_cache=True,
197
+ max_new_tokens=128,
198
+ )
199
+
200
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
201
+
202
+ # By default, the generated text is cleanup and the entities are extracted.
203
+ processed_text, entities = processor.post_process_generation(generated_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ annotated_image = draw_entity_boxes_on_image(image_input, entities, show=False)
206
+
207
+ color_id = -1
208
+ entity_info = []
209
+ filtered_entities = []
210
+ for entity in entities:
211
+ entity_name, (start, end), bboxes = entity
212
+ if start == end:
213
+ # skip bounding bbox without a `phrase` associated
214
+ continue
215
+ color_id += 1
216
+ # for bbox_id, _ in enumerate(bboxes):
217
+ # if start is None and bbox_id > 0:
218
+ # color_id += 1
219
+ entity_info.append(((start, end), color_id))
220
+ filtered_entities.append(entity)
221
+
222
+ colored_text = []
223
+ prev_start = 0
224
+ end = 0
225
+ for idx, ((start, end), color_id) in enumerate(entity_info):
226
+ if start > prev_start:
227
+ colored_text.append((processed_text[prev_start:start], None))
228
+ colored_text.append((processed_text[start:end], f"{color_id}"))
229
+ prev_start = end
230
+
231
+ if end < len(processed_text):
232
+ colored_text.append((processed_text[end:len(processed_text)], None))
233
+
234
+ return annotated_image, colored_text, str(filtered_entities)
235
+
236
+ term_of_use = """
237
+ ### Terms of use
238
+ By using this model, users are required to agree to the following terms:
239
+ The model is intended for academic and research purposes.
240
+ The utilization of the model to create unsuitable material is strictly forbidden and not endorsed by this work.
241
+ The accountability for any improper or unacceptable application of the model rests exclusively with the individuals who generated such content.
242
+
243
+ ### License
244
+ This project is licensed under the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct).
245
+ """
246
+
247
+ with gr.Blocks(title="Kosmos-2", theme=gr.themes.Base()).queue() as demo:
248
+ gr.Markdown(("""
249
+ # Kosmos-2: Grounding Multimodal Large Language Models to the World
250
+ [[Paper]](https://arxiv.org/abs/2306.14824) [[Code]](https://github.com/microsoft/unilm/blob/master/kosmos-2)
251
+ """))
252
+ with gr.Row():
253
+ with gr.Column():
254
+ image_input = gr.Image(type="pil", label="Test Image")
255
+ text_input = gr.Radio(["Brief", "Detailed"], label="Description Type", value="Brief")
256
+
257
+ run_button = gr.Button(label="Run", visible=True)
258
+
259
+ with gr.Column():
260
+ image_output = gr.Image(type="pil")
261
+ text_output1 = gr.HighlightedText(
262
+ label="Generated Description",
263
+ combine_adjacent=False,
264
+ show_legend=True,
265
+ ).style(color_map=color_map)
266
+
267
+ with gr.Row():
268
+ with gr.Column():
269
+ gr.Examples(examples=[
270
+ ["images/two_dogs.jpg", "Detailed"],
271
+ ["images/snowman.png", "Brief"],
272
+ ["images/man_ball.png", "Detailed"],
273
+ ], inputs=[image_input, text_input])
274
+ with gr.Column():
275
+ gr.Examples(examples=[
276
+ ["images/six_planes.png", "Brief"],
277
+ ["images/quadrocopter.jpg", "Brief"],
278
+ ["images/carnaby_street.jpg", "Brief"],
279
+ ], inputs=[image_input, text_input])
280
+ gr.Markdown(term_of_use)
281
+
282
+ # record which text span (label) is selected
283
+ selected = gr.Number(-1, show_label=False, placeholder="Selected", visible=False)
284
+
285
+ # record the current `entities`
286
+ entity_output = gr.Textbox(visible=False)
287
+
288
+ # get the current selected span label
289
+ def get_text_span_label(evt: gr.SelectData):
290
+ if evt.value[-1] is None:
291
+ return -1
292
+ return int(evt.value[-1])
293
+ # and set this information to `selected`
294
+ text_output1.select(get_text_span_label, None, selected)
295
+
296
+ # update output image when we change the span (enity) selection
297
+ def update_output_image(img_input, image_output, entities, idx):
298
+ entities = ast.literal_eval(entities)
299
+ updated_image = draw_entity_boxes_on_image(img_input, entities, entity_index=idx)
300
+ return updated_image
301
+ selected.change(update_output_image, [image_input, image_output, entity_output, selected], [image_output])
302
+
303
+ run_button.click(fn=generate_predictions,
304
+ inputs=[image_input, text_input],
305
+ outputs=[image_output, text_output1, entity_output],
306
+ show_progress=True, queue=True)
307
+
308
+ demo.launch(share=False)