Matthijs Hollemans
overlay mask on original image
6a36cd0
raw history blame
No virus
2.44 kB
import numpy as np
import gradio as gr
from PIL import Image
import torch
from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation
model_checkpoint = "apple/deeplabv3-mobilevit-small"
feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_checkpoint) #, do_center_crop=False, size=(512, 512))
model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval()
palette = np.array(
[
[ 0, 0, 0], [192, 0, 0], [ 0, 192, 0], [192, 192, 0],
[ 0, 0, 192], [192, 0, 192], [ 0, 192, 192], [192, 192, 192],
[128, 0, 0], [255, 0, 0], [128, 192, 0], [255, 192, 0],
[128, 0, 192], [255, 0, 192], [128, 192, 192], [255, 192, 192],
[ 0, 128, 0], [192, 128, 0], [ 0, 255, 0], [192, 255, 0],
[ 0, 128, 192]
],
dtype=np.uint8)
def predict(image):
with torch.no_grad():
inputs = feature_extractor(image, return_tensors="pt")
outputs = model(**inputs)
# Get preprocessed image. The pixel values don't need to be unnormalized
# for this particular model.
resized = (inputs["pixel_values"].numpy().squeeze().transpose(1, 2, 0)[..., ::-1] * 255).astype(np.uint8)
# Class predictions for each pixel.
classes = outputs.logits.argmax(1).squeeze().numpy().astype(np.uint8)
# Super slow method but it works
colored = np.zeros((classes.shape[0], classes.shape[1], 3), dtype=np.uint8)
for y in range(classes.shape[0]):
for x in range(classes.shape[1]):
colored[y, x] = palette[classes[y, x]]
# Resize predictions to input size (not original size).
colored = Image.fromarray(colored)
colored = colored.resize((resized.shape[1], resized.shape[0]), resample=Image.NEAREST)
# Keep everything that is not background.
mask = (classes != 0) * 255
mask = Image.fromarray(mask.astype(np.uint8)).convert("RGB")
mask = mask.resize((resized.shape[1], resized.shape[0]), resample=Image.NEAREST)
# Blend with the input image.
resized = Image.fromarray(resized)
highlighted = Image.blend(resized, mask, 0.4)
return colored, highlighted
gr.Interface(
fn=predict,
inputs=gr.inputs.Image(label="Upload image"),
outputs=[gr.outputs.Image(label="Classes"), gr.outputs.Image(label="Highlighted")],
title="Semantic Segmentation with MobileViT and DeepLabV3",
).launch()
# TODO: combo box with some example images