Jyothirmai commited on
Commit
12ff552
1 Parent(s): 75d178b

Upload vitGPT.py

Browse files
Files changed (1) hide show
  1. vitGPT.py +349 -0
vitGPT.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import pandas as pd
6
+ import matplotlib.pyplot as plt
7
+ from timm import create_model, list_models
8
+ from types import SimpleNamespace
9
+ from transformers import GPT2LMHeadModel, GPT2TokenizerFast, get_linear_schedule_with_warmup
10
+ import albumentations as A
11
+ from albumentations.pytorch import ToTensorV2
12
+ from PIL import Image
13
+ from pathlib import Path
14
+ from sklearn.model_selection import train_test_split
15
+ from torch.cuda.amp import GradScaler, autocast
16
+ from tqdm.auto import tqdm
17
+ import gc
18
+ import json
19
+
20
+ class GPT2Attention(nn.Module):
21
+ def __init__(self,config):
22
+ super().__init__()
23
+ self.embed_dim = config.embed_dim
24
+ self.n_heads = config.num_heads
25
+ assert self.embed_dim % self.n_heads == 0, 'embedding dimension by be divisible by number of heads'
26
+ self.head_size = self.embed_dim // self.n_heads
27
+ self.seq_len = config.seq_len
28
+
29
+ self.c_attn = nn.Linear(self.embed_dim, self.head_size * self.n_heads * 3,bias=True)
30
+ self.scale = self.head_size ** -0.5
31
+
32
+ self.register_buffer('mask',torch.tril(torch.ones(1,1,self.seq_len,self.seq_len)))
33
+
34
+ self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
35
+
36
+ self.attn_dropout = nn.Dropout(config.attention_dropout)
37
+ self.resid_dropout = nn.Dropout(config.residual_dropout)
38
+
39
+
40
+ def forward(self, x):
41
+ b,t,c = x.shape
42
+ # q,k,v shape individually: batch_size x seq_len x embed_dim
43
+ # we know that qk_t = q x k_t, where q=bxtxhead_dim, k_t=bxhead_timxt
44
+ q,k,v = self.c_attn(x).chunk(3,dim=-1)
45
+ q = q.view(b,t,self.n_heads,self.head_size).permute(0,2,1,3) # batch x n_heads x seq_len x head_dim
46
+ k = k.view(b,t,self.n_heads,self.head_size).permute(0,2,1,3)
47
+ v = v.view(b,t,self.n_heads,self.head_size).permute(0,2,1,3)
48
+
49
+ qk_t = (q@k.transpose(-2,-1)) * self.scale
50
+ qk_t = qk_t.masked_fill(self.mask[:,:,:t,:t]==0,float('-inf'))
51
+ qk_t = F.softmax(qk_t,dim=-1)
52
+ weights = self.attn_dropout(qk_t)
53
+
54
+ attention = weights @ v # batch x n_heads x t x head_size
55
+ attention = attention.permute(0,2,1,3).contiguous().view(b,t,c) # batch x t x embed_dim
56
+
57
+ out = self.c_proj(attention)
58
+ out = self.resid_dropout(out)
59
+
60
+ return out
61
+
62
+ class GPT2CrossAttention(nn.Module):
63
+ def __init__(self,config):
64
+ super().__init__()
65
+ self.embed_dim = config.embed_dim
66
+ self.n_heads = config.num_heads
67
+ assert self.embed_dim % self.n_heads == 0, 'embedding dimension by be divisible by number of heads'
68
+ self.head_size = self.embed_dim // self.n_heads
69
+ self.seq_len = config.seq_len
70
+
71
+ self.q = nn.Linear(self.embed_dim,self.embed_dim)
72
+ self.k = nn.Linear(self.embed_dim,self.embed_dim)
73
+ self.v = nn.Linear(self.embed_dim,self.embed_dim)
74
+ self.scale = self.head_size ** -0.5
75
+
76
+ self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
77
+
78
+ self.attn_dropout = nn.Dropout(config.attention_dropout)
79
+ self.resid_dropout = nn.Dropout(config.residual_dropout)
80
+
81
+ self.apply(self._init_weights)
82
+
83
+ def _init_weights(self, module):
84
+ if isinstance(module, nn.Linear):
85
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
86
+ if module.bias is not None:
87
+ torch.nn.init.zeros_(module.bias)
88
+
89
+
90
+ def forward(self, q,k,v):
91
+ b,t,c = q.shape
92
+
93
+ q = self.q(q)
94
+ k = self.k(k)
95
+ v = self.v(v)
96
+
97
+ q = q.view(b,q.size(1),self.n_heads,self.head_size).permute(0,2,1,3) # batch x n_heads x seq_len x head_dim
98
+ k = k.view(b,k.size(1),self.n_heads,self.head_size).permute(0,2,1,3)
99
+ v = v.view(b,v.size(1),self.n_heads,self.head_size).permute(0,2,1,3)
100
+
101
+ qk_t = (q@k.transpose(-2,-1)) * self.scale
102
+ qk_t = F.softmax(qk_t,dim=-1)
103
+ weights = self.attn_dropout(qk_t)
104
+
105
+ attention = weights @ v # batch x n_heads x t x head_size
106
+ attention = attention.permute(0,2,1,3).contiguous().view(b,t,c) # batch x t x embed_dim
107
+
108
+ out = self.c_proj(attention)
109
+ out = self.resid_dropout(out)
110
+
111
+ return out
112
+
113
+
114
+ class GPT2MLP(nn.Module):
115
+ def __init__(self,config):
116
+ super().__init__()
117
+ self.embed_dim = config.embed_dim
118
+ self.mlp_ratio = config.mlp_ratio
119
+ self.mlp_dropout = config.mlp_dropout
120
+
121
+ self.c_fc = nn.Linear(self.embed_dim,self.embed_dim*self.mlp_ratio)
122
+ self.c_proj = nn.Linear(self.embed_dim*self.mlp_ratio,self.embed_dim)
123
+ self.act = nn.GELU()
124
+ self.dropout = nn.Dropout(self.mlp_dropout)
125
+
126
+ def forward(self,x):
127
+ x = self.c_fc(x)
128
+ x = self.act(x)
129
+ x = self.c_proj(x)
130
+ x = self.dropout(x)
131
+ return x
132
+
133
+
134
+ class GPT2Block(nn.Module):
135
+ def __init__(self,config):
136
+ super().__init__()
137
+ self.embed_dim = config.embed_dim
138
+ self.ln_1 = nn.LayerNorm(self.embed_dim)
139
+ self.attn = GPT2Attention(config)
140
+ self.ln_2 = nn.LayerNorm(self.embed_dim)
141
+ self.mlp = GPT2MLP(config)
142
+ self.ln_3 = nn.LayerNorm(self.embed_dim)
143
+ self.cross_attn = GPT2CrossAttention(config)
144
+
145
+ def forward(self,x,enc_out):
146
+ x = x+self.attn(self.ln_1(x))
147
+ x = x+self.cross_attn(self.ln_2(x),enc_out,enc_out)
148
+ x = x+self.mlp(self.ln_3(x))
149
+ return x
150
+
151
+
152
+
153
+ class VisionGPT2Model(nn.Module):
154
+ def __init__(self,config):
155
+ super().__init__()
156
+
157
+ self.config = config
158
+ print(torch.cuda.is_available())
159
+ vit = create_model('vit_base_patch16_224',pretrained=True,num_classes=0)
160
+ self.patch_embed = vit.patch_embed
161
+ num_patches = self.patch_embed.num_patches
162
+
163
+ self.cls_token = vit.cls_token
164
+ embed_len = num_patches + vit.num_prefix_tokens
165
+ self.pos_embed = vit.pos_embed
166
+ self.pos_drop = nn.Dropout(p=0.)
167
+
168
+ self.blocks = nn.ModuleList([vit.blocks[i] for i in range(config.depth)])
169
+
170
+ self.transformer = nn.ModuleDict(dict(
171
+ wte = nn.Embedding(config.vocab_size,config.embed_dim),
172
+ wpe = nn.Embedding(config.seq_len,config.embed_dim),
173
+ drop = nn.Dropout(config.emb_dropout),
174
+ h = nn.ModuleList([GPT2Block(config) for _ in range(config.depth)]),
175
+ ln_f = nn.LayerNorm(config.embed_dim)
176
+ ))
177
+ self.lm_head = nn.Linear(config.embed_dim,config.vocab_size,bias=False)
178
+ self.transformer.wte.weight = self.lm_head.weight
179
+
180
+ def _pos_embed(self,x):
181
+ pos_embed = self.pos_embed
182
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
183
+ x = x + pos_embed
184
+ return self.pos_drop(x)
185
+
186
+ def pretrained_layers_trainable(self,trainable=False):
187
+ layers = [
188
+ self.cls_token, self.patch_embed, self.pos_embed, self.blocks,
189
+ self.transformer.wte, self.transformer.wpe,
190
+ self.transformer.ln_f, self.lm_head
191
+ ]
192
+ gpt_layers = [[
193
+ self.transformer.h[i].ln_1,self.transformer.h[i].ln_2,
194
+ self.transformer.h[i].attn,self.transformer.h[i].mlp
195
+ ] for i in range(self.config.depth)]
196
+ for l in gpt_layers:
197
+ layers.extend(l)
198
+
199
+ for layer in layers:
200
+ if not isinstance(layer,nn.Parameter):
201
+ for p in layer.parameters():
202
+ p.requires_grad = trainable
203
+ else:
204
+ layer.requires_grad = trainable
205
+
206
+ total_frozen_params = sum([p.numel() for p in self.parameters() if not p.requires_grad])
207
+ print(f'{total_frozen_params=}')
208
+
209
+ def unfreeze_gpt_layers(self,):
210
+ gpt_layers = [[
211
+ self.transformer.h[i].ln_1,self.transformer.h[i].ln_2,
212
+ self.transformer.h[i].attn,self.transformer.h[i].mlp
213
+ ] for i in range(self.config.depth)]
214
+ flatten = []
215
+ for l in gpt_layers:
216
+ flatten.extend(l)
217
+
218
+ for layer in flatten:
219
+ if not isinstance(layer,nn.Parameter):
220
+ for p in layer.parameters():
221
+ p.requires_grad = True
222
+ else:
223
+ layer.requires_grad = True
224
+
225
+ @classmethod
226
+ def from_pretrained(self,config):
227
+ model = VisionGPT2Model(config)
228
+ sd = model.state_dict()
229
+ keys = sd.keys()
230
+ ignore_matches = ['blocks.','cross_attn.','ln_3','cls_token','pos_embed','patch_embed.','.attn.mask']
231
+ vit_keys = [key for key in keys if any(match in key for match in ignore_matches)]
232
+ gpt_keys = [key for key in keys if key not in vit_keys]
233
+
234
+ gpt2_small = GPT2LMHeadModel.from_pretrained('gpt2')
235
+ sd_hf = gpt2_small.state_dict()
236
+ hf_keys = sd_hf.keys()
237
+ hf_keys = [k for k in hf_keys if not k.endswith('.attn.masked_bias')]
238
+ hf_keys = [k for k in hf_keys if not k.endswith('.attn.bias')]
239
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
240
+
241
+ for k in hf_keys:
242
+ if any(match in k for match in ignore_matches):
243
+ continue
244
+ if any(k.endswith(w) for w in transposed):
245
+ assert sd_hf[k].shape[::-1] == sd[k].shape
246
+ with torch.no_grad():
247
+ sd[k].copy_(sd_hf[k].t())
248
+ else:
249
+ assert sd_hf[k].shape == sd[k].shape
250
+ with torch.no_grad():
251
+ sd[k].copy_(sd_hf[k])
252
+
253
+ model.load_state_dict(sd)
254
+
255
+ return model
256
+
257
+ def forward(self,image,input_ids,labels=None):
258
+
259
+ image = self.patch_embed(image)
260
+ image = self._pos_embed(image)
261
+
262
+ token_embeddings = self.transformer.wte(input_ids) # batch x seq_len
263
+ pos_embs = torch.arange(0, input_ids.size(1), device=self.config.device)
264
+ positional_embeddings = self.transformer.wpe(pos_embs)
265
+ input_ids = self.transformer.drop(token_embeddings+positional_embeddings)
266
+
267
+ for i in range(self.config.depth):
268
+ image = self.blocks[i](image)
269
+ input_ids = self.transformer.h[i](input_ids, image)
270
+
271
+ input_ids = self.transformer.ln_f(input_ids)
272
+
273
+ if labels is not None:
274
+ lm_logits = self.lm_head(input_ids)
275
+ loss = F.cross_entropy(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
276
+ return loss
277
+
278
+ lm_logits = self.lm_head(input_ids[:,[-1],:])
279
+ return lm_logits
280
+
281
+ def generate(self,image,sequence,tokenizer,max_tokens=50,temperature=1.0,deterministic=False):
282
+ for _ in range(max_tokens):
283
+ out = self(image,sequence)
284
+ out = out[:,-1,:] / temperature
285
+ probs = F.softmax(out,dim=-1)
286
+ if deterministic:
287
+ next_token = torch.argmax(probs,dim=-1,keepdim=True)
288
+ else:
289
+ next_token = torch.multinomial(probs,num_samples=1)
290
+ sequence = torch.cat([sequence,next_token],dim=1)
291
+ if next_token.item() == tokenizer.eos_token_id:
292
+ break
293
+
294
+ return sequence.cpu().flatten()
295
+
296
+
297
+ model_config = SimpleNamespace(
298
+ vocab_size = 50_257,
299
+ embed_dim = 768,
300
+ num_heads = 12,
301
+ seq_len = 1024,
302
+ depth = 12,
303
+ attention_dropout = 0.1,
304
+ residual_dropout = 0.1,
305
+ mlp_ratio = 4,
306
+ mlp_dropout = 0.1,
307
+ emb_dropout = 0.1,
308
+ device='cpu'
309
+ )
310
+
311
+
312
+
313
+ model = VisionGPT2Model.from_pretrained(model_config)
314
+ model.load_state_dict(torch.load("captioner.pt", map_location='cpu')) # Use 'cuda' if you have a GPU
315
+ model.eval() # Set the model to evaluation mode
316
+
317
+
318
+ def generate_caption(image,max_tokens=50,temperature=0.9,deterministic=True):
319
+ tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
320
+ tokenizer.pad_token = tokenizer.eos_token
321
+
322
+
323
+ gen_tfms = A.Compose([
324
+ A.Resize(224,224),
325
+ A.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5],always_apply=True),
326
+ ToTensorV2()
327
+ ])
328
+
329
+ image = Image.open(image)
330
+ image = np.array(image)
331
+ image = gen_tfms(image=image)['image']
332
+ image = image.unsqueeze(0)
333
+ sequence = torch.ones(1,1).long() * tokenizer.bos_token_id
334
+
335
+ caption = model.generate(
336
+ image,
337
+ sequence,
338
+ tokenizer,
339
+ max_tokens=max_tokens,
340
+ temperature=temperature,
341
+ deterministic=deterministic,
342
+
343
+ )
344
+ caption = tokenizer.decode(caption.numpy(),skip_special_tokens=True)
345
+ print(caption)
346
+ return caption
347
+
348
+ image = "/Users/jkottu/Desktop/image-captioning-chest-xrays/sample_images/CXR191_IM-0591-1001.png"
349
+ generate_caption(image)