File size: 3,316 Bytes
7bb7f6b
4fba7a2
f254011
7bb7f6b
7c653a9
4fba7a2
 
7bb7f6b
 
7c653a9
 
 
 
 
 
 
 
4c76544
a49b93f
480594f
a49b93f
4c76544
 
7c653a9
4c76544
 
 
 
480594f
a49b93f
4c76544
 
 
480594f
a49b93f
7bb7f6b
a49b93f
7c653a9
a67cc84
 
 
 
 
7c653a9
a67cc84
7e0a636
 
 
a67cc84
 
 
 
 
a49b93f
7bb7f6b
 
 
 
 
 
 
 
 
f254011
a49b93f
7bb7f6b
 
a49b93f
7c653a9
f254011
7c653a9
 
7bb7f6b
f254011
 
 
8c822b0
b905339
8c822b0
138e452
7bb7f6b
1f6ce81
 
 
 
8c822b0
7bb7f6b
8c822b0
 
 
 
f254011
 
c15e860
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import glob
import gradio as gr
import numpy as np
from PIL import Image
from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor


example_images = sorted(glob.glob('examples/map*.jpg'))

model_id = f"facebook/maskformer-swin-large-coco"
vegetation_labels = ["tree-merged", "grass-merged"]

preprocessor = MaskFormerImageProcessor.from_pretrained(model_id)
model = MaskFormerForInstanceSegmentation.from_pretrained(model_id)


def visualize_instance_seg_mask(img_in, mask, id2label, included_labels):
    img_out = np.zeros((mask.shape[0], mask.shape[1], 3))
    image_total_pixels = mask.shape[0] * mask.shape[1]
    label_ids = np.unique(mask)

    def get_color(id):
        id_color = (np.random.randint(0, 2), np.random.randint(0, 4), np.random.randint(0, 256))
        if id2label[id] in included_labels:
            id_color = (0, 140, 0)
        return id_color

    id2color = {id: get_color(id) for id in label_ids}
    id2count = {id: 0 for id in label_ids}

    for i in range(img_out.shape[0]):
      for j in range(img_out.shape[1]):
        img_out[i, j, :] = id2color[mask[i, j]]
        id2count[mask[i, j]] = id2count[mask[i, j]] + 1

    image_res = (0.5 * img_in + 0.5 * img_out).astype(np.uint8)

    vegetation_count = sum([id2count[id] for id in label_ids if id2label[id] in included_labels])

    dataframe_vegetation_items = [[
        f"{id2label[id]}",
        f"{(100 * id2count[id] / image_total_pixels):.2f} %",
        f"{np.sqrt(id2count[id] / image_total_pixels):.2f} m"
        ] for id in label_ids if id2label[id] in included_labels]
    dataframe_all_items = [[
        f"{id2label[id]}",
        f"{(100 * id2count[id] / image_total_pixels):.2f} %",
        f"{np.sqrt(id2count[id] / image_total_pixels):.2f} m"
        ] for id in label_ids]
    dataframe_vegetation_total = [[
        f"vegetation",
        f"{(100 * vegetation_count / image_total_pixels):.2f} %",
        f"{np.sqrt(vegetation_count / image_total_pixels):.2f} m"]]

    dataframe = dataframe_vegetation_total
    if len(dataframe) < 1:
        dataframe = [[
            f"",
            f"{(0):.2f} %",
            f"{(0):.2f} m"
        ]]

    return image_res, dataframe


def query_image(image_path):
    img = np.array(Image.open(image_path))
    img_size = (img.shape[0], img.shape[1])
    inputs = preprocessor(images=img, return_tensors="pt")
    outputs = model(**inputs)
    results = preprocessor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[img_size])[0]
    mask_img, dataframe = visualize_instance_seg_mask(img, results.numpy(), model.config.id2label, vegetation_labels)
    return mask_img, dataframe


demo = gr.Interface(
    title="Maskformer (large-coco)",
    description="Using [facebook/maskformer-swin-large-coco](https://huggingface.co/facebook/maskformer-swin-large-coco) model to calculate percentage of pixels in an image that belong to vegetation.",

    fn=query_image,
    inputs=[gr.Image(type="filepath", label="Input Image")],
    outputs=[
        gr.Image(label="Vegetation"),
        gr.DataFrame(label="Info", headers=["Object Label", "Pixel Percent", "Square Length"])
    ],

    examples=example_images,
    cache_examples=True,

    allow_flagging="never",
    analytics_enabled=None
)

demo.launch(show_api=False)