Spaces:
Runtime error
Runtime error
File size: 3,297 Bytes
a08f593 1129909 a08f593 1129909 a08f593 1129909 a08f593 1129909 a08f593 28f8a0e a08f593 1129909 a08f593 1129909 a08f593 1129909 a08f593 1129909 a08f593 1129909 a08f593 1129909 a08f593 1129909 a08f593 1129909 a08f593 ded0d65 a08f593 594b8bf a08f593 1129909 a08f593 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
import numpy as np
import tensorflow as tf
from PIL import Image
from transformers import SegformerImageProcessor, TFSegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from matplotlib import gridspec
# Load model and feature extractor
feature_extractor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-1024-1024")
model = TFSegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-1024-1024")
# Load labels
labels_list = []
with open(r'labels.txt', 'r') as fp:
for line in fp:
labels_list.append(line[:-1])
# ADE20K palette
def ade_palette():
return [
[255, 0, 0],
[255, 187, 0],
[255, 228, 0],
[29, 219, 22],
[178, 204, 255],
[1, 0, 255],
[165, 102, 255],
[217, 65, 197],
[116, 116, 116],
[204, 114, 61],
[206, 242, 121],
[61, 183, 204],
[94, 94, 94],
[196, 183, 59],
[246, 246, 246],
[209, 178, 255],
[0, 87, 102],
[153, 0, 76]
]
labels_list = []
with open(r'labels.txt', 'r') as fp:
for line in fp:
labels_list.append(line[:-1])
colormap = np.asarray(ade_palette())
# Label to color image mapping
def label_to_color_image(label):
if label.ndim != 2:
raise ValueError("Expect 2-D input label")
if np.max(label) >= len(colormap):
raise ValueError("label value too large.")
return colormap[label]
# Draw segmentation plot
def draw_plot(pred_img, seg):
fig = plt.figure(figsize=(20, 15))
grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
plt.subplot(grid_spec[0])
plt.imshow(pred_img)
plt.axis('off')
LABEL_NAMES = np.asarray(labels_list)
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
unique_labels = np.unique(seg.numpy().astype("uint8"))
ax = plt.subplot(grid_spec[1])
plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([], [])
ax.tick_params(width=0.0, labelsize=25)
return fig
# Sepia function
def sepia(input_img):
input_img = Image.fromarray(input_img)
inputs = feature_extractor(images=input_img, return_tensors="tf")
outputs = model(**inputs)
logits = outputs.logits
logits = tf.transpose(logits, [0, 2, 3, 1])
logits = tf.image.resize(
logits, input_img.size[::-1]
)
seg = tf.math.argmax(logits, axis=-1)[0]
color_seg = np.zeros(
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8
)
for label, color in enumerate(colormap):
color_seg[seg.numpy() == label, :] = color
pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
pred_img = pred_img.astype(np.uint8)
fig = draw_plot(pred_img, seg)
return fig
# Gradio Interface
demo = gr.Interface(fn=sepia,
inputs=gr.Image(shape=(800, 1200)),
outputs=['plot'],
examples=["citiscape-1.jpeg", "citiscape-2.jpeg"],
allow_flagging='never')
# Launch the interface
demo.launch() |