Matthijs Hollemans
no comma between labels
7a8487d
raw history blame
No virus
4.26 kB
import numpy as np
import gradio as gr
from PIL import Image
import torch
from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation
model_checkpoint = "apple/deeplabv3-mobilevit-small"
feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_checkpoint)
model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval()
palette = np.array(
[
[ 0, 0, 0], [192, 0, 0], [ 0, 192, 0], [192, 192, 0],
[ 0, 0, 192], [192, 0, 192], [ 0, 192, 192], [192, 192, 192],
[128, 0, 0], [255, 0, 0], [128, 192, 0], [255, 192, 0],
[128, 0, 192], [255, 0, 192], [128, 192, 192], [255, 192, 192],
[ 0, 128, 0], [192, 128, 0], [ 0, 255, 0], [192, 255, 0],
[ 0, 128, 192]
],
dtype=np.uint8)
labels = [
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
# Draw the labels. Light colors use black text, dark colors use white text.
inverted = [ 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20 ]
labels_colored = []
for i in range(len(labels)):
r, g, b = palette[i]
label = labels[i]
color = "white" if i in inverted else "black"
text = "<span style='background-color: rgb(%d, %d, %d); color: %s; padding: 2px 4px;'>%s</span>" % (r, g, b, color, label)
labels_colored.append(text)
labels_text = " ".join(labels_colored)
title = "Semantic Segmentation with MobileViT and DeepLabV3"
description = """
The input image is resized and center cropped to 512Γ—512 pixels. The segmentation output is 32Γ—32 pixels.<br>
This model has been trained on <a href="http://host.robots.ox.ac.uk/pascal/VOC/">Pascal VOC</a>.
The classes are:
""" + labels_text + "</p>"
article = """
<div style='margin:20px auto;'>
<p>Sources:<p>
<p>πŸ“œ <a href="https://arxiv.org/abs/2110.02178">MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer</a></p>
<p>πŸ‹οΈ Original pretrained weights from <a href="https://github.com/apple/ml-cvnets">this GitHub repo</a></p>
<p>πŸ™ Example images from <a href="https://huggingface.co/datasets/mishig/sample_images">this dataset</a><p>
</div>
"""
examples = [
["cat-3.jpg"],
["construction-site.jpg"],
["dog-cat.jpg"],
["football-match.jpg"],
]
def predict(image):
with torch.no_grad():
inputs = feature_extractor(image, return_tensors="pt")
outputs = model(**inputs)
# Get preprocessed image. The pixel values don't need to be unnormalized
# for this particular model.
resized = (inputs["pixel_values"].numpy().squeeze().transpose(1, 2, 0)[..., ::-1] * 255).astype(np.uint8)
# Class predictions for each pixel.
classes = outputs.logits.argmax(1).squeeze().numpy().astype(np.uint8)
# Super slow method but it works... should probably improve this.
colored = np.zeros((classes.shape[0], classes.shape[1], 3), dtype=np.uint8)
for y in range(classes.shape[0]):
for x in range(classes.shape[1]):
colored[y, x] = palette[classes[y, x]]
# Resize predictions to input size (not original size).
colored = Image.fromarray(colored)
colored = colored.resize((resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST)
# Keep everything that is not background.
mask = (classes != 0) * 255
mask = Image.fromarray(mask.astype(np.uint8)).convert("RGB")
mask = mask.resize((resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST)
# Blend with the input image.
resized = Image.fromarray(resized)
highlighted = Image.blend(resized, mask, 0.4)
#colored = colored.resize((256, 256), resample=Image.Resampling.BICUBIC)
#highlighted = highlighted.resize((256, 256), resample=Image.Resampling.BICUBIC)
return colored, highlighted
gr.Interface(
fn=predict,
inputs=gr.inputs.Image(label="Upload image"),
outputs=[gr.outputs.Image(label="Classes"), gr.outputs.Image(label="Overlay")],
title=title,
description=description,
article=article,
examples=examples,
).launch()