chansung commited on
Commit
538bf82
1 Parent(s): 00d53da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -2
app.py CHANGED
@@ -1,5 +1,8 @@
1
  import gradio as gr
2
 
 
 
 
3
  import numpy as np
4
  from PIL import Image
5
  import tensorflow as tf
@@ -167,6 +170,58 @@ def ade_palette():
167
  [92, 0, 255],
168
  ]
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  def sepia(input_img):
171
  input_img = Image.fromarray(input_img)
172
 
@@ -194,8 +249,10 @@ def sepia(input_img):
194
  # Show image + mask
195
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
196
  pred_img = pred_img.astype(np.uint8)
197
- return pred_img
198
 
199
- demo = gr.Interface(sepia, gr.Image(shape=(200, 200)), "image", examples=["ADE_val_00000001.jpeg"])
 
 
 
200
 
201
  demo.launch()
 
1
  import gradio as gr
2
 
3
+ import pandas as pd
4
+ from matplotlib import gridspec
5
+ import matplotlib.pyplot as plt
6
  import numpy as np
7
  from PIL import Image
8
  import tensorflow as tf
 
170
  [92, 0, 255],
171
  ]
172
 
173
+ def label_to_color_image(label):
174
+ """Adds color defined by the dataset colormap to the label.
175
+
176
+ Args:
177
+ label: A 2D array with integer type, storing the segmentation label.
178
+
179
+ Returns:
180
+ result: A 2D array with floating type. The element of the array
181
+ is the color indexed by the corresponding element in the input label
182
+ to the PASCAL color map.
183
+
184
+ Raises:
185
+ ValueError: If label is not of rank 2 or its value is larger than color
186
+ map maximum entry.
187
+ """
188
+ if label.ndim != 2:
189
+ raise ValueError("Expect 2-D input label")
190
+
191
+ colormap = np.asarray(ade_palette())
192
+
193
+ if np.max(label) >= len(colormap):
194
+ raise ValueError("label value too large.")
195
+
196
+ return colormap[label]
197
+
198
+ def draw_plot(pred_img, seg):
199
+ fig = plt.figure(figsize=(20, 15))
200
+
201
+ grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
202
+
203
+ plt.subplot(grid_spec[0])
204
+ plt.imshow(pred_img)
205
+ plt.axis('off')
206
+
207
+ ade20k_labels_info = pd.read_csv(
208
+ "https://raw.githubusercontent.com/CSAILVision/sceneparsing/master/objectInfo150.csv"
209
+ )
210
+ labels_list = list(ade20k_labels_info["Name"])
211
+
212
+ LABEL_NAMES = np.asarray(labels_list)
213
+ FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
214
+ FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
215
+
216
+ unique_labels = np.unique(seg.numpy().astype("uint8"))
217
+ ax = plt.subplot(grid_spec[1])
218
+ plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
219
+ ax.yaxis.tick_right()
220
+ plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
221
+ plt.xticks([], [])
222
+ ax.tick_params(width=0.0, labelsize=25)
223
+ return fig
224
+
225
  def sepia(input_img):
226
  input_img = Image.fromarray(input_img)
227
 
 
249
  # Show image + mask
250
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
251
  pred_img = pred_img.astype(np.uint8)
 
252
 
253
+ fig = draw_plot(pred_img, seg)
254
+ return fig
255
+
256
+ demo = gr.Interface(sepia, gr.Image(shape=(200, 200)), outputs=['plot'], examples=["ADE_val_00000001.jpeg"])
257
 
258
  demo.launch()