image_to_text / model.py
ovi054
first commit
00abfdc
raw
history blame contribute delete
No virus
2.66 kB
import torch
import torch.nn as nn
import torchvision.models as models
class EncoderCNN(nn.Module):
def __init__(self, embed_size):
super(EncoderCNN, self).__init__()
resnet = models.resnet50(pretrained=True)
for param in resnet.parameters():
param.requires_grad_(False)
modules = list(resnet.children())[:-1]
self.resnet = nn.Sequential(*modules)
self.embed = nn.Linear(resnet.fc.in_features, embed_size)
def forward(self, images):
features = self.resnet(images)
features = features.view(features.size(0), -1)
features = self.embed(features)
return features
class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
super(DecoderRNN, self).__init__()
self.hidden_dim = hidden_size
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
self.hidden = (torch.zeros(1, 1, hidden_size),torch.zeros(1, 1, hidden_size))
def forward(self, features, captions):
cap_embedding = self.embed(captions[:,:-1])
embeddings = torch.cat((features.unsqueeze(1), cap_embedding), 1)
#print('in decoderrnn forward, embedding shape ', embeddings.shape)
#packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
#lstm_out, self.hidden = self.lstm(embeddings, self.hidden)
#lstm_out, self.hidden = self.lstm(embeddings.view(len(embeddings), 1, -1), self.hidden)
lstm_out, self.hidden = self.lstm(embeddings)
outputs = self.linear(lstm_out)
#return outputs[:,1:,:]
return outputs
def sample(self, inputs, hidden=None, max_len=20):
" accepts pre-processed image tensor (inputs) and returns predicted sentence (list of tensor ids of length max_len) "
res = []
for i in range(max_len):
outputs, hidden = self.lstm(inputs, hidden)
# print('lstm output shape ', outputs.shape)
# print('lstm output.squeeze(1) shape ', outputs.squeeze(1).shape)
outputs = self.linear(outputs.squeeze(1))
# print('linear output shape ', outputs.shape)
target_index = outputs.max(1)[1]
# print('target_index shape ', target_index.shape)
res.append(target_index.item())
inputs = self.embed(target_index).unsqueeze(1)
# print('new inputs shape ', inputs.shape, '\n')
return res