BigData-AI @ KSU commited on
Commit
ffb81ab
·
1 Parent(s): 149c4c6

actual bigmed model uploaded needs sample fixing

Browse files

this is without CLIP folder as the model will be cloned from github

.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ PathVQA_2Decoders_1024_30iterations_Trial4_CLIPVIT32.pth.tar filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1 +1,3 @@
1
- .idea/
 
 
 
1
+ .idea/
2
+ __pycache__/
3
+ CLIP/
MED_VQA_Huggyface_Gradio.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##### VQA MED Demo
2
+
3
+ import gradio as gr
4
+ from transformers import ViltProcessor, ViltForQuestionAnswering
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import CLIPTokenizer
8
+ from CLIP import clip
9
+ from Transformers_for_Caption import Transformer_Caption
10
+ import numpy as np
11
+ import torchvision.transforms as transforms
12
+
13
+ class Config(object):
14
+ def __init__(self):
15
+ # Learning Rates
16
+ # Transformer
17
+ self.hidden_dim = 512
18
+ self.pad_token_id = 0
19
+ self.max_position_embeddings = 76
20
+ self.layer_norm_eps = 1e-12
21
+ self.dropout = 0.1
22
+ self.vocab_size = 49408
23
+
24
+ self.enc_layers = 1
25
+ self.dec_layers = 1
26
+ self.dim_feedforward = 1024 #2048
27
+ self.nheads = 4
28
+ self.pre_norm = True
29
+ # Dataset
30
+ #self.dir = os.getcwd() + '/data/coco'
31
+ self.limit = -1
32
+
33
+
34
+
35
+ ##### OUR MODEL
36
+
37
+ class VQA_Net(nn.Module):
38
+ def __init__(self, num_classes):
39
+ super(VQA_Net,self).__init__()
40
+ #self.VIT = deit_base_distilled_patch16_224(pretrained=True)
41
+ #self.VIT =vit_base_patch16_224_dino(pretrained=True)
42
+ #self.VIT = vit_base_patch32_sam_224(pretrained=True) ###### please not that we used only 6 layers
43
+ #self.VIT=maxvit_rmlp_nano_rw_256(pretrained=True)
44
+ #self.VIT = vit_base_patch8_224(pretrained=True)
45
+ #self.VIT=m = tf_efficientnetv2_m(pretrained=True, features_only=True, out_indices=(1,3), feature_location='expansion')
46
+ self.backbone, _ = clip.load('ViT-B/32', 'cpu', jit=False)
47
+ self.input_proj = nn.LayerNorm(512) # nn.Sequential(nn.LayerNorm(768),nn.Linear(768,768),nn.GELU(),nn.Dropout(0.1))
48
+ self.transformer_decoder = Transformer_Caption(config,num_decoder_layers=2)
49
+ self.mlp = nn.Sequential(nn.Sequential(nn.Linear(512, num_classes))) # MLP(256, 512, 30522, 1) 49408)
50
+ #self.samples_proj = nn.Sequential(nn.Linear(768,512))
51
+ self.samples_proj = nn.Identity()
52
+ self.question_proj = nn.Identity() #nn.Sequential(nn.Linear(512, 512,bias=False)) # nn.Sequential(nn.LayerNorm(768),nn.Linear(768,768),nn.GELU(),nn.Dropout(0.1))
53
+ #self.tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
54
+
55
+ def forward(self, samples, question_in, answer_out, mask_answer):
56
+ # print('Here')
57
+ #print(samples.shape)
58
+ _, _,samples = self.backbone.encode_image(samples)
59
+
60
+ #samples=self.VIT(samples)
61
+ #print(samples.shape)
62
+ samples=samples.float()
63
+ #samples = self.VIT(samples)
64
+ #print(`samples.shape)
65
+ #samples = samples.view(-1, 512, 8 * 8)
66
+ # print(img_seq.shape)
67
+ #samples = samples.permute(0, 2, 1)
68
+ #samples=samples[:,0:,:] @ self.samples_proj
69
+ samples = self.samples_proj(samples)
70
+ #print(samples.shape)
71
+ #print(samples.shape)
72
+ _, _,question_in = self.backbone.encode_text(question_in)
73
+ #print(question_in.shape)
74
+ #samples = self.samples_proj(samples.float())
75
+ question_in = self.question_proj(question_in.float())
76
+ #print(question_in.shape)
77
+ #print(samples.shape)
78
+ samples = torch.cat((samples, question_in), dim=1)
79
+ #print(samples.shape)
80
+
81
+ # src, mask = features[-1].decompose()
82
+ # assert mask is not None
83
+ hs = self.transformer_decoder(self.input_proj(samples.permute(1, 0, 2).float()), answer_out, tgt_mask=mask_answer)
84
+ out = self.mlp(hs.permute(1, 0, 2))
85
+ # print(out.shape)
86
+ return out
87
+
88
+ config = Config()
89
+ Tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
90
+ My_VQA = VQA_Net(num_classes=len(Tokenizer))
91
+ My_VQA.load_state_dict(torch.load("./PathVQA_2Decoders_1024_30iterations_Trial4_CLIPVIT32.pth.tar",map_location= torch.device("cuda" if torch.cuda.is_available() else "cpu")))
92
+
93
+
94
+ tfms = transforms.Compose([
95
+ #transforms.Lambda(under_max),
96
+ transforms.Resize((224, 224)),
97
+ transforms.ToTensor(),
98
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
99
+ std=[0.229, 0.224, 0.225])
100
+ # transforms.Normalize(0.5, 0.5),
101
+ ])
102
+
103
+
104
+ def answer_question(image, text_question):
105
+ with torch.no_grad():
106
+ for iter in range(1):
107
+ start_token = Tokenizer.convert_tokens_to_ids("<|startoftext|>")
108
+ # end_token = Tokenizer.convert_tokens_to_ids("<|endoftext|>")
109
+ # start_token=tokenizer.convert_tokens_to_ids(tokenizer._cls_token)
110
+ caption = torch.zeros((1, config.max_position_embeddings), dtype=torch.long)
111
+ cap_mask = torch.ones((1, config.max_position_embeddings), dtype=torch.bool)
112
+ caption[:, 0] = start_token
113
+ cap_mask[:, 0] = False
114
+ print(text_question)
115
+ if text_question.find('?') > -1:
116
+ text_question = text_question.split('?')[0].lower()
117
+ text_question= np.array(Tokenizer.encode_plus(text_question, max_length=77, pad_to_max_length=True,return_attention_mask=True,
118
+ return_token_type_ids=False, truncation=True)['input_ids'])
119
+ #print(torch.Tensor(text_question).unsqueeze(0).long())
120
+ for i in range(config.max_position_embeddings - 1):
121
+ predictions = My_VQA(image.unsqueeze(0),torch.Tensor(text_question).unsqueeze(0).long(), caption,cap_mask)
122
+ predictions = predictions[:, i, :]
123
+ predicted_id = torch.argmax(predictions, axis=-1)
124
+ caption[:, i + 1] = predicted_id[0]
125
+ cap_mask[:, i + 1] = False
126
+ if predicted_id[0] == 49407:
127
+ break
128
+ #print('question:')
129
+ #print(batch_test['question'])
130
+ cap_result_intermediate = Tokenizer.decode(caption[0].tolist(), skip_special_tokens=True)
131
+ #print('+++++++++++++++++++++++++++++++++++')
132
+ #print("True:")
133
+ # print(ref_sentence)
134
+ cap_result = cap_result_intermediate.split('!')
135
+ #ref_sentence = batch_test['answer'].lower()
136
+ #print(ref_sentence)
137
+ #print("Predict:")
138
+ #print(cap_result)
139
+ # image_disp=inv_Normalize(batch_test['image'])[0].permute(1,2,0).detach().cpu().numpy()
140
+ # print('************************')
141
+ # plt.imshow(image_disp)
142
+ return cap_result
143
+
144
+
145
+ def infer_answer_question(image, text):
146
+ if text is None:
147
+ cap_result = "please write a question"
148
+ elif image is None:
149
+ cap_result = "please upload an image"
150
+ else:
151
+ image_encoded = tfms(image)
152
+ print(image_encoded)
153
+ cap_result=answer_question(image_encoded,text)[0]
154
+
155
+ return cap_result
156
+
157
+
158
+ image = gr.inputs.Image(type="pil")
159
+ question = gr.inputs.Textbox(label="Question")
160
+ answer = gr.outputs.Textbox(label="Predicted answer")
161
+ examples = [["train_0000.jpg", "Where are liver stem cells (oval cells) located?"],
162
+ ["train_0001.jpg", "What are stained here with an immunohistochemical stain for cytokeratin 7?"],
163
+ ["train_0002.jpg", "What are bile duct cells and canals of Hering stained here with for cytokeratin 7?"],
164
+ ["train_0003.jpg", "Are bile duct cells and canals of Hering stained here with an immunohistochemical stain for cytokeratin 7?"],
165
+ ["train_0018.jpg", "Is there an infarct in the brain hypertrophy?"],
166
+ ["train_0019.jpg", "What is ischemic coagulative necrosis?"]]
167
+
168
+ title = "Interactive Vsisual Question Answering demo(BigMed@ai: Artificial Intelligence for Large-Scale Medical Image Analysis)"
169
+ description = "<div style='display: flex;align-items: center;justify-content: space-between;'><p style='width:60vw;'>Gradio Demo for VQA medical model trained on PathVQA dataset, To use it, upload your image and type a question and click 'submit', or click one of the examples to load them.</p><a href='https://github.com/dandelin/ViLT' target='_blank' class='link'><img src='file/GitHub.png' style='justify-self:margin-top:0.5em;center; width:calc(200px + 5vw);'></a></div>"
170
+ ### link to paper and github code
171
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2102.03334' target='_blank'>BigMed@ai</a> | <a href='https://github.com/dandelin/ViLT' target='_blank'>Github Repo</a></p>"
172
+
173
+ interface = gr.Interface(fn=infer_answer_question,
174
+ inputs=[image, question],
175
+ outputs=answer,
176
+ examples=examples,
177
+ title=title,
178
+ description=description,
179
+ article=article,
180
+ enable_queue=True)
181
+ interface.launch(debug=True)
PathVQA_2Decoders_1024_30iterations_Trial4_CLIPVIT32.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79262e9686303e4e8c515078b820341394b6b380382be0819c2c01d9dd9eaa51
3
+ size 589964081
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: yellow
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.15.0
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
 
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.15.0
8
+ app_file: MED_VQA_Huggyface_Gradio.py
9
  pinned: false
10
  ---
11
 
Transformers_for_Caption.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import copy
3
+ from typing import Optional, List
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn, Tensor
8
+
9
+
10
+ class Transformer_Caption(nn.Module):
11
+
12
+ def __init__(self, config,d_model=512, nhead=4, num_encoder_layers=1,
13
+ num_decoder_layers=2, dim_feedforward=1024, dropout=0.1,
14
+ activation="gelu", normalize_before=False,
15
+ return_intermediate_dec=False):
16
+ super().__init__()
17
+
18
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
19
+ dropout, activation, normalize_before)
20
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
21
+ self.encoder = TransformerEncoder(
22
+ encoder_layer, num_encoder_layers, encoder_norm)
23
+
24
+ self.embeddings = DecoderEmbeddings(config)
25
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
26
+ dropout, activation, normalize_before)
27
+ decoder_norm = nn.LayerNorm(d_model)
28
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
29
+ return_intermediate=return_intermediate_dec)
30
+ print("Num decoders:")
31
+ print(num_decoder_layers)
32
+ self._reset_parameters()
33
+
34
+ self.d_model = d_model
35
+ self.nhead = nhead
36
+
37
+ def _reset_parameters(self):
38
+ for p in self.parameters():
39
+ if p.dim() > 1:
40
+ nn.init.xavier_uniform_(p)
41
+
42
+ def forward(self, src, tgt, tgt_mask):
43
+ # flatten NxCxHxW to HWxNxC
44
+ #print("HERRRRRR")
45
+ #print(src.shape)
46
+ h, bs, w = src.shape
47
+ #src = src.permute(1, 0, 2)
48
+ #print("SRCCCCCCCC")
49
+ #print(src.shape)
50
+ #pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
51
+ #mask = mask.flatten(1)
52
+ #print(num_decoder_layers)
53
+
54
+ tgt = self.embeddings(tgt).permute(1, 0, 2)
55
+ query_embed = self.embeddings.position_embeddings.weight.unsqueeze(1)
56
+ query_embed = query_embed.repeat(1, bs, 1)
57
+ #print("firstmyyyyyyyyyyyyyy")
58
+ #print(tgt.shape)
59
+ #print(tgt_mask.shape)
60
+ #print(pos_embed.shape)
61
+ #print(query_embed.shape)
62
+ #print(generate_square_subsequent_mask(len(tgt)).to(tgt.device).shape)
63
+ #print(src.shape)
64
+
65
+ #memory = self.encoder(src, src_key_padding_mask=None, pos=None)
66
+ #memory = self.encoder(src)
67
+ #print("then....")
68
+ #print(tgt_mask.shape)
69
+ hs = self.decoder(tgt, src, memory_key_padding_mask=None, tgt_key_padding_mask=tgt_mask,
70
+ pos=None, query_pos=query_embed,
71
+ tgt_mask=generate_square_subsequent_mask(len(tgt)).to(tgt.device))
72
+ #hs = self.decoder(tgt, memory, tgt_key_padding_mask=tgt_mask,query_pos=query_embed,tgt_mask=generate_square_subsequent_mask(len(tgt)).to(tgt.device))
73
+
74
+ return hs
75
+
76
+
77
+ class TransformerEncoder(nn.Module):
78
+
79
+ def __init__(self, encoder_layer, num_layers, norm=None):
80
+ super().__init__()
81
+ self.layers = _get_clones(encoder_layer, num_layers)
82
+ self.num_layers = num_layers
83
+ self.norm = norm
84
+
85
+ def forward(self, src,
86
+ mask: Optional[Tensor] = None,
87
+ src_key_padding_mask: Optional[Tensor] = None,
88
+ pos: Optional[Tensor] = None):
89
+ output = src
90
+
91
+ for layer in self.layers:
92
+ output = layer(output, src_mask=mask,
93
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
94
+
95
+ if self.norm is not None:
96
+ output = self.norm(output)
97
+
98
+ return output
99
+
100
+
101
+ class TransformerDecoder(nn.Module):
102
+
103
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
104
+ super().__init__()
105
+ self.layers = _get_clones(decoder_layer, num_layers)
106
+ self.num_layers = num_layers
107
+ self.norm = norm
108
+ self.return_intermediate = return_intermediate
109
+
110
+ def forward(self, tgt, memory,
111
+ tgt_mask: Optional[Tensor] = None,
112
+ memory_mask: Optional[Tensor] = None,
113
+ tgt_key_padding_mask: Optional[Tensor] = None,
114
+ memory_key_padding_mask: Optional[Tensor] = None,
115
+ pos: Optional[Tensor] = None,
116
+ query_pos: Optional[Tensor] = None):
117
+ output = tgt
118
+
119
+ intermediate = []
120
+
121
+ for layer in self.layers:
122
+ output = layer(output, memory, tgt_mask=tgt_mask,
123
+ memory_mask=memory_mask,
124
+ tgt_key_padding_mask=tgt_key_padding_mask,
125
+ memory_key_padding_mask=memory_key_padding_mask,
126
+ pos=pos, query_pos=query_pos)
127
+ if self.return_intermediate:
128
+ intermediate.append(self.norm(output))
129
+
130
+ if self.norm is not None:
131
+ output = self.norm(output)
132
+ if self.return_intermediate:
133
+ intermediate.pop()
134
+ intermediate.append(output)
135
+
136
+ if self.return_intermediate:
137
+ return torch.stack(intermediate)
138
+
139
+ return output
140
+
141
+
142
+ class TransformerEncoderLayer(nn.Module):
143
+
144
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
145
+ activation="relu", normalize_before=False):
146
+ super().__init__()
147
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
148
+ # Implementation of Feedforward model
149
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
150
+ self.dropout = nn.Dropout(dropout)
151
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
152
+
153
+ self.norm1 = nn.LayerNorm(d_model)
154
+ self.norm2 = nn.LayerNorm(d_model)
155
+ self.dropout1 = nn.Dropout(dropout)
156
+ self.dropout2 = nn.Dropout(dropout)
157
+
158
+ self.activation = _get_activation_fn(activation)
159
+ self.normalize_before = normalize_before
160
+
161
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
162
+ return tensor if pos is None else tensor + pos
163
+
164
+ def forward_post(self,
165
+ src,
166
+ src_mask: Optional[Tensor] = None,
167
+ src_key_padding_mask: Optional[Tensor] = None,
168
+ pos: Optional[Tensor] = None):
169
+ q = k = self.with_pos_embed(src, pos)
170
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
171
+ key_padding_mask=src_key_padding_mask)[0]
172
+ src = src + self.dropout1(src2)
173
+ src = self.norm1(src)
174
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
175
+ src = src + self.dropout2(src2)
176
+ src = self.norm2(src)
177
+ return src
178
+
179
+ def forward_pre(self, src,
180
+ src_mask: Optional[Tensor] = None,
181
+ src_key_padding_mask: Optional[Tensor] = None,
182
+ pos: Optional[Tensor] = None):
183
+ src2 = self.norm1(src)
184
+ q = k = self.with_pos_embed(src2, pos)
185
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
186
+ key_padding_mask=src_key_padding_mask)[0]
187
+ src = src + self.dropout1(src2)
188
+ src2 = self.norm2(src)
189
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
190
+ src = src + self.dropout2(src2)
191
+ return src
192
+
193
+ def forward(self, src,
194
+ src_mask: Optional[Tensor] = None,
195
+ src_key_padding_mask: Optional[Tensor] = None,
196
+ pos: Optional[Tensor] = None):
197
+ if self.normalize_before:
198
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
199
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
200
+
201
+
202
+ class TransformerDecoderLayer(nn.Module):
203
+
204
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
205
+ activation="relu", normalize_before=False):
206
+ super().__init__()
207
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
208
+ self.multihead_attn = nn.MultiheadAttention(
209
+ d_model, nhead, dropout=dropout)
210
+ # Implementation of Feedforward model
211
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
212
+ self.dropout = nn.Dropout(dropout)
213
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
214
+
215
+ self.norm1 = nn.LayerNorm(d_model)
216
+ self.norm2 = nn.LayerNorm(d_model)
217
+ self.norm3 = nn.LayerNorm(d_model)
218
+ self.dropout1 = nn.Dropout(dropout)
219
+ self.dropout2 = nn.Dropout(dropout)
220
+ self.dropout3 = nn.Dropout(dropout)
221
+
222
+ self.activation = _get_activation_fn(activation)
223
+ self.normalize_before = normalize_before
224
+
225
+
226
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
227
+ return tensor if pos is None else tensor + pos
228
+
229
+ def forward_post(self, tgt, memory,
230
+ tgt_mask: Optional[Tensor] = None,
231
+ memory_mask: Optional[Tensor] = None,
232
+ tgt_key_padding_mask: Optional[Tensor] = None,
233
+ memory_key_padding_mask: Optional[Tensor] = None,
234
+ pos: Optional[Tensor] = None,
235
+ query_pos: Optional[Tensor] = None):
236
+ #print(tgt.shape)
237
+ #print(query_pos.shape)
238
+
239
+ q = k = self.with_pos_embed(tgt, query_pos)
240
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
241
+ key_padding_mask=tgt_key_padding_mask)[0]
242
+ tgt = tgt + self.dropout1(tgt2)
243
+ tgt = self.norm1(tgt)
244
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
245
+ key=self.with_pos_embed(memory, pos),
246
+ value=memory, attn_mask=memory_mask,
247
+ key_padding_mask=memory_key_padding_mask)[0]
248
+ tgt = tgt + self.dropout2(tgt2)
249
+ tgt = self.norm2(tgt)
250
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
251
+ tgt = tgt + self.dropout3(tgt2)
252
+ tgt = self.norm3(tgt)
253
+ return tgt
254
+
255
+ def forward_pre(self, tgt, memory,
256
+ tgt_mask: Optional[Tensor] = None,
257
+ memory_mask: Optional[Tensor] = None,
258
+ tgt_key_padding_mask: Optional[Tensor] = None,
259
+ memory_key_padding_mask: Optional[Tensor] = None,
260
+ pos: Optional[Tensor] = None,
261
+ query_pos: Optional[Tensor] = None):
262
+ tgt2 = self.norm1(tgt)
263
+ q = k = self.with_pos_embed(tgt2, query_pos)
264
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
265
+ key_padding_mask=tgt_key_padding_mask)[0]
266
+ tgt = tgt + self.dropout1(tgt2)
267
+ tgt2 = self.norm2(tgt)
268
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
269
+ key=self.with_pos_embed(memory, pos),
270
+ value=memory, attn_mask=memory_mask,
271
+ key_padding_mask=memory_key_padding_mask)[0]
272
+ tgt = tgt + self.dropout2(tgt2)
273
+ tgt2 = self.norm3(tgt)
274
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
275
+ tgt = tgt + self.dropout3(tgt2)
276
+ return tgt
277
+
278
+ def forward(self, tgt, memory,
279
+ tgt_mask: Optional[Tensor] = None,
280
+ memory_mask: Optional[Tensor] = None,
281
+ tgt_key_padding_mask: Optional[Tensor] = None,
282
+ memory_key_padding_mask: Optional[Tensor] = None,
283
+ pos: Optional[Tensor] = None,
284
+ query_pos: Optional[Tensor] = None):
285
+ if self.normalize_before:
286
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
287
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
288
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
289
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
290
+
291
+
292
+ class DecoderEmbeddings(nn.Module):
293
+ def __init__(self, config):
294
+ super().__init__()
295
+ self.word_embeddings = nn.Embedding(
296
+ config.vocab_size, config.hidden_dim, padding_idx=config.pad_token_id)
297
+ self.position_embeddings = nn.Embedding(
298
+ config.max_position_embeddings, config.hidden_dim
299
+ )
300
+
301
+ self.LayerNorm = torch.nn.LayerNorm(
302
+ config.hidden_dim, eps=config.layer_norm_eps)
303
+ self.dropout = nn.Dropout(config.dropout)
304
+
305
+ def forward(self, x):
306
+ input_shape = x.size()
307
+ x=x.long()
308
+ #print(x.shape)
309
+ seq_length = input_shape[1]
310
+ device = x.device
311
+
312
+ position_ids = torch.arange(
313
+ seq_length, dtype=torch.long, device=device)
314
+ position_ids = position_ids.unsqueeze(0).expand(input_shape)
315
+ input_embeds = self.word_embeddings(x)
316
+ position_embeds = self.position_embeddings(position_ids)
317
+
318
+
319
+ embeddings = input_embeds + position_embeds
320
+ embeddings = self.LayerNorm(embeddings)
321
+ embeddings = self.dropout(embeddings)
322
+
323
+ #print(embeddings)
324
+
325
+ return embeddings
326
+
327
+
328
+ def _get_clones(module, N):
329
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
330
+
331
+
332
+ def _get_activation_fn(activation):
333
+ """Return an activation function given a string"""
334
+ if activation == "relu":
335
+ return F.relu
336
+ if activation == "gelu":
337
+ return F.gelu
338
+ if activation == "glu":
339
+ return F.glu
340
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
341
+
342
+
343
+ def generate_square_subsequent_mask(sz):
344
+ r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
345
+ Unmasked positions are filled with float(0.0).
346
+ """
347
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
348
+ mask = mask.float().masked_fill(mask == 0, float(
349
+ '-inf')).masked_fill(mask == 1, float(0.0))
350
+ return mask
351
+
352
+
353
+ def build_transformer(config):
354
+ return Transformer_Caption(
355
+ config,
356
+ d_model=config.hidden_dim,
357
+ dropout=config.dropout,
358
+ nhead=config.nheads,
359
+ dim_feedforward=config.dim_feedforward,
360
+ num_encoder_layers=config.enc_layers,
361
+ num_decoder_layers=config.dec_layers,
362
+ normalize_before=config.pre_norm,
363
+ return_intermediate_dec=False,
364
+ )
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
- torch == 1.13.1
2
- transformers == 4.25.1
 
 
 
 
1
+ clip-by-openai
2
+ transformers
3
+ torch
4
+ numpy
5
+ ftfy
train_0000.jpg ADDED
train_0001.jpg ADDED
train_0002.jpg ADDED
train_0003.jpg ADDED
train_0004.jpg ADDED
train_0018.jpg ADDED
train_0019.jpg ADDED
train_0020.jpg ADDED
train_0021.jpg ADDED
train_0022.jpg ADDED