File size: 2,655 Bytes
00abfdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
import torchvision.models as models


class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad_(False)
        
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.embed = nn.Linear(resnet.fc.in_features, embed_size)

    def forward(self, images):
        features = self.resnet(images)
        features = features.view(features.size(0), -1)
        features = self.embed(features)
        return features
    

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        
        self.hidden_dim = hidden_size
        
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.hidden = (torch.zeros(1, 1, hidden_size),torch.zeros(1, 1, hidden_size))
    
    def forward(self, features, captions):
        cap_embedding = self.embed(captions[:,:-1])
        embeddings = torch.cat((features.unsqueeze(1), cap_embedding), 1)
        #print('in decoderrnn forward, embedding shape ', embeddings.shape)
        #packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
        
        #lstm_out, self.hidden = self.lstm(embeddings, self.hidden)  
        #lstm_out, self.hidden = self.lstm(embeddings.view(len(embeddings), 1, -1), self.hidden) 
        lstm_out, self.hidden = self.lstm(embeddings)
        outputs = self.linear(lstm_out)
        
        #return outputs[:,1:,:]
        return outputs


    def sample(self, inputs, hidden=None, max_len=20):
        " accepts pre-processed image tensor (inputs) and returns predicted sentence (list of tensor ids of length max_len) "
        res = []
        for i in range(max_len):
            outputs, hidden = self.lstm(inputs, hidden)
#             print('lstm output shape ', outputs.shape)
#             print('lstm output.squeeze(1) shape ', outputs.squeeze(1).shape)
            outputs = self.linear(outputs.squeeze(1))
#             print('linear output shape ', outputs.shape)
            target_index = outputs.max(1)[1]
#             print('target_index shape ', target_index.shape)
            res.append(target_index.item())
            inputs = self.embed(target_index).unsqueeze(1)
#             print('new inputs shape ', inputs.shape, '\n')
        return res