Spaces:
Build error
Build error
File size: 2,768 Bytes
ae931ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
import torch.nn as nn
import torchvision.models as models
#----------------------------------------------------------------------------
class EncoderCNN(nn.Module):
def __init__(self, embed_size) -> None:
super().__init__()
self.inception = models.inception_v3(pretrained=True, aux_logits=False)
for param in self.inception.parameters():
param.requires_grad = False
self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
self.relu = nn.ReLU(True)
self.times = []
self.dropout = nn.Dropout(0.5)
def forward(self, imgs):
features = self.inception(imgs)
return self.dropout(self.relu(features))
#----------------------------------------------------------------------------
class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers) -> None:
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.LSTM = nn.LSTM(embed_size, hidden_size, num_layers)
self.linear = nn.Linear(hidden_size, vocab_size)
self.dropout = nn.Dropout(0.5)
def forward(self, features, captions):
embbedings = self.dropout(self.embed(captions))
# unsqueeze(0) 添加时间维度seq_len
embbedings = torch.cat([features.unsqueeze(0), embbedings], dim=0)
hiddens, _ = self.LSTM(embbedings)
outputs = self.linear(hiddens)
return outputs
#----------------------------------------------------------------------------
class CNNtoRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers) -> None:
super().__init__()
self.encoderCNN = EncoderCNN(embed_size)
self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)
def forward(self, imgs, captions):
features = self.encoderCNN(imgs)
outputs = self.decoderRNN(features, captions)
return outputs
def caption_image(self, img, vocab, max_length=50):
result_caption = []
with torch.no_grad():
x = self.encoderCNN(img).unsqueeze(0)
states = None
for _ in range(max_length):
# 逐个预测
h, states = self.decoderRNN.LSTM(x, states)
output = self.decoderRNN.linear(h.squeeze(0))
predicted = output.argmax(1)
# 预测的值作为下一次预测的输入
result_caption.append(predicted.item())
x = self.decoderRNN.embed(predicted).unsqueeze(0)
if vocab.itos[predicted.item()] == '<EOS>':
break
return [vocab.itos[idx] for idx in result_caption]
|