majinyu commited on
Commit
8bdb03f
1 Parent(s): b087a41

remove Grounded-SAM part due to limited resource

Browse files
Files changed (3) hide show
  1. README.md +1 -3
  2. app.py +53 -282
  3. requirements.txt +1 -18
README.md CHANGED
@@ -10,12 +10,10 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- # Recognize Anything Model + Grounded-SAM DEMO
14
 
15
  Please refer to these homepages for details of these projects:
16
 
17
  - Recognize Anthing Model: [https://recognize-anything.github.io/](https://recognize-anything.github.io/)
18
 
19
  - Tag2Text Model: [https://https://tag2text.github.io/](https://https://tag2text.github.io/)
20
-
21
- - Grounded-SAM: [https://github.com/IDEA-Research/Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything)
 
10
  license: mit
11
  ---
12
 
13
+ # Recognize Anything Model / Tag2Text Model
14
 
15
  Please refer to these homepages for details of these projects:
16
 
17
  - Recognize Anthing Model: [https://recognize-anything.github.io/](https://recognize-anything.github.io/)
18
 
19
  - Tag2Text Model: [https://https://tag2text.github.io/](https://https://tag2text.github.io/)
 
 
app.py CHANGED
@@ -1,307 +1,78 @@
1
- import os
2
- import sys
3
- from pathlib import Path
4
 
5
- # setup Grouded-Segment-Anything
6
- # building GroundingDINO requires torch but imports it before installing,
7
- # so directly installing in requirements.txt causes dependency error.
8
- # 1. build with "-e" option to keep the bin file in ./GroundingDINO/groundingdino/, rather than in site-package dir.
9
- os.system("pip install -e ./GroundingDINO/")
10
- # 2. for unknown reason, "import groundingdino" will fill due to unable to find the module, even after installing.
11
- # add ./GroundingDINO/ to PATH, so package "groundingdino" can be imported.
12
- sys.path.append(str(Path(__file__).parent / "GroundingDINO"))
13
-
14
- import random # noqa: E402
15
-
16
- import cv2 # noqa: E402
17
- import groundingdino.datasets.transforms as T # noqa: E402
18
- import numpy as np # noqa: E402
19
- import torch # noqa: E402
20
- import torchvision # noqa: E402
21
- import torchvision.transforms as TS # noqa: E402
22
- from groundingdino.models import build_model # noqa: E402
23
- from groundingdino.util.slconfig import SLConfig # noqa: E402
24
- from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # noqa: E402
25
- from PIL import Image, ImageDraw, ImageFont # noqa: E402
26
- from ram import inference_ram # noqa: E402
27
- from ram import inference_tag2text # noqa: E402
28
- from ram.models import ram # noqa: E402
29
- from ram.models import tag2text_caption # noqa: E402
30
- from segment_anything import SamPredictor, build_sam # noqa: E402
31
-
32
-
33
- # args
34
- config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
35
  ram_checkpoint = "./ram_swin_large_14m.pth"
36
  tag2text_checkpoint = "./tag2text_swin_14m.pth"
37
- grounded_checkpoint = "./groundingdino_swint_ogc.pth"
38
- sam_checkpoint = "./sam_vit_h_4b8939.pth"
39
- box_threshold = 0.25
40
- text_threshold = 0.2
41
- iou_threshold = 0.5
42
- device = "cpu"
43
-
44
-
45
- def load_model(model_config_path, model_checkpoint_path, device):
46
- args = SLConfig.fromfile(model_config_path)
47
- args.device = device
48
- model = build_model(args)
49
- checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
50
- load_res = model.load_state_dict(
51
- clean_state_dict(checkpoint["model"]), strict=False)
52
- print(load_res)
53
- _ = model.eval()
54
- return model
55
-
56
-
57
- def get_grounding_output(model, image, caption, box_threshold, text_threshold, device="cpu"):
58
- caption = caption.lower()
59
- caption = caption.strip()
60
- if not caption.endswith("."):
61
- caption = caption + "."
62
- model = model.to(device)
63
- image = image.to(device)
64
- with torch.no_grad():
65
- outputs = model(image[None], captions=[caption])
66
- logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
67
- boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
68
- logits.shape[0]
69
-
70
- # filter output
71
- logits_filt = logits.clone()
72
- boxes_filt = boxes.clone()
73
- filt_mask = logits_filt.max(dim=1)[0] > box_threshold
74
- logits_filt = logits_filt[filt_mask] # num_filt, 256
75
- boxes_filt = boxes_filt[filt_mask] # num_filt, 4
76
- logits_filt.shape[0]
77
-
78
- # get phrase
79
- tokenlizer = model.tokenizer
80
- tokenized = tokenlizer(caption)
81
- # build pred
82
- pred_phrases = []
83
- scores = []
84
- for logit, box in zip(logits_filt, boxes_filt):
85
- pred_phrase = get_phrases_from_posmap(
86
- logit > text_threshold, tokenized, tokenlizer)
87
- pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
88
- scores.append(logit.max().item())
89
-
90
- return boxes_filt, torch.Tensor(scores), pred_phrases
91
-
92
-
93
- def draw_mask(mask, draw, random_color=False):
94
- if random_color:
95
- color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 153)
96
- else:
97
- color = (30, 144, 255, 153)
98
-
99
- nonzero_coords = np.transpose(np.nonzero(mask))
100
-
101
- for coord in nonzero_coords:
102
- draw.point(coord[::-1], fill=color)
103
-
104
-
105
- def draw_box(box, draw, label):
106
- # random color
107
- color = tuple(np.random.randint(0, 255, size=3).tolist())
108
- line_width = int(max(4, min(20, 0.006*max(draw.im.size))))
109
- draw.rectangle(((box[0], box[1]), (box[2], box[3])), outline=color, width=line_width)
110
-
111
- if label:
112
- font_path = os.path.join(
113
- cv2.__path__[0], 'qt', 'fonts', 'DejaVuSans.ttf')
114
- font_size = int(max(12, min(60, 0.02*max(draw.im.size))))
115
- font = ImageFont.truetype(font_path, size=font_size)
116
- if hasattr(font, "getbbox"):
117
- bbox = draw.textbbox((box[0], box[1]), str(label), font)
118
- else:
119
- w, h = draw.textsize(str(label), font)
120
- bbox = (box[0], box[1], w + box[0], box[1] + h)
121
- draw.rectangle(bbox, fill=color)
122
- draw.text((box[0], box[1]), str(label), fill="white", font=font)
123
-
124
- draw.text((box[0], box[1]), label, font=font)
125
 
126
 
127
  @torch.no_grad()
128
- def inference(
129
- raw_image, specified_tags, do_det_seg,
130
- tagging_model_type, tagging_model, grounding_dino_model, sam_model
131
- ):
132
  print(f"Start processing, image size {raw_image.size}")
133
- raw_image = raw_image.convert("RGB")
134
-
135
- # run tagging model
136
- normalize = TS.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
137
- transform = TS.Compose([
138
- TS.Resize((384, 384)),
139
- TS.ToTensor(),
140
- normalize
141
- ])
142
 
143
- image = raw_image.resize((384, 384))
144
- image = transform(image).unsqueeze(0).to(device)
145
 
146
- # Currently ", " is better for detecting single tags
147
- # while ". " is a little worse in some case
148
  if tagging_model_type == "RAM":
149
  res = inference_ram(image, tagging_model)
150
- tags = res[0].strip(' ').replace(' ', ' ').replace(' |', ',')
151
- tags_chinese = res[1].strip(' ').replace(' ', ' ').replace(' |', ',')
152
  print("Tags: ", tags)
153
- print("图像标签: ", tags_chinese)
 
154
  else:
155
  res = inference_tag2text(image, tagging_model, specified_tags)
156
- tags = res[0].strip(' ').replace(' ', ' ').replace(' |', ',')
157
  caption = res[2]
158
  print(f"Tags: {tags}")
159
  print(f"Caption: {caption}")
 
160
 
161
- # return
162
- if not do_det_seg:
163
- if tagging_model_type == "RAM":
164
- return tags.replace(", ", " | "), tags_chinese.replace(", ", " | "), None
165
- else:
166
- return tags.replace(", ", " | "), caption, None
167
-
168
- # run groundingDINO
169
- transform = T.Compose([
170
- T.RandomResize([800], max_size=1333),
171
- T.ToTensor(),
172
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
173
- ])
174
-
175
- image, _ = transform(raw_image, None) # 3, h, w
176
-
177
- boxes_filt, scores, pred_phrases = get_grounding_output(
178
- grounding_dino_model, image, tags, box_threshold, text_threshold, device=device
179
- )
180
- print("GroundingDINO finished")
181
 
182
- # run SAM
183
- image = np.asarray(raw_image)
184
- sam_model.set_image(image)
185
 
186
- size = raw_image.size
187
- H, W = size[1], size[0]
188
- for i in range(boxes_filt.size(0)):
189
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
190
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
191
- boxes_filt[i][2:] += boxes_filt[i][:2]
192
 
193
- boxes_filt = boxes_filt.cpu()
194
- # use NMS to handle overlapped boxes
195
- print(f"Before NMS: {boxes_filt.shape[0]} boxes")
196
- nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
197
- boxes_filt = boxes_filt[nms_idx]
198
- pred_phrases = [pred_phrases[idx] for idx in nms_idx]
199
- print(f"After NMS: {boxes_filt.shape[0]} boxes")
200
-
201
- transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
202
-
203
- masks, _, _ = sam_model.predict_torch(
204
- point_coords=None,
205
- point_labels=None,
206
- boxes=transformed_boxes.to(device),
207
- multimask_output=False,
208
- )
209
- print("SAM finished")
210
-
211
- # draw output image
212
- mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
213
-
214
- mask_draw = ImageDraw.Draw(mask_image)
215
- for mask in masks:
216
- draw_mask(mask[0].cpu().numpy(), mask_draw, random_color=True)
217
-
218
- image_draw = ImageDraw.Draw(raw_image)
219
-
220
- for box, label in zip(boxes_filt, pred_phrases):
221
- draw_box(box, image_draw, label)
222
-
223
- out_image = raw_image.convert('RGBA')
224
- out_image.alpha_composite(mask_image)
225
-
226
- # return
227
- if tagging_model_type == "RAM":
228
- return tags.replace(", ", " | "), tags_chinese.replace(", ", " | "), out_image
229
- else:
230
- return tags.replace(", ", " | "), caption, out_image
231
 
232
 
233
  if __name__ == "__main__":
234
  import gradio as gr
235
 
236
- # load RAM
237
- ram_model = ram(pretrained=ram_checkpoint, image_size=384, vit='swin_l')
238
- ram_model.eval()
239
- ram_model = ram_model.to(device)
240
-
241
- # load Tag2Text
242
- delete_tag_index = [] # filter out attributes and action categories which are difficult to grounding
243
- for i in range(3012, 3429):
244
- delete_tag_index.append(i)
245
-
246
- tag2text_model = tag2text_caption(pretrained=tag2text_checkpoint,
247
- image_size=384,
248
- vit='swin_b',
249
- delete_tag_index=delete_tag_index)
250
- tag2text_model.threshold = 0.64 # we reduce the threshold to obtain more tags
251
- tag2text_model.eval()
252
- tag2text_model = tag2text_model.to(device)
253
-
254
- # load groundingDINO
255
- grounding_dino_model = load_model(config_file, grounded_checkpoint, device=device)
256
-
257
- # load SAM
258
- sam_model = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
259
 
260
  # build GUI
261
  def build_gui():
262
 
263
  description = """
264
- <center><strong><font size='10'>Recognize Anything Model + Grounded-SAM</font></strong></center>
265
  <br>
266
- Welcome to the RAM/Tag2Text + Grounded-SAM demo! <br><br>
267
  <li>
268
  <b>Recognize Anything Model:</b> Upload your image to get the <b>English and Chinese tags</b>!
269
  </li>
270
  <li>
271
- <b>Tag2Text Model:</b> Upload your image to get the <b>tags and caption</b>!
272
- (Optional: Specify tags to get the corresponding caption.)
273
- </li>
274
- <li>
275
- <b>Grounded-SAM:</b> Tick the checkbox to get <b>boxes</b> and <b>masks</b> of tags!
276
  </li>
277
- <br>
278
- Great thanks to <a href='https://huggingface.co/majinyu' target='_blank'>Ma Jinyu</a>, the major contributor of this demo!
279
  """ # noqa
280
 
281
  article = """
282
  <p style='text-align: center'>
283
  RAM and Tag2Text are trained on open-source datasets, and we are persisting in refining and iterating upon it.<br/>
284
- Grounded-SAM is a combination of Grounding DINO and SAM aming to detect and segment anything with text inputs.<br/>
285
  <a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything: A Strong Image Tagging Model</a>
286
  |
287
  <a href='https://https://tag2text.github.io/' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a>
288
- |
289
- <a href='https://github.com/IDEA-Research/Grounded-Segment-Anything' target='_blank'>Grounded-Segment-Anything</a>
290
  </p>
291
  """ # noqa
292
 
293
- def inference_with_ram(img, do_det_seg):
294
- return inference(
295
- img, None, do_det_seg,
296
- "RAM", ram_model, grounding_dino_model, sam_model
297
- )
298
-
299
- def inference_with_t2t(img, input_tags, do_det_seg):
300
- return inference(
301
- img, input_tags, do_det_seg,
302
- "Tag2Text", tag2text_model, grounding_dino_model, sam_model
303
- )
304
-
305
  with gr.Blocks(title="Recognize Anything Model") as demo:
306
  ###############
307
  # components
@@ -312,23 +83,24 @@ if __name__ == "__main__":
312
  with gr.Row():
313
  with gr.Column():
314
  ram_in_img = gr.Image(type="pil")
315
- ram_opt_det_seg = gr.Checkbox(label="Get Boxes and Masks with Grounded-SAM", value=True)
316
  with gr.Row():
317
  ram_btn_run = gr.Button(value="Run")
318
- ram_btn_clear = gr.ClearButton()
 
 
 
319
  with gr.Column():
320
- ram_out_img = gr.Image(type="pil")
321
  ram_out_tag = gr.Textbox(label="Tags")
322
  ram_out_biaoqian = gr.Textbox(label="标签")
323
  gr.Examples(
324
  examples=[
325
- ["images/demo1.jpg", True],
326
- ["images/demo2.jpg", True],
327
- ["images/demo4.jpg", True],
328
  ],
329
  fn=inference_with_ram,
330
- inputs=[ram_in_img, ram_opt_det_seg],
331
- outputs=[ram_out_tag, ram_out_biaoqian, ram_out_img],
332
  cache_examples=True
333
  )
334
 
@@ -337,23 +109,24 @@ if __name__ == "__main__":
337
  with gr.Column():
338
  t2t_in_img = gr.Image(type="pil")
339
  t2t_in_tag = gr.Textbox(label="User Specified Tags (Optional, separated by comma)")
340
- t2t_opt_det_seg = gr.Checkbox(label="Get Boxes and Masks with Grounded-SAM", value=True)
341
  with gr.Row():
342
  t2t_btn_run = gr.Button(value="Run")
343
- t2t_btn_clear = gr.ClearButton()
 
 
 
344
  with gr.Column():
345
- t2t_out_img = gr.Image(type="pil")
346
  t2t_out_tag = gr.Textbox(label="Tags")
347
  t2t_out_cap = gr.Textbox(label="Caption")
348
  gr.Examples(
349
  examples=[
350
- ["images/demo4.jpg", "", True],
351
- ["images/demo4.jpg", "power line", False],
352
- ["images/demo4.jpg", "track, train", False],
353
  ],
354
  fn=inference_with_t2t,
355
- inputs=[t2t_in_img, t2t_in_tag, t2t_opt_det_seg],
356
- outputs=[t2t_out_tag, t2t_out_cap, t2t_out_img],
357
  cache_examples=True
358
  )
359
 
@@ -365,22 +138,20 @@ if __name__ == "__main__":
365
  # run inference
366
  ram_btn_run.click(
367
  fn=inference_with_ram,
368
- inputs=[ram_in_img, ram_opt_det_seg],
369
- outputs=[ram_out_tag, ram_out_biaoqian, ram_out_img]
370
  )
371
  t2t_btn_run.click(
372
  fn=inference_with_t2t,
373
- inputs=[t2t_in_img, t2t_in_tag, t2t_opt_det_seg],
374
- outputs=[t2t_out_tag, t2t_out_cap, t2t_out_img]
375
  )
376
 
377
- # hide or show image output
378
- ram_opt_det_seg.change(fn=lambda b: gr.update(visible=b), inputs=[ram_opt_det_seg], outputs=[ram_out_img])
379
- t2t_opt_det_seg.change(fn=lambda b: gr.update(visible=b), inputs=[t2t_opt_det_seg], outputs=[t2t_out_img])
380
-
381
  # clear
382
- ram_btn_clear.add([ram_in_img, ram_out_img, ram_out_tag, ram_out_biaoqian])
383
- t2t_btn_clear.add([t2t_in_img, t2t_in_tag, t2t_out_img, t2t_out_tag, t2t_out_cap])
 
 
384
 
385
  return demo
386
 
 
1
+ import torch
2
+ from ram import get_transform, inference_ram, inference_tag2text
3
+ from ram.models import ram, tag2text_caption
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  ram_checkpoint = "./ram_swin_large_14m.pth"
6
  tag2text_checkpoint = "./tag2text_swin_14m.pth"
7
+ image_size = 384
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  @torch.no_grad()
12
+ def inference(raw_image, specified_tags, tagging_model_type, tagging_model, transform):
 
 
 
13
  print(f"Start processing, image size {raw_image.size}")
 
 
 
 
 
 
 
 
 
14
 
15
+ image = transform(raw_image).unsqueeze(0).to(device)
 
16
 
 
 
17
  if tagging_model_type == "RAM":
18
  res = inference_ram(image, tagging_model)
19
+ tags = res[0].strip(' ').replace(' ', ' ')
20
+ tags_chinese = res[1].strip(' ').replace(' ', ' ')
21
  print("Tags: ", tags)
22
+ print("标签: ", tags_chinese)
23
+ return tags, tags_chinese
24
  else:
25
  res = inference_tag2text(image, tagging_model, specified_tags)
26
+ tags = res[0].strip(' ').replace(' ', ' ')
27
  caption = res[2]
28
  print(f"Tags: {tags}")
29
  print(f"Caption: {caption}")
30
+ return tags, caption
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ def inference_with_ram(img):
34
+ return inference(img, None, "RAM", ram_model, transform)
 
35
 
 
 
 
 
 
 
36
 
37
+ def inference_with_t2t(img, input_tags):
38
+ return inference(img, input_tags, "Tag2Text", tag2text_model, transform)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  if __name__ == "__main__":
42
  import gradio as gr
43
 
44
+ # get transform and load models
45
+ transform = get_transform(image_size=image_size)
46
+ ram_model = ram(pretrained=ram_checkpoint, image_size=image_size, vit='swin_l').eval().to(device)
47
+ tag2text_model = tag2text_caption(
48
+ pretrained=tag2text_checkpoint, image_size=image_size, vit='swin_b').eval().to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # build GUI
51
  def build_gui():
52
 
53
  description = """
54
+ <center><strong><font size='10'>Recognize Anything Model</font></strong></center>
55
  <br>
56
+ <p>Welcome to the <a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything Model</a> / <a href='https://tag2text.github.io/Tag2Text' target='_blank'>Tag2Text Model</a> demo!</p>
57
  <li>
58
  <b>Recognize Anything Model:</b> Upload your image to get the <b>English and Chinese tags</b>!
59
  </li>
60
  <li>
61
+ <b>Tag2Text Model:</b> Upload your image to get the <b>tags and caption</b>! (Optional: Specify tags to get the corresponding caption.)
 
 
 
 
62
  </li>
63
+ <p><b>More over:</b> Combine with <a href='https://github.com/IDEA-Research/Grounded-Segment-Anything' target='_blank'>Grounded-SAM</a>, you can get <b>boxes and masks</b>! Please run <a href='https://github.com/xinyu1205/recognize-anything/blob/main/gui_demo.ipynb' target='_blank'>this notebook</a> to try out!</p>
64
+ <p>Great thanks to <a href='https://huggingface.co/majinyu' target='_blank'>Ma Jinyu</a>, the major contributor of this demo!</p>
65
  """ # noqa
66
 
67
  article = """
68
  <p style='text-align: center'>
69
  RAM and Tag2Text are trained on open-source datasets, and we are persisting in refining and iterating upon it.<br/>
 
70
  <a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything: A Strong Image Tagging Model</a>
71
  |
72
  <a href='https://https://tag2text.github.io/' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a>
 
 
73
  </p>
74
  """ # noqa
75
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  with gr.Blocks(title="Recognize Anything Model") as demo:
77
  ###############
78
  # components
 
83
  with gr.Row():
84
  with gr.Column():
85
  ram_in_img = gr.Image(type="pil")
 
86
  with gr.Row():
87
  ram_btn_run = gr.Button(value="Run")
88
+ try:
89
+ ram_btn_clear = gr.ClearButton()
90
+ except AttributeError: # old gradio does not have ClearButton, not big problem
91
+ ram_btn_clear = None
92
  with gr.Column():
 
93
  ram_out_tag = gr.Textbox(label="Tags")
94
  ram_out_biaoqian = gr.Textbox(label="标签")
95
  gr.Examples(
96
  examples=[
97
+ ["images/demo1.jpg"],
98
+ ["images/demo2.jpg"],
99
+ ["images/demo4.jpg"],
100
  ],
101
  fn=inference_with_ram,
102
+ inputs=[ram_in_img],
103
+ outputs=[ram_out_tag, ram_out_biaoqian],
104
  cache_examples=True
105
  )
106
 
 
109
  with gr.Column():
110
  t2t_in_img = gr.Image(type="pil")
111
  t2t_in_tag = gr.Textbox(label="User Specified Tags (Optional, separated by comma)")
 
112
  with gr.Row():
113
  t2t_btn_run = gr.Button(value="Run")
114
+ try:
115
+ t2t_btn_clear = gr.ClearButton()
116
+ except AttributeError: # old gradio does not have ClearButton, not big problem
117
+ t2t_btn_clear = None
118
  with gr.Column():
 
119
  t2t_out_tag = gr.Textbox(label="Tags")
120
  t2t_out_cap = gr.Textbox(label="Caption")
121
  gr.Examples(
122
  examples=[
123
+ ["images/demo4.jpg", ""],
124
+ ["images/demo4.jpg", "power line"],
125
+ ["images/demo4.jpg", "track, train"],
126
  ],
127
  fn=inference_with_t2t,
128
+ inputs=[t2t_in_img, t2t_in_tag],
129
+ outputs=[t2t_out_tag, t2t_out_cap],
130
  cache_examples=True
131
  )
132
 
 
138
  # run inference
139
  ram_btn_run.click(
140
  fn=inference_with_ram,
141
+ inputs=[ram_in_img],
142
+ outputs=[ram_out_tag, ram_out_biaoqian]
143
  )
144
  t2t_btn_run.click(
145
  fn=inference_with_t2t,
146
+ inputs=[t2t_in_img, t2t_in_tag],
147
+ outputs=[t2t_out_tag, t2t_out_cap]
148
  )
149
 
 
 
 
 
150
  # clear
151
+ if ram_btn_clear is not None:
152
+ ram_btn_clear.add([ram_in_img, ram_out_tag, ram_out_biaoqian])
153
+ if t2t_btn_clear is not None:
154
+ t2t_btn_clear.add([t2t_in_img, t2t_in_tag, t2t_out_tag, t2t_out_cap])
155
 
156
  return demo
157
 
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  timm==0.4.12
2
- transformers
3
  fairscale==0.4.4
4
  pycocoevalcap
5
  torch
@@ -7,21 +7,4 @@ torchvision
7
  Pillow
8
  scipy
9
  git+https://github.com/openai/CLIP.git
10
- git+https://github.com/IDEA-Research/Grounded-Segment-Anything.git/#subdirectory=segment_anything
11
  git+https://github.com/xinyu1205/recognize-anything.git
12
- addict
13
- diffusers[torch]
14
- gradio
15
- huggingface_hub
16
- matplotlib
17
- numpy
18
- onnxruntime
19
- opencv_python
20
- pycocotools
21
- PyYAML
22
- requests
23
- setuptools
24
- supervision
25
- termcolor
26
- yapf
27
- nltk
 
1
  timm==0.4.12
2
+ transformers==4.15.0
3
  fairscale==0.4.4
4
  pycocoevalcap
5
  torch
 
7
  Pillow
8
  scipy
9
  git+https://github.com/openai/CLIP.git
 
10
  git+https://github.com/xinyu1205/recognize-anything.git