| from typing import Sequence |
|
|
| from itertools import chain |
|
|
| import torch |
| import torch.nn as nn |
| from torchvision import models |
|
|
| from eval_tool.lpips.utils import normalize_activation |
|
|
|
|
| def get_network(net_type: str): |
| if net_type == 'alex': |
| return AlexNet() |
| elif net_type == 'squeeze': |
| return SqueezeNet() |
| elif net_type == 'vgg': |
| return VGG16() |
| else: |
| raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') |
|
|
|
|
| class LinLayers(nn.ModuleList): |
| def __init__(self, n_channels_list: Sequence[int]): |
| super(LinLayers, self).__init__([ |
| nn.Sequential( |
| nn.Identity(), |
| nn.Conv2d(nc, 1, 1, 1, 0, bias=False) |
| ) for nc in n_channels_list |
| ]) |
|
|
| for param in self.parameters(): |
| param.requires_grad = False |
|
|
|
|
| class BaseNet(nn.Module): |
| def __init__(self): |
| super(BaseNet, self).__init__() |
|
|
| |
| self.register_buffer( |
| 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) |
| self.register_buffer( |
| 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) |
|
|
| def set_requires_grad(self, state: bool): |
| for param in chain(self.parameters(), self.buffers()): |
| param.requires_grad = state |
|
|
| def z_score(self, x: torch.Tensor): |
| return (x - self.mean) / self.std |
|
|
| def forward(self, x: torch.Tensor): |
| x = self.z_score(x) |
|
|
| output = [] |
| for i, (_, layer) in enumerate(self.layers._modules.items(), 1): |
| x = layer(x) |
| if i in self.target_layers: |
| output.append(normalize_activation(x)) |
| if len(output) == len(self.target_layers): |
| break |
| return output |
|
|
|
|
| class SqueezeNet(BaseNet): |
| def __init__(self): |
| super(SqueezeNet, self).__init__() |
|
|
| self.layers = models.squeezenet1_1(True).features |
| self.target_layers = [2, 5, 8, 10, 11, 12, 13] |
| self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] |
|
|
| self.set_requires_grad(False) |
|
|
|
|
| class AlexNet(BaseNet): |
| def __init__(self): |
| super(AlexNet, self).__init__() |
|
|
| self.layers = models.alexnet(True).features |
| self.target_layers = [2, 5, 8, 10, 12] |
| self.n_channels_list = [64, 192, 384, 256, 256] |
|
|
| self.set_requires_grad(False) |
|
|
|
|
| class VGG16(BaseNet): |
| def __init__(self): |
| super(VGG16, self).__init__() |
|
|
| self.layers = models.vgg16(True).features |
| self.target_layers = [4, 9, 16, 23, 30] |
| self.n_channels_list = [64, 128, 256, 512, 512] |
|
|
| self.set_requires_grad(False) |