""" This code was adapted from: https://github.com/rgeirhos/texture-vs-shape """ import os import sys from collections import OrderedDict import torch import torch.nn as nn import torchvision import torchvision.models from torch.utils import model_zoo from .normalizer import Normalizer def load_model(model_name): model_urls = { 'resnet50_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/6f41d2e86fc60566f78de64ecff35cc61eb6436f/resnet50_train_60_epochs-c8e5653e.pth.tar', 'resnet50_trained_on_SIN_and_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_train_45_epochs_combined_IN_SF-2a0d100e.pth.tar', 'resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar', 'vgg16_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/0008049cd10f74a944c6d5e90d4639927f8620ae/vgg16_train_60_epochs_lr0.01-6c6fcc9f.pth.tar', 'alexnet_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/0008049cd10f74a944c6d5e90d4639927f8620ae/alexnet_train_60_epochs_lr0.001-b4aa5238.pth.tar', } if "resnet50" in model_name: #print("Using the ResNet50 architecture.") model = torchvision.models.resnet50(pretrained=False) #model = torch.nn.DataParallel(model) # .cuda() # fake DataParallel structrue model = torch.nn.Sequential(OrderedDict([('module', model)])) checkpoint = model_zoo.load_url(model_urls[model_name], map_location=torch.device('cpu')) elif "vgg16" in model_name: #print("Using the VGG-16 architecture.") # download model from URL manually and save to desired location filepath = "./vgg16_train_60_epochs_lr0.01-6c6fcc9f.pth.tar" assert os.path.exists(filepath), "Please download the VGG model yourself from the following link and save it locally: https://drive.google.com/drive/folders/1A0vUWyU6fTuc-xWgwQQeBvzbwi6geYQK (too large to be downloaded automatically like the other models)" model = torchvision.models.vgg16(pretrained=False) model.features = torch.nn.DataParallel(model.features) model.cuda() checkpoint = torch.load(filepath, map_location=torch.device('cpu')) elif "alexnet" in model_name: #print("Using the AlexNet architecture.") model = torchvision.models.alexnet(pretrained=False) model.features = torch.nn.DataParallel(model.features) model.cuda() checkpoint = model_zoo.load_url(model_urls[model_name], map_location=torch.device('cpu')) else: raise ValueError("unknown model architecture.") model.load_state_dict(checkpoint["state_dict"]) return model # --- DeepGaze Adaptation ---- class RGBShapeNetA(nn.Sequential): def __init__(self): super(RGBShapeNetA, self).__init__() self.shapenet = load_model("resnet50_trained_on_SIN") self.normalizer = Normalizer() super(RGBShapeNetA, self).__init__(self.normalizer, self.shapenet) class RGBShapeNetB(nn.Sequential): def __init__(self): super(RGBShapeNetB, self).__init__() self.shapenet = load_model("resnet50_trained_on_SIN_and_IN") self.normalizer = Normalizer() super(RGBShapeNetB, self).__init__(self.normalizer, self.shapenet) class RGBShapeNetC(nn.Sequential): def __init__(self): super(RGBShapeNetC, self).__init__() self.shapenet = load_model("resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN") self.normalizer = Normalizer() super(RGBShapeNetC, self).__init__(self.normalizer, self.shapenet)