|
import torch |
|
import random |
|
import gradio as gr |
|
import numpy as np |
|
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation |
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device) |
|
model.eval() |
|
preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade") |
|
|
|
|
|
def visualize_instance_seg_mask(mask): |
|
|
|
image = np.zeros((mask.shape[0], mask.shape[1], 3)) |
|
|
|
labels = np.unique(mask) |
|
label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels} |
|
|
|
for i in range(image.shape[0]): |
|
for j in range(image.shape[1]): |
|
image[i, j, :] = label2color[mask[i, j]] |
|
|
|
image = image / 255 |
|
return image |
|
|
|
def query_image(img): |
|
target_size = (img.shape[0], img.shape[1]) |
|
inputs = preprocessor(images=img, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
outputs.class_queries_logits = outputs.class_queries_logits.cpu() |
|
outputs.masks_queries_logits = outputs.masks_queries_logits.cpu() |
|
results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach() |
|
results = torch.argmax(results, dim=0).numpy() |
|
results = visualize_instance_seg_mask(results) |
|
|
|
return results |
|
|
|
|
|
description = """ |
|
Gradio demo for <a href="https://huggingface.co/docs/transformers/main/en/model_doc/maskformer">MaskFormer</a>, |
|
introduced in <a href="https://arxiv.org/abs/2107.06278">Per-Pixel Classification is Not All You Need for Semantic Segmentation |
|
</a>. |
|
\n\n"MaskFormer is a unified framework for panoptic, instance and semantic segmentation, trained across four popular datasets (ADE20K, Cityscapes, COCO, Mapillary Vistas). |
|
""" |
|
|
|
demo = gr.Interface( |
|
query_image, |
|
inputs=[gr.Image()], |
|
outputs="image", |
|
title="MaskFormer Demo", |
|
description=description, |
|
examples=["assets/test_image_35.png", "assets/test_image_82.png"] |
|
) |
|
demo.launch() |