Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1 |
import gradio as gr
|
2 |
|
|
|
|
|
|
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
import tensorflow as tf
|
@@ -167,6 +170,58 @@ def ade_palette():
|
|
167 |
[92, 0, 255],
|
168 |
]
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
def sepia(input_img):
|
171 |
input_img = Image.fromarray(input_img)
|
172 |
|
@@ -194,8 +249,10 @@ def sepia(input_img):
|
|
194 |
# Show image + mask
|
195 |
pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
|
196 |
pred_img = pred_img.astype(np.uint8)
|
197 |
-
return pred_img
|
198 |
|
199 |
-
|
|
|
|
|
|
|
200 |
|
201 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
import pandas as pd
|
4 |
+
from matplotlib import gridspec
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
import numpy as np
|
7 |
from PIL import Image
|
8 |
import tensorflow as tf
|
|
|
170 |
[92, 0, 255],
|
171 |
]
|
172 |
|
173 |
+
def label_to_color_image(label):
|
174 |
+
"""Adds color defined by the dataset colormap to the label.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
label: A 2D array with integer type, storing the segmentation label.
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
result: A 2D array with floating type. The element of the array
|
181 |
+
is the color indexed by the corresponding element in the input label
|
182 |
+
to the PASCAL color map.
|
183 |
+
|
184 |
+
Raises:
|
185 |
+
ValueError: If label is not of rank 2 or its value is larger than color
|
186 |
+
map maximum entry.
|
187 |
+
"""
|
188 |
+
if label.ndim != 2:
|
189 |
+
raise ValueError("Expect 2-D input label")
|
190 |
+
|
191 |
+
colormap = np.asarray(ade_palette())
|
192 |
+
|
193 |
+
if np.max(label) >= len(colormap):
|
194 |
+
raise ValueError("label value too large.")
|
195 |
+
|
196 |
+
return colormap[label]
|
197 |
+
|
198 |
+
def draw_plot(pred_img, seg):
|
199 |
+
fig = plt.figure(figsize=(20, 15))
|
200 |
+
|
201 |
+
grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
|
202 |
+
|
203 |
+
plt.subplot(grid_spec[0])
|
204 |
+
plt.imshow(pred_img)
|
205 |
+
plt.axis('off')
|
206 |
+
|
207 |
+
ade20k_labels_info = pd.read_csv(
|
208 |
+
"https://raw.githubusercontent.com/CSAILVision/sceneparsing/master/objectInfo150.csv"
|
209 |
+
)
|
210 |
+
labels_list = list(ade20k_labels_info["Name"])
|
211 |
+
|
212 |
+
LABEL_NAMES = np.asarray(labels_list)
|
213 |
+
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
|
214 |
+
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
|
215 |
+
|
216 |
+
unique_labels = np.unique(seg.numpy().astype("uint8"))
|
217 |
+
ax = plt.subplot(grid_spec[1])
|
218 |
+
plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
|
219 |
+
ax.yaxis.tick_right()
|
220 |
+
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
|
221 |
+
plt.xticks([], [])
|
222 |
+
ax.tick_params(width=0.0, labelsize=25)
|
223 |
+
return fig
|
224 |
+
|
225 |
def sepia(input_img):
|
226 |
input_img = Image.fromarray(input_img)
|
227 |
|
|
|
249 |
# Show image + mask
|
250 |
pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
|
251 |
pred_img = pred_img.astype(np.uint8)
|
|
|
252 |
|
253 |
+
fig = draw_plot(pred_img, seg)
|
254 |
+
return fig
|
255 |
+
|
256 |
+
demo = gr.Interface(sepia, gr.Image(shape=(200, 200)), outputs=['plot'], examples=["ADE_val_00000001.jpeg"])
|
257 |
|
258 |
demo.launch()
|