ImageCaptionDemo / model.py
leonhardt's picture
add app
ae931ab
raw
history blame
No virus
2.77 kB
import torch
import torch.nn as nn
import torchvision.models as models
#----------------------------------------------------------------------------
class EncoderCNN(nn.Module):
def __init__(self, embed_size) -> None:
super().__init__()
self.inception = models.inception_v3(pretrained=True, aux_logits=False)
for param in self.inception.parameters():
param.requires_grad = False
self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
self.relu = nn.ReLU(True)
self.times = []
self.dropout = nn.Dropout(0.5)
def forward(self, imgs):
features = self.inception(imgs)
return self.dropout(self.relu(features))
#----------------------------------------------------------------------------
class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers) -> None:
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.LSTM = nn.LSTM(embed_size, hidden_size, num_layers)
self.linear = nn.Linear(hidden_size, vocab_size)
self.dropout = nn.Dropout(0.5)
def forward(self, features, captions):
embbedings = self.dropout(self.embed(captions))
# unsqueeze(0) 添加时间维度seq_len
embbedings = torch.cat([features.unsqueeze(0), embbedings], dim=0)
hiddens, _ = self.LSTM(embbedings)
outputs = self.linear(hiddens)
return outputs
#----------------------------------------------------------------------------
class CNNtoRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers) -> None:
super().__init__()
self.encoderCNN = EncoderCNN(embed_size)
self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)
def forward(self, imgs, captions):
features = self.encoderCNN(imgs)
outputs = self.decoderRNN(features, captions)
return outputs
def caption_image(self, img, vocab, max_length=50):
result_caption = []
with torch.no_grad():
x = self.encoderCNN(img).unsqueeze(0)
states = None
for _ in range(max_length):
# 逐个预测
h, states = self.decoderRNN.LSTM(x, states)
output = self.decoderRNN.linear(h.squeeze(0))
predicted = output.argmax(1)
# 预测的值作为下一次预测的输入
result_caption.append(predicted.item())
x = self.decoderRNN.embed(predicted).unsqueeze(0)
if vocab.itos[predicted.item()] == '<EOS>':
break
return [vocab.itos[idx] for idx in result_caption]