ohjho commited on
Commit
70f6db8
1 Parent(s): b99a21f

added st app for testing

Browse files
Files changed (5) hide show
  1. README.md +96 -1
  2. app.py +101 -0
  3. download.py +475 -0
  4. requirements.txt +6 -0
  5. run_gradio.py +126 -0
README.md CHANGED
@@ -8,5 +8,100 @@ sdk_version: 1.9.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  app_file: app.py
9
  pinned: false
10
  ---
11
+ # Saliency Based Image Cropping
12
 
13
+ This repo was forked by the [Miro team](https://miro.io/#) to create the interface [here]()
14
+
15
+
16
+
17
+
18
+
19
+ # Contextual Encoder-Decoder Network <br/> for Visual Saliency Prediction
20
+
21
+ ![](https://img.shields.io/badge/python-v3.6.8-orange.svg?style=flat-square)
22
+ ![](https://img.shields.io/badge/tensorflow-v1.13.1-orange.svg?style=flat-square)
23
+ ![](https://img.shields.io/badge/matplotlib-v3.0.3-orange.svg?style=flat-square)
24
+ ![](https://img.shields.io/badge/requests-v2.21.0-orange.svg?style=flat-square)
25
+
26
+ <img src="./figures/results.jpg" width="800"/>
27
+
28
+ This repository contains the official *TensorFlow* implementation of the MSI-Net (multi-scale information network), as described in the Neural Networks paper [Contextual encoder-decoder network for visual saliency prediction](https://www.sciencedirect.com/science/article/pii/S0893608020301660) (2020) and on [arXiv](https://arxiv.org/abs/1902.06634).
29
+
30
+ **_Abstract:_** *Predicting salient regions in natural images requires the detection of objects that are present in a scene. To develop robust representations for this challenging task, high-level visual features at multiple spatial scales must be extracted and augmented with contextual information. However, existing models aimed at explaining human fixation maps do not incorporate such a mechanism explicitly. Here we propose an approach based on a convolutional neural network pre-trained on a large-scale image classification task. The architecture forms an encoder-decoder structure and includes a module with multiple convolutional layers at different dilation rates to capture multi-scale features in parallel. Moreover, we combine the resulting representations with global scene information for accurately predicting visual saliency. Our model achieves competitive and consistent results across multiple evaluation metrics on two public saliency benchmarks and we demonstrate the effectiveness of the suggested approach on five datasets and selected examples. Compared to state of the art approaches, the network is based on a lightweight image classification backbone and hence presents a suitable choice for applications with limited computational resources, such as (virtual) robotic systems, to estimate human fixations across complex natural scenes.*
31
+
32
+ Our results are available on the original [MIT saliency benchmark](http://saliency.mit.edu/results.html) and the updated [MIT/Tübingen saliency benchmark](https://saliency.tuebingen.ai/results.html). The latter are derived from a probabilistic version of our model with metric-specific postprocessing for a fair model comparison.
33
+
34
+ ## Reference
35
+
36
+ If you use this code in your research, please cite the following paper:
37
+
38
+ ```
39
+ @article{kroner2020contextual,
40
+ title={Contextual encoder-decoder network for visual saliency prediction},
41
+ author={Kroner, Alexander and Senden, Mario and Driessens, Kurt and Goebel, Rainer},
42
+ url={http://www.sciencedirect.com/science/article/pii/S0893608020301660},
43
+ doi={https://doi.org/10.1016/j.neunet.2020.05.004},
44
+ journal={Neural Networks},
45
+ publisher={Elsevier},
46
+ year={2020},
47
+ volume={129},
48
+ pages={261--270},
49
+ issn={0893-6080}
50
+ }
51
+ ```
52
+
53
+ ## Architecture
54
+
55
+ <img src="./figures/architecture.jpg" width="700"/>
56
+
57
+ ## Requirements
58
+
59
+ | Package | Version |
60
+ |:----------:|:-------:|
61
+ | python | 3.6.8 |
62
+ | tensorflow | 1.13.1 |
63
+ | matplotlib | 3.0.3 |
64
+ | requests | 2.21.0 |
65
+ | scipy | 1.4.1 |
66
+
67
+ The code was tested and is compatible with both Windows and Linux. We strongly recommend to use *TensorFlow* with GPU acceleration, especially when training the model. Nevertheless, a slower CPU version is officially supported.
68
+
69
+ ## Training
70
+
71
+ The results of our paper can be reproduced by first training the MSI-Net via the following command:
72
+
73
+ ```
74
+ python main.py train
75
+ ```
76
+
77
+ This will start the training procedure for the SALICON dataset with the hyperparameters defined in `config.py`. If you want to optimize the model for CPU usage, please change the corresponding `device` value in the configurations file. Optionally, the dataset and download path can be specified via command line arguments:
78
+
79
+ ```
80
+ python main.py train -d DATA -p PATH
81
+ ```
82
+
83
+ Here, the `DATA` argument must be `salicon`, `mit1003`, `cat2000`, `dutomron`, `pascals`, `osie`, or `fiwi`. It is required that the model is first trained on the SALICON dataset before fine-tuning it on any of the other ones. By default, the selected saliency dataset will be downloaded to the folder `data/` but you can point to a different directory via the `PATH` argument.
84
+
85
+ All results are then stored under the folder `results/`, which contains the training history and model checkpoints. This allows to continue training or perform inference on test instances, as described in the next section.
86
+
87
+ ## Testing
88
+
89
+ To test a pre-trained model on image data and produce saliency maps, execute the following command:
90
+
91
+ ```
92
+ python main.py test -d DATA -p PATH
93
+ ```
94
+
95
+ If no checkpoint is available from prior training, it will automatically download our pre-trained model to `weights/`. The `DATA` argument defines which network is used and must be `salicon`, `mit1003`, `cat2000`, `dutomron`, `pascals`, `osie`, or `fiwi`. It will then resize the input images to the dimensions specified in the configurations file. Note that this might lead to excessive image padding depending on the selected dataset.
96
+
97
+ The `PATH` argument points to the folder where the test data is stored but can also denote a single image file directly. As for network training, the `device` value can be changed to CPU in the configurations file. This ensures that the model optimized for CPU will be utilized and hence improves the inference speed. All results are finally stored in the folder `results/images/` with the original image dimensions.
98
+
99
+ ## Demo
100
+
101
+ <img src="./demo/demo.gif" width="750"/>
102
+
103
+ A demonstration of saliency prediction in the browser is available [here](https://storage.googleapis.com/msi-net/demo/index.html). It computes saliency maps based on the input from a webcam via *TensorFlow.js*. Since the library uses the machine's hardware, model performance is dependent on your local configuration. The buttons allow you to select the quality, ranging from *very low* for a version trained on low image resolution with high inference speed, to *very high* for a version trained on high image resolution with slow inference speed.
104
+
105
+ ## Contact
106
+
107
+ For questions, bug reports, and suggestions about this work, please create an [issue](https://github.com/alexanderkroner/saliency/issues) in this repository.
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ import os, sys, io
4
+ import urllib.request as urllib
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ from run_gradio import load_model, test_model
9
+
10
+ ### Some Utils Functions ###
11
+ def get_image(st_asset = st.sidebar, as_np_arr = False, extension_list = ['jpg', 'jpeg', 'png']):
12
+ image_url, image_fh = None, None
13
+ if st_asset.checkbox('use image URL?'):
14
+ image_url = st_asset.text_input("Enter Image URL")
15
+ else:
16
+ image_fh = st_asset.file_uploader(label = "Update your image", type = extension_list)
17
+
18
+ im = None
19
+ if image_url:
20
+ response = urllib.urlopen(image_url)
21
+ im = Image.open(io.BytesIO(bytearray(response.read())))
22
+ elif image_fh:
23
+ im = Image.open(image_fh)
24
+
25
+ if im and as_np_arr:
26
+ im = np.array(im)
27
+ return im
28
+
29
+ def show_miro_logo(use_column_width = False, width = 100, st_asset= st.sidebar):
30
+ logo_url = 'https://miro.medium.com/max/1400/0*qLL-32srlq6Y_iTm.png'
31
+ st_asset.image(logo_url, use_column_width = use_column_width, channels = 'BGR', output_format = 'PNG', width = width)
32
+
33
+ def im_draw_bbox(pil_im, x0, y0, x1, y1, color = 'black', width = 3, caption = None,
34
+ bbv_label_only = False):
35
+ '''
36
+ draw bounding box on the input image pil_im in-place
37
+ Args:
38
+ color: color name as read by Pillow.ImageColor
39
+ use_bbv: use bbox_visualizer
40
+ '''
41
+ import bbox_visualizer as bbv
42
+ if any([type(i)== float for i in [x0,y0,x1,y1]]):
43
+ warnings.warn(f'im_draw_bbox: at least one of x0,y0,x1,y1 is of the type float and is converted to int.')
44
+ x0 = int(x0)
45
+ y0 = int(y0)
46
+ x1 = int(x1)
47
+ y1 = int(y1)
48
+
49
+ if bbv_label_only:
50
+ if caption:
51
+ im_array = bbv.draw_flag_with_label(np.array(pil_im),
52
+ label = caption,
53
+ bbox = [x0,y0,x1,y1],
54
+ line_color = ImageColor.getrgb(color),
55
+ text_bg_color = ImageColor.getrgb(color)
56
+ )
57
+ else:
58
+ raise ValueError(f'im_draw_bbox: bbv_label_only is True but caption is None')
59
+ else:
60
+ im_array = bbv.draw_rectangle(np.array(pil_im),
61
+ bbox = [x0, y0, x1, y1],
62
+ bbox_color = ImageColor.getrgb(color),
63
+ thickness = width
64
+ )
65
+ im_array = bbv.add_label(
66
+ im_array, label = caption,
67
+ bbox = [x0,y0,x1,y1],
68
+ text_bg_color = ImageColor.getrgb(color)
69
+ )if caption else im_array
70
+ return Image.fromarray(im_array)
71
+
72
+ ### Streamlit App ###
73
+
74
+ def Main(model_dict):
75
+ st.set_page_config(layout = 'wide')
76
+ show_miro_logo()
77
+ with st.sidebar.expander('Saliency Demo'):
78
+ st.info(f'''
79
+ [TensorFlow Implementation of MSI-Net](https://github.com/alexanderkroner/saliency)
80
+ which archived
81
+ [SoTA performance](https://saliency.tuebingen.ai/results.html) on the
82
+ [MIT Saliency Benchmark dataset](http://saliency.mit.edu/datasets.html)
83
+ ''')
84
+
85
+ im = get_image(st_asset = st.sidebar.expander('Input Image', expanded = True), extension_list = ['jpg','jpeg'])
86
+ aspect_ratio = st.sidebar.selectbox('aspect ratio', help = 'to demo saliency cropping',
87
+ options = ['','16x9','4x3'])
88
+ if im:
89
+ aspect_ratio_tup = tuple([int(i) for i in aspect_ratio.split('x')]) if aspect_ratio else None
90
+ saliency_im = test_model(np.array(im), model_dict = model_dict,
91
+ aspect_ratio_tup = aspect_ratio_tup)
92
+
93
+ l_col, r_col = st.columns(2)
94
+ l_col.image(im, caption = 'Input Image')
95
+ r_col.image(saliency_im, caption = 'Saliency Map')
96
+ else:
97
+ st.warning(f':point_left: please provide an image')
98
+
99
+ if __name__ == '__main__':
100
+ model_dict = load_model()
101
+ Main(model_dict = model_dict)
download.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import zipfile
4
+
5
+ import gdown
6
+ import h5py
7
+ import numpy as np
8
+ import requests
9
+ from matplotlib.pyplot import imread, imsave
10
+ from scipy.io import loadmat
11
+ from scipy.ndimage import gaussian_filter
12
+
13
+
14
+ def download_salicon(data_path):
15
+ """Downloads the SALICON dataset. Three folders are then created that
16
+ contain the stimuli, binary fixation maps, and blurred saliency
17
+ distributions respectively.
18
+
19
+ Args:
20
+ data_path (str): Defines the path where the dataset will be
21
+ downloaded and extracted to.
22
+
23
+ .. seealso:: The code for downloading files from google drive is based
24
+ on the solution provided at [https://bit.ly/2JSVgMQ].
25
+ """
26
+
27
+ print(">> Downloading SALICON dataset...", end="", flush=True)
28
+
29
+ default_path = data_path + "salicon/"
30
+ fixations_path = default_path + "fixations/"
31
+ saliency_path = default_path + "saliency/"
32
+
33
+ os.makedirs(fixations_path, exist_ok=True)
34
+ os.makedirs(saliency_path, exist_ok=True)
35
+
36
+ ids = ["1g8j-hTT-51IG1UFwP0xTGhLdgIUCW5e5",
37
+ "1P-jeZXCsjoKO79OhFUgnj6FGcyvmLDPj",
38
+ "1PnO7szbdub1559LfjYHMy65EDC4VhJC8"]
39
+
40
+ urls = ["https://drive.google.com/uc?id=" +
41
+ i + "&export=download" for i in ids]
42
+
43
+ save_paths = [default_path, fixations_path, saliency_path]
44
+
45
+ for count, url in enumerate(urls):
46
+ gdown.download(url, data_path + "tmp.zip", quiet=True)
47
+
48
+ with zipfile.ZipFile(data_path + "tmp.zip", "r") as zip_ref:
49
+ for file in zip_ref.namelist():
50
+ if "test" not in file:
51
+ zip_ref.extract(file, save_paths[count])
52
+
53
+ os.rename(default_path + "images", default_path + "stimuli")
54
+
55
+ os.remove(data_path + "tmp.zip")
56
+
57
+ print("done!", flush=True)
58
+
59
+
60
+ def download_mit1003(data_path):
61
+ """Downloads the MIT1003 dataset. Three folders are then created that
62
+ contain the stimuli, binary fixation maps, and blurred saliency
63
+ distributions respectively.
64
+
65
+ Args:
66
+ data_path (str): Defines the path where the dataset will be
67
+ downloaded and extracted to.
68
+ """
69
+
70
+ print(">> Downloading MIT1003 dataset...", end="", flush=True)
71
+
72
+ default_path = data_path + "mit1003/"
73
+ stimuli_path = default_path + "stimuli/"
74
+ fixations_path = default_path + "fixations/"
75
+ saliency_path = default_path + "saliency/"
76
+
77
+ os.makedirs(stimuli_path, exist_ok=True)
78
+ os.makedirs(fixations_path, exist_ok=True)
79
+ os.makedirs(saliency_path, exist_ok=True)
80
+
81
+ url = "https://people.csail.mit.edu/tjudd/WherePeopleLook/ALLSTIMULI.zip"
82
+
83
+ with open(data_path + "tmp.zip", "wb") as f:
84
+ f.write(requests.get(url).content)
85
+
86
+ with zipfile.ZipFile(data_path + "tmp.zip", "r") as zip_ref:
87
+ for file in zip_ref.namelist():
88
+ if file.endswith(".jpeg"):
89
+ file_name = os.path.split(file)[1]
90
+ file_path = stimuli_path + file_name
91
+
92
+ with open(file_path, "wb") as stimulus:
93
+ stimulus.write(zip_ref.read(file))
94
+
95
+ url = "https://people.csail.mit.edu/tjudd/WherePeopleLook/ALLFIXATIONMAPS.zip"
96
+
97
+ with open(data_path + "tmp.zip", "wb") as f:
98
+ f.write(requests.get(url).content)
99
+
100
+ with zipfile.ZipFile(data_path + "tmp.zip", "r") as zip_ref:
101
+ for file in zip_ref.namelist():
102
+ file_name = os.path.split(file)[1]
103
+
104
+ if file.endswith("Pts.jpg"):
105
+ file_path = fixations_path + file_name
106
+
107
+ # this file is mistakenly included in the dataset and can be ignored
108
+ if file_name == "i05june05_static_street_boston_p1010764fixPts.jpg":
109
+ continue
110
+
111
+ with open(file_path, "wb") as fixations:
112
+ fixations.write(zip_ref.read(file))
113
+
114
+ elif file.endswith("Map.jpg"):
115
+ file_path = saliency_path + file_name
116
+
117
+ with open(file_path, "wb") as saliency:
118
+ saliency.write(zip_ref.read(file))
119
+
120
+ os.remove(data_path + "tmp.zip")
121
+
122
+ print("done!", flush=True)
123
+
124
+
125
+ def download_cat2000(data_path):
126
+ """Downloads the CAT2000 dataset. Three folders are then created that
127
+ contain the stimuli, binary fixation maps, and blurred saliency
128
+ distributions respectively.
129
+
130
+ Args:
131
+ data_path (str): Defines the path where the dataset will be
132
+ downloaded and extracted to.
133
+ """
134
+
135
+ print(">> Downloading CAT2000 dataset...", end="", flush=True)
136
+
137
+ default_path = data_path + "cat2000/"
138
+
139
+ os.makedirs(data_path, exist_ok=True)
140
+
141
+ url = "http://saliency.mit.edu/trainSet.zip"
142
+
143
+ with open(data_path + "tmp.zip", "wb") as f:
144
+ f.write(requests.get(url).content)
145
+
146
+ with zipfile.ZipFile(data_path + "tmp.zip", "r") as zip_ref:
147
+ for file in zip_ref.namelist():
148
+ if not("Output" in file or "allFixData" in file):
149
+ zip_ref.extract(file, data_path)
150
+
151
+ os.rename(data_path + "trainSet/", default_path)
152
+
153
+ os.rename(default_path + "Stimuli", default_path + "stimuli")
154
+ os.rename(default_path + "FIXATIONLOCS", default_path + "fixations")
155
+ os.rename(default_path + "FIXATIONMAPS", default_path + "saliency")
156
+
157
+ os.remove(data_path + "tmp.zip")
158
+
159
+ print("done!", flush=True)
160
+
161
+
162
+ def download_dutomron(data_path):
163
+ """Downloads the DUT-OMRON dataset. Three folders are then created that
164
+ contain the stimuli, binary fixation maps, and blurred saliency
165
+ distributions respectively.
166
+
167
+ Args:
168
+ data_path (str): Defines the path where the dataset will be
169
+ downloaded and extracted to.
170
+ """
171
+
172
+ print(">> Downloading DUTOMRON dataset...", end="", flush=True)
173
+
174
+ default_path = data_path + "dutomron/"
175
+ stimuli_path = default_path + "stimuli/"
176
+ fixations_path = default_path + "fixations/"
177
+ saliency_path = default_path + "saliency/"
178
+
179
+ os.makedirs(stimuli_path, exist_ok=True)
180
+ os.makedirs(fixations_path, exist_ok=True)
181
+ os.makedirs(saliency_path, exist_ok=True)
182
+
183
+ url = "http://saliencydetection.net/dut-omron/download/DUT-OMRON-image.zip"
184
+
185
+ with open(data_path + "tmp.zip", "wb") as f:
186
+ f.write(requests.get(url).content)
187
+
188
+ with zipfile.ZipFile(data_path + "tmp.zip", "r") as zip_ref:
189
+ for file in zip_ref.namelist():
190
+ if file.endswith(".jpg") and "._" not in file:
191
+ file_name = os.path.basename(file)
192
+ file_path = stimuli_path + file_name
193
+
194
+ with open(file_path, "wb") as stimulus:
195
+ stimulus.write(zip_ref.read(file))
196
+
197
+ url = "http://saliencydetection.net/dut-omron/download/DUT-OMRON-eye-fixations.zip"
198
+
199
+ with open(data_path + "tmp.zip", "wb") as f:
200
+ f.write(requests.get(url).content)
201
+
202
+ with zipfile.ZipFile(data_path + "tmp.zip", "r") as zip_ref:
203
+ for file in zip_ref.namelist():
204
+ if file.endswith(".mat") and "._" not in file:
205
+ file_name = os.path.basename(file)
206
+ file_name = os.path.splitext(file_name)[0] + ".png"
207
+
208
+ loaded_zip = io.BytesIO(zip_ref.read(file))
209
+
210
+ fixations = loadmat(loaded_zip)["s"]
211
+ sorted_idx = fixations[:, 2].argsort()
212
+ fixations = fixations[sorted_idx]
213
+
214
+ size = fixations[0, :2]
215
+
216
+ fixations_map = np.zeros((size[1], size[0]))
217
+
218
+ fixations_map[fixations[1:, 1],
219
+ fixations[1:, 0]] = 1
220
+
221
+ saliency_map = gaussian_filter(fixations_map, 16)
222
+
223
+ imsave(saliency_path + file_name, saliency_map, cmap="gray")
224
+ imsave(fixations_path + file_name, fixations_map, cmap="gray")
225
+
226
+ os.remove(data_path + "tmp.zip")
227
+
228
+ print("done!", flush=True)
229
+
230
+
231
+ def download_pascals(data_path):
232
+ """Downloads the PASCAL-S dataset. Three folders are then created that
233
+ contain the stimuli, binary fixation maps, and blurred saliency
234
+ distributions respectively.
235
+
236
+ Args:
237
+ data_path (str): Defines the path where the dataset will be
238
+ downloaded and extracted to.
239
+ """
240
+
241
+ print(">> Downloading PASCALS dataset...", end="", flush=True)
242
+
243
+ default_path = data_path + "pascals/"
244
+ stimuli_path = default_path + "stimuli/"
245
+ fixations_path = default_path + "fixations/"
246
+ saliency_path = default_path + "saliency/"
247
+
248
+ os.makedirs(stimuli_path, exist_ok=True)
249
+ os.makedirs(fixations_path, exist_ok=True)
250
+ os.makedirs(saliency_path, exist_ok=True)
251
+
252
+ url = "http://cbs.ic.gatech.edu/salobj/download/salObj.zip"
253
+
254
+ with open(data_path + "tmp.zip", "wb") as f:
255
+ f.write(requests.get(url).content)
256
+
257
+ with zipfile.ZipFile(data_path + "tmp.zip", "r") as zip_ref:
258
+ for file in zip_ref.namelist():
259
+ file_name = os.path.basename(file)
260
+
261
+ if file.endswith(".jpg") and "imgs/pascal" in file:
262
+ file_path = stimuli_path + file_name
263
+
264
+ with open(file_path, "wb") as stimulus:
265
+ stimulus.write(zip_ref.read(file))
266
+
267
+ elif file.endswith(".png") and "pascal/humanFix" in file:
268
+ file_path = saliency_path + file_name
269
+
270
+ with open(file_path, "wb") as saliency:
271
+ saliency.write(zip_ref.read(file))
272
+
273
+ elif "pascalFix.mat" in file:
274
+ loaded_zip = io.BytesIO(zip_ref.read(file))
275
+
276
+ with h5py.File(loaded_zip, "r") as f:
277
+ fixations = np.array(f.get("fixCell"))[0]
278
+
279
+ fixations_list = []
280
+
281
+ for reference in fixations:
282
+ obj = np.array(f[reference])
283
+ obj = np.stack((obj[0], obj[1]), axis=-1)
284
+ fixations_list.append(obj)
285
+
286
+ elif "pascalSize.mat" in file:
287
+ loaded_zip = io.BytesIO(zip_ref.read(file))
288
+
289
+ with h5py.File(loaded_zip, "r") as f:
290
+ sizes = np.array(f.get("sizeData"))
291
+ sizes = np.transpose(sizes, (1, 0))
292
+
293
+ for idx, value in enumerate(fixations_list):
294
+ size = [int(x) for x in sizes[idx]]
295
+ fixations_map = np.zeros(size)
296
+
297
+ for fixation in value:
298
+ fixations_map[int(fixation[0]) - 1,
299
+ int(fixation[1]) - 1] = 1
300
+
301
+ file_name = str(idx + 1) + ".png"
302
+ file_path = fixations_path + file_name
303
+
304
+ imsave(file_path, fixations_map, cmap="gray")
305
+
306
+ os.remove(data_path + "tmp.zip")
307
+
308
+ print("done!", flush=True)
309
+
310
+
311
+ def download_osie(data_path):
312
+ """Downloads the OSIE dataset. Three folders are then created that
313
+ contain the stimuli, binary fixation maps, and blurred saliency
314
+ distributions respectively.
315
+
316
+ Args:
317
+ data_path (str): Defines the path where the dataset will be
318
+ downloaded and extracted to.
319
+ """
320
+
321
+ print(">> Downloading OSIE dataset...", end="", flush=True)
322
+
323
+ default_path = data_path + "osie/"
324
+ stimuli_path = default_path + "stimuli/"
325
+ fixations_path = default_path + "fixations/"
326
+ saliency_path = default_path + "saliency/"
327
+
328
+ os.makedirs(stimuli_path, exist_ok=True)
329
+ os.makedirs(fixations_path, exist_ok=True)
330
+ os.makedirs(saliency_path, exist_ok=True)
331
+
332
+ url = "https://github.com/NUS-VIP/predicting-human-gaze-beyond-pixels/archive/master.zip"
333
+
334
+ with open(data_path + "tmp.zip", "wb") as f:
335
+ f.write(requests.get(url).content)
336
+
337
+ with zipfile.ZipFile(data_path + "tmp.zip", "r") as zip_ref:
338
+ for file in zip_ref.namelist():
339
+ file_name = os.path.basename(file)
340
+
341
+ if file.endswith(".jpg") and "data/stimuli" in file:
342
+ file_path = stimuli_path + file_name
343
+
344
+ with open(file_path, "wb") as stimulus:
345
+ stimulus.write(zip_ref.read(file))
346
+
347
+ elif file_name == "fixations.mat":
348
+ loaded_zip = io.BytesIO(zip_ref.read(file))
349
+
350
+ loaded_mat = loadmat(loaded_zip)["fixations"]
351
+
352
+ for idx, value in enumerate(loaded_mat):
353
+ subjects = value[0][0][0][1]
354
+
355
+ fixations_map = np.zeros((600, 800))
356
+
357
+ for subject in subjects:
358
+ x_vals = subject[0][0][0][0][0]
359
+ y_vals = subject[0][0][0][1][0]
360
+
361
+ fixations = np.stack((y_vals, x_vals), axis=-1)
362
+ fixations = fixations.astype(int)
363
+
364
+ fixations_map[fixations[:, 0],
365
+ fixations[:, 1]] = 1
366
+
367
+ file_name = str(1001 + idx) + ".png"
368
+
369
+ saliency_map = gaussian_filter(fixations_map, 16)
370
+
371
+ imsave(saliency_path + file_name, saliency_map, cmap="gray")
372
+ imsave(fixations_path + file_name, fixations_map, cmap="gray")
373
+
374
+ os.remove(data_path + "tmp.zip")
375
+
376
+ print("done!", flush=True)
377
+
378
+
379
+ def download_fiwi(data_path):
380
+ """Downloads the FIWI dataset. Three folders are then created that
381
+ contain the stimuli, binary fixation maps, and blurred saliency
382
+ distributions respectively.
383
+
384
+ Args:
385
+ data_path (str): Defines the path where the dataset will be
386
+ downloaded and extracted to.
387
+ """
388
+
389
+ print(">> Downloading FIWI dataset...", end="", flush=True)
390
+
391
+ default_path = data_path + "fiwi/"
392
+ stimuli_path = default_path + "stimuli/"
393
+ fixations_path = default_path + "fixations/"
394
+ saliency_path = default_path + "saliency/"
395
+
396
+ os.makedirs(stimuli_path, exist_ok=True)
397
+ os.makedirs(fixations_path, exist_ok=True)
398
+ os.makedirs(saliency_path, exist_ok=True)
399
+
400
+ url = "https://www.dropbox.com/s/30nxg2uwd1wpb80/webpage_dataset.zip?dl=1"
401
+
402
+ with open(data_path + "tmp.zip", "wb") as f:
403
+ f.write(requests.get(url).content)
404
+
405
+ with zipfile.ZipFile(data_path + "tmp.zip", "r") as zip_ref:
406
+ for file in zip_ref.namelist():
407
+ file_name = os.path.basename(file)
408
+
409
+ if file.endswith(".png") and "stimuli" in file:
410
+ file_path = stimuli_path + file_name
411
+
412
+ with open(file_path, "wb") as stimulus:
413
+ stimulus.write(zip_ref.read(file))
414
+
415
+ elif file.endswith(".png") and "all5" in file:
416
+ loaded_zip = io.BytesIO(zip_ref.read(file))
417
+
418
+ fixations = imread(loaded_zip)
419
+ saliency = gaussian_filter(fixations, 30)
420
+
421
+ imsave(saliency_path + file_name, saliency, cmap="gray")
422
+ imsave(fixations_path + file_name, fixations, cmap="gray")
423
+
424
+ os.remove(data_path + "tmp.zip")
425
+
426
+ print("done!", flush=True)
427
+
428
+
429
+ def download_pretrained_weights(data_path, key):
430
+ """Downloads the pre-trained weights for the VGG16 model when
431
+ training or the MSI-Net when testing on new data instances.
432
+
433
+ Args:
434
+ data_path (str): Defines the path where the weights will be
435
+ downloaded and extracted to.
436
+ key (str): Describes the type of model for which the weights will
437
+ be downloaded. This contains the device and dataset.
438
+
439
+ .. seealso:: The code for downloading files from google drive is based
440
+ on the solution provided at [https://bit.ly/2JSVgMQ].
441
+ """
442
+
443
+ print(">> Downloading pre-trained weights...", end="", flush=True)
444
+
445
+ os.makedirs(data_path, exist_ok=True)
446
+
447
+ ids = {
448
+ "vgg16_hybrid": "1ff0va472Xs1bvidCwRlW3Ctf7Hbyyn7p",
449
+ "model_salicon_cpu": "1Xy9C72pcA8DO4CY0rc6B7wsuE9L9DDZY",
450
+ "model_salicon_gpu": "1Th7fqVYx25ePMZz4LYsjNQWgAu8tJqwL",
451
+ "model_mit1003_cpu": "1jsESjYtsTvkMqKftA4rdstfB7mSYw5Ec",
452
+ "model_mit1003_gpu": "1P_tWxBl3igZlzcHGp5H3T3kzsOskWeG6",
453
+ "model_cat2000_cpu": "1XxaEx7xxD6rHasQTa-VY7T7eVpGhMxuV",
454
+ "model_cat2000_gpu": "1T6ChEGB6Mf02gKXrENjdeD6XXJkE_jHh",
455
+ "model_dutomron_cpu": "14tuRZpKi8LMDKRHNVUylu6RuAaXLjHTa",
456
+ "model_dutomron_gpu": "15LG_M45fpYC1pTwnwmArNTZw_Z3BOIA-",
457
+ "model_pascals_cpu": "1af9IvBqFamKWx64Ror6ALivuKNioOVIf",
458
+ "model_pascals_gpu": "1C-T-RQzX2SaiY9Nw1HmaSx6syyCt01Z0",
459
+ "model_osie_cpu": "1JD1tvAqZGxj_gEGmIfoxb9dTe5HOaHj1",
460
+ "model_osie_gpu": "1g8UPr1hGpUdOSWerRb751pZqiWBOZOCh",
461
+ "model_fiwi_cpu": "19qj9nAjd5gVHLB71oRn_YfYDw5n4Uf2X",
462
+ "model_fiwi_gpu": "12OpIMIi2IyDVaxkE2d37XO9uUsSYf1Ec"
463
+ }
464
+
465
+ url = "https://drive.google.com/uc?id=" + ids[key] + "&export=download"
466
+
467
+ gdown.download(url, data_path + "tmp.zip", quiet=True)
468
+
469
+ with zipfile.ZipFile(data_path + "tmp.zip", "r") as zip_ref:
470
+ for file in zip_ref.namelist():
471
+ zip_ref.extract(file, data_path)
472
+
473
+ os.remove(data_path + "tmp.zip")
474
+
475
+ print("done!", flush=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ tensorflow==1.13.1
2
+ protobuf==3.19.0
3
+ matplotlib==3.0.3
4
+ requests==2.21.0
5
+ scipy==1.4.1
6
+ streamlit==0.89.0
run_gradio.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw
2
+ import gradio as gr
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ import download, os, sys
6
+
7
+
8
+ def best_window(saliency, aspect_ratio=(16, 9)):
9
+ """ returns left, right, bottom, top
10
+ saliency is np.array with shape (height, width)
11
+ aspect_ratio is tuple of (width, height)
12
+ """
13
+ orig_height, orig_width = saliency.shape
14
+ move_vertically = orig_height >= orig_width / aspect_ratio[0] * \
15
+ aspect_ratio[1]
16
+ if move_vertically:
17
+ saliency_per_row = np.sum(saliency, axis=1)
18
+ height = round(orig_width / aspect_ratio[0] * aspect_ratio[1])
19
+ convolved_saliency = np.convolve(saliency_per_row, np.ones(height),
20
+ "valid")
21
+ max_row = np.argmax(convolved_saliency)
22
+ return 0, orig_width, max_row, max_row + height
23
+ else:
24
+ saliency_per_col = np.sum(saliency, axis=0)
25
+ width = round(orig_height / aspect_ratio[1] * aspect_ratio[0])
26
+ convolved_saliency = np.convolve(saliency_per_col, np.ones(width),
27
+ "valid")
28
+ max_col = np.argmax(convolved_saliency)
29
+ return max_col, max_col + width, 0, orig_height
30
+
31
+
32
+ def overlay_saliency(img, map, bbox = {}):
33
+ background = img.convert("RGBA")
34
+ overlay = map.convert("RGBA")
35
+ overlaid = Image.blend(background, overlay, 0.75)
36
+ draw = ImageDraw.Draw(overlaid)
37
+ if bbox:
38
+ draw.rectangle(
39
+ [bbox['left'], bbox['bottom'], bbox['right'], bbox['top']],
40
+ outline="orange", width=5)
41
+ return overlaid
42
+
43
+
44
+ def get_saliency_sum_box(crop_data, bounded, saliency):
45
+ left, right, bottom, top = int(crop_data["x"]), int(
46
+ crop_data["x"] + crop_data["width"]), int(crop_data["y"]), int(
47
+ crop_data["y"] + crop_data["height"])
48
+ sal_sum = np.sum(saliency[bottom:top, left:right])
49
+ total = np.sum(saliency)
50
+ pct_sal = round(100 * sal_sum / total, 2)
51
+ draw = ImageDraw.Draw(bounded)
52
+ draw.rectangle([left, bottom, right, top], outline="red", width=5)
53
+ return bounded, pct_sal
54
+
55
+
56
+ def test_model(im_arr, model_dict, aspect_ratio_tup = None):
57
+ # original_arr, crop_data = original_arr
58
+ # crop_data["original_height"] = original_arr.shape[0]
59
+ # crop_data["original_width"] = original_arr.shape[1]
60
+ original_img = Image.fromarray(im_arr).convert('RGB')
61
+ w, h = original_img.size
62
+ h_ = int(400 / w * h)
63
+ resized_img = original_img.resize((400, h_))
64
+ resized_arr = np.asarray(resized_img)
65
+
66
+ resized_arr = resized_arr[np.newaxis, ...]
67
+ saliency_arr = model_dict['sess'].run(model_dict['predicted_maps'],
68
+ feed_dict={
69
+ model_dict['input_plhd']: resized_arr
70
+ })
71
+ saliency_arr = saliency_arr.squeeze()
72
+
73
+ saliency_img = Image.fromarray(np.uint8(saliency_arr * 255), 'L')
74
+ saliency_resized_img = saliency_img.resize((w, h))
75
+
76
+ saliency_resized_arr = np.asarray(saliency_resized_img)
77
+ saliency_zero_one = np.divide(saliency_resized_arr, 255.0)
78
+
79
+ bbox = None
80
+ if aspect_ratio_tup:
81
+ left, right, bottom, top = best_window(saliency_resized_arr,
82
+ aspect_ratio=aspect_ratio_tup)
83
+ bbox = {'left': left, 'right': right, 'bottom': bottom, 'top':top}
84
+ # output = original_arr[bottom:top, left:right, :]
85
+
86
+ bounded = overlay_saliency(original_img, saliency_resized_img, bbox=bbox)
87
+ return bounded
88
+ # with_sal_box, pct_sal = get_saliency_sum_box(crop_data, bounded,
89
+ # saliency_zero_one)
90
+ # sal_sum = str(pct_sal) + "%"
91
+ # return with_sal_box, sal_sum
92
+
93
+ def load_model(model_name = "weights/model_mit1003_cpu.pb"):
94
+ ### Model loading code
95
+ graph_def = tf.GraphDef()
96
+ if not os.path.isfile(model_name):
97
+ download.download_pretrained_weights('weights/', 'model_mit1003_cpu')
98
+
99
+ with tf.gfile.Open(model_name, "rb") as file:
100
+ graph_def.ParseFromString(file.read())
101
+ input_plhd = tf.placeholder(tf.float32, (None, None, None, 3))
102
+ [predicted_maps] = tf.import_graph_def(graph_def,
103
+ input_map={"input": input_plhd},
104
+ return_elements=["output:0"])
105
+
106
+ sess = tf.Session()
107
+ return {
108
+ 'sess': sess,
109
+ 'predicted_maps': predicted_maps,
110
+ 'input_plhd': input_plhd
111
+ }
112
+
113
+ if __name__ == '__main__':
114
+ examples = [["images/1.jpg", True],
115
+ ["images/2.jpg", True]]
116
+
117
+ thumbnail = "https://ibb.co/hXdbDyD"
118
+ io = gr.Interface(test_model,
119
+ gr.inputs.Image(label="Your Image", tool='select'),
120
+ [gr.outputs.Image(label="Cropped Image"),
121
+ gr.outputs.Label(label="Percent of Saliency in Red Box")],
122
+ allow_flagging=False,
123
+ thumbnail=thumbnail,
124
+ examples=examples, analytics_enabled=False)
125
+
126
+ io.launch(debug=True)