Spaces:
Runtime error
Runtime error
| from torchvision.models import resnet50, ResNet50_Weights | |
| import torch.nn as nn | |
| class EncoderCNN(nn.Module): | |
| def __init__(self, embed_size, fine_tune=False): | |
| super(EncoderCNN, self).__init__() | |
| resnet = resnet50(weights=ResNet50_Weights.DEFAULT if fine_tune else None) | |
| for param in resnet.parameters(): | |
| param.requires_grad = False | |
| if fine_tune: | |
| for param in resnet.layer4.parameters(): | |
| param.requires_grad = True | |
| backbone = list(resnet.children())[:-1] | |
| self.resnet = nn.Sequential(*backbone) | |
| self.fc = nn.Linear(resnet.fc.in_features, embed_size) | |
| self.bn = nn.BatchNorm1d(num_features=embed_size, momentum=0.01) | |
| def forward(self, images): # (B, C, W, H) | |
| features = self.resnet(images) # (B, 2048, 1, 1) | |
| features = features.reshape(features.shape[0], -1) # (B, 2048*1*1) not necessay to reshape as fc layer can take any size input | |
| return self.bn(self.fc(features)) # (B, embed_size) | |