File size: 4,263 Bytes
c304fb7
2c1d18b
c304fb7
 
 
 
 
f1cff84
c304fb7
f1cff84
c304fb7
 
6a36cd0
 
 
 
 
 
 
 
 
 
 
f1cff84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a8487d
f1cff84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c1d18b
 
8239775
c304fb7
 
 
 
6a36cd0
 
 
 
 
c304fb7
 
f1cff84
c304fb7
 
 
 
 
6a36cd0
 
f1cff84
6a36cd0
 
 
 
f1cff84
6a36cd0
 
 
 
c304fb7
f1cff84
 
 
6a36cd0
c304fb7
2c1d18b
8239775
 
c304fb7
f1cff84
 
 
 
 
8239775
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import numpy as np
import gradio as gr
from PIL import Image

import torch
from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation


model_checkpoint = "apple/deeplabv3-mobilevit-small"
feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_checkpoint)
model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval()

palette = np.array(
[
    [  0,   0,   0], [192,   0,   0], [  0, 192,   0], [192, 192,   0],
    [  0,   0, 192], [192,   0, 192], [  0, 192, 192], [192, 192, 192],
    [128,   0,   0], [255,   0,   0], [128, 192,   0], [255, 192,   0],
    [128,   0, 192], [255,   0, 192], [128, 192, 192], [255, 192, 192],
    [  0, 128,   0], [192, 128,   0], [  0, 255,   0], [192, 255,   0],
    [  0, 128, 192]
],
dtype=np.uint8)

labels = [
    "background",
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "pottedplant",
    "sheep",
    "sofa",
    "train",
    "tvmonitor",
]

# Draw the labels. Light colors use black text, dark colors use white text.
inverted = [ 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20 ]
labels_colored = []
for i in range(len(labels)):
    r, g, b = palette[i]
    label = labels[i]
    color = "white" if i in inverted else "black"
    text = "<span style='background-color: rgb(%d, %d, %d); color: %s; padding: 2px 4px;'>%s</span>" % (r, g, b, color, label)
    labels_colored.append(text)
labels_text = " ".join(labels_colored)

title = "Semantic Segmentation with MobileViT and DeepLabV3"

description = """
The input image is resized and center cropped to 512Γ—512 pixels. The segmentation output is 32Γ—32 pixels.<br>
This model has been trained on <a href="http://host.robots.ox.ac.uk/pascal/VOC/">Pascal VOC</a>.
The classes are:
""" + labels_text + "</p>"

article = """
<div style='margin:20px auto;'>

<p>Sources:<p>

<p>πŸ“œ <a href="https://arxiv.org/abs/2110.02178">MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer</a></p>

<p>πŸ‹οΈ Original pretrained weights from <a href="https://github.com/apple/ml-cvnets">this GitHub repo</a></p>

<p>πŸ™ Example images from <a href="https://huggingface.co/datasets/mishig/sample_images">this dataset</a><p>

</div>
"""

examples = [
    ["cat-3.jpg"],
    ["construction-site.jpg"],
    ["dog-cat.jpg"],
    ["football-match.jpg"],
]


def predict(image):
    with torch.no_grad():
        inputs = feature_extractor(image, return_tensors="pt")
        outputs = model(**inputs)

    # Get preprocessed image. The pixel values don't need to be unnormalized
    # for this particular model.
    resized = (inputs["pixel_values"].numpy().squeeze().transpose(1, 2, 0)[..., ::-1] * 255).astype(np.uint8)

    # Class predictions for each pixel.
    classes = outputs.logits.argmax(1).squeeze().numpy().astype(np.uint8)

    # Super slow method but it works... should probably improve this.
    colored = np.zeros((classes.shape[0], classes.shape[1], 3), dtype=np.uint8)
    for y in range(classes.shape[0]):
        for x in range(classes.shape[1]):
            colored[y, x] = palette[classes[y, x]]

    # Resize predictions to input size (not original size).
    colored = Image.fromarray(colored)
    colored = colored.resize((resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST)

    # Keep everything that is not background.
    mask = (classes != 0) * 255
    mask = Image.fromarray(mask.astype(np.uint8)).convert("RGB")
    mask = mask.resize((resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST)

    # Blend with the input image.
    resized = Image.fromarray(resized)
    highlighted = Image.blend(resized, mask, 0.4)

    #colored = colored.resize((256, 256), resample=Image.Resampling.BICUBIC)
    #highlighted = highlighted.resize((256, 256), resample=Image.Resampling.BICUBIC)

    return colored, highlighted


gr.Interface(
    fn=predict,
    inputs=gr.inputs.Image(label="Upload image"),
    outputs=[gr.outputs.Image(label="Classes"), gr.outputs.Image(label="Overlay")],
    title=title,
    description=description,
    article=article,
    examples=examples,
).launch()