File size: 2,644 Bytes
514ba19
 
 
 
 
 
e98c94e
514ba19
 
 
 
6a33499
514ba19
 
e98c94e
 
 
514ba19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a33499
514ba19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39340ac
514ba19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e98c94e
514ba19
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import numpy as np
import tensorflow as tf
from huggingface_hub import from_pretrained_keras
from PIL import Image

MODEL_CKPT = "chansung/segmentation-training-pipeline@v1667722548"
MODEL = from_pretrained_keras(MODEL_CKPT)

RESOLTUION = 128

PETS_PALETTE = []
with open(r"./palette.txt", "r") as fp:
    for line in fp:
        if "#" not in line:
            tmp_list = list(map(int, line[:-1].strip("][").split(", ")))
            PETS_PALETTE.append(tmp_list)


def preprocess_input(image: Image) -> tf.Tensor:
    image = np.array(image)
    image = tf.convert_to_tensor(image)

    image = tf.image.resize(image, (RESOLTUION, RESOLTUION))
    image = image / 255

    return tf.expand_dims(image, 0)


# The below utility get_seg_overlay() are from:
# https://github.com/deep-diver/semantic-segmentation-ml-pipeline/blob/main/notebooks/inference_from_SavedModel.ipynb


def get_seg_overlay(image, seg):
    color_seg = np.zeros(
        (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
    )  # height, width, 3
    palette = np.array(PETS_PALETTE)

    for label, color in enumerate(palette):
        color_seg[seg == label, :] = color

    # Show image + mask
    img = np.array(image) * 0.5 + color_seg * 0.5

    img *= 255
    img = np.clip(img, 0, 255)
    img = img.astype(np.uint8)
    return img


def run_model(image: Image) -> tf.Tensor:
    preprocessed_image = preprocess_input(image)
    prediction = MODEL.predict(preprocessed_image)

    seg_mask = tf.math.argmax(prediction, -1)
    seg_mask = tf.squeeze(seg_mask)
    return seg_mask


def get_predictions(image: Image):
    predicted_segmentation_mask = run_model(image)
    preprocessed_image = preprocess_input(image)
    preprocessed_image = tf.squeeze(preprocessed_image, 0)

    pred_img = get_seg_overlay(
        preprocessed_image.numpy(), predicted_segmentation_mask.numpy()
    )
    return Image.fromarray(pred_img)


title = (
    "Simple demo for a semantic segmentation model trained on the PETS dataset."
)

description = """

Note that the outputs obtained in this demo won't be state-of-the-art. The underlying project has a different objective focusing more on the ops side of
deploying a semantic segmentation model. For more details, check out the repository: https://github.com/deep-diver/semantic-segmentation-ml-pipeline/.

"""

demo = gr.Interface(
    get_predictions,
    gr.inputs.Image(type="pil"),
    "pil",
    allow_flagging="never",
    title=title,
    description=description,
    examples=[["test-image1.png"], ["test-image2.png"], ["test-image3.png"], ["test-image4.png"], ["test-image5.png"]],
)

demo.launch()