Matthijs Hollemans
segmentation demo
c304fb7
raw
history blame
No virus
1.87 kB
import numpy as np
import gradio as gr
from PIL import Image
import torch
from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation
model_checkpoint = "apple/deeplabv3-mobilevit-small"
feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_checkpoint, do_center_crop=False, size=(512, 512))
model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval()
# From https://gist.github.com/kaixin96/457cc3d3be699f1f5b2fd4cdb638d4b4
palette = np.array([
[ 0, 0, 0], [128, 0, 0], [ 0, 128, 0], [128, 128, 0], [ 0, 0, 128],
[128, 0, 128], [ 0, 128, 128], [128, 128, 128], [ 64, 0, 0], [192, 0, 0],
[ 64, 128, 0], [192, 128, 0], [ 64, 0, 128], [192, 0, 128], [ 64, 128, 128],
[192, 128, 128], [ 0, 64, 0], [128, 64, 0], [ 0, 192, 0], [128, 192, 0],
[ 0, 64, 128]], dtype=np.uint8)
def predict(image):
with torch.no_grad():
inputs = feature_extractor(image, return_tensors="pt")
outputs = model(**inputs)
classes = outputs.logits.argmax(1).squeeze().numpy().astype(np.uint8)
# Super slow method but it works
colored = np.zeros((classes.shape[0], classes.shape[1], 3), dtype=np.uint8)
for y in range(classes.shape[0]):
for x in range(classes.shape[1]):
colored[y, x] = palette[classes[y, x]]
# TODO: overlay mask on image?
out_image = Image.fromarray(colored)
out_image = out_image.resize((image.shape[1], image.shape[0]), resample=Image.NEAREST)
return out_image
gr.Interface(
fn=predict,
inputs=gr.inputs.Image(label="Upload image"),
outputs=gr.outputs.Image(),
title="Semantic Segmentation with MobileViT and DeepLabV3",
).launch()
# TODO: combo box with some example images
# TODO: combo box with classes to show on the output, if none then do argmax