| | from typing import Sequence |
| |
|
| | from itertools import chain |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torchvision import models |
| |
|
| | from .utils import normalize_activation |
| |
|
| |
|
| | def get_network(net_type: str): |
| | if net_type == 'alex': |
| | return AlexNet() |
| | elif net_type == 'squeeze': |
| | return SqueezeNet() |
| | elif net_type == 'vgg': |
| | return VGG16() |
| | else: |
| | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') |
| |
|
| |
|
| | class LinLayers(nn.ModuleList): |
| | def __init__(self, n_channels_list: Sequence[int]): |
| | super(LinLayers, self).__init__([ |
| | nn.Sequential( |
| | nn.Identity(), |
| | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) |
| | ) for nc in n_channels_list |
| | ]) |
| |
|
| | for param in self.parameters(): |
| | param.requires_grad = False |
| |
|
| |
|
| | class BaseNet(nn.Module): |
| | def __init__(self): |
| | super(BaseNet, self).__init__() |
| |
|
| | |
| | self.register_buffer( |
| | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) |
| | self.register_buffer( |
| | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) |
| |
|
| | def set_requires_grad(self, state: bool): |
| | for param in chain(self.parameters(), self.buffers()): |
| | param.requires_grad = state |
| |
|
| | def z_score(self, x: torch.Tensor): |
| | return (x - self.mean) / self.std |
| |
|
| | def forward(self, x: torch.Tensor): |
| | x = self.z_score(x) |
| |
|
| | output = [] |
| | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): |
| | x = layer(x) |
| | if i in self.target_layers: |
| | output.append(normalize_activation(x)) |
| | if len(output) == len(self.target_layers): |
| | break |
| | return output |
| |
|
| |
|
| | class SqueezeNet(BaseNet): |
| | def __init__(self): |
| | super(SqueezeNet, self).__init__() |
| |
|
| | self.layers = models.squeezenet1_1(True).features |
| | self.target_layers = [2, 5, 8, 10, 11, 12, 13] |
| | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] |
| |
|
| | self.set_requires_grad(False) |
| |
|
| |
|
| | class AlexNet(BaseNet): |
| | def __init__(self): |
| | super(AlexNet, self).__init__() |
| |
|
| | self.layers = models.alexnet(True).features |
| | self.target_layers = [2, 5, 8, 10, 12] |
| | self.n_channels_list = [64, 192, 384, 256, 256] |
| |
|
| | self.set_requires_grad(False) |
| |
|
| |
|
| | class VGG16(BaseNet): |
| | def __init__(self): |
| | super(VGG16, self).__init__() |
| |
|
| | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features |
| | self.target_layers = [4, 9, 16, 23, 30] |
| | self.n_channels_list = [64, 128, 256, 512, 512] |
| |
|
| | self.set_requires_grad(False) |
| |
|