mbar0075 commited on
Commit
c9baa67
1 Parent(s): 2bd58f3

Testing Commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SaRa/__pycache__/pySaliencyMap.cpython-39.pyc +0 -0
  2. SaRa/__pycache__/pySaliencyMapDefs.cpython-39.pyc +0 -0
  3. SaRa/__pycache__/saraRC1.cpython-39.pyc +0 -0
  4. SaRa/pySaliencyMap.py +288 -0
  5. SaRa/pySaliencyMapDefs.py +74 -0
  6. SaRa/saraRC1.py +1082 -0
  7. app.py +154 -0
  8. deepgaze_pytorch/__init__.py +3 -0
  9. deepgaze_pytorch/__pycache__/__init__.cpython-39.pyc +0 -0
  10. deepgaze_pytorch/__pycache__/deepgaze1.cpython-39.pyc +0 -0
  11. deepgaze_pytorch/__pycache__/deepgaze2e.cpython-39.pyc +0 -0
  12. deepgaze_pytorch/__pycache__/deepgaze3.cpython-39.pyc +0 -0
  13. deepgaze_pytorch/__pycache__/layers.cpython-39.pyc +0 -0
  14. deepgaze_pytorch/__pycache__/modules.cpython-39.pyc +0 -0
  15. deepgaze_pytorch/data.py +403 -0
  16. deepgaze_pytorch/deepgaze1.py +42 -0
  17. deepgaze_pytorch/deepgaze2e.py +151 -0
  18. deepgaze_pytorch/deepgaze3.py +110 -0
  19. deepgaze_pytorch/features/__init__.py +0 -0
  20. deepgaze_pytorch/features/__pycache__/__init__.cpython-39.pyc +0 -0
  21. deepgaze_pytorch/features/__pycache__/alexnet.cpython-39.pyc +0 -0
  22. deepgaze_pytorch/features/__pycache__/densenet.cpython-39.pyc +0 -0
  23. deepgaze_pytorch/features/__pycache__/efficientnet.cpython-39.pyc +0 -0
  24. deepgaze_pytorch/features/__pycache__/normalizer.cpython-39.pyc +0 -0
  25. deepgaze_pytorch/features/__pycache__/resnext.cpython-39.pyc +0 -0
  26. deepgaze_pytorch/features/__pycache__/shapenet.cpython-39.pyc +0 -0
  27. deepgaze_pytorch/features/alexnet.py +18 -0
  28. deepgaze_pytorch/features/bagnet.py +192 -0
  29. deepgaze_pytorch/features/densenet.py +19 -0
  30. deepgaze_pytorch/features/efficientnet.py +31 -0
  31. deepgaze_pytorch/features/efficientnet_pytorch/__init__.py +10 -0
  32. deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/__init__.cpython-39.pyc +0 -0
  33. deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/model.cpython-39.pyc +0 -0
  34. deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/utils.cpython-39.pyc +0 -0
  35. deepgaze_pytorch/features/efficientnet_pytorch/model.py +229 -0
  36. deepgaze_pytorch/features/efficientnet_pytorch/utils.py +335 -0
  37. deepgaze_pytorch/features/inception.py +20 -0
  38. deepgaze_pytorch/features/mobilenet.py +17 -0
  39. deepgaze_pytorch/features/normalizer.py +28 -0
  40. deepgaze_pytorch/features/resnet.py +44 -0
  41. deepgaze_pytorch/features/resnext.py +27 -0
  42. deepgaze_pytorch/features/shapenet.py +89 -0
  43. deepgaze_pytorch/features/squeezenet.py +17 -0
  44. deepgaze_pytorch/features/swav.py +20 -0
  45. deepgaze_pytorch/features/uninformative.py +26 -0
  46. deepgaze_pytorch/features/vgg.py +86 -0
  47. deepgaze_pytorch/features/vggnet.py +24 -0
  48. deepgaze_pytorch/features/wsl.py +27 -0
  49. deepgaze_pytorch/layers.py +427 -0
  50. deepgaze_pytorch/metrics.py +69 -0
SaRa/__pycache__/pySaliencyMap.cpython-39.pyc ADDED
Binary file (7.79 kB). View file
 
SaRa/__pycache__/pySaliencyMapDefs.cpython-39.pyc ADDED
Binary file (2.01 kB). View file
 
SaRa/__pycache__/saraRC1.cpython-39.pyc ADDED
Binary file (18.5 kB). View file
 
SaRa/pySaliencyMap.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-------------------------------------------------------------------------------
2
+ # Name: pySaliencyMap
3
+ # Purpose: Extracting a saliency map from a single still image
4
+ #
5
+ # Author: Akisato Kimura <akisato@ieee.org>
6
+ #
7
+ # Created: April 24, 2014
8
+ # Copyright: (c) Akisato Kimura 2014-
9
+ # Licence: All rights reserved
10
+ #-------------------------------------------------------------------------------
11
+
12
+ import cv2
13
+ import numpy as np
14
+ import SaRa.pySaliencyMapDefs as pySaliencyMapDefs
15
+ import time
16
+
17
+ class pySaliencyMap:
18
+ # initialization
19
+ def __init__(self, width, height):
20
+ self.width = width
21
+ self.height = height
22
+ self.prev_frame = None
23
+ self.SM = None
24
+ self.GaborKernel0 = np.array(pySaliencyMapDefs.GaborKernel_0)
25
+ self.GaborKernel45 = np.array(pySaliencyMapDefs.GaborKernel_45)
26
+ self.GaborKernel90 = np.array(pySaliencyMapDefs.GaborKernel_90)
27
+ self.GaborKernel135 = np.array(pySaliencyMapDefs.GaborKernel_135)
28
+
29
+ # extracting color channels
30
+ def SMExtractRGBI(self, inputImage):
31
+ # convert scale of array elements
32
+ src = np.float32(inputImage) * 1./255
33
+ # split
34
+ (B, G, R) = cv2.split(src)
35
+ # extract an intensity image
36
+ I = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)
37
+ # return
38
+ return R, G, B, I
39
+
40
+ # feature maps
41
+ ## constructing a Gaussian pyramid
42
+ def FMCreateGaussianPyr(self, src):
43
+ dst = list()
44
+ dst.append(src)
45
+ for i in range(1,9):
46
+ nowdst = cv2.pyrDown(dst[i-1])
47
+ dst.append(nowdst)
48
+ return dst
49
+ ## taking center-surround differences
50
+ def FMCenterSurroundDiff(self, GaussianMaps):
51
+ dst = list()
52
+ for s in range(2,5):
53
+ now_size = GaussianMaps[s].shape
54
+ now_size = (now_size[1], now_size[0]) ## (width, height)
55
+ tmp = cv2.resize(GaussianMaps[s+3], now_size, interpolation=cv2.INTER_LINEAR)
56
+ nowdst = cv2.absdiff(GaussianMaps[s], tmp)
57
+ dst.append(nowdst)
58
+ tmp = cv2.resize(GaussianMaps[s+4], now_size, interpolation=cv2.INTER_LINEAR)
59
+ nowdst = cv2.absdiff(GaussianMaps[s], tmp)
60
+ dst.append(nowdst)
61
+ return dst
62
+ ## constructing a Gaussian pyramid + taking center-surround differences
63
+ def FMGaussianPyrCSD(self, src):
64
+ GaussianMaps = self.FMCreateGaussianPyr(src)
65
+ dst = self.FMCenterSurroundDiff(GaussianMaps)
66
+ return dst
67
+ ## intensity feature maps
68
+ def IFMGetFM(self, I):
69
+ return self.FMGaussianPyrCSD(I)
70
+ ## color feature maps
71
+ def CFMGetFM(self, R, G, B):
72
+ # max(R,G,B)
73
+ tmp1 = cv2.max(R, G)
74
+ RGBMax = cv2.max(B, tmp1)
75
+ RGBMax[RGBMax <= 0] = 0.0001 # prevent dividing by 0
76
+ # min(R,G)
77
+ RGMin = cv2.min(R, G)
78
+ # RG = (R-G)/max(R,G,B)
79
+ RG = (R - G) / RGBMax
80
+ # BY = (B-min(R,G)/max(R,G,B)
81
+ BY = (B - RGMin) / RGBMax
82
+ # clamp nagative values to 0
83
+ RG[RG < 0] = 0
84
+ BY[BY < 0] = 0
85
+ # obtain feature maps in the same way as intensity
86
+ RGFM = self.FMGaussianPyrCSD(RG)
87
+ BYFM = self.FMGaussianPyrCSD(BY)
88
+ # return
89
+ return RGFM, BYFM
90
+ ## orientation feature maps
91
+ def OFMGetFM(self, src):
92
+ # creating a Gaussian pyramid
93
+ GaussianI = self.FMCreateGaussianPyr(src)
94
+ # convoluting a Gabor filter with an intensity image to extract oriemtation features
95
+ GaborOutput0 = [ np.empty((1,1)), np.empty((1,1)) ] # dummy data: any kinds of np.array()s are OK
96
+ GaborOutput45 = [ np.empty((1,1)), np.empty((1,1)) ]
97
+ GaborOutput90 = [ np.empty((1,1)), np.empty((1,1)) ]
98
+ GaborOutput135 = [ np.empty((1,1)), np.empty((1,1)) ]
99
+ for j in range(2,9):
100
+ GaborOutput0.append( cv2.filter2D(GaussianI[j], cv2.CV_32F, self.GaborKernel0) )
101
+ GaborOutput45.append( cv2.filter2D(GaussianI[j], cv2.CV_32F, self.GaborKernel45) )
102
+ GaborOutput90.append( cv2.filter2D(GaussianI[j], cv2.CV_32F, self.GaborKernel90) )
103
+ GaborOutput135.append( cv2.filter2D(GaussianI[j], cv2.CV_32F, self.GaborKernel135) )
104
+ # calculating center-surround differences for every oriantation
105
+ CSD0 = self.FMCenterSurroundDiff(GaborOutput0)
106
+ CSD45 = self.FMCenterSurroundDiff(GaborOutput45)
107
+ CSD90 = self.FMCenterSurroundDiff(GaborOutput90)
108
+ CSD135 = self.FMCenterSurroundDiff(GaborOutput135)
109
+ # concatenate
110
+ dst = list(CSD0)
111
+ dst.extend(CSD45)
112
+ dst.extend(CSD90)
113
+ dst.extend(CSD135)
114
+ # return
115
+ return dst
116
+ ## motion feature maps
117
+ def MFMGetFM(self, src):
118
+ # convert scale
119
+ I8U = np.uint8(255 * src)
120
+ # cv2.waitKey(10)
121
+ # calculating optical flows
122
+ if self.prev_frame is not None:
123
+ farne_pyr_scale= pySaliencyMapDefs.farne_pyr_scale
124
+ farne_levels = pySaliencyMapDefs.farne_levels
125
+ farne_winsize = pySaliencyMapDefs.farne_winsize
126
+ farne_iterations = pySaliencyMapDefs.farne_iterations
127
+ farne_poly_n = pySaliencyMapDefs.farne_poly_n
128
+ farne_poly_sigma = pySaliencyMapDefs.farne_poly_sigma
129
+ farne_flags = pySaliencyMapDefs.farne_flags
130
+ flow = cv2.calcOpticalFlowFarneback(\
131
+ prev = self.prev_frame, \
132
+ next = I8U, \
133
+ pyr_scale = farne_pyr_scale, \
134
+ levels = farne_levels, \
135
+ winsize = farne_winsize, \
136
+ iterations = farne_iterations, \
137
+ poly_n = farne_poly_n, \
138
+ poly_sigma = farne_poly_sigma, \
139
+ flags = farne_flags, \
140
+ flow = None \
141
+ )
142
+ flowx = flow[...,0]
143
+ flowy = flow[...,1]
144
+ else:
145
+ flowx = np.zeros(I8U.shape)
146
+ flowy = np.zeros(I8U.shape)
147
+ # create Gaussian pyramids
148
+ dst_x = self.FMGaussianPyrCSD(flowx)
149
+ dst_y = self.FMGaussianPyrCSD(flowy)
150
+ # update the current frame
151
+ self.prev_frame = np.uint8(I8U)
152
+ # return
153
+ return dst_x, dst_y
154
+
155
+ # conspicuity maps
156
+ ## standard range normalization
157
+ def SMRangeNormalize(self, src):
158
+ minn, maxx, dummy1, dummy2 = cv2.minMaxLoc(src)
159
+ if maxx!=minn:
160
+ dst = src/(maxx-minn) + minn/(minn-maxx)
161
+ else:
162
+ dst = src - minn
163
+ return dst
164
+ ## computing an average of local maxima
165
+ def SMAvgLocalMax(self, src):
166
+ # size
167
+ stepsize = pySaliencyMapDefs.default_step_local
168
+ width = src.shape[1]
169
+ height = src.shape[0]
170
+ # find local maxima
171
+ numlocal = 0
172
+ lmaxmean = 0
173
+ for y in range(0, height-stepsize, stepsize):
174
+ for x in range(0, width-stepsize, stepsize):
175
+ localimg = src[y:y+stepsize, x:x+stepsize]
176
+ lmin, lmax, dummy1, dummy2 = cv2.minMaxLoc(localimg)
177
+ lmaxmean += lmax
178
+ numlocal += 1
179
+ # averaging over all the local regions (error checking for numlocal)
180
+ if numlocal==0:
181
+ return 0
182
+ else:
183
+ return lmaxmean / numlocal
184
+ ## normalization specific for the saliency map model
185
+ def SMNormalization(self, src):
186
+ dst = self.SMRangeNormalize(src)
187
+ lmaxmean = self.SMAvgLocalMax(dst)
188
+ normcoeff = (1-lmaxmean)*(1-lmaxmean)
189
+ return dst * normcoeff
190
+ ## normalizing feature maps
191
+ def normalizeFeatureMaps(self, FM):
192
+ NFM = list()
193
+ for i in range(0,6):
194
+ normalizedImage = self.SMNormalization(FM[i])
195
+ nownfm = cv2.resize(normalizedImage, (self.width, self.height), interpolation=cv2.INTER_LINEAR)
196
+ NFM.append(nownfm)
197
+ return NFM
198
+ ## intensity conspicuity map
199
+ def ICMGetCM(self, IFM):
200
+ NIFM = self.normalizeFeatureMaps(IFM)
201
+ ICM = sum(NIFM)
202
+ return ICM
203
+ ## color conspicuity map
204
+ def CCMGetCM(self, CFM_RG, CFM_BY):
205
+ # extracting a conspicuity map for every color opponent pair
206
+ CCM_RG = self.ICMGetCM(CFM_RG)
207
+ CCM_BY = self.ICMGetCM(CFM_BY)
208
+ # merge
209
+ CCM = CCM_RG + CCM_BY
210
+ # return
211
+ return CCM
212
+ ## orientation conspicuity map
213
+ def OCMGetCM(self, OFM):
214
+ OCM = np.zeros((self.height, self.width))
215
+ for i in range (0,4):
216
+ # slicing
217
+ nowofm = OFM[i*6:(i+1)*6] # angle = i*45
218
+ # extracting a conspicuity map for every angle
219
+ NOFM = self.ICMGetCM(nowofm)
220
+ # normalize
221
+ NOFM2 = self.SMNormalization(NOFM)
222
+ # accumulate
223
+ OCM += NOFM2
224
+ return OCM
225
+ ## motion conspicuity map
226
+ def MCMGetCM(self, MFM_X, MFM_Y):
227
+ return self.CCMGetCM(MFM_X, MFM_Y)
228
+
229
+ # core
230
+ def SMGetSM(self, src):
231
+ # definitions
232
+ size = src.shape
233
+ width = size[1]
234
+ height = size[0]
235
+ # check
236
+ # if(width != self.width or height != self.height):
237
+ # sys.exit("size mismatch")
238
+ # extracting individual color channels
239
+ R, G, B, I = self.SMExtractRGBI(src)
240
+ # extracting feature maps
241
+ IFM = self.IFMGetFM(I)
242
+ CFM_RG, CFM_BY = self.CFMGetFM(R, G, B)
243
+ OFM = self.OFMGetFM(I)
244
+ MFM_X, MFM_Y = self.MFMGetFM(I)
245
+ # extracting conspicuity maps
246
+ ICM = self.ICMGetCM(IFM)
247
+ CCM = self.CCMGetCM(CFM_RG, CFM_BY)
248
+ OCM = self.OCMGetCM(OFM)
249
+ MCM = self.MCMGetCM(MFM_X, MFM_Y)
250
+ # adding all the conspicuity maps to form a saliency map
251
+ wi = pySaliencyMapDefs.weight_intensity
252
+ wc = pySaliencyMapDefs.weight_color
253
+ wo = pySaliencyMapDefs.weight_orientation
254
+ wm = pySaliencyMapDefs.weight_motion
255
+ SMMat = wi*ICM + wc*CCM + wo*OCM + wm*MCM
256
+ # normalize
257
+ normalizedSM = self.SMRangeNormalize(SMMat)
258
+ normalizedSM2 = normalizedSM.astype(np.float32)
259
+ smoothedSM = cv2.bilateralFilter(normalizedSM2, 7, 3, 1.55)
260
+ self.SM = cv2.resize(smoothedSM, (width,height), interpolation=cv2.INTER_NEAREST)
261
+ # return
262
+ return self.SM
263
+
264
+ def SMGetBinarizedSM(self, src):
265
+ # get a saliency map
266
+ if self.SM is None:
267
+ self.SM = self.SMGetSM(src)
268
+ # convert scale
269
+ SM_I8U = np.uint8(255 * self.SM)
270
+ # binarize
271
+ thresh, binarized_SM = cv2.threshold(SM_I8U, thresh=0, maxval=255, type=cv2.THRESH_BINARY+cv2.THRESH_OTSU)
272
+ return binarized_SM
273
+
274
+ def SMGetSalientRegion(self, src):
275
+ # get a binarized saliency map
276
+ binarized_SM = self.SMGetBinarizedSM(src)
277
+ # GrabCut
278
+ img = src.copy()
279
+ mask = np.where((binarized_SM!=0), cv2.GC_PR_FGD, cv2.GC_PR_BGD).astype('uint8')
280
+ bgdmodel = np.zeros((1,65),np.float64)
281
+ fgdmodel = np.zeros((1,65),np.float64)
282
+ rect = (0,0,1,1) # dummy
283
+ iterCount = 1
284
+ cv2.grabCut(img, mask=mask, rect=rect, bgdModel=bgdmodel, fgdModel=fgdmodel, iterCount=iterCount, mode=cv2.GC_INIT_WITH_MASK)
285
+ # post-processing
286
+ mask_out = np.where((mask==cv2.GC_FGD) + (mask==cv2.GC_PR_FGD), 255, 0).astype('uint8')
287
+ output = cv2.bitwise_and(img,img,mask=mask_out)
288
+ return output
SaRa/pySaliencyMapDefs.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-------------------------------------------------------------------------------
2
+ # Name: pySaliencyMapDefs
3
+ # Purpose: Definitions for class pySaliencyMap
4
+ #
5
+ # Author: Akisato Kimura <akisato@ieee.org>
6
+ #
7
+ # Created: April 24, 2014
8
+ # Copyright: (c) Akisato Kimura 2014-
9
+ # Licence: All rights reserved
10
+ #-------------------------------------------------------------------------------
11
+
12
+ # parameters for computing optical flows using the Gunner Farneback's algorithm
13
+ farne_pyr_scale = 0.5
14
+ farne_levels = 3
15
+ farne_winsize = 15
16
+ farne_iterations = 3
17
+ farne_poly_n = 5
18
+ farne_poly_sigma = 1.2
19
+ farne_flags = 0
20
+
21
+ # parameters for detecting local maxima
22
+ default_step_local = 16
23
+
24
+ # feature weights
25
+ weight_intensity = 0.30
26
+ weight_color = 0.30
27
+ weight_orientation = 0.20
28
+ weight_motion = 0.20
29
+
30
+ # coefficients of Gabor filters
31
+ GaborKernel_0 = [\
32
+ [ 1.85212E-06, 1.28181E-05, -0.000350433, -0.000136537, 0.002010422, -0.000136537, -0.000350433, 1.28181E-05, 1.85212E-06 ],\
33
+ [ 2.80209E-05, 0.000193926, -0.005301717, -0.002065674, 0.030415784, -0.002065674, -0.005301717, 0.000193926, 2.80209E-05 ],\
34
+ [ 0.000195076, 0.001350077, -0.036909595, -0.014380852, 0.211749204, -0.014380852, -0.036909595, 0.001350077, 0.000195076 ],\
35
+ [ 0.000624940, 0.004325061, -0.118242318, -0.046070008, 0.678352526, -0.046070008, -0.118242318, 0.004325061, 0.000624940 ],\
36
+ [ 0.000921261, 0.006375831, -0.174308068, -0.067914552, 1.000000000, -0.067914552, -0.174308068, 0.006375831, 0.000921261 ],\
37
+ [ 0.000624940, 0.004325061, -0.118242318, -0.046070008, 0.678352526, -0.046070008, -0.118242318, 0.004325061, 0.000624940 ],\
38
+ [ 0.000195076, 0.001350077, -0.036909595, -0.014380852, 0.211749204, -0.014380852, -0.036909595, 0.001350077, 0.000195076 ],\
39
+ [ 2.80209E-05, 0.000193926, -0.005301717, -0.002065674, 0.030415784, -0.002065674, -0.005301717, 0.000193926, 2.80209E-05 ],\
40
+ [ 1.85212E-06, 1.28181E-05, -0.000350433, -0.000136537, 0.002010422, -0.000136537, -0.000350433, 1.28181E-05, 1.85212E-06 ]\
41
+ ]
42
+ GaborKernel_45 = [\
43
+ [ 4.04180E-06, 2.25320E-05, -0.000279806, -0.001028923, 3.79931E-05, 0.000744712, 0.000132863, -9.04408E-06, -1.01551E-06 ],\
44
+ [ 2.25320E-05, 0.000925120, 0.002373205, -0.013561362, -0.022947700, 0.000389916, 0.003516954, 0.000288732, -9.04408E-06 ],\
45
+ [ -0.000279806, 0.002373205, 0.044837725, 0.052928748, -0.139178011, -0.108372072, 0.000847346, 0.003516954, 0.000132863 ],\
46
+ [ -0.001028923, -0.013561362, 0.052928748, 0.460162150, 0.249959607, -0.302454279, -0.108372072, 0.000389916, 0.000744712 ],\
47
+ [ 3.79931E-05, -0.022947700, -0.139178011, 0.249959607, 1.000000000, 0.249959607, -0.139178011, -0.022947700, 3.79931E-05 ],\
48
+ [ 0.000744712, 0.003899160, -0.108372072, -0.302454279, 0.249959607, 0.460162150, 0.052928748, -0.013561362, -0.001028923 ],\
49
+ [ 0.000132863, 0.003516954, 0.000847346, -0.108372072, -0.139178011, 0.052928748, 0.044837725, 0.002373205, -0.000279806 ],\
50
+ [ -9.04408E-06, 0.000288732, 0.003516954, 0.000389916, -0.022947700, -0.013561362, 0.002373205, 0.000925120, 2.25320E-05 ],\
51
+ [ -1.01551E-06, -9.04408E-06, 0.000132863, 0.000744712, 3.79931E-05, -0.001028923, -0.000279806, 2.25320E-05, 4.04180E-06 ]\
52
+ ]
53
+ GaborKernel_90 = [\
54
+ [ 1.85212E-06, 2.80209E-05, 0.000195076, 0.000624940, 0.000921261, 0.000624940, 0.000195076, 2.80209E-05, 1.85212E-06 ],\
55
+ [ 1.28181E-05, 0.000193926, 0.001350077, 0.004325061, 0.006375831, 0.004325061, 0.001350077, 0.000193926, 1.28181E-05 ],\
56
+ [ -0.000350433, -0.005301717, -0.036909595, -0.118242318, -0.174308068, -0.118242318, -0.036909595, -0.005301717, -0.000350433 ],\
57
+ [ -0.000136537, -0.002065674, -0.014380852, -0.046070008, -0.067914552, -0.046070008, -0.014380852, -0.002065674, -0.000136537 ],\
58
+ [ 0.002010422, 0.030415784, 0.211749204, 0.678352526, 1.000000000, 0.678352526, 0.211749204, 0.030415784, 0.002010422 ],\
59
+ [ -0.000136537, -0.002065674, -0.014380852, -0.046070008, -0.067914552, -0.046070008, -0.014380852, -0.002065674, -0.000136537 ],\
60
+ [ -0.000350433, -0.005301717, -0.036909595, -0.118242318, -0.174308068, -0.118242318, -0.036909595, -0.005301717, -0.000350433 ],\
61
+ [ 1.28181E-05, 0.000193926, 0.001350077, 0.004325061, 0.006375831, 0.004325061, 0.001350077, 0.000193926, 1.28181E-05 ],\
62
+ [ 1.85212E-06, 2.80209E-05, 0.000195076, 0.000624940, 0.000921261, 0.000624940, 0.000195076, 2.80209E-05, 1.85212E-06 ]
63
+ ]
64
+ GaborKernel_135 = [\
65
+ [ -1.01551E-06, -9.04408E-06, 0.000132863, 0.000744712, 3.79931E-05, -0.001028923, -0.000279806, 2.2532E-05, 4.0418E-06 ],\
66
+ [ -9.04408E-06, 0.000288732, 0.003516954, 0.000389916, -0.022947700, -0.013561362, 0.002373205, 0.00092512, 2.2532E-05 ],\
67
+ [ 0.000132863, 0.003516954, 0.000847346, -0.108372072, -0.139178011, 0.052928748, 0.044837725, 0.002373205, -0.000279806 ],\
68
+ [ 0.000744712, 0.000389916, -0.108372072, -0.302454279, 0.249959607, 0.46016215, 0.052928748, -0.013561362, -0.001028923 ],\
69
+ [ 3.79931E-05, -0.022947700, -0.139178011, 0.249959607, 1.000000000, 0.249959607, -0.139178011, -0.0229477, 3.79931E-05 ],\
70
+ [ -0.001028923, -0.013561362, 0.052928748, 0.460162150, 0.249959607, -0.302454279, -0.108372072, 0.000389916, 0.000744712 ],\
71
+ [ -0.000279806, 0.002373205, 0.044837725, 0.052928748, -0.139178011, -0.108372072, 0.000847346, 0.003516954, 0.000132863 ],\
72
+ [ 2.25320E-05, 0.000925120, 0.002373205, -0.013561362, -0.022947700, 0.000389916, 0.003516954, 0.000288732, -9.04408E-06 ],\
73
+ [ 4.04180E-06, 2.25320E-05, -0.000279806, -0.001028923, 3.79931E-05 , 0.000744712, 0.000132863, -9.04408E-06, -1.01551E-06 ]\
74
+ ]
SaRa/saraRC1.py ADDED
@@ -0,0 +1,1082 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import math
4
+ import scipy.stats as st
5
+ from mpl_toolkits.mplot3d import Axes3D
6
+ from matplotlib.lines import Line2D
7
+ import matplotlib.pyplot as plt
8
+ import operator
9
+ import time
10
+ import os
11
+ from enum import Enum
12
+ import pandas as pd
13
+
14
+ # Akisato Kimura <akisato@ieee.org> implementation of Itti's Saliency Map Generator -- https://github.com/akisatok/pySaliencyMap
15
+ from SaRa.pySaliencyMap import pySaliencyMap
16
+
17
+
18
+ # Global Variables
19
+
20
+ # Entropy, sum, depth, centre-bias
21
+ WEIGHTS = (1, 1, 1, 1)
22
+
23
+ # segments_entropies = []
24
+ segments_scores = []
25
+ segments_coords = []
26
+
27
+ seg_dim = 0
28
+ segments = []
29
+ gt_segments = []
30
+ dws = []
31
+ sara_list = []
32
+
33
+ eval_list = []
34
+ labels_eval_list = ['Image', 'Index', 'Rank', 'Quartile', 'isGT', 'Outcome']
35
+
36
+ outcome_list = []
37
+ labels_outcome_list = ['Image', 'FN', 'FP', 'TN', 'TP']
38
+
39
+ dataframe_collection = {}
40
+ error_count = 0
41
+
42
+
43
+ # SaRa Initial Functions
44
+ def generate_segments(img, seg_count) -> list:
45
+ '''
46
+ Given an image img and the desired number of segments seg_count, this
47
+ function divides the image into segments and returns a list of segments.
48
+ '''
49
+
50
+ segments = []
51
+ segment_count = seg_count
52
+ index = 0
53
+
54
+ w_interval = int(img.shape[1] / segment_count)
55
+ h_interval = int(img.shape[0] / segment_count)
56
+
57
+ for i in range(segment_count):
58
+ for j in range(segment_count):
59
+ temp_segment = img[int(h_interval * i):int(h_interval * (i + 1)),
60
+ int(w_interval * j):int(w_interval * (j + 1))]
61
+ segments.append(temp_segment)
62
+
63
+ coord_tup = (index, int(w_interval * j), int(h_interval * i),
64
+ int(w_interval * (j + 1)), int(h_interval * (i + 1)))
65
+ segments_coords.append(coord_tup)
66
+
67
+ index += 1
68
+
69
+ return segments
70
+
71
+
72
+ def return_saliency(img, generator='itti', deepgaze_model=None, emlnet_models=None, DEVICE='cpu'):
73
+ '''
74
+ Takes an image img as input and calculates the saliency map using the
75
+ Itti's Saliency Map Generator. It returns the saliency map.
76
+ '''
77
+
78
+ img_width, img_height = img.shape[1], img.shape[0]
79
+
80
+ if generator == 'itti':
81
+
82
+ sm = pySaliencyMap(img_width, img_height)
83
+ saliency_map = sm.SMGetSM(img)
84
+
85
+ # Scale pixel values to 0-255 instead of float (approx 0, hence black image)
86
+ # https://stackoverflow.com/questions/48331211/how-to-use-cv2-imshow-correctly-for-the-float-image-returned-by-cv2-distancet/48333272
87
+ saliency_map = cv2.normalize(saliency_map, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8UC1)
88
+ elif generator == 'deepgaze':
89
+ import numpy as np
90
+ from scipy.misc import face
91
+ from scipy.ndimage import zoom
92
+ from scipy.special import logsumexp
93
+ import torch
94
+
95
+ import deepgaze_pytorch
96
+
97
+ # you can use DeepGazeI or DeepGazeIIE
98
+ # model = deepgaze_pytorch.DeepGazeIIE(pretrained=True).to(DEVICE)
99
+
100
+ if deepgaze_model is None:
101
+ model = deepgaze_pytorch.DeepGazeIIE(pretrained=True).to(DEVICE)
102
+ else:
103
+ model = deepgaze_model
104
+
105
+ # image = face()
106
+ image = img
107
+
108
+ # load precomputed centerbias log density (from MIT1003) over a 1024x1024 image
109
+ # you can download the centerbias from https://github.com/matthias-k/DeepGaze/releases/download/v1.0.0/centerbias_mit1003.npy
110
+ # alternatively, you can use a uniform centerbias via `centerbias_template = np.zeros((1024, 1024))`.
111
+ # centerbias_template = np.load('centerbias_mit1003.npy')
112
+ centerbias_template = np.zeros((1024, 1024))
113
+ # rescale to match image size
114
+ centerbias = zoom(centerbias_template, (image.shape[0]/centerbias_template.shape[0], image.shape[1]/centerbias_template.shape[1]), order=0, mode='nearest')
115
+ # renormalize log density
116
+ centerbias -= logsumexp(centerbias)
117
+
118
+ image_tensor = torch.tensor([image.transpose(2, 0, 1)]).to(DEVICE)
119
+ centerbias_tensor = torch.tensor([centerbias]).to(DEVICE)
120
+
121
+ log_density_prediction = model(image_tensor, centerbias_tensor)
122
+
123
+ saliency_map = cv2.resize(log_density_prediction.detach().cpu().numpy()[0, 0], (img_width, img_height))
124
+
125
+ elif generator == 'fpn':
126
+ # Add ./fpn to the system path
127
+ import sys
128
+ sys.path.append('./fpn')
129
+ import inference as inf
130
+
131
+ results_dict = {}
132
+ rt_args = inf.parse_arguments(img)
133
+
134
+ # Call the run_inference function and capture the results
135
+ pred_masks_raw_list, pred_masks_round_list = inf.run_inference(rt_args)
136
+
137
+ # Store the results in the dictionary
138
+ results_dict['pred_masks_raw'] = pred_masks_raw_list
139
+ results_dict['pred_masks_round'] = pred_masks_round_list
140
+
141
+ saliency_map = results_dict['pred_masks_raw']
142
+
143
+ if img_width > img_height:
144
+ saliency_map = cv2.resize(saliency_map, (img_width, img_width))
145
+
146
+ diff = (img_width - img_height) // 2
147
+
148
+ saliency_map = saliency_map[diff:img_width - diff, 0:img_width]
149
+ else:
150
+ saliency_map = cv2.resize(saliency_map, (img_height, img_height))
151
+
152
+ diff = (img_height - img_width) // 2
153
+
154
+ saliency_map = saliency_map[0:img_height, diff:img_height - diff]
155
+
156
+ elif generator == 'emlnet':
157
+ from emlnet.eval_combined import main as eval_combined
158
+ saliency_map = eval_combined(img, emlnet_models)
159
+
160
+ # Resize to image size
161
+ saliency_map = cv2.resize(saliency_map, (img_width, img_height))
162
+
163
+ # Normalize saliency map
164
+ saliency_map = cv2.normalize(saliency_map, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8UC1)
165
+
166
+ saliency_map = cv2.GaussianBlur(saliency_map, (31, 31), 10)
167
+ return saliency_map
168
+ saliency_map = saliency_map // 16
169
+
170
+ return saliency_map
171
+
172
+
173
+ def return_saliency_batch(images, generator='deepgaze', deepgaze_model=None, emlnet_models=None, DEVICE='cuda', BATCH_SIZE=1):
174
+ img_widths, img_heights = [], []
175
+ if generator == 'deepgaze':
176
+ import numpy as np
177
+ from scipy.misc import face
178
+ from scipy.ndimage import zoom
179
+ from scipy.special import logsumexp
180
+ import torch
181
+
182
+ import deepgaze_pytorch
183
+
184
+ # you can use DeepGazeI or DeepGazeIIE
185
+ # model = deepgaze_pytorch.DeepGazeIIE(pretrained=True).to(DEVICE)
186
+
187
+ if deepgaze_model is None:
188
+ model = deepgaze_pytorch.DeepGazeIIE(pretrained=True).to(DEVICE)
189
+ else:
190
+ model = deepgaze_model
191
+
192
+ # image = face()
193
+ # image = img
194
+ image_batch = torch.tensor([img.transpose(2, 0, 1) for img in images]).to(DEVICE)
195
+ centerbias_template = np.zeros((1024, 1024))
196
+ centerbias_tensors = []
197
+
198
+ for img in images:
199
+ centerbias = zoom(centerbias_template, (img.shape[0] / centerbias_template.shape[0], img.shape[1] / centerbias_template.shape[1]), order=0, mode='nearest')
200
+ centerbias -= logsumexp(centerbias)
201
+ centerbias_tensors.append(torch.tensor(centerbias).to(DEVICE))
202
+
203
+ # Set img_width and img_height
204
+ img_widths.append(img.shape[1])
205
+
206
+
207
+ # rescale to match image size
208
+ # centerbias = zoom(centerbias_template, (image.shape[0]/centerbias_template.shape[0], image.shape[1]/centerbias_template.shape[1]), order=0, mode='nearest')
209
+ # # renormalize log density
210
+ # centerbias -= logsumexp(centerbias)
211
+
212
+ # image_tensor = torch.tensor([image.transpose(2, 0, 1)]).to(DEVICE)
213
+ # centerbias_tensor = torch.tensor([centerbias]).to(DEVICE)
214
+ with torch.no_grad():
215
+ # Process the batch of images in one forward pass
216
+ log_density_predictions = model(image_batch, torch.stack(centerbias_tensors))
217
+
218
+ # log_density_prediction = model(image_tensor, centerbias_tensor)
219
+
220
+ # saliency_map = cv2.resize(log_density_prediction.detach().cpu().numpy()[0, 0], (img_width, img_height))
221
+
222
+ saliency_maps = []
223
+
224
+ for i in range(len(images)):
225
+ saliency_map = cv2.resize(log_density_predictions[i, 0].cpu().numpy(), (img_widths[i], img_widths[i]))
226
+
227
+ saliency_map = cv2.normalize(saliency_map, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8UC1)
228
+
229
+ saliency_map = cv2.GaussianBlur(saliency_map, (31, 31), 10)
230
+ saliency_map = saliency_map // 16
231
+
232
+ saliency_maps.append(saliency_map)
233
+
234
+ return saliency_maps
235
+
236
+
237
+ # def return_itti_saliency(img):
238
+ # '''
239
+ # Takes an image img as input and calculates the saliency map using the
240
+ # Itti's Saliency Map Generator. It returns the saliency map.
241
+ # '''
242
+
243
+ # img_width, img_height = img.shape[1], img.shape[0]
244
+
245
+ # sm = pySaliencyMap.pySaliencyMap(img_width, img_height)
246
+ # saliency_map = sm.SMGetSM(img)
247
+
248
+ # # Scale pixel values to 0-255 instead of float (approx 0, hence black image)
249
+ # # https://stackoverflow.com/questions/48331211/how-to-use-cv2-imshow-correctly-for-the-float-image-returned-by-cv2-distancet/48333272
250
+ # saliency_map = cv2.normalize(saliency_map, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8UC1)
251
+
252
+ # return saliency_map
253
+
254
+
255
+ # Saliency Ranking
256
+ def calculate_pixel_frequency(img) -> dict:
257
+ '''
258
+ Calculates the frequency of each pixel value in the image img and
259
+ returns a dictionary containing the pixel frequencies.
260
+ '''
261
+
262
+ flt = img.flatten()
263
+ unique, counts = np.unique(flt, return_counts=True)
264
+ pixels_frequency = dict(zip(unique, counts))
265
+
266
+ return pixels_frequency
267
+
268
+
269
+ def calculate_score(H, sum, ds, cb, w):
270
+ '''
271
+ Calculates the saliency score of an image img using the entropy H, depth score ds, centre-bias cb and weights w. It returns the saliency score.
272
+ '''
273
+
274
+ # Normalise H
275
+ # H = (H - 0) / (math.log(2, 256) - 0)
276
+
277
+ # H = wth root of H
278
+ H = H ** w[0]
279
+
280
+ if sum > 0:
281
+ sum = np.log(sum)
282
+ sum = sum ** w[1]
283
+
284
+ ds = ds ** w[2]
285
+
286
+ cb = (cb + 1) ** w[3]
287
+
288
+ return H + sum + ds + cb
289
+
290
+
291
+ def calculate_entropy(img, w, dw) -> float:
292
+ '''
293
+ Calculates the entropy of an image img using the given weights w and
294
+ depth weights dw. It returns the entropy value.
295
+ '''
296
+
297
+ flt = img.flatten()
298
+
299
+ # c = flt.shape[0]
300
+ total_pixels = 0
301
+ t_prob = 0
302
+ # sum_of_probs = 0
303
+ entropy = 0
304
+ wt = w * 10
305
+
306
+ # if imgD=None then proceed normally
307
+ # else calculate its frequency and find max
308
+ # use this max value as a weight in entropy
309
+
310
+ pixels_frequency = calculate_pixel_frequency(flt)
311
+
312
+ total_pixels = sum(pixels_frequency.values())
313
+
314
+ for px in pixels_frequency:
315
+ t_prob = pixels_frequency[px] / total_pixels
316
+
317
+ if t_prob != 0:
318
+ entropy += (t_prob * math.log((1 / t_prob), 2))
319
+
320
+ # entropy = entropy * wt * dw
321
+
322
+ return entropy
323
+
324
+
325
+ def find_most_salient_segment(segments, kernel, dws):
326
+ '''
327
+ Finds the most salient segment among the provided segments using a
328
+ given kernel and depth weights. It returns the maximum entropy value
329
+ and the index of the most salient segment.
330
+ '''
331
+
332
+ # max_entropy = 0
333
+ max_score = 0
334
+ index = 0
335
+ i = 0
336
+
337
+ for segment in segments:
338
+ temp_entropy = calculate_entropy(segment, kernel[i], dws[i])
339
+ # Normalise semgnet bweetn 0 and 255
340
+ segment = cv2.normalize(segment, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8UC1)
341
+ temp_sum = np.sum(segment)
342
+ # temp_tup = (i, temp_entropy)
343
+ # segments_entropies.append(temp_tup)
344
+
345
+ w = WEIGHTS
346
+
347
+ temp_score = calculate_score(temp_entropy, temp_sum, dws[i], kernel[i], w)
348
+
349
+ temp_tup = (i, temp_score, temp_entropy ** w[0], temp_sum ** w[1], (kernel[i] + 1) ** w[2], dws[i] ** w[3])
350
+
351
+ # segments_scores.append((i, temp_score))
352
+ segments_scores.append(temp_tup)
353
+
354
+ # if temp_entropy > max_entropy:
355
+ # max_entropy = temp_entropy
356
+ # index = i
357
+
358
+ if temp_score > max_score:
359
+ max_score = temp_score
360
+ index = i
361
+
362
+ i += 1
363
+
364
+ # return max_entropy, index
365
+ return max_score, index
366
+
367
+
368
+ def make_gaussian(size, fwhm=10, center=None):
369
+ '''
370
+ Generates a 2D Gaussian kernel with the specified size and full-width-half-maximum (fwhm). It returns the Gaussian kernel.
371
+
372
+ size: length of a side of the square
373
+ fwhm: full-width-half-maximum, which can be thought of as an effective
374
+ radius.
375
+
376
+ https://gist.github.com/andrewgiessel/4635563
377
+ '''
378
+
379
+ x = np.arange(0, size, 1, float)
380
+ y = x[:, np.newaxis]
381
+
382
+ if center is None:
383
+ x0 = y0 = size // 2
384
+ else:
385
+ x0 = center[0]
386
+ y0 = center[1]
387
+
388
+
389
+ return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / fwhm ** 2)
390
+
391
+
392
+ def gen_depth_weights(d_segments, depth_map) -> list:
393
+ '''
394
+ Generates depth weights for the segments based on the depth map. It
395
+ returns a list of depth weights.
396
+ '''
397
+
398
+ hist_d, _ = np.histogram(depth_map, 256, [0, 256])
399
+
400
+ # Get first non-zero index
401
+ first_nz = next((i for i, x in enumerate(hist_d) if x), None)
402
+
403
+ # Get last non-zero index
404
+ rev = (len(hist_d) - idx for idx, item in enumerate(reversed(hist_d), 1) if item)
405
+ last_nz = next(rev, default=None)
406
+
407
+ mid = (first_nz + last_nz) / 2
408
+
409
+ for seg in d_segments:
410
+ hist, _ = np.histogram(seg, 256, [0, 256])
411
+ dw = 0
412
+ ind = 0
413
+ for s in hist:
414
+ if ind > mid:
415
+ dw = dw + (s * 1)
416
+ ind = ind + 1
417
+ dws.append(dw)
418
+
419
+ return dws
420
+
421
+
422
+ def gen_blank_depth_weight(d_segments):
423
+ '''
424
+ Generates blank depth weights for the segments. It returns a list of
425
+ depth weights.
426
+ '''
427
+
428
+ for _ in d_segments:
429
+ dw = 1
430
+ dws.append(dw)
431
+ return dws
432
+
433
+
434
+ # def generate_heatmap(img, mode, sorted_seg_scores, segments_coords) -> tuple:
435
+ # '''
436
+ # Generates a heatmap overlay on the input image img based on the
437
+ # provided sorted segment scores. The mode parameter determines the color
438
+ # scheme of the heatmap. It returns the image with the heatmap overlay
439
+ # and a list of segment scores.
440
+
441
+ # mode: 0 for white grid, 1 for color-coded grid
442
+ # '''
443
+
444
+ # font = cv2.FONT_HERSHEY_SIMPLEX
445
+ # # print_index = 0
446
+ # print_index = len(sorted_seg_scores) - 1
447
+ # set_value = int(0.25 * len(sorted_seg_scores))
448
+ # color = (0, 0, 0)
449
+
450
+ # max_x = 0
451
+ # max_y = 0
452
+
453
+ # overlay = np.zeros_like(img, dtype=np.uint8)
454
+ # text_overlay = np.zeros_like(img, dtype=np.uint8)
455
+
456
+ # sara_list_out = []
457
+
458
+ # for ent in reversed(sorted_seg_scores):
459
+ # quartile = 0
460
+ # if mode == 0:
461
+ # color = (255, 255, 255)
462
+ # t = 4
463
+ # elif mode == 1:
464
+ # if print_index + 1 <= set_value:
465
+ # color = (0, 0, 255, 255)
466
+ # t = 2
467
+ # quartile = 1
468
+ # elif print_index + 1 <= set_value * 2:
469
+ # color = (0, 128, 255, 192)
470
+ # t = 4
471
+ # quartile = 2
472
+ # elif print_index + 1 <= set_value * 3:
473
+ # color = (0, 255, 255, 128)
474
+ # t = 4
475
+ # t = 6
476
+ # quartile = 3
477
+ # # elif print_index + 1 <= set_value * 4:
478
+ # # color = (0, 250, 0, 64)
479
+ # # t = 8
480
+ # # quartile = 4
481
+ # else:
482
+ # color = (0, 250, 0, 64)
483
+ # t = 8
484
+ # quartile = 4
485
+
486
+
487
+ # x1 = segments_coords[ent[0]][1]
488
+ # y1 = segments_coords[ent[0]][2]
489
+ # x2 = segments_coords[ent[0]][3]
490
+ # y2 = segments_coords[ent[0]][4]
491
+
492
+ # if x2 > max_x:
493
+ # max_x = x2
494
+ # if y2 > max_y:
495
+ # max_y = y2
496
+
497
+ # x = int((x1 + x2) / 2)
498
+ # y = int((y1 + y2) / 2)
499
+
500
+
501
+
502
+ # # fill rectangle
503
+ # cv2.rectangle(overlay, (x1, y1), (x2, y2), color, -1)
504
+
505
+ # cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 0, 0), 1)
506
+ # # put text in the middle of the rectangle
507
+
508
+ # # white text
509
+ # cv2.putText(text_overlay, str(print_index), (x - 5, y),
510
+ # font, .4, (255, 255, 255), 1, cv2.LINE_AA)
511
+
512
+ # # Index, rank, score, entropy, entropy_sum, centre_bias, depth, quartile
513
+ # sara_tuple = (ent[0], print_index, ent[1], ent[2], ent[3], ent[4], ent[5], quartile)
514
+ # sara_list_out.append(sara_tuple)
515
+ # print_index -= 1
516
+
517
+ # # crop the overlay to up to x2 and y2
518
+ # overlay = overlay[0:max_y, 0:max_x]
519
+ # text_overlay = text_overlay[0:max_y, 0:max_x]
520
+ # img = img[0:max_y, 0:max_x]
521
+
522
+
523
+ # img = cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img)
524
+
525
+ # img[text_overlay > 128] = text_overlay[text_overlay > 128]
526
+
527
+
528
+ # return img, sara_list_out
529
+ def generate_heatmap(img, sorted_seg_scores, segments_coords, mode=1) -> tuple:
530
+ '''
531
+ Generates a more vibrant heatmap overlay on the input image img based on the
532
+ provided sorted segment scores. It returns the image with the heatmap overlay
533
+ and a list of segment scores with quartile information.
534
+
535
+ mode: 0 for white grid, 1 for color-coded grid, 2 for heatmap to be used as a feature
536
+ '''
537
+ alpha =0.3
538
+ if mode == 2:
539
+
540
+ font = cv2.FONT_HERSHEY_SIMPLEX
541
+ print_index = len(sorted_seg_scores) - 1
542
+ set_value = int(0.25 * len(sorted_seg_scores))
543
+
544
+ max_x = 0
545
+ max_y = 0
546
+
547
+ overlay = np.zeros_like(img, dtype=np.uint8)
548
+ text_overlay = np.zeros_like(img, dtype=np.uint8)
549
+
550
+ sara_list_out = []
551
+
552
+ scores = [score[1] for score in sorted_seg_scores]
553
+ min_score = min(scores)
554
+ max_score = max(scores)
555
+
556
+ # Choose a colormap from matplotlib
557
+ colormap = plt.get_cmap('jet') # 'jet', 'viridis', 'plasma', 'magma', 'cividis, jet_r, viridis_r, plasma_r, magma_r, cividis_r
558
+
559
+ for ent in reversed(sorted_seg_scores):
560
+ score = ent[1]
561
+ normalized_score = (score - min_score) / (max_score - min_score)
562
+ color_weight = normalized_score * score # Weighted color based on the score
563
+ color = np.array(colormap(normalized_score)[:3]) * 255 #* color_weight
564
+
565
+ x1 = segments_coords[ent[0]][1]
566
+ y1 = segments_coords[ent[0]][2]
567
+ x2 = segments_coords[ent[0]][3]
568
+ y2 = segments_coords[ent[0]][4]
569
+
570
+ if x2 > max_x:
571
+ max_x = x2
572
+ if y2 > max_y:
573
+ max_y = y2
574
+
575
+ x = int((x1 + x2) / 2)
576
+ y = int((y1 + y2) / 2)
577
+
578
+ # fill rectangle
579
+ cv2.rectangle(overlay, (x1, y1), (x2, y2), color, -1)
580
+ # black border
581
+ # cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 0, 0), 1)
582
+
583
+ # white text
584
+ # cv2.putText(text_overlay, str(print_index), (x - 5, y),
585
+ # font, .4, (255, 255, 255), 1, cv2.LINE_AA)
586
+
587
+ # Determine quartile based on print_index
588
+ if print_index + 1 <= set_value:
589
+ quartile = 1
590
+ elif print_index + 1 <= set_value * 2:
591
+ quartile = 2
592
+ elif print_index + 1 <= set_value * 3:
593
+ quartile = 3
594
+ else:
595
+ quartile = 4
596
+
597
+ sara_tuple = (ent[0], print_index, ent[1], ent[2], ent[3], ent[4], ent[5], quartile)
598
+ sara_list_out.append(sara_tuple)
599
+ print_index -= 1
600
+
601
+ overlay = overlay[0:max_y, 0:max_x]
602
+ text_overlay = text_overlay[0:max_y, 0:max_x]
603
+ img = img[0:max_y, 0:max_x]
604
+
605
+ # Create a blank grayscale image with the same dimensions as the original image
606
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
607
+
608
+ gray = cv2.merge([gray, gray, gray])
609
+
610
+ gray = cv2.addWeighted(overlay, alpha, gray, 1-alpha, 0, gray)
611
+ gray[text_overlay > 128] = text_overlay[text_overlay > 128]
612
+
613
+ return gray, sara_list_out
614
+ else:
615
+ font = cv2.FONT_HERSHEY_SIMPLEX
616
+ # print_index = 0
617
+ print_index = len(sorted_seg_scores) - 1
618
+ set_value = int(0.25 * len(sorted_seg_scores))
619
+ color = (0, 0, 0)
620
+
621
+ max_x = 0
622
+ max_y = 0
623
+
624
+ overlay = np.zeros_like(img, dtype=np.uint8)
625
+ text_overlay = np.zeros_like(img, dtype=np.uint8)
626
+
627
+ sara_list_out = []
628
+
629
+ for ent in reversed(sorted_seg_scores):
630
+ quartile = 0
631
+ if mode == 0:
632
+ color = (255, 255, 255)
633
+ t = 4
634
+ elif mode == 1:
635
+ if print_index + 1 <= set_value:
636
+ color = (0, 0, 255, 255)
637
+ t = 2
638
+ quartile = 1
639
+ elif print_index + 1 <= set_value * 2:
640
+ color = (0, 128, 255, 192)
641
+ t = 4
642
+ quartile = 2
643
+ elif print_index + 1 <= set_value * 3:
644
+ color = (0, 255, 255, 128)
645
+ t = 4
646
+ t = 6
647
+ quartile = 3
648
+ # elif print_index + 1 <= set_value * 4:
649
+ # color = (0, 250, 0, 64)
650
+ # t = 8
651
+ # quartile = 4
652
+ else:
653
+ color = (0, 250, 0, 64)
654
+ t = 8
655
+ quartile = 4
656
+
657
+
658
+ x1 = segments_coords[ent[0]][1]
659
+ y1 = segments_coords[ent[0]][2]
660
+ x2 = segments_coords[ent[0]][3]
661
+ y2 = segments_coords[ent[0]][4]
662
+
663
+ if x2 > max_x:
664
+ max_x = x2
665
+ if y2 > max_y:
666
+ max_y = y2
667
+
668
+ x = int((x1 + x2) / 2)
669
+ y = int((y1 + y2) / 2)
670
+
671
+
672
+
673
+ # fill rectangle
674
+ cv2.rectangle(overlay, (x1, y1), (x2, y2), color, -1)
675
+
676
+ cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 0, 0), 1)
677
+ # put text in the middle of the rectangle
678
+
679
+ # white text
680
+ cv2.putText(text_overlay, str(print_index), (x - 5, y),
681
+ font, .4, (255, 255, 255), 1, cv2.LINE_AA)
682
+
683
+ # Index, rank, score, entropy, entropy_sum, centre_bias, depth, quartile
684
+ sara_tuple = (ent[0], print_index, ent[1], ent[2], ent[3], ent[4], ent[5], quartile)
685
+ sara_list_out.append(sara_tuple)
686
+ print_index -= 1
687
+
688
+ # crop the overlay to up to x2 and y2
689
+ overlay = overlay[0:max_y, 0:max_x]
690
+ text_overlay = text_overlay[0:max_y, 0:max_x]
691
+ img = img[0:max_y, 0:max_x]
692
+
693
+
694
+ img = cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img)
695
+
696
+ img[text_overlay > 128] = text_overlay[text_overlay > 128]
697
+
698
+
699
+ return img, sara_list_out
700
+
701
+ def generate_sara(tex, tex_segments, mode=2):
702
+ '''
703
+ Generates the SaRa (Salient Region Annotation) output by calculating
704
+ saliency scores for the segments of the given texture image tex. It
705
+ returns the texture image with the heatmap overlay and a list of
706
+ segment scores.
707
+ '''
708
+
709
+ gaussian_kernel_array = make_gaussian(seg_dim)
710
+ gaussian1d = gaussian_kernel_array.ravel()
711
+
712
+ dws = gen_blank_depth_weight(tex_segments)
713
+
714
+ max_h, index = find_most_salient_segment(tex_segments, gaussian1d, dws)
715
+ # dict_entropies = dict(segments_entropies)
716
+ # segments_scores list with 5 elements, use index as key for dict and store rest as list of index
717
+ dict_scores = {}
718
+
719
+ for segment in segments_scores:
720
+ # Index: score, entropy, sum, depth, centre-bias
721
+ dict_scores[segment[0]] = [segment[1], segment[2], segment[3], segment[4], segment[5]]
722
+
723
+ # sorted_entropies = sorted(dict_entropies.items(),
724
+ # key=operator.itemgetter(1), reverse=True)
725
+
726
+
727
+ # sorted_scores = sorted(dict_scores.items(),
728
+ # key=operator.itemgetter(1), reverse=True)
729
+
730
+ # Sort by first value in value list
731
+ sorted_scores = sorted(dict_scores.items(), key=lambda x: x[1][0], reverse=True)
732
+
733
+ # flatten
734
+ sorted_scores = [[i[0], i[1][0], i[1][1], i[1][2], i[1][3], i[1][4]] for i in sorted_scores]
735
+
736
+ # tex_out, sara_list_out = generate_heatmap(
737
+ # tex, 1, sorted_entropies, segments_coords)
738
+
739
+ tex_out, sara_list_out = generate_heatmap(
740
+ tex, sorted_scores, segments_coords, mode = mode)
741
+
742
+ sara_list_out = list(reversed(sara_list_out))
743
+
744
+ return tex_out, sara_list_out
745
+
746
+
747
+ def return_sara(input_img, grid, generator='itti', saliency_map=None, mode = 2):
748
+ '''
749
+ Computes the SaRa output for the given input image. It uses the
750
+ generate_sara function internally. It returns the SaRa output image and
751
+ a list of segment scores.
752
+ '''
753
+
754
+ global seg_dim
755
+ seg_dim = grid
756
+
757
+ if saliency_map is None:
758
+ saliency_map = return_saliency(input_img, generator)
759
+
760
+ tex_segments = generate_segments(saliency_map, seg_dim)
761
+
762
+ # tex_segments = generate_segments(input_img, seg_dim)
763
+ sara_output, sara_list_output = generate_sara(input_img, tex_segments, mode=mode)
764
+
765
+ return sara_output, sara_list_output
766
+
767
+
768
+ def mean_squared_error(image_a, image_b) -> float:
769
+ '''
770
+ Calculates the Mean Squared Error (MSE), i.e. sum of squared
771
+ differences between two images image_a and image_b. It returns the MSE
772
+ value.
773
+
774
+ NOTE: The two images must have the same dimension
775
+ '''
776
+
777
+ err = np.sum((image_a.astype('float') - image_b.astype('float')) ** 2)
778
+ err /= float(image_a.shape[0] * image_a.shape[1])
779
+
780
+ return err
781
+
782
+
783
+ def reset():
784
+ '''
785
+ Resets all global variables to their default values.
786
+ '''
787
+
788
+ # global segments_entropies, segments_scores, segments_coords, seg_dim, segments, gt_segments, dws, sara_list
789
+
790
+ global segments_scores, segments_coords, seg_dim, segments, gt_segments, dws, sara_list
791
+
792
+ # segments_entropies = []
793
+ segments_scores = []
794
+ segments_coords = []
795
+
796
+ seg_dim = 0
797
+ segments = []
798
+ gt_segments = []
799
+ dws = []
800
+ sara_list = []
801
+
802
+
803
+
804
+ def resize_based_on_important_ranks(img, sara_info, grid_size, rate=0.3):
805
+ def generate_segments(image, seg_count) -> dict:
806
+ """
807
+ Function to generate segments of an image
808
+
809
+ Args:
810
+ image: input image
811
+ seg_count: number of segments to generate
812
+
813
+ Returns:
814
+ segments: dictionary of segments
815
+
816
+ """
817
+ # Initializing segments dictionary
818
+ segments = {}
819
+ # Initializing segment index and segment count
820
+ segment_count = seg_count
821
+ index = 0
822
+
823
+ # Retrieving image width and height
824
+ h, w = image.shape[:2]
825
+
826
+ # Calculating width and height intervals for segments from the segment count
827
+ w_interval = w // segment_count
828
+ h_interval = h // segment_count
829
+
830
+ # Iterating through the image and generating segments
831
+ for i in range(segment_count):
832
+ for j in range(segment_count):
833
+ # Calculating segment coordinates
834
+ x1, y1 = j * w_interval, i * h_interval
835
+ x2, y2 = x1 + w_interval, y1 + h_interval
836
+
837
+ # Adding segment coordinates to segments dictionary
838
+ segments[index] = (x1, y1, x2, y2)
839
+
840
+ # Incrementing segment index
841
+ index += 1
842
+
843
+ # Returning segments dictionary
844
+ return segments
845
+
846
+ # Retrieving important ranks from SaRa
847
+ sara_dict = {
848
+ info[0]: {
849
+ 'score': info[2],
850
+ 'index': info[1]
851
+ }
852
+ for info in sara_info[1]
853
+ }
854
+
855
+ # Sorting important ranks by score
856
+ sorted_sara_dict = sorted(sara_dict.items(), key=lambda item: item[1]['score'], reverse=True)
857
+
858
+ # Generating segments
859
+ index_info = generate_segments(img, grid_size)
860
+
861
+ # Initializing most important ranks image
862
+ most_imp_ranks = np.zeros_like(img)
863
+
864
+ # Calculating maximum rank
865
+ max_rank = int(grid_size * grid_size * rate)
866
+ count = 0
867
+
868
+ # Iterating through important ranks and adding them to most important ranks image
869
+ for rank, info in sorted_sara_dict:
870
+ # Checking if rank is within maximum rank
871
+ if count <= max_rank:
872
+ # Retrieving segment coordinates
873
+ coords = index_info[rank]
874
+
875
+ # Adding segment to most important ranks image by making it white
876
+ most_imp_ranks[coords[1]:coords[3], coords[0]:coords[2]] = 255
877
+
878
+ # Incrementing count
879
+ count += 1
880
+ else:
881
+ break
882
+
883
+ # Retrieving coordinates of most important ranks
884
+ coords = np.argwhere(most_imp_ranks == 255)
885
+
886
+ # Checking if no important ranks were found and returning original image
887
+ if coords.size == 0:
888
+ return img , most_imp_ranks, [0, 0, img.shape[0], img.shape[1]]
889
+
890
+ # Cropping image based on most important ranks
891
+ x0, y0 = coords.min(axis=0)[:2]
892
+ x1, y1 = coords.max(axis=0)[:2] + 1
893
+ cropped_img = img[x0:x1, y0:y1]
894
+ return cropped_img , most_imp_ranks, [x0, y0, x1, y1]
895
+
896
+ def sara_resize(img, sara_info, grid_size, rate=0.3, iterations=2):
897
+ """
898
+ Function to resize an image based on SaRa
899
+
900
+ Args:
901
+ img: input image
902
+ sara_info: SaRa information
903
+ grid_size: size of the grid
904
+ rate: rate of important ranks
905
+ iterations: number of iterations to resize
906
+
907
+ Returns:
908
+ img: resized image
909
+ """
910
+ # Iterating through iterations
911
+ for _ in range(iterations):
912
+ # Resizing image based on important ranks
913
+ img, most_imp_ranks, coords = resize_based_on_important_ranks(img, sara_info, grid_size, rate=rate)
914
+
915
+ # Returning resized image
916
+ return img, most_imp_ranks, coords
917
+
918
+ def plot_3D(img, sara_info, grid_size, rate=0.3):
919
+ def generate_segments(image, seg_count) -> dict:
920
+ """
921
+ Function to generate segments of an image
922
+
923
+ Args:
924
+ image: input image
925
+ seg_count: number of segments to generate
926
+
927
+ Returns:
928
+ segments: dictionary of segments
929
+
930
+ """
931
+ # Initializing segments dictionary
932
+ segments = {}
933
+ # Initializing segment index and segment count
934
+ segment_count = seg_count
935
+ index = 0
936
+
937
+ # Retrieving image width and height
938
+ h, w = image.shape[:2]
939
+
940
+ # Calculating width and height intervals for segments from the segment count
941
+ w_interval = w // segment_count
942
+ h_interval = h // segment_count
943
+
944
+ # Iterating through the image and generating segments
945
+ for i in range(segment_count):
946
+ for j in range(segment_count):
947
+ # Calculating segment coordinates
948
+ x1, y1 = j * w_interval, i * h_interval
949
+ x2, y2 = x1 + w_interval, y1 + h_interval
950
+
951
+ # Adding segment coordinates to segments dictionary
952
+ segments[index] = (x1, y1, x2, y2)
953
+
954
+ # Incrementing segment index
955
+ index += 1
956
+
957
+ # Returning segments dictionary
958
+ return segments
959
+
960
+ # Extracting heatmap from SaRa information
961
+ heatmap = sara_info[0]
962
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
963
+
964
+ # Retrieving important ranks from SaRa
965
+ sara_dict = {
966
+ info[0]: {
967
+ 'score': info[2],
968
+ 'index': info[1]
969
+ }
970
+ for info in sara_info[1]
971
+ }
972
+
973
+ # Sorting important ranks by score
974
+ sorted_sara_dict = sorted(sara_dict.items(), key=lambda item: item[1]['score'], reverse=True)
975
+
976
+ # Generating segments
977
+ index_info = generate_segments(img, grid_size)
978
+
979
+ # Calculating maximum rank
980
+ max_rank = int(grid_size * grid_size * rate)
981
+ count = 0
982
+
983
+ # Normalizing heatmap
984
+ heatmap = heatmap.astype(float) / 255.0
985
+
986
+ # Creating a figure
987
+ fig = plt.figure(figsize=(20, 10))
988
+
989
+ # Creating a 3D plot
990
+ ax = fig.add_subplot(111, projection='3d')
991
+
992
+ # Defining the x and y coordinates for the heatmap
993
+ x_coords = np.linspace(0, 1, heatmap.shape[1])
994
+ y_coords = np.linspace(0, 1, heatmap.shape[0])
995
+ x, y = np.meshgrid(x_coords, y_coords)
996
+
997
+ # Defining the z-coordinate for the heatmap (a constant, such as -5)
998
+ z = np.asarray([[-10] * heatmap.shape[1]] * heatmap.shape[0])
999
+
1000
+ # Plotting the heatmap as a texture on the xy-plane
1001
+ ax.plot_surface(x, y, z, facecolors=heatmap, rstride=1, cstride=1, shade=False)
1002
+
1003
+ # Initializing the single distribution array
1004
+ single_distribution = np.asarray([[1e-6] * heatmap.shape[1]] * heatmap.shape[0], dtype=float)
1005
+
1006
+ importance = 0
1007
+ # Creating the single distribution by summing up Gaussian distributions for each segment
1008
+ for rank, info in sorted_sara_dict:
1009
+ # Retrieving segment coordinates
1010
+ coords = index_info[rank]
1011
+
1012
+ # Creating a Gaussian distribution for the whole segment, i.e., arrange all the pixels in the segment in a 3D Gaussian distribution
1013
+ x_temp = np.linspace(0, 1, coords[2] - coords[0])
1014
+ y_temp = np.linspace(0, 1, coords[3] - coords[1])
1015
+
1016
+ # Creating a meshgrid
1017
+ x_temp, y_temp = np.meshgrid(x_temp, y_temp)
1018
+
1019
+ # Calculating the Gaussian distribution
1020
+ distribution = np.exp(-((x_temp - 0.5) ** 2 + (y_temp - 0.5) ** 2) / 0.1) * ((grid_size ** 2 - importance) / grid_size ** 2) # (constant)
1021
+
1022
+ # Adding the Gaussian distribution to the single distribution
1023
+ single_distribution[coords[1]:coords[3], coords[0]:coords[2]] += distribution
1024
+
1025
+ # Incrementing importance
1026
+ importance +=1
1027
+
1028
+ # Based on the rate, calculating the minimum number for the most important ranks
1029
+ min_rank = int(grid_size * grid_size * rate)
1030
+
1031
+ # Calculating the scale factor for the single distribution
1032
+ scale_factor = ((grid_size ** 2 - min_rank) / grid_size ** 2) * 5
1033
+
1034
+ # Scaling the distribution
1035
+ single_distribution *= scale_factor
1036
+
1037
+ # Retrieving the max and min values of the single distribution
1038
+ max_value = np.max(single_distribution)
1039
+ min_value = np.min(single_distribution)
1040
+
1041
+ # Calculating the hyperplane
1042
+ hyperplane = np.asarray([[(max_value - min_value)* (1 - rate) + min_value] * heatmap.shape[1]] * heatmap.shape[0])
1043
+
1044
+ # Plotting a horizontal plane at the minimum rank level (hyperplane)
1045
+ ax.plot_surface(x, y, hyperplane, rstride=1, cstride=1, color='red', alpha=0.3, shade=False)
1046
+
1047
+ # Plotting the single distribution as a wireframe on the xy-plane
1048
+ ax.plot_surface(x, y, single_distribution, rstride=1, cstride=1, color='blue', shade=False)
1049
+
1050
+ # Setting the title
1051
+ ax.set_title('SaRa 3D Heatmap Plot', fontsize=20)
1052
+
1053
+ # Setting the labels
1054
+ ax.set_xlabel('X', fontsize=16)
1055
+ ax.set_ylabel('Y', fontsize=16)
1056
+ ax.set_zlabel('Z', fontsize=16)
1057
+
1058
+ # Setting the viewing angle to look from the y, x diagonal position
1059
+ ax.view_init(elev=30, azim=45) # Adjust the elevation (elev) and azimuth (azim) angles as needed
1060
+ # ax.view_init(elev=0, azim=0) # View from the top
1061
+
1062
+ # Adding legend to the plot
1063
+ # Creating Line2D objects for the legend
1064
+ legend_elements = [Line2D([0], [0], color='blue', lw=4, label='Rank Distribution'),
1065
+ Line2D([0], [0], color='red', lw=4, label='Threshold Hyperplane ({}%)'.format(rate*100)),
1066
+ Line2D([0], [0], color='green', lw=4, label='SaRa Heatmap')]
1067
+
1068
+ # Creating the legend
1069
+ plt.subplots_adjust(right=0.5)
1070
+ ax.legend(handles=legend_elements, fontsize=16, loc='center left', bbox_to_anchor=(1, 0.5))
1071
+
1072
+ # Inverting the x axis
1073
+ ax.invert_xaxis()
1074
+
1075
+ # Removing labels
1076
+ ax.set_xticks([])
1077
+ ax.set_yticks([])
1078
+ ax.set_zticks([])
1079
+
1080
+ # Showing the plot
1081
+ plt.show()
1082
+
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import gradio as gr
3
+ import numpy as np
4
+ import cv2
5
+ import SaRa.saraRC1 as sara
6
+ import warnings
7
+ warnings.filterwarnings("ignore")
8
+
9
+
10
+ ALPHA = 0.3
11
+ GENERATORS = ['itti', 'deepgaze']
12
+
13
+ MARKDOWN = """
14
+ <h1 style='text-align: center'>Saliency Ranking: Itti vs. Deepgaze</h1>
15
+ """
16
+
17
+ IMAGE_EXAMPLES = [
18
+ ['https://media.roboflow.com/supervision/image-examples/people-walking.png', 9],
19
+ ['https://media.roboflow.com/supervision/image-examples/vehicles.png', 9],
20
+ ['https://media.roboflow.com/supervision/image-examples/basketball-1.png', 9],
21
+ ]
22
+
23
+ def detect_and_annotate(image: np.ndarray,
24
+ GRID_SIZE: int,
25
+ generator: str,
26
+ ALPHA:float =ALPHA)-> np.ndarray:
27
+ # Convert image from BGR to RGB
28
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
29
+
30
+ # Copy and convert the image for sara processing
31
+ sara_image = image.copy()
32
+ sara_image = cv2.cvtColor(sara_image, cv2.COLOR_RGB2BGR)
33
+
34
+ # Resetting sara
35
+ sara.reset()
36
+
37
+ # Running sara (Original implementation on itti)
38
+ sara_info = sara.return_sara(sara_image, GRID_SIZE, generator, mode=1)
39
+
40
+ # Generate saliency map
41
+ saliency_map = sara.return_saliency(image, generator=generator)
42
+ # Resize saliency map to match the image size
43
+ saliency_map = cv2.resize(saliency_map, (image.shape[1], image.shape[0]))
44
+
45
+ # Apply color map and convert to RGB
46
+ saliency_map = cv2.applyColorMap(saliency_map, cv2.COLORMAP_JET)
47
+ saliency_map = cv2.cvtColor(saliency_map, cv2.COLOR_BGR2RGB)
48
+
49
+ # Overlay the saliency map on the original image
50
+ saliency_map = cv2.addWeighted(saliency_map, ALPHA, image, 1-ALPHA, 0)
51
+
52
+ # Extract and convert heatmap to RGB
53
+ heatmap = sara_info[0]
54
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
55
+
56
+ return saliency_map, heatmap
57
+
58
+ def process_image(
59
+ input_image: np.ndarray,
60
+ GRIDSIZE: int,
61
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
62
+ # Validate GRID_SIZE
63
+ if GRIDSIZE is None and GRIDSIZE < 4:
64
+ GRIDSIZE = 9
65
+
66
+ itti_saliency_map, itti_heatmap = detect_and_annotate(
67
+ input_image, sara, GRIDSIZE, 'itti')
68
+ deepgaze_saliency_map, deepgaze_heatmap = detect_and_annotate(
69
+ input_image, sara, GRIDSIZE, 'deepgaze')
70
+
71
+ return (
72
+ itti_saliency_map,
73
+ itti_heatmap,
74
+ deepgaze_saliency_map,
75
+ deepgaze_heatmap,
76
+ )
77
+
78
+ grid_size_Component = gr.Slider(
79
+ minimum=4,
80
+ maximum=70,
81
+ value=9,
82
+ step=1,
83
+ label="Grid Size",
84
+ info=(
85
+ "The grid size for the Saliency Ranking (SaRa) model. The grid size determines "
86
+ "the number of regions the image is divided into. A higher grid size results in "
87
+ "more regions and a lower grid size results in fewer regions. The default grid "
88
+ "size is 9."
89
+ ))
90
+
91
+
92
+ with gr.Blocks() as demo:
93
+ gr.Markdown(MARKDOWN)
94
+ with gr.Accordion("Configuration", open=False):
95
+ with gr.Row():
96
+ grid_size_Component.render()
97
+ with gr.Row():
98
+ input_image_component = gr.Image(
99
+ type='pil',
100
+ label='Input'
101
+ )
102
+ with gr.Row():
103
+ itti_saliency_map = gr.Image(
104
+ type='pil',
105
+ label='Itti Saliency Map'
106
+ )
107
+ itti_heatmap = gr.Image(
108
+ type='pil',
109
+ label='Itti Saliency Ranking Heatmap'
110
+ )
111
+ with gr.Row():
112
+ deepgaze_saliency_map = gr.Image(
113
+ type='pil',
114
+ label='DeepGaze Saliency Map'
115
+ )
116
+ deepgaze_heatmap = gr.Image(
117
+ type='pil',
118
+ label='DeepGaze Saliency Ranking Heatmap'
119
+ )
120
+ submit_button_component = gr.Button(
121
+ value='Submit',
122
+ scale=1,
123
+ variant='primary'
124
+ )
125
+ gr.Examples(
126
+ fn=process_image,
127
+ examples=IMAGE_EXAMPLES,
128
+ inputs=[
129
+ input_image_component,
130
+ grid_size_Component,
131
+ ],
132
+ outputs=[
133
+ itti_saliency_map,
134
+ itti_heatmap,
135
+ deepgaze_saliency_map,
136
+ deepgaze_heatmap,
137
+ ]
138
+ )
139
+
140
+ submit_button_component.click(
141
+ fn=process_image,
142
+ inputs=[
143
+ input_image_component,
144
+ grid_size_Component,
145
+ ],
146
+ outputs=[
147
+ itti_saliency_map,
148
+ itti_heatmap,
149
+ deepgaze_saliency_map,
150
+ deepgaze_heatmap,
151
+ ]
152
+ )
153
+
154
+ demo.launch(debug=False, show_error=True, max_threads=1)
deepgaze_pytorch/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .deepgaze1 import DeepGazeI
2
+ from .deepgaze2e import DeepGazeIIE
3
+ from .deepgaze3 import DeepGazeIII
deepgaze_pytorch/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (289 Bytes). View file
 
deepgaze_pytorch/__pycache__/deepgaze1.cpython-39.pyc ADDED
Binary file (2.12 kB). View file
 
deepgaze_pytorch/__pycache__/deepgaze2e.cpython-39.pyc ADDED
Binary file (4.29 kB). View file
 
deepgaze_pytorch/__pycache__/deepgaze3.cpython-39.pyc ADDED
Binary file (3.48 kB). View file
 
deepgaze_pytorch/__pycache__/layers.cpython-39.pyc ADDED
Binary file (13.7 kB). View file
 
deepgaze_pytorch/__pycache__/modules.cpython-39.pyc ADDED
Binary file (10.4 kB). View file
 
deepgaze_pytorch/data.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ import io
3
+ import os
4
+ import pickle
5
+ import random
6
+
7
+ from boltons.iterutils import chunked
8
+ import lmdb
9
+ import numpy as np
10
+ from PIL import Image
11
+ import pysaliency
12
+ from pysaliency.datasets import create_subset
13
+ from pysaliency.utils import remove_trailing_nans
14
+ import torch
15
+ from tqdm import tqdm
16
+
17
+
18
+ def ensure_color_image(image):
19
+ if len(image.shape) == 2:
20
+ return np.dstack([image, image, image])
21
+ return image
22
+
23
+
24
+ def x_y_to_sparse_indices(xs, ys):
25
+ # Converts list of x and y coordinates into indices and values for sparse mask
26
+ x_inds = []
27
+ y_inds = []
28
+ values = []
29
+ pair_inds = {}
30
+
31
+ for x, y in zip(xs, ys):
32
+ key = (x, y)
33
+ if key not in pair_inds:
34
+ x_inds.append(x)
35
+ y_inds.append(y)
36
+ pair_inds[key] = len(x_inds) - 1
37
+ values.append(1)
38
+ else:
39
+ values[pair_inds[key]] += 1
40
+
41
+ return np.array([y_inds, x_inds]), values
42
+
43
+
44
+ class ImageDataset(torch.utils.data.Dataset):
45
+ def __init__(
46
+ self,
47
+ stimuli,
48
+ fixations,
49
+ centerbias_model=None,
50
+ lmdb_path=None,
51
+ transform=None,
52
+ cached=None,
53
+ average='fixation'
54
+ ):
55
+ self.stimuli = stimuli
56
+ self.fixations = fixations
57
+ self.centerbias_model = centerbias_model
58
+ self.lmdb_path = lmdb_path
59
+ self.transform = transform
60
+ self.average = average
61
+
62
+ # cache only short dataset
63
+ if cached is None:
64
+ cached = len(self.stimuli) < 100
65
+
66
+ cache_fixation_data = cached
67
+
68
+ if lmdb_path is not None:
69
+ _export_dataset_to_lmdb(stimuli, centerbias_model, lmdb_path)
70
+ self.lmdb_env = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path),
71
+ readonly=True, lock=False,
72
+ readahead=False, meminit=False
73
+ )
74
+ cached = False
75
+ cache_fixation_data = True
76
+ else:
77
+ self.lmdb_env = None
78
+
79
+ self.cached = cached
80
+ if cached:
81
+ self._cache = {}
82
+ self.cache_fixation_data = cache_fixation_data
83
+ if cache_fixation_data:
84
+ print("Populating fixations cache")
85
+ self._xs_cache = {}
86
+ self._ys_cache = {}
87
+
88
+ for x, y, n in zip(self.fixations.x_int, self.fixations.y_int, tqdm(self.fixations.n)):
89
+ self._xs_cache.setdefault(n, []).append(x)
90
+ self._ys_cache.setdefault(n, []).append(y)
91
+
92
+ for key in list(self._xs_cache):
93
+ self._xs_cache[key] = np.array(self._xs_cache[key], dtype=int)
94
+ for key in list(self._ys_cache):
95
+ self._ys_cache[key] = np.array(self._ys_cache[key], dtype=int)
96
+
97
+ def get_shapes(self):
98
+ return list(self.stimuli.sizes)
99
+
100
+ def _get_image_data(self, n):
101
+ if self.lmdb_env:
102
+ image, centerbias_prediction = _get_image_data_from_lmdb(self.lmdb_env, n)
103
+ else:
104
+ image = np.array(self.stimuli.stimuli[n])
105
+ centerbias_prediction = self.centerbias_model.log_density(image)
106
+
107
+ image = ensure_color_image(image).astype(np.float32)
108
+ image = image.transpose(2, 0, 1)
109
+
110
+ return image, centerbias_prediction
111
+
112
+ def __getitem__(self, key):
113
+ if not self.cached or key not in self._cache:
114
+
115
+ image, centerbias_prediction = self._get_image_data(key)
116
+ centerbias_prediction = centerbias_prediction.astype(np.float32)
117
+
118
+ if self.cache_fixation_data and self.cached:
119
+ xs = self._xs_cache.pop(key)
120
+ ys = self._ys_cache.pop(key)
121
+ elif self.cache_fixation_data and not self.cached:
122
+ xs = self._xs_cache[key]
123
+ ys = self._ys_cache[key]
124
+ else:
125
+ inds = self.fixations.n == key
126
+ xs = np.array(self.fixations.x_int[inds], dtype=int)
127
+ ys = np.array(self.fixations.y_int[inds], dtype=int)
128
+
129
+ data = {
130
+ "image": image,
131
+ "x": xs,
132
+ "y": ys,
133
+ "centerbias": centerbias_prediction,
134
+ }
135
+
136
+ if self.average == 'image':
137
+ data['weight'] = 1.0
138
+ else:
139
+ data['weight'] = float(len(xs))
140
+
141
+ if self.cached:
142
+ self._cache[key] = data
143
+ else:
144
+ data = self._cache[key]
145
+
146
+ if self.transform is not None:
147
+ return self.transform(dict(data))
148
+
149
+ return data
150
+
151
+ def __len__(self):
152
+ return len(self.stimuli)
153
+
154
+
155
+ class FixationDataset(torch.utils.data.Dataset):
156
+ def __init__(
157
+ self,
158
+ stimuli, fixations,
159
+ centerbias_model=None,
160
+ lmdb_path=None,
161
+ transform=None,
162
+ included_fixations=-2,
163
+ allow_missing_fixations=False,
164
+ average='fixation',
165
+ cache_image_data=False,
166
+ ):
167
+ self.stimuli = stimuli
168
+ self.fixations = fixations
169
+ self.centerbias_model = centerbias_model
170
+ self.lmdb_path = lmdb_path
171
+
172
+ if lmdb_path is not None:
173
+ _export_dataset_to_lmdb(stimuli, centerbias_model, lmdb_path)
174
+ self.lmdb_env = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path),
175
+ readonly=True, lock=False,
176
+ readahead=False, meminit=False
177
+ )
178
+ cache_image_data=False
179
+ else:
180
+ self.lmdb_env = None
181
+
182
+ self.transform = transform
183
+ self.average = average
184
+
185
+ self._shapes = None
186
+
187
+ if isinstance(included_fixations, int):
188
+ if included_fixations < 0:
189
+ included_fixations = [-1 - i for i in range(-included_fixations)]
190
+ else:
191
+ raise NotImplementedError()
192
+
193
+ self.included_fixations = included_fixations
194
+ self.allow_missing_fixations = allow_missing_fixations
195
+ self.fixation_counts = Counter(fixations.n)
196
+
197
+ self.cache_image_data = cache_image_data
198
+
199
+ if self.cache_image_data:
200
+ self.image_data_cache = {}
201
+
202
+ print("Populating image cache")
203
+ for n in tqdm(range(len(self.stimuli))):
204
+ self.image_data_cache[n] = self._get_image_data(n)
205
+
206
+ def get_shapes(self):
207
+ if self._shapes is None:
208
+ shapes = list(self.stimuli.sizes)
209
+ self._shapes = [shapes[n] for n in self.fixations.n]
210
+
211
+ return self._shapes
212
+
213
+ def _get_image_data(self, n):
214
+ if self.lmdb_path:
215
+ return _get_image_data_from_lmdb(self.lmdb_env, n)
216
+ image = np.array(self.stimuli.stimuli[n])
217
+ centerbias_prediction = self.centerbias_model.log_density(image)
218
+
219
+ image = ensure_color_image(image).astype(np.float32)
220
+ image = image.transpose(2, 0, 1)
221
+
222
+ return image, centerbias_prediction
223
+
224
+ def __getitem__(self, key):
225
+ n = self.fixations.n[key]
226
+
227
+ if self.cache_image_data:
228
+ image, centerbias_prediction = self.image_data_cache[n]
229
+ else:
230
+ image, centerbias_prediction = self._get_image_data(n)
231
+
232
+ centerbias_prediction = centerbias_prediction.astype(np.float32)
233
+
234
+ x_hist = remove_trailing_nans(self.fixations.x_hist[key])
235
+ y_hist = remove_trailing_nans(self.fixations.y_hist[key])
236
+
237
+ if self.allow_missing_fixations:
238
+ _x_hist = []
239
+ _y_hist = []
240
+ for fixation_index in self.included_fixations:
241
+ if fixation_index < -len(x_hist):
242
+ _x_hist.append(np.nan)
243
+ _y_hist.append(np.nan)
244
+ else:
245
+ _x_hist.append(x_hist[fixation_index])
246
+ _y_hist.append(y_hist[fixation_index])
247
+ x_hist = np.array(_x_hist)
248
+ y_hist = np.array(_y_hist)
249
+ else:
250
+ print("Not missing")
251
+ x_hist = x_hist[self.included_fixations]
252
+ y_hist = y_hist[self.included_fixations]
253
+
254
+ data = {
255
+ "image": image,
256
+ "x": np.array([self.fixations.x_int[key]], dtype=int),
257
+ "y": np.array([self.fixations.y_int[key]], dtype=int),
258
+ "x_hist": x_hist,
259
+ "y_hist": y_hist,
260
+ "centerbias": centerbias_prediction,
261
+ }
262
+
263
+ if self.average == 'image':
264
+ data['weight'] = 1.0 / self.fixation_counts[n]
265
+ else:
266
+ data['weight'] = 1.0
267
+
268
+ if self.transform is not None:
269
+ return self.transform(data)
270
+
271
+ return data
272
+
273
+ def __len__(self):
274
+ return len(self.fixations)
275
+
276
+
277
+ class FixationMaskTransform(object):
278
+ def __init__(self, sparse=True):
279
+ super().__init__()
280
+ self.sparse = sparse
281
+
282
+ def __call__(self, item):
283
+ shape = torch.Size([item['image'].shape[1], item['image'].shape[2]])
284
+ x = item.pop('x')
285
+ y = item.pop('y')
286
+
287
+ # inds, values = x_y_to_sparse_indices(x, y)
288
+ inds = np.array([y, x])
289
+ values = np.ones(len(y), dtype=int)
290
+
291
+ mask = torch.sparse.IntTensor(torch.tensor(inds), torch.tensor(values), shape)
292
+ mask = mask.coalesce()
293
+ # sparse tensors don't work with workers...
294
+ if not self.sparse:
295
+ mask = mask.to_dense()
296
+
297
+ item['fixation_mask'] = mask
298
+
299
+ return item
300
+
301
+
302
+ class ImageDatasetSampler(torch.utils.data.Sampler):
303
+ def __init__(self, data_source, batch_size=1, ratio_used=1.0, shuffle=True):
304
+ self.ratio_used = ratio_used
305
+ self.shuffle = shuffle
306
+
307
+ shapes = data_source.get_shapes()
308
+ unique_shapes = sorted(set(shapes))
309
+
310
+ shape_indices = [[] for shape in unique_shapes]
311
+
312
+ for k, shape in enumerate(shapes):
313
+ shape_indices[unique_shapes.index(shape)].append(k)
314
+
315
+ if self.shuffle:
316
+ for indices in shape_indices:
317
+ random.shuffle(indices)
318
+
319
+ self.batches = sum([chunked(indices, size=batch_size) for indices in shape_indices], [])
320
+
321
+ def __iter__(self):
322
+ if self.shuffle:
323
+ indices = torch.randperm(len(self.batches))
324
+ else:
325
+ indices = range(len(self.batches))
326
+
327
+ if self.ratio_used < 1.0:
328
+ indices = indices[:int(self.ratio_used * len(indices))]
329
+
330
+ return iter(self.batches[i] for i in indices)
331
+
332
+ def __len__(self):
333
+ return int(self.ratio_used * len(self.batches))
334
+
335
+
336
+ def _export_dataset_to_lmdb(stimuli: pysaliency.FileStimuli, centerbias_model: pysaliency.Model, lmdb_path, write_frequency=100):
337
+ lmdb_path = os.path.expanduser(lmdb_path)
338
+ isdir = os.path.isdir(lmdb_path)
339
+
340
+ print("Generate LMDB to %s" % lmdb_path)
341
+ db = lmdb.open(lmdb_path, subdir=isdir,
342
+ map_size=1099511627776 * 2, readonly=False,
343
+ meminit=False, map_async=True)
344
+
345
+ txn = db.begin(write=True)
346
+ for idx, stimulus in enumerate(tqdm(stimuli)):
347
+ key = u'{}'.format(idx).encode('ascii')
348
+
349
+ previous_data = txn.get(key)
350
+ if previous_data:
351
+ continue
352
+
353
+ #timulus_data = stimulus.stimulus_data
354
+ stimulus_filename = stimuli.filenames[idx]
355
+ centerbias = centerbias_model.log_density(stimulus)
356
+
357
+ txn.put(
358
+ key,
359
+ _encode_filestimulus_item(stimulus_filename, centerbias)
360
+ )
361
+ if idx % write_frequency == 0:
362
+ #print("[%d/%d]" % (idx, len(stimuli)))
363
+ #print("stimulus ids", len(stimuli.stimulus_ids._cache))
364
+ #print("stimuli.cached", stimuli.cached)
365
+ #print("stimuli", len(stimuli.stimuli._cache))
366
+ #print("centerbias", len(centerbias_model._cache._cache))
367
+ txn.commit()
368
+ txn = db.begin(write=True)
369
+
370
+ # finish iterating through dataset
371
+ txn.commit()
372
+ #keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
373
+ #with db.begin(write=True) as txn:
374
+ # txn.put(b'__keys__', dumps_pyarrow(keys))
375
+ # txn.put(b'__len__', dumps_pyarrow(len(keys)))
376
+
377
+ print("Flushing database ...")
378
+ db.sync()
379
+ db.close()
380
+
381
+
382
+ def _encode_filestimulus_item(filename, centerbias):
383
+ with open(filename, 'rb') as f:
384
+ image_bytes = f.read()
385
+
386
+ buffer = io.BytesIO()
387
+ pickle.dump({'image': image_bytes, 'centerbias': centerbias}, buffer)
388
+ buffer.seek(0)
389
+ return buffer.read()
390
+
391
+
392
+ def _get_image_data_from_lmdb(lmdb_env, n):
393
+ key = '{}'.format(n).encode('ascii')
394
+ with lmdb_env.begin(write=False) as txn:
395
+ byteflow = txn.get(key)
396
+ data = pickle.loads(byteflow)
397
+ buffer = io.BytesIO(data['image'])
398
+ buffer.seek(0)
399
+ image = np.array(Image.open(buffer).convert('RGB'))
400
+ centerbias_prediction = data['centerbias']
401
+ image = image.transpose(2, 0, 1)
402
+
403
+ return image, centerbias_prediction
deepgaze_pytorch/deepgaze1.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from torch.utils import model_zoo
7
+
8
+ from .features.alexnet import RGBalexnet
9
+ from .modules import FeatureExtractor, Finalizer, DeepGazeII as TorchDeepGazeII
10
+
11
+
12
+ class DeepGazeI(TorchDeepGazeII):
13
+ """DeepGaze I model
14
+
15
+ Please note that this version of DeepGaze I is not exactly the one from the original paper.
16
+ The original model used caffe for AlexNet and theano for the linear readout and was trained using the SFO optimizer.
17
+ Here, we use the torch implementation of AlexNet (without any adaptations), which doesn't use the two-steam architecture,
18
+ and the DeepGaze II torch implementation with a simple linear readout network.
19
+ The model has been retrained with Adam, but still on the same dataset (all images of MIT1003 which are of size 1024x768).
20
+ Also, we don't use the sparsity penalty anymore.
21
+
22
+ Reference:
23
+ Kümmerer, M., Theis, L., & Bethge, M. (2015). Deep Gaze I: Boosting Saliency Prediction with Feature Maps Trained on ImageNet. ICLR Workshop Track. http://arxiv.org/abs/1411.1045
24
+ """
25
+ def __init__(self, pretrained=True):
26
+ features = RGBalexnet()
27
+ feature_extractor = FeatureExtractor(features, ['1.features.10'])
28
+
29
+ readout_network = nn.Sequential(OrderedDict([
30
+ ('conv0', nn.Conv2d(256, 1, (1, 1), bias=False)),
31
+ ]))
32
+
33
+ super().__init__(
34
+ features=feature_extractor,
35
+ readout_network=readout_network,
36
+ downsample=2,
37
+ readout_factor=4,
38
+ saliency_map_factor=4,
39
+ )
40
+
41
+ if pretrained:
42
+ self.load_state_dict(model_zoo.load_url('https://github.com/matthias-k/DeepGaze/releases/download/v1.01/deepgaze1.pth', map_location=torch.device('cpu')))
deepgaze_pytorch/deepgaze2e.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import importlib
3
+ import os
4
+
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from torch.utils import model_zoo
11
+
12
+ from .modules import FeatureExtractor, Finalizer, DeepGazeIIIMixture, MixtureModel
13
+
14
+ from .layers import (
15
+ Conv2dMultiInput,
16
+ LayerNorm,
17
+ LayerNormMultiInput,
18
+ Bias,
19
+ )
20
+
21
+
22
+ BACKBONES = [
23
+ {
24
+ 'type': 'deepgaze_pytorch.features.shapenet.RGBShapeNetC',
25
+ 'used_features': [
26
+ '1.module.layer3.0.conv2',
27
+ '1.module.layer3.3.conv2',
28
+ '1.module.layer3.5.conv1',
29
+ '1.module.layer3.5.conv2',
30
+ '1.module.layer4.1.conv2',
31
+ '1.module.layer4.2.conv2',
32
+ ],
33
+ 'channels': 2048,
34
+ },
35
+ {
36
+ 'type': 'deepgaze_pytorch.features.efficientnet.RGBEfficientNetB5',
37
+ 'used_features': [
38
+ '1._blocks.24._depthwise_conv',
39
+ '1._blocks.26._depthwise_conv',
40
+ '1._blocks.35._project_conv',
41
+ ],
42
+ 'channels': 2416,
43
+ },
44
+ {
45
+ 'type': 'deepgaze_pytorch.features.densenet.RGBDenseNet201',
46
+ 'used_features': [
47
+ '1.features.denseblock4.denselayer32.norm1',
48
+ '1.features.denseblock4.denselayer32.conv1',
49
+ '1.features.denseblock4.denselayer31.conv2',
50
+ ],
51
+ 'channels': 2048,
52
+ },
53
+ {
54
+ 'type': 'deepgaze_pytorch.features.resnext.RGBResNext50',
55
+ 'used_features': [
56
+ '1.layer3.5.conv1',
57
+ '1.layer3.5.conv2',
58
+ '1.layer3.4.conv2',
59
+ '1.layer4.2.conv2',
60
+ ],
61
+ 'channels': 2560,
62
+ },
63
+ ]
64
+
65
+
66
+ def build_saliency_network(input_channels):
67
+ return nn.Sequential(OrderedDict([
68
+ ('layernorm0', LayerNorm(input_channels)),
69
+ ('conv0', nn.Conv2d(input_channels, 8, (1, 1), bias=False)),
70
+ ('bias0', Bias(8)),
71
+ ('softplus0', nn.Softplus()),
72
+
73
+ ('layernorm1', LayerNorm(8)),
74
+ ('conv1', nn.Conv2d(8, 16, (1, 1), bias=False)),
75
+ ('bias1', Bias(16)),
76
+ ('softplus1', nn.Softplus()),
77
+
78
+ ('layernorm2', LayerNorm(16)),
79
+ ('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
80
+ ('bias2', Bias(1)),
81
+ ('softplus3', nn.Softplus()),
82
+ ]))
83
+
84
+
85
+ def build_fixation_selection_network():
86
+ return nn.Sequential(OrderedDict([
87
+ ('layernorm0', LayerNormMultiInput([1, 0])),
88
+ ('conv0', Conv2dMultiInput([1, 0], 128, (1, 1), bias=False)),
89
+ ('bias0', Bias(128)),
90
+ ('softplus0', nn.Softplus()),
91
+
92
+ ('layernorm1', LayerNorm(128)),
93
+ ('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)),
94
+ ('bias1', Bias(16)),
95
+ ('softplus1', nn.Softplus()),
96
+
97
+ ('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
98
+ ]))
99
+
100
+
101
+ def build_deepgaze_mixture(backbone_config, components=10):
102
+ feature_class = import_class(backbone_config['type'])
103
+ features = feature_class()
104
+
105
+ feature_extractor = FeatureExtractor(features, backbone_config['used_features'])
106
+
107
+ saliency_networks = []
108
+ scanpath_networks = []
109
+ fixation_selection_networks = []
110
+ finalizers = []
111
+ for component in range(components):
112
+ saliency_network = build_saliency_network(backbone_config['channels'])
113
+ fixation_selection_network = build_fixation_selection_network()
114
+
115
+ saliency_networks.append(saliency_network)
116
+ scanpath_networks.append(None)
117
+ fixation_selection_networks.append(fixation_selection_network)
118
+ finalizers.append(Finalizer(sigma=8.0, learn_sigma=True, saliency_map_factor=2))
119
+
120
+ return DeepGazeIIIMixture(
121
+ features=feature_extractor,
122
+ saliency_networks=saliency_networks,
123
+ scanpath_networks=scanpath_networks,
124
+ fixation_selection_networks=fixation_selection_networks,
125
+ finalizers=finalizers,
126
+ downsample=2,
127
+ readout_factor=16,
128
+ saliency_map_factor=2,
129
+ included_fixations=[],
130
+ )
131
+
132
+
133
+ class DeepGazeIIE(MixtureModel):
134
+ """DeepGazeIIE model
135
+
136
+ :note
137
+ See Linardos, A., Kümmerer, M., Press, O., & Bethge, M. (2021). Calibrated prediction in and out-of-domain for state-of-the-art saliency modeling. ArXiv:2105.12441 [Cs], http://arxiv.org/abs/2105.12441
138
+ """
139
+ def __init__(self, pretrained=True):
140
+ # we average over 3 instances per backbone, each instance has 10 crossvalidation folds
141
+ backbone_models = [build_deepgaze_mixture(backbone_config, components=3 * 10) for backbone_config in BACKBONES]
142
+ super().__init__(backbone_models)
143
+
144
+ if pretrained:
145
+ self.load_state_dict(model_zoo.load_url('https://github.com/matthias-k/DeepGaze/releases/download/v1.0.0/deepgaze2e.pth', map_location=torch.device('cpu')))
146
+
147
+
148
+ def import_class(name):
149
+ module_name, class_name = name.rsplit('.', 1)
150
+ module = importlib.import_module(module_name)
151
+ return getattr(module, class_name)
deepgaze_pytorch/deepgaze3.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from torch.utils import model_zoo
8
+
9
+ from .features.densenet import RGBDenseNet201
10
+ from .modules import FeatureExtractor, Finalizer, DeepGazeIIIMixture
11
+ from .layers import FlexibleScanpathHistoryEncoding
12
+
13
+ from .layers import (
14
+ Conv2dMultiInput,
15
+ LayerNorm,
16
+ LayerNormMultiInput,
17
+ Bias,
18
+ )
19
+
20
+
21
+ def build_saliency_network(input_channels):
22
+ return nn.Sequential(OrderedDict([
23
+ ('layernorm0', LayerNorm(input_channels)),
24
+ ('conv0', nn.Conv2d(input_channels, 8, (1, 1), bias=False)),
25
+ ('bias0', Bias(8)),
26
+ ('softplus0', nn.Softplus()),
27
+
28
+ ('layernorm1', LayerNorm(8)),
29
+ ('conv1', nn.Conv2d(8, 16, (1, 1), bias=False)),
30
+ ('bias1', Bias(16)),
31
+ ('softplus1', nn.Softplus()),
32
+
33
+ ('layernorm2', LayerNorm(16)),
34
+ ('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
35
+ ('bias2', Bias(1)),
36
+ ('softplus2', nn.Softplus()),
37
+ ]))
38
+
39
+
40
+ def build_scanpath_network():
41
+ return nn.Sequential(OrderedDict([
42
+ ('encoding0', FlexibleScanpathHistoryEncoding(in_fixations=4, channels_per_fixation=3, out_channels=128, kernel_size=[1, 1], bias=True)),
43
+ ('softplus0', nn.Softplus()),
44
+
45
+ ('layernorm1', LayerNorm(128)),
46
+ ('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)),
47
+ ('bias1', Bias(16)),
48
+ ('softplus1', nn.Softplus()),
49
+ ]))
50
+
51
+
52
+ def build_fixation_selection_network():
53
+ return nn.Sequential(OrderedDict([
54
+ ('layernorm0', LayerNormMultiInput([1, 16])),
55
+ ('conv0', Conv2dMultiInput([1, 16], 128, (1, 1), bias=False)),
56
+ ('bias0', Bias(128)),
57
+ ('softplus0', nn.Softplus()),
58
+
59
+ ('layernorm1', LayerNorm(128)),
60
+ ('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)),
61
+ ('bias1', Bias(16)),
62
+ ('softplus1', nn.Softplus()),
63
+
64
+ ('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
65
+ ]))
66
+
67
+
68
+ class DeepGazeIII(DeepGazeIIIMixture):
69
+ """DeepGazeIII model
70
+
71
+ :note
72
+ See Kümmerer, M., Bethge, M., & Wallis, T.S.A. (2022). DeepGaze III: Modeling free-viewing human scanpaths with deep learning. Journal of Vision 2022, https://doi.org/10.1167/jov.22.5.7
73
+ """
74
+ def __init__(self, pretrained=True):
75
+ features = RGBDenseNet201()
76
+
77
+ feature_extractor = FeatureExtractor(features, [
78
+ '1.features.denseblock4.denselayer32.norm1',
79
+ '1.features.denseblock4.denselayer32.conv1',
80
+ '1.features.denseblock4.denselayer31.conv2',
81
+ ])
82
+
83
+ saliency_networks = []
84
+ scanpath_networks = []
85
+ fixation_selection_networks = []
86
+ finalizers = []
87
+ for component in range(10):
88
+ saliency_network = build_saliency_network(2048)
89
+ scanpath_network = build_scanpath_network()
90
+ fixation_selection_network = build_fixation_selection_network()
91
+
92
+ saliency_networks.append(saliency_network)
93
+ scanpath_networks.append(scanpath_network)
94
+ fixation_selection_networks.append(fixation_selection_network)
95
+ finalizers.append(Finalizer(sigma=8.0, learn_sigma=True, saliency_map_factor=4))
96
+
97
+ super().__init__(
98
+ features=feature_extractor,
99
+ saliency_networks=saliency_networks,
100
+ scanpath_networks=scanpath_networks,
101
+ fixation_selection_networks=fixation_selection_networks,
102
+ finalizers=finalizers,
103
+ downsample=2,
104
+ readout_factor=4,
105
+ saliency_map_factor=4,
106
+ included_fixations=[-1, -2, -3, -4]
107
+ )
108
+
109
+ if pretrained:
110
+ self.load_state_dict(model_zoo.load_url('https://github.com/matthias-k/DeepGaze/releases/download/v1.1.0/deepgaze3.pth', map_location=torch.device('cpu')))
deepgaze_pytorch/features/__init__.py ADDED
File without changes
deepgaze_pytorch/features/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (164 Bytes). View file
 
deepgaze_pytorch/features/__pycache__/alexnet.cpython-39.pyc ADDED
Binary file (836 Bytes). View file
 
deepgaze_pytorch/features/__pycache__/densenet.cpython-39.pyc ADDED
Binary file (852 Bytes). View file
 
deepgaze_pytorch/features/__pycache__/efficientnet.cpython-39.pyc ADDED
Binary file (1.25 kB). View file
 
deepgaze_pytorch/features/__pycache__/normalizer.cpython-39.pyc ADDED
Binary file (1.1 kB). View file
 
deepgaze_pytorch/features/__pycache__/resnext.cpython-39.pyc ADDED
Binary file (1.23 kB). View file
 
deepgaze_pytorch/features/__pycache__/shapenet.cpython-39.pyc ADDED
Binary file (3.7 kB). View file
 
deepgaze_pytorch/features/alexnet.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+
8
+ from .normalizer import Normalizer
9
+
10
+
11
+
12
+ class RGBalexnet(nn.Sequential):
13
+ def __init__(self):
14
+ super(RGBalexnet, self).__init__()
15
+ self.model = torch.hub.load('pytorch/vision:v0.6.0', 'alexnet', pretrained=True)
16
+ self.normalizer = Normalizer()
17
+ super(RGBalexnet, self).__init__(self.normalizer, self.model)
18
+
deepgaze_pytorch/features/bagnet.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code is adapted from: https://github.com/wielandbrendel/bag-of-local-features-models
3
+ """
4
+
5
+ import torch.nn as nn
6
+ import math
7
+ import torch
8
+ from collections import OrderedDict
9
+ from torch.utils import model_zoo
10
+
11
+ from .normalizer import Normalizer
12
+
13
+
14
+ import os
15
+ dir_path = os.path.dirname(os.path.realpath(__file__))
16
+
17
+ __all__ = ['bagnet9', 'bagnet17', 'bagnet33']
18
+
19
+ model_urls = {
20
+ 'bagnet9': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet8-34f4ccd2.pth.tar',
21
+ 'bagnet17': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet16-105524de.pth.tar',
22
+ 'bagnet33': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet32-2ddd53ed.pth.tar',
23
+ }
24
+
25
+
26
+ class Bottleneck(nn.Module):
27
+ expansion = 4
28
+
29
+ def __init__(self, inplanes, planes, stride=1, downsample=None, kernel_size=1):
30
+ super(Bottleneck, self).__init__()
31
+ # print('Creating bottleneck with kernel size {} and stride {} with padding {}'.format(kernel_size, stride, (kernel_size - 1) // 2))
32
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
33
+ self.bn1 = nn.BatchNorm2d(planes)
34
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, stride=stride,
35
+ padding=0, bias=False) # changed padding from (kernel_size - 1) // 2
36
+ self.bn2 = nn.BatchNorm2d(planes)
37
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
38
+ self.bn3 = nn.BatchNorm2d(planes * 4)
39
+ self.relu = nn.ReLU(inplace=True)
40
+ self.downsample = downsample
41
+ self.stride = stride
42
+
43
+ def forward(self, x, **kwargs):
44
+ residual = x
45
+
46
+ out = self.conv1(x)
47
+ out = self.bn1(out)
48
+ out = self.relu(out)
49
+
50
+ out = self.conv2(out)
51
+ out = self.bn2(out)
52
+ out = self.relu(out)
53
+
54
+ out = self.conv3(out)
55
+ out = self.bn3(out)
56
+
57
+ if self.downsample is not None:
58
+ residual = self.downsample(x)
59
+
60
+ if residual.size(-1) != out.size(-1):
61
+ diff = residual.size(-1) - out.size(-1)
62
+ residual = residual[:,:,:-diff,:-diff]
63
+
64
+ out += residual
65
+ out = self.relu(out)
66
+
67
+ return out
68
+
69
+
70
+ class BagNet(nn.Module):
71
+
72
+ def __init__(self, block, layers, strides=[1, 2, 2, 2], kernel3=[0, 0, 0, 0], num_classes=1000, avg_pool=True):
73
+ self.inplanes = 64
74
+ super(BagNet, self).__init__()
75
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=1, stride=1, padding=0,
76
+ bias=False)
77
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0,
78
+ bias=False)
79
+ self.bn1 = nn.BatchNorm2d(64, momentum=0.001)
80
+ self.relu = nn.ReLU(inplace=True)
81
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], kernel3=kernel3[0], prefix='layer1')
82
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], kernel3=kernel3[1], prefix='layer2')
83
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], kernel3=kernel3[2], prefix='layer3')
84
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], kernel3=kernel3[3], prefix='layer4')
85
+ self.avgpool = nn.AvgPool2d(1, stride=1)
86
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
87
+ self.avg_pool = avg_pool
88
+ self.block = block
89
+
90
+ for m in self.modules():
91
+ if isinstance(m, nn.Conv2d):
92
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
93
+ m.weight.data.normal_(0, math.sqrt(2. / n))
94
+ elif isinstance(m, nn.BatchNorm2d):
95
+ m.weight.data.fill_(1)
96
+ m.bias.data.zero_()
97
+
98
+ def _make_layer(self, block, planes, blocks, stride=1, kernel3=0, prefix=''):
99
+ downsample = None
100
+ if stride != 1 or self.inplanes != planes * block.expansion:
101
+ downsample = nn.Sequential(
102
+ nn.Conv2d(self.inplanes, planes * block.expansion,
103
+ kernel_size=1, stride=stride, bias=False),
104
+ nn.BatchNorm2d(planes * block.expansion),
105
+ )
106
+
107
+ layers = []
108
+ kernel = 1 if kernel3 == 0 else 3
109
+ layers.append(block(self.inplanes, planes, stride, downsample, kernel_size=kernel))
110
+ self.inplanes = planes * block.expansion
111
+ for i in range(1, blocks):
112
+ kernel = 1 if kernel3 <= i else 3
113
+ layers.append(block(self.inplanes, planes, kernel_size=kernel))
114
+
115
+ return nn.Sequential(*layers)
116
+
117
+ def forward(self, x):
118
+ x = self.conv1(x)
119
+ x = self.conv2(x)
120
+ x = self.bn1(x)
121
+ x = self.relu(x)
122
+
123
+ x = self.layer1(x)
124
+ x = self.layer2(x)
125
+ x = self.layer3(x)
126
+ x = self.layer4(x)
127
+
128
+ if self.avg_pool:
129
+ x = nn.AvgPool2d(x.size()[2], stride=1)(x)
130
+ x = x.view(x.size(0), -1)
131
+ x = self.fc(x)
132
+ else:
133
+ x = x.permute(0,2,3,1)
134
+ x = self.fc(x)
135
+
136
+ return x
137
+
138
+ def bagnet33(pretrained=False, strides=[2, 2, 2, 1], **kwargs):
139
+ """Constructs a Bagnet-33 model.
140
+
141
+ Args:
142
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
143
+ """
144
+ model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,1,1], **kwargs)
145
+ if pretrained:
146
+ model.load_state_dict(model_zoo.load_url(model_urls['bagnet33']))
147
+ return model
148
+
149
+ def bagnet17(pretrained=False, strides=[2, 2, 2, 1], **kwargs):
150
+ """Constructs a Bagnet-17 model.
151
+
152
+ Args:
153
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
154
+ """
155
+ model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,1,0], **kwargs)
156
+ if pretrained:
157
+ model.load_state_dict(model_zoo.load_url(model_urls['bagnet17']))
158
+ return model
159
+
160
+ def bagnet9(pretrained=False, strides=[2, 2, 2, 1], **kwargs):
161
+ """Constructs a Bagnet-9 model.
162
+
163
+ Args:
164
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
165
+ """
166
+ model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,0,0], **kwargs)
167
+ if pretrained:
168
+ model.load_state_dict(model_zoo.load_url(model_urls['bagnet9']))
169
+ return model
170
+
171
+ # --- DeepGaze Adaptation ----
172
+
173
+
174
+
175
+
176
+ class RGBBagNet17(nn.Sequential):
177
+ def __init__(self):
178
+ super(RGBBagNet17, self).__init__()
179
+ self.bagnet = bagnet17(pretrained=True, avg_pool=False)
180
+ self.normalizer = Normalizer()
181
+ super(RGBBagNet17, self).__init__(self.normalizer, self.bagnet)
182
+
183
+
184
+ class RGBBagNet33(nn.Sequential):
185
+ def __init__(self):
186
+ super(RGBBagNet33, self).__init__()
187
+ self.bagnet = bagnet33(pretrained=True, avg_pool=False)
188
+ self.normalizer = Normalizer()
189
+ super(RGBBagNet33, self).__init__(self.normalizer, self.bagnet)
190
+
191
+
192
+
deepgaze_pytorch/features/densenet.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+
8
+ from .normalizer import Normalizer
9
+
10
+
11
+
12
+ class RGBDenseNet201(nn.Sequential):
13
+ def __init__(self):
14
+ super(RGBDenseNet201, self).__init__()
15
+ self.densenet = torch.hub.load('pytorch/vision:v0.6.0', 'densenet201', pretrained=True)
16
+ self.normalizer = Normalizer()
17
+ super(RGBDenseNet201, self).__init__(self.normalizer, self.densenet)
18
+
19
+
deepgaze_pytorch/features/efficientnet.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+
8
+ from .efficientnet_pytorch import EfficientNet
9
+
10
+
11
+ from .normalizer import Normalizer
12
+
13
+
14
+
15
+ class RGBEfficientNetB5(nn.Sequential):
16
+ def __init__(self):
17
+ super(RGBEfficientNetB5, self).__init__()
18
+ self.efficientnet = EfficientNet.from_pretrained('efficientnet-b5')
19
+ self.normalizer = Normalizer()
20
+ super(RGBEfficientNetB5, self).__init__(self.normalizer, self.efficientnet)
21
+
22
+
23
+
24
+ class RGBEfficientNetB7(nn.Sequential):
25
+ def __init__(self):
26
+ super(RGBEfficientNetB7, self).__init__()
27
+ self.efficientnet = EfficientNet.from_pretrained('efficientnet-b7')
28
+ self.normalizer = Normalizer()
29
+ super(RGBEfficientNetB7, self).__init__(self.normalizer, self.efficientnet)
30
+
31
+
deepgaze_pytorch/features/efficientnet_pytorch/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.6.3"
2
+ from .model import EfficientNet
3
+ from .utils import (
4
+ GlobalParams,
5
+ BlockArgs,
6
+ BlockDecoder,
7
+ efficientnet,
8
+ get_model_params,
9
+ )
10
+
deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (383 Bytes). View file
 
deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/model.cpython-39.pyc ADDED
Binary file (6.98 kB). View file
 
deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/utils.cpython-39.pyc ADDED
Binary file (12.7 kB). View file
 
deepgaze_pytorch/features/efficientnet_pytorch/model.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from .utils import (
6
+ round_filters,
7
+ round_repeats,
8
+ drop_connect,
9
+ get_same_padding_conv2d,
10
+ get_model_params,
11
+ efficientnet_params,
12
+ load_pretrained_weights,
13
+ Swish,
14
+ MemoryEfficientSwish,
15
+ )
16
+
17
+ class MBConvBlock(nn.Module):
18
+ """
19
+ Mobile Inverted Residual Bottleneck Block
20
+
21
+ Args:
22
+ block_args (namedtuple): BlockArgs, see above
23
+ global_params (namedtuple): GlobalParam, see above
24
+
25
+ Attributes:
26
+ has_se (bool): Whether the block contains a Squeeze and Excitation layer.
27
+ """
28
+
29
+ def __init__(self, block_args, global_params):
30
+ super().__init__()
31
+ self._block_args = block_args
32
+ self._bn_mom = 1 - global_params.batch_norm_momentum
33
+ self._bn_eps = global_params.batch_norm_epsilon
34
+ self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
35
+ self.id_skip = block_args.id_skip # skip connection and drop connect
36
+
37
+ # Get static or dynamic convolution depending on image size
38
+ Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
39
+
40
+ # Expansion phase
41
+ inp = self._block_args.input_filters # number of input channels
42
+ oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
43
+ if self._block_args.expand_ratio != 1:
44
+ self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
45
+ self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
46
+
47
+ # Depthwise convolution phase
48
+ k = self._block_args.kernel_size
49
+ s = self._block_args.stride
50
+ self._depthwise_conv = Conv2d(
51
+ in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
52
+ kernel_size=k, stride=s, bias=False)
53
+ self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
54
+
55
+ # Squeeze and Excitation layer, if desired
56
+ if self.has_se:
57
+ num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
58
+ self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
59
+ self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
60
+
61
+ # Output phase
62
+ final_oup = self._block_args.output_filters
63
+ self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
64
+ self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
65
+ self._swish = MemoryEfficientSwish()
66
+
67
+ def forward(self, inputs, drop_connect_rate=None):
68
+ """
69
+ :param inputs: input tensor
70
+ :param drop_connect_rate: drop connect rate (float, between 0 and 1)
71
+ :return: output of block
72
+ """
73
+
74
+ # Expansion and Depthwise Convolution
75
+ x = inputs
76
+ if self._block_args.expand_ratio != 1:
77
+ x = self._swish(self._bn0(self._expand_conv(inputs)))
78
+ x = self._swish(self._bn1(self._depthwise_conv(x)))
79
+
80
+ # Squeeze and Excitation
81
+ if self.has_se:
82
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
83
+ x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
84
+ x = torch.sigmoid(x_squeezed) * x
85
+
86
+ x = self._bn2(self._project_conv(x))
87
+
88
+ # Skip connection and drop connect
89
+ input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
90
+ if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
91
+ if drop_connect_rate:
92
+ x = drop_connect(x, p=drop_connect_rate, training=self.training)
93
+ x = x + inputs # skip connection
94
+ return x
95
+
96
+ def set_swish(self, memory_efficient=True):
97
+ """Sets swish function as memory efficient (for training) or standard (for export)"""
98
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
99
+
100
+
101
+ class EfficientNet(nn.Module):
102
+ """
103
+ An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
104
+
105
+ Args:
106
+ blocks_args (list): A list of BlockArgs to construct blocks
107
+ global_params (namedtuple): A set of GlobalParams shared between blocks
108
+
109
+ Example:
110
+ model = EfficientNet.from_pretrained('efficientnet-b0')
111
+
112
+ """
113
+
114
+ def __init__(self, blocks_args=None, global_params=None):
115
+ super().__init__()
116
+ assert isinstance(blocks_args, list), 'blocks_args should be a list'
117
+ assert len(blocks_args) > 0, 'block args must be greater than 0'
118
+ self._global_params = global_params
119
+ self._blocks_args = blocks_args
120
+
121
+ # Get static or dynamic convolution depending on image size
122
+ Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
123
+
124
+ # Batch norm parameters
125
+ bn_mom = 1 - self._global_params.batch_norm_momentum
126
+ bn_eps = self._global_params.batch_norm_epsilon
127
+
128
+ # Stem
129
+ in_channels = 3 # rgb
130
+ out_channels = round_filters(32, self._global_params) # number of output channels
131
+ self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
132
+ self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
133
+
134
+ # Build blocks
135
+ self._blocks = nn.ModuleList([])
136
+ for block_args in self._blocks_args:
137
+
138
+ # Update block input and output filters based on depth multiplier.
139
+ block_args = block_args._replace(
140
+ input_filters=round_filters(block_args.input_filters, self._global_params),
141
+ output_filters=round_filters(block_args.output_filters, self._global_params),
142
+ num_repeat=round_repeats(block_args.num_repeat, self._global_params)
143
+ )
144
+
145
+ # The first block needs to take care of stride and filter size increase.
146
+ self._blocks.append(MBConvBlock(block_args, self._global_params))
147
+ if block_args.num_repeat > 1:
148
+ block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
149
+ for _ in range(block_args.num_repeat - 1):
150
+ self._blocks.append(MBConvBlock(block_args, self._global_params))
151
+
152
+ # Head
153
+ in_channels = block_args.output_filters # output of final block
154
+ out_channels = round_filters(1280, self._global_params)
155
+ self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
156
+ self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
157
+
158
+ # Final linear layer
159
+ self._avg_pooling = nn.AdaptiveAvgPool2d(1)
160
+ self._dropout = nn.Dropout(self._global_params.dropout_rate)
161
+ self._fc = nn.Linear(out_channels, self._global_params.num_classes)
162
+ self._swish = MemoryEfficientSwish()
163
+
164
+ def set_swish(self, memory_efficient=True):
165
+ """Sets swish function as memory efficient (for training) or standard (for export)"""
166
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
167
+ for block in self._blocks:
168
+ block.set_swish(memory_efficient)
169
+
170
+
171
+ def extract_features(self, inputs):
172
+ """ Returns output of the final convolution layer """
173
+
174
+ # Stem
175
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
176
+
177
+ # Blocks
178
+ for idx, block in enumerate(self._blocks):
179
+ drop_connect_rate = self._global_params.drop_connect_rate
180
+ if drop_connect_rate:
181
+ drop_connect_rate *= float(idx) / len(self._blocks)
182
+ x = block(x, drop_connect_rate=drop_connect_rate)
183
+
184
+ # Head
185
+ x = self._swish(self._bn1(self._conv_head(x)))
186
+
187
+ return x
188
+
189
+ def forward(self, inputs):
190
+ """ Calls extract_features to extract features, applies final linear layer, and returns logits. """
191
+ bs = inputs.size(0)
192
+ # Convolution layers
193
+ x = self.extract_features(inputs)
194
+
195
+ # Pooling and final linear layer
196
+ x = self._avg_pooling(x)
197
+ x = x.view(bs, -1)
198
+ x = self._dropout(x)
199
+ x = self._fc(x)
200
+ return x
201
+
202
+ @classmethod
203
+ def from_name(cls, model_name, override_params=None):
204
+ cls._check_model_name_is_valid(model_name)
205
+ blocks_args, global_params = get_model_params(model_name, override_params)
206
+ return cls(blocks_args, global_params)
207
+
208
+ @classmethod
209
+ def from_pretrained(cls, model_name, advprop=False, num_classes=1000, in_channels=3):
210
+ model = cls.from_name(model_name, override_params={'num_classes': num_classes})
211
+ load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop)
212
+ if in_channels != 3:
213
+ Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size)
214
+ out_channels = round_filters(32, model._global_params)
215
+ model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
216
+ return model
217
+
218
+ @classmethod
219
+ def get_image_size(cls, model_name):
220
+ cls._check_model_name_is_valid(model_name)
221
+ _, _, res, _ = efficientnet_params(model_name)
222
+ return res
223
+
224
+ @classmethod
225
+ def _check_model_name_is_valid(cls, model_name):
226
+ """ Validates model name. """
227
+ valid_models = ['efficientnet-b'+str(i) for i in range(9)]
228
+ if model_name not in valid_models:
229
+ raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
deepgaze_pytorch/features/efficientnet_pytorch/utils.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains helper functions for building the model and for loading model parameters.
3
+ These helper functions are built to mirror those in the official TensorFlow implementation.
4
+ """
5
+
6
+ import re
7
+ import math
8
+ import collections
9
+ from functools import partial
10
+ import torch
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+ from torch.utils import model_zoo
14
+
15
+ ########################################################################
16
+ ############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
17
+ ########################################################################
18
+
19
+
20
+ # Parameters for the entire model (stem, all blocks, and head)
21
+ GlobalParams = collections.namedtuple('GlobalParams', [
22
+ 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate',
23
+ 'num_classes', 'width_coefficient', 'depth_coefficient',
24
+ 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])
25
+
26
+ # Parameters for an individual model block
27
+ BlockArgs = collections.namedtuple('BlockArgs', [
28
+ 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
29
+ 'expand_ratio', 'id_skip', 'stride', 'se_ratio'])
30
+
31
+ # Change namedtuple defaults
32
+ GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
33
+ BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
34
+
35
+
36
+ class SwishImplementation(torch.autograd.Function):
37
+ @staticmethod
38
+ def forward(ctx, i):
39
+ result = i * torch.sigmoid(i)
40
+ ctx.save_for_backward(i)
41
+ return result
42
+
43
+ @staticmethod
44
+ def backward(ctx, grad_output):
45
+ i = ctx.saved_variables[0]
46
+ sigmoid_i = torch.sigmoid(i)
47
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
48
+
49
+
50
+ class MemoryEfficientSwish(nn.Module):
51
+ def forward(self, x):
52
+ return SwishImplementation.apply(x)
53
+
54
+ class Swish(nn.Module):
55
+ def forward(self, x):
56
+ return x * torch.sigmoid(x)
57
+
58
+
59
+ def round_filters(filters, global_params):
60
+ """ Calculate and round number of filters based on depth multiplier. """
61
+ multiplier = global_params.width_coefficient
62
+ if not multiplier:
63
+ return filters
64
+ divisor = global_params.depth_divisor
65
+ min_depth = global_params.min_depth
66
+ filters *= multiplier
67
+ min_depth = min_depth or divisor
68
+ new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
69
+ if new_filters < 0.9 * filters: # prevent rounding by more than 10%
70
+ new_filters += divisor
71
+ return int(new_filters)
72
+
73
+
74
+ def round_repeats(repeats, global_params):
75
+ """ Round number of filters based on depth multiplier. """
76
+ multiplier = global_params.depth_coefficient
77
+ if not multiplier:
78
+ return repeats
79
+ return int(math.ceil(multiplier * repeats))
80
+
81
+
82
+ def drop_connect(inputs, p, training):
83
+ """ Drop connect. """
84
+ if not training: return inputs
85
+ batch_size = inputs.shape[0]
86
+ keep_prob = 1 - p
87
+ random_tensor = keep_prob
88
+ random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
89
+ binary_tensor = torch.floor(random_tensor)
90
+ output = inputs / keep_prob * binary_tensor
91
+ return output
92
+
93
+
94
+ def get_same_padding_conv2d(image_size=None):
95
+ """ Chooses static padding if you have specified an image size, and dynamic padding otherwise.
96
+ Static padding is necessary for ONNX exporting of models. """
97
+ if image_size is None:
98
+ return Conv2dDynamicSamePadding
99
+ else:
100
+ return partial(Conv2dStaticSamePadding, image_size=image_size)
101
+
102
+
103
+ class Conv2dDynamicSamePadding(nn.Conv2d):
104
+ """ 2D Convolutions like TensorFlow, for a dynamic image size """
105
+
106
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
107
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
108
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
109
+
110
+ def forward(self, x):
111
+ ih, iw = x.size()[-2:]
112
+ kh, kw = self.weight.size()[-2:]
113
+ sh, sw = self.stride
114
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
115
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
116
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
117
+ if pad_h > 0 or pad_w > 0:
118
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
119
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
120
+
121
+
122
+ class Conv2dStaticSamePadding(nn.Conv2d):
123
+ """ 2D Convolutions like TensorFlow, for a fixed image size"""
124
+
125
+ def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
126
+ super().__init__(in_channels, out_channels, kernel_size, **kwargs)
127
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
128
+
129
+ # Calculate padding based on image size and save it
130
+ assert image_size is not None
131
+ ih, iw = image_size if type(image_size) == list else [image_size, image_size]
132
+ kh, kw = self.weight.size()[-2:]
133
+ sh, sw = self.stride
134
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
135
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
136
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
137
+ if pad_h > 0 or pad_w > 0:
138
+ self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
139
+ else:
140
+ self.static_padding = Identity()
141
+
142
+ def forward(self, x):
143
+ x = self.static_padding(x)
144
+ x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
145
+ return x
146
+
147
+
148
+ class Identity(nn.Module):
149
+ def __init__(self, ):
150
+ super(Identity, self).__init__()
151
+
152
+ def forward(self, input):
153
+ return input
154
+
155
+
156
+ ########################################################################
157
+ ############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
158
+ ########################################################################
159
+
160
+
161
+ def efficientnet_params(model_name):
162
+ """ Map EfficientNet model name to parameter coefficients. """
163
+ params_dict = {
164
+ # Coefficients: width,depth,res,dropout
165
+ 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
166
+ 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
167
+ 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
168
+ 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
169
+ 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
170
+ 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
171
+ 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
172
+ 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
173
+ 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
174
+ 'efficientnet-l2': (4.3, 5.3, 800, 0.5),
175
+ }
176
+ return params_dict[model_name]
177
+
178
+
179
+ class BlockDecoder(object):
180
+ """ Block Decoder for readability, straight from the official TensorFlow repository """
181
+
182
+ @staticmethod
183
+ def _decode_block_string(block_string):
184
+ """ Gets a block through a string notation of arguments. """
185
+ assert isinstance(block_string, str)
186
+
187
+ ops = block_string.split('_')
188
+ options = {}
189
+ for op in ops:
190
+ splits = re.split(r'(\d.*)', op)
191
+ if len(splits) >= 2:
192
+ key, value = splits[:2]
193
+ options[key] = value
194
+
195
+ # Check stride
196
+ assert (('s' in options and len(options['s']) == 1) or
197
+ (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
198
+
199
+ return BlockArgs(
200
+ kernel_size=int(options['k']),
201
+ num_repeat=int(options['r']),
202
+ input_filters=int(options['i']),
203
+ output_filters=int(options['o']),
204
+ expand_ratio=int(options['e']),
205
+ id_skip=('noskip' not in block_string),
206
+ se_ratio=float(options['se']) if 'se' in options else None,
207
+ stride=[int(options['s'][0])])
208
+
209
+ @staticmethod
210
+ def _encode_block_string(block):
211
+ """Encodes a block to a string."""
212
+ args = [
213
+ 'r%d' % block.num_repeat,
214
+ 'k%d' % block.kernel_size,
215
+ 's%d%d' % (block.strides[0], block.strides[1]),
216
+ 'e%s' % block.expand_ratio,
217
+ 'i%d' % block.input_filters,
218
+ 'o%d' % block.output_filters
219
+ ]
220
+ if 0 < block.se_ratio <= 1:
221
+ args.append('se%s' % block.se_ratio)
222
+ if block.id_skip is False:
223
+ args.append('noskip')
224
+ return '_'.join(args)
225
+
226
+ @staticmethod
227
+ def decode(string_list):
228
+ """
229
+ Decodes a list of string notations to specify blocks inside the network.
230
+
231
+ :param string_list: a list of strings, each string is a notation of block
232
+ :return: a list of BlockArgs namedtuples of block args
233
+ """
234
+ assert isinstance(string_list, list)
235
+ blocks_args = []
236
+ for block_string in string_list:
237
+ blocks_args.append(BlockDecoder._decode_block_string(block_string))
238
+ return blocks_args
239
+
240
+ @staticmethod
241
+ def encode(blocks_args):
242
+ """
243
+ Encodes a list of BlockArgs to a list of strings.
244
+
245
+ :param blocks_args: a list of BlockArgs namedtuples of block args
246
+ :return: a list of strings, each string is a notation of block
247
+ """
248
+ block_strings = []
249
+ for block in blocks_args:
250
+ block_strings.append(BlockDecoder._encode_block_string(block))
251
+ return block_strings
252
+
253
+
254
+ def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2,
255
+ drop_connect_rate=0.2, image_size=None, num_classes=1000):
256
+ """ Creates a efficientnet model. """
257
+
258
+ blocks_args = [
259
+ 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
260
+ 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
261
+ 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
262
+ 'r1_k3_s11_e6_i192_o320_se0.25',
263
+ ]
264
+ blocks_args = BlockDecoder.decode(blocks_args)
265
+
266
+ global_params = GlobalParams(
267
+ batch_norm_momentum=0.99,
268
+ batch_norm_epsilon=1e-3,
269
+ dropout_rate=dropout_rate,
270
+ drop_connect_rate=drop_connect_rate,
271
+ # data_format='channels_last', # removed, this is always true in PyTorch
272
+ num_classes=num_classes,
273
+ width_coefficient=width_coefficient,
274
+ depth_coefficient=depth_coefficient,
275
+ depth_divisor=8,
276
+ min_depth=None,
277
+ image_size=image_size,
278
+ )
279
+
280
+ return blocks_args, global_params
281
+
282
+
283
+ def get_model_params(model_name, override_params):
284
+ """ Get the block args and global params for a given model """
285
+ if model_name.startswith('efficientnet'):
286
+ w, d, s, p = efficientnet_params(model_name)
287
+ # note: all models have drop connect rate = 0.2
288
+ blocks_args, global_params = efficientnet(
289
+ width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
290
+ else:
291
+ raise NotImplementedError('model name is not pre-defined: %s' % model_name)
292
+ if override_params:
293
+ # ValueError will be raised here if override_params has fields not included in global_params.
294
+ global_params = global_params._replace(**override_params)
295
+ return blocks_args, global_params
296
+
297
+
298
+ url_map = {
299
+ 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
300
+ 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
301
+ 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
302
+ 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
303
+ 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
304
+ 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
305
+ 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
306
+ 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
307
+ }
308
+
309
+
310
+ url_map_advprop = {
311
+ 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
312
+ 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
313
+ 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
314
+ 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
315
+ 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
316
+ 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
317
+ 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
318
+ 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
319
+ 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
320
+ }
321
+
322
+
323
+ def load_pretrained_weights(model, model_name, load_fc=True, advprop=False):
324
+ """ Loads pretrained weights, and downloads if loading for the first time. """
325
+ # AutoAugment or Advprop (different preprocessing)
326
+ url_map_ = url_map_advprop if advprop else url_map
327
+ state_dict = model_zoo.load_url(url_map_[model_name])
328
+ if load_fc:
329
+ model.load_state_dict(state_dict)
330
+ else:
331
+ state_dict.pop('_fc.weight')
332
+ state_dict.pop('_fc.bias')
333
+ res = model.load_state_dict(state_dict, strict=False)
334
+ assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
335
+ print('Loaded pretrained weights for {}'.format(model_name))
deepgaze_pytorch/features/inception.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+
8
+ from .normalizer import Normalizer
9
+
10
+
11
+
12
+
13
+ class RGBInceptionV3(nn.Sequential):
14
+ def __init__(self):
15
+ super(RGBInceptionV3, self).__init__()
16
+ self.resnext = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True)
17
+ self.normalizer = Normalizer()
18
+ super(RGBInceptionV3, self).__init__(self.normalizer, self.resnext)
19
+
20
+
deepgaze_pytorch/features/mobilenet.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+
8
+ from .normalizer import Normalizer
9
+
10
+
11
+
12
+ class RGBMobileNetV2(nn.Sequential):
13
+ def __init__(self):
14
+ super(RGBMobileNetV2, self).__init__()
15
+ self.mobilenet_v2 = torchvision.models.mobilenet_v2(pretrained=True)
16
+ self.normalizer = Normalizer()
17
+ super(RGBMobileNetV2, self).__init__(self.normalizer, self.mobilenet_v2)
deepgaze_pytorch/features/normalizer.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+
8
+ class Normalizer(nn.Module):
9
+ def __init__(self):
10
+ super(Normalizer, self).__init__()
11
+ mean = np.array([0.485, 0.456, 0.406])
12
+ mean = mean[:, np.newaxis, np.newaxis]
13
+
14
+ std = np.array([0.229, 0.224, 0.225])
15
+ std = std[:, np.newaxis, np.newaxis]
16
+
17
+ # don't persist to keep old checkpoints working
18
+ self.register_buffer('mean', torch.tensor(mean), persistent=False)
19
+ self.register_buffer('std', torch.tensor(std), persistent=False)
20
+
21
+
22
+ def forward(self, tensor):
23
+ tensor = tensor / 255.0
24
+
25
+ tensor -= self.mean
26
+ tensor /= self.std
27
+
28
+ return tensor
deepgaze_pytorch/features/resnet.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+
8
+ from .normalizer import Normalizer
9
+
10
+
11
+
12
+ class RGBResNet34(nn.Sequential):
13
+ def __init__(self):
14
+ super(RGBResNet34, self).__init__()
15
+ self.resnet = torchvision.models.resnet34(pretrained=True)
16
+ self.normalizer = Normalizer()
17
+ super(RGBResNet34, self).__init__(self.normalizer, self.resnet)
18
+
19
+
20
+ class RGBResNet50(nn.Sequential):
21
+ def __init__(self):
22
+ super(RGBResNet50, self).__init__()
23
+ self.resnet = torchvision.models.resnet50(pretrained=True)
24
+ self.normalizer = Normalizer()
25
+ super(RGBResNet50, self).__init__(self.normalizer, self.resnet)
26
+
27
+
28
+ class RGBResNet50_alt(nn.Sequential):
29
+ def __init__(self):
30
+ super(RGBResNet50, self).__init__()
31
+ self.resnet = torchvision.models.resnet50(pretrained=True)
32
+ self.normalizer = Normalizer()
33
+ state_dict = torch.load("Resnet-AlternativePreTrain.pth")
34
+ model.load_state_dict(state_dict)
35
+ super(RGBResNet50, self).__init__(self.normalizer, self.resnet)
36
+
37
+
38
+
39
+ class RGBResNet101(nn.Sequential):
40
+ def __init__(self):
41
+ super(RGBResNet101, self).__init__()
42
+ self.resnet = torchvision.models.resnet101(pretrained=True)
43
+ self.normalizer = Normalizer()
44
+ super(RGBResNet101, self).__init__(self.normalizer, self.resnet)
deepgaze_pytorch/features/resnext.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+
8
+ from .normalizer import Normalizer
9
+
10
+
11
+
12
+ class RGBResNext50(nn.Sequential):
13
+ def __init__(self):
14
+ super(RGBResNext50, self).__init__()
15
+ self.resnext = torch.hub.load('pytorch/vision:v0.6.0', 'resnext50_32x4d', pretrained=True)
16
+ self.normalizer = Normalizer()
17
+ super(RGBResNext50, self).__init__(self.normalizer, self.resnext)
18
+
19
+
20
+ class RGBResNext101(nn.Sequential):
21
+ def __init__(self):
22
+ super(RGBResNext101, self).__init__()
23
+ self.resnext = torch.hub.load('pytorch/vision:v0.6.0', 'resnext101_32x8d', pretrained=True)
24
+ self.normalizer = Normalizer()
25
+ super(RGBResNext101, self).__init__(self.normalizer, self.resnext)
26
+
27
+
deepgaze_pytorch/features/shapenet.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code was adapted from: https://github.com/rgeirhos/texture-vs-shape
3
+ """
4
+ import os
5
+ import sys
6
+ from collections import OrderedDict
7
+ import torch
8
+ import torch.nn as nn
9
+ import torchvision
10
+ import torchvision.models
11
+ from torch.utils import model_zoo
12
+
13
+ from .normalizer import Normalizer
14
+
15
+
16
+ def load_model(model_name):
17
+
18
+ model_urls = {
19
+ 'resnet50_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/6f41d2e86fc60566f78de64ecff35cc61eb6436f/resnet50_train_60_epochs-c8e5653e.pth.tar',
20
+ 'resnet50_trained_on_SIN_and_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_train_45_epochs_combined_IN_SF-2a0d100e.pth.tar',
21
+ 'resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar',
22
+ 'vgg16_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/0008049cd10f74a944c6d5e90d4639927f8620ae/vgg16_train_60_epochs_lr0.01-6c6fcc9f.pth.tar',
23
+ 'alexnet_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/0008049cd10f74a944c6d5e90d4639927f8620ae/alexnet_train_60_epochs_lr0.001-b4aa5238.pth.tar',
24
+ }
25
+
26
+ if "resnet50" in model_name:
27
+ #print("Using the ResNet50 architecture.")
28
+ model = torchvision.models.resnet50(pretrained=False)
29
+ #model = torch.nn.DataParallel(model) # .cuda()
30
+ # fake DataParallel structrue
31
+ model = torch.nn.Sequential(OrderedDict([('module', model)]))
32
+ checkpoint = model_zoo.load_url(model_urls[model_name], map_location=torch.device('cpu'))
33
+ elif "vgg16" in model_name:
34
+ #print("Using the VGG-16 architecture.")
35
+
36
+ # download model from URL manually and save to desired location
37
+ filepath = "./vgg16_train_60_epochs_lr0.01-6c6fcc9f.pth.tar"
38
+
39
+ assert os.path.exists(filepath), "Please download the VGG model yourself from the following link and save it locally: https://drive.google.com/drive/folders/1A0vUWyU6fTuc-xWgwQQeBvzbwi6geYQK (too large to be downloaded automatically like the other models)"
40
+
41
+ model = torchvision.models.vgg16(pretrained=False)
42
+ model.features = torch.nn.DataParallel(model.features)
43
+ model.cuda()
44
+ checkpoint = torch.load(filepath, map_location=torch.device('cpu'))
45
+
46
+
47
+ elif "alexnet" in model_name:
48
+ #print("Using the AlexNet architecture.")
49
+ model = torchvision.models.alexnet(pretrained=False)
50
+ model.features = torch.nn.DataParallel(model.features)
51
+ model.cuda()
52
+ checkpoint = model_zoo.load_url(model_urls[model_name], map_location=torch.device('cpu'))
53
+ else:
54
+ raise ValueError("unknown model architecture.")
55
+
56
+ model.load_state_dict(checkpoint["state_dict"])
57
+ return model
58
+
59
+ # --- DeepGaze Adaptation ----
60
+
61
+
62
+
63
+
64
+ class RGBShapeNetA(nn.Sequential):
65
+ def __init__(self):
66
+ super(RGBShapeNetA, self).__init__()
67
+ self.shapenet = load_model("resnet50_trained_on_SIN")
68
+ self.normalizer = Normalizer()
69
+ super(RGBShapeNetA, self).__init__(self.normalizer, self.shapenet)
70
+
71
+
72
+
73
+ class RGBShapeNetB(nn.Sequential):
74
+ def __init__(self):
75
+ super(RGBShapeNetB, self).__init__()
76
+ self.shapenet = load_model("resnet50_trained_on_SIN_and_IN")
77
+ self.normalizer = Normalizer()
78
+ super(RGBShapeNetB, self).__init__(self.normalizer, self.shapenet)
79
+
80
+
81
+ class RGBShapeNetC(nn.Sequential):
82
+ def __init__(self):
83
+ super(RGBShapeNetC, self).__init__()
84
+ self.shapenet = load_model("resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN")
85
+ self.normalizer = Normalizer()
86
+ super(RGBShapeNetC, self).__init__(self.normalizer, self.shapenet)
87
+
88
+
89
+
deepgaze_pytorch/features/squeezenet.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+
8
+ from .normalizer import Normalizer
9
+
10
+
11
+ class RGBSqueezeNet(nn.Sequential):
12
+ def __init__(self):
13
+ super(RGBSqueezeNet, self).__init__()
14
+ self.squeezenet = torch.hub.load('pytorch/vision:v0.6.0', 'squeezenet1_0', pretrained=True)
15
+ self.normalizer = Normalizer()
16
+ super(RGBSqueezeNet, self).__init__(self.normalizer, self.squeezenet)
17
+
deepgaze_pytorch/features/swav.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+
8
+ from .normalizer import Normalizer
9
+
10
+
11
+
12
+
13
+ class RGBSwav(nn.Sequential):
14
+ def __init__(self):
15
+ super(RGBSwav, self).__init__()
16
+ self.swav = torch.hub.load('facebookresearch/swav', 'resnet50', pretrained=True)
17
+ self.normalizer = Normalizer()
18
+ super(RGBSwav, self).__init__(self.normalizer, self.swav)
19
+
20
+
deepgaze_pytorch/features/uninformative.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class OnesLayer(nn.Module):
8
+ def __init__(self, size=None):
9
+ super().__init__()
10
+ self.size = size
11
+
12
+ def forward(self, tensor):
13
+ shape = list(tensor.shape)
14
+ shape[1] = 1 # return only one channel
15
+
16
+ if self.size is not None:
17
+ shape[2], shape[3] = self.size
18
+
19
+ return torch.ones(shape, dtype=torch.float32, device=tensor.device)
20
+
21
+
22
+ class UninformativeFeatures(torch.nn.Sequential):
23
+ def __init__(self):
24
+ super().__init__(OrderedDict([
25
+ ('ones', OnesLayer(size=(1, 1))),
26
+ ]))
deepgaze_pytorch/features/vgg.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+
8
+
9
+ class VGGInputNormalization(torch.nn.Module):
10
+ def __init__(self, inplace=True):
11
+ super().__init__()
12
+
13
+ self.inplace = inplace
14
+
15
+ mean = np.array([0.485, 0.456, 0.406])
16
+ mean = mean[:, np.newaxis, np.newaxis]
17
+
18
+ std = np.array([0.229, 0.224, 0.225])
19
+ std = std[:, np.newaxis, np.newaxis]
20
+ self.register_buffer('mean', torch.tensor(mean))
21
+ self.register_buffer('std', torch.tensor(std))
22
+
23
+ def forward(self, tensor):
24
+ if self.inplace:
25
+ tensor /= 255.0
26
+ else:
27
+ tensor = tensor / 255.0
28
+
29
+ tensor -= self.mean
30
+ tensor /= self.std
31
+
32
+ return tensor
33
+
34
+
35
+ class VGG19BNNamedFeatures(torch.nn.Sequential):
36
+ def __init__(self):
37
+ names = []
38
+ for block in range(5):
39
+ block_size = 2 if block < 2 else 4
40
+ for layer in range(block_size):
41
+ names.append(f'conv{block+1}_{layer+1}')
42
+ names.append(f'bn{block+1}_{layer+1}')
43
+ names.append(f'relu{block+1}_{layer+1}')
44
+ names.append(f'pool{block+1}')
45
+
46
+ vgg = torchvision.models.vgg19_bn(pretrained=True)
47
+ vgg_features = vgg.features
48
+ vgg.classifier = torch.nn.Sequential()
49
+
50
+ assert len(names) == len(vgg_features)
51
+
52
+ named_features = OrderedDict({'normalize': VGGInputNormalization()})
53
+
54
+ for name, feature in zip(names, vgg_features):
55
+ if isinstance(feature, nn.MaxPool2d):
56
+ feature.ceil_mode = True
57
+ named_features[name] = feature
58
+
59
+ super().__init__(named_features)
60
+
61
+
62
+ class VGG19NamedFeatures(torch.nn.Sequential):
63
+ def __init__(self):
64
+ names = []
65
+ for block in range(5):
66
+ block_size = 2 if block < 2 else 4
67
+ for layer in range(block_size):
68
+ names.append(f'conv{block+1}_{layer+1}')
69
+ names.append(f'relu{block+1}_{layer+1}')
70
+ names.append(f'pool{block+1}')
71
+
72
+ vgg = torchvision.models.vgg19(pretrained=True)
73
+ vgg_features = vgg.features
74
+ vgg.classifier = torch.nn.Sequential()
75
+
76
+ assert len(names) == len(vgg_features)
77
+
78
+ named_features = OrderedDict({'normalize': VGGInputNormalization()})
79
+
80
+ for name, feature in zip(names, vgg_features):
81
+ if isinstance(feature, nn.MaxPool2d):
82
+ feature.ceil_mode = True
83
+
84
+ named_features[name] = feature
85
+
86
+ super().__init__(named_features)
deepgaze_pytorch/features/vggnet.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+
8
+ from .normalizer import Normalizer
9
+
10
+
11
+ class RGBvgg19(nn.Sequential):
12
+ def __init__(self):
13
+ super(RGBvgg19, self).__init__()
14
+ self.model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg19', pretrained=True)
15
+ self.normalizer = Normalizer()
16
+ super(RGBvgg19, self).__init__(self.normalizer, self.model)
17
+
18
+
19
+ class RGBvgg11(nn.Sequential):
20
+ def __init__(self):
21
+ super(RGBvgg11, self).__init__()
22
+ self.model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg11', pretrained=True)
23
+ self.normalizer = Normalizer()
24
+ super(RGBvgg11, self).__init__(self.normalizer, self.model)
deepgaze_pytorch/features/wsl.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+
8
+ from .normalizer import Normalizer
9
+
10
+
11
+
12
+ class RGBResNext50(nn.Sequential):
13
+ def __init__(self):
14
+ super(RGBResNext50, self).__init__()
15
+ self.resnext = torch.hub.load('facebookresearch/WSL-Images', 'resnext50_32x16d_wsl')
16
+ self.normalizer = Normalizer()
17
+ super(RGBResNext50, self).__init__(self.normalizer, self.resnext)
18
+
19
+
20
+ class RGBResNext101(nn.Sequential):
21
+ def __init__(self):
22
+ super(RGBResNext101, self).__init__()
23
+ self.resnext = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x16d_wsl')
24
+ self.normalizer = Normalizer()
25
+ super(RGBResNext101, self).__init__(self.normalizer, self.resnext)
26
+
27
+
deepgaze_pytorch/layers.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=missing-module-docstring,invalid-name
2
+ # pylint: disable=missing-docstring
3
+ # pylint: disable=line-too-long
4
+
5
+ import math
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class LayerNorm(nn.Module):
14
+ r"""Applies Layer Normalization over a mini-batch of inputs as described in
15
+ the paper `Layer Normalization`_ .
16
+
17
+ .. math::
18
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
19
+
20
+ The mean and standard-deviation are calculated separately over the last
21
+ certain number dimensions which have to be of the shape specified by
22
+ :attr:`normalized_shape`.
23
+ :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
24
+ :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
25
+
26
+ .. note::
27
+ Unlike Batch Normalization and Instance Normalization, which applies
28
+ scalar scale and bias for each entire channel/plane with the
29
+ :attr:`affine` option, Layer Normalization applies per-element scale and
30
+ bias with :attr:`elementwise_affine`.
31
+
32
+ This layer uses statistics computed from input data in both training and
33
+ evaluation modes.
34
+
35
+ Args:
36
+ normalized_shape (int or list or torch.Size): input shape from an expected input
37
+ of size
38
+
39
+ .. math::
40
+ [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
41
+ \times \ldots \times \text{normalized\_shape}[-1]]
42
+
43
+ If a single integer is used, it is treated as a singleton list, and this module will
44
+ normalize over the last dimension which is expected to be of that specific size.
45
+ eps: a value added to the denominator for numerical stability. Default: 1e-5
46
+ elementwise_affine: a boolean value that when set to ``True``, this module
47
+ has learnable per-element affine parameters initialized to ones (for weights)
48
+ and zeros (for biases). Default: ``True``.
49
+
50
+ Shape:
51
+ - Input: :math:`(N, *)`
52
+ - Output: :math:`(N, *)` (same shape as input)
53
+
54
+ Examples::
55
+
56
+ >>> input = torch.randn(20, 5, 10, 10)
57
+ >>> # With Learnable Parameters
58
+ >>> m = nn.LayerNorm(input.size()[1:])
59
+ >>> # Without Learnable Parameters
60
+ >>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
61
+ >>> # Normalize over last two dimensions
62
+ >>> m = nn.LayerNorm([10, 10])
63
+ >>> # Normalize over last dimension of size 10
64
+ >>> m = nn.LayerNorm(10)
65
+ >>> # Activating the module
66
+ >>> output = m(input)
67
+
68
+ .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
69
+ """
70
+ __constants__ = ['features', 'weight', 'bias', 'eps', 'center', 'scale']
71
+
72
+ def __init__(self, features, eps=1e-12, center=True, scale=True):
73
+ super(LayerNorm, self).__init__()
74
+ self.features = features
75
+ self.eps = eps
76
+ self.center = center
77
+ self.scale = scale
78
+
79
+ if self.scale:
80
+ self.weight = nn.Parameter(torch.Tensor(self.features))
81
+ else:
82
+ self.register_parameter('weight', None)
83
+
84
+ if self.center:
85
+ self.bias = nn.Parameter(torch.Tensor(self.features))
86
+ else:
87
+ self.register_parameter('bias', None)
88
+
89
+ self.reset_parameters()
90
+
91
+ def reset_parameters(self):
92
+ if self.scale:
93
+ nn.init.ones_(self.weight)
94
+
95
+ if self.center:
96
+ nn.init.zeros_(self.bias)
97
+
98
+ def adjust_parameter(self, tensor, parameter):
99
+ return torch.repeat_interleave(
100
+ torch.repeat_interleave(
101
+ parameter.view(-1, 1, 1),
102
+ repeats=tensor.shape[2],
103
+ dim=1),
104
+ repeats=tensor.shape[3],
105
+ dim=2
106
+ )
107
+
108
+ def forward(self, input):
109
+ normalized_shape = (self.features, input.shape[2], input.shape[3])
110
+ weight = self.adjust_parameter(input, self.weight)
111
+ bias = self.adjust_parameter(input, self.bias)
112
+ return F.layer_norm(
113
+ input, normalized_shape, weight, bias, self.eps)
114
+
115
+ def extra_repr(self):
116
+ return '{features}, eps={eps}, ' \
117
+ 'center={center}, scale={scale}'.format(**self.__dict__)
118
+
119
+
120
+ def gaussian_filter_1d(tensor, dim, sigma, truncate=4, kernel_size=None, padding_mode='replicate', padding_value=0.0):
121
+ sigma = torch.as_tensor(sigma, device=tensor.device, dtype=tensor.dtype)
122
+
123
+ if kernel_size is not None:
124
+ kernel_size = torch.as_tensor(kernel_size, device=tensor.device, dtype=torch.int64)
125
+ else:
126
+ kernel_size = torch.as_tensor(2 * torch.ceil(truncate * sigma) + 1, device=tensor.device, dtype=torch.int64)
127
+
128
+ kernel_size = kernel_size.detach()
129
+
130
+ kernel_size_int = kernel_size.detach().cpu().numpy()
131
+
132
+ mean = (torch.as_tensor(kernel_size, dtype=tensor.dtype) - 1) / 2
133
+
134
+ grid = torch.arange(kernel_size, device=tensor.device) - mean
135
+
136
+ kernel_shape = (1, 1, kernel_size)
137
+ grid = grid.view(kernel_shape)
138
+
139
+ grid = grid.detach()
140
+
141
+ source_shape = tensor.shape
142
+
143
+ tensor = torch.movedim(tensor, dim, len(source_shape)-1)
144
+ dim_last_shape = tensor.shape
145
+ assert tensor.shape[-1] == source_shape[dim]
146
+
147
+ # we need reshape instead of view for batches like B x C x H x W
148
+ tensor = tensor.reshape(-1, 1, source_shape[dim])
149
+
150
+ padding = (math.ceil((kernel_size_int - 1) / 2), math.ceil((kernel_size_int - 1) / 2))
151
+ tensor_ = F.pad(tensor, padding, padding_mode, padding_value)
152
+
153
+ # create gaussian kernel from grid using current sigma
154
+ kernel = torch.exp(-0.5 * (grid / sigma) ** 2)
155
+ kernel = kernel / kernel.sum()
156
+
157
+ # convolve input with gaussian kernel
158
+ tensor_ = F.conv1d(tensor_, kernel)
159
+ tensor_ = tensor_.view(dim_last_shape)
160
+ tensor_ = torch.movedim(tensor_, len(source_shape)-1, dim)
161
+
162
+ assert tensor_.shape == source_shape
163
+
164
+ return tensor_
165
+
166
+
167
+ class GaussianFilterNd(nn.Module):
168
+ """A differentiable gaussian filter"""
169
+
170
+ def __init__(self, dims, sigma, truncate=4, kernel_size=None, padding_mode='replicate', padding_value=0.0,
171
+ trainable=False):
172
+ """Creates a 1d gaussian filter
173
+
174
+ Args:
175
+ dims ([int]): the dimensions to which the gaussian filter is applied. Negative values won't work
176
+ sigma (float): standard deviation of the gaussian filter (blur size)
177
+ input_dims (int, optional): number of input dimensions ignoring batch and channel dimension,
178
+ i.e. use input_dims=2 for images (default: 2).
179
+ truncate (float, optional): truncate the filter at this many standard deviations (default: 4.0).
180
+ This has no effect if the `kernel_size` is explicitely set
181
+ kernel_size (int): size of the gaussian kernel convolved with the input
182
+ padding_mode (string, optional): Padding mode implemented by `torch.nn.functional.pad`.
183
+ padding_value (string, optional): Value used for constant padding.
184
+ """
185
+ # IDEA determine input_dims dynamically for every input
186
+ super(GaussianFilterNd, self).__init__()
187
+
188
+ self.dims = dims
189
+ self.sigma = nn.Parameter(torch.tensor(sigma, dtype=torch.float32), requires_grad=trainable) # default: no optimization
190
+ self.truncate = truncate
191
+ self.kernel_size = kernel_size
192
+
193
+ # setup padding
194
+ self.padding_mode = padding_mode
195
+ self.padding_value = padding_value
196
+
197
+ def forward(self, tensor):
198
+ """Applies the gaussian filter to the given tensor"""
199
+ for dim in self.dims:
200
+ tensor = gaussian_filter_1d(
201
+ tensor,
202
+ dim=dim,
203
+ sigma=self.sigma,
204
+ truncate=self.truncate,
205
+ kernel_size=self.kernel_size,
206
+ padding_mode=self.padding_mode,
207
+ padding_value=self.padding_value,
208
+ )
209
+
210
+ return tensor
211
+
212
+
213
+ class Conv2dMultiInput(nn.Module):
214
+ def __init__(self, in_channels, out_channels, kernel_size, bias=True):
215
+ super().__init__()
216
+ self.in_channels = in_channels
217
+ self.out_channels = out_channels
218
+
219
+ for k, _in_channels in enumerate(in_channels):
220
+ if _in_channels:
221
+ setattr(self, f'conv_part{k}', nn.Conv2d(_in_channels, out_channels, kernel_size, bias=bias))
222
+
223
+ def forward(self, tensors):
224
+ assert len(tensors) == len(self.in_channels)
225
+
226
+ out = None
227
+ for k, (count, tensor) in enumerate(zip(self.in_channels, tensors)):
228
+ if not count:
229
+ continue
230
+ _out = getattr(self, f'conv_part{k}')(tensor)
231
+
232
+ if out is None:
233
+ out = _out
234
+ else:
235
+ out += _out
236
+
237
+ return out
238
+
239
+ # def extra_repr(self):
240
+ # return f'{self.in_channels}'
241
+
242
+
243
+ class LayerNormMultiInput(nn.Module):
244
+ __constants__ = ['features', 'weight', 'bias', 'eps', 'center', 'scale']
245
+
246
+ def __init__(self, features, eps=1e-12, center=True, scale=True):
247
+ super().__init__()
248
+ self.features = features
249
+ self.eps = eps
250
+ self.center = center
251
+ self.scale = scale
252
+
253
+ for k, _features in enumerate(features):
254
+ if _features:
255
+ setattr(self, f'layernorm_part{k}', LayerNorm(_features, eps=eps, center=center, scale=scale))
256
+
257
+ def forward(self, tensors):
258
+ assert len(tensors) == len(self.features)
259
+
260
+ out = []
261
+ for k, (count, tensor) in enumerate(zip(self.features, tensors)):
262
+ if not count:
263
+ assert tensor is None
264
+ out.append(None)
265
+ continue
266
+ out.append(getattr(self, f'layernorm_part{k}')(tensor))
267
+
268
+ return out
269
+
270
+
271
+ class Bias(nn.Module):
272
+ def __init__(self, channels):
273
+ super().__init__()
274
+ self.channels = channels
275
+ self.bias = nn.Parameter(torch.zeros(channels))
276
+
277
+ def forward(self, tensor):
278
+ return tensor + self.bias[np.newaxis, :, np.newaxis, np.newaxis]
279
+
280
+ def extra_repr(self):
281
+ return f'channels={self.channels}'
282
+
283
+
284
+ class SelfAttention(nn.Module):
285
+ """ Self attention Layer
286
+
287
+ adapted from https://discuss.pytorch.org/t/attention-in-image-classification/80147/3
288
+ """
289
+
290
+ def __init__(self, in_channels, out_channels=None, key_channels=None, activation=None, skip_connection_with_convolution=False, return_attention=True):
291
+ super().__init__()
292
+ self.in_channels = in_channels
293
+ if out_channels is None:
294
+ out_channels = in_channels
295
+ self.out_channels = out_channels
296
+ if key_channels is None:
297
+ key_channels = in_channels // 8
298
+ self.key_channels = key_channels
299
+ self.activation = activation
300
+ self.skip_connection_with_convolution = skip_connection_with_convolution
301
+ if not self.skip_connection_with_convolution:
302
+ if self.out_channels != self.in_channels:
303
+ raise ValueError("out_channels has to be equal to in_channels with true skip connection!")
304
+ self.return_attention = return_attention
305
+
306
+ self.query_conv = nn.Conv2d(in_channels=in_channels, out_channels=key_channels, kernel_size=1)
307
+ self.key_conv = nn.Conv2d(in_channels=in_channels, out_channels=key_channels, kernel_size=1)
308
+ self.value_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
309
+ self.gamma = nn.Parameter(torch.zeros(1))
310
+ if self.skip_connection_with_convolution:
311
+ self.skip_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
312
+
313
+ self.softmax = nn.Softmax(dim=-1)
314
+
315
+ def forward(self, x):
316
+ """
317
+ inputs :
318
+ x : input feature maps( B X C X W X H)
319
+ returns :
320
+ out : self attention value + input feature
321
+ attention: B X N X N (N is Width*Height)
322
+ """
323
+ m_batchsize, C, width, height = x.size()
324
+ proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N)
325
+ proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)
326
+ energy = torch.bmm(proj_query, proj_key) # transpose check
327
+ attention = self.softmax(energy) # BX (N) X (N)
328
+ proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N
329
+
330
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
331
+ out = out.view(m_batchsize, self.out_channels, width, height)
332
+
333
+ if self.skip_connection_with_convolution:
334
+ skip_connection = self.skip_conv(x)
335
+ else:
336
+ skip_connection = x
337
+ out = self.gamma * out + skip_connection
338
+
339
+ if self.activation is not None:
340
+ out = self.activation(out)
341
+
342
+ if self.return_attention:
343
+ return out, attention
344
+
345
+ return out
346
+
347
+
348
+ class MultiHeadSelfAttention(nn.Module):
349
+ """ Self attention Layer
350
+
351
+ adapted from https://discuss.pytorch.org/t/attention-in-image-classification/80147/3
352
+ """
353
+
354
+ def __init__(self, in_channels, heads, out_channels=None, key_channels=None, activation=None, skip_connection_with_convolution=False):
355
+ super().__init__()
356
+ self.heads = heads
357
+ self.heads = nn.ModuleList([SelfAttention(
358
+ in_channels=in_channels,
359
+ out_channels=out_channels,
360
+ key_channels=key_channels,
361
+ activation=activation,
362
+ skip_connection_with_convolution=skip_connection_with_convolution,
363
+ return_attention=False,
364
+ ) for _ in range(heads)])
365
+
366
+ def forward(self, tensor):
367
+ outs = [head(tensor) for head in self.heads]
368
+ out = torch.cat(outs, dim=1)
369
+ return out
370
+
371
+
372
+ class FlexibleScanpathHistoryEncoding(nn.Module):
373
+ """
374
+ a convolutional layer which works for different numbers of previous fixations.
375
+
376
+ Nonexistent fixations will deactivate the respective convolutions
377
+ the bias will be added per fixation (if the given fixation is present)
378
+ """
379
+ def __init__(self, in_fixations, channels_per_fixation, out_channels, kernel_size, bias=True,):
380
+ super().__init__()
381
+ self.in_fixations = in_fixations
382
+ self.channels_per_fixation = channels_per_fixation
383
+ self.out_channels = out_channels
384
+ self.kernel_size = kernel_size
385
+ self.bias = bias
386
+ self.convolutions = nn.ModuleList([
387
+ nn.Conv2d(
388
+ in_channels=self.channels_per_fixation,
389
+ out_channels=self.out_channels,
390
+ kernel_size=self.kernel_size,
391
+ bias=self.bias
392
+ ) for i in range(in_fixations)
393
+ ])
394
+
395
+ def forward(self, tensor):
396
+ results = None
397
+ valid_fixations = ~torch.isnan(
398
+ tensor[:, :self.in_fixations, 0, 0]
399
+ )
400
+ # print("valid fix", valid_fixations)
401
+
402
+ for fixation_index in range(self.in_fixations):
403
+ valid_indices = valid_fixations[:, fixation_index]
404
+ if not torch.any(valid_indices):
405
+ continue
406
+ this_input = tensor[
407
+ valid_indices,
408
+ fixation_index::self.in_fixations
409
+ ]
410
+ this_result = self.convolutions[fixation_index](
411
+ this_input
412
+ )
413
+ # TODO: This will break if all data points
414
+ # in the batch don't have a single fixation
415
+ # but that's not a case I intend to train
416
+ # anyway.
417
+ if results is None:
418
+ b, _, _, _ = tensor.shape
419
+ _, _, h, w = this_result.shape
420
+ results = torch.zeros(
421
+ (b, self.out_channels, h, w),
422
+ dtype=tensor.dtype,
423
+ device=tensor.device
424
+ )
425
+ results[valid_indices] += this_result
426
+
427
+ return results
deepgaze_pytorch/metrics.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pysaliency.roc import general_roc
3
+ from pysaliency.numba_utils import auc_for_one_positive
4
+ import torch
5
+
6
+
7
+ def _general_auc(positives, negatives):
8
+ if len(positives) == 1:
9
+ return auc_for_one_positive(positives[0], negatives)
10
+ else:
11
+ return general_roc(positives, negatives)[0]
12
+
13
+
14
+ def log_likelihood(log_density, fixation_mask, weights=None):
15
+ #if weights is None:
16
+ # weights = torch.ones(log_density.shape[0])
17
+
18
+ weights = len(weights) * weights.view(-1, 1, 1) / weights.sum()
19
+
20
+ if isinstance(fixation_mask, torch.sparse.IntTensor):
21
+ dense_mask = fixation_mask.to_dense()
22
+ else:
23
+ dense_mask = fixation_mask
24
+ fixation_count = dense_mask.sum(dim=(-1, -2), keepdim=True)
25
+ ll = torch.mean(
26
+ weights * torch.sum(log_density * dense_mask, dim=(-1, -2), keepdim=True) / fixation_count
27
+ )
28
+ return (ll + np.log(log_density.shape[-1] * log_density.shape[-2])) / np.log(2)
29
+
30
+
31
+ def nss(log_density, fixation_mask, weights=None):
32
+ weights = len(weights) * weights.view(-1, 1, 1) / weights.sum()
33
+ if isinstance(fixation_mask, torch.sparse.IntTensor):
34
+ dense_mask = fixation_mask.to_dense()
35
+ else:
36
+ dense_mask = fixation_mask
37
+
38
+ fixation_count = dense_mask.sum(dim=(-1, -2), keepdim=True)
39
+
40
+ density = torch.exp(log_density)
41
+ mean, std = torch.std_mean(density, dim=(-1, -2), keepdim=True)
42
+ saliency_map = (density - mean) / std
43
+
44
+ nss = torch.mean(
45
+ weights * torch.sum(saliency_map * dense_mask, dim=(-1, -2), keepdim=True) / fixation_count
46
+ )
47
+ return nss
48
+
49
+
50
+ def auc(log_density, fixation_mask, weights=None):
51
+ weights = len(weights) * weights / weights.sum()
52
+
53
+ # TODO: This doesn't account for multiple fixations in the same location!
54
+ def image_auc(log_density, fixation_mask):
55
+ if isinstance(fixation_mask, torch.sparse.IntTensor):
56
+ dense_mask = fixation_mask.to_dense()
57
+ else:
58
+ dense_mask = fixation_mask
59
+
60
+ positives = torch.masked_select(log_density, dense_mask.type(torch.bool)).detach().cpu().numpy().astype(np.float64)
61
+ negatives = log_density.flatten().detach().cpu().numpy().astype(np.float64)
62
+
63
+ auc = _general_auc(positives, negatives)
64
+
65
+ return torch.tensor(auc)
66
+
67
+ return torch.mean(weights.cpu() * torch.tensor([
68
+ image_auc(log_density[i], fixation_mask[i]) for i in range(log_density.shape[0])
69
+ ]))