Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,307 +1,78 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
from
|
4 |
|
5 |
-
# setup Grouded-Segment-Anything
|
6 |
-
# building GroundingDINO requires torch but imports it before installing,
|
7 |
-
# so directly installing in requirements.txt causes dependency error.
|
8 |
-
# 1. build with "-e" option to keep the bin file in ./GroundingDINO/groundingdino/, rather than in site-package dir.
|
9 |
-
os.system("pip install -e ./GroundingDINO/")
|
10 |
-
# 2. for unknown reason, "import groundingdino" will fill due to unable to find the module, even after installing.
|
11 |
-
# add ./GroundingDINO/ to PATH, so package "groundingdino" can be imported.
|
12 |
-
sys.path.append(str(Path(__file__).parent / "GroundingDINO"))
|
13 |
-
|
14 |
-
import random # noqa: E402
|
15 |
-
|
16 |
-
import cv2 # noqa: E402
|
17 |
-
import groundingdino.datasets.transforms as T # noqa: E402
|
18 |
-
import numpy as np # noqa: E402
|
19 |
-
import torch # noqa: E402
|
20 |
-
import torchvision # noqa: E402
|
21 |
-
import torchvision.transforms as TS # noqa: E402
|
22 |
-
from groundingdino.models import build_model # noqa: E402
|
23 |
-
from groundingdino.util.slconfig import SLConfig # noqa: E402
|
24 |
-
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # noqa: E402
|
25 |
-
from PIL import Image, ImageDraw, ImageFont # noqa: E402
|
26 |
-
from ram import inference_ram # noqa: E402
|
27 |
-
from ram import inference_tag2text # noqa: E402
|
28 |
-
from ram.models import ram # noqa: E402
|
29 |
-
from ram.models import tag2text_caption # noqa: E402
|
30 |
-
from segment_anything import SamPredictor, build_sam # noqa: E402
|
31 |
-
|
32 |
-
|
33 |
-
# args
|
34 |
-
config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
35 |
ram_checkpoint = "./ram_swin_large_14m.pth"
|
36 |
tag2text_checkpoint = "./tag2text_swin_14m.pth"
|
37 |
-
|
38 |
-
|
39 |
-
box_threshold = 0.25
|
40 |
-
text_threshold = 0.2
|
41 |
-
iou_threshold = 0.5
|
42 |
-
device = "cpu"
|
43 |
-
|
44 |
-
|
45 |
-
def load_model(model_config_path, model_checkpoint_path, device):
|
46 |
-
args = SLConfig.fromfile(model_config_path)
|
47 |
-
args.device = device
|
48 |
-
model = build_model(args)
|
49 |
-
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
50 |
-
load_res = model.load_state_dict(
|
51 |
-
clean_state_dict(checkpoint["model"]), strict=False)
|
52 |
-
print(load_res)
|
53 |
-
_ = model.eval()
|
54 |
-
return model
|
55 |
-
|
56 |
-
|
57 |
-
def get_grounding_output(model, image, caption, box_threshold, text_threshold, device="cpu"):
|
58 |
-
caption = caption.lower()
|
59 |
-
caption = caption.strip()
|
60 |
-
if not caption.endswith("."):
|
61 |
-
caption = caption + "."
|
62 |
-
model = model.to(device)
|
63 |
-
image = image.to(device)
|
64 |
-
with torch.no_grad():
|
65 |
-
outputs = model(image[None], captions=[caption])
|
66 |
-
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
67 |
-
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
|
68 |
-
logits.shape[0]
|
69 |
-
|
70 |
-
# filter output
|
71 |
-
logits_filt = logits.clone()
|
72 |
-
boxes_filt = boxes.clone()
|
73 |
-
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
74 |
-
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
75 |
-
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
76 |
-
logits_filt.shape[0]
|
77 |
-
|
78 |
-
# get phrase
|
79 |
-
tokenlizer = model.tokenizer
|
80 |
-
tokenized = tokenlizer(caption)
|
81 |
-
# build pred
|
82 |
-
pred_phrases = []
|
83 |
-
scores = []
|
84 |
-
for logit, box in zip(logits_filt, boxes_filt):
|
85 |
-
pred_phrase = get_phrases_from_posmap(
|
86 |
-
logit > text_threshold, tokenized, tokenlizer)
|
87 |
-
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
88 |
-
scores.append(logit.max().item())
|
89 |
-
|
90 |
-
return boxes_filt, torch.Tensor(scores), pred_phrases
|
91 |
-
|
92 |
-
|
93 |
-
def draw_mask(mask, draw, random_color=False):
|
94 |
-
if random_color:
|
95 |
-
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 153)
|
96 |
-
else:
|
97 |
-
color = (30, 144, 255, 153)
|
98 |
-
|
99 |
-
nonzero_coords = np.transpose(np.nonzero(mask))
|
100 |
-
|
101 |
-
for coord in nonzero_coords:
|
102 |
-
draw.point(coord[::-1], fill=color)
|
103 |
-
|
104 |
-
|
105 |
-
def draw_box(box, draw, label):
|
106 |
-
# random color
|
107 |
-
color = tuple(np.random.randint(0, 255, size=3).tolist())
|
108 |
-
line_width = int(max(4, min(20, 0.006*max(draw.im.size))))
|
109 |
-
draw.rectangle(((box[0], box[1]), (box[2], box[3])), outline=color, width=line_width)
|
110 |
-
|
111 |
-
if label:
|
112 |
-
font_path = os.path.join(
|
113 |
-
cv2.__path__[0], 'qt', 'fonts', 'DejaVuSans.ttf')
|
114 |
-
font_size = int(max(12, min(60, 0.02*max(draw.im.size))))
|
115 |
-
font = ImageFont.truetype(font_path, size=font_size)
|
116 |
-
if hasattr(font, "getbbox"):
|
117 |
-
bbox = draw.textbbox((box[0], box[1]), str(label), font)
|
118 |
-
else:
|
119 |
-
w, h = draw.textsize(str(label), font)
|
120 |
-
bbox = (box[0], box[1], w + box[0], box[1] + h)
|
121 |
-
draw.rectangle(bbox, fill=color)
|
122 |
-
draw.text((box[0], box[1]), str(label), fill="white", font=font)
|
123 |
-
|
124 |
-
draw.text((box[0], box[1]), label, font=font)
|
125 |
|
126 |
|
127 |
@torch.no_grad()
|
128 |
-
def inference(
|
129 |
-
raw_image, specified_tags, do_det_seg,
|
130 |
-
tagging_model_type, tagging_model, grounding_dino_model, sam_model
|
131 |
-
):
|
132 |
print(f"Start processing, image size {raw_image.size}")
|
133 |
-
raw_image = raw_image.convert("RGB")
|
134 |
-
|
135 |
-
# run tagging model
|
136 |
-
normalize = TS.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
137 |
-
transform = TS.Compose([
|
138 |
-
TS.Resize((384, 384)),
|
139 |
-
TS.ToTensor(),
|
140 |
-
normalize
|
141 |
-
])
|
142 |
|
143 |
-
image = raw_image.
|
144 |
-
image = transform(image).unsqueeze(0).to(device)
|
145 |
|
146 |
-
# Currently ", " is better for detecting single tags
|
147 |
-
# while ". " is a little worse in some case
|
148 |
if tagging_model_type == "RAM":
|
149 |
res = inference_ram(image, tagging_model)
|
150 |
-
tags = res[0].strip(' ').replace(' ', ' ')
|
151 |
-
tags_chinese = res[1].strip(' ').replace(' ', ' ')
|
152 |
print("Tags: ", tags)
|
153 |
-
print("
|
|
|
154 |
else:
|
155 |
res = inference_tag2text(image, tagging_model, specified_tags)
|
156 |
-
tags = res[0].strip(' ').replace(' ', ' ')
|
157 |
caption = res[2]
|
158 |
print(f"Tags: {tags}")
|
159 |
print(f"Caption: {caption}")
|
|
|
160 |
|
161 |
-
# return
|
162 |
-
if not do_det_seg:
|
163 |
-
if tagging_model_type == "RAM":
|
164 |
-
return tags.replace(", ", " | "), tags_chinese.replace(", ", " | "), None
|
165 |
-
else:
|
166 |
-
return tags.replace(", ", " | "), caption, None
|
167 |
-
|
168 |
-
# run groundingDINO
|
169 |
-
transform = T.Compose([
|
170 |
-
T.RandomResize([800], max_size=1333),
|
171 |
-
T.ToTensor(),
|
172 |
-
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
173 |
-
])
|
174 |
-
|
175 |
-
image, _ = transform(raw_image, None) # 3, h, w
|
176 |
-
|
177 |
-
boxes_filt, scores, pred_phrases = get_grounding_output(
|
178 |
-
grounding_dino_model, image, tags, box_threshold, text_threshold, device=device
|
179 |
-
)
|
180 |
-
print("GroundingDINO finished")
|
181 |
|
182 |
-
|
183 |
-
|
184 |
-
sam_model.set_image(image)
|
185 |
|
186 |
-
size = raw_image.size
|
187 |
-
H, W = size[1], size[0]
|
188 |
-
for i in range(boxes_filt.size(0)):
|
189 |
-
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
190 |
-
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
191 |
-
boxes_filt[i][2:] += boxes_filt[i][:2]
|
192 |
|
193 |
-
|
194 |
-
|
195 |
-
print(f"Before NMS: {boxes_filt.shape[0]} boxes")
|
196 |
-
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
|
197 |
-
boxes_filt = boxes_filt[nms_idx]
|
198 |
-
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
|
199 |
-
print(f"After NMS: {boxes_filt.shape[0]} boxes")
|
200 |
-
|
201 |
-
transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
|
202 |
-
|
203 |
-
masks, _, _ = sam_model.predict_torch(
|
204 |
-
point_coords=None,
|
205 |
-
point_labels=None,
|
206 |
-
boxes=transformed_boxes.to(device),
|
207 |
-
multimask_output=False,
|
208 |
-
)
|
209 |
-
print("SAM finished")
|
210 |
-
|
211 |
-
# draw output image
|
212 |
-
mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
|
213 |
-
|
214 |
-
mask_draw = ImageDraw.Draw(mask_image)
|
215 |
-
for mask in masks:
|
216 |
-
draw_mask(mask[0].cpu().numpy(), mask_draw, random_color=True)
|
217 |
-
|
218 |
-
image_draw = ImageDraw.Draw(raw_image)
|
219 |
-
|
220 |
-
for box, label in zip(boxes_filt, pred_phrases):
|
221 |
-
draw_box(box, image_draw, label)
|
222 |
-
|
223 |
-
out_image = raw_image.convert('RGBA')
|
224 |
-
out_image.alpha_composite(mask_image)
|
225 |
-
|
226 |
-
# return
|
227 |
-
if tagging_model_type == "RAM":
|
228 |
-
return tags.replace(", ", " | "), tags_chinese.replace(", ", " | "), out_image
|
229 |
-
else:
|
230 |
-
return tags.replace(", ", " | "), caption, out_image
|
231 |
|
232 |
|
233 |
if __name__ == "__main__":
|
234 |
import gradio as gr
|
235 |
|
236 |
-
# load
|
237 |
-
|
238 |
-
ram_model.eval()
|
239 |
-
|
240 |
-
|
241 |
-
# load Tag2Text
|
242 |
-
delete_tag_index = [] # filter out attributes and action categories which are difficult to grounding
|
243 |
-
for i in range(3012, 3429):
|
244 |
-
delete_tag_index.append(i)
|
245 |
-
|
246 |
-
tag2text_model = tag2text_caption(pretrained=tag2text_checkpoint,
|
247 |
-
image_size=384,
|
248 |
-
vit='swin_b',
|
249 |
-
delete_tag_index=delete_tag_index)
|
250 |
-
tag2text_model.threshold = 0.64 # we reduce the threshold to obtain more tags
|
251 |
-
tag2text_model.eval()
|
252 |
-
tag2text_model = tag2text_model.to(device)
|
253 |
-
|
254 |
-
# load groundingDINO
|
255 |
-
grounding_dino_model = load_model(config_file, grounded_checkpoint, device=device)
|
256 |
-
|
257 |
-
# load SAM
|
258 |
-
sam_model = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
|
259 |
|
260 |
# build GUI
|
261 |
def build_gui():
|
262 |
|
263 |
description = """
|
264 |
-
<center><strong><font size='10'>Recognize Anything Model
|
265 |
<br>
|
266 |
-
Welcome to the
|
267 |
<li>
|
268 |
<b>Recognize Anything Model:</b> Upload your image to get the <b>English and Chinese tags</b>!
|
269 |
</li>
|
270 |
<li>
|
271 |
-
<b>Tag2Text Model:</b> Upload your image to get the <b>tags and caption</b>!
|
272 |
-
(Optional: Specify tags to get the corresponding caption.)
|
273 |
-
</li>
|
274 |
-
<li>
|
275 |
-
<b>Grounded-SAM:</b> Tick the checkbox to get <b>boxes</b> and <b>masks</b> of tags!
|
276 |
</li>
|
277 |
-
<
|
278 |
-
Great thanks to <a href='https://huggingface.co/majinyu' target='_blank'>Ma Jinyu</a>, the major contributor of this demo
|
279 |
""" # noqa
|
280 |
|
281 |
article = """
|
282 |
<p style='text-align: center'>
|
283 |
RAM and Tag2Text are trained on open-source datasets, and we are persisting in refining and iterating upon it.<br/>
|
284 |
-
Grounded-SAM is a combination of Grounding DINO and SAM aming to detect and segment anything with text inputs.<br/>
|
285 |
<a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything: A Strong Image Tagging Model</a>
|
286 |
|
|
287 |
<a href='https://https://tag2text.github.io/' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a>
|
288 |
-
|
|
289 |
-
<a href='https://github.com/IDEA-Research/Grounded-Segment-Anything' target='_blank'>Grounded-Segment-Anything</a>
|
290 |
</p>
|
291 |
""" # noqa
|
292 |
|
293 |
-
def inference_with_ram(img, do_det_seg):
|
294 |
-
return inference(
|
295 |
-
img, None, do_det_seg,
|
296 |
-
"RAM", ram_model, grounding_dino_model, sam_model
|
297 |
-
)
|
298 |
-
|
299 |
-
def inference_with_t2t(img, input_tags, do_det_seg):
|
300 |
-
return inference(
|
301 |
-
img, input_tags, do_det_seg,
|
302 |
-
"Tag2Text", tag2text_model, grounding_dino_model, sam_model
|
303 |
-
)
|
304 |
-
|
305 |
with gr.Blocks(title="Recognize Anything Model") as demo:
|
306 |
###############
|
307 |
# components
|
@@ -312,23 +83,24 @@ if __name__ == "__main__":
|
|
312 |
with gr.Row():
|
313 |
with gr.Column():
|
314 |
ram_in_img = gr.Image(type="pil")
|
315 |
-
ram_opt_det_seg = gr.Checkbox(label="Get Boxes and Masks with Grounded-SAM", value=True)
|
316 |
with gr.Row():
|
317 |
ram_btn_run = gr.Button(value="Run")
|
318 |
-
|
|
|
|
|
|
|
319 |
with gr.Column():
|
320 |
-
ram_out_img = gr.Image(type="pil")
|
321 |
ram_out_tag = gr.Textbox(label="Tags")
|
322 |
ram_out_biaoqian = gr.Textbox(label="标签")
|
323 |
gr.Examples(
|
324 |
examples=[
|
325 |
-
["images/demo1.jpg"
|
326 |
-
["images/demo2.jpg"
|
327 |
-
["images/demo4.jpg"
|
328 |
],
|
329 |
fn=inference_with_ram,
|
330 |
-
inputs=[ram_in_img
|
331 |
-
outputs=[ram_out_tag, ram_out_biaoqian
|
332 |
cache_examples=True
|
333 |
)
|
334 |
|
@@ -337,23 +109,24 @@ if __name__ == "__main__":
|
|
337 |
with gr.Column():
|
338 |
t2t_in_img = gr.Image(type="pil")
|
339 |
t2t_in_tag = gr.Textbox(label="User Specified Tags (Optional, separated by comma)")
|
340 |
-
t2t_opt_det_seg = gr.Checkbox(label="Get Boxes and Masks with Grounded-SAM", value=True)
|
341 |
with gr.Row():
|
342 |
t2t_btn_run = gr.Button(value="Run")
|
343 |
-
|
|
|
|
|
|
|
344 |
with gr.Column():
|
345 |
-
t2t_out_img = gr.Image(type="pil")
|
346 |
t2t_out_tag = gr.Textbox(label="Tags")
|
347 |
t2t_out_cap = gr.Textbox(label="Caption")
|
348 |
gr.Examples(
|
349 |
examples=[
|
350 |
-
["images/demo4.jpg", ""
|
351 |
-
["images/demo4.jpg", "power line"
|
352 |
-
["images/demo4.jpg", "track, train"
|
353 |
],
|
354 |
fn=inference_with_t2t,
|
355 |
-
inputs=[t2t_in_img, t2t_in_tag
|
356 |
-
outputs=[t2t_out_tag, t2t_out_cap
|
357 |
cache_examples=True
|
358 |
)
|
359 |
|
@@ -365,22 +138,20 @@ if __name__ == "__main__":
|
|
365 |
# run inference
|
366 |
ram_btn_run.click(
|
367 |
fn=inference_with_ram,
|
368 |
-
inputs=[ram_in_img
|
369 |
-
outputs=[ram_out_tag, ram_out_biaoqian
|
370 |
)
|
371 |
t2t_btn_run.click(
|
372 |
fn=inference_with_t2t,
|
373 |
-
inputs=[t2t_in_img, t2t_in_tag
|
374 |
-
outputs=[t2t_out_tag, t2t_out_cap
|
375 |
)
|
376 |
|
377 |
-
# hide or show image output
|
378 |
-
ram_opt_det_seg.change(fn=lambda b: gr.update(visible=b), inputs=[ram_opt_det_seg], outputs=[ram_out_img])
|
379 |
-
t2t_opt_det_seg.change(fn=lambda b: gr.update(visible=b), inputs=[t2t_opt_det_seg], outputs=[t2t_out_img])
|
380 |
-
|
381 |
# clear
|
382 |
-
ram_btn_clear
|
383 |
-
|
|
|
|
|
384 |
|
385 |
return demo
|
386 |
|
|
|
1 |
+
import torch
|
2 |
+
from ram import get_transform, inference_ram, inference_tag2text
|
3 |
+
from ram.models import ram, tag2text_caption
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
ram_checkpoint = "./ram_swin_large_14m.pth"
|
6 |
tag2text_checkpoint = "./tag2text_swin_14m.pth"
|
7 |
+
image_size = 384
|
8 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
@torch.no_grad()
|
12 |
+
def inference(raw_image, specified_tags, tagging_model_type, tagging_model, transform):
|
|
|
|
|
|
|
13 |
print(f"Start processing, image size {raw_image.size}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
image = transform(raw_image).unsqueeze(0).to(device)
|
|
|
16 |
|
|
|
|
|
17 |
if tagging_model_type == "RAM":
|
18 |
res = inference_ram(image, tagging_model)
|
19 |
+
tags = res[0].strip(' ').replace(' ', ' ')
|
20 |
+
tags_chinese = res[1].strip(' ').replace(' ', ' ')
|
21 |
print("Tags: ", tags)
|
22 |
+
print("标签: ", tags_chinese)
|
23 |
+
return tags, tags_chinese
|
24 |
else:
|
25 |
res = inference_tag2text(image, tagging_model, specified_tags)
|
26 |
+
tags = res[0].strip(' ').replace(' ', ' ')
|
27 |
caption = res[2]
|
28 |
print(f"Tags: {tags}")
|
29 |
print(f"Caption: {caption}")
|
30 |
+
return tags, caption
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
def inference_with_ram(img):
|
34 |
+
return inference(img, None, "RAM", ram_model, transform)
|
|
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
def inference_with_t2t(img, input_tags):
|
38 |
+
return inference(img, input_tags, "Tag2Text", tag2text_model, transform)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
if __name__ == "__main__":
|
42 |
import gradio as gr
|
43 |
|
44 |
+
# get transform and load models
|
45 |
+
transform = get_transform(image_size=image_size)
|
46 |
+
ram_model = ram(pretrained=ram_checkpoint, image_size=image_size, vit='swin_l').eval().to(device)
|
47 |
+
tag2text_model = tag2text_caption(
|
48 |
+
pretrained=tag2text_checkpoint, image_size=image_size, vit='swin_b').eval().to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
# build GUI
|
51 |
def build_gui():
|
52 |
|
53 |
description = """
|
54 |
+
<center><strong><font size='10'>Recognize Anything Model</font></strong></center>
|
55 |
<br>
|
56 |
+
<p>Welcome to the <a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything Model</a> / <a href='https://tag2text.github.io/Tag2Text' target='_blank'>Tag2Text Model</a> demo!</p>
|
57 |
<li>
|
58 |
<b>Recognize Anything Model:</b> Upload your image to get the <b>English and Chinese tags</b>!
|
59 |
</li>
|
60 |
<li>
|
61 |
+
<b>Tag2Text Model:</b> Upload your image to get the <b>tags and caption</b>! (Optional: Specify tags to get the corresponding caption.)
|
|
|
|
|
|
|
|
|
62 |
</li>
|
63 |
+
<p><b>More over:</b> Combine with <a href='https://github.com/IDEA-Research/Grounded-Segment-Anything' target='_blank'>Grounded-SAM</a>, you can get <b>boxes and masks</b>! Please run <a href='https://github.com/xinyu1205/recognize-anything/blob/main/gui_demo.ipynb' target='_blank'>this notebook</a> to try out!</p>
|
64 |
+
<p>Great thanks to <a href='https://huggingface.co/majinyu' target='_blank'>Ma Jinyu</a>, the major contributor of this demo!</p>
|
65 |
""" # noqa
|
66 |
|
67 |
article = """
|
68 |
<p style='text-align: center'>
|
69 |
RAM and Tag2Text are trained on open-source datasets, and we are persisting in refining and iterating upon it.<br/>
|
|
|
70 |
<a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything: A Strong Image Tagging Model</a>
|
71 |
|
|
72 |
<a href='https://https://tag2text.github.io/' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a>
|
|
|
|
|
73 |
</p>
|
74 |
""" # noqa
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
with gr.Blocks(title="Recognize Anything Model") as demo:
|
77 |
###############
|
78 |
# components
|
|
|
83 |
with gr.Row():
|
84 |
with gr.Column():
|
85 |
ram_in_img = gr.Image(type="pil")
|
|
|
86 |
with gr.Row():
|
87 |
ram_btn_run = gr.Button(value="Run")
|
88 |
+
try:
|
89 |
+
ram_btn_clear = gr.ClearButton()
|
90 |
+
except AttributeError: # old gradio does not have ClearButton, not big problem
|
91 |
+
ram_btn_clear = None
|
92 |
with gr.Column():
|
|
|
93 |
ram_out_tag = gr.Textbox(label="Tags")
|
94 |
ram_out_biaoqian = gr.Textbox(label="标签")
|
95 |
gr.Examples(
|
96 |
examples=[
|
97 |
+
["images/demo1.jpg"],
|
98 |
+
["images/demo2.jpg"],
|
99 |
+
["images/demo4.jpg"],
|
100 |
],
|
101 |
fn=inference_with_ram,
|
102 |
+
inputs=[ram_in_img],
|
103 |
+
outputs=[ram_out_tag, ram_out_biaoqian],
|
104 |
cache_examples=True
|
105 |
)
|
106 |
|
|
|
109 |
with gr.Column():
|
110 |
t2t_in_img = gr.Image(type="pil")
|
111 |
t2t_in_tag = gr.Textbox(label="User Specified Tags (Optional, separated by comma)")
|
|
|
112 |
with gr.Row():
|
113 |
t2t_btn_run = gr.Button(value="Run")
|
114 |
+
try:
|
115 |
+
t2t_btn_clear = gr.ClearButton()
|
116 |
+
except AttributeError: # old gradio does not have ClearButton, not big problem
|
117 |
+
t2t_btn_clear = None
|
118 |
with gr.Column():
|
|
|
119 |
t2t_out_tag = gr.Textbox(label="Tags")
|
120 |
t2t_out_cap = gr.Textbox(label="Caption")
|
121 |
gr.Examples(
|
122 |
examples=[
|
123 |
+
["images/demo4.jpg", ""],
|
124 |
+
["images/demo4.jpg", "power line"],
|
125 |
+
["images/demo4.jpg", "track, train"],
|
126 |
],
|
127 |
fn=inference_with_t2t,
|
128 |
+
inputs=[t2t_in_img, t2t_in_tag],
|
129 |
+
outputs=[t2t_out_tag, t2t_out_cap],
|
130 |
cache_examples=True
|
131 |
)
|
132 |
|
|
|
138 |
# run inference
|
139 |
ram_btn_run.click(
|
140 |
fn=inference_with_ram,
|
141 |
+
inputs=[ram_in_img],
|
142 |
+
outputs=[ram_out_tag, ram_out_biaoqian]
|
143 |
)
|
144 |
t2t_btn_run.click(
|
145 |
fn=inference_with_t2t,
|
146 |
+
inputs=[t2t_in_img, t2t_in_tag],
|
147 |
+
outputs=[t2t_out_tag, t2t_out_cap]
|
148 |
)
|
149 |
|
|
|
|
|
|
|
|
|
150 |
# clear
|
151 |
+
if ram_btn_clear is not None:
|
152 |
+
ram_btn_clear.add([ram_in_img, ram_out_tag, ram_out_biaoqian])
|
153 |
+
if t2t_btn_clear is not None:
|
154 |
+
t2t_btn_clear.add([t2t_in_img, t2t_in_tag, t2t_out_tag, t2t_out_cap])
|
155 |
|
156 |
return demo
|
157 |
|