Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from torchvision import models | |
from einops import rearrange | |
from torchvision.models._utils import IntermediateLayerGetter | |
class Vgg(nn.Module): | |
def __init__(self, name, ss, ks, hidden, pretrained=True, dropout=0.5): | |
super(Vgg, self).__init__() | |
if name == 'vgg11_bn': | |
cnn = models.vgg11_bn(weights='DEFAULT') | |
elif name == 'vgg19_bn': | |
cnn = models.vgg19_bn(weights='DEFAULT') | |
pool_idx = 0 | |
for i, layer in enumerate(cnn.features): | |
if isinstance(layer, torch.nn.MaxPool2d): | |
cnn.features[i] = torch.nn.AvgPool2d(kernel_size=ks[pool_idx], stride=ss[pool_idx], padding=0) | |
pool_idx += 1 | |
self.features = cnn.features | |
self.dropout = nn.Dropout(dropout) | |
self.last_conv_1x1 = nn.Conv2d(512, hidden, 1) | |
def forward(self, x): | |
""" | |
Shape: | |
- x: (N, C, H, W) | |
- output: (W, N, C) | |
""" | |
conv = self.features(x) | |
conv = self.dropout(conv) | |
conv = self.last_conv_1x1(conv) | |
# conv = rearrange(conv, 'b d h w -> b d (w h)') | |
conv = conv.transpose(-1, -2) | |
conv = conv.flatten(2) | |
conv = conv.permute(-1, 0, 1) | |
return conv | |
def vgg11_bn(ss, ks, hidden, pretrained=True, dropout=0.5): | |
return Vgg('vgg11_bn', ss, ks, hidden, pretrained, dropout) | |
def vgg19_bn(ss, ks, hidden, pretrained=True, dropout=0.5): | |
return Vgg('vgg19_bn', ss, ks, hidden, pretrained, dropout) | |