Spaces:
Running
Running
import functools | |
import math | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .layers import GaussianFilterNd | |
def encode_scanpath_features(x_hist, y_hist, size, device=None, include_x=True, include_y=True, include_duration=False): | |
assert include_x | |
assert include_y | |
assert not include_duration | |
height = size[0] | |
width = size[1] | |
xs = torch.arange(width, dtype=torch.float32).to(device) | |
ys = torch.arange(height, dtype=torch.float32).to(device) | |
YS, XS = torch.meshgrid(ys, xs, indexing='ij') | |
XS = torch.repeat_interleave( | |
torch.repeat_interleave( | |
XS[np.newaxis, np.newaxis, :, :], | |
repeats=x_hist.shape[0], | |
dim=0, | |
), | |
repeats=x_hist.shape[1], | |
dim=1, | |
) | |
YS = torch.repeat_interleave( | |
torch.repeat_interleave( | |
YS[np.newaxis, np.newaxis, :, :], | |
repeats=y_hist.shape[0], | |
dim=0, | |
), | |
repeats=y_hist.shape[1], | |
dim=1, | |
) | |
XS -= x_hist.unsqueeze(2).unsqueeze(3) | |
YS -= y_hist.unsqueeze(2).unsqueeze(3) | |
distances = torch.sqrt(XS**2 + YS**2) | |
return torch.cat((XS, YS, distances), axis=1) | |
class FeatureExtractor(torch.nn.Module): | |
def __init__(self, features, targets): | |
super().__init__() | |
self.features = features | |
self.targets = targets | |
#print("Targets are {}".format(targets)) | |
self.outputs = {} | |
for target in targets: | |
layer = dict([*self.features.named_modules()])[target] | |
layer.register_forward_hook(self.save_outputs_hook(target)) | |
def save_outputs_hook(self, layer_id: str): | |
def fn(_, __, output): | |
self.outputs[layer_id] = output.clone() | |
return fn | |
def forward(self, x): | |
self.outputs.clear() | |
self.features(x) | |
return [self.outputs[target] for target in self.targets] | |
def upscale(tensor, size): | |
tensor_size = torch.tensor(tensor.shape[2:]).type(torch.float32) | |
target_size = torch.tensor(size).type(torch.float32) | |
factors = torch.ceil(target_size / tensor_size) | |
factor = torch.max(factors).type(torch.int64).to(tensor.device) | |
assert factor >= 1 | |
tensor = torch.repeat_interleave(tensor, factor, dim=2) | |
tensor = torch.repeat_interleave(tensor, factor, dim=3) | |
tensor = tensor[:, :, :size[0], :size[1]] | |
return tensor | |
class Finalizer(nn.Module): | |
"""Transforms a readout into a gaze prediction | |
A readout network returns a single, spatial map of probable gaze locations. | |
This module bundles the common processing steps necessary to transform this into | |
the predicted gaze distribution: | |
- resizing to the stimulus size | |
- smoothing of the prediction using a gaussian filter | |
- removing of channel and time dimension | |
- weighted addition of the center bias | |
- normalization | |
""" | |
def __init__( | |
self, | |
sigma, | |
kernel_size=None, | |
learn_sigma=False, | |
center_bias_weight=1.0, | |
learn_center_bias_weight=True, | |
saliency_map_factor=4, | |
): | |
"""Creates a new finalizer | |
Args: | |
size (tuple): target size for the predictions | |
sigma (float): standard deviation of the gaussian kernel used for smoothing | |
kernel_size (int, optional): size of the gaussian kernel | |
learn_sigma (bool, optional): If True, the standard deviation of the gaussian kernel will | |
be learned (default: False) | |
center_bias (string or tensor): the center bias | |
center_bias_weight (float, optional): initial weight of the center bias | |
learn_center_bias_weight (bool, optional): If True, the center bias weight will be | |
learned (default: True) | |
""" | |
super(Finalizer, self).__init__() | |
self.saliency_map_factor = saliency_map_factor | |
self.gauss = GaussianFilterNd([2, 3], sigma, truncate=3, trainable=learn_sigma) | |
self.center_bias_weight = nn.Parameter(torch.Tensor([center_bias_weight]), requires_grad=learn_center_bias_weight) | |
def forward(self, readout, centerbias): | |
"""Applies the finalization steps to the given readout""" | |
downscaled_centerbias = F.interpolate( | |
centerbias.view(centerbias.shape[0], 1, centerbias.shape[1], centerbias.shape[2]), | |
scale_factor=1 / self.saliency_map_factor, | |
recompute_scale_factor=False, | |
)[:, 0, :, :] | |
out = F.interpolate( | |
readout, | |
size=[downscaled_centerbias.shape[1], downscaled_centerbias.shape[2]] | |
) | |
# apply gaussian filter | |
out = self.gauss(out) | |
# remove channel dimension | |
out = out[:, 0, :, :] | |
# add to center bias | |
out = out + self.center_bias_weight * downscaled_centerbias | |
out = F.interpolate(out[:, np.newaxis, :, :], size=[centerbias.shape[1], centerbias.shape[2]])[:, 0, :, :] | |
# normalize | |
out = out - out.logsumexp(dim=(1, 2), keepdim=True) | |
return out | |
class DeepGazeII(torch.nn.Module): | |
def __init__(self, features, readout_network, downsample=2, readout_factor=16, saliency_map_factor=2, initial_sigma=8.0): | |
super().__init__() | |
self.readout_factor = readout_factor | |
self.saliency_map_factor = saliency_map_factor | |
self.features = features | |
for param in self.features.parameters(): | |
param.requires_grad = False | |
self.features.eval() | |
self.readout_network = readout_network | |
self.finalizer = Finalizer( | |
sigma=initial_sigma, | |
learn_sigma=True, | |
saliency_map_factor=self.saliency_map_factor, | |
) | |
self.downsample = downsample | |
def forward(self, x, centerbias): | |
orig_shape = x.shape | |
x = F.interpolate( | |
x, | |
scale_factor=1 / self.downsample, | |
recompute_scale_factor=False, | |
) | |
x = self.features(x) | |
readout_shape = [math.ceil(orig_shape[2] / self.downsample / self.readout_factor), math.ceil(orig_shape[3] / self.downsample / self.readout_factor)] | |
x = [F.interpolate(item, readout_shape) for item in x] | |
x = torch.cat(x, dim=1) | |
x = self.readout_network(x) | |
x = self.finalizer(x, centerbias) | |
return x | |
def train(self, mode=True): | |
self.features.eval() | |
self.readout_network.train(mode=mode) | |
self.finalizer.train(mode=mode) | |
class DeepGazeIII(torch.nn.Module): | |
def __init__(self, features, saliency_network, scanpath_network, fixation_selection_network, downsample=2, readout_factor=2, saliency_map_factor=2, included_fixations=-2, initial_sigma=8.0): | |
super().__init__() | |
self.downsample = downsample | |
self.readout_factor = readout_factor | |
self.saliency_map_factor = saliency_map_factor | |
self.included_fixations = included_fixations | |
self.features = features | |
for param in self.features.parameters(): | |
param.requires_grad = False | |
self.features.eval() | |
self.saliency_network = saliency_network | |
self.scanpath_network = scanpath_network | |
self.fixation_selection_network = fixation_selection_network | |
self.finalizer = Finalizer( | |
sigma=initial_sigma, | |
learn_sigma=True, | |
saliency_map_factor=self.saliency_map_factor, | |
) | |
def forward(self, x, centerbias, x_hist=None, y_hist=None, durations=None): | |
orig_shape = x.shape | |
x = F.interpolate(x, scale_factor=1 / self.downsample) | |
x = self.features(x) | |
readout_shape = [math.ceil(orig_shape[2] / self.downsample / self.readout_factor), math.ceil(orig_shape[3] / self.downsample / self.readout_factor)] | |
x = [F.interpolate(item, readout_shape) for item in x] | |
x = torch.cat(x, dim=1) | |
x = self.saliency_network(x) | |
if self.scanpath_network is not None: | |
scanpath_features = encode_scanpath_features(x_hist, y_hist, size=(orig_shape[2], orig_shape[3]), device=x.device) | |
#scanpath_features = F.interpolate(scanpath_features, scale_factor=1 / self.downsample / self.readout_factor) | |
scanpath_features = F.interpolate(scanpath_features, readout_shape) | |
y = self.scanpath_network(scanpath_features) | |
else: | |
y = None | |
x = self.fixation_selection_network((x, y)) | |
x = self.finalizer(x, centerbias) | |
return x | |
def train(self, mode=True): | |
self.features.eval() | |
self.saliency_network.train(mode=mode) | |
if self.scanpath_network is not None: | |
self.scanpath_network.train(mode=mode) | |
self.fixation_selection_network.train(mode=mode) | |
self.finalizer.train(mode=mode) | |
class DeepGazeIIIMixture(torch.nn.Module): | |
def __init__(self, features, saliency_networks, scanpath_networks, fixation_selection_networks, finalizers, downsample=2, readout_factor=2, saliency_map_factor=2, included_fixations=-2, initial_sigma=8.0): | |
super().__init__() | |
self.downsample = downsample | |
self.readout_factor = readout_factor | |
self.saliency_map_factor = saliency_map_factor | |
self.included_fixations = included_fixations | |
self.features = features | |
for param in self.features.parameters(): | |
param.requires_grad = False | |
self.features.eval() | |
self.saliency_networks = torch.nn.ModuleList(saliency_networks) | |
self.scanpath_networks = torch.nn.ModuleList(scanpath_networks) | |
self.fixation_selection_networks = torch.nn.ModuleList(fixation_selection_networks) | |
self.finalizers = torch.nn.ModuleList(finalizers) | |
def forward(self, x, centerbias, x_hist=None, y_hist=None, durations=None): | |
orig_shape = x.shape | |
x = F.interpolate( | |
x, | |
scale_factor=1 / self.downsample, | |
recompute_scale_factor=False, | |
) | |
x = self.features(x) | |
readout_shape = [math.ceil(orig_shape[2] / self.downsample / self.readout_factor), math.ceil(orig_shape[3] / self.downsample / self.readout_factor)] | |
x = [F.interpolate(item, readout_shape) for item in x] | |
x = torch.cat(x, dim=1) | |
predictions = [] | |
readout_input = x | |
for saliency_network, scanpath_network, fixation_selection_network, finalizer in zip( | |
self.saliency_networks, self.scanpath_networks, self.fixation_selection_networks, self.finalizers | |
): | |
x = saliency_network(readout_input) | |
if scanpath_network is not None: | |
scanpath_features = encode_scanpath_features(x_hist, y_hist, size=(orig_shape[2], orig_shape[3]), device=x.device) | |
scanpath_features = F.interpolate(scanpath_features, readout_shape) | |
y = scanpath_network(scanpath_features) | |
else: | |
y = None | |
x = fixation_selection_network((x, y)) | |
x = finalizer(x, centerbias) | |
predictions.append(x[:, np.newaxis, :, :]) | |
predictions = torch.cat(predictions, dim=1) - np.log(len(self.saliency_networks)) | |
prediction = predictions.logsumexp(dim=(1), keepdim=True) | |
return prediction | |
class MixtureModel(torch.nn.Module): | |
def __init__(self, models): | |
super().__init__() | |
self.models = torch.nn.ModuleList(models) | |
def forward(self, *args, **kwargs): | |
predictions = [model.forward(*args, **kwargs) for model in self.models] | |
predictions = torch.cat(predictions, dim=1) | |
predictions -= np.log(len(self.models)) | |
prediction = predictions.logsumexp(dim=(1), keepdim=True) | |
return prediction | |