leonhardt commited on
Commit
ae931ab
1 Parent(s): f94bfc7
Files changed (6) hide show
  1. app.py +51 -0
  2. captions.txt +0 -0
  3. get_loader.py +141 -0
  4. model.py +72 -0
  5. requirements.txt +2 -0
  6. utils.py +69 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import torch
4
+ from torchvision.transforms import transforms
5
+
6
+ from get_loader import Vocabulary
7
+ from model import CNNtoRNN
8
+
9
+
10
+ def predict(img):
11
+ img = transform(img)
12
+ output = model.caption_image(img.unsqueeze(0).to(device), vocab)
13
+ return " ".join(output[1:-1])
14
+
15
+ if __name__ == '__main__':
16
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+ model = CNNtoRNN(
18
+ embed_size=256,
19
+ hidden_size=256,
20
+ vocab_size=2994,
21
+ num_layers=1
22
+ )
23
+ print('Loading model weights...')
24
+ model.load_state_dict(torch.load(
25
+ 'my_checkpoint.pth.tar', map_location=device)["state_dict"])
26
+ model.to(device)
27
+ model.eval()
28
+ print('Building vocabulary...')
29
+ vocab = Vocabulary(5)
30
+ df = pd.read_csv('captions.txt')
31
+ vocab.build_vocabulary(df['caption'].tolist())
32
+
33
+ transform = transforms.Compose(
34
+ [
35
+ transforms.ToPILImage(),
36
+ transforms.Resize((299, 299)),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
39
+ ]
40
+ )
41
+ print('Creating app...')
42
+ app = gr.Interface(
43
+ fn=predict,
44
+ inputs=gr.Image(shape=(256, 256)),
45
+ outputs="text",
46
+ )
47
+ print('done!')
48
+
49
+ app.launch(
50
+ share=True,
51
+ )
captions.txt ADDED
The diff for this file is too large to render. See raw diff
 
get_loader.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os # when loading file paths
2
+
3
+ import pandas as pd # for lookup in annotation file
4
+ import spacy # for tokenizer
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image # Load img
8
+ from torch.nn.utils.rnn import pad_sequence # pad batch
9
+ from torch.utils.data import DataLoader, Dataset
10
+
11
+ # We want to convert text -> numerical values
12
+ # 1. We need a Vocabulary mapping each word to a index
13
+ # 2. We need to setup a Pytorch dataset to load the data
14
+ # 3. Setup padding of every batch (all examples should be
15
+ # of same seq_len and setup dataloader)
16
+
17
+ # Download with: python -m spacy download en
18
+ spacy_eng = spacy.load("en_core_web_sm")
19
+
20
+
21
+ class Vocabulary:
22
+ def __init__(self, freq_threshold):
23
+ self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
24
+ self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
25
+ self.freq_threshold = freq_threshold
26
+
27
+ def __len__(self):
28
+ return len(self.itos)
29
+
30
+ @staticmethod
31
+ def tokenizer_eng(text):
32
+ return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
33
+
34
+ def build_vocabulary(self, sentence_list):
35
+ frequencies = {}
36
+ idx = 4
37
+
38
+ for sentence in sentence_list:
39
+ for word in self.tokenizer_eng(sentence):
40
+ if word not in frequencies:
41
+ frequencies[word] = 1
42
+
43
+ else:
44
+ frequencies[word] += 1
45
+
46
+ if frequencies[word] == self.freq_threshold:
47
+ self.stoi[word] = idx
48
+ self.itos[idx] = word
49
+ idx += 1
50
+
51
+ def numericalize(self, text):
52
+ tokenized_text = self.tokenizer_eng(text)
53
+
54
+ return [
55
+ self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
56
+ for token in tokenized_text
57
+ ]
58
+
59
+
60
+ class FlickrDataset(Dataset):
61
+ def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
62
+ self.root_dir = root_dir
63
+ self.df = pd.read_csv(captions_file)
64
+ self.transform = transform
65
+
66
+ # Get img, caption columns
67
+ self.imgs = self.df["image"]
68
+ self.captions = self.df["caption"]
69
+
70
+ # Initialize vocabulary and build vocab
71
+ self.vocab = Vocabulary(freq_threshold)
72
+ self.vocab.build_vocabulary(self.captions.tolist())
73
+
74
+ def __len__(self):
75
+ return len(self.df)
76
+
77
+ def __getitem__(self, index):
78
+ caption = self.captions[index]
79
+ img_id = self.imgs[index]
80
+ img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
81
+
82
+ if self.transform is not None:
83
+ img = self.transform(img)
84
+
85
+ numericalized_caption = [self.vocab.stoi["<SOS>"]]
86
+ numericalized_caption += self.vocab.numericalize(caption)
87
+ numericalized_caption.append(self.vocab.stoi["<EOS>"])
88
+
89
+ return img, torch.tensor(numericalized_caption)
90
+
91
+
92
+ class MyCollate:
93
+ def __init__(self, pad_idx):
94
+ self.pad_idx = pad_idx
95
+
96
+ def __call__(self, batch):
97
+ imgs = [item[0].unsqueeze(0) for item in batch]
98
+ imgs = torch.cat(imgs, dim=0) # [BCHW]
99
+ targets = [item[1] for item in batch] # [BL] L长度不同
100
+ targets = pad_sequence(targets, batch_first=False, # [LB] L长度相同
101
+ padding_value=self.pad_idx)
102
+ return imgs, targets
103
+
104
+
105
+ def get_loader(
106
+ root_folder,
107
+ annotation_file,
108
+ transform,
109
+ batch_size=32,
110
+ num_workers=8,
111
+ shuffle=True,
112
+ pin_memory=True,
113
+ ):
114
+ dataset = FlickrDataset(root_folder, annotation_file, transform=transform)
115
+
116
+ pad_idx = dataset.vocab.stoi["<PAD>"]
117
+
118
+ loader = DataLoader(
119
+ dataset=dataset,
120
+ batch_size=batch_size,
121
+ num_workers=num_workers,
122
+ shuffle=shuffle,
123
+ pin_memory=pin_memory,
124
+ collate_fn=MyCollate(pad_idx=pad_idx),
125
+ )
126
+
127
+ return loader, dataset
128
+
129
+
130
+ if __name__ == "__main__":
131
+ transform = transforms.Compose(
132
+ [transforms.Resize((224, 224)), transforms.ToTensor(),]
133
+ )
134
+
135
+ loader, dataset = get_loader(
136
+ "flickr8k/images/", "flickr8k/captions.txt", transform=transform
137
+ )
138
+
139
+ for idx, (imgs, captions) in enumerate(loader):
140
+ print(imgs.shape)
141
+ print(captions.shape)
model.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+
5
+ #----------------------------------------------------------------------------
6
+
7
+ class EncoderCNN(nn.Module):
8
+ def __init__(self, embed_size) -> None:
9
+ super().__init__()
10
+ self.inception = models.inception_v3(pretrained=True, aux_logits=False)
11
+ for param in self.inception.parameters():
12
+ param.requires_grad = False
13
+ self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
14
+ self.relu = nn.ReLU(True)
15
+ self.times = []
16
+ self.dropout = nn.Dropout(0.5)
17
+
18
+ def forward(self, imgs):
19
+ features = self.inception(imgs)
20
+ return self.dropout(self.relu(features))
21
+
22
+ #----------------------------------------------------------------------------
23
+
24
+ class DecoderRNN(nn.Module):
25
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers) -> None:
26
+ super().__init__()
27
+ self.embed = nn.Embedding(vocab_size, embed_size)
28
+ self.LSTM = nn.LSTM(embed_size, hidden_size, num_layers)
29
+ self.linear = nn.Linear(hidden_size, vocab_size)
30
+ self.dropout = nn.Dropout(0.5)
31
+
32
+ def forward(self, features, captions):
33
+ embbedings = self.dropout(self.embed(captions))
34
+ # unsqueeze(0) 添加时间维度seq_len
35
+ embbedings = torch.cat([features.unsqueeze(0), embbedings], dim=0)
36
+ hiddens, _ = self.LSTM(embbedings)
37
+ outputs = self.linear(hiddens)
38
+ return outputs
39
+
40
+ #----------------------------------------------------------------------------
41
+
42
+ class CNNtoRNN(nn.Module):
43
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers) -> None:
44
+ super().__init__()
45
+ self.encoderCNN = EncoderCNN(embed_size)
46
+ self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)
47
+
48
+ def forward(self, imgs, captions):
49
+ features = self.encoderCNN(imgs)
50
+ outputs = self.decoderRNN(features, captions)
51
+ return outputs
52
+
53
+ def caption_image(self, img, vocab, max_length=50):
54
+ result_caption = []
55
+
56
+ with torch.no_grad():
57
+ x = self.encoderCNN(img).unsqueeze(0)
58
+ states = None
59
+
60
+ for _ in range(max_length):
61
+ # 逐个预测
62
+ h, states = self.decoderRNN.LSTM(x, states)
63
+ output = self.decoderRNN.linear(h.squeeze(0))
64
+ predicted = output.argmax(1)
65
+ # 预测的值作为下一次预测的输入
66
+ result_caption.append(predicted.item())
67
+ x = self.decoderRNN.embed(predicted).unsqueeze(0)
68
+
69
+ if vocab.itos[predicted.item()] == '<EOS>':
70
+ break
71
+
72
+ return [vocab.itos[idx] for idx in result_caption]
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ spacy
2
+ pandas
utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+
5
+
6
+ def print_examples(model, device, dataset):
7
+ transform = transforms.Compose(
8
+ [
9
+ transforms.Resize((299, 299)),
10
+ transforms.ToTensor(),
11
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
12
+ ]
13
+ )
14
+
15
+ model.eval()
16
+ test_img1 = transform(Image.open("test_examples/dog.jpg").convert("RGB")).unsqueeze(
17
+ 0
18
+ )
19
+ print("Example 1 CORRECT: Dog on a beach by the ocean")
20
+ print(
21
+ "Example 1 OUTPUT: "
22
+ + " ".join(model.caption_image(test_img1.to(device), dataset.vocab))
23
+ )
24
+ test_img2 = transform(
25
+ Image.open("test_examples/child.jpg").convert("RGB")
26
+ ).unsqueeze(0)
27
+ print("Example 2 CORRECT: Child holding red frisbee outdoors")
28
+ print(
29
+ "Example 2 OUTPUT: "
30
+ + " ".join(model.caption_image(test_img2.to(device), dataset.vocab))
31
+ )
32
+ test_img3 = transform(Image.open("test_examples/bus.png").convert("RGB")).unsqueeze(
33
+ 0
34
+ )
35
+ print("Example 3 CORRECT: Bus driving by parked cars")
36
+ print(
37
+ "Example 3 OUTPUT: "
38
+ + " ".join(model.caption_image(test_img3.to(device), dataset.vocab))
39
+ )
40
+ test_img4 = transform(
41
+ Image.open("test_examples/boat.png").convert("RGB")
42
+ ).unsqueeze(0)
43
+ print("Example 4 CORRECT: A small boat in the ocean")
44
+ print(
45
+ "Example 4 OUTPUT: "
46
+ + " ".join(model.caption_image(test_img4.to(device), dataset.vocab))
47
+ )
48
+ test_img5 = transform(
49
+ Image.open("test_examples/horse.png").convert("RGB")
50
+ ).unsqueeze(0)
51
+ print("Example 5 CORRECT: A cowboy riding a horse in the desert")
52
+ print(
53
+ "Example 5 OUTPUT: "
54
+ + " ".join(model.caption_image(test_img5.to(device), dataset.vocab))
55
+ )
56
+ model.train()
57
+
58
+
59
+ def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
60
+ print("=> Saving checkpoint")
61
+ torch.save(state, filename)
62
+
63
+
64
+ def load_checkpoint(checkpoint, model, optimizer):
65
+ print("=> Loading checkpoint")
66
+ model.load_state_dict(checkpoint["state_dict"])
67
+ optimizer.load_state_dict(checkpoint["optimizer"])
68
+ step = checkpoint["step"]
69
+ return step