Spaces:
Build error
Build error
add app
Browse files- app.py +51 -0
- captions.txt +0 -0
- get_loader.py +141 -0
- model.py +72 -0
- requirements.txt +2 -0
- utils.py +69 -0
app.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import torch
|
4 |
+
from torchvision.transforms import transforms
|
5 |
+
|
6 |
+
from get_loader import Vocabulary
|
7 |
+
from model import CNNtoRNN
|
8 |
+
|
9 |
+
|
10 |
+
def predict(img):
|
11 |
+
img = transform(img)
|
12 |
+
output = model.caption_image(img.unsqueeze(0).to(device), vocab)
|
13 |
+
return " ".join(output[1:-1])
|
14 |
+
|
15 |
+
if __name__ == '__main__':
|
16 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
17 |
+
model = CNNtoRNN(
|
18 |
+
embed_size=256,
|
19 |
+
hidden_size=256,
|
20 |
+
vocab_size=2994,
|
21 |
+
num_layers=1
|
22 |
+
)
|
23 |
+
print('Loading model weights...')
|
24 |
+
model.load_state_dict(torch.load(
|
25 |
+
'my_checkpoint.pth.tar', map_location=device)["state_dict"])
|
26 |
+
model.to(device)
|
27 |
+
model.eval()
|
28 |
+
print('Building vocabulary...')
|
29 |
+
vocab = Vocabulary(5)
|
30 |
+
df = pd.read_csv('captions.txt')
|
31 |
+
vocab.build_vocabulary(df['caption'].tolist())
|
32 |
+
|
33 |
+
transform = transforms.Compose(
|
34 |
+
[
|
35 |
+
transforms.ToPILImage(),
|
36 |
+
transforms.Resize((299, 299)),
|
37 |
+
transforms.ToTensor(),
|
38 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
39 |
+
]
|
40 |
+
)
|
41 |
+
print('Creating app...')
|
42 |
+
app = gr.Interface(
|
43 |
+
fn=predict,
|
44 |
+
inputs=gr.Image(shape=(256, 256)),
|
45 |
+
outputs="text",
|
46 |
+
)
|
47 |
+
print('done!')
|
48 |
+
|
49 |
+
app.launch(
|
50 |
+
share=True,
|
51 |
+
)
|
captions.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
get_loader.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os # when loading file paths
|
2 |
+
|
3 |
+
import pandas as pd # for lookup in annotation file
|
4 |
+
import spacy # for tokenizer
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from PIL import Image # Load img
|
8 |
+
from torch.nn.utils.rnn import pad_sequence # pad batch
|
9 |
+
from torch.utils.data import DataLoader, Dataset
|
10 |
+
|
11 |
+
# We want to convert text -> numerical values
|
12 |
+
# 1. We need a Vocabulary mapping each word to a index
|
13 |
+
# 2. We need to setup a Pytorch dataset to load the data
|
14 |
+
# 3. Setup padding of every batch (all examples should be
|
15 |
+
# of same seq_len and setup dataloader)
|
16 |
+
|
17 |
+
# Download with: python -m spacy download en
|
18 |
+
spacy_eng = spacy.load("en_core_web_sm")
|
19 |
+
|
20 |
+
|
21 |
+
class Vocabulary:
|
22 |
+
def __init__(self, freq_threshold):
|
23 |
+
self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
|
24 |
+
self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
|
25 |
+
self.freq_threshold = freq_threshold
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.itos)
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def tokenizer_eng(text):
|
32 |
+
return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
|
33 |
+
|
34 |
+
def build_vocabulary(self, sentence_list):
|
35 |
+
frequencies = {}
|
36 |
+
idx = 4
|
37 |
+
|
38 |
+
for sentence in sentence_list:
|
39 |
+
for word in self.tokenizer_eng(sentence):
|
40 |
+
if word not in frequencies:
|
41 |
+
frequencies[word] = 1
|
42 |
+
|
43 |
+
else:
|
44 |
+
frequencies[word] += 1
|
45 |
+
|
46 |
+
if frequencies[word] == self.freq_threshold:
|
47 |
+
self.stoi[word] = idx
|
48 |
+
self.itos[idx] = word
|
49 |
+
idx += 1
|
50 |
+
|
51 |
+
def numericalize(self, text):
|
52 |
+
tokenized_text = self.tokenizer_eng(text)
|
53 |
+
|
54 |
+
return [
|
55 |
+
self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
|
56 |
+
for token in tokenized_text
|
57 |
+
]
|
58 |
+
|
59 |
+
|
60 |
+
class FlickrDataset(Dataset):
|
61 |
+
def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
|
62 |
+
self.root_dir = root_dir
|
63 |
+
self.df = pd.read_csv(captions_file)
|
64 |
+
self.transform = transform
|
65 |
+
|
66 |
+
# Get img, caption columns
|
67 |
+
self.imgs = self.df["image"]
|
68 |
+
self.captions = self.df["caption"]
|
69 |
+
|
70 |
+
# Initialize vocabulary and build vocab
|
71 |
+
self.vocab = Vocabulary(freq_threshold)
|
72 |
+
self.vocab.build_vocabulary(self.captions.tolist())
|
73 |
+
|
74 |
+
def __len__(self):
|
75 |
+
return len(self.df)
|
76 |
+
|
77 |
+
def __getitem__(self, index):
|
78 |
+
caption = self.captions[index]
|
79 |
+
img_id = self.imgs[index]
|
80 |
+
img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
|
81 |
+
|
82 |
+
if self.transform is not None:
|
83 |
+
img = self.transform(img)
|
84 |
+
|
85 |
+
numericalized_caption = [self.vocab.stoi["<SOS>"]]
|
86 |
+
numericalized_caption += self.vocab.numericalize(caption)
|
87 |
+
numericalized_caption.append(self.vocab.stoi["<EOS>"])
|
88 |
+
|
89 |
+
return img, torch.tensor(numericalized_caption)
|
90 |
+
|
91 |
+
|
92 |
+
class MyCollate:
|
93 |
+
def __init__(self, pad_idx):
|
94 |
+
self.pad_idx = pad_idx
|
95 |
+
|
96 |
+
def __call__(self, batch):
|
97 |
+
imgs = [item[0].unsqueeze(0) for item in batch]
|
98 |
+
imgs = torch.cat(imgs, dim=0) # [BCHW]
|
99 |
+
targets = [item[1] for item in batch] # [BL] L长度不同
|
100 |
+
targets = pad_sequence(targets, batch_first=False, # [LB] L长度相同
|
101 |
+
padding_value=self.pad_idx)
|
102 |
+
return imgs, targets
|
103 |
+
|
104 |
+
|
105 |
+
def get_loader(
|
106 |
+
root_folder,
|
107 |
+
annotation_file,
|
108 |
+
transform,
|
109 |
+
batch_size=32,
|
110 |
+
num_workers=8,
|
111 |
+
shuffle=True,
|
112 |
+
pin_memory=True,
|
113 |
+
):
|
114 |
+
dataset = FlickrDataset(root_folder, annotation_file, transform=transform)
|
115 |
+
|
116 |
+
pad_idx = dataset.vocab.stoi["<PAD>"]
|
117 |
+
|
118 |
+
loader = DataLoader(
|
119 |
+
dataset=dataset,
|
120 |
+
batch_size=batch_size,
|
121 |
+
num_workers=num_workers,
|
122 |
+
shuffle=shuffle,
|
123 |
+
pin_memory=pin_memory,
|
124 |
+
collate_fn=MyCollate(pad_idx=pad_idx),
|
125 |
+
)
|
126 |
+
|
127 |
+
return loader, dataset
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
transform = transforms.Compose(
|
132 |
+
[transforms.Resize((224, 224)), transforms.ToTensor(),]
|
133 |
+
)
|
134 |
+
|
135 |
+
loader, dataset = get_loader(
|
136 |
+
"flickr8k/images/", "flickr8k/captions.txt", transform=transform
|
137 |
+
)
|
138 |
+
|
139 |
+
for idx, (imgs, captions) in enumerate(loader):
|
140 |
+
print(imgs.shape)
|
141 |
+
print(captions.shape)
|
model.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.models as models
|
4 |
+
|
5 |
+
#----------------------------------------------------------------------------
|
6 |
+
|
7 |
+
class EncoderCNN(nn.Module):
|
8 |
+
def __init__(self, embed_size) -> None:
|
9 |
+
super().__init__()
|
10 |
+
self.inception = models.inception_v3(pretrained=True, aux_logits=False)
|
11 |
+
for param in self.inception.parameters():
|
12 |
+
param.requires_grad = False
|
13 |
+
self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
|
14 |
+
self.relu = nn.ReLU(True)
|
15 |
+
self.times = []
|
16 |
+
self.dropout = nn.Dropout(0.5)
|
17 |
+
|
18 |
+
def forward(self, imgs):
|
19 |
+
features = self.inception(imgs)
|
20 |
+
return self.dropout(self.relu(features))
|
21 |
+
|
22 |
+
#----------------------------------------------------------------------------
|
23 |
+
|
24 |
+
class DecoderRNN(nn.Module):
|
25 |
+
def __init__(self, embed_size, hidden_size, vocab_size, num_layers) -> None:
|
26 |
+
super().__init__()
|
27 |
+
self.embed = nn.Embedding(vocab_size, embed_size)
|
28 |
+
self.LSTM = nn.LSTM(embed_size, hidden_size, num_layers)
|
29 |
+
self.linear = nn.Linear(hidden_size, vocab_size)
|
30 |
+
self.dropout = nn.Dropout(0.5)
|
31 |
+
|
32 |
+
def forward(self, features, captions):
|
33 |
+
embbedings = self.dropout(self.embed(captions))
|
34 |
+
# unsqueeze(0) 添加时间维度seq_len
|
35 |
+
embbedings = torch.cat([features.unsqueeze(0), embbedings], dim=0)
|
36 |
+
hiddens, _ = self.LSTM(embbedings)
|
37 |
+
outputs = self.linear(hiddens)
|
38 |
+
return outputs
|
39 |
+
|
40 |
+
#----------------------------------------------------------------------------
|
41 |
+
|
42 |
+
class CNNtoRNN(nn.Module):
|
43 |
+
def __init__(self, embed_size, hidden_size, vocab_size, num_layers) -> None:
|
44 |
+
super().__init__()
|
45 |
+
self.encoderCNN = EncoderCNN(embed_size)
|
46 |
+
self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)
|
47 |
+
|
48 |
+
def forward(self, imgs, captions):
|
49 |
+
features = self.encoderCNN(imgs)
|
50 |
+
outputs = self.decoderRNN(features, captions)
|
51 |
+
return outputs
|
52 |
+
|
53 |
+
def caption_image(self, img, vocab, max_length=50):
|
54 |
+
result_caption = []
|
55 |
+
|
56 |
+
with torch.no_grad():
|
57 |
+
x = self.encoderCNN(img).unsqueeze(0)
|
58 |
+
states = None
|
59 |
+
|
60 |
+
for _ in range(max_length):
|
61 |
+
# 逐个预测
|
62 |
+
h, states = self.decoderRNN.LSTM(x, states)
|
63 |
+
output = self.decoderRNN.linear(h.squeeze(0))
|
64 |
+
predicted = output.argmax(1)
|
65 |
+
# 预测的值作为下一次预测的输入
|
66 |
+
result_caption.append(predicted.item())
|
67 |
+
x = self.decoderRNN.embed(predicted).unsqueeze(0)
|
68 |
+
|
69 |
+
if vocab.itos[predicted.item()] == '<EOS>':
|
70 |
+
break
|
71 |
+
|
72 |
+
return [vocab.itos[idx] for idx in result_caption]
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
spacy
|
2 |
+
pandas
|
utils.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
def print_examples(model, device, dataset):
|
7 |
+
transform = transforms.Compose(
|
8 |
+
[
|
9 |
+
transforms.Resize((299, 299)),
|
10 |
+
transforms.ToTensor(),
|
11 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
12 |
+
]
|
13 |
+
)
|
14 |
+
|
15 |
+
model.eval()
|
16 |
+
test_img1 = transform(Image.open("test_examples/dog.jpg").convert("RGB")).unsqueeze(
|
17 |
+
0
|
18 |
+
)
|
19 |
+
print("Example 1 CORRECT: Dog on a beach by the ocean")
|
20 |
+
print(
|
21 |
+
"Example 1 OUTPUT: "
|
22 |
+
+ " ".join(model.caption_image(test_img1.to(device), dataset.vocab))
|
23 |
+
)
|
24 |
+
test_img2 = transform(
|
25 |
+
Image.open("test_examples/child.jpg").convert("RGB")
|
26 |
+
).unsqueeze(0)
|
27 |
+
print("Example 2 CORRECT: Child holding red frisbee outdoors")
|
28 |
+
print(
|
29 |
+
"Example 2 OUTPUT: "
|
30 |
+
+ " ".join(model.caption_image(test_img2.to(device), dataset.vocab))
|
31 |
+
)
|
32 |
+
test_img3 = transform(Image.open("test_examples/bus.png").convert("RGB")).unsqueeze(
|
33 |
+
0
|
34 |
+
)
|
35 |
+
print("Example 3 CORRECT: Bus driving by parked cars")
|
36 |
+
print(
|
37 |
+
"Example 3 OUTPUT: "
|
38 |
+
+ " ".join(model.caption_image(test_img3.to(device), dataset.vocab))
|
39 |
+
)
|
40 |
+
test_img4 = transform(
|
41 |
+
Image.open("test_examples/boat.png").convert("RGB")
|
42 |
+
).unsqueeze(0)
|
43 |
+
print("Example 4 CORRECT: A small boat in the ocean")
|
44 |
+
print(
|
45 |
+
"Example 4 OUTPUT: "
|
46 |
+
+ " ".join(model.caption_image(test_img4.to(device), dataset.vocab))
|
47 |
+
)
|
48 |
+
test_img5 = transform(
|
49 |
+
Image.open("test_examples/horse.png").convert("RGB")
|
50 |
+
).unsqueeze(0)
|
51 |
+
print("Example 5 CORRECT: A cowboy riding a horse in the desert")
|
52 |
+
print(
|
53 |
+
"Example 5 OUTPUT: "
|
54 |
+
+ " ".join(model.caption_image(test_img5.to(device), dataset.vocab))
|
55 |
+
)
|
56 |
+
model.train()
|
57 |
+
|
58 |
+
|
59 |
+
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
|
60 |
+
print("=> Saving checkpoint")
|
61 |
+
torch.save(state, filename)
|
62 |
+
|
63 |
+
|
64 |
+
def load_checkpoint(checkpoint, model, optimizer):
|
65 |
+
print("=> Loading checkpoint")
|
66 |
+
model.load_state_dict(checkpoint["state_dict"])
|
67 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
68 |
+
step = checkpoint["step"]
|
69 |
+
return step
|