diff --git a/SaRa/__pycache__/pySaliencyMap.cpython-39.pyc b/SaRa/__pycache__/pySaliencyMap.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d54a9550b57f44ab6789f4f01530630a7ea0bce Binary files /dev/null and b/SaRa/__pycache__/pySaliencyMap.cpython-39.pyc differ diff --git a/SaRa/__pycache__/pySaliencyMapDefs.cpython-39.pyc b/SaRa/__pycache__/pySaliencyMapDefs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28db17c5804e4b94c25ff24286da83eeb4a866a5 Binary files /dev/null and b/SaRa/__pycache__/pySaliencyMapDefs.cpython-39.pyc differ diff --git a/SaRa/__pycache__/saraRC1.cpython-39.pyc b/SaRa/__pycache__/saraRC1.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c183934bcd991f004aa325f2ce5312fdf5c4645 Binary files /dev/null and b/SaRa/__pycache__/saraRC1.cpython-39.pyc differ diff --git a/SaRa/pySaliencyMap.py b/SaRa/pySaliencyMap.py new file mode 100644 index 0000000000000000000000000000000000000000..c7ed95c2d99331f4235af663dbdf16413e09f492 --- /dev/null +++ b/SaRa/pySaliencyMap.py @@ -0,0 +1,288 @@ +#------------------------------------------------------------------------------- +# Name: pySaliencyMap +# Purpose: Extracting a saliency map from a single still image +# +# Author: Akisato Kimura +# +# Created: April 24, 2014 +# Copyright: (c) Akisato Kimura 2014- +# Licence: All rights reserved +#------------------------------------------------------------------------------- + +import cv2 +import numpy as np +import SaRa.pySaliencyMapDefs as pySaliencyMapDefs +import time + +class pySaliencyMap: + # initialization + def __init__(self, width, height): + self.width = width + self.height = height + self.prev_frame = None + self.SM = None + self.GaborKernel0 = np.array(pySaliencyMapDefs.GaborKernel_0) + self.GaborKernel45 = np.array(pySaliencyMapDefs.GaborKernel_45) + self.GaborKernel90 = np.array(pySaliencyMapDefs.GaborKernel_90) + self.GaborKernel135 = np.array(pySaliencyMapDefs.GaborKernel_135) + + # extracting color channels + def SMExtractRGBI(self, inputImage): + # convert scale of array elements + src = np.float32(inputImage) * 1./255 + # split + (B, G, R) = cv2.split(src) + # extract an intensity image + I = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY) + # return + return R, G, B, I + + # feature maps + ## constructing a Gaussian pyramid + def FMCreateGaussianPyr(self, src): + dst = list() + dst.append(src) + for i in range(1,9): + nowdst = cv2.pyrDown(dst[i-1]) + dst.append(nowdst) + return dst + ## taking center-surround differences + def FMCenterSurroundDiff(self, GaussianMaps): + dst = list() + for s in range(2,5): + now_size = GaussianMaps[s].shape + now_size = (now_size[1], now_size[0]) ## (width, height) + tmp = cv2.resize(GaussianMaps[s+3], now_size, interpolation=cv2.INTER_LINEAR) + nowdst = cv2.absdiff(GaussianMaps[s], tmp) + dst.append(nowdst) + tmp = cv2.resize(GaussianMaps[s+4], now_size, interpolation=cv2.INTER_LINEAR) + nowdst = cv2.absdiff(GaussianMaps[s], tmp) + dst.append(nowdst) + return dst + ## constructing a Gaussian pyramid + taking center-surround differences + def FMGaussianPyrCSD(self, src): + GaussianMaps = self.FMCreateGaussianPyr(src) + dst = self.FMCenterSurroundDiff(GaussianMaps) + return dst + ## intensity feature maps + def IFMGetFM(self, I): + return self.FMGaussianPyrCSD(I) + ## color feature maps + def CFMGetFM(self, R, G, B): + # max(R,G,B) + tmp1 = cv2.max(R, G) + RGBMax = cv2.max(B, tmp1) + RGBMax[RGBMax <= 0] = 0.0001 # prevent dividing by 0 + # min(R,G) + RGMin = cv2.min(R, G) + # RG = (R-G)/max(R,G,B) + RG = (R - G) / RGBMax + # BY = (B-min(R,G)/max(R,G,B) + BY = (B - RGMin) / RGBMax + # clamp nagative values to 0 + RG[RG < 0] = 0 + BY[BY < 0] = 0 + # obtain feature maps in the same way as intensity + RGFM = self.FMGaussianPyrCSD(RG) + BYFM = self.FMGaussianPyrCSD(BY) + # return + return RGFM, BYFM + ## orientation feature maps + def OFMGetFM(self, src): + # creating a Gaussian pyramid + GaussianI = self.FMCreateGaussianPyr(src) + # convoluting a Gabor filter with an intensity image to extract oriemtation features + GaborOutput0 = [ np.empty((1,1)), np.empty((1,1)) ] # dummy data: any kinds of np.array()s are OK + GaborOutput45 = [ np.empty((1,1)), np.empty((1,1)) ] + GaborOutput90 = [ np.empty((1,1)), np.empty((1,1)) ] + GaborOutput135 = [ np.empty((1,1)), np.empty((1,1)) ] + for j in range(2,9): + GaborOutput0.append( cv2.filter2D(GaussianI[j], cv2.CV_32F, self.GaborKernel0) ) + GaborOutput45.append( cv2.filter2D(GaussianI[j], cv2.CV_32F, self.GaborKernel45) ) + GaborOutput90.append( cv2.filter2D(GaussianI[j], cv2.CV_32F, self.GaborKernel90) ) + GaborOutput135.append( cv2.filter2D(GaussianI[j], cv2.CV_32F, self.GaborKernel135) ) + # calculating center-surround differences for every oriantation + CSD0 = self.FMCenterSurroundDiff(GaborOutput0) + CSD45 = self.FMCenterSurroundDiff(GaborOutput45) + CSD90 = self.FMCenterSurroundDiff(GaborOutput90) + CSD135 = self.FMCenterSurroundDiff(GaborOutput135) + # concatenate + dst = list(CSD0) + dst.extend(CSD45) + dst.extend(CSD90) + dst.extend(CSD135) + # return + return dst + ## motion feature maps + def MFMGetFM(self, src): + # convert scale + I8U = np.uint8(255 * src) + # cv2.waitKey(10) + # calculating optical flows + if self.prev_frame is not None: + farne_pyr_scale= pySaliencyMapDefs.farne_pyr_scale + farne_levels = pySaliencyMapDefs.farne_levels + farne_winsize = pySaliencyMapDefs.farne_winsize + farne_iterations = pySaliencyMapDefs.farne_iterations + farne_poly_n = pySaliencyMapDefs.farne_poly_n + farne_poly_sigma = pySaliencyMapDefs.farne_poly_sigma + farne_flags = pySaliencyMapDefs.farne_flags + flow = cv2.calcOpticalFlowFarneback(\ + prev = self.prev_frame, \ + next = I8U, \ + pyr_scale = farne_pyr_scale, \ + levels = farne_levels, \ + winsize = farne_winsize, \ + iterations = farne_iterations, \ + poly_n = farne_poly_n, \ + poly_sigma = farne_poly_sigma, \ + flags = farne_flags, \ + flow = None \ + ) + flowx = flow[...,0] + flowy = flow[...,1] + else: + flowx = np.zeros(I8U.shape) + flowy = np.zeros(I8U.shape) + # create Gaussian pyramids + dst_x = self.FMGaussianPyrCSD(flowx) + dst_y = self.FMGaussianPyrCSD(flowy) + # update the current frame + self.prev_frame = np.uint8(I8U) + # return + return dst_x, dst_y + + # conspicuity maps + ## standard range normalization + def SMRangeNormalize(self, src): + minn, maxx, dummy1, dummy2 = cv2.minMaxLoc(src) + if maxx!=minn: + dst = src/(maxx-minn) + minn/(minn-maxx) + else: + dst = src - minn + return dst + ## computing an average of local maxima + def SMAvgLocalMax(self, src): + # size + stepsize = pySaliencyMapDefs.default_step_local + width = src.shape[1] + height = src.shape[0] + # find local maxima + numlocal = 0 + lmaxmean = 0 + for y in range(0, height-stepsize, stepsize): + for x in range(0, width-stepsize, stepsize): + localimg = src[y:y+stepsize, x:x+stepsize] + lmin, lmax, dummy1, dummy2 = cv2.minMaxLoc(localimg) + lmaxmean += lmax + numlocal += 1 + # averaging over all the local regions (error checking for numlocal) + if numlocal==0: + return 0 + else: + return lmaxmean / numlocal + ## normalization specific for the saliency map model + def SMNormalization(self, src): + dst = self.SMRangeNormalize(src) + lmaxmean = self.SMAvgLocalMax(dst) + normcoeff = (1-lmaxmean)*(1-lmaxmean) + return dst * normcoeff + ## normalizing feature maps + def normalizeFeatureMaps(self, FM): + NFM = list() + for i in range(0,6): + normalizedImage = self.SMNormalization(FM[i]) + nownfm = cv2.resize(normalizedImage, (self.width, self.height), interpolation=cv2.INTER_LINEAR) + NFM.append(nownfm) + return NFM + ## intensity conspicuity map + def ICMGetCM(self, IFM): + NIFM = self.normalizeFeatureMaps(IFM) + ICM = sum(NIFM) + return ICM + ## color conspicuity map + def CCMGetCM(self, CFM_RG, CFM_BY): + # extracting a conspicuity map for every color opponent pair + CCM_RG = self.ICMGetCM(CFM_RG) + CCM_BY = self.ICMGetCM(CFM_BY) + # merge + CCM = CCM_RG + CCM_BY + # return + return CCM + ## orientation conspicuity map + def OCMGetCM(self, OFM): + OCM = np.zeros((self.height, self.width)) + for i in range (0,4): + # slicing + nowofm = OFM[i*6:(i+1)*6] # angle = i*45 + # extracting a conspicuity map for every angle + NOFM = self.ICMGetCM(nowofm) + # normalize + NOFM2 = self.SMNormalization(NOFM) + # accumulate + OCM += NOFM2 + return OCM + ## motion conspicuity map + def MCMGetCM(self, MFM_X, MFM_Y): + return self.CCMGetCM(MFM_X, MFM_Y) + + # core + def SMGetSM(self, src): + # definitions + size = src.shape + width = size[1] + height = size[0] + # check +# if(width != self.width or height != self.height): +# sys.exit("size mismatch") + # extracting individual color channels + R, G, B, I = self.SMExtractRGBI(src) + # extracting feature maps + IFM = self.IFMGetFM(I) + CFM_RG, CFM_BY = self.CFMGetFM(R, G, B) + OFM = self.OFMGetFM(I) + MFM_X, MFM_Y = self.MFMGetFM(I) + # extracting conspicuity maps + ICM = self.ICMGetCM(IFM) + CCM = self.CCMGetCM(CFM_RG, CFM_BY) + OCM = self.OCMGetCM(OFM) + MCM = self.MCMGetCM(MFM_X, MFM_Y) + # adding all the conspicuity maps to form a saliency map + wi = pySaliencyMapDefs.weight_intensity + wc = pySaliencyMapDefs.weight_color + wo = pySaliencyMapDefs.weight_orientation + wm = pySaliencyMapDefs.weight_motion + SMMat = wi*ICM + wc*CCM + wo*OCM + wm*MCM + # normalize + normalizedSM = self.SMRangeNormalize(SMMat) + normalizedSM2 = normalizedSM.astype(np.float32) + smoothedSM = cv2.bilateralFilter(normalizedSM2, 7, 3, 1.55) + self.SM = cv2.resize(smoothedSM, (width,height), interpolation=cv2.INTER_NEAREST) + # return + return self.SM + + def SMGetBinarizedSM(self, src): + # get a saliency map + if self.SM is None: + self.SM = self.SMGetSM(src) + # convert scale + SM_I8U = np.uint8(255 * self.SM) + # binarize + thresh, binarized_SM = cv2.threshold(SM_I8U, thresh=0, maxval=255, type=cv2.THRESH_BINARY+cv2.THRESH_OTSU) + return binarized_SM + + def SMGetSalientRegion(self, src): + # get a binarized saliency map + binarized_SM = self.SMGetBinarizedSM(src) + # GrabCut + img = src.copy() + mask = np.where((binarized_SM!=0), cv2.GC_PR_FGD, cv2.GC_PR_BGD).astype('uint8') + bgdmodel = np.zeros((1,65),np.float64) + fgdmodel = np.zeros((1,65),np.float64) + rect = (0,0,1,1) # dummy + iterCount = 1 + cv2.grabCut(img, mask=mask, rect=rect, bgdModel=bgdmodel, fgdModel=fgdmodel, iterCount=iterCount, mode=cv2.GC_INIT_WITH_MASK) + # post-processing + mask_out = np.where((mask==cv2.GC_FGD) + (mask==cv2.GC_PR_FGD), 255, 0).astype('uint8') + output = cv2.bitwise_and(img,img,mask=mask_out) + return output diff --git a/SaRa/pySaliencyMapDefs.py b/SaRa/pySaliencyMapDefs.py new file mode 100644 index 0000000000000000000000000000000000000000..10d27c7acf9e11e4c90fe74065d178f8f0e28cef --- /dev/null +++ b/SaRa/pySaliencyMapDefs.py @@ -0,0 +1,74 @@ +#------------------------------------------------------------------------------- +# Name: pySaliencyMapDefs +# Purpose: Definitions for class pySaliencyMap +# +# Author: Akisato Kimura +# +# Created: April 24, 2014 +# Copyright: (c) Akisato Kimura 2014- +# Licence: All rights reserved +#------------------------------------------------------------------------------- + +# parameters for computing optical flows using the Gunner Farneback's algorithm +farne_pyr_scale = 0.5 +farne_levels = 3 +farne_winsize = 15 +farne_iterations = 3 +farne_poly_n = 5 +farne_poly_sigma = 1.2 +farne_flags = 0 + +# parameters for detecting local maxima +default_step_local = 16 + +# feature weights +weight_intensity = 0.30 +weight_color = 0.30 +weight_orientation = 0.20 +weight_motion = 0.20 + +# coefficients of Gabor filters +GaborKernel_0 = [\ + [ 1.85212E-06, 1.28181E-05, -0.000350433, -0.000136537, 0.002010422, -0.000136537, -0.000350433, 1.28181E-05, 1.85212E-06 ],\ + [ 2.80209E-05, 0.000193926, -0.005301717, -0.002065674, 0.030415784, -0.002065674, -0.005301717, 0.000193926, 2.80209E-05 ],\ + [ 0.000195076, 0.001350077, -0.036909595, -0.014380852, 0.211749204, -0.014380852, -0.036909595, 0.001350077, 0.000195076 ],\ + [ 0.000624940, 0.004325061, -0.118242318, -0.046070008, 0.678352526, -0.046070008, -0.118242318, 0.004325061, 0.000624940 ],\ + [ 0.000921261, 0.006375831, -0.174308068, -0.067914552, 1.000000000, -0.067914552, -0.174308068, 0.006375831, 0.000921261 ],\ + [ 0.000624940, 0.004325061, -0.118242318, -0.046070008, 0.678352526, -0.046070008, -0.118242318, 0.004325061, 0.000624940 ],\ + [ 0.000195076, 0.001350077, -0.036909595, -0.014380852, 0.211749204, -0.014380852, -0.036909595, 0.001350077, 0.000195076 ],\ + [ 2.80209E-05, 0.000193926, -0.005301717, -0.002065674, 0.030415784, -0.002065674, -0.005301717, 0.000193926, 2.80209E-05 ],\ + [ 1.85212E-06, 1.28181E-05, -0.000350433, -0.000136537, 0.002010422, -0.000136537, -0.000350433, 1.28181E-05, 1.85212E-06 ]\ +] +GaborKernel_45 = [\ + [ 4.04180E-06, 2.25320E-05, -0.000279806, -0.001028923, 3.79931E-05, 0.000744712, 0.000132863, -9.04408E-06, -1.01551E-06 ],\ + [ 2.25320E-05, 0.000925120, 0.002373205, -0.013561362, -0.022947700, 0.000389916, 0.003516954, 0.000288732, -9.04408E-06 ],\ + [ -0.000279806, 0.002373205, 0.044837725, 0.052928748, -0.139178011, -0.108372072, 0.000847346, 0.003516954, 0.000132863 ],\ + [ -0.001028923, -0.013561362, 0.052928748, 0.460162150, 0.249959607, -0.302454279, -0.108372072, 0.000389916, 0.000744712 ],\ + [ 3.79931E-05, -0.022947700, -0.139178011, 0.249959607, 1.000000000, 0.249959607, -0.139178011, -0.022947700, 3.79931E-05 ],\ + [ 0.000744712, 0.003899160, -0.108372072, -0.302454279, 0.249959607, 0.460162150, 0.052928748, -0.013561362, -0.001028923 ],\ + [ 0.000132863, 0.003516954, 0.000847346, -0.108372072, -0.139178011, 0.052928748, 0.044837725, 0.002373205, -0.000279806 ],\ + [ -9.04408E-06, 0.000288732, 0.003516954, 0.000389916, -0.022947700, -0.013561362, 0.002373205, 0.000925120, 2.25320E-05 ],\ + [ -1.01551E-06, -9.04408E-06, 0.000132863, 0.000744712, 3.79931E-05, -0.001028923, -0.000279806, 2.25320E-05, 4.04180E-06 ]\ +] +GaborKernel_90 = [\ + [ 1.85212E-06, 2.80209E-05, 0.000195076, 0.000624940, 0.000921261, 0.000624940, 0.000195076, 2.80209E-05, 1.85212E-06 ],\ + [ 1.28181E-05, 0.000193926, 0.001350077, 0.004325061, 0.006375831, 0.004325061, 0.001350077, 0.000193926, 1.28181E-05 ],\ + [ -0.000350433, -0.005301717, -0.036909595, -0.118242318, -0.174308068, -0.118242318, -0.036909595, -0.005301717, -0.000350433 ],\ + [ -0.000136537, -0.002065674, -0.014380852, -0.046070008, -0.067914552, -0.046070008, -0.014380852, -0.002065674, -0.000136537 ],\ + [ 0.002010422, 0.030415784, 0.211749204, 0.678352526, 1.000000000, 0.678352526, 0.211749204, 0.030415784, 0.002010422 ],\ + [ -0.000136537, -0.002065674, -0.014380852, -0.046070008, -0.067914552, -0.046070008, -0.014380852, -0.002065674, -0.000136537 ],\ + [ -0.000350433, -0.005301717, -0.036909595, -0.118242318, -0.174308068, -0.118242318, -0.036909595, -0.005301717, -0.000350433 ],\ + [ 1.28181E-05, 0.000193926, 0.001350077, 0.004325061, 0.006375831, 0.004325061, 0.001350077, 0.000193926, 1.28181E-05 ],\ + [ 1.85212E-06, 2.80209E-05, 0.000195076, 0.000624940, 0.000921261, 0.000624940, 0.000195076, 2.80209E-05, 1.85212E-06 ] +] +GaborKernel_135 = [\ + [ -1.01551E-06, -9.04408E-06, 0.000132863, 0.000744712, 3.79931E-05, -0.001028923, -0.000279806, 2.2532E-05, 4.0418E-06 ],\ + [ -9.04408E-06, 0.000288732, 0.003516954, 0.000389916, -0.022947700, -0.013561362, 0.002373205, 0.00092512, 2.2532E-05 ],\ + [ 0.000132863, 0.003516954, 0.000847346, -0.108372072, -0.139178011, 0.052928748, 0.044837725, 0.002373205, -0.000279806 ],\ + [ 0.000744712, 0.000389916, -0.108372072, -0.302454279, 0.249959607, 0.46016215, 0.052928748, -0.013561362, -0.001028923 ],\ + [ 3.79931E-05, -0.022947700, -0.139178011, 0.249959607, 1.000000000, 0.249959607, -0.139178011, -0.0229477, 3.79931E-05 ],\ + [ -0.001028923, -0.013561362, 0.052928748, 0.460162150, 0.249959607, -0.302454279, -0.108372072, 0.000389916, 0.000744712 ],\ + [ -0.000279806, 0.002373205, 0.044837725, 0.052928748, -0.139178011, -0.108372072, 0.000847346, 0.003516954, 0.000132863 ],\ + [ 2.25320E-05, 0.000925120, 0.002373205, -0.013561362, -0.022947700, 0.000389916, 0.003516954, 0.000288732, -9.04408E-06 ],\ + [ 4.04180E-06, 2.25320E-05, -0.000279806, -0.001028923, 3.79931E-05 , 0.000744712, 0.000132863, -9.04408E-06, -1.01551E-06 ]\ +] diff --git a/SaRa/saraRC1.py b/SaRa/saraRC1.py new file mode 100644 index 0000000000000000000000000000000000000000..8576725bfe8f9da61cd334a3d2e00a6acbd24c1d --- /dev/null +++ b/SaRa/saraRC1.py @@ -0,0 +1,1082 @@ +import cv2 +import numpy as np +import math +import scipy.stats as st +from mpl_toolkits.mplot3d import Axes3D +from matplotlib.lines import Line2D +import matplotlib.pyplot as plt +import operator +import time +import os +from enum import Enum +import pandas as pd + +# Akisato Kimura implementation of Itti's Saliency Map Generator -- https://github.com/akisatok/pySaliencyMap +from SaRa.pySaliencyMap import pySaliencyMap + + +# Global Variables + +# Entropy, sum, depth, centre-bias +WEIGHTS = (1, 1, 1, 1) + +# segments_entropies = [] +segments_scores = [] +segments_coords = [] + +seg_dim = 0 +segments = [] +gt_segments = [] +dws = [] +sara_list = [] + +eval_list = [] +labels_eval_list = ['Image', 'Index', 'Rank', 'Quartile', 'isGT', 'Outcome'] + +outcome_list = [] +labels_outcome_list = ['Image', 'FN', 'FP', 'TN', 'TP'] + +dataframe_collection = {} +error_count = 0 + + +# SaRa Initial Functions +def generate_segments(img, seg_count) -> list: + ''' + Given an image img and the desired number of segments seg_count, this + function divides the image into segments and returns a list of segments. + ''' + + segments = [] + segment_count = seg_count + index = 0 + + w_interval = int(img.shape[1] / segment_count) + h_interval = int(img.shape[0] / segment_count) + + for i in range(segment_count): + for j in range(segment_count): + temp_segment = img[int(h_interval * i):int(h_interval * (i + 1)), + int(w_interval * j):int(w_interval * (j + 1))] + segments.append(temp_segment) + + coord_tup = (index, int(w_interval * j), int(h_interval * i), + int(w_interval * (j + 1)), int(h_interval * (i + 1))) + segments_coords.append(coord_tup) + + index += 1 + + return segments + + +def return_saliency(img, generator='itti', deepgaze_model=None, emlnet_models=None, DEVICE='cpu'): + ''' + Takes an image img as input and calculates the saliency map using the + Itti's Saliency Map Generator. It returns the saliency map. + ''' + + img_width, img_height = img.shape[1], img.shape[0] + + if generator == 'itti': + + sm = pySaliencyMap(img_width, img_height) + saliency_map = sm.SMGetSM(img) + + # Scale pixel values to 0-255 instead of float (approx 0, hence black image) + # https://stackoverflow.com/questions/48331211/how-to-use-cv2-imshow-correctly-for-the-float-image-returned-by-cv2-distancet/48333272 + saliency_map = cv2.normalize(saliency_map, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8UC1) + elif generator == 'deepgaze': + import numpy as np + from scipy.misc import face + from scipy.ndimage import zoom + from scipy.special import logsumexp + import torch + + import deepgaze_pytorch + + # you can use DeepGazeI or DeepGazeIIE + # model = deepgaze_pytorch.DeepGazeIIE(pretrained=True).to(DEVICE) + + if deepgaze_model is None: + model = deepgaze_pytorch.DeepGazeIIE(pretrained=True).to(DEVICE) + else: + model = deepgaze_model + + # image = face() + image = img + + # load precomputed centerbias log density (from MIT1003) over a 1024x1024 image + # you can download the centerbias from https://github.com/matthias-k/DeepGaze/releases/download/v1.0.0/centerbias_mit1003.npy + # alternatively, you can use a uniform centerbias via `centerbias_template = np.zeros((1024, 1024))`. + # centerbias_template = np.load('centerbias_mit1003.npy') + centerbias_template = np.zeros((1024, 1024)) + # rescale to match image size + centerbias = zoom(centerbias_template, (image.shape[0]/centerbias_template.shape[0], image.shape[1]/centerbias_template.shape[1]), order=0, mode='nearest') + # renormalize log density + centerbias -= logsumexp(centerbias) + + image_tensor = torch.tensor([image.transpose(2, 0, 1)]).to(DEVICE) + centerbias_tensor = torch.tensor([centerbias]).to(DEVICE) + + log_density_prediction = model(image_tensor, centerbias_tensor) + + saliency_map = cv2.resize(log_density_prediction.detach().cpu().numpy()[0, 0], (img_width, img_height)) + + elif generator == 'fpn': + # Add ./fpn to the system path + import sys + sys.path.append('./fpn') + import inference as inf + + results_dict = {} + rt_args = inf.parse_arguments(img) + + # Call the run_inference function and capture the results + pred_masks_raw_list, pred_masks_round_list = inf.run_inference(rt_args) + + # Store the results in the dictionary + results_dict['pred_masks_raw'] = pred_masks_raw_list + results_dict['pred_masks_round'] = pred_masks_round_list + + saliency_map = results_dict['pred_masks_raw'] + + if img_width > img_height: + saliency_map = cv2.resize(saliency_map, (img_width, img_width)) + + diff = (img_width - img_height) // 2 + + saliency_map = saliency_map[diff:img_width - diff, 0:img_width] + else: + saliency_map = cv2.resize(saliency_map, (img_height, img_height)) + + diff = (img_height - img_width) // 2 + + saliency_map = saliency_map[0:img_height, diff:img_height - diff] + + elif generator == 'emlnet': + from emlnet.eval_combined import main as eval_combined + saliency_map = eval_combined(img, emlnet_models) + + # Resize to image size + saliency_map = cv2.resize(saliency_map, (img_width, img_height)) + + # Normalize saliency map + saliency_map = cv2.normalize(saliency_map, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8UC1) + + saliency_map = cv2.GaussianBlur(saliency_map, (31, 31), 10) + return saliency_map + saliency_map = saliency_map // 16 + + return saliency_map + + +def return_saliency_batch(images, generator='deepgaze', deepgaze_model=None, emlnet_models=None, DEVICE='cuda', BATCH_SIZE=1): + img_widths, img_heights = [], [] + if generator == 'deepgaze': + import numpy as np + from scipy.misc import face + from scipy.ndimage import zoom + from scipy.special import logsumexp + import torch + + import deepgaze_pytorch + + # you can use DeepGazeI or DeepGazeIIE + # model = deepgaze_pytorch.DeepGazeIIE(pretrained=True).to(DEVICE) + + if deepgaze_model is None: + model = deepgaze_pytorch.DeepGazeIIE(pretrained=True).to(DEVICE) + else: + model = deepgaze_model + + # image = face() + # image = img + image_batch = torch.tensor([img.transpose(2, 0, 1) for img in images]).to(DEVICE) + centerbias_template = np.zeros((1024, 1024)) + centerbias_tensors = [] + + for img in images: + centerbias = zoom(centerbias_template, (img.shape[0] / centerbias_template.shape[0], img.shape[1] / centerbias_template.shape[1]), order=0, mode='nearest') + centerbias -= logsumexp(centerbias) + centerbias_tensors.append(torch.tensor(centerbias).to(DEVICE)) + + # Set img_width and img_height + img_widths.append(img.shape[1]) + + + # rescale to match image size + # centerbias = zoom(centerbias_template, (image.shape[0]/centerbias_template.shape[0], image.shape[1]/centerbias_template.shape[1]), order=0, mode='nearest') + # # renormalize log density + # centerbias -= logsumexp(centerbias) + + # image_tensor = torch.tensor([image.transpose(2, 0, 1)]).to(DEVICE) + # centerbias_tensor = torch.tensor([centerbias]).to(DEVICE) + with torch.no_grad(): + # Process the batch of images in one forward pass + log_density_predictions = model(image_batch, torch.stack(centerbias_tensors)) + + # log_density_prediction = model(image_tensor, centerbias_tensor) + + # saliency_map = cv2.resize(log_density_prediction.detach().cpu().numpy()[0, 0], (img_width, img_height)) + + saliency_maps = [] + + for i in range(len(images)): + saliency_map = cv2.resize(log_density_predictions[i, 0].cpu().numpy(), (img_widths[i], img_widths[i])) + + saliency_map = cv2.normalize(saliency_map, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8UC1) + + saliency_map = cv2.GaussianBlur(saliency_map, (31, 31), 10) + saliency_map = saliency_map // 16 + + saliency_maps.append(saliency_map) + + return saliency_maps + + +# def return_itti_saliency(img): +# ''' +# Takes an image img as input and calculates the saliency map using the +# Itti's Saliency Map Generator. It returns the saliency map. +# ''' + +# img_width, img_height = img.shape[1], img.shape[0] + +# sm = pySaliencyMap.pySaliencyMap(img_width, img_height) +# saliency_map = sm.SMGetSM(img) + +# # Scale pixel values to 0-255 instead of float (approx 0, hence black image) +# # https://stackoverflow.com/questions/48331211/how-to-use-cv2-imshow-correctly-for-the-float-image-returned-by-cv2-distancet/48333272 +# saliency_map = cv2.normalize(saliency_map, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8UC1) + +# return saliency_map + + +# Saliency Ranking +def calculate_pixel_frequency(img) -> dict: + ''' + Calculates the frequency of each pixel value in the image img and + returns a dictionary containing the pixel frequencies. + ''' + + flt = img.flatten() + unique, counts = np.unique(flt, return_counts=True) + pixels_frequency = dict(zip(unique, counts)) + + return pixels_frequency + + +def calculate_score(H, sum, ds, cb, w): + ''' + Calculates the saliency score of an image img using the entropy H, depth score ds, centre-bias cb and weights w. It returns the saliency score. + ''' + + # Normalise H + # H = (H - 0) / (math.log(2, 256) - 0) + + # H = wth root of H + H = H ** w[0] + + if sum > 0: + sum = np.log(sum) + sum = sum ** w[1] + + ds = ds ** w[2] + + cb = (cb + 1) ** w[3] + + return H + sum + ds + cb + + +def calculate_entropy(img, w, dw) -> float: + ''' + Calculates the entropy of an image img using the given weights w and + depth weights dw. It returns the entropy value. + ''' + + flt = img.flatten() + + # c = flt.shape[0] + total_pixels = 0 + t_prob = 0 + # sum_of_probs = 0 + entropy = 0 + wt = w * 10 + + # if imgD=None then proceed normally + # else calculate its frequency and find max + # use this max value as a weight in entropy + + pixels_frequency = calculate_pixel_frequency(flt) + + total_pixels = sum(pixels_frequency.values()) + + for px in pixels_frequency: + t_prob = pixels_frequency[px] / total_pixels + + if t_prob != 0: + entropy += (t_prob * math.log((1 / t_prob), 2)) + + # entropy = entropy * wt * dw + + return entropy + + +def find_most_salient_segment(segments, kernel, dws): + ''' + Finds the most salient segment among the provided segments using a + given kernel and depth weights. It returns the maximum entropy value + and the index of the most salient segment. + ''' + + # max_entropy = 0 + max_score = 0 + index = 0 + i = 0 + + for segment in segments: + temp_entropy = calculate_entropy(segment, kernel[i], dws[i]) + # Normalise semgnet bweetn 0 and 255 + segment = cv2.normalize(segment, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8UC1) + temp_sum = np.sum(segment) + # temp_tup = (i, temp_entropy) + # segments_entropies.append(temp_tup) + + w = WEIGHTS + + temp_score = calculate_score(temp_entropy, temp_sum, dws[i], kernel[i], w) + + temp_tup = (i, temp_score, temp_entropy ** w[0], temp_sum ** w[1], (kernel[i] + 1) ** w[2], dws[i] ** w[3]) + + # segments_scores.append((i, temp_score)) + segments_scores.append(temp_tup) + + # if temp_entropy > max_entropy: + # max_entropy = temp_entropy + # index = i + + if temp_score > max_score: + max_score = temp_score + index = i + + i += 1 + + # return max_entropy, index + return max_score, index + + +def make_gaussian(size, fwhm=10, center=None): + ''' + Generates a 2D Gaussian kernel with the specified size and full-width-half-maximum (fwhm). It returns the Gaussian kernel. + + size: length of a side of the square + fwhm: full-width-half-maximum, which can be thought of as an effective + radius. + + https://gist.github.com/andrewgiessel/4635563 + ''' + + x = np.arange(0, size, 1, float) + y = x[:, np.newaxis] + + if center is None: + x0 = y0 = size // 2 + else: + x0 = center[0] + y0 = center[1] + + + return np.exp(-4 * np.log(2) * ((x - x0) ** 2 + (y - y0) ** 2) / fwhm ** 2) + + +def gen_depth_weights(d_segments, depth_map) -> list: + ''' + Generates depth weights for the segments based on the depth map. It + returns a list of depth weights. + ''' + + hist_d, _ = np.histogram(depth_map, 256, [0, 256]) + + # Get first non-zero index + first_nz = next((i for i, x in enumerate(hist_d) if x), None) + + # Get last non-zero index + rev = (len(hist_d) - idx for idx, item in enumerate(reversed(hist_d), 1) if item) + last_nz = next(rev, default=None) + + mid = (first_nz + last_nz) / 2 + + for seg in d_segments: + hist, _ = np.histogram(seg, 256, [0, 256]) + dw = 0 + ind = 0 + for s in hist: + if ind > mid: + dw = dw + (s * 1) + ind = ind + 1 + dws.append(dw) + + return dws + + +def gen_blank_depth_weight(d_segments): + ''' + Generates blank depth weights for the segments. It returns a list of + depth weights. + ''' + + for _ in d_segments: + dw = 1 + dws.append(dw) + return dws + + +# def generate_heatmap(img, mode, sorted_seg_scores, segments_coords) -> tuple: +# ''' +# Generates a heatmap overlay on the input image img based on the +# provided sorted segment scores. The mode parameter determines the color +# scheme of the heatmap. It returns the image with the heatmap overlay +# and a list of segment scores. + +# mode: 0 for white grid, 1 for color-coded grid +# ''' + +# font = cv2.FONT_HERSHEY_SIMPLEX +# # print_index = 0 +# print_index = len(sorted_seg_scores) - 1 +# set_value = int(0.25 * len(sorted_seg_scores)) +# color = (0, 0, 0) + +# max_x = 0 +# max_y = 0 + +# overlay = np.zeros_like(img, dtype=np.uint8) +# text_overlay = np.zeros_like(img, dtype=np.uint8) + +# sara_list_out = [] + +# for ent in reversed(sorted_seg_scores): +# quartile = 0 +# if mode == 0: +# color = (255, 255, 255) +# t = 4 +# elif mode == 1: +# if print_index + 1 <= set_value: +# color = (0, 0, 255, 255) +# t = 2 +# quartile = 1 +# elif print_index + 1 <= set_value * 2: +# color = (0, 128, 255, 192) +# t = 4 +# quartile = 2 +# elif print_index + 1 <= set_value * 3: +# color = (0, 255, 255, 128) +# t = 4 +# t = 6 +# quartile = 3 +# # elif print_index + 1 <= set_value * 4: +# # color = (0, 250, 0, 64) +# # t = 8 +# # quartile = 4 +# else: +# color = (0, 250, 0, 64) +# t = 8 +# quartile = 4 + + +# x1 = segments_coords[ent[0]][1] +# y1 = segments_coords[ent[0]][2] +# x2 = segments_coords[ent[0]][3] +# y2 = segments_coords[ent[0]][4] + +# if x2 > max_x: +# max_x = x2 +# if y2 > max_y: +# max_y = y2 + +# x = int((x1 + x2) / 2) +# y = int((y1 + y2) / 2) + + + +# # fill rectangle +# cv2.rectangle(overlay, (x1, y1), (x2, y2), color, -1) + +# cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 0, 0), 1) +# # put text in the middle of the rectangle + +# # white text +# cv2.putText(text_overlay, str(print_index), (x - 5, y), +# font, .4, (255, 255, 255), 1, cv2.LINE_AA) + +# # Index, rank, score, entropy, entropy_sum, centre_bias, depth, quartile +# sara_tuple = (ent[0], print_index, ent[1], ent[2], ent[3], ent[4], ent[5], quartile) +# sara_list_out.append(sara_tuple) +# print_index -= 1 + +# # crop the overlay to up to x2 and y2 +# overlay = overlay[0:max_y, 0:max_x] +# text_overlay = text_overlay[0:max_y, 0:max_x] +# img = img[0:max_y, 0:max_x] + + +# img = cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img) + +# img[text_overlay > 128] = text_overlay[text_overlay > 128] + + +# return img, sara_list_out +def generate_heatmap(img, sorted_seg_scores, segments_coords, mode=1) -> tuple: + ''' + Generates a more vibrant heatmap overlay on the input image img based on the + provided sorted segment scores. It returns the image with the heatmap overlay + and a list of segment scores with quartile information. + + mode: 0 for white grid, 1 for color-coded grid, 2 for heatmap to be used as a feature + ''' + alpha =0.3 + if mode == 2: + + font = cv2.FONT_HERSHEY_SIMPLEX + print_index = len(sorted_seg_scores) - 1 + set_value = int(0.25 * len(sorted_seg_scores)) + + max_x = 0 + max_y = 0 + + overlay = np.zeros_like(img, dtype=np.uint8) + text_overlay = np.zeros_like(img, dtype=np.uint8) + + sara_list_out = [] + + scores = [score[1] for score in sorted_seg_scores] + min_score = min(scores) + max_score = max(scores) + + # Choose a colormap from matplotlib + colormap = plt.get_cmap('jet') # 'jet', 'viridis', 'plasma', 'magma', 'cividis, jet_r, viridis_r, plasma_r, magma_r, cividis_r + + for ent in reversed(sorted_seg_scores): + score = ent[1] + normalized_score = (score - min_score) / (max_score - min_score) + color_weight = normalized_score * score # Weighted color based on the score + color = np.array(colormap(normalized_score)[:3]) * 255 #* color_weight + + x1 = segments_coords[ent[0]][1] + y1 = segments_coords[ent[0]][2] + x2 = segments_coords[ent[0]][3] + y2 = segments_coords[ent[0]][4] + + if x2 > max_x: + max_x = x2 + if y2 > max_y: + max_y = y2 + + x = int((x1 + x2) / 2) + y = int((y1 + y2) / 2) + + # fill rectangle + cv2.rectangle(overlay, (x1, y1), (x2, y2), color, -1) + # black border + # cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 0, 0), 1) + + # white text + # cv2.putText(text_overlay, str(print_index), (x - 5, y), + # font, .4, (255, 255, 255), 1, cv2.LINE_AA) + + # Determine quartile based on print_index + if print_index + 1 <= set_value: + quartile = 1 + elif print_index + 1 <= set_value * 2: + quartile = 2 + elif print_index + 1 <= set_value * 3: + quartile = 3 + else: + quartile = 4 + + sara_tuple = (ent[0], print_index, ent[1], ent[2], ent[3], ent[4], ent[5], quartile) + sara_list_out.append(sara_tuple) + print_index -= 1 + + overlay = overlay[0:max_y, 0:max_x] + text_overlay = text_overlay[0:max_y, 0:max_x] + img = img[0:max_y, 0:max_x] + + # Create a blank grayscale image with the same dimensions as the original image + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + gray = cv2.merge([gray, gray, gray]) + + gray = cv2.addWeighted(overlay, alpha, gray, 1-alpha, 0, gray) + gray[text_overlay > 128] = text_overlay[text_overlay > 128] + + return gray, sara_list_out + else: + font = cv2.FONT_HERSHEY_SIMPLEX + # print_index = 0 + print_index = len(sorted_seg_scores) - 1 + set_value = int(0.25 * len(sorted_seg_scores)) + color = (0, 0, 0) + + max_x = 0 + max_y = 0 + + overlay = np.zeros_like(img, dtype=np.uint8) + text_overlay = np.zeros_like(img, dtype=np.uint8) + + sara_list_out = [] + + for ent in reversed(sorted_seg_scores): + quartile = 0 + if mode == 0: + color = (255, 255, 255) + t = 4 + elif mode == 1: + if print_index + 1 <= set_value: + color = (0, 0, 255, 255) + t = 2 + quartile = 1 + elif print_index + 1 <= set_value * 2: + color = (0, 128, 255, 192) + t = 4 + quartile = 2 + elif print_index + 1 <= set_value * 3: + color = (0, 255, 255, 128) + t = 4 + t = 6 + quartile = 3 + # elif print_index + 1 <= set_value * 4: + # color = (0, 250, 0, 64) + # t = 8 + # quartile = 4 + else: + color = (0, 250, 0, 64) + t = 8 + quartile = 4 + + + x1 = segments_coords[ent[0]][1] + y1 = segments_coords[ent[0]][2] + x2 = segments_coords[ent[0]][3] + y2 = segments_coords[ent[0]][4] + + if x2 > max_x: + max_x = x2 + if y2 > max_y: + max_y = y2 + + x = int((x1 + x2) / 2) + y = int((y1 + y2) / 2) + + + + # fill rectangle + cv2.rectangle(overlay, (x1, y1), (x2, y2), color, -1) + + cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 0, 0), 1) + # put text in the middle of the rectangle + + # white text + cv2.putText(text_overlay, str(print_index), (x - 5, y), + font, .4, (255, 255, 255), 1, cv2.LINE_AA) + + # Index, rank, score, entropy, entropy_sum, centre_bias, depth, quartile + sara_tuple = (ent[0], print_index, ent[1], ent[2], ent[3], ent[4], ent[5], quartile) + sara_list_out.append(sara_tuple) + print_index -= 1 + + # crop the overlay to up to x2 and y2 + overlay = overlay[0:max_y, 0:max_x] + text_overlay = text_overlay[0:max_y, 0:max_x] + img = img[0:max_y, 0:max_x] + + + img = cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img) + + img[text_overlay > 128] = text_overlay[text_overlay > 128] + + + return img, sara_list_out + +def generate_sara(tex, tex_segments, mode=2): + ''' + Generates the SaRa (Salient Region Annotation) output by calculating + saliency scores for the segments of the given texture image tex. It + returns the texture image with the heatmap overlay and a list of + segment scores. + ''' + + gaussian_kernel_array = make_gaussian(seg_dim) + gaussian1d = gaussian_kernel_array.ravel() + + dws = gen_blank_depth_weight(tex_segments) + + max_h, index = find_most_salient_segment(tex_segments, gaussian1d, dws) + # dict_entropies = dict(segments_entropies) + # segments_scores list with 5 elements, use index as key for dict and store rest as list of index + dict_scores = {} + + for segment in segments_scores: + # Index: score, entropy, sum, depth, centre-bias + dict_scores[segment[0]] = [segment[1], segment[2], segment[3], segment[4], segment[5]] + + # sorted_entropies = sorted(dict_entropies.items(), + # key=operator.itemgetter(1), reverse=True) + + + # sorted_scores = sorted(dict_scores.items(), + # key=operator.itemgetter(1), reverse=True) + + # Sort by first value in value list + sorted_scores = sorted(dict_scores.items(), key=lambda x: x[1][0], reverse=True) + + # flatten + sorted_scores = [[i[0], i[1][0], i[1][1], i[1][2], i[1][3], i[1][4]] for i in sorted_scores] + + # tex_out, sara_list_out = generate_heatmap( + # tex, 1, sorted_entropies, segments_coords) + + tex_out, sara_list_out = generate_heatmap( + tex, sorted_scores, segments_coords, mode = mode) + + sara_list_out = list(reversed(sara_list_out)) + + return tex_out, sara_list_out + + +def return_sara(input_img, grid, generator='itti', saliency_map=None, mode = 2): + ''' + Computes the SaRa output for the given input image. It uses the + generate_sara function internally. It returns the SaRa output image and + a list of segment scores. + ''' + + global seg_dim + seg_dim = grid + + if saliency_map is None: + saliency_map = return_saliency(input_img, generator) + + tex_segments = generate_segments(saliency_map, seg_dim) + + # tex_segments = generate_segments(input_img, seg_dim) + sara_output, sara_list_output = generate_sara(input_img, tex_segments, mode=mode) + + return sara_output, sara_list_output + + +def mean_squared_error(image_a, image_b) -> float: + ''' + Calculates the Mean Squared Error (MSE), i.e. sum of squared + differences between two images image_a and image_b. It returns the MSE + value. + + NOTE: The two images must have the same dimension + ''' + + err = np.sum((image_a.astype('float') - image_b.astype('float')) ** 2) + err /= float(image_a.shape[0] * image_a.shape[1]) + + return err + + +def reset(): + ''' + Resets all global variables to their default values. + ''' + + # global segments_entropies, segments_scores, segments_coords, seg_dim, segments, gt_segments, dws, sara_list + + global segments_scores, segments_coords, seg_dim, segments, gt_segments, dws, sara_list + + # segments_entropies = [] + segments_scores = [] + segments_coords = [] + + seg_dim = 0 + segments = [] + gt_segments = [] + dws = [] + sara_list = [] + + + +def resize_based_on_important_ranks(img, sara_info, grid_size, rate=0.3): + def generate_segments(image, seg_count) -> dict: + """ + Function to generate segments of an image + + Args: + image: input image + seg_count: number of segments to generate + + Returns: + segments: dictionary of segments + + """ + # Initializing segments dictionary + segments = {} + # Initializing segment index and segment count + segment_count = seg_count + index = 0 + + # Retrieving image width and height + h, w = image.shape[:2] + + # Calculating width and height intervals for segments from the segment count + w_interval = w // segment_count + h_interval = h // segment_count + + # Iterating through the image and generating segments + for i in range(segment_count): + for j in range(segment_count): + # Calculating segment coordinates + x1, y1 = j * w_interval, i * h_interval + x2, y2 = x1 + w_interval, y1 + h_interval + + # Adding segment coordinates to segments dictionary + segments[index] = (x1, y1, x2, y2) + + # Incrementing segment index + index += 1 + + # Returning segments dictionary + return segments + + # Retrieving important ranks from SaRa + sara_dict = { + info[0]: { + 'score': info[2], + 'index': info[1] + } + for info in sara_info[1] + } + + # Sorting important ranks by score + sorted_sara_dict = sorted(sara_dict.items(), key=lambda item: item[1]['score'], reverse=True) + + # Generating segments + index_info = generate_segments(img, grid_size) + + # Initializing most important ranks image + most_imp_ranks = np.zeros_like(img) + + # Calculating maximum rank + max_rank = int(grid_size * grid_size * rate) + count = 0 + + # Iterating through important ranks and adding them to most important ranks image + for rank, info in sorted_sara_dict: + # Checking if rank is within maximum rank + if count <= max_rank: + # Retrieving segment coordinates + coords = index_info[rank] + + # Adding segment to most important ranks image by making it white + most_imp_ranks[coords[1]:coords[3], coords[0]:coords[2]] = 255 + + # Incrementing count + count += 1 + else: + break + + # Retrieving coordinates of most important ranks + coords = np.argwhere(most_imp_ranks == 255) + + # Checking if no important ranks were found and returning original image + if coords.size == 0: + return img , most_imp_ranks, [0, 0, img.shape[0], img.shape[1]] + + # Cropping image based on most important ranks + x0, y0 = coords.min(axis=0)[:2] + x1, y1 = coords.max(axis=0)[:2] + 1 + cropped_img = img[x0:x1, y0:y1] + return cropped_img , most_imp_ranks, [x0, y0, x1, y1] + +def sara_resize(img, sara_info, grid_size, rate=0.3, iterations=2): + """ + Function to resize an image based on SaRa + + Args: + img: input image + sara_info: SaRa information + grid_size: size of the grid + rate: rate of important ranks + iterations: number of iterations to resize + + Returns: + img: resized image + """ + # Iterating through iterations + for _ in range(iterations): + # Resizing image based on important ranks + img, most_imp_ranks, coords = resize_based_on_important_ranks(img, sara_info, grid_size, rate=rate) + + # Returning resized image + return img, most_imp_ranks, coords + +def plot_3D(img, sara_info, grid_size, rate=0.3): + def generate_segments(image, seg_count) -> dict: + """ + Function to generate segments of an image + + Args: + image: input image + seg_count: number of segments to generate + + Returns: + segments: dictionary of segments + + """ + # Initializing segments dictionary + segments = {} + # Initializing segment index and segment count + segment_count = seg_count + index = 0 + + # Retrieving image width and height + h, w = image.shape[:2] + + # Calculating width and height intervals for segments from the segment count + w_interval = w // segment_count + h_interval = h // segment_count + + # Iterating through the image and generating segments + for i in range(segment_count): + for j in range(segment_count): + # Calculating segment coordinates + x1, y1 = j * w_interval, i * h_interval + x2, y2 = x1 + w_interval, y1 + h_interval + + # Adding segment coordinates to segments dictionary + segments[index] = (x1, y1, x2, y2) + + # Incrementing segment index + index += 1 + + # Returning segments dictionary + return segments + + # Extracting heatmap from SaRa information + heatmap = sara_info[0] + heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) + + # Retrieving important ranks from SaRa + sara_dict = { + info[0]: { + 'score': info[2], + 'index': info[1] + } + for info in sara_info[1] + } + + # Sorting important ranks by score + sorted_sara_dict = sorted(sara_dict.items(), key=lambda item: item[1]['score'], reverse=True) + + # Generating segments + index_info = generate_segments(img, grid_size) + + # Calculating maximum rank + max_rank = int(grid_size * grid_size * rate) + count = 0 + + # Normalizing heatmap + heatmap = heatmap.astype(float) / 255.0 + + # Creating a figure + fig = plt.figure(figsize=(20, 10)) + + # Creating a 3D plot + ax = fig.add_subplot(111, projection='3d') + + # Defining the x and y coordinates for the heatmap + x_coords = np.linspace(0, 1, heatmap.shape[1]) + y_coords = np.linspace(0, 1, heatmap.shape[0]) + x, y = np.meshgrid(x_coords, y_coords) + + # Defining the z-coordinate for the heatmap (a constant, such as -5) + z = np.asarray([[-10] * heatmap.shape[1]] * heatmap.shape[0]) + + # Plotting the heatmap as a texture on the xy-plane + ax.plot_surface(x, y, z, facecolors=heatmap, rstride=1, cstride=1, shade=False) + + # Initializing the single distribution array + single_distribution = np.asarray([[1e-6] * heatmap.shape[1]] * heatmap.shape[0], dtype=float) + + importance = 0 + # Creating the single distribution by summing up Gaussian distributions for each segment + for rank, info in sorted_sara_dict: + # Retrieving segment coordinates + coords = index_info[rank] + + # Creating a Gaussian distribution for the whole segment, i.e., arrange all the pixels in the segment in a 3D Gaussian distribution + x_temp = np.linspace(0, 1, coords[2] - coords[0]) + y_temp = np.linspace(0, 1, coords[3] - coords[1]) + + # Creating a meshgrid + x_temp, y_temp = np.meshgrid(x_temp, y_temp) + + # Calculating the Gaussian distribution + distribution = np.exp(-((x_temp - 0.5) ** 2 + (y_temp - 0.5) ** 2) / 0.1) * ((grid_size ** 2 - importance) / grid_size ** 2) # (constant) + + # Adding the Gaussian distribution to the single distribution + single_distribution[coords[1]:coords[3], coords[0]:coords[2]] += distribution + + # Incrementing importance + importance +=1 + + # Based on the rate, calculating the minimum number for the most important ranks + min_rank = int(grid_size * grid_size * rate) + + # Calculating the scale factor for the single distribution + scale_factor = ((grid_size ** 2 - min_rank) / grid_size ** 2) * 5 + + # Scaling the distribution + single_distribution *= scale_factor + + # Retrieving the max and min values of the single distribution + max_value = np.max(single_distribution) + min_value = np.min(single_distribution) + + # Calculating the hyperplane + hyperplane = np.asarray([[(max_value - min_value)* (1 - rate) + min_value] * heatmap.shape[1]] * heatmap.shape[0]) + + # Plotting a horizontal plane at the minimum rank level (hyperplane) + ax.plot_surface(x, y, hyperplane, rstride=1, cstride=1, color='red', alpha=0.3, shade=False) + + # Plotting the single distribution as a wireframe on the xy-plane + ax.plot_surface(x, y, single_distribution, rstride=1, cstride=1, color='blue', shade=False) + + # Setting the title + ax.set_title('SaRa 3D Heatmap Plot', fontsize=20) + + # Setting the labels + ax.set_xlabel('X', fontsize=16) + ax.set_ylabel('Y', fontsize=16) + ax.set_zlabel('Z', fontsize=16) + + # Setting the viewing angle to look from the y, x diagonal position + ax.view_init(elev=30, azim=45) # Adjust the elevation (elev) and azimuth (azim) angles as needed + # ax.view_init(elev=0, azim=0) # View from the top + + # Adding legend to the plot + # Creating Line2D objects for the legend + legend_elements = [Line2D([0], [0], color='blue', lw=4, label='Rank Distribution'), + Line2D([0], [0], color='red', lw=4, label='Threshold Hyperplane ({}%)'.format(rate*100)), + Line2D([0], [0], color='green', lw=4, label='SaRa Heatmap')] + + # Creating the legend + plt.subplots_adjust(right=0.5) + ax.legend(handles=legend_elements, fontsize=16, loc='center left', bbox_to_anchor=(1, 0.5)) + + # Inverting the x axis + ax.invert_xaxis() + + # Removing labels + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_zticks([]) + + # Showing the plot + plt.show() + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..839432c39f9b561728bae4821947aec3e7f2a1c0 --- /dev/null +++ b/app.py @@ -0,0 +1,154 @@ +from typing import Tuple +import gradio as gr +import numpy as np +import cv2 +import SaRa.saraRC1 as sara +import warnings +warnings.filterwarnings("ignore") + + +ALPHA = 0.3 +GENERATORS = ['itti', 'deepgaze'] + +MARKDOWN = """ +

Saliency Ranking: Itti vs. Deepgaze

+""" + +IMAGE_EXAMPLES = [ + ['https://media.roboflow.com/supervision/image-examples/people-walking.png', 9], + ['https://media.roboflow.com/supervision/image-examples/vehicles.png', 9], + ['https://media.roboflow.com/supervision/image-examples/basketball-1.png', 9], +] + +def detect_and_annotate(image: np.ndarray, + GRID_SIZE: int, + generator: str, + ALPHA:float =ALPHA)-> np.ndarray: + # Convert image from BGR to RGB + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Copy and convert the image for sara processing + sara_image = image.copy() + sara_image = cv2.cvtColor(sara_image, cv2.COLOR_RGB2BGR) + + # Resetting sara + sara.reset() + + # Running sara (Original implementation on itti) + sara_info = sara.return_sara(sara_image, GRID_SIZE, generator, mode=1) + + # Generate saliency map + saliency_map = sara.return_saliency(image, generator=generator) + # Resize saliency map to match the image size + saliency_map = cv2.resize(saliency_map, (image.shape[1], image.shape[0])) + + # Apply color map and convert to RGB + saliency_map = cv2.applyColorMap(saliency_map, cv2.COLORMAP_JET) + saliency_map = cv2.cvtColor(saliency_map, cv2.COLOR_BGR2RGB) + + # Overlay the saliency map on the original image + saliency_map = cv2.addWeighted(saliency_map, ALPHA, image, 1-ALPHA, 0) + + # Extract and convert heatmap to RGB + heatmap = sara_info[0] + heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) + + return saliency_map, heatmap + +def process_image( + input_image: np.ndarray, + GRIDSIZE: int, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + # Validate GRID_SIZE + if GRIDSIZE is None and GRIDSIZE < 4: + GRIDSIZE = 9 + + itti_saliency_map, itti_heatmap = detect_and_annotate( + input_image, sara, GRIDSIZE, 'itti') + deepgaze_saliency_map, deepgaze_heatmap = detect_and_annotate( + input_image, sara, GRIDSIZE, 'deepgaze') + + return ( + itti_saliency_map, + itti_heatmap, + deepgaze_saliency_map, + deepgaze_heatmap, + ) + +grid_size_Component = gr.Slider( + minimum=4, + maximum=70, + value=9, + step=1, + label="Grid Size", + info=( + "The grid size for the Saliency Ranking (SaRa) model. The grid size determines " + "the number of regions the image is divided into. A higher grid size results in " + "more regions and a lower grid size results in fewer regions. The default grid " + "size is 9." + )) + + +with gr.Blocks() as demo: + gr.Markdown(MARKDOWN) + with gr.Accordion("Configuration", open=False): + with gr.Row(): + grid_size_Component.render() + with gr.Row(): + input_image_component = gr.Image( + type='pil', + label='Input' + ) + with gr.Row(): + itti_saliency_map = gr.Image( + type='pil', + label='Itti Saliency Map' + ) + itti_heatmap = gr.Image( + type='pil', + label='Itti Saliency Ranking Heatmap' + ) + with gr.Row(): + deepgaze_saliency_map = gr.Image( + type='pil', + label='DeepGaze Saliency Map' + ) + deepgaze_heatmap = gr.Image( + type='pil', + label='DeepGaze Saliency Ranking Heatmap' + ) + submit_button_component = gr.Button( + value='Submit', + scale=1, + variant='primary' + ) + gr.Examples( + fn=process_image, + examples=IMAGE_EXAMPLES, + inputs=[ + input_image_component, + grid_size_Component, + ], + outputs=[ + itti_saliency_map, + itti_heatmap, + deepgaze_saliency_map, + deepgaze_heatmap, + ] + ) + + submit_button_component.click( + fn=process_image, + inputs=[ + input_image_component, + grid_size_Component, + ], + outputs=[ + itti_saliency_map, + itti_heatmap, + deepgaze_saliency_map, + deepgaze_heatmap, + ] + ) + +demo.launch(debug=False, show_error=True, max_threads=1) \ No newline at end of file diff --git a/deepgaze_pytorch/__init__.py b/deepgaze_pytorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3dad0637fe7e72af438c90a5b0e637c4299159b --- /dev/null +++ b/deepgaze_pytorch/__init__.py @@ -0,0 +1,3 @@ +from .deepgaze1 import DeepGazeI +from .deepgaze2e import DeepGazeIIE +from .deepgaze3 import DeepGazeIII diff --git a/deepgaze_pytorch/__pycache__/__init__.cpython-39.pyc b/deepgaze_pytorch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8485592ea3dfa71c9e5f966615c30ff7e7eb4018 Binary files /dev/null and b/deepgaze_pytorch/__pycache__/__init__.cpython-39.pyc differ diff --git a/deepgaze_pytorch/__pycache__/deepgaze1.cpython-39.pyc b/deepgaze_pytorch/__pycache__/deepgaze1.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68e0ae72c54aca99144986dc56c8e1b14aa7dc90 Binary files /dev/null and b/deepgaze_pytorch/__pycache__/deepgaze1.cpython-39.pyc differ diff --git a/deepgaze_pytorch/__pycache__/deepgaze2e.cpython-39.pyc b/deepgaze_pytorch/__pycache__/deepgaze2e.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e11aeb7bbc0d21070e963636b96f1dbb7daef1db Binary files /dev/null and b/deepgaze_pytorch/__pycache__/deepgaze2e.cpython-39.pyc differ diff --git a/deepgaze_pytorch/__pycache__/deepgaze3.cpython-39.pyc b/deepgaze_pytorch/__pycache__/deepgaze3.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94438b99f3c34c9313bf1d29666b273335e581b0 Binary files /dev/null and b/deepgaze_pytorch/__pycache__/deepgaze3.cpython-39.pyc differ diff --git a/deepgaze_pytorch/__pycache__/layers.cpython-39.pyc b/deepgaze_pytorch/__pycache__/layers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..795be699204098c31229ff9b9ab0481f8b1f1cc8 Binary files /dev/null and b/deepgaze_pytorch/__pycache__/layers.cpython-39.pyc differ diff --git a/deepgaze_pytorch/__pycache__/modules.cpython-39.pyc b/deepgaze_pytorch/__pycache__/modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27074544a7d7dfccc07ede88fd0d4655c90f559f Binary files /dev/null and b/deepgaze_pytorch/__pycache__/modules.cpython-39.pyc differ diff --git a/deepgaze_pytorch/data.py b/deepgaze_pytorch/data.py new file mode 100644 index 0000000000000000000000000000000000000000..5f29466fda2da6122b824cec43c6c0e4241dd56e --- /dev/null +++ b/deepgaze_pytorch/data.py @@ -0,0 +1,403 @@ +from collections import Counter +import io +import os +import pickle +import random + +from boltons.iterutils import chunked +import lmdb +import numpy as np +from PIL import Image +import pysaliency +from pysaliency.datasets import create_subset +from pysaliency.utils import remove_trailing_nans +import torch +from tqdm import tqdm + + +def ensure_color_image(image): + if len(image.shape) == 2: + return np.dstack([image, image, image]) + return image + + +def x_y_to_sparse_indices(xs, ys): + # Converts list of x and y coordinates into indices and values for sparse mask + x_inds = [] + y_inds = [] + values = [] + pair_inds = {} + + for x, y in zip(xs, ys): + key = (x, y) + if key not in pair_inds: + x_inds.append(x) + y_inds.append(y) + pair_inds[key] = len(x_inds) - 1 + values.append(1) + else: + values[pair_inds[key]] += 1 + + return np.array([y_inds, x_inds]), values + + +class ImageDataset(torch.utils.data.Dataset): + def __init__( + self, + stimuli, + fixations, + centerbias_model=None, + lmdb_path=None, + transform=None, + cached=None, + average='fixation' + ): + self.stimuli = stimuli + self.fixations = fixations + self.centerbias_model = centerbias_model + self.lmdb_path = lmdb_path + self.transform = transform + self.average = average + + # cache only short dataset + if cached is None: + cached = len(self.stimuli) < 100 + + cache_fixation_data = cached + + if lmdb_path is not None: + _export_dataset_to_lmdb(stimuli, centerbias_model, lmdb_path) + self.lmdb_env = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), + readonly=True, lock=False, + readahead=False, meminit=False + ) + cached = False + cache_fixation_data = True + else: + self.lmdb_env = None + + self.cached = cached + if cached: + self._cache = {} + self.cache_fixation_data = cache_fixation_data + if cache_fixation_data: + print("Populating fixations cache") + self._xs_cache = {} + self._ys_cache = {} + + for x, y, n in zip(self.fixations.x_int, self.fixations.y_int, tqdm(self.fixations.n)): + self._xs_cache.setdefault(n, []).append(x) + self._ys_cache.setdefault(n, []).append(y) + + for key in list(self._xs_cache): + self._xs_cache[key] = np.array(self._xs_cache[key], dtype=int) + for key in list(self._ys_cache): + self._ys_cache[key] = np.array(self._ys_cache[key], dtype=int) + + def get_shapes(self): + return list(self.stimuli.sizes) + + def _get_image_data(self, n): + if self.lmdb_env: + image, centerbias_prediction = _get_image_data_from_lmdb(self.lmdb_env, n) + else: + image = np.array(self.stimuli.stimuli[n]) + centerbias_prediction = self.centerbias_model.log_density(image) + + image = ensure_color_image(image).astype(np.float32) + image = image.transpose(2, 0, 1) + + return image, centerbias_prediction + + def __getitem__(self, key): + if not self.cached or key not in self._cache: + + image, centerbias_prediction = self._get_image_data(key) + centerbias_prediction = centerbias_prediction.astype(np.float32) + + if self.cache_fixation_data and self.cached: + xs = self._xs_cache.pop(key) + ys = self._ys_cache.pop(key) + elif self.cache_fixation_data and not self.cached: + xs = self._xs_cache[key] + ys = self._ys_cache[key] + else: + inds = self.fixations.n == key + xs = np.array(self.fixations.x_int[inds], dtype=int) + ys = np.array(self.fixations.y_int[inds], dtype=int) + + data = { + "image": image, + "x": xs, + "y": ys, + "centerbias": centerbias_prediction, + } + + if self.average == 'image': + data['weight'] = 1.0 + else: + data['weight'] = float(len(xs)) + + if self.cached: + self._cache[key] = data + else: + data = self._cache[key] + + if self.transform is not None: + return self.transform(dict(data)) + + return data + + def __len__(self): + return len(self.stimuli) + + +class FixationDataset(torch.utils.data.Dataset): + def __init__( + self, + stimuli, fixations, + centerbias_model=None, + lmdb_path=None, + transform=None, + included_fixations=-2, + allow_missing_fixations=False, + average='fixation', + cache_image_data=False, + ): + self.stimuli = stimuli + self.fixations = fixations + self.centerbias_model = centerbias_model + self.lmdb_path = lmdb_path + + if lmdb_path is not None: + _export_dataset_to_lmdb(stimuli, centerbias_model, lmdb_path) + self.lmdb_env = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), + readonly=True, lock=False, + readahead=False, meminit=False + ) + cache_image_data=False + else: + self.lmdb_env = None + + self.transform = transform + self.average = average + + self._shapes = None + + if isinstance(included_fixations, int): + if included_fixations < 0: + included_fixations = [-1 - i for i in range(-included_fixations)] + else: + raise NotImplementedError() + + self.included_fixations = included_fixations + self.allow_missing_fixations = allow_missing_fixations + self.fixation_counts = Counter(fixations.n) + + self.cache_image_data = cache_image_data + + if self.cache_image_data: + self.image_data_cache = {} + + print("Populating image cache") + for n in tqdm(range(len(self.stimuli))): + self.image_data_cache[n] = self._get_image_data(n) + + def get_shapes(self): + if self._shapes is None: + shapes = list(self.stimuli.sizes) + self._shapes = [shapes[n] for n in self.fixations.n] + + return self._shapes + + def _get_image_data(self, n): + if self.lmdb_path: + return _get_image_data_from_lmdb(self.lmdb_env, n) + image = np.array(self.stimuli.stimuli[n]) + centerbias_prediction = self.centerbias_model.log_density(image) + + image = ensure_color_image(image).astype(np.float32) + image = image.transpose(2, 0, 1) + + return image, centerbias_prediction + + def __getitem__(self, key): + n = self.fixations.n[key] + + if self.cache_image_data: + image, centerbias_prediction = self.image_data_cache[n] + else: + image, centerbias_prediction = self._get_image_data(n) + + centerbias_prediction = centerbias_prediction.astype(np.float32) + + x_hist = remove_trailing_nans(self.fixations.x_hist[key]) + y_hist = remove_trailing_nans(self.fixations.y_hist[key]) + + if self.allow_missing_fixations: + _x_hist = [] + _y_hist = [] + for fixation_index in self.included_fixations: + if fixation_index < -len(x_hist): + _x_hist.append(np.nan) + _y_hist.append(np.nan) + else: + _x_hist.append(x_hist[fixation_index]) + _y_hist.append(y_hist[fixation_index]) + x_hist = np.array(_x_hist) + y_hist = np.array(_y_hist) + else: + print("Not missing") + x_hist = x_hist[self.included_fixations] + y_hist = y_hist[self.included_fixations] + + data = { + "image": image, + "x": np.array([self.fixations.x_int[key]], dtype=int), + "y": np.array([self.fixations.y_int[key]], dtype=int), + "x_hist": x_hist, + "y_hist": y_hist, + "centerbias": centerbias_prediction, + } + + if self.average == 'image': + data['weight'] = 1.0 / self.fixation_counts[n] + else: + data['weight'] = 1.0 + + if self.transform is not None: + return self.transform(data) + + return data + + def __len__(self): + return len(self.fixations) + + +class FixationMaskTransform(object): + def __init__(self, sparse=True): + super().__init__() + self.sparse = sparse + + def __call__(self, item): + shape = torch.Size([item['image'].shape[1], item['image'].shape[2]]) + x = item.pop('x') + y = item.pop('y') + + # inds, values = x_y_to_sparse_indices(x, y) + inds = np.array([y, x]) + values = np.ones(len(y), dtype=int) + + mask = torch.sparse.IntTensor(torch.tensor(inds), torch.tensor(values), shape) + mask = mask.coalesce() + # sparse tensors don't work with workers... + if not self.sparse: + mask = mask.to_dense() + + item['fixation_mask'] = mask + + return item + + +class ImageDatasetSampler(torch.utils.data.Sampler): + def __init__(self, data_source, batch_size=1, ratio_used=1.0, shuffle=True): + self.ratio_used = ratio_used + self.shuffle = shuffle + + shapes = data_source.get_shapes() + unique_shapes = sorted(set(shapes)) + + shape_indices = [[] for shape in unique_shapes] + + for k, shape in enumerate(shapes): + shape_indices[unique_shapes.index(shape)].append(k) + + if self.shuffle: + for indices in shape_indices: + random.shuffle(indices) + + self.batches = sum([chunked(indices, size=batch_size) for indices in shape_indices], []) + + def __iter__(self): + if self.shuffle: + indices = torch.randperm(len(self.batches)) + else: + indices = range(len(self.batches)) + + if self.ratio_used < 1.0: + indices = indices[:int(self.ratio_used * len(indices))] + + return iter(self.batches[i] for i in indices) + + def __len__(self): + return int(self.ratio_used * len(self.batches)) + + +def _export_dataset_to_lmdb(stimuli: pysaliency.FileStimuli, centerbias_model: pysaliency.Model, lmdb_path, write_frequency=100): + lmdb_path = os.path.expanduser(lmdb_path) + isdir = os.path.isdir(lmdb_path) + + print("Generate LMDB to %s" % lmdb_path) + db = lmdb.open(lmdb_path, subdir=isdir, + map_size=1099511627776 * 2, readonly=False, + meminit=False, map_async=True) + + txn = db.begin(write=True) + for idx, stimulus in enumerate(tqdm(stimuli)): + key = u'{}'.format(idx).encode('ascii') + + previous_data = txn.get(key) + if previous_data: + continue + + #timulus_data = stimulus.stimulus_data + stimulus_filename = stimuli.filenames[idx] + centerbias = centerbias_model.log_density(stimulus) + + txn.put( + key, + _encode_filestimulus_item(stimulus_filename, centerbias) + ) + if idx % write_frequency == 0: + #print("[%d/%d]" % (idx, len(stimuli))) + #print("stimulus ids", len(stimuli.stimulus_ids._cache)) + #print("stimuli.cached", stimuli.cached) + #print("stimuli", len(stimuli.stimuli._cache)) + #print("centerbias", len(centerbias_model._cache._cache)) + txn.commit() + txn = db.begin(write=True) + + # finish iterating through dataset + txn.commit() + #keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] + #with db.begin(write=True) as txn: + # txn.put(b'__keys__', dumps_pyarrow(keys)) + # txn.put(b'__len__', dumps_pyarrow(len(keys))) + + print("Flushing database ...") + db.sync() + db.close() + + +def _encode_filestimulus_item(filename, centerbias): + with open(filename, 'rb') as f: + image_bytes = f.read() + + buffer = io.BytesIO() + pickle.dump({'image': image_bytes, 'centerbias': centerbias}, buffer) + buffer.seek(0) + return buffer.read() + + +def _get_image_data_from_lmdb(lmdb_env, n): + key = '{}'.format(n).encode('ascii') + with lmdb_env.begin(write=False) as txn: + byteflow = txn.get(key) + data = pickle.loads(byteflow) + buffer = io.BytesIO(data['image']) + buffer.seek(0) + image = np.array(Image.open(buffer).convert('RGB')) + centerbias_prediction = data['centerbias'] + image = image.transpose(2, 0, 1) + + return image, centerbias_prediction \ No newline at end of file diff --git a/deepgaze_pytorch/deepgaze1.py b/deepgaze_pytorch/deepgaze1.py new file mode 100644 index 0000000000000000000000000000000000000000..1eb750a7a2160c27d91fb9564cde2fd1143bf9b7 --- /dev/null +++ b/deepgaze_pytorch/deepgaze1.py @@ -0,0 +1,42 @@ +from collections import OrderedDict + +import torch +import torch.nn as nn + +from torch.utils import model_zoo + +from .features.alexnet import RGBalexnet +from .modules import FeatureExtractor, Finalizer, DeepGazeII as TorchDeepGazeII + + +class DeepGazeI(TorchDeepGazeII): + """DeepGaze I model + + Please note that this version of DeepGaze I is not exactly the one from the original paper. + The original model used caffe for AlexNet and theano for the linear readout and was trained using the SFO optimizer. + Here, we use the torch implementation of AlexNet (without any adaptations), which doesn't use the two-steam architecture, + and the DeepGaze II torch implementation with a simple linear readout network. + The model has been retrained with Adam, but still on the same dataset (all images of MIT1003 which are of size 1024x768). + Also, we don't use the sparsity penalty anymore. + + Reference: + Kümmerer, M., Theis, L., & Bethge, M. (2015). Deep Gaze I: Boosting Saliency Prediction with Feature Maps Trained on ImageNet. ICLR Workshop Track. http://arxiv.org/abs/1411.1045 + """ + def __init__(self, pretrained=True): + features = RGBalexnet() + feature_extractor = FeatureExtractor(features, ['1.features.10']) + + readout_network = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(256, 1, (1, 1), bias=False)), + ])) + + super().__init__( + features=feature_extractor, + readout_network=readout_network, + downsample=2, + readout_factor=4, + saliency_map_factor=4, + ) + + if pretrained: + self.load_state_dict(model_zoo.load_url('https://github.com/matthias-k/DeepGaze/releases/download/v1.01/deepgaze1.pth', map_location=torch.device('cpu'))) diff --git a/deepgaze_pytorch/deepgaze2e.py b/deepgaze_pytorch/deepgaze2e.py new file mode 100644 index 0000000000000000000000000000000000000000..16d95988e7dc35bffc7115e91a1ea935c3b45da6 --- /dev/null +++ b/deepgaze_pytorch/deepgaze2e.py @@ -0,0 +1,151 @@ +from collections import OrderedDict +import importlib +import os + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.utils import model_zoo + +from .modules import FeatureExtractor, Finalizer, DeepGazeIIIMixture, MixtureModel + +from .layers import ( + Conv2dMultiInput, + LayerNorm, + LayerNormMultiInput, + Bias, +) + + +BACKBONES = [ + { + 'type': 'deepgaze_pytorch.features.shapenet.RGBShapeNetC', + 'used_features': [ + '1.module.layer3.0.conv2', + '1.module.layer3.3.conv2', + '1.module.layer3.5.conv1', + '1.module.layer3.5.conv2', + '1.module.layer4.1.conv2', + '1.module.layer4.2.conv2', + ], + 'channels': 2048, + }, + { + 'type': 'deepgaze_pytorch.features.efficientnet.RGBEfficientNetB5', + 'used_features': [ + '1._blocks.24._depthwise_conv', + '1._blocks.26._depthwise_conv', + '1._blocks.35._project_conv', + ], + 'channels': 2416, + }, + { + 'type': 'deepgaze_pytorch.features.densenet.RGBDenseNet201', + 'used_features': [ + '1.features.denseblock4.denselayer32.norm1', + '1.features.denseblock4.denselayer32.conv1', + '1.features.denseblock4.denselayer31.conv2', + ], + 'channels': 2048, + }, + { + 'type': 'deepgaze_pytorch.features.resnext.RGBResNext50', + 'used_features': [ + '1.layer3.5.conv1', + '1.layer3.5.conv2', + '1.layer3.4.conv2', + '1.layer4.2.conv2', + ], + 'channels': 2560, + }, +] + + +def build_saliency_network(input_channels): + return nn.Sequential(OrderedDict([ + ('layernorm0', LayerNorm(input_channels)), + ('conv0', nn.Conv2d(input_channels, 8, (1, 1), bias=False)), + ('bias0', Bias(8)), + ('softplus0', nn.Softplus()), + + ('layernorm1', LayerNorm(8)), + ('conv1', nn.Conv2d(8, 16, (1, 1), bias=False)), + ('bias1', Bias(16)), + ('softplus1', nn.Softplus()), + + ('layernorm2', LayerNorm(16)), + ('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)), + ('bias2', Bias(1)), + ('softplus3', nn.Softplus()), + ])) + + +def build_fixation_selection_network(): + return nn.Sequential(OrderedDict([ + ('layernorm0', LayerNormMultiInput([1, 0])), + ('conv0', Conv2dMultiInput([1, 0], 128, (1, 1), bias=False)), + ('bias0', Bias(128)), + ('softplus0', nn.Softplus()), + + ('layernorm1', LayerNorm(128)), + ('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)), + ('bias1', Bias(16)), + ('softplus1', nn.Softplus()), + + ('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)), + ])) + + +def build_deepgaze_mixture(backbone_config, components=10): + feature_class = import_class(backbone_config['type']) + features = feature_class() + + feature_extractor = FeatureExtractor(features, backbone_config['used_features']) + + saliency_networks = [] + scanpath_networks = [] + fixation_selection_networks = [] + finalizers = [] + for component in range(components): + saliency_network = build_saliency_network(backbone_config['channels']) + fixation_selection_network = build_fixation_selection_network() + + saliency_networks.append(saliency_network) + scanpath_networks.append(None) + fixation_selection_networks.append(fixation_selection_network) + finalizers.append(Finalizer(sigma=8.0, learn_sigma=True, saliency_map_factor=2)) + + return DeepGazeIIIMixture( + features=feature_extractor, + saliency_networks=saliency_networks, + scanpath_networks=scanpath_networks, + fixation_selection_networks=fixation_selection_networks, + finalizers=finalizers, + downsample=2, + readout_factor=16, + saliency_map_factor=2, + included_fixations=[], + ) + + +class DeepGazeIIE(MixtureModel): + """DeepGazeIIE model + + :note + See Linardos, A., Kümmerer, M., Press, O., & Bethge, M. (2021). Calibrated prediction in and out-of-domain for state-of-the-art saliency modeling. ArXiv:2105.12441 [Cs], http://arxiv.org/abs/2105.12441 + """ + def __init__(self, pretrained=True): + # we average over 3 instances per backbone, each instance has 10 crossvalidation folds + backbone_models = [build_deepgaze_mixture(backbone_config, components=3 * 10) for backbone_config in BACKBONES] + super().__init__(backbone_models) + + if pretrained: + self.load_state_dict(model_zoo.load_url('https://github.com/matthias-k/DeepGaze/releases/download/v1.0.0/deepgaze2e.pth', map_location=torch.device('cpu'))) + + +def import_class(name): + module_name, class_name = name.rsplit('.', 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) diff --git a/deepgaze_pytorch/deepgaze3.py b/deepgaze_pytorch/deepgaze3.py new file mode 100644 index 0000000000000000000000000000000000000000..e275499e9d5c57f109be5ba9c6869713d366f419 --- /dev/null +++ b/deepgaze_pytorch/deepgaze3.py @@ -0,0 +1,110 @@ +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.utils import model_zoo + +from .features.densenet import RGBDenseNet201 +from .modules import FeatureExtractor, Finalizer, DeepGazeIIIMixture +from .layers import FlexibleScanpathHistoryEncoding + +from .layers import ( + Conv2dMultiInput, + LayerNorm, + LayerNormMultiInput, + Bias, +) + + +def build_saliency_network(input_channels): + return nn.Sequential(OrderedDict([ + ('layernorm0', LayerNorm(input_channels)), + ('conv0', nn.Conv2d(input_channels, 8, (1, 1), bias=False)), + ('bias0', Bias(8)), + ('softplus0', nn.Softplus()), + + ('layernorm1', LayerNorm(8)), + ('conv1', nn.Conv2d(8, 16, (1, 1), bias=False)), + ('bias1', Bias(16)), + ('softplus1', nn.Softplus()), + + ('layernorm2', LayerNorm(16)), + ('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)), + ('bias2', Bias(1)), + ('softplus2', nn.Softplus()), + ])) + + +def build_scanpath_network(): + return nn.Sequential(OrderedDict([ + ('encoding0', FlexibleScanpathHistoryEncoding(in_fixations=4, channels_per_fixation=3, out_channels=128, kernel_size=[1, 1], bias=True)), + ('softplus0', nn.Softplus()), + + ('layernorm1', LayerNorm(128)), + ('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)), + ('bias1', Bias(16)), + ('softplus1', nn.Softplus()), + ])) + + +def build_fixation_selection_network(): + return nn.Sequential(OrderedDict([ + ('layernorm0', LayerNormMultiInput([1, 16])), + ('conv0', Conv2dMultiInput([1, 16], 128, (1, 1), bias=False)), + ('bias0', Bias(128)), + ('softplus0', nn.Softplus()), + + ('layernorm1', LayerNorm(128)), + ('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)), + ('bias1', Bias(16)), + ('softplus1', nn.Softplus()), + + ('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)), + ])) + + +class DeepGazeIII(DeepGazeIIIMixture): + """DeepGazeIII model + + :note + See Kümmerer, M., Bethge, M., & Wallis, T.S.A. (2022). DeepGaze III: Modeling free-viewing human scanpaths with deep learning. Journal of Vision 2022, https://doi.org/10.1167/jov.22.5.7 + """ + def __init__(self, pretrained=True): + features = RGBDenseNet201() + + feature_extractor = FeatureExtractor(features, [ + '1.features.denseblock4.denselayer32.norm1', + '1.features.denseblock4.denselayer32.conv1', + '1.features.denseblock4.denselayer31.conv2', + ]) + + saliency_networks = [] + scanpath_networks = [] + fixation_selection_networks = [] + finalizers = [] + for component in range(10): + saliency_network = build_saliency_network(2048) + scanpath_network = build_scanpath_network() + fixation_selection_network = build_fixation_selection_network() + + saliency_networks.append(saliency_network) + scanpath_networks.append(scanpath_network) + fixation_selection_networks.append(fixation_selection_network) + finalizers.append(Finalizer(sigma=8.0, learn_sigma=True, saliency_map_factor=4)) + + super().__init__( + features=feature_extractor, + saliency_networks=saliency_networks, + scanpath_networks=scanpath_networks, + fixation_selection_networks=fixation_selection_networks, + finalizers=finalizers, + downsample=2, + readout_factor=4, + saliency_map_factor=4, + included_fixations=[-1, -2, -3, -4] + ) + + if pretrained: + self.load_state_dict(model_zoo.load_url('https://github.com/matthias-k/DeepGaze/releases/download/v1.1.0/deepgaze3.pth', map_location=torch.device('cpu'))) \ No newline at end of file diff --git a/deepgaze_pytorch/features/__init__.py b/deepgaze_pytorch/features/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/deepgaze_pytorch/features/__pycache__/__init__.cpython-39.pyc b/deepgaze_pytorch/features/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8893794198951d429733ea7536d22fc7b0eb9e60 Binary files /dev/null and b/deepgaze_pytorch/features/__pycache__/__init__.cpython-39.pyc differ diff --git a/deepgaze_pytorch/features/__pycache__/alexnet.cpython-39.pyc b/deepgaze_pytorch/features/__pycache__/alexnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5ea7f5ff91f21a1950e535872a05978ac5d268a Binary files /dev/null and b/deepgaze_pytorch/features/__pycache__/alexnet.cpython-39.pyc differ diff --git a/deepgaze_pytorch/features/__pycache__/densenet.cpython-39.pyc b/deepgaze_pytorch/features/__pycache__/densenet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..665062ce4149de88b0dc07cdbfca6f5f89c4c575 Binary files /dev/null and b/deepgaze_pytorch/features/__pycache__/densenet.cpython-39.pyc differ diff --git a/deepgaze_pytorch/features/__pycache__/efficientnet.cpython-39.pyc b/deepgaze_pytorch/features/__pycache__/efficientnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9938fc6bc6b27a68362a9aee19975fca7f28ecf7 Binary files /dev/null and b/deepgaze_pytorch/features/__pycache__/efficientnet.cpython-39.pyc differ diff --git a/deepgaze_pytorch/features/__pycache__/normalizer.cpython-39.pyc b/deepgaze_pytorch/features/__pycache__/normalizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c677f04a6d6d35a09e8c98983fc80c6a316e11fd Binary files /dev/null and b/deepgaze_pytorch/features/__pycache__/normalizer.cpython-39.pyc differ diff --git a/deepgaze_pytorch/features/__pycache__/resnext.cpython-39.pyc b/deepgaze_pytorch/features/__pycache__/resnext.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec7bfa3f67965f1c155f8af5b4e824fa6c2bde8e Binary files /dev/null and b/deepgaze_pytorch/features/__pycache__/resnext.cpython-39.pyc differ diff --git a/deepgaze_pytorch/features/__pycache__/shapenet.cpython-39.pyc b/deepgaze_pytorch/features/__pycache__/shapenet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95431271a702d571a3be4dea2e25eff04601af23 Binary files /dev/null and b/deepgaze_pytorch/features/__pycache__/shapenet.cpython-39.pyc differ diff --git a/deepgaze_pytorch/features/alexnet.py b/deepgaze_pytorch/features/alexnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0c1e24b9ba4d3bbf1c58466eb11d6f768454e670 --- /dev/null +++ b/deepgaze_pytorch/features/alexnet.py @@ -0,0 +1,18 @@ +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torchvision + +from .normalizer import Normalizer + + + +class RGBalexnet(nn.Sequential): + def __init__(self): + super(RGBalexnet, self).__init__() + self.model = torch.hub.load('pytorch/vision:v0.6.0', 'alexnet', pretrained=True) + self.normalizer = Normalizer() + super(RGBalexnet, self).__init__(self.normalizer, self.model) + diff --git a/deepgaze_pytorch/features/bagnet.py b/deepgaze_pytorch/features/bagnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d040c853ed7735b8c8814688d130b898441dd5b8 --- /dev/null +++ b/deepgaze_pytorch/features/bagnet.py @@ -0,0 +1,192 @@ +""" +This code is adapted from: https://github.com/wielandbrendel/bag-of-local-features-models +""" + +import torch.nn as nn +import math +import torch +from collections import OrderedDict +from torch.utils import model_zoo + +from .normalizer import Normalizer + + +import os +dir_path = os.path.dirname(os.path.realpath(__file__)) + +__all__ = ['bagnet9', 'bagnet17', 'bagnet33'] + +model_urls = { + 'bagnet9': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet8-34f4ccd2.pth.tar', + 'bagnet17': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet16-105524de.pth.tar', + 'bagnet33': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet32-2ddd53ed.pth.tar', +} + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, kernel_size=1): + super(Bottleneck, self).__init__() + # print('Creating bottleneck with kernel size {} and stride {} with padding {}'.format(kernel_size, stride, (kernel_size - 1) // 2)) + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, stride=stride, + padding=0, bias=False) # changed padding from (kernel_size - 1) // 2 + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x, **kwargs): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + if residual.size(-1) != out.size(-1): + diff = residual.size(-1) - out.size(-1) + residual = residual[:,:,:-diff,:-diff] + + out += residual + out = self.relu(out) + + return out + + +class BagNet(nn.Module): + + def __init__(self, block, layers, strides=[1, 2, 2, 2], kernel3=[0, 0, 0, 0], num_classes=1000, avg_pool=True): + self.inplanes = 64 + super(BagNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=1, stride=1, padding=0, + bias=False) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0, + bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=0.001) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], kernel3=kernel3[0], prefix='layer1') + self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], kernel3=kernel3[1], prefix='layer2') + self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], kernel3=kernel3[2], prefix='layer3') + self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], kernel3=kernel3[3], prefix='layer4') + self.avgpool = nn.AvgPool2d(1, stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + self.avg_pool = avg_pool + self.block = block + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, kernel3=0, prefix=''): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + kernel = 1 if kernel3 == 0 else 3 + layers.append(block(self.inplanes, planes, stride, downsample, kernel_size=kernel)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + kernel = 1 if kernel3 <= i else 3 + layers.append(block(self.inplanes, planes, kernel_size=kernel)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + if self.avg_pool: + x = nn.AvgPool2d(x.size()[2], stride=1)(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + else: + x = x.permute(0,2,3,1) + x = self.fc(x) + + return x + +def bagnet33(pretrained=False, strides=[2, 2, 2, 1], **kwargs): + """Constructs a Bagnet-33 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,1,1], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['bagnet33'])) + return model + +def bagnet17(pretrained=False, strides=[2, 2, 2, 1], **kwargs): + """Constructs a Bagnet-17 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,1,0], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['bagnet17'])) + return model + +def bagnet9(pretrained=False, strides=[2, 2, 2, 1], **kwargs): + """Constructs a Bagnet-9 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,0,0], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['bagnet9'])) + return model + +# --- DeepGaze Adaptation ---- + + + + +class RGBBagNet17(nn.Sequential): + def __init__(self): + super(RGBBagNet17, self).__init__() + self.bagnet = bagnet17(pretrained=True, avg_pool=False) + self.normalizer = Normalizer() + super(RGBBagNet17, self).__init__(self.normalizer, self.bagnet) + + +class RGBBagNet33(nn.Sequential): + def __init__(self): + super(RGBBagNet33, self).__init__() + self.bagnet = bagnet33(pretrained=True, avg_pool=False) + self.normalizer = Normalizer() + super(RGBBagNet33, self).__init__(self.normalizer, self.bagnet) + + + diff --git a/deepgaze_pytorch/features/densenet.py b/deepgaze_pytorch/features/densenet.py new file mode 100644 index 0000000000000000000000000000000000000000..6c61d9f801f7e8f2bdc8b1a520cce321d8a37762 --- /dev/null +++ b/deepgaze_pytorch/features/densenet.py @@ -0,0 +1,19 @@ +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torchvision + +from .normalizer import Normalizer + + + +class RGBDenseNet201(nn.Sequential): + def __init__(self): + super(RGBDenseNet201, self).__init__() + self.densenet = torch.hub.load('pytorch/vision:v0.6.0', 'densenet201', pretrained=True) + self.normalizer = Normalizer() + super(RGBDenseNet201, self).__init__(self.normalizer, self.densenet) + + diff --git a/deepgaze_pytorch/features/efficientnet.py b/deepgaze_pytorch/features/efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..34f430f8d09452896f57f9275f24a9cb3cee6686 --- /dev/null +++ b/deepgaze_pytorch/features/efficientnet.py @@ -0,0 +1,31 @@ +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torchvision + +from .efficientnet_pytorch import EfficientNet + + +from .normalizer import Normalizer + + + +class RGBEfficientNetB5(nn.Sequential): + def __init__(self): + super(RGBEfficientNetB5, self).__init__() + self.efficientnet = EfficientNet.from_pretrained('efficientnet-b5') + self.normalizer = Normalizer() + super(RGBEfficientNetB5, self).__init__(self.normalizer, self.efficientnet) + + + +class RGBEfficientNetB7(nn.Sequential): + def __init__(self): + super(RGBEfficientNetB7, self).__init__() + self.efficientnet = EfficientNet.from_pretrained('efficientnet-b7') + self.normalizer = Normalizer() + super(RGBEfficientNetB7, self).__init__(self.normalizer, self.efficientnet) + + diff --git a/deepgaze_pytorch/features/efficientnet_pytorch/__init__.py b/deepgaze_pytorch/features/efficientnet_pytorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7868fcdfc78d790f86a42a9335bcbef415d34480 --- /dev/null +++ b/deepgaze_pytorch/features/efficientnet_pytorch/__init__.py @@ -0,0 +1,10 @@ +__version__ = "0.6.3" +from .model import EfficientNet +from .utils import ( + GlobalParams, + BlockArgs, + BlockDecoder, + efficientnet, + get_model_params, +) + diff --git a/deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/__init__.cpython-39.pyc b/deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21c8e0038b8e4bb4081e07db435b7ad6d8b2bb6b Binary files /dev/null and b/deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/__init__.cpython-39.pyc differ diff --git a/deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/model.cpython-39.pyc b/deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eba4828e03d7c69e496f54a670a5d8076000d4f7 Binary files /dev/null and b/deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/model.cpython-39.pyc differ diff --git a/deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/utils.cpython-39.pyc b/deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a210eea355f9151b5872247f9fe17c7da005a1c1 Binary files /dev/null and b/deepgaze_pytorch/features/efficientnet_pytorch/__pycache__/utils.cpython-39.pyc differ diff --git a/deepgaze_pytorch/features/efficientnet_pytorch/model.py b/deepgaze_pytorch/features/efficientnet_pytorch/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f2bc2afd1ddb3d6a0fec3cb7d5ee5d4070dbbc48 --- /dev/null +++ b/deepgaze_pytorch/features/efficientnet_pytorch/model.py @@ -0,0 +1,229 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from .utils import ( + round_filters, + round_repeats, + drop_connect, + get_same_padding_conv2d, + get_model_params, + efficientnet_params, + load_pretrained_weights, + Swish, + MemoryEfficientSwish, +) + +class MBConvBlock(nn.Module): + """ + Mobile Inverted Residual Bottleneck Block + + Args: + block_args (namedtuple): BlockArgs, see above + global_params (namedtuple): GlobalParam, see above + + Attributes: + has_se (bool): Whether the block contains a Squeeze and Excitation layer. + """ + + def __init__(self, block_args, global_params): + super().__init__() + self._block_args = block_args + self._bn_mom = 1 - global_params.batch_norm_momentum + self._bn_eps = global_params.batch_norm_epsilon + self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) + self.id_skip = block_args.id_skip # skip connection and drop connect + + # Get static or dynamic convolution depending on image size + Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) + + # Expansion phase + inp = self._block_args.input_filters # number of input channels + oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels + if self._block_args.expand_ratio != 1: + self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) + self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + + # Depthwise convolution phase + k = self._block_args.kernel_size + s = self._block_args.stride + self._depthwise_conv = Conv2d( + in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise + kernel_size=k, stride=s, bias=False) + self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + + # Squeeze and Excitation layer, if desired + if self.has_se: + num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) + self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) + self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) + + # Output phase + final_oup = self._block_args.output_filters + self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) + self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) + self._swish = MemoryEfficientSwish() + + def forward(self, inputs, drop_connect_rate=None): + """ + :param inputs: input tensor + :param drop_connect_rate: drop connect rate (float, between 0 and 1) + :return: output of block + """ + + # Expansion and Depthwise Convolution + x = inputs + if self._block_args.expand_ratio != 1: + x = self._swish(self._bn0(self._expand_conv(inputs))) + x = self._swish(self._bn1(self._depthwise_conv(x))) + + # Squeeze and Excitation + if self.has_se: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed))) + x = torch.sigmoid(x_squeezed) * x + + x = self._bn2(self._project_conv(x)) + + # Skip connection and drop connect + input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters + if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: + if drop_connect_rate: + x = drop_connect(x, p=drop_connect_rate, training=self.training) + x = x + inputs # skip connection + return x + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export)""" + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + + +class EfficientNet(nn.Module): + """ + An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods + + Args: + blocks_args (list): A list of BlockArgs to construct blocks + global_params (namedtuple): A set of GlobalParams shared between blocks + + Example: + model = EfficientNet.from_pretrained('efficientnet-b0') + + """ + + def __init__(self, blocks_args=None, global_params=None): + super().__init__() + assert isinstance(blocks_args, list), 'blocks_args should be a list' + assert len(blocks_args) > 0, 'block args must be greater than 0' + self._global_params = global_params + self._blocks_args = blocks_args + + # Get static or dynamic convolution depending on image size + Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) + + # Batch norm parameters + bn_mom = 1 - self._global_params.batch_norm_momentum + bn_eps = self._global_params.batch_norm_epsilon + + # Stem + in_channels = 3 # rgb + out_channels = round_filters(32, self._global_params) # number of output channels + self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) + self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + # Build blocks + self._blocks = nn.ModuleList([]) + for block_args in self._blocks_args: + + # Update block input and output filters based on depth multiplier. + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, self._global_params), + output_filters=round_filters(block_args.output_filters, self._global_params), + num_repeat=round_repeats(block_args.num_repeat, self._global_params) + ) + + # The first block needs to take care of stride and filter size increase. + self._blocks.append(MBConvBlock(block_args, self._global_params)) + if block_args.num_repeat > 1: + block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + self._blocks.append(MBConvBlock(block_args, self._global_params)) + + # Head + in_channels = block_args.output_filters # output of final block + out_channels = round_filters(1280, self._global_params) + self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + # Final linear layer + self._avg_pooling = nn.AdaptiveAvgPool2d(1) + self._dropout = nn.Dropout(self._global_params.dropout_rate) + self._fc = nn.Linear(out_channels, self._global_params.num_classes) + self._swish = MemoryEfficientSwish() + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export)""" + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + for block in self._blocks: + block.set_swish(memory_efficient) + + + def extract_features(self, inputs): + """ Returns output of the final convolution layer """ + + # Stem + x = self._swish(self._bn0(self._conv_stem(inputs))) + + # Blocks + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) + x = block(x, drop_connect_rate=drop_connect_rate) + + # Head + x = self._swish(self._bn1(self._conv_head(x))) + + return x + + def forward(self, inputs): + """ Calls extract_features to extract features, applies final linear layer, and returns logits. """ + bs = inputs.size(0) + # Convolution layers + x = self.extract_features(inputs) + + # Pooling and final linear layer + x = self._avg_pooling(x) + x = x.view(bs, -1) + x = self._dropout(x) + x = self._fc(x) + return x + + @classmethod + def from_name(cls, model_name, override_params=None): + cls._check_model_name_is_valid(model_name) + blocks_args, global_params = get_model_params(model_name, override_params) + return cls(blocks_args, global_params) + + @classmethod + def from_pretrained(cls, model_name, advprop=False, num_classes=1000, in_channels=3): + model = cls.from_name(model_name, override_params={'num_classes': num_classes}) + load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop) + if in_channels != 3: + Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size) + out_channels = round_filters(32, model._global_params) + model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) + return model + + @classmethod + def get_image_size(cls, model_name): + cls._check_model_name_is_valid(model_name) + _, _, res, _ = efficientnet_params(model_name) + return res + + @classmethod + def _check_model_name_is_valid(cls, model_name): + """ Validates model name. """ + valid_models = ['efficientnet-b'+str(i) for i in range(9)] + if model_name not in valid_models: + raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) diff --git a/deepgaze_pytorch/features/efficientnet_pytorch/utils.py b/deepgaze_pytorch/features/efficientnet_pytorch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3870708a9758ac539a81bee46238905c2d836dcf --- /dev/null +++ b/deepgaze_pytorch/features/efficientnet_pytorch/utils.py @@ -0,0 +1,335 @@ +""" +This file contains helper functions for building the model and for loading model parameters. +These helper functions are built to mirror those in the official TensorFlow implementation. +""" + +import re +import math +import collections +from functools import partial +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils import model_zoo + +######################################################################## +############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ############### +######################################################################## + + +# Parameters for the entire model (stem, all blocks, and head) +GlobalParams = collections.namedtuple('GlobalParams', [ + 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', + 'num_classes', 'width_coefficient', 'depth_coefficient', + 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size']) + +# Parameters for an individual model block +BlockArgs = collections.namedtuple('BlockArgs', [ + 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', + 'expand_ratio', 'id_skip', 'stride', 'se_ratio']) + +# Change namedtuple defaults +GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) + + +class SwishImplementation(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class MemoryEfficientSwish(nn.Module): + def forward(self, x): + return SwishImplementation.apply(x) + +class Swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +def round_filters(filters, global_params): + """ Calculate and round number of filters based on depth multiplier. """ + multiplier = global_params.width_coefficient + if not multiplier: + return filters + divisor = global_params.depth_divisor + min_depth = global_params.min_depth + filters *= multiplier + min_depth = min_depth or divisor + new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: # prevent rounding by more than 10% + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats, global_params): + """ Round number of filters based on depth multiplier. """ + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + return int(math.ceil(multiplier * repeats)) + + +def drop_connect(inputs, p, training): + """ Drop connect. """ + if not training: return inputs + batch_size = inputs.shape[0] + keep_prob = 1 - p + random_tensor = keep_prob + random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) + binary_tensor = torch.floor(random_tensor) + output = inputs / keep_prob * binary_tensor + return output + + +def get_same_padding_conv2d(image_size=None): + """ Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. """ + if image_size is None: + return Conv2dDynamicSamePadding + else: + return partial(Conv2dStaticSamePadding, image_size=image_size) + + +class Conv2dDynamicSamePadding(nn.Conv2d): + """ 2D Convolutions like TensorFlow, for a dynamic image size """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): + super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class Conv2dStaticSamePadding(nn.Conv2d): + """ 2D Convolutions like TensorFlow, for a fixed image size""" + + def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs): + super().__init__(in_channels, out_channels, kernel_size, **kwargs) + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 + + # Calculate padding based on image size and save it + assert image_size is not None + ih, iw = image_size if type(image_size) == list else [image_size, image_size] + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) + else: + self.static_padding = Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +class Identity(nn.Module): + def __init__(self, ): + super(Identity, self).__init__() + + def forward(self, input): + return input + + +######################################################################## +############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ############## +######################################################################## + + +def efficientnet_params(model_name): + """ Map EfficientNet model name to parameter coefficients. """ + params_dict = { + # Coefficients: width,depth,res,dropout + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + } + return params_dict[model_name] + + +class BlockDecoder(object): + """ Block Decoder for readability, straight from the official TensorFlow repository """ + + @staticmethod + def _decode_block_string(block_string): + """ Gets a block through a string notation of arguments. """ + assert isinstance(block_string, str) + + ops = block_string.split('_') + options = {} + for op in ops: + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # Check stride + assert (('s' in options and len(options['s']) == 1) or + (len(options['s']) == 2 and options['s'][0] == options['s'][1])) + + return BlockArgs( + kernel_size=int(options['k']), + num_repeat=int(options['r']), + input_filters=int(options['i']), + output_filters=int(options['o']), + expand_ratio=int(options['e']), + id_skip=('noskip' not in block_string), + se_ratio=float(options['se']) if 'se' in options else None, + stride=[int(options['s'][0])]) + + @staticmethod + def _encode_block_string(block): + """Encodes a block to a string.""" + args = [ + 'r%d' % block.num_repeat, + 'k%d' % block.kernel_size, + 's%d%d' % (block.strides[0], block.strides[1]), + 'e%s' % block.expand_ratio, + 'i%d' % block.input_filters, + 'o%d' % block.output_filters + ] + if 0 < block.se_ratio <= 1: + args.append('se%s' % block.se_ratio) + if block.id_skip is False: + args.append('noskip') + return '_'.join(args) + + @staticmethod + def decode(string_list): + """ + Decodes a list of string notations to specify blocks inside the network. + + :param string_list: a list of strings, each string is a notation of block + :return: a list of BlockArgs namedtuples of block args + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(BlockDecoder._decode_block_string(block_string)) + return blocks_args + + @staticmethod + def encode(blocks_args): + """ + Encodes a list of BlockArgs to a list of strings. + + :param blocks_args: a list of BlockArgs namedtuples of block args + :return: a list of strings, each string is a notation of block + """ + block_strings = [] + for block in blocks_args: + block_strings.append(BlockDecoder._encode_block_string(block)) + return block_strings + + +def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2, + drop_connect_rate=0.2, image_size=None, num_classes=1000): + """ Creates a efficientnet model. """ + + blocks_args = [ + 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', + 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', + 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', + 'r1_k3_s11_e6_i192_o320_se0.25', + ] + blocks_args = BlockDecoder.decode(blocks_args) + + global_params = GlobalParams( + batch_norm_momentum=0.99, + batch_norm_epsilon=1e-3, + dropout_rate=dropout_rate, + drop_connect_rate=drop_connect_rate, + # data_format='channels_last', # removed, this is always true in PyTorch + num_classes=num_classes, + width_coefficient=width_coefficient, + depth_coefficient=depth_coefficient, + depth_divisor=8, + min_depth=None, + image_size=image_size, + ) + + return blocks_args, global_params + + +def get_model_params(model_name, override_params): + """ Get the block args and global params for a given model """ + if model_name.startswith('efficientnet'): + w, d, s, p = efficientnet_params(model_name) + # note: all models have drop connect rate = 0.2 + blocks_args, global_params = efficientnet( + width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) + else: + raise NotImplementedError('model name is not pre-defined: %s' % model_name) + if override_params: + # ValueError will be raised here if override_params has fields not included in global_params. + global_params = global_params._replace(**override_params) + return blocks_args, global_params + + +url_map = { + 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth', + 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth', + 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth', + 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth', + 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth', + 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth', + 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth', + 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth', +} + + +url_map_advprop = { + 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth', + 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth', + 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth', + 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth', + 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth', + 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth', + 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth', + 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth', + 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth', +} + + +def load_pretrained_weights(model, model_name, load_fc=True, advprop=False): + """ Loads pretrained weights, and downloads if loading for the first time. """ + # AutoAugment or Advprop (different preprocessing) + url_map_ = url_map_advprop if advprop else url_map + state_dict = model_zoo.load_url(url_map_[model_name]) + if load_fc: + model.load_state_dict(state_dict) + else: + state_dict.pop('_fc.weight') + state_dict.pop('_fc.bias') + res = model.load_state_dict(state_dict, strict=False) + assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights' + print('Loaded pretrained weights for {}'.format(model_name)) diff --git a/deepgaze_pytorch/features/inception.py b/deepgaze_pytorch/features/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..51550e1c03e339746eadda629dbf45e5c3f231ff --- /dev/null +++ b/deepgaze_pytorch/features/inception.py @@ -0,0 +1,20 @@ +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torchvision + +from .normalizer import Normalizer + + + + +class RGBInceptionV3(nn.Sequential): + def __init__(self): + super(RGBInceptionV3, self).__init__() + self.resnext = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True) + self.normalizer = Normalizer() + super(RGBInceptionV3, self).__init__(self.normalizer, self.resnext) + + diff --git a/deepgaze_pytorch/features/mobilenet.py b/deepgaze_pytorch/features/mobilenet.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd3fad2fc43cefb820d706b1309aed42374292b --- /dev/null +++ b/deepgaze_pytorch/features/mobilenet.py @@ -0,0 +1,17 @@ +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torchvision + +from .normalizer import Normalizer + + + +class RGBMobileNetV2(nn.Sequential): + def __init__(self): + super(RGBMobileNetV2, self).__init__() + self.mobilenet_v2 = torchvision.models.mobilenet_v2(pretrained=True) + self.normalizer = Normalizer() + super(RGBMobileNetV2, self).__init__(self.normalizer, self.mobilenet_v2) diff --git a/deepgaze_pytorch/features/normalizer.py b/deepgaze_pytorch/features/normalizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3d0c6740f425b1398de886545f783fd018177e6e --- /dev/null +++ b/deepgaze_pytorch/features/normalizer.py @@ -0,0 +1,28 @@ +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torchvision + +class Normalizer(nn.Module): + def __init__(self): + super(Normalizer, self).__init__() + mean = np.array([0.485, 0.456, 0.406]) + mean = mean[:, np.newaxis, np.newaxis] + + std = np.array([0.229, 0.224, 0.225]) + std = std[:, np.newaxis, np.newaxis] + + # don't persist to keep old checkpoints working + self.register_buffer('mean', torch.tensor(mean), persistent=False) + self.register_buffer('std', torch.tensor(std), persistent=False) + + + def forward(self, tensor): + tensor = tensor / 255.0 + + tensor -= self.mean + tensor /= self.std + + return tensor \ No newline at end of file diff --git a/deepgaze_pytorch/features/resnet.py b/deepgaze_pytorch/features/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4eae1ec4302e4f48cfff97b93eda7c0cd2e46cfb --- /dev/null +++ b/deepgaze_pytorch/features/resnet.py @@ -0,0 +1,44 @@ +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torchvision + +from .normalizer import Normalizer + + + +class RGBResNet34(nn.Sequential): + def __init__(self): + super(RGBResNet34, self).__init__() + self.resnet = torchvision.models.resnet34(pretrained=True) + self.normalizer = Normalizer() + super(RGBResNet34, self).__init__(self.normalizer, self.resnet) + + +class RGBResNet50(nn.Sequential): + def __init__(self): + super(RGBResNet50, self).__init__() + self.resnet = torchvision.models.resnet50(pretrained=True) + self.normalizer = Normalizer() + super(RGBResNet50, self).__init__(self.normalizer, self.resnet) + + +class RGBResNet50_alt(nn.Sequential): + def __init__(self): + super(RGBResNet50, self).__init__() + self.resnet = torchvision.models.resnet50(pretrained=True) + self.normalizer = Normalizer() + state_dict = torch.load("Resnet-AlternativePreTrain.pth") + model.load_state_dict(state_dict) + super(RGBResNet50, self).__init__(self.normalizer, self.resnet) + + + +class RGBResNet101(nn.Sequential): + def __init__(self): + super(RGBResNet101, self).__init__() + self.resnet = torchvision.models.resnet101(pretrained=True) + self.normalizer = Normalizer() + super(RGBResNet101, self).__init__(self.normalizer, self.resnet) diff --git a/deepgaze_pytorch/features/resnext.py b/deepgaze_pytorch/features/resnext.py new file mode 100644 index 0000000000000000000000000000000000000000..7fe657a6783348e4011c2b178730bcf491ad6e8c --- /dev/null +++ b/deepgaze_pytorch/features/resnext.py @@ -0,0 +1,27 @@ +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torchvision + +from .normalizer import Normalizer + + + +class RGBResNext50(nn.Sequential): + def __init__(self): + super(RGBResNext50, self).__init__() + self.resnext = torch.hub.load('pytorch/vision:v0.6.0', 'resnext50_32x4d', pretrained=True) + self.normalizer = Normalizer() + super(RGBResNext50, self).__init__(self.normalizer, self.resnext) + + +class RGBResNext101(nn.Sequential): + def __init__(self): + super(RGBResNext101, self).__init__() + self.resnext = torch.hub.load('pytorch/vision:v0.6.0', 'resnext101_32x8d', pretrained=True) + self.normalizer = Normalizer() + super(RGBResNext101, self).__init__(self.normalizer, self.resnext) + + diff --git a/deepgaze_pytorch/features/shapenet.py b/deepgaze_pytorch/features/shapenet.py new file mode 100644 index 0000000000000000000000000000000000000000..f9500cd115130f0f17bb69a9a235e204022d79c1 --- /dev/null +++ b/deepgaze_pytorch/features/shapenet.py @@ -0,0 +1,89 @@ +""" +This code was adapted from: https://github.com/rgeirhos/texture-vs-shape +""" +import os +import sys +from collections import OrderedDict +import torch +import torch.nn as nn +import torchvision +import torchvision.models +from torch.utils import model_zoo + +from .normalizer import Normalizer + + +def load_model(model_name): + + model_urls = { + 'resnet50_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/6f41d2e86fc60566f78de64ecff35cc61eb6436f/resnet50_train_60_epochs-c8e5653e.pth.tar', + 'resnet50_trained_on_SIN_and_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_train_45_epochs_combined_IN_SF-2a0d100e.pth.tar', + 'resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar', + 'vgg16_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/0008049cd10f74a944c6d5e90d4639927f8620ae/vgg16_train_60_epochs_lr0.01-6c6fcc9f.pth.tar', + 'alexnet_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/0008049cd10f74a944c6d5e90d4639927f8620ae/alexnet_train_60_epochs_lr0.001-b4aa5238.pth.tar', + } + + if "resnet50" in model_name: + #print("Using the ResNet50 architecture.") + model = torchvision.models.resnet50(pretrained=False) + #model = torch.nn.DataParallel(model) # .cuda() + # fake DataParallel structrue + model = torch.nn.Sequential(OrderedDict([('module', model)])) + checkpoint = model_zoo.load_url(model_urls[model_name], map_location=torch.device('cpu')) + elif "vgg16" in model_name: + #print("Using the VGG-16 architecture.") + + # download model from URL manually and save to desired location + filepath = "./vgg16_train_60_epochs_lr0.01-6c6fcc9f.pth.tar" + + assert os.path.exists(filepath), "Please download the VGG model yourself from the following link and save it locally: https://drive.google.com/drive/folders/1A0vUWyU6fTuc-xWgwQQeBvzbwi6geYQK (too large to be downloaded automatically like the other models)" + + model = torchvision.models.vgg16(pretrained=False) + model.features = torch.nn.DataParallel(model.features) + model.cuda() + checkpoint = torch.load(filepath, map_location=torch.device('cpu')) + + + elif "alexnet" in model_name: + #print("Using the AlexNet architecture.") + model = torchvision.models.alexnet(pretrained=False) + model.features = torch.nn.DataParallel(model.features) + model.cuda() + checkpoint = model_zoo.load_url(model_urls[model_name], map_location=torch.device('cpu')) + else: + raise ValueError("unknown model architecture.") + + model.load_state_dict(checkpoint["state_dict"]) + return model + +# --- DeepGaze Adaptation ---- + + + + +class RGBShapeNetA(nn.Sequential): + def __init__(self): + super(RGBShapeNetA, self).__init__() + self.shapenet = load_model("resnet50_trained_on_SIN") + self.normalizer = Normalizer() + super(RGBShapeNetA, self).__init__(self.normalizer, self.shapenet) + + + +class RGBShapeNetB(nn.Sequential): + def __init__(self): + super(RGBShapeNetB, self).__init__() + self.shapenet = load_model("resnet50_trained_on_SIN_and_IN") + self.normalizer = Normalizer() + super(RGBShapeNetB, self).__init__(self.normalizer, self.shapenet) + + +class RGBShapeNetC(nn.Sequential): + def __init__(self): + super(RGBShapeNetC, self).__init__() + self.shapenet = load_model("resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN") + self.normalizer = Normalizer() + super(RGBShapeNetC, self).__init__(self.normalizer, self.shapenet) + + + diff --git a/deepgaze_pytorch/features/squeezenet.py b/deepgaze_pytorch/features/squeezenet.py new file mode 100644 index 0000000000000000000000000000000000000000..b68a561d47c9ba878a36b3cb8c986efa6b5eec02 --- /dev/null +++ b/deepgaze_pytorch/features/squeezenet.py @@ -0,0 +1,17 @@ +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torchvision + +from .normalizer import Normalizer + + +class RGBSqueezeNet(nn.Sequential): + def __init__(self): + super(RGBSqueezeNet, self).__init__() + self.squeezenet = torch.hub.load('pytorch/vision:v0.6.0', 'squeezenet1_0', pretrained=True) + self.normalizer = Normalizer() + super(RGBSqueezeNet, self).__init__(self.normalizer, self.squeezenet) + diff --git a/deepgaze_pytorch/features/swav.py b/deepgaze_pytorch/features/swav.py new file mode 100644 index 0000000000000000000000000000000000000000..6f08db67205169d487e24bff3891599e1b1db97c --- /dev/null +++ b/deepgaze_pytorch/features/swav.py @@ -0,0 +1,20 @@ +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torchvision + +from .normalizer import Normalizer + + + + +class RGBSwav(nn.Sequential): + def __init__(self): + super(RGBSwav, self).__init__() + self.swav = torch.hub.load('facebookresearch/swav', 'resnet50', pretrained=True) + self.normalizer = Normalizer() + super(RGBSwav, self).__init__(self.normalizer, self.swav) + + diff --git a/deepgaze_pytorch/features/uninformative.py b/deepgaze_pytorch/features/uninformative.py new file mode 100644 index 0000000000000000000000000000000000000000..77eb94acc5e1ef81707f3835e5c69632854a4843 --- /dev/null +++ b/deepgaze_pytorch/features/uninformative.py @@ -0,0 +1,26 @@ +from collections import OrderedDict + +import torch +import torch.nn as nn + + +class OnesLayer(nn.Module): + def __init__(self, size=None): + super().__init__() + self.size = size + + def forward(self, tensor): + shape = list(tensor.shape) + shape[1] = 1 # return only one channel + + if self.size is not None: + shape[2], shape[3] = self.size + + return torch.ones(shape, dtype=torch.float32, device=tensor.device) + + +class UninformativeFeatures(torch.nn.Sequential): + def __init__(self): + super().__init__(OrderedDict([ + ('ones', OnesLayer(size=(1, 1))), + ])) diff --git a/deepgaze_pytorch/features/vgg.py b/deepgaze_pytorch/features/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..5924694aa10eee006cfbe07ec2fa6eb9bbd7df2a --- /dev/null +++ b/deepgaze_pytorch/features/vgg.py @@ -0,0 +1,86 @@ +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torchvision + + +class VGGInputNormalization(torch.nn.Module): + def __init__(self, inplace=True): + super().__init__() + + self.inplace = inplace + + mean = np.array([0.485, 0.456, 0.406]) + mean = mean[:, np.newaxis, np.newaxis] + + std = np.array([0.229, 0.224, 0.225]) + std = std[:, np.newaxis, np.newaxis] + self.register_buffer('mean', torch.tensor(mean)) + self.register_buffer('std', torch.tensor(std)) + + def forward(self, tensor): + if self.inplace: + tensor /= 255.0 + else: + tensor = tensor / 255.0 + + tensor -= self.mean + tensor /= self.std + + return tensor + + +class VGG19BNNamedFeatures(torch.nn.Sequential): + def __init__(self): + names = [] + for block in range(5): + block_size = 2 if block < 2 else 4 + for layer in range(block_size): + names.append(f'conv{block+1}_{layer+1}') + names.append(f'bn{block+1}_{layer+1}') + names.append(f'relu{block+1}_{layer+1}') + names.append(f'pool{block+1}') + + vgg = torchvision.models.vgg19_bn(pretrained=True) + vgg_features = vgg.features + vgg.classifier = torch.nn.Sequential() + + assert len(names) == len(vgg_features) + + named_features = OrderedDict({'normalize': VGGInputNormalization()}) + + for name, feature in zip(names, vgg_features): + if isinstance(feature, nn.MaxPool2d): + feature.ceil_mode = True + named_features[name] = feature + + super().__init__(named_features) + + +class VGG19NamedFeatures(torch.nn.Sequential): + def __init__(self): + names = [] + for block in range(5): + block_size = 2 if block < 2 else 4 + for layer in range(block_size): + names.append(f'conv{block+1}_{layer+1}') + names.append(f'relu{block+1}_{layer+1}') + names.append(f'pool{block+1}') + + vgg = torchvision.models.vgg19(pretrained=True) + vgg_features = vgg.features + vgg.classifier = torch.nn.Sequential() + + assert len(names) == len(vgg_features) + + named_features = OrderedDict({'normalize': VGGInputNormalization()}) + + for name, feature in zip(names, vgg_features): + if isinstance(feature, nn.MaxPool2d): + feature.ceil_mode = True + + named_features[name] = feature + + super().__init__(named_features) diff --git a/deepgaze_pytorch/features/vggnet.py b/deepgaze_pytorch/features/vggnet.py new file mode 100644 index 0000000000000000000000000000000000000000..723b0ef1418d3054011ac56a8c1a8372d0ff9353 --- /dev/null +++ b/deepgaze_pytorch/features/vggnet.py @@ -0,0 +1,24 @@ +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torchvision + +from .normalizer import Normalizer + + +class RGBvgg19(nn.Sequential): + def __init__(self): + super(RGBvgg19, self).__init__() + self.model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg19', pretrained=True) + self.normalizer = Normalizer() + super(RGBvgg19, self).__init__(self.normalizer, self.model) + + +class RGBvgg11(nn.Sequential): + def __init__(self): + super(RGBvgg11, self).__init__() + self.model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg11', pretrained=True) + self.normalizer = Normalizer() + super(RGBvgg11, self).__init__(self.normalizer, self.model) \ No newline at end of file diff --git a/deepgaze_pytorch/features/wsl.py b/deepgaze_pytorch/features/wsl.py new file mode 100644 index 0000000000000000000000000000000000000000..153205a7be7128cac53ae494bac29a10d96859aa --- /dev/null +++ b/deepgaze_pytorch/features/wsl.py @@ -0,0 +1,27 @@ +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torchvision + +from .normalizer import Normalizer + + + +class RGBResNext50(nn.Sequential): + def __init__(self): + super(RGBResNext50, self).__init__() + self.resnext = torch.hub.load('facebookresearch/WSL-Images', 'resnext50_32x16d_wsl') + self.normalizer = Normalizer() + super(RGBResNext50, self).__init__(self.normalizer, self.resnext) + + +class RGBResNext101(nn.Sequential): + def __init__(self): + super(RGBResNext101, self).__init__() + self.resnext = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x16d_wsl') + self.normalizer = Normalizer() + super(RGBResNext101, self).__init__(self.normalizer, self.resnext) + + diff --git a/deepgaze_pytorch/layers.py b/deepgaze_pytorch/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..b1020bee1eda915cfbae77504d0c3d2ee90f8e5f --- /dev/null +++ b/deepgaze_pytorch/layers.py @@ -0,0 +1,427 @@ +# pylint: disable=missing-module-docstring,invalid-name +# pylint: disable=missing-docstring +# pylint: disable=line-too-long + +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LayerNorm(nn.Module): + r"""Applies Layer Normalization over a mini-batch of inputs as described in + the paper `Layer Normalization`_ . + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated separately over the last + certain number dimensions which have to be of the shape specified by + :attr:`normalized_shape`. + :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of + :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. + + .. note:: + Unlike Batch Normalization and Instance Normalization, which applies + scalar scale and bias for each entire channel/plane with the + :attr:`affine` option, Layer Normalization applies per-element scale and + bias with :attr:`elementwise_affine`. + + This layer uses statistics computed from input data in both training and + evaluation modes. + + Args: + normalized_shape (int or list or torch.Size): input shape from an expected input + of size + + .. math:: + [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] + \times \ldots \times \text{normalized\_shape}[-1]] + + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability. Default: 1e-5 + elementwise_affine: a boolean value that when set to ``True``, this module + has learnable per-element affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + + Shape: + - Input: :math:`(N, *)` + - Output: :math:`(N, *)` (same shape as input) + + Examples:: + + >>> input = torch.randn(20, 5, 10, 10) + >>> # With Learnable Parameters + >>> m = nn.LayerNorm(input.size()[1:]) + >>> # Without Learnable Parameters + >>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False) + >>> # Normalize over last two dimensions + >>> m = nn.LayerNorm([10, 10]) + >>> # Normalize over last dimension of size 10 + >>> m = nn.LayerNorm(10) + >>> # Activating the module + >>> output = m(input) + + .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450 + """ + __constants__ = ['features', 'weight', 'bias', 'eps', 'center', 'scale'] + + def __init__(self, features, eps=1e-12, center=True, scale=True): + super(LayerNorm, self).__init__() + self.features = features + self.eps = eps + self.center = center + self.scale = scale + + if self.scale: + self.weight = nn.Parameter(torch.Tensor(self.features)) + else: + self.register_parameter('weight', None) + + if self.center: + self.bias = nn.Parameter(torch.Tensor(self.features)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + if self.scale: + nn.init.ones_(self.weight) + + if self.center: + nn.init.zeros_(self.bias) + + def adjust_parameter(self, tensor, parameter): + return torch.repeat_interleave( + torch.repeat_interleave( + parameter.view(-1, 1, 1), + repeats=tensor.shape[2], + dim=1), + repeats=tensor.shape[3], + dim=2 + ) + + def forward(self, input): + normalized_shape = (self.features, input.shape[2], input.shape[3]) + weight = self.adjust_parameter(input, self.weight) + bias = self.adjust_parameter(input, self.bias) + return F.layer_norm( + input, normalized_shape, weight, bias, self.eps) + + def extra_repr(self): + return '{features}, eps={eps}, ' \ + 'center={center}, scale={scale}'.format(**self.__dict__) + + +def gaussian_filter_1d(tensor, dim, sigma, truncate=4, kernel_size=None, padding_mode='replicate', padding_value=0.0): + sigma = torch.as_tensor(sigma, device=tensor.device, dtype=tensor.dtype) + + if kernel_size is not None: + kernel_size = torch.as_tensor(kernel_size, device=tensor.device, dtype=torch.int64) + else: + kernel_size = torch.as_tensor(2 * torch.ceil(truncate * sigma) + 1, device=tensor.device, dtype=torch.int64) + + kernel_size = kernel_size.detach() + + kernel_size_int = kernel_size.detach().cpu().numpy() + + mean = (torch.as_tensor(kernel_size, dtype=tensor.dtype) - 1) / 2 + + grid = torch.arange(kernel_size, device=tensor.device) - mean + + kernel_shape = (1, 1, kernel_size) + grid = grid.view(kernel_shape) + + grid = grid.detach() + + source_shape = tensor.shape + + tensor = torch.movedim(tensor, dim, len(source_shape)-1) + dim_last_shape = tensor.shape + assert tensor.shape[-1] == source_shape[dim] + + # we need reshape instead of view for batches like B x C x H x W + tensor = tensor.reshape(-1, 1, source_shape[dim]) + + padding = (math.ceil((kernel_size_int - 1) / 2), math.ceil((kernel_size_int - 1) / 2)) + tensor_ = F.pad(tensor, padding, padding_mode, padding_value) + + # create gaussian kernel from grid using current sigma + kernel = torch.exp(-0.5 * (grid / sigma) ** 2) + kernel = kernel / kernel.sum() + + # convolve input with gaussian kernel + tensor_ = F.conv1d(tensor_, kernel) + tensor_ = tensor_.view(dim_last_shape) + tensor_ = torch.movedim(tensor_, len(source_shape)-1, dim) + + assert tensor_.shape == source_shape + + return tensor_ + + +class GaussianFilterNd(nn.Module): + """A differentiable gaussian filter""" + + def __init__(self, dims, sigma, truncate=4, kernel_size=None, padding_mode='replicate', padding_value=0.0, + trainable=False): + """Creates a 1d gaussian filter + + Args: + dims ([int]): the dimensions to which the gaussian filter is applied. Negative values won't work + sigma (float): standard deviation of the gaussian filter (blur size) + input_dims (int, optional): number of input dimensions ignoring batch and channel dimension, + i.e. use input_dims=2 for images (default: 2). + truncate (float, optional): truncate the filter at this many standard deviations (default: 4.0). + This has no effect if the `kernel_size` is explicitely set + kernel_size (int): size of the gaussian kernel convolved with the input + padding_mode (string, optional): Padding mode implemented by `torch.nn.functional.pad`. + padding_value (string, optional): Value used for constant padding. + """ + # IDEA determine input_dims dynamically for every input + super(GaussianFilterNd, self).__init__() + + self.dims = dims + self.sigma = nn.Parameter(torch.tensor(sigma, dtype=torch.float32), requires_grad=trainable) # default: no optimization + self.truncate = truncate + self.kernel_size = kernel_size + + # setup padding + self.padding_mode = padding_mode + self.padding_value = padding_value + + def forward(self, tensor): + """Applies the gaussian filter to the given tensor""" + for dim in self.dims: + tensor = gaussian_filter_1d( + tensor, + dim=dim, + sigma=self.sigma, + truncate=self.truncate, + kernel_size=self.kernel_size, + padding_mode=self.padding_mode, + padding_value=self.padding_value, + ) + + return tensor + + +class Conv2dMultiInput(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, bias=True): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + for k, _in_channels in enumerate(in_channels): + if _in_channels: + setattr(self, f'conv_part{k}', nn.Conv2d(_in_channels, out_channels, kernel_size, bias=bias)) + + def forward(self, tensors): + assert len(tensors) == len(self.in_channels) + + out = None + for k, (count, tensor) in enumerate(zip(self.in_channels, tensors)): + if not count: + continue + _out = getattr(self, f'conv_part{k}')(tensor) + + if out is None: + out = _out + else: + out += _out + + return out + +# def extra_repr(self): +# return f'{self.in_channels}' + + +class LayerNormMultiInput(nn.Module): + __constants__ = ['features', 'weight', 'bias', 'eps', 'center', 'scale'] + + def __init__(self, features, eps=1e-12, center=True, scale=True): + super().__init__() + self.features = features + self.eps = eps + self.center = center + self.scale = scale + + for k, _features in enumerate(features): + if _features: + setattr(self, f'layernorm_part{k}', LayerNorm(_features, eps=eps, center=center, scale=scale)) + + def forward(self, tensors): + assert len(tensors) == len(self.features) + + out = [] + for k, (count, tensor) in enumerate(zip(self.features, tensors)): + if not count: + assert tensor is None + out.append(None) + continue + out.append(getattr(self, f'layernorm_part{k}')(tensor)) + + return out + + +class Bias(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.bias = nn.Parameter(torch.zeros(channels)) + + def forward(self, tensor): + return tensor + self.bias[np.newaxis, :, np.newaxis, np.newaxis] + + def extra_repr(self): + return f'channels={self.channels}' + + +class SelfAttention(nn.Module): + """ Self attention Layer + + adapted from https://discuss.pytorch.org/t/attention-in-image-classification/80147/3 + """ + + def __init__(self, in_channels, out_channels=None, key_channels=None, activation=None, skip_connection_with_convolution=False, return_attention=True): + super().__init__() + self.in_channels = in_channels + if out_channels is None: + out_channels = in_channels + self.out_channels = out_channels + if key_channels is None: + key_channels = in_channels // 8 + self.key_channels = key_channels + self.activation = activation + self.skip_connection_with_convolution = skip_connection_with_convolution + if not self.skip_connection_with_convolution: + if self.out_channels != self.in_channels: + raise ValueError("out_channels has to be equal to in_channels with true skip connection!") + self.return_attention = return_attention + + self.query_conv = nn.Conv2d(in_channels=in_channels, out_channels=key_channels, kernel_size=1) + self.key_conv = nn.Conv2d(in_channels=in_channels, out_channels=key_channels, kernel_size=1) + self.value_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) + self.gamma = nn.Parameter(torch.zeros(1)) + if self.skip_connection_with_convolution: + self.skip_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x): + """ + inputs : + x : input feature maps( B X C X W X H) + returns : + out : self attention value + input feature + attention: B X N X N (N is Width*Height) + """ + m_batchsize, C, width, height = x.size() + proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N) + proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H) + energy = torch.bmm(proj_query, proj_key) # transpose check + attention = self.softmax(energy) # BX (N) X (N) + proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N + + out = torch.bmm(proj_value, attention.permute(0, 2, 1)) + out = out.view(m_batchsize, self.out_channels, width, height) + + if self.skip_connection_with_convolution: + skip_connection = self.skip_conv(x) + else: + skip_connection = x + out = self.gamma * out + skip_connection + + if self.activation is not None: + out = self.activation(out) + + if self.return_attention: + return out, attention + + return out + + +class MultiHeadSelfAttention(nn.Module): + """ Self attention Layer + + adapted from https://discuss.pytorch.org/t/attention-in-image-classification/80147/3 + """ + + def __init__(self, in_channels, heads, out_channels=None, key_channels=None, activation=None, skip_connection_with_convolution=False): + super().__init__() + self.heads = heads + self.heads = nn.ModuleList([SelfAttention( + in_channels=in_channels, + out_channels=out_channels, + key_channels=key_channels, + activation=activation, + skip_connection_with_convolution=skip_connection_with_convolution, + return_attention=False, + ) for _ in range(heads)]) + + def forward(self, tensor): + outs = [head(tensor) for head in self.heads] + out = torch.cat(outs, dim=1) + return out + + +class FlexibleScanpathHistoryEncoding(nn.Module): + """ + a convolutional layer which works for different numbers of previous fixations. + + Nonexistent fixations will deactivate the respective convolutions + the bias will be added per fixation (if the given fixation is present) + """ + def __init__(self, in_fixations, channels_per_fixation, out_channels, kernel_size, bias=True,): + super().__init__() + self.in_fixations = in_fixations + self.channels_per_fixation = channels_per_fixation + self.out_channels = out_channels + self.kernel_size = kernel_size + self.bias = bias + self.convolutions = nn.ModuleList([ + nn.Conv2d( + in_channels=self.channels_per_fixation, + out_channels=self.out_channels, + kernel_size=self.kernel_size, + bias=self.bias + ) for i in range(in_fixations) + ]) + + def forward(self, tensor): + results = None + valid_fixations = ~torch.isnan( + tensor[:, :self.in_fixations, 0, 0] + ) + # print("valid fix", valid_fixations) + + for fixation_index in range(self.in_fixations): + valid_indices = valid_fixations[:, fixation_index] + if not torch.any(valid_indices): + continue + this_input = tensor[ + valid_indices, + fixation_index::self.in_fixations + ] + this_result = self.convolutions[fixation_index]( + this_input + ) + # TODO: This will break if all data points + # in the batch don't have a single fixation + # but that's not a case I intend to train + # anyway. + if results is None: + b, _, _, _ = tensor.shape + _, _, h, w = this_result.shape + results = torch.zeros( + (b, self.out_channels, h, w), + dtype=tensor.dtype, + device=tensor.device + ) + results[valid_indices] += this_result + + return results diff --git a/deepgaze_pytorch/metrics.py b/deepgaze_pytorch/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..04c945078b34b63ec0979eeb3b82c5568fc44eda --- /dev/null +++ b/deepgaze_pytorch/metrics.py @@ -0,0 +1,69 @@ +import numpy as np +from pysaliency.roc import general_roc +from pysaliency.numba_utils import auc_for_one_positive +import torch + + +def _general_auc(positives, negatives): + if len(positives) == 1: + return auc_for_one_positive(positives[0], negatives) + else: + return general_roc(positives, negatives)[0] + + +def log_likelihood(log_density, fixation_mask, weights=None): + #if weights is None: + # weights = torch.ones(log_density.shape[0]) + + weights = len(weights) * weights.view(-1, 1, 1) / weights.sum() + + if isinstance(fixation_mask, torch.sparse.IntTensor): + dense_mask = fixation_mask.to_dense() + else: + dense_mask = fixation_mask + fixation_count = dense_mask.sum(dim=(-1, -2), keepdim=True) + ll = torch.mean( + weights * torch.sum(log_density * dense_mask, dim=(-1, -2), keepdim=True) / fixation_count + ) + return (ll + np.log(log_density.shape[-1] * log_density.shape[-2])) / np.log(2) + + +def nss(log_density, fixation_mask, weights=None): + weights = len(weights) * weights.view(-1, 1, 1) / weights.sum() + if isinstance(fixation_mask, torch.sparse.IntTensor): + dense_mask = fixation_mask.to_dense() + else: + dense_mask = fixation_mask + + fixation_count = dense_mask.sum(dim=(-1, -2), keepdim=True) + + density = torch.exp(log_density) + mean, std = torch.std_mean(density, dim=(-1, -2), keepdim=True) + saliency_map = (density - mean) / std + + nss = torch.mean( + weights * torch.sum(saliency_map * dense_mask, dim=(-1, -2), keepdim=True) / fixation_count + ) + return nss + + +def auc(log_density, fixation_mask, weights=None): + weights = len(weights) * weights / weights.sum() + + # TODO: This doesn't account for multiple fixations in the same location! + def image_auc(log_density, fixation_mask): + if isinstance(fixation_mask, torch.sparse.IntTensor): + dense_mask = fixation_mask.to_dense() + else: + dense_mask = fixation_mask + + positives = torch.masked_select(log_density, dense_mask.type(torch.bool)).detach().cpu().numpy().astype(np.float64) + negatives = log_density.flatten().detach().cpu().numpy().astype(np.float64) + + auc = _general_auc(positives, negatives) + + return torch.tensor(auc) + + return torch.mean(weights.cpu() * torch.tensor([ + image_auc(log_density[i], fixation_mask[i]) for i in range(log_density.shape[0]) + ])) diff --git a/deepgaze_pytorch/modules.py b/deepgaze_pytorch/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..0f3c722b3cccadc68410ffa68e096cf39c39bd94 --- /dev/null +++ b/deepgaze_pytorch/modules.py @@ -0,0 +1,343 @@ +import functools +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .layers import GaussianFilterNd + + +def encode_scanpath_features(x_hist, y_hist, size, device=None, include_x=True, include_y=True, include_duration=False): + assert include_x + assert include_y + assert not include_duration + + height = size[0] + width = size[1] + + xs = torch.arange(width, dtype=torch.float32).to(device) + ys = torch.arange(height, dtype=torch.float32).to(device) + YS, XS = torch.meshgrid(ys, xs, indexing='ij') + + XS = torch.repeat_interleave( + torch.repeat_interleave( + XS[np.newaxis, np.newaxis, :, :], + repeats=x_hist.shape[0], + dim=0, + ), + repeats=x_hist.shape[1], + dim=1, + ) + + YS = torch.repeat_interleave( + torch.repeat_interleave( + YS[np.newaxis, np.newaxis, :, :], + repeats=y_hist.shape[0], + dim=0, + ), + repeats=y_hist.shape[1], + dim=1, + ) + + XS -= x_hist.unsqueeze(2).unsqueeze(3) + YS -= y_hist.unsqueeze(2).unsqueeze(3) + + distances = torch.sqrt(XS**2 + YS**2) + + return torch.cat((XS, YS, distances), axis=1) + +class FeatureExtractor(torch.nn.Module): + def __init__(self, features, targets): + super().__init__() + self.features = features + self.targets = targets + #print("Targets are {}".format(targets)) + self.outputs = {} + + for target in targets: + layer = dict([*self.features.named_modules()])[target] + layer.register_forward_hook(self.save_outputs_hook(target)) + + def save_outputs_hook(self, layer_id: str): + def fn(_, __, output): + self.outputs[layer_id] = output.clone() + return fn + + def forward(self, x): + + self.outputs.clear() + self.features(x) + return [self.outputs[target] for target in self.targets] + + +def upscale(tensor, size): + tensor_size = torch.tensor(tensor.shape[2:]).type(torch.float32) + target_size = torch.tensor(size).type(torch.float32) + factors = torch.ceil(target_size / tensor_size) + factor = torch.max(factors).type(torch.int64).to(tensor.device) + assert factor >= 1 + + tensor = torch.repeat_interleave(tensor, factor, dim=2) + tensor = torch.repeat_interleave(tensor, factor, dim=3) + + tensor = tensor[:, :, :size[0], :size[1]] + + return tensor + + +class Finalizer(nn.Module): + """Transforms a readout into a gaze prediction + + A readout network returns a single, spatial map of probable gaze locations. + This module bundles the common processing steps necessary to transform this into + the predicted gaze distribution: + + - resizing to the stimulus size + - smoothing of the prediction using a gaussian filter + - removing of channel and time dimension + - weighted addition of the center bias + - normalization + """ + + def __init__( + self, + sigma, + kernel_size=None, + learn_sigma=False, + center_bias_weight=1.0, + learn_center_bias_weight=True, + saliency_map_factor=4, + ): + """Creates a new finalizer + + Args: + size (tuple): target size for the predictions + sigma (float): standard deviation of the gaussian kernel used for smoothing + kernel_size (int, optional): size of the gaussian kernel + learn_sigma (bool, optional): If True, the standard deviation of the gaussian kernel will + be learned (default: False) + center_bias (string or tensor): the center bias + center_bias_weight (float, optional): initial weight of the center bias + learn_center_bias_weight (bool, optional): If True, the center bias weight will be + learned (default: True) + """ + super(Finalizer, self).__init__() + + self.saliency_map_factor = saliency_map_factor + + self.gauss = GaussianFilterNd([2, 3], sigma, truncate=3, trainable=learn_sigma) + self.center_bias_weight = nn.Parameter(torch.Tensor([center_bias_weight]), requires_grad=learn_center_bias_weight) + + def forward(self, readout, centerbias): + """Applies the finalization steps to the given readout""" + + downscaled_centerbias = F.interpolate( + centerbias.view(centerbias.shape[0], 1, centerbias.shape[1], centerbias.shape[2]), + scale_factor=1 / self.saliency_map_factor, + recompute_scale_factor=False, + )[:, 0, :, :] + + out = F.interpolate( + readout, + size=[downscaled_centerbias.shape[1], downscaled_centerbias.shape[2]] + ) + + # apply gaussian filter + out = self.gauss(out) + + # remove channel dimension + out = out[:, 0, :, :] + + # add to center bias + out = out + self.center_bias_weight * downscaled_centerbias + + out = F.interpolate(out[:, np.newaxis, :, :], size=[centerbias.shape[1], centerbias.shape[2]])[:, 0, :, :] + + # normalize + out = out - out.logsumexp(dim=(1, 2), keepdim=True) + + return out + + +class DeepGazeII(torch.nn.Module): + def __init__(self, features, readout_network, downsample=2, readout_factor=16, saliency_map_factor=2, initial_sigma=8.0): + super().__init__() + + self.readout_factor = readout_factor + self.saliency_map_factor = saliency_map_factor + + self.features = features + + for param in self.features.parameters(): + param.requires_grad = False + self.features.eval() + + self.readout_network = readout_network + self.finalizer = Finalizer( + sigma=initial_sigma, + learn_sigma=True, + saliency_map_factor=self.saliency_map_factor, + ) + self.downsample = downsample + + def forward(self, x, centerbias): + orig_shape = x.shape + x = F.interpolate( + x, + scale_factor=1 / self.downsample, + recompute_scale_factor=False, + ) + x = self.features(x) + + readout_shape = [math.ceil(orig_shape[2] / self.downsample / self.readout_factor), math.ceil(orig_shape[3] / self.downsample / self.readout_factor)] + x = [F.interpolate(item, readout_shape) for item in x] + + x = torch.cat(x, dim=1) + x = self.readout_network(x) + x = self.finalizer(x, centerbias) + + return x + + def train(self, mode=True): + self.features.eval() + self.readout_network.train(mode=mode) + self.finalizer.train(mode=mode) + + +class DeepGazeIII(torch.nn.Module): + def __init__(self, features, saliency_network, scanpath_network, fixation_selection_network, downsample=2, readout_factor=2, saliency_map_factor=2, included_fixations=-2, initial_sigma=8.0): + super().__init__() + + self.downsample = downsample + self.readout_factor = readout_factor + self.saliency_map_factor = saliency_map_factor + self.included_fixations = included_fixations + + self.features = features + + for param in self.features.parameters(): + param.requires_grad = False + self.features.eval() + + self.saliency_network = saliency_network + self.scanpath_network = scanpath_network + self.fixation_selection_network = fixation_selection_network + + self.finalizer = Finalizer( + sigma=initial_sigma, + learn_sigma=True, + saliency_map_factor=self.saliency_map_factor, + ) + + def forward(self, x, centerbias, x_hist=None, y_hist=None, durations=None): + orig_shape = x.shape + x = F.interpolate(x, scale_factor=1 / self.downsample) + x = self.features(x) + + readout_shape = [math.ceil(orig_shape[2] / self.downsample / self.readout_factor), math.ceil(orig_shape[3] / self.downsample / self.readout_factor)] + x = [F.interpolate(item, readout_shape) for item in x] + + x = torch.cat(x, dim=1) + x = self.saliency_network(x) + + if self.scanpath_network is not None: + scanpath_features = encode_scanpath_features(x_hist, y_hist, size=(orig_shape[2], orig_shape[3]), device=x.device) + #scanpath_features = F.interpolate(scanpath_features, scale_factor=1 / self.downsample / self.readout_factor) + scanpath_features = F.interpolate(scanpath_features, readout_shape) + y = self.scanpath_network(scanpath_features) + else: + y = None + + x = self.fixation_selection_network((x, y)) + + x = self.finalizer(x, centerbias) + + return x + + def train(self, mode=True): + self.features.eval() + self.saliency_network.train(mode=mode) + if self.scanpath_network is not None: + self.scanpath_network.train(mode=mode) + self.fixation_selection_network.train(mode=mode) + self.finalizer.train(mode=mode) + + +class DeepGazeIIIMixture(torch.nn.Module): + def __init__(self, features, saliency_networks, scanpath_networks, fixation_selection_networks, finalizers, downsample=2, readout_factor=2, saliency_map_factor=2, included_fixations=-2, initial_sigma=8.0): + super().__init__() + + self.downsample = downsample + self.readout_factor = readout_factor + self.saliency_map_factor = saliency_map_factor + self.included_fixations = included_fixations + + self.features = features + + for param in self.features.parameters(): + param.requires_grad = False + self.features.eval() + + self.saliency_networks = torch.nn.ModuleList(saliency_networks) + self.scanpath_networks = torch.nn.ModuleList(scanpath_networks) + self.fixation_selection_networks = torch.nn.ModuleList(fixation_selection_networks) + self.finalizers = torch.nn.ModuleList(finalizers) + + def forward(self, x, centerbias, x_hist=None, y_hist=None, durations=None): + orig_shape = x.shape + x = F.interpolate( + x, + scale_factor=1 / self.downsample, + recompute_scale_factor=False, + ) + x = self.features(x) + + readout_shape = [math.ceil(orig_shape[2] / self.downsample / self.readout_factor), math.ceil(orig_shape[3] / self.downsample / self.readout_factor)] + x = [F.interpolate(item, readout_shape) for item in x] + + x = torch.cat(x, dim=1) + + predictions = [] + + readout_input = x + + for saliency_network, scanpath_network, fixation_selection_network, finalizer in zip( + self.saliency_networks, self.scanpath_networks, self.fixation_selection_networks, self.finalizers + ): + + x = saliency_network(readout_input) + + if scanpath_network is not None: + scanpath_features = encode_scanpath_features(x_hist, y_hist, size=(orig_shape[2], orig_shape[3]), device=x.device) + scanpath_features = F.interpolate(scanpath_features, readout_shape) + y = scanpath_network(scanpath_features) + else: + y = None + + x = fixation_selection_network((x, y)) + + x = finalizer(x, centerbias) + + predictions.append(x[:, np.newaxis, :, :]) + + predictions = torch.cat(predictions, dim=1) - np.log(len(self.saliency_networks)) + + prediction = predictions.logsumexp(dim=(1), keepdim=True) + + return prediction + + +class MixtureModel(torch.nn.Module): + def __init__(self, models): + super().__init__() + self.models = torch.nn.ModuleList(models) + + def forward(self, *args, **kwargs): + predictions = [model.forward(*args, **kwargs) for model in self.models] + predictions = torch.cat(predictions, dim=1) + predictions -= np.log(len(self.models)) + prediction = predictions.logsumexp(dim=(1), keepdim=True) + + return prediction diff --git a/deepgaze_pytorch/training.py b/deepgaze_pytorch/training.py new file mode 100644 index 0000000000000000000000000000000000000000..6443a5488ddda7fa7cb460ff918ba290130ad2b0 --- /dev/null +++ b/deepgaze_pytorch/training.py @@ -0,0 +1,309 @@ +# flake8: noqa E501 +# pylint: disable=not-callable +# E501: line too long + +from collections import defaultdict +from datetime import datetime +import glob +import os +import tempfile + +from boltons.cacheutils import cached, LRU +from boltons.fileutils import atomic_save, mkdir_p +from boltons.iterutils import windowed +from IPython import get_ipython +from IPython.display import display +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pysaliency +from pysaliency.filter_datasets import iterate_crossvalidation +from pysaliency.plotting import visualize_distribution +import torch +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm +import yaml + +from .data import ImageDataset, FixationDataset, ImageDatasetSampler, FixationMaskTransform +#from .loading import import_class, build_model, DeepGazeCheckpointModel, SharedPyTorchModel, _get_from_config +from .metrics import log_likelihood, nss, auc +from .modules import DeepGazeII + + + +baseline_performance = cached(LRU(max_size=3))(lambda model, *args, **kwargs: model.information_gain(*args, **kwargs)) + + +def eval_epoch(model, dataset, baseline_information_gain, device, metrics=None): + model.eval() + + if metrics is None: + metrics = ['LL', 'IG', 'NSS', 'AUC'] + + metric_scores = {} + metric_functions = { + 'LL': log_likelihood, + 'NSS': nss, + 'AUC': auc, + } + batch_weights = [] + + with torch.no_grad(): + pbar = tqdm(dataset) + for batch in pbar: + image = batch.pop('image').to(device) + centerbias = batch.pop('centerbias').to(device) + fixation_mask = batch.pop('fixation_mask').to(device) + x_hist = batch.pop('x_hist', torch.tensor([])).to(device) + y_hist = batch.pop('y_hist', torch.tensor([])).to(device) + weights = batch.pop('weight').to(device) + durations = batch.pop('durations', torch.tensor([])).to(device) + + kwargs = {} + for key, value in dict(batch).items(): + kwargs[key] = value.to(device) + + if isinstance(model, DeepGazeII): + log_density = model(image, centerbias, **kwargs) + else: + log_density = model(image, centerbias, x_hist=x_hist, y_hist=y_hist, durations=durations, **kwargs) + + for metric_name, metric_fn in metric_functions.items(): + if metric_name not in metrics: + continue + metric_scores.setdefault(metric_name, []).append(metric_fn(log_density, fixation_mask, weights=weights).detach().cpu().numpy()) + batch_weights.append(weights.detach().cpu().numpy().sum()) + + for display_metric in ['LL', 'NSS', 'AUC']: + if display_metric in metrics: + pbar.set_description('{} {:.05f}'.format(display_metric, np.average(metric_scores[display_metric], weights=batch_weights))) + break + + data = {metric_name: np.average(scores, weights=batch_weights) for metric_name, scores in metric_scores.items()} + if 'IG' in metrics: + data['IG'] = data['LL'] - baseline_information_gain + + return data + +def train_epoch(model, dataset, optimizer, device): + model.train() + losses = [] + batch_weights = [] + + pbar = tqdm(dataset) + for batch in pbar: + optimizer.zero_grad() + + image = batch.pop('image').to(device) + centerbias = batch.pop('centerbias').to(device) + fixation_mask = batch.pop('fixation_mask').to(device) + x_hist = batch.pop('x_hist', torch.tensor([])).to(device) + y_hist = batch.pop('y_hist', torch.tensor([])).to(device) + weights = batch.pop('weight').to(device) + durations = batch.pop('durations', torch.tensor([])).to(device) + + kwargs = {} + for key, value in dict(batch).items(): + kwargs[key] = value.to(device) + + if isinstance(model, DeepGazeII): + log_density = model(image, centerbias, **kwargs) + else: + log_density = model(image, centerbias, x_hist=x_hist, y_hist=y_hist, durations=durations, **kwargs) + + loss = -log_likelihood(log_density, fixation_mask, weights=weights) + losses.append(loss.detach().cpu().numpy()) + + batch_weights.append(weights.detach().cpu().numpy().sum()) + + pbar.set_description('{:.05f}'.format(np.average(losses, weights=batch_weights))) + + loss.backward() + + optimizer.step() + + return np.average(losses, weights=batch_weights) + + +def restore_from_checkpoint(model, optimizer, scheduler, path): + print("Restoring from", path) + data = torch.load(path) + if 'optimizer' in data: + # checkpoint contains training progress + model.load_state_dict(data['model']) + optimizer.load_state_dict(data['optimizer']) + scheduler.load_state_dict(data['scheduler']) + torch.set_rng_state(data['rng_state']) + return data['step'], data['loss'] + else: + # checkpoint contains just a model + missing_keys, unexpected_keys = model.load_state_dict(data, strict=False) + if missing_keys: + print("WARNING! missing keys", missing_keys) + if unexpected_keys: + print("WARNING! Unexpected keys", unexpected_keys) + + +def save_training_state(model, optimizer, scheduler, step, loss, path): + data = { + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict(), + 'rng_state': torch.get_rng_state(), + 'step': step, + 'loss': loss, + } + + with atomic_save(path, text_mode=False, overwrite_part=True) as f: + torch.save(data, f) + + + + +def _train(this_directory, + model, + train_loader, train_baseline_log_likelihood, + val_loader, val_baseline_log_likelihood, + optimizer, lr_scheduler, + #optimizer_config, lr_scheduler_config, + minimum_learning_rate, + #initial_learning_rate, learning_rate_scheduler, learning_rate_decay, learning_rate_decay_epochs, learning_rate_backlook, learning_rate_reset_strategy, minimum_learning_rate, + validation_metric='IG', + validation_metrics=['IG', 'LL', 'AUC', 'NSS'], + validation_epochs=1, + startwith=None, + device=None): + mkdir_p(this_directory) + + if os.path.isfile(os.path.join(this_directory, 'final.pth')): + print("Training Already finished") + return + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + + print("Using device", device) + + model.to(device) + + val_metrics = defaultdict(lambda: []) + + if startwith is not None: + restore_from_checkpoint(model, optimizer, lr_scheduler, startwith) + + writer = SummaryWriter(os.path.join(this_directory, 'log'), flush_secs=30) + + columns = ['epoch', 'timestamp', 'learning_rate', 'loss'] + print("validation metrics", validation_metrics) + for metric in validation_metrics: + columns.append(f'validation_{metric}') + + progress = pd.DataFrame(columns=columns) + + step = 0 + last_loss = np.nan + + def save_step(): + + save_training_state( + model, optimizer, lr_scheduler, step, last_loss, + '{}/step-{:04d}.pth'.format(this_directory, step), + ) + + #f = visualize(model, vis_data_loader) + #display_if_in_IPython(f) + + #writer.add_figure('prediction', f, step) + writer.add_scalar('training/loss', last_loss, step) + writer.add_scalar('training/learning_rate', optimizer.state_dict()['param_groups'][0]['lr'], step) + writer.add_scalar('parameters/sigma', model.finalizer.gauss.sigma.detach().cpu().numpy(), step) + writer.add_scalar('parameters/center_bias_weight', model.finalizer.center_bias_weight.detach().cpu().numpy()[0], step) + + if step % validation_epochs == 0: + _val_metrics = eval_epoch(model, val_loader, val_baseline_log_likelihood, device, metrics=validation_metrics) + else: + print("Skipping validation") + _val_metrics = {} + + for key, value in _val_metrics.items(): + val_metrics[key].append(value) + + for key, value in _val_metrics.items(): + writer.add_scalar(f'validation/{key}', value, step) + + new_row = { + 'epoch': step, + 'timestamp': datetime.utcnow(), + 'learning_rate': optimizer.state_dict()['param_groups'][0]['lr'], + 'loss': last_loss, + #'validation_ig': val_igs[-1] + } + for key, value in _val_metrics.items(): + new_row['validation_{}'.format(key)] = value + + progress.loc[step] = new_row + + print(progress.tail(n=2)) + print(progress[['validation_{}'.format(key) for key in val_metrics]].idxmax(axis=0)) + + with atomic_save('{}/log.csv'.format(this_directory), text_mode=True, overwrite_part=True) as f: + progress.to_csv(f) + + for old_step in range(1, step): + # only check if we are computing validation metrics... + if validation_metric in val_metrics and val_metrics[validation_metric] and old_step == np.argmax(val_metrics[validation_metric]): + continue + for filename in glob.glob('{}/step-{:04d}.pth'.format(this_directory, old_step)): + print("removing", filename) + os.remove(filename) + + old_checkpoints = sorted(glob.glob(os.path.join(this_directory, 'step-*.pth'))) + if old_checkpoints: + last_checkpoint = old_checkpoints[-1] + print("Found old checkpoint", last_checkpoint) + step, last_loss = restore_from_checkpoint(model, optimizer, lr_scheduler, last_checkpoint) + print("Setting step to", step) + + if step == 0: + print("Beginning training") + save_step() + + else: + print("Continuing from step", step) + progress = pd.read_csv(os.path.join(this_directory, 'log.csv'), index_col=0) + val_metrics = {} + for column_name in progress.columns: + if column_name.startswith('validation_'): + val_metrics[column_name.split('validation_', 1)[1]] = list(progress[column_name]) + + if step not in progress.epoch.values: + print("Epoch not yet evaluated, evaluating...") + save_step() + + # We have to make one scheduler step here, since we make the + # scheduler step _after_ saving the checkpoint + lr_scheduler.step() + + print(progress) + + while optimizer.state_dict()['param_groups'][0]['lr'] >= minimum_learning_rate: + step += 1 + last_loss = train_epoch(model, train_loader, optimizer, device) + save_step() + lr_scheduler.step() + + + + #if learning_rate_reset_strategy == 'validation': + # best_step = np.argmax(val_metrics[validation_metric]) + # print("Best previous validation in step {}, saving as final result".format(best_step)) + # restore_from_checkpoint(model, optimizer, scheduler, os.path.join(this_directory, 'step-{:04d}.pth'.format(best_step))) + #else: + # print("Not resetting to best validation epoch") + + torch.save(model.state_dict(), '{}/final.pth'.format(this_directory)) + + for filename in glob.glob(os.path.join(this_directory, 'step-*')): + print("removing", filename) + os.remove(filename) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..662799d538aaed2dff1bf12e485e9252f5cbb4fd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +setuptools<70.0.0 +awscli==1.29.54 +gradio==4.19.2 +torch==2.3.1 +opencv-python==4.9.0 \ No newline at end of file