Spaces:
Runtime error
Runtime error
xvjiarui
commited on
Commit
β’
249926b
0
Parent(s):
add hg spaces app
Browse files- README.md +9 -0
- app.py +156 -0
- packages.txt +3 -0
- requirements.txt +12 -0
README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: GroupViT
|
3 |
+
emoji: π
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
app_file: app.py
|
8 |
+
pinned: true
|
9 |
+
---
|
app.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from the implementation of https://huggingface.co/akhaliq
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
os.system("git clone https://github.com/NVlabs/GroupViT")
|
5 |
+
sys.path.append('./GroupViT')
|
6 |
+
|
7 |
+
import os.path as osp
|
8 |
+
from collections import namedtuple
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
import mmcv
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from datasets import build_text_transform
|
15 |
+
from mmcv.cnn.utils import revert_sync_batchnorm
|
16 |
+
from mmcv.image import tensor2imgs
|
17 |
+
from mmcv.parallel import collate, scatter
|
18 |
+
from models import build_model
|
19 |
+
from omegaconf import read_write
|
20 |
+
from segmentation.datasets import (COCOObjectDataset, PascalContextDataset,
|
21 |
+
PascalVOCDataset)
|
22 |
+
from segmentation.evaluation import (GROUP_PALETTE, build_seg_demo_pipeline,
|
23 |
+
build_seg_inference)
|
24 |
+
from utils import get_config, load_checkpoint
|
25 |
+
|
26 |
+
checkpoint_url = 'https://github.com/xvjiarui/GroupViT-1/releases/download/v1.0.0/group_vit_gcc_yfcc_30e-74d335e6.pth'
|
27 |
+
cfg_path = 'configs/group_vit_gcc_yfcc_30e.yml'
|
28 |
+
output_dir = 'demo/output'
|
29 |
+
device = 'cpu'
|
30 |
+
# vis_modes = ['first_group', 'final_group', 'input_pred_label']
|
31 |
+
vis_modes = ['input_pred_label', 'final_group']
|
32 |
+
output_labels = ['segmentation map', 'groups']
|
33 |
+
dataset_options = ['Pascal VOC', 'Pascal Context', 'COCO']
|
34 |
+
examples = [['Pascal VOC', '', 'demo/examples/voc.jpg'],
|
35 |
+
['Pascal Context', '', 'demo/examples/ctx.jpg'],
|
36 |
+
['COCO', 'rock', 'demo/examples/coco.jpg']]
|
37 |
+
|
38 |
+
PSEUDO_ARGS = namedtuple('PSEUDO_ARGS',
|
39 |
+
['cfg', 'opts', 'resume', 'vis', 'local_rank'])
|
40 |
+
|
41 |
+
args = PSEUDO_ARGS(
|
42 |
+
cfg=cfg_path, opts=[], resume=checkpoint_url, vis=vis_modes, local_rank=0)
|
43 |
+
|
44 |
+
cfg = get_config(args)
|
45 |
+
|
46 |
+
with read_write(cfg):
|
47 |
+
cfg.evaluate.eval_only = True
|
48 |
+
|
49 |
+
model = build_model(cfg.model)
|
50 |
+
model = revert_sync_batchnorm(model)
|
51 |
+
model.to(device)
|
52 |
+
model.eval()
|
53 |
+
|
54 |
+
load_checkpoint(cfg, model, None, None)
|
55 |
+
|
56 |
+
text_transform = build_text_transform(False, cfg.data.text_aug, with_dc=False)
|
57 |
+
test_pipeline = build_seg_demo_pipeline()
|
58 |
+
|
59 |
+
|
60 |
+
def inference(dataset, additional_classes, input_img):
|
61 |
+
if dataset == 'voc' or dataset == 'Pascal VOC':
|
62 |
+
dataset_class = PascalVOCDataset
|
63 |
+
seg_cfg = 'segmentation/configs/_base_/datasets/pascal_voc12.py'
|
64 |
+
elif dataset == 'coco' or dataset == 'COCO':
|
65 |
+
dataset_class = COCOObjectDataset
|
66 |
+
seg_cfg = 'segmentation/configs/_base_/datasets/coco_object164k.py'
|
67 |
+
elif dataset == 'context' or dataset == 'Pascal Context':
|
68 |
+
dataset_class = PascalContextDataset
|
69 |
+
seg_cfg = 'segmentation/configs/_base_/datasets/pascal_context.py'
|
70 |
+
else:
|
71 |
+
raise ValueError('Unknown dataset: {}'.format(args.dataset))
|
72 |
+
with read_write(cfg):
|
73 |
+
cfg.evaluate.seg.cfg = seg_cfg
|
74 |
+
|
75 |
+
dataset_cfg = mmcv.Config()
|
76 |
+
dataset_cfg.CLASSES = list(dataset_class.CLASSES)
|
77 |
+
dataset_cfg.PALETTE = dataset_class.PALETTE.copy()
|
78 |
+
|
79 |
+
if len(additional_classes) > 0:
|
80 |
+
additional_classes = additional_classes.split(',')
|
81 |
+
additional_classes = list(
|
82 |
+
set(additional_classes) - set(dataset_cfg.CLASSES))
|
83 |
+
dataset_cfg.CLASSES.extend(additional_classes)
|
84 |
+
dataset_cfg.PALETTE.extend(GROUP_PALETTE[np.random.choice(
|
85 |
+
list(range(len(GROUP_PALETTE))), len(additional_classes))])
|
86 |
+
seg_model = build_seg_inference(model, dataset_cfg, text_transform,
|
87 |
+
cfg.evaluate.seg)
|
88 |
+
|
89 |
+
device = next(seg_model.parameters()).device
|
90 |
+
# prepare data
|
91 |
+
data = dict(img=input_img)
|
92 |
+
data = test_pipeline(data)
|
93 |
+
data = collate([data], samples_per_gpu=1)
|
94 |
+
if next(seg_model.parameters()).is_cuda:
|
95 |
+
# scatter to specified GPU
|
96 |
+
data = scatter(data, [device])[0]
|
97 |
+
else:
|
98 |
+
data['img_metas'] = [i.data[0] for i in data['img_metas']]
|
99 |
+
with torch.no_grad():
|
100 |
+
result = seg_model(return_loss=False, rescale=True, **data)
|
101 |
+
|
102 |
+
img_tensor = data['img'][0]
|
103 |
+
img_metas = data['img_metas'][0]
|
104 |
+
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
|
105 |
+
assert len(imgs) == len(img_metas)
|
106 |
+
|
107 |
+
out_file_dict = dict()
|
108 |
+
for img, img_meta in zip(imgs, img_metas):
|
109 |
+
h, w, _ = img_meta['img_shape']
|
110 |
+
img_show = img[:h, :w, :]
|
111 |
+
|
112 |
+
ori_h, ori_w = img_meta['ori_shape'][:-1]
|
113 |
+
img_show = mmcv.imresize(img_show, (ori_w, ori_h))
|
114 |
+
|
115 |
+
for vis_mode in vis_modes:
|
116 |
+
out_file = osp.join(output_dir, 'vis_imgs', vis_mode,
|
117 |
+
f'{vis_mode}.jpg')
|
118 |
+
seg_model.show_result(img_show, img_tensor.to(device), result,
|
119 |
+
out_file, vis_mode)
|
120 |
+
out_file_dict[vis_mode] = out_file
|
121 |
+
|
122 |
+
return [out_file_dict[mode] for mode in vis_modes]
|
123 |
+
|
124 |
+
|
125 |
+
title = 'GroupViT'
|
126 |
+
|
127 |
+
description = """
|
128 |
+
Gradio Demo for GroupViT: Semantic Segmentation Emerges from Text Supervision. \n
|
129 |
+
You may click on of the examples or upload your own image. \n
|
130 |
+
GroupViT could perform open vocabulary segmentation, you may input more classes,
|
131 |
+
e.g. "rock" is not in the COCO dataset, but you could input it for the giraffe image.
|
132 |
+
"""
|
133 |
+
|
134 |
+
article = """
|
135 |
+
<p style='text-align: center'>
|
136 |
+
<a href='https://arxiv.org/abs/2202.11094' target='_blank'>
|
137 |
+
GroupViT: Semantic Segmentation Emerges from Text Supervision
|
138 |
+
</a>
|
139 |
+
|
|
140 |
+
<a href='https://github.com/NVlabs/GroupViT' target='_blank'>Github Repo</a></p>
|
141 |
+
"""
|
142 |
+
|
143 |
+
gr.Interface(
|
144 |
+
inference,
|
145 |
+
inputs=[
|
146 |
+
gr.inputs.Dropdown(dataset_options, type='value', label='Dataset'),
|
147 |
+
gr.inputs.Textbox(
|
148 |
+
lines=1, placeholder=None, default='', label='More classes'),
|
149 |
+
gr.inputs.Image(type='filepath')
|
150 |
+
],
|
151 |
+
outputs=[gr.outputs.Image(label=label) for label in output_labels],
|
152 |
+
title=title,
|
153 |
+
description=description,
|
154 |
+
article=article,
|
155 |
+
examples=examples).launch(
|
156 |
+
enable_queue=True, share=True)
|
packages.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
libsm6
|
2 |
+
libxext6
|
3 |
+
python3-opencv
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffdist==0.1
|
2 |
+
einops
|
3 |
+
ftfy==6.0.3
|
4 |
+
mmcv==1.3.14
|
5 |
+
git+https://github.com/xvjiarui/mmsegmentation.git@cpu_only#egg=mmsegmentation
|
6 |
+
nltk==3.6.2
|
7 |
+
omegaconf==2.1.1
|
8 |
+
termcolor==1.1.0
|
9 |
+
timm==0.3.2
|
10 |
+
torch==1.8.0
|
11 |
+
torchvision==0.9.0
|
12 |
+
webdataset==0.1.103
|