File size: 2,332 Bytes
d4dcd19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import random
import gradio as gr
import numpy as np
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation


# Use GPU if available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device)
model.eval()
preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade")


def visualize_instance_seg_mask(mask):
    # Initialize image
    image = np.zeros((mask.shape[0], mask.shape[1], 3))

    labels = np.unique(mask)
    label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}

    for i in range(image.shape[0]):
      for j in range(image.shape[1]):
        image[i, j, :] = label2color[mask[i, j]]

    image = image / 255
    return image

def query_image(img):
    target_size = (img.shape[0], img.shape[1])
    inputs = preprocessor(images=img, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    outputs.class_queries_logits = outputs.class_queries_logits.cpu()
    outputs.masks_queries_logits = outputs.masks_queries_logits.cpu()
    results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach()
    results = torch.argmax(results, dim=0).numpy()
    results = visualize_instance_seg_mask(results)

    return results


description = """
Gradio demo for <a href="https://huggingface.co/docs/transformers/main/en/model_doc/maskformer">MaskFormer</a>, 
introduced in <a href="https://arxiv.org/abs/2107.06278">Per-Pixel Classification is Not All You Need for Semantic Segmentation
</a>. 
\n\n"Mask2Former is a unified framework architecture based on MaskFormer meta-architecture that achieves SOTA on panoptic, 
instance and semantic segmentation across four popular datasets (ADE20K, Cityscapes, COCO, Mapillary Vistas). You can use 
MaskFormer for semantic, instance (illustrated in the demo) and panoptic segmentation.
"""

demo = gr.Interface(
    query_image, 
    inputs=[gr.Image()], 
    outputs="image",
    title="MaskFormer Demo",
    description=description,
    examples=["assets/test_image_35.png", "assets/test_image_82.png"]
)
demo.launch()