sankhyikii's picture
minor update to look good visually
864e5e7
raw history blame
No virus
3.61 kB
import csv
import os
import sys
import cv2
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime as ort
from matplotlib import gridspec
ade_palette = []
labels_list = []
csv.field_size_limit(sys.maxsize)
with open(r"labels.txt", "r") as fp:
for line in fp:
labels_list.append(line[:-1])
with open(r"ade_palette.txt", "r") as fp:
for line in fp:
tmp_list = list(map(int, line[:-1].strip("][").split(", ")))
ade_palette.append(tmp_list)
colormap = np.asarray(ade_palette)
model_filename = "segformer-b5-finetuned-ade-640-640.onnx"
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = os.cpu_count()
sess = ort.InferenceSession(
model_filename, sess_options, providers=["CPUExecutionProvider"]
)
def label_to_color_image(label):
if label.ndim != 2:
raise ValueError("Expect 2-D input label")
if np.max(label) >= len(colormap):
raise ValueError("label value too large.")
return colormap[label]
def draw_plot(pred_img, seg):
fig = plt.figure(figsize=(20, 15))
grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
plt.subplot(grid_spec[0])
plt.imshow(pred_img)
plt.axis("off")
LABEL_NAMES = np.asarray(labels_list)
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
unique_labels = np.unique(seg)
ax = plt.subplot(grid_spec[1])
plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([], [])
ax.tick_params(width=0.0, labelsize=25)
return fig
def sepia(input_img):
img = cv2.imread(input_img)
img = cv2.resize(img, (640, 640)).astype(np.float32)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_batch = np.expand_dims(img, axis=0)
img_batch = np.transpose(img_batch, (0, 3, 1, 2))
logits = sess.run(None, {"pixel_values": img_batch})[0]
logits = np.transpose(logits, (0, 2, 3, 1))
seg = np.argmax(logits, axis=-1)[0].astype("float32")
seg = cv2.resize(seg, (640, 640)).astype("uint8")
color_seg = np.zeros(
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8
) # height, width, 3
for label, color in enumerate(colormap):
color_seg[seg == label, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]
# Show image + mask
pred_img = img * 0.5 + color_seg * 0.5
pred_img = pred_img.astype(np.uint8)
fig = draw_plot(pred_img, seg)
return fig
title = "SegFormer(ADE20k) in TensorFlow"
description = """
This is demo TensorFlow SegFormer from πŸ€— `transformers` official package.
The pre-trained model was trained to segment scene specific images. We are **currently using ONNX model converted from the TensorFlow based SegFormer to improve the latency**.
The average latency of an inference is **21** and **8** seconds for TensorFlow and ONNX converted models respectively
(in [Colab](https://github.com/deep-diver/segformer-tf-transformers/blob/main/notebooks/TFSegFormer_ONNX.ipynb)).
Check out the [repository](https://github.com/deep-diver/segformer-tf-transformers) to find out how to make inference, finetune the model with custom dataset, and further information.
"""
demo = gr.Interface(
sepia,
gr.inputs.Image(type="filepath"),
outputs=["plot"],
examples=["ADE_val_00000001.jpeg"],
allow_flagging="never",
title=title,
description=description,
)
demo.launch()