xvjiarui commited on
Commit
249926b
β€’
0 Parent(s):

add hg spaces app

Browse files
Files changed (4) hide show
  1. README.md +9 -0
  2. app.py +156 -0
  3. packages.txt +3 -0
  4. requirements.txt +12 -0
README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: GroupViT
3
+ emoji: πŸ‘€
4
+ colorFrom: indigo
5
+ colorTo: red
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: true
9
+ ---
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from the implementation of https://huggingface.co/akhaliq
2
+ import os
3
+ import sys
4
+ os.system("git clone https://github.com/NVlabs/GroupViT")
5
+ sys.path.append('./GroupViT')
6
+
7
+ import os.path as osp
8
+ from collections import namedtuple
9
+
10
+ import gradio as gr
11
+ import mmcv
12
+ import numpy as np
13
+ import torch
14
+ from datasets import build_text_transform
15
+ from mmcv.cnn.utils import revert_sync_batchnorm
16
+ from mmcv.image import tensor2imgs
17
+ from mmcv.parallel import collate, scatter
18
+ from models import build_model
19
+ from omegaconf import read_write
20
+ from segmentation.datasets import (COCOObjectDataset, PascalContextDataset,
21
+ PascalVOCDataset)
22
+ from segmentation.evaluation import (GROUP_PALETTE, build_seg_demo_pipeline,
23
+ build_seg_inference)
24
+ from utils import get_config, load_checkpoint
25
+
26
+ checkpoint_url = 'https://github.com/xvjiarui/GroupViT-1/releases/download/v1.0.0/group_vit_gcc_yfcc_30e-74d335e6.pth'
27
+ cfg_path = 'configs/group_vit_gcc_yfcc_30e.yml'
28
+ output_dir = 'demo/output'
29
+ device = 'cpu'
30
+ # vis_modes = ['first_group', 'final_group', 'input_pred_label']
31
+ vis_modes = ['input_pred_label', 'final_group']
32
+ output_labels = ['segmentation map', 'groups']
33
+ dataset_options = ['Pascal VOC', 'Pascal Context', 'COCO']
34
+ examples = [['Pascal VOC', '', 'demo/examples/voc.jpg'],
35
+ ['Pascal Context', '', 'demo/examples/ctx.jpg'],
36
+ ['COCO', 'rock', 'demo/examples/coco.jpg']]
37
+
38
+ PSEUDO_ARGS = namedtuple('PSEUDO_ARGS',
39
+ ['cfg', 'opts', 'resume', 'vis', 'local_rank'])
40
+
41
+ args = PSEUDO_ARGS(
42
+ cfg=cfg_path, opts=[], resume=checkpoint_url, vis=vis_modes, local_rank=0)
43
+
44
+ cfg = get_config(args)
45
+
46
+ with read_write(cfg):
47
+ cfg.evaluate.eval_only = True
48
+
49
+ model = build_model(cfg.model)
50
+ model = revert_sync_batchnorm(model)
51
+ model.to(device)
52
+ model.eval()
53
+
54
+ load_checkpoint(cfg, model, None, None)
55
+
56
+ text_transform = build_text_transform(False, cfg.data.text_aug, with_dc=False)
57
+ test_pipeline = build_seg_demo_pipeline()
58
+
59
+
60
+ def inference(dataset, additional_classes, input_img):
61
+ if dataset == 'voc' or dataset == 'Pascal VOC':
62
+ dataset_class = PascalVOCDataset
63
+ seg_cfg = 'segmentation/configs/_base_/datasets/pascal_voc12.py'
64
+ elif dataset == 'coco' or dataset == 'COCO':
65
+ dataset_class = COCOObjectDataset
66
+ seg_cfg = 'segmentation/configs/_base_/datasets/coco_object164k.py'
67
+ elif dataset == 'context' or dataset == 'Pascal Context':
68
+ dataset_class = PascalContextDataset
69
+ seg_cfg = 'segmentation/configs/_base_/datasets/pascal_context.py'
70
+ else:
71
+ raise ValueError('Unknown dataset: {}'.format(args.dataset))
72
+ with read_write(cfg):
73
+ cfg.evaluate.seg.cfg = seg_cfg
74
+
75
+ dataset_cfg = mmcv.Config()
76
+ dataset_cfg.CLASSES = list(dataset_class.CLASSES)
77
+ dataset_cfg.PALETTE = dataset_class.PALETTE.copy()
78
+
79
+ if len(additional_classes) > 0:
80
+ additional_classes = additional_classes.split(',')
81
+ additional_classes = list(
82
+ set(additional_classes) - set(dataset_cfg.CLASSES))
83
+ dataset_cfg.CLASSES.extend(additional_classes)
84
+ dataset_cfg.PALETTE.extend(GROUP_PALETTE[np.random.choice(
85
+ list(range(len(GROUP_PALETTE))), len(additional_classes))])
86
+ seg_model = build_seg_inference(model, dataset_cfg, text_transform,
87
+ cfg.evaluate.seg)
88
+
89
+ device = next(seg_model.parameters()).device
90
+ # prepare data
91
+ data = dict(img=input_img)
92
+ data = test_pipeline(data)
93
+ data = collate([data], samples_per_gpu=1)
94
+ if next(seg_model.parameters()).is_cuda:
95
+ # scatter to specified GPU
96
+ data = scatter(data, [device])[0]
97
+ else:
98
+ data['img_metas'] = [i.data[0] for i in data['img_metas']]
99
+ with torch.no_grad():
100
+ result = seg_model(return_loss=False, rescale=True, **data)
101
+
102
+ img_tensor = data['img'][0]
103
+ img_metas = data['img_metas'][0]
104
+ imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
105
+ assert len(imgs) == len(img_metas)
106
+
107
+ out_file_dict = dict()
108
+ for img, img_meta in zip(imgs, img_metas):
109
+ h, w, _ = img_meta['img_shape']
110
+ img_show = img[:h, :w, :]
111
+
112
+ ori_h, ori_w = img_meta['ori_shape'][:-1]
113
+ img_show = mmcv.imresize(img_show, (ori_w, ori_h))
114
+
115
+ for vis_mode in vis_modes:
116
+ out_file = osp.join(output_dir, 'vis_imgs', vis_mode,
117
+ f'{vis_mode}.jpg')
118
+ seg_model.show_result(img_show, img_tensor.to(device), result,
119
+ out_file, vis_mode)
120
+ out_file_dict[vis_mode] = out_file
121
+
122
+ return [out_file_dict[mode] for mode in vis_modes]
123
+
124
+
125
+ title = 'GroupViT'
126
+
127
+ description = """
128
+ Gradio Demo for GroupViT: Semantic Segmentation Emerges from Text Supervision. \n
129
+ You may click on of the examples or upload your own image. \n
130
+ GroupViT could perform open vocabulary segmentation, you may input more classes,
131
+ e.g. "rock" is not in the COCO dataset, but you could input it for the giraffe image.
132
+ """
133
+
134
+ article = """
135
+ <p style='text-align: center'>
136
+ <a href='https://arxiv.org/abs/2202.11094' target='_blank'>
137
+ GroupViT: Semantic Segmentation Emerges from Text Supervision
138
+ </a>
139
+ |
140
+ <a href='https://github.com/NVlabs/GroupViT' target='_blank'>Github Repo</a></p>
141
+ """
142
+
143
+ gr.Interface(
144
+ inference,
145
+ inputs=[
146
+ gr.inputs.Dropdown(dataset_options, type='value', label='Dataset'),
147
+ gr.inputs.Textbox(
148
+ lines=1, placeholder=None, default='', label='More classes'),
149
+ gr.inputs.Image(type='filepath')
150
+ ],
151
+ outputs=[gr.outputs.Image(label=label) for label in output_labels],
152
+ title=title,
153
+ description=description,
154
+ article=article,
155
+ examples=examples).launch(
156
+ enable_queue=True, share=True)
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ libsm6
2
+ libxext6
3
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffdist==0.1
2
+ einops
3
+ ftfy==6.0.3
4
+ mmcv==1.3.14
5
+ git+https://github.com/xvjiarui/mmsegmentation.git@cpu_only#egg=mmsegmentation
6
+ nltk==3.6.2
7
+ omegaconf==2.1.1
8
+ termcolor==1.1.0
9
+ timm==0.3.2
10
+ torch==1.8.0
11
+ torchvision==0.9.0
12
+ webdataset==0.1.103