22 / app.py
hyo37009's picture
a
1985ff0
raw
history blame
No virus
3.04 kB
import gradio as gr
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
from PIL import Image
import tensorflow as tf
import requests
# Load the pre-trained model and feature extractor
feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-cityscapes-640-1280")
def my_palette():
return [
[131, 162, 255],
[180, 189, 255],
[255, 227, 187],
[255, 210, 143],
[248, 117, 170],
[255, 223, 223],
[255, 246, 246],
[174, 222, 252],
[150, 194, 145],
[255, 219, 170],
[244, 238, 238],
[50, 38, 83],
[128, 98, 214],
[146, 136, 248],
[255, 210, 215],
[255, 152, 152],
[162, 103, 138],
[63, 29, 56]
]
labels_list = []
with open(r"labels.txt", "r") as fp:
for line in fp:
labels_list.append(line[:-1])
colormap = np.asarray(my_palette())
def greet(input_img):
inputs = feature_extractor(images=input_img, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
logits_tf = tf.transpose(logits, [0, 2, 3, 1])
logits_tf = tf.image.resize(
logits_tf, input_img.size[::-1]
)
seg = tf.math.argmax(logits_tf, axis=-1)[0]
color_seg = np.zeros(
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8
) # height, width, 3
for label, color in enumerate(colormap):
color_seg[seg.numpy() == label, :] = color
pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
pred_img = pred_img.astype(np.uint8)
fig = draw_plot(pred_img, seg)
return fig
def draw_plot(pred_img, seg):
fig = plt.figure(figsize=(20, 15))
grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
plt.subplot(grid_spec[0])
plt.imshow(pred_img)
plt.axis("off")
LABEL_NAMES = np.asarray(labels_list)
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
unique_labels = np.unique(seg.astype("uint8"))
ax = plt.subplot(grid_spec[1])
plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([], [])
ax.tick_params(width=0.0, labelsize=25)
return fig
def label_to_color_image(label):
if label.ndim != 2:
raise ValueError("Expect 2-D input label")
if np.max(label) >= len(colormap):
raise ValueError("label value too large.")
return colormap[label]
iface = gr.Interface(
fn=greet,
inputs="image",
outputs=["plot"],
examples=["image (1).jpg", "image (2).jpg", "image (3).jpg", "image (4).jpg", "image (5).jpg"],
allow_flagging="never"
)
iface.launch(share=True)