File size: 3,848 Bytes
62cf953
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
import matplotlib.pyplot as plt
import tensorflow as tf

loaded_model = tf.saved_model.load("model/")
loaded_model = loaded_model.signatures["serving_default"]


def get_target_shape(original_shape):
    original_aspect_ratio = original_shape[0] / original_shape[1]

    square_mode = abs(original_aspect_ratio - 1.0)
    landscape_mode = abs(original_aspect_ratio - 240 / 320)
    portrait_mode = abs(original_aspect_ratio - 320 / 240)

    best_mode = min(square_mode, landscape_mode, portrait_mode)

    if best_mode == square_mode:
        target_shape = (320, 320)
    elif best_mode == landscape_mode:
        target_shape = (240, 320)
    else:
        target_shape = (320, 240)

    return target_shape


def preprocess_input(input_image, target_shape):
    input_tensor = tf.expand_dims(input_image, axis=0)

    input_tensor = tf.image.resize(
        input_tensor, target_shape, preserve_aspect_ratio=True
    )

    vertical_padding = target_shape[0] - input_tensor.shape[1]
    horizontal_padding = target_shape[1] - input_tensor.shape[2]

    vertical_padding_1 = vertical_padding // 2
    vertical_padding_2 = vertical_padding - vertical_padding_1

    horizontal_padding_1 = horizontal_padding // 2
    horizontal_padding_2 = horizontal_padding - horizontal_padding_1

    input_tensor = tf.pad(
        input_tensor,
        [
            [0, 0],
            [vertical_padding_1, vertical_padding_2],
            [horizontal_padding_1, horizontal_padding_2],
            [0, 0],
        ],
    )

    return (
        input_tensor,
        [vertical_padding_1, vertical_padding_2],
        [horizontal_padding_1, horizontal_padding_2],
    )


def postprocess_output(
    output_tensor, vertical_padding, horizontal_padding, original_shape
):
    output_tensor = output_tensor[
        :,
        vertical_padding[0] : output_tensor.shape[1] - vertical_padding[1],
        horizontal_padding[0] : output_tensor.shape[2] - horizontal_padding[1],
        :,
    ]

    output_tensor = tf.image.resize(output_tensor, original_shape)

    output_array = output_tensor.numpy().squeeze()
    output_array = plt.cm.inferno(output_array)[..., :3]

    return output_array


def compute_saliency(input_image, alpha=0.65):
    if input_image is not None:
        original_shape = input_image.shape[:2]
        target_shape = get_target_shape(original_shape)

        input_tensor, vertical_padding, horizontal_padding = preprocess_input(
            input_image, target_shape
        )

        saliency_map = loaded_model(input_tensor)["output"]

        saliency_map = postprocess_output(
            saliency_map, vertical_padding, horizontal_padding, original_shape
        )

        blended_image = alpha * saliency_map + (1 - alpha) * input_image / 255

        return blended_image


examples = [
    "examples/kirsten-frank-o1sXiz_LU1A-unsplash.jpg",
    "examples/oscar-fickel-F5ze5FkEu1g-unsplash.jpg",
    "examples/ting-tian-_79ZJS8pV70-unsplash.jpg",
    "examples/gina-domenique-LmrAUrHinqk-unsplash.jpg",
    "examples/robby-mccullough-r05GkQBcaPM-unsplash.jpg",
]

demo = gr.Interface(
    fn=compute_saliency,
    inputs=gr.Image(label="Input Image"),
    outputs=gr.Image(label="Saliency Map"),
    examples=examples,
    title="Visual Saliency Prediction",
    description="A demo to predict where humans fixate an image using a deep learning model trained on eye movement data. Upload an image file, take a snapshot from your webcam, or paste an image from the clipboard to compute the saliency map.",
    article="For more information on the model, check out [GitHub](https://github.com/alexanderkroner/saliency) and the corresponding [paper](https://www.sciencedirect.com/science/article/pii/S0893608020301660).",
    allow_flagging="never",
)

if __name__ == "__main__":
    demo.queue().launch()