Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
class EncoderCNN(nn.Module): | |
def __init__(self, embed_size): | |
super(EncoderCNN, self).__init__() | |
resnet = models.resnet50(pretrained=True) | |
for param in resnet.parameters(): | |
param.requires_grad_(False) | |
modules = list(resnet.children())[:-1] | |
self.resnet = nn.Sequential(*modules) | |
self.embed = nn.Linear(resnet.fc.in_features, embed_size) | |
def forward(self, images): | |
features = self.resnet(images) | |
features = features.view(features.size(0), -1) | |
features = self.embed(features) | |
return features | |
class DecoderRNN(nn.Module): | |
def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1): | |
super(DecoderRNN, self).__init__() | |
self.hidden_dim = hidden_size | |
self.embed = nn.Embedding(vocab_size, embed_size) | |
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) | |
self.linear = nn.Linear(hidden_size, vocab_size) | |
self.hidden = (torch.zeros(1, 1, hidden_size),torch.zeros(1, 1, hidden_size)) | |
def forward(self, features, captions): | |
cap_embedding = self.embed(captions[:,:-1]) | |
embeddings = torch.cat((features.unsqueeze(1), cap_embedding), 1) | |
#print('in decoderrnn forward, embedding shape ', embeddings.shape) | |
#packed = pack_padded_sequence(embeddings, lengths, batch_first=True) | |
#lstm_out, self.hidden = self.lstm(embeddings, self.hidden) | |
#lstm_out, self.hidden = self.lstm(embeddings.view(len(embeddings), 1, -1), self.hidden) | |
lstm_out, self.hidden = self.lstm(embeddings) | |
outputs = self.linear(lstm_out) | |
#return outputs[:,1:,:] | |
return outputs | |
def sample(self, inputs, hidden=None, max_len=20): | |
" accepts pre-processed image tensor (inputs) and returns predicted sentence (list of tensor ids of length max_len) " | |
res = [] | |
for i in range(max_len): | |
outputs, hidden = self.lstm(inputs, hidden) | |
# print('lstm output shape ', outputs.shape) | |
# print('lstm output.squeeze(1) shape ', outputs.squeeze(1).shape) | |
outputs = self.linear(outputs.squeeze(1)) | |
# print('linear output shape ', outputs.shape) | |
target_index = outputs.max(1)[1] | |
# print('target_index shape ', target_index.shape) | |
res.append(target_index.item()) | |
inputs = self.embed(target_index).unsqueeze(1) | |
# print('new inputs shape ', inputs.shape, '\n') | |
return res |