caption-gen / encoder.py
Sher1988's picture
Change structure of the project.
eb55711
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn
class EncoderCNN(nn.Module):
def __init__(self, embed_size, fine_tune=False):
super(EncoderCNN, self).__init__()
resnet = resnet50(weights=ResNet50_Weights.DEFAULT if fine_tune else None)
for param in resnet.parameters():
param.requires_grad = False
if fine_tune:
for param in resnet.layer4.parameters():
param.requires_grad = True
backbone = list(resnet.children())[:-1]
self.resnet = nn.Sequential(*backbone)
self.fc = nn.Linear(resnet.fc.in_features, embed_size)
self.bn = nn.BatchNorm1d(num_features=embed_size, momentum=0.01)
def forward(self, images): # (B, C, W, H)
features = self.resnet(images) # (B, 2048, 1, 1)
features = features.reshape(features.shape[0], -1) # (B, 2048*1*1) not necessay to reshape as fc layer can take any size input
return self.bn(self.fc(features)) # (B, embed_size)