byterocker20 commited on
Commit
b538d20
1 Parent(s): 3aa769d

Add application file

Browse files
Files changed (4) hide show
  1. app.py +290 -0
  2. image1.jpg +0 -0
  3. image2.jpg +0 -0
  4. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
+
4
+ import requests
5
+ import copy
6
+
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ import io
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib.patches as patches
11
+
12
+ import random
13
+ import numpy as np
14
+
15
+ import os
16
+ import subprocess
17
+
18
+ from unittest.mock import patch
19
+ from transformers.dynamic_module_utils import get_imports
20
+ def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
21
+ if not str(filename).endswith("modeling_florence2.py"):
22
+ return get_imports(filename)
23
+ imports = get_imports(filename)
24
+ imports.remove("flash_attn")
25
+ return imports
26
+
27
+
28
+ models = {
29
+ 'microsoft/Florence-2-base-ft': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True, device_map='cpu').eval(),
30
+ 'microsoft/Florence-2-base': AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True, device_map='cpu').eval(),
31
+ }
32
+
33
+ processors = {
34
+ 'microsoft/Florence-2-base-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True),
35
+ 'microsoft/Florence-2-base': AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True),
36
+ }
37
+
38
+
39
+ DESCRIPTION = "# [Florence-2 Demo](https://huggingface.co/microsoft/Florence-2-base)"
40
+
41
+ colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
42
+ 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
43
+
44
+ def fig_to_pil(fig):
45
+ buf = io.BytesIO()
46
+ fig.savefig(buf, format='png')
47
+ buf.seek(0)
48
+ return Image.open(buf)
49
+
50
+
51
+ def run_example(task_prompt, image, text_input=None, model_id='microsoft/Florence-2-large'):
52
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
53
+ model = models[model_id]
54
+ processor = processors[model_id]
55
+ if text_input is None:
56
+ prompt = task_prompt
57
+ else:
58
+ prompt = task_prompt + text_input
59
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
60
+ generated_ids = model.generate(
61
+ input_ids=inputs["input_ids"],
62
+ pixel_values=inputs["pixel_values"],
63
+ max_new_tokens=1024,
64
+ early_stopping=False,
65
+ do_sample=False,
66
+ num_beams=3,
67
+ )
68
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
69
+ parsed_answer = processor.post_process_generation(
70
+ generated_text,
71
+ task=task_prompt,
72
+ image_size=(image.width, image.height)
73
+ )
74
+ return parsed_answer
75
+
76
+ def plot_bbox(image, data):
77
+ fig, ax = plt.subplots()
78
+ ax.imshow(image)
79
+ for bbox, label in zip(data['bboxes'], data['labels']):
80
+ x1, y1, x2, y2 = bbox
81
+ rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
82
+ ax.add_patch(rect)
83
+ plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
84
+ ax.axis('off')
85
+ return fig
86
+
87
+ def draw_polygons(image, prediction, fill_mask=False):
88
+
89
+ draw = ImageDraw.Draw(image)
90
+ scale = 1
91
+ for polygons, label in zip(prediction['polygons'], prediction['labels']):
92
+ color = random.choice(colormap)
93
+ fill_color = random.choice(colormap) if fill_mask else None
94
+ for _polygon in polygons:
95
+ _polygon = np.array(_polygon).reshape(-1, 2)
96
+ if len(_polygon) < 3:
97
+ print('Invalid polygon:', _polygon)
98
+ continue
99
+ _polygon = (_polygon * scale).reshape(-1).tolist()
100
+ if fill_mask:
101
+ draw.polygon(_polygon, outline=color, fill=fill_color)
102
+ else:
103
+ draw.polygon(_polygon, outline=color)
104
+ draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color)
105
+ return image
106
+
107
+ def convert_to_od_format(data):
108
+ bboxes = data.get('bboxes', [])
109
+ labels = data.get('bboxes_labels', [])
110
+ od_results = {
111
+ 'bboxes': bboxes,
112
+ 'labels': labels
113
+ }
114
+ return od_results
115
+
116
+ def draw_ocr_bboxes(image, prediction):
117
+ scale = 1
118
+ draw = ImageDraw.Draw(image)
119
+ bboxes, labels = prediction['quad_boxes'], prediction['labels']
120
+ for box, label in zip(bboxes, labels):
121
+ color = random.choice(colormap)
122
+ new_box = (np.array(box) * scale).tolist()
123
+ draw.polygon(new_box, width=3, outline=color)
124
+ draw.text((new_box[0]+8, new_box[1]+2),
125
+ "{}".format(label),
126
+ align="right",
127
+ fill=color)
128
+ return image
129
+
130
+ def process_image(image, task_prompt, text_input=None, model_id='microsoft/Florence-2-large'):
131
+ image = Image.fromarray(image) # Convert NumPy array to PIL Image
132
+ if task_prompt == 'Caption':
133
+ task_prompt = '<CAPTION>'
134
+ results = run_example(task_prompt, image, model_id=model_id)
135
+ return results, None
136
+ elif task_prompt == 'Detailed Caption':
137
+ task_prompt = '<DETAILED_CAPTION>'
138
+ results = run_example(task_prompt, image, model_id=model_id)
139
+ return results, None
140
+ elif task_prompt == 'More Detailed Caption':
141
+ task_prompt = '<MORE_DETAILED_CAPTION>'
142
+ results = run_example(task_prompt, image, model_id=model_id)
143
+ return results, None
144
+ elif task_prompt == 'Caption + Grounding':
145
+ task_prompt = '<CAPTION>'
146
+ results = run_example(task_prompt, image, model_id=model_id)
147
+ text_input = results[task_prompt]
148
+ task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
149
+ results = run_example(task_prompt, image, text_input, model_id)
150
+ results['<CAPTION>'] = text_input
151
+ fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
152
+ return results, fig_to_pil(fig)
153
+ elif task_prompt == 'Detailed Caption + Grounding':
154
+ task_prompt = '<DETAILED_CAPTION>'
155
+ results = run_example(task_prompt, image, model_id=model_id)
156
+ text_input = results[task_prompt]
157
+ task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
158
+ results = run_example(task_prompt, image, text_input, model_id)
159
+ results['<DETAILED_CAPTION>'] = text_input
160
+ fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
161
+ return results, fig_to_pil(fig)
162
+ elif task_prompt == 'More Detailed Caption + Grounding':
163
+ task_prompt = '<MORE_DETAILED_CAPTION>'
164
+ results = run_example(task_prompt, image, model_id=model_id)
165
+ text_input = results[task_prompt]
166
+ task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
167
+ results = run_example(task_prompt, image, text_input, model_id)
168
+ results['<MORE_DETAILED_CAPTION>'] = text_input
169
+ fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
170
+ return results, fig_to_pil(fig)
171
+ elif task_prompt == 'Object Detection':
172
+ task_prompt = '<OD>'
173
+ results = run_example(task_prompt, image, model_id=model_id)
174
+ fig = plot_bbox(image, results['<OD>'])
175
+ return results, fig_to_pil(fig)
176
+ elif task_prompt == 'Dense Region Caption':
177
+ task_prompt = '<DENSE_REGION_CAPTION>'
178
+ results = run_example(task_prompt, image, model_id=model_id)
179
+ fig = plot_bbox(image, results['<DENSE_REGION_CAPTION>'])
180
+ return results, fig_to_pil(fig)
181
+ elif task_prompt == 'Region Proposal':
182
+ task_prompt = '<REGION_PROPOSAL>'
183
+ results = run_example(task_prompt, image, model_id=model_id)
184
+ fig = plot_bbox(image, results['<REGION_PROPOSAL>'])
185
+ return results, fig_to_pil(fig)
186
+ elif task_prompt == 'Caption to Phrase Grounding':
187
+ task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
188
+ results = run_example(task_prompt, image, text_input, model_id)
189
+ fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
190
+ return results, fig_to_pil(fig)
191
+ elif task_prompt == 'Referring Expression Segmentation':
192
+ task_prompt = '<REFERRING_EXPRESSION_SEGMENTATION>'
193
+ results = run_example(task_prompt, image, text_input, model_id)
194
+ output_image = copy.deepcopy(image)
195
+ output_image = draw_polygons(output_image, results['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)
196
+ return results, output_image
197
+ elif task_prompt == 'Region to Segmentation':
198
+ task_prompt = '<REGION_TO_SEGMENTATION>'
199
+ results = run_example(task_prompt, image, text_input, model_id)
200
+ output_image = copy.deepcopy(image)
201
+ output_image = draw_polygons(output_image, results['<REGION_TO_SEGMENTATION>'], fill_mask=True)
202
+ return results, output_image
203
+ elif task_prompt == 'Open Vocabulary Detection':
204
+ task_prompt = '<OPEN_VOCABULARY_DETECTION>'
205
+ results = run_example(task_prompt, image, text_input, model_id)
206
+ bbox_results = convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>'])
207
+ fig = plot_bbox(image, bbox_results)
208
+ return results, fig_to_pil(fig)
209
+ elif task_prompt == 'Region to Category':
210
+ task_prompt = '<REGION_TO_CATEGORY>'
211
+ results = run_example(task_prompt, image, text_input, model_id)
212
+ return results, None
213
+ elif task_prompt == 'Region to Description':
214
+ task_prompt = '<REGION_TO_DESCRIPTION>'
215
+ results = run_example(task_prompt, image, text_input, model_id)
216
+ return results, None
217
+ elif task_prompt == 'OCR':
218
+ task_prompt = '<OCR>'
219
+ results = run_example(task_prompt, image, model_id=model_id)
220
+ return results, None
221
+ elif task_prompt == 'OCR with Region':
222
+ task_prompt = '<OCR_WITH_REGION>'
223
+ results = run_example(task_prompt, image, model_id=model_id)
224
+ output_image = copy.deepcopy(image)
225
+ output_image = draw_ocr_bboxes(output_image, results['<OCR_WITH_REGION>'])
226
+ return results, output_image
227
+ else:
228
+ return "", None # Return empty string and None for unknown task prompts
229
+
230
+ css = """
231
+ #output {
232
+ height: 500px;
233
+ overflow: auto;
234
+ border: 1px solid #ccc;
235
+ }
236
+ """
237
+
238
+
239
+ single_task_list =[
240
+ 'Caption', 'Detailed Caption', 'More Detailed Caption', 'Object Detection',
241
+ 'Dense Region Caption', 'Region Proposal', 'Caption to Phrase Grounding',
242
+ 'Referring Expression Segmentation', 'Region to Segmentation',
243
+ 'Open Vocabulary Detection', 'Region to Category', 'Region to Description',
244
+ 'OCR', 'OCR with Region'
245
+ ]
246
+
247
+ cascased_task_list =[
248
+ 'Caption + Grounding', 'Detailed Caption + Grounding', 'More Detailed Caption + Grounding'
249
+ ]
250
+
251
+
252
+ def update_task_dropdown(choice):
253
+ if choice == 'Cascased task':
254
+ return gr.Dropdown(choices=cascased_task_list, value='Caption + Grounding')
255
+ else:
256
+ return gr.Dropdown(choices=single_task_list, value='Caption')
257
+
258
+
259
+
260
+ with gr.Blocks(css=css) as demo:
261
+ gr.Markdown(DESCRIPTION)
262
+ with gr.Tab(label="Florence-2 Image Captioning"):
263
+ with gr.Row():
264
+ with gr.Column():
265
+ input_img = gr.Image(label="Input Picture")
266
+ model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large')
267
+ task_type = gr.Radio(choices=['Single task', 'Cascased task'], label='Task type selector', value='Single task')
268
+ task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Caption")
269
+ task_type.change(fn=update_task_dropdown, inputs=task_type, outputs=task_prompt)
270
+ text_input = gr.Textbox(label="Text Input (optional)")
271
+ submit_btn = gr.Button(value="Submit")
272
+ with gr.Column():
273
+ output_text = gr.Textbox(label="Output Text")
274
+ output_img = gr.Image(label="Output Image")
275
+
276
+ gr.Examples(
277
+ examples=[
278
+ ["image1.jpg", 'Object Detection'],
279
+ ["image2.jpg", 'OCR with Region']
280
+ ],
281
+ inputs=[input_img, task_prompt],
282
+ outputs=[output_text, output_img],
283
+ fn=process_image,
284
+ cache_examples=True,
285
+ label='Try examples'
286
+ )
287
+
288
+ submit_btn.click(process_image, [input_img, task_prompt, text_input, model_selector], [output_text, output_img])
289
+
290
+ demo.launch(debug=True)
image1.jpg ADDED
image2.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ timm