|
|
from typing import Sequence |
|
|
|
|
|
from itertools import chain |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import models |
|
|
|
|
|
from criteria.lpips.utils import normalize_activation |
|
|
|
|
|
|
|
|
def get_network(net_type: str): |
|
|
if net_type == "alex": |
|
|
return AlexNet() |
|
|
elif net_type == "squeeze": |
|
|
return SqueezeNet() |
|
|
elif net_type == "vgg": |
|
|
return VGG16() |
|
|
else: |
|
|
raise NotImplementedError("choose net_type from [alex, squeeze, vgg].") |
|
|
|
|
|
|
|
|
class LinLayers(nn.ModuleList): |
|
|
def __init__(self, n_channels_list: Sequence[int]): |
|
|
super(LinLayers, self).__init__( |
|
|
[ |
|
|
nn.Sequential(nn.Identity(), nn.Conv2d(nc, 1, 1, 1, 0, bias=False)) |
|
|
for nc in n_channels_list |
|
|
] |
|
|
) |
|
|
|
|
|
for param in self.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
class BaseNet(nn.Module): |
|
|
def __init__(self): |
|
|
super(BaseNet, self).__init__() |
|
|
|
|
|
|
|
|
self.register_buffer( |
|
|
"mean", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] |
|
|
) |
|
|
self.register_buffer( |
|
|
"std", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] |
|
|
) |
|
|
|
|
|
def set_requires_grad(self, state: bool): |
|
|
for param in chain(self.parameters(), self.buffers()): |
|
|
param.requires_grad = state |
|
|
|
|
|
def z_score(self, x: torch.Tensor): |
|
|
return (x - self.mean) / self.std |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
x = self.z_score(x) |
|
|
|
|
|
output = [] |
|
|
for i, (_, layer) in enumerate(self.layers._modules.items(), 1): |
|
|
x = layer(x) |
|
|
if i in self.target_layers: |
|
|
output.append(normalize_activation(x)) |
|
|
if len(output) == len(self.target_layers): |
|
|
break |
|
|
return output |
|
|
|
|
|
|
|
|
class SqueezeNet(BaseNet): |
|
|
def __init__(self): |
|
|
super(SqueezeNet, self).__init__() |
|
|
|
|
|
self.layers = models.squeezenet1_1(True).features |
|
|
self.target_layers = [2, 5, 8, 10, 11, 12, 13] |
|
|
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] |
|
|
|
|
|
self.set_requires_grad(False) |
|
|
|
|
|
|
|
|
class AlexNet(BaseNet): |
|
|
def __init__(self): |
|
|
super(AlexNet, self).__init__() |
|
|
|
|
|
self.layers = models.alexnet(True).features |
|
|
self.target_layers = [2, 5, 8, 10, 12] |
|
|
self.n_channels_list = [64, 192, 384, 256, 256] |
|
|
|
|
|
self.set_requires_grad(False) |
|
|
|
|
|
|
|
|
class VGG16(BaseNet): |
|
|
def __init__(self): |
|
|
super(VGG16, self).__init__() |
|
|
|
|
|
self.layers = models.vgg16(True).features |
|
|
self.target_layers = [4, 9, 16, 23, 30] |
|
|
self.n_channels_list = [64, 128, 256, 512, 512] |
|
|
|
|
|
self.set_requires_grad(False) |
|
|
|