mbar0075's picture
Testing Commit
c9baa67
raw
history blame
11.8 kB
import functools
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .layers import GaussianFilterNd
def encode_scanpath_features(x_hist, y_hist, size, device=None, include_x=True, include_y=True, include_duration=False):
assert include_x
assert include_y
assert not include_duration
height = size[0]
width = size[1]
xs = torch.arange(width, dtype=torch.float32).to(device)
ys = torch.arange(height, dtype=torch.float32).to(device)
YS, XS = torch.meshgrid(ys, xs, indexing='ij')
XS = torch.repeat_interleave(
torch.repeat_interleave(
XS[np.newaxis, np.newaxis, :, :],
repeats=x_hist.shape[0],
dim=0,
),
repeats=x_hist.shape[1],
dim=1,
)
YS = torch.repeat_interleave(
torch.repeat_interleave(
YS[np.newaxis, np.newaxis, :, :],
repeats=y_hist.shape[0],
dim=0,
),
repeats=y_hist.shape[1],
dim=1,
)
XS -= x_hist.unsqueeze(2).unsqueeze(3)
YS -= y_hist.unsqueeze(2).unsqueeze(3)
distances = torch.sqrt(XS**2 + YS**2)
return torch.cat((XS, YS, distances), axis=1)
class FeatureExtractor(torch.nn.Module):
def __init__(self, features, targets):
super().__init__()
self.features = features
self.targets = targets
#print("Targets are {}".format(targets))
self.outputs = {}
for target in targets:
layer = dict([*self.features.named_modules()])[target]
layer.register_forward_hook(self.save_outputs_hook(target))
def save_outputs_hook(self, layer_id: str):
def fn(_, __, output):
self.outputs[layer_id] = output.clone()
return fn
def forward(self, x):
self.outputs.clear()
self.features(x)
return [self.outputs[target] for target in self.targets]
def upscale(tensor, size):
tensor_size = torch.tensor(tensor.shape[2:]).type(torch.float32)
target_size = torch.tensor(size).type(torch.float32)
factors = torch.ceil(target_size / tensor_size)
factor = torch.max(factors).type(torch.int64).to(tensor.device)
assert factor >= 1
tensor = torch.repeat_interleave(tensor, factor, dim=2)
tensor = torch.repeat_interleave(tensor, factor, dim=3)
tensor = tensor[:, :, :size[0], :size[1]]
return tensor
class Finalizer(nn.Module):
"""Transforms a readout into a gaze prediction
A readout network returns a single, spatial map of probable gaze locations.
This module bundles the common processing steps necessary to transform this into
the predicted gaze distribution:
- resizing to the stimulus size
- smoothing of the prediction using a gaussian filter
- removing of channel and time dimension
- weighted addition of the center bias
- normalization
"""
def __init__(
self,
sigma,
kernel_size=None,
learn_sigma=False,
center_bias_weight=1.0,
learn_center_bias_weight=True,
saliency_map_factor=4,
):
"""Creates a new finalizer
Args:
size (tuple): target size for the predictions
sigma (float): standard deviation of the gaussian kernel used for smoothing
kernel_size (int, optional): size of the gaussian kernel
learn_sigma (bool, optional): If True, the standard deviation of the gaussian kernel will
be learned (default: False)
center_bias (string or tensor): the center bias
center_bias_weight (float, optional): initial weight of the center bias
learn_center_bias_weight (bool, optional): If True, the center bias weight will be
learned (default: True)
"""
super(Finalizer, self).__init__()
self.saliency_map_factor = saliency_map_factor
self.gauss = GaussianFilterNd([2, 3], sigma, truncate=3, trainable=learn_sigma)
self.center_bias_weight = nn.Parameter(torch.Tensor([center_bias_weight]), requires_grad=learn_center_bias_weight)
def forward(self, readout, centerbias):
"""Applies the finalization steps to the given readout"""
downscaled_centerbias = F.interpolate(
centerbias.view(centerbias.shape[0], 1, centerbias.shape[1], centerbias.shape[2]),
scale_factor=1 / self.saliency_map_factor,
recompute_scale_factor=False,
)[:, 0, :, :]
out = F.interpolate(
readout,
size=[downscaled_centerbias.shape[1], downscaled_centerbias.shape[2]]
)
# apply gaussian filter
out = self.gauss(out)
# remove channel dimension
out = out[:, 0, :, :]
# add to center bias
out = out + self.center_bias_weight * downscaled_centerbias
out = F.interpolate(out[:, np.newaxis, :, :], size=[centerbias.shape[1], centerbias.shape[2]])[:, 0, :, :]
# normalize
out = out - out.logsumexp(dim=(1, 2), keepdim=True)
return out
class DeepGazeII(torch.nn.Module):
def __init__(self, features, readout_network, downsample=2, readout_factor=16, saliency_map_factor=2, initial_sigma=8.0):
super().__init__()
self.readout_factor = readout_factor
self.saliency_map_factor = saliency_map_factor
self.features = features
for param in self.features.parameters():
param.requires_grad = False
self.features.eval()
self.readout_network = readout_network
self.finalizer = Finalizer(
sigma=initial_sigma,
learn_sigma=True,
saliency_map_factor=self.saliency_map_factor,
)
self.downsample = downsample
def forward(self, x, centerbias):
orig_shape = x.shape
x = F.interpolate(
x,
scale_factor=1 / self.downsample,
recompute_scale_factor=False,
)
x = self.features(x)
readout_shape = [math.ceil(orig_shape[2] / self.downsample / self.readout_factor), math.ceil(orig_shape[3] / self.downsample / self.readout_factor)]
x = [F.interpolate(item, readout_shape) for item in x]
x = torch.cat(x, dim=1)
x = self.readout_network(x)
x = self.finalizer(x, centerbias)
return x
def train(self, mode=True):
self.features.eval()
self.readout_network.train(mode=mode)
self.finalizer.train(mode=mode)
class DeepGazeIII(torch.nn.Module):
def __init__(self, features, saliency_network, scanpath_network, fixation_selection_network, downsample=2, readout_factor=2, saliency_map_factor=2, included_fixations=-2, initial_sigma=8.0):
super().__init__()
self.downsample = downsample
self.readout_factor = readout_factor
self.saliency_map_factor = saliency_map_factor
self.included_fixations = included_fixations
self.features = features
for param in self.features.parameters():
param.requires_grad = False
self.features.eval()
self.saliency_network = saliency_network
self.scanpath_network = scanpath_network
self.fixation_selection_network = fixation_selection_network
self.finalizer = Finalizer(
sigma=initial_sigma,
learn_sigma=True,
saliency_map_factor=self.saliency_map_factor,
)
def forward(self, x, centerbias, x_hist=None, y_hist=None, durations=None):
orig_shape = x.shape
x = F.interpolate(x, scale_factor=1 / self.downsample)
x = self.features(x)
readout_shape = [math.ceil(orig_shape[2] / self.downsample / self.readout_factor), math.ceil(orig_shape[3] / self.downsample / self.readout_factor)]
x = [F.interpolate(item, readout_shape) for item in x]
x = torch.cat(x, dim=1)
x = self.saliency_network(x)
if self.scanpath_network is not None:
scanpath_features = encode_scanpath_features(x_hist, y_hist, size=(orig_shape[2], orig_shape[3]), device=x.device)
#scanpath_features = F.interpolate(scanpath_features, scale_factor=1 / self.downsample / self.readout_factor)
scanpath_features = F.interpolate(scanpath_features, readout_shape)
y = self.scanpath_network(scanpath_features)
else:
y = None
x = self.fixation_selection_network((x, y))
x = self.finalizer(x, centerbias)
return x
def train(self, mode=True):
self.features.eval()
self.saliency_network.train(mode=mode)
if self.scanpath_network is not None:
self.scanpath_network.train(mode=mode)
self.fixation_selection_network.train(mode=mode)
self.finalizer.train(mode=mode)
class DeepGazeIIIMixture(torch.nn.Module):
def __init__(self, features, saliency_networks, scanpath_networks, fixation_selection_networks, finalizers, downsample=2, readout_factor=2, saliency_map_factor=2, included_fixations=-2, initial_sigma=8.0):
super().__init__()
self.downsample = downsample
self.readout_factor = readout_factor
self.saliency_map_factor = saliency_map_factor
self.included_fixations = included_fixations
self.features = features
for param in self.features.parameters():
param.requires_grad = False
self.features.eval()
self.saliency_networks = torch.nn.ModuleList(saliency_networks)
self.scanpath_networks = torch.nn.ModuleList(scanpath_networks)
self.fixation_selection_networks = torch.nn.ModuleList(fixation_selection_networks)
self.finalizers = torch.nn.ModuleList(finalizers)
def forward(self, x, centerbias, x_hist=None, y_hist=None, durations=None):
orig_shape = x.shape
x = F.interpolate(
x,
scale_factor=1 / self.downsample,
recompute_scale_factor=False,
)
x = self.features(x)
readout_shape = [math.ceil(orig_shape[2] / self.downsample / self.readout_factor), math.ceil(orig_shape[3] / self.downsample / self.readout_factor)]
x = [F.interpolate(item, readout_shape) for item in x]
x = torch.cat(x, dim=1)
predictions = []
readout_input = x
for saliency_network, scanpath_network, fixation_selection_network, finalizer in zip(
self.saliency_networks, self.scanpath_networks, self.fixation_selection_networks, self.finalizers
):
x = saliency_network(readout_input)
if scanpath_network is not None:
scanpath_features = encode_scanpath_features(x_hist, y_hist, size=(orig_shape[2], orig_shape[3]), device=x.device)
scanpath_features = F.interpolate(scanpath_features, readout_shape)
y = scanpath_network(scanpath_features)
else:
y = None
x = fixation_selection_network((x, y))
x = finalizer(x, centerbias)
predictions.append(x[:, np.newaxis, :, :])
predictions = torch.cat(predictions, dim=1) - np.log(len(self.saliency_networks))
prediction = predictions.logsumexp(dim=(1), keepdim=True)
return prediction
class MixtureModel(torch.nn.Module):
def __init__(self, models):
super().__init__()
self.models = torch.nn.ModuleList(models)
def forward(self, *args, **kwargs):
predictions = [model.forward(*args, **kwargs) for model in self.models]
predictions = torch.cat(predictions, dim=1)
predictions -= np.log(len(self.models))
prediction = predictions.logsumexp(dim=(1), keepdim=True)
return prediction