JinHyeong99 commited on
Commit
f6fd199
1 Parent(s): d7d9b14
Files changed (3) hide show
  1. app.py +244 -0
  2. labels.txt +18 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from matplotlib import gridspec
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from PIL import Image
7
+ import tensorflow as tf
8
+ from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
9
+
10
+ feature_extractor = SegformerFeatureExtractor.from_pretrained(
11
+ # "nvidia/segformer-b5-finetuned-ade-640-640"
12
+ "mattmdjaga/segformer_b2_clothes"
13
+ )
14
+ model = TFSegformerForSemanticSegmentation.from_pretrained(
15
+ # "nvidia/segformer-b5-finetuned-ade-640-640"
16
+ "mattmdjaga/segformer_b2_clothes"
17
+ )
18
+
19
+ def ade_palette():
20
+ """ADE20K palette that maps each class to RGB values."""
21
+ return [
22
+ [204, 87, 92],
23
+ [112, 185, 212],
24
+ [45, 189, 106],
25
+ [234, 123, 67],
26
+ [78, 56, 123],
27
+ [210, 32, 89],
28
+ [90, 180, 56],
29
+ [155, 102, 200],
30
+ [33, 147, 176],
31
+ [255, 183, 76],
32
+ [67, 123, 89],
33
+ [190, 60, 45],
34
+ [134, 112, 200],
35
+ [56, 45, 189],
36
+ [200, 56, 123],
37
+ [87, 92, 204],
38
+ [120, 56, 123],
39
+ [45, 78, 123],
40
+ [156, 200, 56],
41
+ [32, 90, 210],
42
+ [56, 123, 67],
43
+ [180, 56, 123],
44
+ [123, 67, 45],
45
+ [45, 134, 200],
46
+ [67, 56, 123],
47
+ [78, 123, 67],
48
+ [32, 210, 90],
49
+ [45, 56, 189],
50
+ [123, 56, 123],
51
+ [56, 156, 200],
52
+ [189, 56, 45],
53
+ [112, 200, 56],
54
+ [56, 123, 45],
55
+ [200, 32, 90],
56
+ [123, 45, 78],
57
+ [200, 156, 56],
58
+ [45, 67, 123],
59
+ [56, 45, 78],
60
+ [45, 56, 123],
61
+ [123, 67, 56],
62
+ [56, 78, 123],
63
+ [210, 90, 32],
64
+ [123, 56, 189],
65
+ [45, 200, 134],
66
+ [67, 123, 56],
67
+ [123, 45, 67],
68
+ [90, 32, 210],
69
+ [200, 45, 78],
70
+ [32, 210, 90],
71
+ [45, 123, 67],
72
+ [165, 42, 87],
73
+ [72, 145, 167],
74
+ [15, 158, 75],
75
+ [209, 89, 40],
76
+ [32, 21, 121],
77
+ [184, 20, 100],
78
+ [56, 135, 15],
79
+ [128, 92, 176],
80
+ [1, 119, 140],
81
+ [220, 151, 43],
82
+ [41, 97, 72],
83
+ [148, 38, 27],
84
+ [107, 86, 176],
85
+ [21, 26, 136],
86
+ [174, 27, 90],
87
+ [91, 96, 204],
88
+ [108, 50, 107],
89
+ [27, 45, 136],
90
+ [168, 200, 52],
91
+ [7, 102, 27],
92
+ [42, 93, 56],
93
+ [140, 52, 112],
94
+ [92, 107, 168],
95
+ [17, 118, 176],
96
+ [59, 50, 174],
97
+ [206, 40, 143],
98
+ [44, 19, 142],
99
+ [23, 168, 75],
100
+ [54, 57, 189],
101
+ [144, 21, 15],
102
+ [15, 176, 35],
103
+ [107, 19, 79],
104
+ [204, 52, 114],
105
+ [48, 173, 83],
106
+ [11, 120, 53],
107
+ [206, 104, 28],
108
+ [20, 31, 153],
109
+ [27, 21, 93],
110
+ [11, 206, 138],
111
+ [112, 30, 83],
112
+ [68, 91, 152],
113
+ [153, 13, 43],
114
+ [25, 114, 54],
115
+ [92, 27, 150],
116
+ [108, 42, 59],
117
+ [194, 77, 5],
118
+ [145, 48, 83],
119
+ [7, 113, 19],
120
+ [25, 92, 113],
121
+ [60, 168, 79],
122
+ [78, 33, 120],
123
+ [89, 176, 205],
124
+ [27, 200, 94],
125
+ [210, 67, 23],
126
+ [123, 89, 189],
127
+ [225, 56, 112],
128
+ [75, 156, 45],
129
+ [172, 104, 200],
130
+ [15, 170, 197],
131
+ [240, 133, 65],
132
+ [89, 156, 112],
133
+ [214, 88, 57],
134
+ [156, 134, 200],
135
+ [78, 57, 189],
136
+ [200, 78, 123],
137
+ [106, 120, 210],
138
+ [145, 56, 112],
139
+ [89, 120, 189],
140
+ [185, 206, 56],
141
+ [47, 99, 28],
142
+ [112, 189, 78],
143
+ [200, 112, 89],
144
+ [89, 145, 112],
145
+ [78, 106, 189],
146
+ [112, 78, 189],
147
+ [156, 112, 78],
148
+ [28, 210, 99],
149
+ [78, 89, 189],
150
+ [189, 78, 57],
151
+ [112, 200, 78],
152
+ [189, 47, 78],
153
+ [205, 112, 57],
154
+ [78, 145, 57],
155
+ [200, 78, 112],
156
+ [99, 89, 145],
157
+ [200, 156, 78],
158
+ [57, 78, 145],
159
+ [78, 57, 99],
160
+ [57, 78, 145],
161
+ [145, 112, 78],
162
+ [78, 89, 145],
163
+ [210, 99, 28],
164
+ [145, 78, 189],
165
+ [57, 200, 136],
166
+ [89, 156, 78],
167
+ [145, 78, 99],
168
+ [99, 28, 210],
169
+ [189, 78, 47],
170
+ [28, 210, 99],
171
+ [78, 145, 57],
172
+ ]
173
+
174
+ labels_list = []
175
+
176
+ with open(r'labels.txt', 'r') as fp:
177
+ for line in fp:
178
+ labels_list.append(line[:-1])
179
+
180
+ colormap = np.asarray(ade_palette())
181
+
182
+ def label_to_color_image(label):
183
+ if label.ndim != 2:
184
+ raise ValueError("Expect 2-D input label")
185
+
186
+ if np.max(label) >= len(colormap):
187
+ raise ValueError("label value too large.")
188
+ return colormap[label]
189
+
190
+ def draw_plot(pred_img, seg):
191
+ fig = plt.figure(figsize=(20, 15))
192
+
193
+ grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
194
+
195
+ plt.subplot(grid_spec[0])
196
+ plt.imshow(pred_img)
197
+ plt.axis('off')
198
+ LABEL_NAMES = np.asarray(labels_list)
199
+ FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
200
+ FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
201
+
202
+ unique_labels = np.unique(seg.numpy().astype("uint8"))
203
+ ax = plt.subplot(grid_spec[1])
204
+ plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
205
+ ax.yaxis.tick_right()
206
+ plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
207
+ plt.xticks([], [])
208
+ ax.tick_params(width=0.0, labelsize=25)
209
+ return fig
210
+
211
+ def sepia(input_img):
212
+ input_img = Image.fromarray(input_img)
213
+
214
+ inputs = feature_extractor(images=input_img, return_tensors="tf")
215
+ outputs = model(**inputs)
216
+ logits = outputs.logits
217
+
218
+ logits = tf.transpose(logits, [0, 2, 3, 1])
219
+ logits = tf.image.resize(
220
+ logits, input_img.size[::-1]
221
+ ) # We reverse the shape of `image` because `image.size` returns width and height.
222
+ seg = tf.math.argmax(logits, axis=-1)[0]
223
+
224
+ color_seg = np.zeros(
225
+ (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
226
+ ) # height, width, 3
227
+ for label, color in enumerate(colormap):
228
+ color_seg[seg.numpy() == label, :] = color
229
+
230
+ # Show image + mask
231
+ pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
232
+ pred_img = pred_img.astype(np.uint8)
233
+
234
+ fig = draw_plot(pred_img, seg)
235
+ return fig
236
+
237
+ demo = gr.Interface(fn=sepia,
238
+ inputs=gr.Image(shape=(400, 600)),
239
+ outputs=['plot'],
240
+ examples=["ADE_val_00000001.jpeg", "ADE_val_00001159.jpg", "ADE_val_00001248.jpg", "ADE_val_00001472.jpg"],
241
+ allow_flagging='never')
242
+
243
+
244
+ demo.launch()
labels.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Background
2
+ Hat
3
+ Hair
4
+ Sunglasses
5
+ Upper-clothes
6
+ Skirt
7
+ Pants
8
+ Dress
9
+ Belt
10
+ Left-shoe
11
+ Right-shoe
12
+ Face
13
+ Left-leg
14
+ Right-leg
15
+ Left-arm
16
+ Right-arm
17
+ Bag
18
+ Scarf
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ tensorflow
4
+ numpy
5
+ Image
6
+ matplotlib