SixOpen commited on
Commit
2efe949
1 Parent(s): 6cbd03a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -0
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from unittest.mock import patch
3
+ import spaces
4
+ import gradio as gr
5
+ from transformers import AutoProcessor, AutoModelForCausalLM
6
+ from transformers.dynamic_module_utils import get_imports
7
+ import torch
8
+ import requests
9
+ from PIL import Image, ImageDraw
10
+ import random
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib.patches as patches
14
+ import cv2
15
+ import io
16
+
17
+ def workaround_fixed_get_imports(filename: str | os.PathLike) -> list[str]:
18
+ if not str(filename).endswith("/modeling_florence2.py"):
19
+ return get_imports(filename)
20
+ imports = get_imports(filename)
21
+ imports.remove("flash_attn")
22
+ return imports
23
+
24
+ with patch("transformers.dynamic_module_utils.get_imports", workaround_fixed_get_imports):
25
+ model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True).to("cuda").eval()
26
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
27
+
28
+ colormap = ['blue', 'orange', 'green', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'red',
29
+ 'lime', 'indigo', 'violet', 'aqua', 'magenta', 'coral', 'gold', 'tan', 'skyblue']
30
+
31
+ def fig_to_pil(fig):
32
+ buf = io.BytesIO()
33
+ fig.savefig(buf, format='png')
34
+ buf.seek(0)
35
+ return Image.open(buf)
36
+
37
+ @spaces.GPU
38
+ def run_example(task_prompt, image, text_input=None):
39
+ if text_input is None:
40
+ prompt = task_prompt
41
+ else:
42
+ prompt = task_prompt + text_input
43
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
44
+ with torch.inference_mode():
45
+ generated_ids = model.generate(
46
+ input_ids=inputs["input_ids"],
47
+ pixel_values=inputs["pixel_values"],
48
+ max_new_tokens=1024,
49
+ early_stopping=False,
50
+ do_sample=False,
51
+ num_beams=3,
52
+ )
53
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
54
+ parsed_answer = processor.post_process_generation(
55
+ generated_text,
56
+ task=task_prompt,
57
+ image_size=(image.size[0], image.size[1])
58
+ )
59
+ return parsed_answer
60
+
61
+ def plot_bbox(image, data):
62
+ fig, ax = plt.subplots()
63
+ ax.imshow(image)
64
+ for bbox, label in zip(data['bboxes'], data['labels']):
65
+ x1, y1, x2, y2 = bbox
66
+ rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
67
+ ax.add_patch(rect)
68
+ plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='indigo', alpha=0.5))
69
+ ax.axis('off')
70
+ return fig_to_pil(fig)
71
+
72
+ def draw_polygons(image, prediction, fill_mask=False):
73
+ fig, ax = plt.subplots()
74
+ ax.imshow(image)
75
+ scale = 1
76
+ for polygons, label in zip(prediction['polygons'], prediction['labels']):
77
+ color = random.choice(colormap)
78
+ fill_color = random.choice(colormap) if fill_mask else None
79
+ for _polygon in polygons:
80
+ _polygon = np.array(_polygon).reshape(-1, 2)
81
+ if _polygon.shape[0] < 3:
82
+ print('Invalid polygon:', _polygon)
83
+ continue
84
+ _polygon = (_polygon * scale).reshape(-1).tolist()
85
+ if len(_polygon) % 2 != 0:
86
+ print('Invalid polygon:', _polygon)
87
+ continue
88
+ polygon_points = np.array(_polygon).reshape(-1, 2)
89
+ if fill_mask:
90
+ polygon = patches.Polygon(polygon_points, edgecolor=color, facecolor=fill_color, linewidth=2)
91
+ else:
92
+ polygon = patches.Polygon(polygon_points, edgecolor=color, fill=False, linewidth=2)
93
+ ax.add_patch(polygon)
94
+ plt.text(polygon_points[0, 0], polygon_points[0, 1], label, color='white', fontsize=8, bbox=dict(facecolor=color, alpha=0.5))
95
+ ax.axis('off')
96
+ return fig_to_pil(fig)
97
+
98
+ def draw_ocr_bboxes(image, prediction):
99
+ fig, ax = plt.subplots()
100
+ ax.imshow(image)
101
+ scale = 1
102
+ bboxes, labels = prediction['quad_boxes'], prediction['labels']
103
+ for box, label in zip(bboxes, labels):
104
+ color = random.choice(colormap)
105
+ new_box = (np.array(box) * scale).tolist()
106
+ polygon = patches.Polygon(new_box, edgecolor=color, fill=False, linewidth=3)
107
+ ax.add_patch(polygon)
108
+ plt.text(new_box[0], new_box[1], label, color='white', fontsize=8, bbox=dict(facecolor=color, alpha=0.5))
109
+ ax.axis('off')
110
+ return fig_to_pil(fig)
111
+
112
+
113
+ @spaces.GPU(duration=120)
114
+ def process_video(input_video_path, task_prompt):
115
+ cap = cv2.VideoCapture(input_video_path)
116
+ if not cap.isOpened():
117
+ print("Error: Can't open the video file.")
118
+ return
119
+
120
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
121
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
122
+ fps = cap.get(cv2.CAP_PROP_FPS)
123
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
124
+ out = cv2.VideoWriter("output_vid.mp4", fourcc, fps, (frame_width, frame_height))
125
+
126
+ while cap.isOpened():
127
+ ret, frame = cap.read()
128
+ if not ret:
129
+ break
130
+
131
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
132
+ pil_image = Image.fromarray(frame_rgb)
133
+
134
+ result = run_example(task_prompt, pil_image)
135
+
136
+ if task_prompt == "<OD>":
137
+ processed_image = plot_bbox(pil_image, result['<OD>'])
138
+ elif task_prompt == "<DENSE_REGION_CAPTION>":
139
+ processed_image = plot_bbox(pil_image, result['<DENSE_REGION_CAPTION>'])
140
+ else:
141
+ processed_image = pil_image
142
+
143
+ processed_frame = cv2.cvtColor(np.array(processed_image), cv2.COLOR_RGB2BGR)
144
+ out.write(processed_frame)
145
+
146
+ cap.release()
147
+ out.release()
148
+ cv2.destroyAllWindows()
149
+ return "output_vid.mp4"
150
+
151
+ css = """
152
+ #output {
153
+ min-height: 100px;
154
+ overflow: auto;
155
+ border: 1px solid #ccc;
156
+ }
157
+ """
158
+
159
+ with gr.Blocks(css=css) as demo:
160
+ gr.HTML("<h1><center>Microsoft Florence-2-large-ft</center></h1>")
161
+ with gr.Tab(label="Image"):
162
+ with gr.Row():
163
+ with gr.Column():
164
+ input_img = gr.Image(label="Input Picture", type="pil")
165
+ task_radio = gr.Radio(
166
+ ["Caption", "Detailed Caption", "More Detailed Caption", "Caption to Phrase Grounding",
167
+ "Object Detection", "Dense Region Caption", "Region Proposal", "Referring Expression Segmentation",
168
+ "Region to Segmentation", "Open Vocabulary Detection", "Region to Category", "Region to Description",
169
+ "OCR", "OCR with Region"],
170
+ label="Task", value="Caption"
171
+ )
172
+ text_input = gr.Textbox(label="Text Input (is Optional)", visible=False)
173
+ submit_btn = gr.Button(value="Submit")
174
+ with gr.Column():
175
+ output_text = gr.Textbox(label="Results")
176
+ output_image = gr.Image(label="Image", type="pil")
177
+
178
+ with gr.Tab(label="Video"):
179
+ with gr.Row():
180
+ with gr.Column():
181
+ input_video = gr.Video(label="Video")
182
+ video_task_radio = gr.Radio(
183
+ ["Object Detection", "Dense Region Caption"],
184
+ label="Video Task", value="Object Detection"
185
+ )
186
+ video_submit_btn = gr.Button(value="Process Video")
187
+ with gr.Column():
188
+ output_video = gr.Video(label="Video")
189
+
190
+ def update_text_input(task):
191
+ return gr.update(visible=task in ["Caption to Phrase Grounding", "Referring Expression Segmentation",
192
+ "Region to Segmentation", "Open Vocabulary Detection", "Region to Category",
193
+ "Region to Description"])
194
+
195
+ task_radio.change(fn=update_text_input, inputs=task_radio, outputs=text_input)
196
+
197
+ def process_image(image, task, text):
198
+ task_mapping = {
199
+ "Caption": ("<CAPTION>", lambda result: (result['<CAPTION>'], image)),
200
+ "Detailed Caption": ("<DETAILED_CAPTION>", lambda result: (result['<DETAILED_CAPTION>'], image)),
201
+ "More Detailed Caption": ("<MORE_DETAILED_CAPTION>", lambda result: (result['<MORE_DETAILED_CAPTION>'], image)),
202
+ "Caption to Phrase Grounding": ("<CAPTION_TO_PHRASE_GROUNDING>", lambda result: (str(result['<CAPTION_TO_PHRASE_GROUNDING>']), plot_bbox(image, result['<CAPTION_TO_PHRASE_GROUNDING>']))),
203
+ "Object Detection": ("<OD>", lambda result: (str(result['<OD>']), plot_bbox(image, result['<OD>']))),
204
+ "Dense Region Caption": ("<DENSE_REGION_CAPTION>", lambda result: (str(result['<DENSE_REGION_CAPTION>']), plot_bbox(image, result['<DENSE_REGION_CAPTION>']))),
205
+ "Region Proposal": ("<REGION_PROPOSAL>", lambda result: (str(result['<REGION_PROPOSAL>']), plot_bbox(image, result['<REGION_PROPOSAL>']))),
206
+ "Referring Expression Segmentation": ("<REFERRING_EXPRESSION_SEGMENTATION>", lambda result: (str(result['<REFERRING_EXPRESSION_SEGMENTATION>']), draw_polygons(image, result['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True))),
207
+ "Region to Segmentation": ("<REGION_TO_SEGMENTATION>", lambda result: (str(result['<REGION_TO_SEGMENTATION>']), draw_polygons(image, result['<REGION_TO_SEGMENTATION>'], fill_mask=True))),
208
+ "Open Vocabulary Detection": ("<OPEN_VOCABULARY_DETECTION>", lambda result: (str(convert_to_od_format(result['<OPEN_VOCABULARY_DETECTION>'])), plot_bbox(image, convert_to_od_format(result['<OPEN_VOCABULARY_DETECTION>'])))),
209
+ "Region to Category": ("<REGION_TO_CATEGORY>", lambda result: (result['<REGION_TO_CATEGORY>'], image)),
210
+ "Region to Description": ("<REGION_TO_DESCRIPTION>", lambda result: (result['<REGION_TO_DESCRIPTION>'], image)),
211
+ "OCR": ("<OCR>", lambda result: (result['<OCR>'], image)),
212
+ "OCR with Region": ("<OCR_WITH_REGION>", lambda result: (str(result['<OCR_WITH_REGION>']), draw_ocr_bboxes(image, result['<OCR_WITH_REGION>']))),
213
+ }
214
+
215
+ if task in task_mapping:
216
+ prompt, process_func = task_mapping[task]
217
+ result = run_example(prompt, image, text)
218
+ return process_func(result)
219
+ else:
220
+ return "", image
221
+
222
+ submit_btn.click(fn=process_image, inputs=[input_img, task_radio, text_input], outputs=[output_text, output_image])
223
+ video_submit_btn.click(fn=process_video, inputs=[input_video, video_task_radio], outputs=output_video)
224
+
225
+ demo.launch()