majinyu commited on
Commit
9027584
1 Parent(s): 6a354bc

commit app

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/*.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ .vscode/
3
+ gradio_cached_examples/
app.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # setup Grouded-Segment-Anything
4
+ os.system("python -m pip install -e 'Grounded-Segment-Anything/segment_anything'")
5
+ os.system("python -m pip install -e 'Grounded-Segment-Anything/GroundingDINO'")
6
+ os.system("pip install --upgrade diffusers[torch]")
7
+ os.system("pip install opencv-python pycocotools matplotlib onnxruntime onnx ipykernel")
8
+
9
+ # setup recognize-anything
10
+ os.system("python -m pip install -e 'recognize-anything'")
11
+
12
+ import random # noqa: E402
13
+
14
+ import cv2 # noqa: E402
15
+ import groundingdino.datasets.transforms as T # noqa: E402
16
+ import numpy as np # noqa: E402
17
+ import torch # noqa: E402
18
+ import torchvision # noqa: E402
19
+ import torchvision.transforms as TS # noqa: E402
20
+ from groundingdino.models import build_model # noqa: E402
21
+ from groundingdino.util.slconfig import SLConfig # noqa: E402
22
+ from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # noqa: E402
23
+ from PIL import Image, ImageDraw, ImageFont # noqa: E402
24
+ from ram import inference_ram # noqa: E402
25
+ from ram import inference_tag2text # noqa: E402
26
+ from ram.models import ram # noqa: E402
27
+ from ram.models import tag2text_caption # noqa: E402
28
+ from segment_anything import SamPredictor, build_sam # noqa: E402
29
+
30
+
31
+ # args
32
+ config_file = "Grounded-Segment-Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
33
+ ram_checkpoint = "./ram_swin_large_14m.pth"
34
+ tag2text_checkpoint = "./tag2text_swin_14m.pth"
35
+ grounded_checkpoint = "./groundingdino_swint_ogc.pth"
36
+ sam_checkpoint = "./sam_vit_h_4b8939.pth"
37
+ box_threshold = 0.25
38
+ text_threshold = 0.2
39
+ iou_threshold = 0.5
40
+ device = "cpu"
41
+
42
+
43
+ def load_model(model_config_path, model_checkpoint_path, device):
44
+ args = SLConfig.fromfile(model_config_path)
45
+ args.device = device
46
+ model = build_model(args)
47
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
48
+ load_res = model.load_state_dict(
49
+ clean_state_dict(checkpoint["model"]), strict=False)
50
+ print(load_res)
51
+ _ = model.eval()
52
+ return model
53
+
54
+
55
+ def get_grounding_output(model, image, caption, box_threshold, text_threshold, device="cpu"):
56
+ caption = caption.lower()
57
+ caption = caption.strip()
58
+ if not caption.endswith("."):
59
+ caption = caption + "."
60
+ model = model.to(device)
61
+ image = image.to(device)
62
+ with torch.no_grad():
63
+ outputs = model(image[None], captions=[caption])
64
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
65
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
66
+ logits.shape[0]
67
+
68
+ # filter output
69
+ logits_filt = logits.clone()
70
+ boxes_filt = boxes.clone()
71
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
72
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
73
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
74
+ logits_filt.shape[0]
75
+
76
+ # get phrase
77
+ tokenlizer = model.tokenizer
78
+ tokenized = tokenlizer(caption)
79
+ # build pred
80
+ pred_phrases = []
81
+ scores = []
82
+ for logit, box in zip(logits_filt, boxes_filt):
83
+ pred_phrase = get_phrases_from_posmap(
84
+ logit > text_threshold, tokenized, tokenlizer)
85
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
86
+ scores.append(logit.max().item())
87
+
88
+ return boxes_filt, torch.Tensor(scores), pred_phrases
89
+
90
+
91
+ def draw_mask(mask, draw, random_color=False):
92
+ if random_color:
93
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 153)
94
+ else:
95
+ color = (30, 144, 255, 153)
96
+
97
+ nonzero_coords = np.transpose(np.nonzero(mask))
98
+
99
+ for coord in nonzero_coords:
100
+ draw.point(coord[::-1], fill=color)
101
+
102
+
103
+ def draw_box(box, draw, label):
104
+ # random color
105
+ color = tuple(np.random.randint(0, 255, size=3).tolist())
106
+ line_width = min(5, max(25, 0.006*max(draw.im.size)))
107
+ draw.rectangle(((box[0], box[1]), (box[2], box[3])), outline=color, width=line_width)
108
+
109
+ if label:
110
+ font_path = os.path.join(
111
+ cv2.__path__[0], 'qt', 'fonts', 'DejaVuSans.ttf')
112
+ font_size = min(15, max(75, 0.02*max(draw.im.size)))
113
+ font = ImageFont.truetype(font_path, size=font_size)
114
+ if hasattr(font, "getbbox"):
115
+ bbox = draw.textbbox((box[0], box[1]), str(label), font)
116
+ else:
117
+ w, h = draw.textsize(str(label), font)
118
+ bbox = (box[0], box[1], w + box[0], box[1] + h)
119
+ draw.rectangle(bbox, fill=color)
120
+ draw.text((box[0], box[1]), str(label), fill="white", font=font)
121
+
122
+ draw.text((box[0], box[1]), label, font=font)
123
+
124
+
125
+ def inference(raw_image, specified_tags, tagging_model_type, tagging_model, grounding_dino_model, sam_model):
126
+ raw_image = raw_image.convert("RGB")
127
+
128
+ # run tagging model
129
+ normalize = TS.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
130
+ transform = TS.Compose([
131
+ TS.Resize((384, 384)),
132
+ TS.ToTensor(),
133
+ normalize
134
+ ])
135
+
136
+ image = raw_image.resize((384, 384))
137
+ image = transform(image).unsqueeze(0).to(device)
138
+
139
+ # Currently ", " is better for detecting single tags
140
+ # while ". " is a little worse in some case
141
+ if tagging_model_type == "RAM":
142
+ res = inference_ram(image, tagging_model)
143
+ tags = res[0].strip(' ').replace(' ', ' ').replace(' |', ',')
144
+ tags_chinese = res[1].strip(' ').replace(' ', ' ').replace(' |', ',')
145
+ print("Tags: ", tags)
146
+ print("图像标签: ", tags_chinese)
147
+ else:
148
+ res = inference_tag2text(image, tagging_model, specified_tags)
149
+ tags = res[0].strip(' ').replace(' ', ' ').replace(' |', ',')
150
+ caption = res[2]
151
+ print(f"Tags: {tags}")
152
+ print(f"Caption: {caption}")
153
+
154
+ # run groundingDINO
155
+ transform = T.Compose([
156
+ T.RandomResize([800], max_size=1333),
157
+ T.ToTensor(),
158
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
159
+ ])
160
+
161
+ image, _ = transform(raw_image, None) # 3, h, w
162
+
163
+ boxes_filt, scores, pred_phrases = get_grounding_output(
164
+ grounding_dino_model, image, tags, box_threshold, text_threshold, device=device
165
+ )
166
+
167
+ # run SAM
168
+ image = np.asarray(raw_image)
169
+ sam_model.set_image(image)
170
+
171
+ size = raw_image.size
172
+ H, W = size[1], size[0]
173
+ for i in range(boxes_filt.size(0)):
174
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
175
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
176
+ boxes_filt[i][2:] += boxes_filt[i][:2]
177
+
178
+ boxes_filt = boxes_filt.cpu()
179
+ # use NMS to handle overlapped boxes
180
+ nms_idx = torchvision.ops.nms(
181
+ boxes_filt, scores, iou_threshold).numpy().tolist()
182
+ boxes_filt = boxes_filt[nms_idx]
183
+ pred_phrases = [pred_phrases[idx] for idx in nms_idx]
184
+
185
+ transformed_boxes = sam_model.transform.apply_boxes_torch(
186
+ boxes_filt, image.shape[:2]).to(device)
187
+
188
+ masks, _, _ = sam_model.predict_torch(
189
+ point_coords=None,
190
+ point_labels=None,
191
+ boxes=transformed_boxes.to(device),
192
+ multimask_output=False,
193
+ )
194
+
195
+ # draw output image
196
+ mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
197
+
198
+ mask_draw = ImageDraw.Draw(mask_image)
199
+ for mask in masks:
200
+ draw_mask(mask[0].cpu().numpy(), mask_draw, random_color=True)
201
+
202
+ image_draw = ImageDraw.Draw(raw_image)
203
+
204
+ for box, label in zip(boxes_filt, pred_phrases):
205
+ draw_box(box, image_draw, label)
206
+
207
+ out_image = raw_image.convert('RGBA')
208
+ out_image.alpha_composite(mask_image)
209
+
210
+ # return
211
+ if tagging_model_type == "RAM":
212
+ return tags, tags_chinese, out_image
213
+ else:
214
+ return tags, caption, out_image
215
+
216
+
217
+ if __name__ == "__main__":
218
+ import gradio as gr
219
+
220
+ # load RAM
221
+ ram_model = ram(pretrained=ram_checkpoint, image_size=384, vit='swin_l')
222
+ ram_model.eval()
223
+ ram_model = ram_model.to(device)
224
+
225
+ # load Tag2Text
226
+ delete_tag_index = [] # filter out attributes and action categories which are difficult to grounding
227
+ for i in range(3012, 3429):
228
+ delete_tag_index.append(i)
229
+
230
+ tag2text_model = tag2text_caption(pretrained=tag2text_checkpoint,
231
+ image_size=384,
232
+ vit='swin_b',
233
+ delete_tag_index=delete_tag_index)
234
+ tag2text_model.threshold = 0.64 # we reduce the threshold to obtain more tags
235
+ tag2text_model.eval()
236
+ tag2text_model = tag2text_model.to(device)
237
+
238
+ # load groundingDINO
239
+ grounding_dino_model = load_model(config_file, grounded_checkpoint, device=device)
240
+
241
+ # load SAM
242
+ sam_model = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
243
+
244
+ # build GUI
245
+ def build_gui():
246
+
247
+ description = """
248
+ <center><strong><font size='10'>Recognize Anything Model + Grounded-SAM</font></strong></center>
249
+ <br>
250
+ Welcome to the RAM/Tag2Text + Grounded-SAM demo! <br><br>
251
+ <li>
252
+ <b>Recognize Anything Model + Grounded-SAM:</b> Upload your image to get the <b>English and Chinese tags</b> (by RAM) and <b>masks and boxes</b> (by Grounded-SAM)!
253
+ </li>
254
+ <li>
255
+ <b>Tag2Text Model + Grounded-SAM:</b> Upload your image to get the <b>tags and caption</b> (by Tag2Text) and <b>masks and boxes</b> (by Grounded-SAM)!
256
+ (Optional: Specify tags to get the corresponding caption.)
257
+ </li>
258
+ """ # noqa
259
+
260
+ article = """
261
+ <p style='text-align: center'>
262
+ RAM and Tag2Text are trained on open-source datasets, and we are persisting in refining and iterating upon it.<br/>
263
+ Grounded-SAM is a combination of Grounding DINO and SAM aming to detect and segment anything with text inputs.<br/>
264
+ <a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything: A Strong Image Tagging Model</a>
265
+ |
266
+ <a href='https://https://tag2text.github.io/' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a>
267
+ |
268
+ <a href='https://github.com/IDEA-Research/Grounded-Segment-Anything' target='_blank'>Grounded-Segment-Anything</a>
269
+ </p>
270
+ """ # noqa
271
+
272
+ def inference_with_ram(img):
273
+ return inference(img, None, "RAM", ram_model, grounding_dino_model, sam_model)
274
+
275
+ def inference_with_t2t(img, input_tags):
276
+ return inference(img, input_tags, "Tag2Text", tag2text_model, grounding_dino_model, sam_model)
277
+
278
+ with gr.Blocks(title="Recognize Anything Model") as demo:
279
+ ###############
280
+ # components
281
+ ###############
282
+ gr.HTML(description)
283
+
284
+ with gr.Tab(label="Recognize Anything Model"):
285
+ with gr.Row():
286
+ with gr.Column():
287
+ ram_in_img = gr.Image(type="pil")
288
+ with gr.Row():
289
+ ram_btn_run = gr.Button(value="Run")
290
+ ram_btn_clear = gr.Button(value="Clear")
291
+ with gr.Column():
292
+ ram_out_img = gr.Image(type="pil")
293
+ ram_out_tag = gr.Textbox(label="Tags")
294
+ ram_out_biaoqian = gr.Textbox(label="标签")
295
+ gr.Examples(
296
+ examples=[
297
+ ["images/demo1.jpg"],
298
+ ["images/demo2.jpg"],
299
+ ["images/demo4.jpg"],
300
+ ],
301
+ fn=inference_with_ram,
302
+ inputs=[ram_in_img],
303
+ outputs=[ram_out_tag, ram_out_biaoqian, ram_out_img],
304
+ cache_examples=True
305
+ )
306
+
307
+ with gr.Tab(label="Tag2Text Model"):
308
+ with gr.Row():
309
+ with gr.Column():
310
+ t2t_in_img = gr.Image(type="pil")
311
+ t2t_in_tag = gr.Textbox(label="User Specified Tags (Optional, separated by comma)")
312
+ with gr.Row():
313
+ t2t_btn_run = gr.Button(value="Run")
314
+ t2t_btn_clear = gr.Button(value="Clear")
315
+ with gr.Column():
316
+ t2t_out_img = gr.Image(type="pil")
317
+ t2t_out_tag = gr.Textbox(label="Tags")
318
+ t2t_out_cap = gr.Textbox(label="Caption")
319
+ gr.Examples(
320
+ examples=[
321
+ ["images/demo4.jpg", ""],
322
+ ["images/demo4.jpg", "power line"],
323
+ ["images/demo4.jpg", "track, train"],
324
+ ],
325
+ fn=inference_with_t2t,
326
+ inputs=[t2t_in_img, t2t_in_tag],
327
+ outputs=[t2t_out_tag, t2t_out_cap, t2t_out_img],
328
+ cache_examples=True
329
+ )
330
+
331
+ gr.HTML(article)
332
+
333
+ ###############
334
+ # events
335
+ ###############
336
+ # run inference
337
+ ram_btn_run.click(
338
+ fn=inference_with_ram,
339
+ inputs=[ram_in_img],
340
+ outputs=[ram_out_tag, ram_out_biaoqian, ram_out_img]
341
+ )
342
+ t2t_btn_run.click(
343
+ fn=inference_with_t2t,
344
+ inputs=[t2t_in_img, t2t_in_tag],
345
+ outputs=[t2t_out_tag, t2t_out_cap, t2t_out_img]
346
+ )
347
+
348
+ # clear all
349
+ def clear_all():
350
+ return [gr.update(value=None)] * 4 + [gr.update(value="")] * 5
351
+
352
+ ram_btn_clear.click(fn=clear_all, inputs=[], outputs=[
353
+ ram_in_img, ram_out_img, t2t_in_img, t2t_out_img,
354
+ ram_out_tag, ram_out_biaoqian, t2t_in_tag, t2t_out_tag, t2t_out_cap
355
+ ])
356
+ t2t_btn_clear.click(fn=clear_all, inputs=[], outputs=[
357
+ ram_in_img, t2t_in_img, t2t_in_img, t2t_out_img,
358
+ ram_out_tag, ram_out_biaoqian, t2t_in_tag, t2t_out_tag, t2t_out_cap
359
+ ])
360
+
361
+ return demo
362
+
363
+ build_gui().launch(enable_queue=True, share=True)
groundingdino_swint_ogc.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b3ca2563c77c69f651d7bd133e97139c186df06231157a64c507099c52bc799
3
+ size 693997677
images/demo1.jpg ADDED

Git LFS Details

  • SHA256: 1b2906f4058a69936df49cb6156ec4cd117a286b420e1eb14764033bf8f3c05f
  • Pointer size: 132 Bytes
  • Size of remote file: 5.7 MB
images/demo2.jpg ADDED

Git LFS Details

  • SHA256: 5c5159bf7114d08967f95475176670043115b157bf700efa34190260cd917662
  • Pointer size: 132 Bytes
  • Size of remote file: 1.03 MB
images/demo4.jpg ADDED

Git LFS Details

  • SHA256: 5c71251326fb9ece01b5ce6334869861b3fce82eeb5cae45977e78e6332f4170
  • Pointer size: 131 Bytes
  • Size of remote file: 165 kB
ram_swin_large_14m.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15c729c793af28b9d107c69f85836a1356d76ea830d4714699fb62e55fcc08ed
3
+ size 5625634877
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ timm==0.4.12
2
+ transformers==4.15.0
3
+ fairscale==0.4.4
4
+ pycocoevalcap
5
+ torch
6
+ torchvision
7
+ Pillow
8
+ scipy
9
+ git+https://github.com/openai/CLIP.git
10
+ git+https://github.com/IDEA-Research/Grounded-Segment-Anything.git
11
+ git+https://github.com/xinyu1205/recognize-anything.git
12
+ addict
13
+ diffusers
14
+ gradio
15
+ huggingface_hub
16
+ matplotlib
17
+ numpy
18
+ onnxruntime
19
+ opencv_python
20
+ pycocotools
21
+ PyYAML
22
+ requests
23
+ setuptools
24
+ supervision
25
+ termcolor
26
+ yapf
27
+ nltk
sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879
tag2text_swin_14m.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ce96f0ce98f940a6680d567f66a38ccc9ca8c4e638e5f5c5c2e881a0e3502ac
3
+ size 4478705095