from collections import OrderedDict import importlib import os import torch import torch.nn as nn import torch.nn.functional as F from torch.utils import model_zoo from .modules import FeatureExtractor, Finalizer, DeepGazeIIIMixture, MixtureModel from .layers import ( Conv2dMultiInput, LayerNorm, LayerNormMultiInput, Bias, ) BACKBONES = [ { 'type': 'deepgaze_pytorch.features.shapenet.RGBShapeNetC', 'used_features': [ '1.module.layer3.0.conv2', '1.module.layer3.3.conv2', '1.module.layer3.5.conv1', '1.module.layer3.5.conv2', '1.module.layer4.1.conv2', '1.module.layer4.2.conv2', ], 'channels': 2048, }, { 'type': 'deepgaze_pytorch.features.efficientnet.RGBEfficientNetB5', 'used_features': [ '1._blocks.24._depthwise_conv', '1._blocks.26._depthwise_conv', '1._blocks.35._project_conv', ], 'channels': 2416, }, { 'type': 'deepgaze_pytorch.features.densenet.RGBDenseNet201', 'used_features': [ '1.features.denseblock4.denselayer32.norm1', '1.features.denseblock4.denselayer32.conv1', '1.features.denseblock4.denselayer31.conv2', ], 'channels': 2048, }, { 'type': 'deepgaze_pytorch.features.resnext.RGBResNext50', 'used_features': [ '1.layer3.5.conv1', '1.layer3.5.conv2', '1.layer3.4.conv2', '1.layer4.2.conv2', ], 'channels': 2560, }, ] def build_saliency_network(input_channels): return nn.Sequential(OrderedDict([ ('layernorm0', LayerNorm(input_channels)), ('conv0', nn.Conv2d(input_channels, 8, (1, 1), bias=False)), ('bias0', Bias(8)), ('softplus0', nn.Softplus()), ('layernorm1', LayerNorm(8)), ('conv1', nn.Conv2d(8, 16, (1, 1), bias=False)), ('bias1', Bias(16)), ('softplus1', nn.Softplus()), ('layernorm2', LayerNorm(16)), ('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)), ('bias2', Bias(1)), ('softplus3', nn.Softplus()), ])) def build_fixation_selection_network(): return nn.Sequential(OrderedDict([ ('layernorm0', LayerNormMultiInput([1, 0])), ('conv0', Conv2dMultiInput([1, 0], 128, (1, 1), bias=False)), ('bias0', Bias(128)), ('softplus0', nn.Softplus()), ('layernorm1', LayerNorm(128)), ('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)), ('bias1', Bias(16)), ('softplus1', nn.Softplus()), ('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)), ])) def build_deepgaze_mixture(backbone_config, components=10): feature_class = import_class(backbone_config['type']) features = feature_class() feature_extractor = FeatureExtractor(features, backbone_config['used_features']) saliency_networks = [] scanpath_networks = [] fixation_selection_networks = [] finalizers = [] for component in range(components): saliency_network = build_saliency_network(backbone_config['channels']) fixation_selection_network = build_fixation_selection_network() saliency_networks.append(saliency_network) scanpath_networks.append(None) fixation_selection_networks.append(fixation_selection_network) finalizers.append(Finalizer(sigma=8.0, learn_sigma=True, saliency_map_factor=2)) return DeepGazeIIIMixture( features=feature_extractor, saliency_networks=saliency_networks, scanpath_networks=scanpath_networks, fixation_selection_networks=fixation_selection_networks, finalizers=finalizers, downsample=2, readout_factor=16, saliency_map_factor=2, included_fixations=[], ) class DeepGazeIIE(MixtureModel): """DeepGazeIIE model :note See Linardos, A., Kümmerer, M., Press, O., & Bethge, M. (2021). Calibrated prediction in and out-of-domain for state-of-the-art saliency modeling. ArXiv:2105.12441 [Cs], http://arxiv.org/abs/2105.12441 """ def __init__(self, pretrained=True): # we average over 3 instances per backbone, each instance has 10 crossvalidation folds backbone_models = [build_deepgaze_mixture(backbone_config, components=3 * 10) for backbone_config in BACKBONES] super().__init__(backbone_models) if pretrained: self.load_state_dict(model_zoo.load_url('https://github.com/matthias-k/DeepGaze/releases/download/v1.0.0/deepgaze2e.pth', map_location=torch.device('cpu'))) def import_class(name): module_name, class_name = name.rsplit('.', 1) module = importlib.import_module(module_name) return getattr(module, class_name)