rushil78 Vageesh1 commited on
Commit
f55b152
0 Parent(s):

Duplicate from Vageesh1/clip_gpt2

Browse files

Co-authored-by: vageesh <Vageesh1@users.noreply.huggingface.co>

Files changed (15) hide show
  1. .gitattributes +34 -0
  2. COCO_model.h5 +3 -0
  3. README.md +13 -0
  4. app.py +84 -0
  5. engine.py +42 -0
  6. model.h5 +3 -0
  7. model.py +220 -0
  8. model_2.py +0 -0
  9. model_trained.pth +3 -0
  10. neuralnet/dataset.py +139 -0
  11. neuralnet/model.py +71 -0
  12. neuralnet/train.py +130 -0
  13. neuralnet/utils.py +42 -0
  14. requirements.txt +19 -0
  15. vocab.json +0 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
COCO_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35200360d19ea02ce5c8f007c8bf6d8297e3c16ae3b3fb4b6eeb24ec1c07f8e6
3
+ size 636283447
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Clip Gpt2
3
+ emoji: 🐨
4
+ colorFrom: pink
5
+ colorTo: indigo
6
+ sdk: streamlit
7
+ sdk_version: 1.19.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: Vageesh1/clip_gpt2
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import clip
3
+ import PIL.Image
4
+ from PIL import Image
5
+ import skimage.io as io
6
+ import streamlit as st
7
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
8
+ from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel
9
+ from model import generate2,ClipCaptionModel
10
+ from engine import inference
11
+
12
+
13
+ model_trained = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
14
+ model_trained.load_state_dict(torch.load('model_trained.pth',map_location=torch.device('cpu')),strict=False)
15
+ image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
16
+ tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
17
+
18
+ def show_n_generate(img, model, greedy = True):
19
+ image = Image.open(img)
20
+ pixel_values = image_processor(image, return_tensors ="pt").pixel_values
21
+
22
+ if greedy:
23
+ generated_ids = model.generate(pixel_values, max_new_tokens = 30)
24
+ else:
25
+ generated_ids = model.generate(
26
+ pixel_values,
27
+ do_sample=True,
28
+ max_new_tokens = 30,
29
+ top_k=5)
30
+ generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
31
+ return generated_text
32
+
33
+ device = "cpu"
34
+ clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
35
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
36
+
37
+ prefix_length = 10
38
+
39
+ model = ClipCaptionModel(prefix_length)
40
+
41
+ model.load_state_dict(torch.load('model.h5',map_location=torch.device('cpu')),strict=False)
42
+
43
+ model = model.eval()
44
+
45
+ coco_model = ClipCaptionModel(prefix_length)
46
+ coco_model.load_state_dict(torch.load('COCO_model.h5',map_location=torch.device('cpu')),strict=False)
47
+ model = model.eval()
48
+
49
+
50
+ def ui():
51
+ st.markdown("# Image Captioning")
52
+ # st.markdown("## Done By- Vageesh and Rushil")
53
+ uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpeg', 'jpg'])
54
+
55
+ if uploaded_file is not None:
56
+ image = io.imread(uploaded_file)
57
+ pil_image = PIL.Image.fromarray(image)
58
+ image = preprocess(pil_image).unsqueeze(0).to(device)
59
+
60
+ option = st.selectbox('Please select the Model',('Clip Captioning','Attention Decoder','VIT+GPT2'))
61
+
62
+ if option=='Clip Captioning':
63
+ with torch.no_grad():
64
+ prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
65
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
66
+ generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
67
+
68
+ st.image(uploaded_file, width = 500, channels = 'RGB')
69
+ st.markdown("**PREDICTION:** " + generated_text_prefix)
70
+ elif option=='Attention Decoder':
71
+ out = inference(uploaded_file)
72
+ st.image(uploaded_file, width = 500, channels = 'RGB')
73
+ st.markdown("**PREDICTION:** " + out)
74
+
75
+ # elif option=='VIT+GPT2':
76
+ # out=show_n_generate(uploaded_file, greedy = False, model = model_trained)
77
+ # st.image(uploaded_file, width = 500, channels = 'RGB')
78
+ # st.markdown("**PREDICTION:** " + out)
79
+
80
+
81
+
82
+ if __name__ == '__main__':
83
+ ui()
84
+
engine.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ import json
6
+ from neuralnet.model import SeqToSeq
7
+ import wget
8
+
9
+ url = "https://github.com/Koushik0901/Image-Captioning/releases/download/v1.0/flickr30k.pt"
10
+ # os.system("curl -L https://github.com/Koushik0901/Image-Captioning/releases/download/v1.0/flickr30k.pt")
11
+ filename = wget.download(url)
12
+
13
+ def inference(img_path):
14
+ transform = transforms.Compose(
15
+ [
16
+ transforms.Resize((299, 299)),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
19
+ ]
20
+ )
21
+
22
+ vocabulary = json.load(open('./vocab.json'))
23
+
24
+ model_params = {"embed_size":256, "hidden_size":512, "vocab_size": 7666, "num_layers": 3, "device":"cpu"}
25
+ model = SeqToSeq(**model_params)
26
+ checkpoint = torch.load('./flickr30k.pt', map_location = 'cpu')
27
+ model.load_state_dict(checkpoint['state_dict'])
28
+
29
+ img = transform(Image.open(img_path).convert("RGB")).unsqueeze(0)
30
+
31
+ result_caption = []
32
+ model.eval()
33
+
34
+ x = model.encoder(img).unsqueeze(0)
35
+ states = None
36
+
37
+ out_captions = model.caption_image(img, vocabulary['itos'], 50)
38
+ return " ".join(out_captions[1:-1])
39
+
40
+
41
+ if __name__ == '__main__':
42
+ print(inference('./test_examples/dog.png'))
model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a36a09076b9779de2807d3aa533d455a398d70c1250aeb24a5cc9110e3d59a4
3
+ size 636272061
model.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import clip
2
+ import os
3
+ from torch import nn
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as nnf
7
+ import sys
8
+ from typing import Tuple, List, Union, Optional
9
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
10
+ from tqdm import tqdm, trange
11
+ import skimage.io as io
12
+ import PIL.Image
13
+
14
+
15
+ N = type(None)
16
+ V = np.array
17
+ ARRAY = np.ndarray
18
+ ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
19
+ VS = Union[Tuple[V, ...], List[V]]
20
+ VN = Union[V, N]
21
+ VNS = Union[VS, N]
22
+ T = torch.Tensor
23
+ TS = Union[Tuple[T, ...], List[T]]
24
+ TN = Optional[T]
25
+ TNS = Union[Tuple[TN, ...], List[TN]]
26
+ TSN = Optional[TS]
27
+ TA = Union[T, ARRAY]
28
+
29
+
30
+ D = torch.device
31
+
32
+ def get_device(device_id: int) -> D:
33
+ if not torch.cuda.is_available():
34
+ return CPU
35
+ device_id = min(torch.cuda.device_count() - 1, device_id)
36
+ return torch.device(f'cuda:{device_id}')
37
+
38
+
39
+ CUDA = get_device
40
+
41
+ current_directory = os.getcwd()
42
+ save_path = os.path.join(os.path.dirname(current_directory), "pretrained_models")
43
+ os.makedirs(save_path, exist_ok=True)
44
+ model_path = os.path.join(save_path, 'model_wieghts.pt')
45
+
46
+
47
+ class MLP(nn.Module):
48
+
49
+ def forward(self, x: T) -> T:
50
+ return self.model(x)
51
+
52
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
53
+ super(MLP, self).__init__()
54
+ layers = []
55
+ for i in range(len(sizes) -1):
56
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
57
+ if i < len(sizes) - 2:
58
+ layers.append(act())
59
+ self.model = nn.Sequential(*layers)
60
+
61
+ class ClipCaptionModel(nn.Module):
62
+
63
+ #@functools.lru_cache #FIXME
64
+ def get_dummy_token(self, batch_size: int, device: D) -> T:
65
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
66
+
67
+ def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
68
+ embedding_text = self.gpt.transformer.wte(tokens)
69
+ prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
70
+ #print(embedding_text.size()) #torch.Size([5, 67, 768])
71
+ #print(prefix_projections.size()) #torch.Size([5, 1, 768])
72
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
73
+ if labels is not None:
74
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
75
+ labels = torch.cat((dummy_token, tokens), dim=1)
76
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
77
+ return out
78
+
79
+ def __init__(self, prefix_length: int, prefix_size: int = 512):
80
+ super(ClipCaptionModel, self).__init__()
81
+ self.prefix_length = prefix_length
82
+ self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
83
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
84
+ if prefix_length > 10: # not enough memory
85
+ self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
86
+ else:
87
+ self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
88
+
89
+
90
+ class ClipCaptionPrefix(ClipCaptionModel):
91
+
92
+ def parameters(self, recurse: bool = True):
93
+ return self.clip_project.parameters()
94
+
95
+ def train(self, mode: bool = True):
96
+ super(ClipCaptionPrefix, self).train(mode)
97
+ self.gpt.eval()
98
+ return self
99
+
100
+ def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
101
+ entry_length=67, temperature=1., stop_token: str = '.'):
102
+
103
+ model.eval()
104
+ stop_token_index = tokenizer.encode(stop_token)[0]
105
+ tokens = None
106
+ scores = None
107
+ device = next(model.parameters()).device
108
+ seq_lengths = torch.ones(beam_size, device=device)
109
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
110
+ with torch.no_grad():
111
+ if embed is not None:
112
+ generated = embed
113
+ else:
114
+ if tokens is None:
115
+ tokens = torch.tensor(tokenizer.encode(prompt))
116
+ tokens = tokens.unsqueeze(0).to(device)
117
+ generated = model.gpt.transformer.wte(tokens)
118
+ for i in range(entry_length):
119
+ outputs = model.gpt(inputs_embeds=generated)
120
+ logits = outputs.logits
121
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
122
+ logits = logits.softmax(-1).log()
123
+ if scores is None:
124
+ scores, next_tokens = logits.topk(beam_size, -1)
125
+ generated = generated.expand(beam_size, *generated.shape[1:])
126
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
127
+ if tokens is None:
128
+ tokens = next_tokens
129
+ else:
130
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
131
+ tokens = torch.cat((tokens, next_tokens), dim=1)
132
+ else:
133
+ logits[is_stopped] = -float(np.inf)
134
+ logits[is_stopped, 0] = 0
135
+ scores_sum = scores[:, None] + logits
136
+ seq_lengths[~is_stopped] += 1
137
+ scores_sum_average = scores_sum / seq_lengths[:, None]
138
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
139
+ next_tokens_source = next_tokens // scores_sum.shape[1]
140
+ seq_lengths = seq_lengths[next_tokens_source]
141
+ next_tokens = next_tokens % scores_sum.shape[1]
142
+ next_tokens = next_tokens.unsqueeze(1)
143
+ tokens = tokens[next_tokens_source]
144
+ tokens = torch.cat((tokens, next_tokens), dim=1)
145
+ generated = generated[next_tokens_source]
146
+ scores = scores_sum_average * seq_lengths
147
+ is_stopped = is_stopped[next_tokens_source]
148
+ next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
149
+ generated = torch.cat((generated, next_token_embed), dim=1)
150
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
151
+ if is_stopped.all():
152
+ break
153
+ scores = scores / seq_lengths
154
+ output_list = tokens.cpu().numpy()
155
+ output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
156
+ order = scores.argsort(descending=True)
157
+ output_texts = [output_texts[i] for i in order]
158
+ return output_texts
159
+
160
+ def generate2(
161
+ model,
162
+ tokenizer,
163
+ tokens=None,
164
+ prompt=None,
165
+ embed=None,
166
+ entry_count=1,
167
+ entry_length=67, # maximum number of words
168
+ top_p=0.8,
169
+ temperature=1.,
170
+ stop_token: str = '.',
171
+ ):
172
+ model.eval()
173
+ generated_num = 0
174
+ generated_list = []
175
+ stop_token_index = tokenizer.encode(stop_token)[0]
176
+ filter_value = -float("Inf")
177
+ device = next(model.parameters()).device
178
+
179
+ with torch.no_grad():
180
+
181
+ for entry_idx in trange(entry_count):
182
+ if embed is not None:
183
+ generated = embed
184
+ else:
185
+ if tokens is None:
186
+ tokens = torch.tensor(tokenizer.encode(prompt))
187
+ tokens = tokens.unsqueeze(0).to(device)
188
+
189
+ generated = model.gpt.transformer.wte(tokens)
190
+
191
+ for i in range(entry_length):
192
+
193
+ outputs = model.gpt(inputs_embeds=generated)
194
+ logits = outputs.logits
195
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
196
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
197
+ cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
198
+ sorted_indices_to_remove = cumulative_probs > top_p
199
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
200
+ ..., :-1
201
+ ].clone()
202
+ sorted_indices_to_remove[..., 0] = 0
203
+
204
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
205
+ logits[:, indices_to_remove] = filter_value
206
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
207
+ next_token_embed = model.gpt.transformer.wte(next_token)
208
+ if tokens is None:
209
+ tokens = next_token
210
+ else:
211
+ tokens = torch.cat((tokens, next_token), dim=1)
212
+ generated = torch.cat((generated, next_token_embed), dim=1)
213
+ if stop_token_index == next_token.item():
214
+ break
215
+
216
+ output_list = list(tokens.squeeze().cpu().numpy())
217
+ output_text = tokenizer.decode(output_list)
218
+ generated_list.append(output_text)
219
+
220
+ return generated_list[0]
model_2.py ADDED
File without changes
model_trained.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f44c397a407f1687578a0346cbe19262b4ba6954c3256ec656ade873ac57d07
3
+ size 982140285
neuralnet/dataset.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os # when loading file paths
2
+ import pandas as pd # for lookup in annotation file
3
+ import spacy # for tokenizer
4
+ import torch
5
+ from torch.nn.utils.rnn import pad_sequence # pad batch
6
+ from torch.utils.data import DataLoader, Dataset
7
+ from PIL import Image # Load img
8
+ import torchvision.transforms as transforms
9
+ import json
10
+
11
+ # Download with: python -m spacy download en
12
+ spacy_eng = spacy.load("en_core_web_sm")
13
+
14
+
15
+ class Vocabulary:
16
+ def __init__(self, freq_threshold):
17
+ self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
18
+ self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
19
+ self.freq_threshold = freq_threshold
20
+
21
+ def __len__(self):
22
+ return len(self.stoi)
23
+
24
+ @staticmethod
25
+ def tokenizer_eng(text):
26
+ return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
27
+
28
+ def build_vocabulary(self, sentence_list):
29
+ frequencies = {}
30
+ idx = 4
31
+
32
+ for sentence in sentence_list:
33
+ for word in self.tokenizer_eng(sentence):
34
+ if word not in frequencies:
35
+ frequencies[word] = 1
36
+
37
+ else:
38
+ frequencies[word] += 1
39
+
40
+ if frequencies[word] == self.freq_threshold:
41
+ self.stoi[word] = idx
42
+ self.itos[idx] = word
43
+ idx += 1
44
+
45
+ def numericalize(self, text):
46
+ tokenized_text = self.tokenizer_eng(text)
47
+
48
+ return [
49
+ self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
50
+ for token in tokenized_text
51
+ ]
52
+
53
+
54
+ class FlickrDataset(Dataset):
55
+ def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
56
+ self.root_dir = root_dir
57
+ self.df = pd.read_csv(captions_file)
58
+ self.transform = transform
59
+
60
+ # Get img, caption columns
61
+ self.imgs = self.df["image_name"]
62
+ self.captions = self.df["comment"]
63
+
64
+ # Initialize vocabulary and build vocab
65
+ self.vocab = Vocabulary(freq_threshold)
66
+ self.vocab.build_vocabulary(self.captions.tolist())
67
+
68
+ def __len__(self):
69
+ return len(self.df)
70
+
71
+ def __getitem__(self, index):
72
+ caption = self.captions[index]
73
+ img_id = self.imgs[index]
74
+ img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
75
+
76
+ if self.transform is not None:
77
+ img = self.transform(img)
78
+
79
+ numericalized_caption = [self.vocab.stoi["<SOS>"]]
80
+ numericalized_caption += self.vocab.numericalize(caption)
81
+ numericalized_caption.append(self.vocab.stoi["<EOS>"])
82
+
83
+ return img, torch.tensor(numericalized_caption)
84
+
85
+
86
+ class MyCollate:
87
+ def __init__(self, pad_idx):
88
+ self.pad_idx = pad_idx
89
+
90
+ def __call__(self, batch):
91
+ imgs = [item[0].unsqueeze(0) for item in batch]
92
+ imgs = torch.cat(imgs, dim=0)
93
+ targets = [item[1] for item in batch]
94
+ targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
95
+
96
+ return imgs, targets
97
+
98
+
99
+ def get_loader(
100
+ root_folder,
101
+ annotation_file,
102
+ transform,
103
+ batch_size=64,
104
+ num_workers=2,
105
+ shuffle=True,
106
+ pin_memory=True,
107
+ ):
108
+ dataset = FlickrDataset(root_folder, annotation_file, transform=transform)
109
+
110
+ pad_idx = dataset.vocab.stoi["<PAD>"]
111
+
112
+ loader = DataLoader(
113
+ dataset=dataset,
114
+ batch_size=batch_size,
115
+ num_workers=num_workers,
116
+ shuffle=shuffle,
117
+ pin_memory=pin_memory,
118
+ collate_fn=MyCollate(pad_idx=pad_idx),
119
+ )
120
+
121
+ return loader, dataset
122
+
123
+
124
+ if __name__ == "__main__":
125
+ transform = transforms.Compose(
126
+ [transforms.Resize((224, 224)), transforms.ToTensor(),]
127
+ )
128
+
129
+ loader, dataset = get_loader(
130
+ "/home/koushik/vscode/Projects/pytorch/img2text_v1/flickr30k/flickr30k_images/", "/home/koushik/vscode/Projects/pytorch/img2text_v1/flickr30k/results.csv", transform=transform
131
+ )
132
+
133
+ for idx, (imgs, captions) in enumerate(loader):
134
+ print(imgs.shape)
135
+ print(captions.shape)
136
+ print(len(dataset.vocab))
137
+ test = {"itos":dataset.vocab.itos, "stoi": dataset.vocab.stoi}
138
+ json.dump(test, open('test.json', 'w'))
139
+ break
neuralnet/model.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+
5
+
6
+ class InceptionEncoder(nn.Module):
7
+ def __init__(self, embed_size, train_CNN=False):
8
+ super(InceptionEncoder, self).__init__()
9
+ self.train_CNN = train_CNN
10
+ self.inception = models.inception_v3(pretrained=True, aux_logits=False)
11
+ self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
12
+ self.relu = nn.ReLU()
13
+ self.bn = nn.BatchNorm1d(embed_size, momentum = 0.01)
14
+ self.dropout = nn.Dropout(0.5)
15
+
16
+ def forward(self, images):
17
+ features = self.inception(images)
18
+ norm_features = self.bn(features)
19
+ return self.dropout(self.relu(norm_features))
20
+
21
+
22
+ class LstmDecoder(nn.Module):
23
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers, device = 'cpu'):
24
+ super(LstmDecoder, self).__init__()
25
+ self.num_layers = num_layers
26
+ self.hidden_size = hidden_size
27
+ self.device = device
28
+ self.embed = nn.Embedding(vocab_size, embed_size)
29
+ self.lstm = nn.LSTM(embed_size, hidden_size, num_layers = self.num_layers)
30
+ self.linear = nn.Linear(hidden_size, vocab_size)
31
+ self.dropout = nn.Dropout(0.5)
32
+
33
+ def forward(self, encoder_out, captions):
34
+ h0 = torch.zeros(self.num_layers, encoder_out.shape[0], self.hidden_size).to(self.device).requires_grad_()
35
+ c0 = torch.zeros(self.num_layers, encoder_out.shape[0], self.hidden_size).to(self.device).requires_grad_()
36
+ embeddings = self.dropout(self.embed(captions))
37
+ embeddings = torch.cat((encoder_out.unsqueeze(0), embeddings), dim=0)
38
+ hiddens, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
39
+ outputs = self.linear(hiddens)
40
+ return outputs
41
+
42
+
43
+ class SeqToSeq(nn.Module):
44
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers, device = 'cpu'):
45
+ super(SeqToSeq, self).__init__()
46
+ self.encoder = InceptionEncoder(embed_size)
47
+ self.decoder = LstmDecoder(embed_size, hidden_size, vocab_size, num_layers, device)
48
+
49
+ def forward(self, images, captions):
50
+ features = self.encoder(images)
51
+ outputs = self.decoder(features, captions)
52
+ return outputs
53
+
54
+ def caption_image(self, image, vocabulary, max_length = 50):
55
+ result_caption = []
56
+
57
+ with torch.no_grad():
58
+ x = self.encoder(image).unsqueeze(0)
59
+ states = None
60
+
61
+ for _ in range(max_length):
62
+ hiddens, states = self.decoder.lstm(x, states)
63
+ output = self.decoder.linear(hiddens.squeeze(0))
64
+ predicted = output.argmax(1)
65
+ result_caption.append(predicted.item())
66
+ x = self.decoder.embed(predicted).unsqueeze(0)
67
+
68
+ if vocabulary[str(predicted.item())] == "<EOS>":
69
+ break
70
+
71
+ return [vocabulary[str(idx)] for idx in result_caption]
neuralnet/train.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import torchvision.transforms as transforms
6
+ from torch.utils.tensorboard import SummaryWriter # For TensorBoard
7
+ from utils import save_checkpoint, load_checkpoint, print_examples
8
+ from dataset import get_loader
9
+ from model import SeqToSeq
10
+ from tabulate import tabulate # To tabulate loss and epoch
11
+ import argparse
12
+ import json
13
+
14
+ def main(args):
15
+ transform = transforms.Compose(
16
+ [
17
+ transforms.Resize((356, 356)),
18
+ transforms.RandomCrop((299, 299)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
21
+ ]
22
+ )
23
+
24
+ train_loader, _ = get_loader(
25
+ root_folder = args.root_dir,
26
+ annotation_file = args.csv_file,
27
+ transform=transform,
28
+ batch_size = 64,
29
+ num_workers=2,
30
+ )
31
+ vocab = json.load(open('vocab.json'))
32
+
33
+ torch.backends.cudnn.benchmark = True
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ load_model = False
36
+ save_model = True
37
+ train_CNN = False
38
+
39
+ # Hyperparameters
40
+ embed_size = args.embed_size
41
+ hidden_size = args.hidden_size
42
+ vocab_size = len(vocab['stoi'])
43
+ num_layers = args.num_layers
44
+ learning_rate = args.lr
45
+ num_epochs = args.num_epochs
46
+ # for tensorboard
47
+
48
+
49
+ writer = SummaryWriter(args.log_dir)
50
+ step = 0
51
+ model_params = {'embed_size': embed_size, 'hidden_size': hidden_size, 'vocab_size':vocab_size, 'num_layers':num_layers}
52
+ # initialize model, loss etc
53
+ model = SeqToSeq(**model_params, device = device).to(device)
54
+ criterion = nn.CrossEntropyLoss(ignore_index = vocab['stoi']["<PAD>"])
55
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
56
+
57
+ # Only finetune the CNN
58
+ for name, param in model.encoder.inception.named_parameters():
59
+ if "fc.weight" in name or "fc.bias" in name:
60
+ param.requires_grad = True
61
+ else:
62
+ param.requires_grad = train_CNN
63
+
64
+ #load from a save checkpoint
65
+ if load_model:
66
+ step = load_checkpoint(torch.load(args.save_path), model, optimizer)
67
+
68
+ model.train()
69
+ best_loss, best_epoch = 10, 0
70
+ for epoch in range(num_epochs):
71
+ print_examples(model, device, vocab['itos'])
72
+
73
+ for idx, (imgs, captions) in tqdm(
74
+ enumerate(train_loader), total=len(train_loader), leave=False):
75
+ imgs = imgs.to(device)
76
+ captions = captions.to(device)
77
+
78
+ outputs = model(imgs, captions[:-1])
79
+ loss = criterion(
80
+ outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
81
+ )
82
+
83
+ writer.add_scalar("Training loss", loss.item(), global_step=step)
84
+ step += 1
85
+
86
+ optimizer.zero_grad()
87
+ loss.backward(loss)
88
+ optimizer.step()
89
+
90
+ train_loss = loss.item()
91
+ if train_loss < best_loss:
92
+ best_loss = train_loss
93
+ best_epoch = epoch + 1
94
+ if save_model:
95
+ checkpoint = {
96
+ "model_params": model_params,
97
+ "state_dict": model.state_dict(),
98
+ "optimizer": optimizer.state_dict(),
99
+ "step": step
100
+ }
101
+ save_checkpoint(checkpoint, args.save_path)
102
+
103
+
104
+ table = [["Loss:", train_loss],
105
+ ["Step:", step],
106
+ ["Epoch:", epoch + 1],
107
+ ["Best Loss:", best_loss],
108
+ ["Best Epoch:", best_epoch]]
109
+ print(tabulate(table))
110
+
111
+
112
+ if __name__ == "__main__":
113
+
114
+ parser = argparse.ArgumentParser()
115
+
116
+ parser.add_argument('--root_dir', type = str, default = './flickr30k/flickr30k_images', help = 'path to images folder')
117
+ parser.add_argument('--csv_file', type = str, default = './flickr30k/results.csv', help = 'path to captions csv file')
118
+ parser.add_argument('--log_dir', type = str, default = './drive/MyDrive/TensorBoard/', help = 'path to save tensorboard logs')
119
+ parser.add_argument('--save_path', type = str, default = './drive/MyDrive/checkpoints/Seq2Seq.pt', help = 'path to save checkpoint')
120
+ # Model Params
121
+ parser.add_argument('--batch_size', type = int, default = 64)
122
+ parser.add_argument('--num_epochs', type = int, default = 100)
123
+ parser.add_argument('--embed_size', type = int, default=256)
124
+ parser.add_argument('--hidden_size', type = int, default=512)
125
+ parser.add_argument('--lr', type = float, default= 0.001)
126
+ parser.add_argument('--num_layers', type = int, default = 3, help = 'number of lstm layers')
127
+
128
+ args = parser.parse_args()
129
+
130
+ main(args)
neuralnet/utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+
5
+
6
+ def print_examples(model, device, vocab):
7
+ transform = transforms.Compose(
8
+ [transforms.Resize((299, 299)),
9
+ transforms.ToTensor(),
10
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
11
+ )
12
+
13
+ model.eval()
14
+
15
+ test_img1 = transform(Image.open("./test_examples/dog.png").convert("RGB")).unsqueeze(0)
16
+ print("dog.png PREDICTION: " + " ".join(model.caption_image(test_img1.to(device), vocab)))
17
+
18
+ test_img2 = transform(Image.open("./test_examples/dirt_bike.png").convert("RGB")).unsqueeze(0)
19
+ print("dirt_bike.png PREDICTION: " + " ".join(model.caption_image(test_img2.to(device), vocab)))
20
+
21
+ test_img3 = transform(Image.open("./test_examples/surfing.png").convert("RGB")).unsqueeze(0)
22
+ print("wave.png PREDICTION: " + " ".join(model.caption_image(test_img3.to(device), vocab)))
23
+
24
+ test_img4 = transform(Image.open("./test_examples/horse.png").convert("RGB")).unsqueeze(0)
25
+ print("horse.png PREDICTION: " + " ".join(model.caption_image(test_img4.to(device), vocab)))
26
+
27
+ test_img5 = transform(Image.open("./test_examples/camera.png").convert("RGB")).unsqueeze(0)
28
+ print("camera.png PREDICTION: " + " ".join(model.caption_image(test_img5.to(device), vocab)))
29
+ model.train()
30
+
31
+
32
+ def save_checkpoint(state, filename="/content/drive/MyDrive/checkpoints/Seq2Seq.pt"):
33
+ print("=> Saving checkpoint")
34
+ torch.save(state, filename)
35
+
36
+
37
+ def load_checkpoint(checkpoint, model, optimizer):
38
+ print("=> Loading checkpoint")
39
+ model.load_state_dict(checkpoint["state_dict"])
40
+ optimizer.load_state_dict(checkpoint["optimizer"])
41
+ step = checkpoint["step"]
42
+ return step
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ ftfy
4
+ git+https://github.com/openai/CLIP.git
5
+ regex
6
+ tqdm
7
+ streamlit
8
+ scikit-image
9
+ pillow
10
+ pandas
11
+ transformers
12
+ numpy
13
+ spacy
14
+ tqdm
15
+ tabulate
16
+ click==7.1.1
17
+ gdown
18
+ wget
19
+ altair<5
vocab.json ADDED
The diff for this file is too large to render. See raw diff