File size: 3,603 Bytes
f2c6c21
6cc6b6b
 
 
eba227e
6cc6b6b
538bf82
6cc6b6b
eba227e
6cc6b6b
872f59a
0c5b5d7
68b644b
 
f2c6c21
 
6cc6b6b
68b644b
 
0934f7b
6cc6b6b
0c5b5d7
6cc6b6b
 
0c5b5d7
 
 
6cc6b6b
 
 
 
 
 
0c5b5d7
 
3551e60
538bf82
 
 
 
 
 
 
 
6cc6b6b
538bf82
 
 
 
 
 
 
6cc6b6b
538bf82
 
 
 
 
eba227e
538bf82
 
 
 
 
 
 
 
6cc6b6b
872f59a
6934201
 
eba227e
 
 
6cc6b6b
eba227e
 
 
6cc6b6b
 
 
872f59a
 
 
 
3551e60
872f59a
 
 
 
 
 
eba227e
6cc6b6b
872f59a
538bf82
 
 
6cc6b6b
d7a7630
 
 
e8f9fb4
d7a7630
 
 
6cc6b6b
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import csv
import os
import sys

import cv2
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime as ort
from matplotlib import gridspec

ade_palette = []
labels_list = []

csv.field_size_limit(sys.maxsize)

with open(r"labels.txt", "r") as fp:
    for line in fp:
        labels_list.append(line[:-1])

with open(r"ade_palette.txt", "r") as fp:
    for line in fp:
        tmp_list = list(map(int, line[:-1].strip("][").split(", ")))
        ade_palette.append(tmp_list)

colormap = np.asarray(ade_palette)

model_filename = "segformer-b5-finetuned-ade-640-640.onnx"
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = os.cpu_count()
sess = ort.InferenceSession(
    model_filename, sess_options, providers=["CPUExecutionProvider"]
)


def label_to_color_image(label):
    if label.ndim != 2:
        raise ValueError("Expect 2-D input label")

    if np.max(label) >= len(colormap):
        raise ValueError("label value too large.")

    return colormap[label]


def draw_plot(pred_img, seg):
    fig = plt.figure(figsize=(20, 15))

    grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])

    plt.subplot(grid_spec[0])
    plt.imshow(pred_img)
    plt.axis("off")

    LABEL_NAMES = np.asarray(labels_list)
    FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
    FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)

    unique_labels = np.unique(seg)
    ax = plt.subplot(grid_spec[1])
    plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
    ax.yaxis.tick_right()
    plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
    plt.xticks([], [])
    ax.tick_params(width=0.0, labelsize=25)
    return fig


def sepia(input_img):
    img = cv2.imread(input_img)
    img = cv2.resize(img, (640, 640)).astype(np.float32)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_batch = np.expand_dims(img, axis=0)
    img_batch = np.transpose(img_batch, (0, 3, 1, 2))

    logits = sess.run(None, {"pixel_values": img_batch})[0]

    logits = np.transpose(logits, (0, 2, 3, 1))
    seg = np.argmax(logits, axis=-1)[0].astype("float32")
    seg = cv2.resize(seg, (640, 640)).astype("uint8")

    color_seg = np.zeros(
        (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
    )  # height, width, 3

    for label, color in enumerate(colormap):
        color_seg[seg == label, :] = color

    # Convert to BGR
    color_seg = color_seg[..., ::-1]

    # Show image + mask
    pred_img = img * 0.5 + color_seg * 0.5
    pred_img = pred_img.astype(np.uint8)

    fig = draw_plot(pred_img, seg)
    return fig


title = "SegFormer(ADE20k) in TensorFlow"
description = """

This is demo TensorFlow SegFormer from 🤗 `transformers` official package. The pre-trained model was trained to segment scene specific images. We are **currently using ONNX model converted from the TensorFlow based SegFormer to improve the latency**. The average latency of an inference is **21** and **8** seconds for TensorFlow and ONNX converted models respectively (in [Colab](https://github.com/deep-diver/segformer-tf-transformers/blob/main/notebooks/TFSegFormer_ONNX.ipynb)). Check out the [repository](https://github.com/deep-diver/segformer-tf-transformers) to find out how to make inference, finetune the model with custom dataset, and further information.

"""

demo = gr.Interface(
    sepia,
    gr.inputs.Image(type="filepath"),
    outputs=["plot"],
    examples=["ADE_val_00000001.jpeg"],
    allow_flagging="never",
    title=title,
    description=description,
)

demo.launch()