akhaliq3 commited on
Commit
4a7bfa8
1 Parent(s): 506da10
Files changed (1) hide show
  1. app.py +224 -0
app.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import os
3
+ import tempfile
4
+ from matplotlib import gridspec
5
+ from matplotlib import pyplot as plt
6
+ import numpy as np
7
+ from PIL import Image
8
+ import urllib
9
+ import tensorflow as tf
10
+ import gradio as gr
11
+ from subprocess import call
12
+ import sys
13
+ import requests
14
+ url1 = 'https://cdn.pixabay.com/photo/2014/09/07/21/52/city-438393_1280.jpg'
15
+ r = requests.get(url1, allow_redirects=True)
16
+ open("city1.jpg", 'wb').write(r.content)
17
+ url2 = 'https://cdn.pixabay.com/photo/2016/02/19/11/36/canal-1209808_1280.jpg'
18
+ r = requests.get(url2, allow_redirects=True)
19
+ open("city2.jpg", 'wb').write(r.content)
20
+ DatasetInfo = collections.namedtuple(
21
+ 'DatasetInfo',
22
+ 'num_classes, label_divisor, thing_list, colormap, class_names')
23
+ def _cityscapes_label_colormap():
24
+ """Creates a label colormap used in CITYSCAPES segmentation benchmark.
25
+ See more about CITYSCAPES dataset at https://www.cityscapes-dataset.com/
26
+ M. Cordts, et al. "The Cityscapes Dataset for Semantic Urban Scene Understanding." CVPR. 2016.
27
+ Returns:
28
+ A 2-D numpy array with each row being mapped RGB color (in uint8 range).
29
+ """
30
+ colormap = np.zeros((256, 3), dtype=np.uint8)
31
+ colormap[0] = [128, 64, 128]
32
+ colormap[1] = [244, 35, 232]
33
+ colormap[2] = [70, 70, 70]
34
+ colormap[3] = [102, 102, 156]
35
+ colormap[4] = [190, 153, 153]
36
+ colormap[5] = [153, 153, 153]
37
+ colormap[6] = [250, 170, 30]
38
+ colormap[7] = [220, 220, 0]
39
+ colormap[8] = [107, 142, 35]
40
+ colormap[9] = [152, 251, 152]
41
+ colormap[10] = [70, 130, 180]
42
+ colormap[11] = [220, 20, 60]
43
+ colormap[12] = [255, 0, 0]
44
+ colormap[13] = [0, 0, 142]
45
+ colormap[14] = [0, 0, 70]
46
+ colormap[15] = [0, 60, 100]
47
+ colormap[16] = [0, 80, 100]
48
+ colormap[17] = [0, 0, 230]
49
+ colormap[18] = [119, 11, 32]
50
+ return colormap
51
+ def _cityscapes_class_names():
52
+ return ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
53
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
54
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
55
+ 'bicycle')
56
+ def cityscapes_dataset_information():
57
+ return DatasetInfo(
58
+ num_classes=19,
59
+ label_divisor=1000,
60
+ thing_list=tuple(range(11, 19)),
61
+ colormap=_cityscapes_label_colormap(),
62
+ class_names=_cityscapes_class_names())
63
+ def perturb_color(color, noise, used_colors, max_trials=50, random_state=None):
64
+ """Pertrubs the color with some noise.
65
+ If `used_colors` is not None, we will return the color that has
66
+ not appeared before in it.
67
+ Args:
68
+ color: A numpy array with three elements [R, G, B].
69
+ noise: Integer, specifying the amount of perturbing noise (in uint8 range).
70
+ used_colors: A set, used to keep track of used colors.
71
+ max_trials: An integer, maximum trials to generate random color.
72
+ random_state: An optional np.random.RandomState. If passed, will be used to
73
+ generate random numbers.
74
+ Returns:
75
+ A perturbed color that has not appeared in used_colors.
76
+ """
77
+ if random_state is None:
78
+ random_state = np.random
79
+ for _ in range(max_trials):
80
+ random_color = color + random_state.randint(
81
+ low=-noise, high=noise + 1, size=3)
82
+ random_color = np.clip(random_color, 0, 255)
83
+ if tuple(random_color) not in used_colors:
84
+ used_colors.add(tuple(random_color))
85
+ return random_color
86
+ print('Max trial reached and duplicate color will be used. Please consider '
87
+ 'increase noise in `perturb_color()`.')
88
+ return random_color
89
+ def color_panoptic_map(panoptic_prediction, dataset_info, perturb_noise):
90
+ """Helper method to colorize output panoptic map.
91
+ Args:
92
+ panoptic_prediction: A 2D numpy array, panoptic prediction from deeplab
93
+ model.
94
+ dataset_info: A DatasetInfo object, dataset associated to the model.
95
+ perturb_noise: Integer, the amount of noise (in uint8 range) added to each
96
+ instance of the same semantic class.
97
+ Returns:
98
+ colored_panoptic_map: A 3D numpy array with last dimension of 3, colored
99
+ panoptic prediction map.
100
+ used_colors: A dictionary mapping semantic_ids to a set of colors used
101
+ in `colored_panoptic_map`.
102
+ """
103
+ if panoptic_prediction.ndim != 2:
104
+ raise ValueError('Expect 2-D panoptic prediction. Got {}'.format(
105
+ panoptic_prediction.shape))
106
+ semantic_map = panoptic_prediction // dataset_info.label_divisor
107
+ instance_map = panoptic_prediction % dataset_info.label_divisor
108
+ height, width = panoptic_prediction.shape
109
+ colored_panoptic_map = np.zeros((height, width, 3), dtype=np.uint8)
110
+ used_colors = collections.defaultdict(set)
111
+ # Use a fixed seed to reproduce the same visualization.
112
+ random_state = np.random.RandomState(0)
113
+ unique_semantic_ids = np.unique(semantic_map)
114
+ for semantic_id in unique_semantic_ids:
115
+ semantic_mask = semantic_map == semantic_id
116
+ if semantic_id in dataset_info.thing_list:
117
+ # For `thing` class, we will add a small amount of random noise to its
118
+ # correspondingly predefined semantic segmentation colormap.
119
+ unique_instance_ids = np.unique(instance_map[semantic_mask])
120
+ for instance_id in unique_instance_ids:
121
+ instance_mask = np.logical_and(semantic_mask,
122
+ instance_map == instance_id)
123
+ random_color = perturb_color(
124
+ dataset_info.colormap[semantic_id],
125
+ perturb_noise,
126
+ used_colors[semantic_id],
127
+ random_state=random_state)
128
+ colored_panoptic_map[instance_mask] = random_color
129
+ else:
130
+ # For `stuff` class, we use the defined semantic color.
131
+ colored_panoptic_map[semantic_mask] = dataset_info.colormap[semantic_id]
132
+ used_colors[semantic_id].add(tuple(dataset_info.colormap[semantic_id]))
133
+ return colored_panoptic_map, used_colors
134
+ def vis_segmentation(image,
135
+ panoptic_prediction,
136
+ dataset_info,
137
+ perturb_noise=60):
138
+ """Visualizes input image, segmentation map and overlay view."""
139
+ plt.figure(figsize=(30, 20))
140
+ grid_spec = gridspec.GridSpec(2, 2)
141
+ ax = plt.subplot(grid_spec[0])
142
+ plt.imshow(image)
143
+ plt.axis('off')
144
+ ax.set_title('input image', fontsize=20)
145
+ ax = plt.subplot(grid_spec[1])
146
+ panoptic_map, used_colors = color_panoptic_map(panoptic_prediction,
147
+ dataset_info, perturb_noise)
148
+ plt.imshow(panoptic_map)
149
+ plt.axis('off')
150
+ ax.set_title('panoptic map', fontsize=20)
151
+ ax = plt.subplot(grid_spec[2])
152
+ plt.imshow(image)
153
+ plt.imshow(panoptic_map, alpha=0.7)
154
+ plt.axis('off')
155
+ ax.set_title('panoptic overlay', fontsize=20)
156
+ ax = plt.subplot(grid_spec[3])
157
+ max_num_instances = max(len(color) for color in used_colors.values())
158
+ # RGBA image as legend.
159
+ legend = np.zeros((len(used_colors), max_num_instances, 4), dtype=np.uint8)
160
+ class_names = []
161
+ for i, semantic_id in enumerate(sorted(used_colors)):
162
+ legend[i, :len(used_colors[semantic_id]), :3] = np.array(
163
+ list(used_colors[semantic_id]))
164
+ legend[i, :len(used_colors[semantic_id]), 3] = 255
165
+ if semantic_id < dataset_info.num_classes:
166
+ class_names.append(dataset_info.class_names[semantic_id])
167
+ else:
168
+ class_names.append('ignore')
169
+ plt.imshow(legend, interpolation='nearest')
170
+ ax.yaxis.tick_left()
171
+ plt.yticks(range(len(legend)), class_names, fontsize=15)
172
+ plt.xticks([], [])
173
+ ax.tick_params(width=0.0, grid_linewidth=0.0)
174
+ plt.grid('off')
175
+ return plt
176
+ def run_cmd(command):
177
+ try:
178
+ print(command)
179
+ call(command, shell=True)
180
+ except KeyboardInterrupt:
181
+ print("Process interrupted")
182
+ sys.exit(1)
183
+ MODEL_NAME = 'resnet50_os32_panoptic_deeplab_cityscapes_crowd_trainfine_saved_model'
184
+ _MODELS = ('resnet50_os32_panoptic_deeplab_cityscapes_crowd_trainfine_saved_model',
185
+ 'resnet50_beta_os32_panoptic_deeplab_cityscapes_trainfine_saved_model',
186
+ 'wide_resnet41_os16_panoptic_deeplab_cityscapes_trainfine_saved_model',
187
+ 'swidernet_sac_1_1_1_os16_panoptic_deeplab_cityscapes_trainfine_saved_model',
188
+ 'swidernet_sac_1_1_3_os16_panoptic_deeplab_cityscapes_trainfine_saved_model',
189
+ 'swidernet_sac_1_1_4.5_os16_panoptic_deeplab_cityscapes_trainfine_saved_model',
190
+ 'axial_swidernet_1_1_1_os16_axial_deeplab_cityscapes_trainfine_saved_model',
191
+ 'axial_swidernet_1_1_3_os16_axial_deeplab_cityscapes_trainfine_saved_model',
192
+ 'axial_swidernet_1_1_4.5_os16_axial_deeplab_cityscapes_trainfine_saved_model',
193
+ 'max_deeplab_s_backbone_os16_axial_deeplab_cityscapes_trainfine_saved_model',
194
+ 'max_deeplab_l_backbone_os16_axial_deeplab_cityscapes_trainfine_saved_model')
195
+ _DOWNLOAD_URL_PATTERN = 'https://storage.googleapis.com/gresearch/tf-deeplab/saved_model/%s.tar.gz'
196
+ _MODEL_NAME_TO_URL_AND_DATASET = {
197
+ model: (_DOWNLOAD_URL_PATTERN % model, cityscapes_dataset_information())
198
+ for model in _MODELS
199
+ }
200
+ MODEL_URL, DATASET_INFO = _MODEL_NAME_TO_URL_AND_DATASET[MODEL_NAME]
201
+ model_dir = tempfile.mkdtemp()
202
+ download_path = os.path.join(model_dir, MODEL_NAME + '.gz')
203
+ urllib.request.urlretrieve(MODEL_URL, download_path)
204
+ run_cmd("tar -xzvf " + download_path + " -C " + model_dir)
205
+ LOADED_MODEL = tf.saved_model.load(os.path.join(model_dir, MODEL_NAME))
206
+ def inference(image):
207
+ image = image.resize(size=(512, 512))
208
+ im = np.array(image)
209
+ output = LOADED_MODEL(tf.cast(im, tf.uint8))
210
+ return vis_segmentation(im, output['panoptic_pred'][0], DATASET_INFO)
211
+ title = "Deeplab2"
212
+ description = "demo for Deeplab2. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
213
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2106.09748'>DeepLab2: A TensorFlow Library for Deep Labeling</a> | <a href='https://github.com/google-research/deeplab2'>Github Repo</a></p>"
214
+ gr.Interface(
215
+ inference,
216
+ [gr.inputs.Image(type="pil", label="Input")],
217
+ gr.outputs.Image(type="plot", label="Output"),
218
+ title=title,
219
+ description=description,
220
+ article=article,
221
+ examples=[
222
+ ["city1.jpg"],
223
+ ["city2.jpg"]
224
+ ]).launch()