File size: 1,813 Bytes
4fba7a2
f254011
 
4fba7a2
 
480594f
 
 
 
f254011
480594f
f254011
a49b93f
480594f
a49b93f
480594f
 
a49b93f
f254011
 
480594f
 
a49b93f
f254011
a49b93f
480594f
 
a49b93f
f254011
 
a49b93f
f254011
a49b93f
480594f
f254011
480594f
 
f254011
 
 
 
 
 
 
a49b93f
f254011
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import gradio as gr
import numpy as np
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation


feature_extractor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-coco")
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-coco")
# feature_extractor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-large-coco")
# model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-large-coco")

def visualize_instance_seg_mask(mask, id2label):
    image = np.zeros((mask.shape[0], mask.shape[1], 3))
    image_total_pixels = mask.shape[0] * mask.shape[1]
    label_ids = np.unique(mask)

    id2color = {id: (np.random.randint(0, 2), np.random.randint(0, 256), np.random.randint(0, 256)) for id in label_ids}
    id2count = {id: 0 for id in label_ids}

    for i in range(image.shape[0]):
      for j in range(image.shape[1]):
        image[i, j, :] = id2color[mask[i, j]]
        id2count[mask[i, j]] = id2count[mask[i, j]] + 1

    image = image / 255

    label2count = {id2label[id]: id2count[id] / image_total_pixels for id in label_ids}
    print(label2count)

    return image


def query_image(img):
    img_size = (img.shape[0], img.shape[1])
    inputs = feature_extractor(images=img, return_tensors="pt")
    outputs = model(**inputs)
    results = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[img_size])[0]
    results = visualize_instance_seg_mask(results.numpy(), model.config.id2label)
    return results


demo = gr.Interface(
    query_image,
    inputs=[gr.Image()],
    outputs="image",
    title="maskformer-swin-large-coco results",
    allow_flagging="never",
    analytics_enabled=None
)

demo.launch(show_api=False)