Skyy93 commited on
Commit
a4fb052
1 Parent(s): 5f45e1d

Add all files

Browse files
.gitignore ADDED
File without changes
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from urllib.request import urlopen
2
+ import argparse
3
+ import clip
4
+ from PIL import Image
5
+ import pandas as pd
6
+ import time
7
+ import torch
8
+ from dataloader.extract_features_dataloader import transform_resize, question_preprocess
9
+ from model.vqa_model import NetVQA
10
+ from dataclasses import dataclass
11
+ from torch.cuda.amp import autocast
12
+ import gradio as gr
13
+
14
+ @dataclass
15
+ class InferenceConfig:
16
+ '''
17
+ Describes configuration of the training process
18
+ '''
19
+ model: str = "RN50x64"
20
+ checkpoint_root_clip: str = "./checkpoints/clip"
21
+ checkpoint_root_head: str = "./checkpoints/head"
22
+
23
+ use_question_preprocess: bool = True # True: delete ? at end
24
+
25
+ aux_mapping = {0: "unanswerable",
26
+ 1: "unsuitable",
27
+ 2: "yes",
28
+ 3: "no",
29
+ 4: "number",
30
+ 5: "color",
31
+ 6: "other"}
32
+ folds = 10
33
+ tta = False
34
+ # Data
35
+ n_classes: int = 5726
36
+
37
+ # class mapping
38
+ class_mapping: str = "./data/annotations/class_mapping.csv"
39
+
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+
42
+ config = InferenceConfig()
43
+
44
+ # load class mapping
45
+ cm = pd.read_csv(config.class_mapping)
46
+ classid_to_answer = {}
47
+ for i in range(len(cm)):
48
+ row = cm.iloc[i]
49
+ classid_to_answer[row["class_id"]] = row["answer"]
50
+
51
+ clip_model, preprocess = clip.load(config.model, download_root=config.checkpoint_root_clip)
52
+
53
+ model = NetVQA(config).to(config.device)
54
+
55
+
56
+ config.checkpoint_head = "{}/{}.pt".format(config.checkpoint_root_head, config.model)
57
+
58
+ model_state_dict = torch.load(config.checkpoint_head)
59
+ model.load_state_dict(model_state_dict, strict=True)
60
+
61
+
62
+ #%%
63
+ # Select Preprocessing
64
+ image_transforms = transform_resize(clip_model.visual.input_resolution)
65
+
66
+ if config.use_question_preprocess:
67
+ question_transforms = question_preprocess
68
+ else:
69
+ question_transforms = None
70
+
71
+ clip_model.eval()
72
+
73
+
74
+ def predict(img, text):
75
+ img = Image.fromarray(img)
76
+ if config.tta:
77
+ image_augmentations = []
78
+ for transform in image_transforms:
79
+ image_augmentations.append(transform(img))
80
+ img = torch.stack(image_augmentations, dim=0)
81
+ else:
82
+ img = image_transforms(img)
83
+ img = img.unsqueeze(dim=0)
84
+
85
+ question = question_transforms(text)
86
+ question_tokens = clip.tokenize(question, truncate=True)
87
+ with torch.no_grad():
88
+ img = img.to(config.device)
89
+ img_feature = clip_model.encode_image(img)
90
+ if config.tta:
91
+ weights = torch.tensor(config.features_selection).reshape((len(config.features_selection),1))
92
+ img_feature = img_feature * weights.to(config.device)
93
+ img_feature = img_feature.sum(0)
94
+ img_feature = img_feature.unsqueeze(0)
95
+
96
+ question_tokens = question_tokens.to(config.device)
97
+ question_feature = clip_model.encode_text(question_tokens)
98
+
99
+ with autocast():
100
+ output, output_aux = model(img_feature, question_feature)
101
+
102
+ prediction_vqa = dict()
103
+ output = output.cpu().squeeze(0)
104
+ for k, v in classid_to_answer.items():
105
+ prediction_vqa[v] = float(output[k])
106
+
107
+ prediction_aux = dict()
108
+ output_aux = output_aux.cpu().squeeze(0)
109
+ for k, v in config.aux_mapping.items():
110
+ prediction_aux[v] = float(output_aux[k])
111
+
112
+
113
+ return prediction_vqa, prediction_aux
114
+
115
+ gr.Interface(fn=predict,
116
+ inputs=[gr.Image(label='Image'), gr.Textbox(label='Question')],
117
+ outputs=[gr.outputs.Label(label='Answer', num_top_classes=5), gr.outputs.Label(label='Answer Category', num_top_classes=7)],
118
+ examples=[['examples/VizWiz_train_00004056.jpg', 'Is that a beer or a coke?'], ['examples/VizWiz_train_00017146.jpg', 'Can you tell me what\'s on this envelope please?'], ['examples/VizWiz_val_00003077.jpg', 'What is this?']]
119
+ ).launch()
120
+
data/annotations/class_mapping.csv ADDED
The diff for this file is too large to render. See raw diff
 
dataloader/__pycache__/extract_features_dataloader.cpython-39.pyc ADDED
Binary file (5.13 kB). View file
 
dataloader/extract_features_dataloader.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ import torch
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+ import clip
7
+ from torch.utils.data import DataLoader
8
+ import torchvision.transforms as tf
9
+ import torchvision.transforms.functional as TF
10
+
11
+
12
+ try:
13
+ from torchvision.transforms import InterpolationMode
14
+ BICUBIC = InterpolationMode.BICUBIC
15
+ except ImportError:
16
+ BICUBIC = Image.BICUBIC
17
+
18
+
19
+ class ExtractFeaturesDataset(Dataset):
20
+ def __init__(self,
21
+ annotations,
22
+ img_path,
23
+ image_transforms=None,
24
+ question_transforms=None,
25
+ tta=False):
26
+
27
+
28
+ self.img_path = img_path
29
+ self.image_transforms = image_transforms
30
+ self.question_transforms = question_transforms
31
+
32
+ self.img_ids = annotations["image_id"].values
33
+ self.split = annotations["split"].values
34
+ self.questions = annotations["question"].values
35
+
36
+ self.tta = tta
37
+
38
+
39
+
40
+ def __getitem__(self, index):
41
+
42
+ image_id = self.img_ids[index]
43
+ split = self.split[index]
44
+
45
+ # image input
46
+ with open(os.path.join(self.img_path, split, image_id), "rb") as f:
47
+ img = Image.open(f)
48
+
49
+ if self.tta:
50
+ image_augmentations = []
51
+
52
+ for transform in self.image_transforms:
53
+
54
+ image_augmentations.append(transform(img))
55
+
56
+
57
+ img = torch.stack(image_augmentations, dim=0)
58
+
59
+ else:
60
+ img = self.image_transforms(img)
61
+
62
+ question = self.questions[index]
63
+
64
+ if self.question_transforms:
65
+ question = self.question_transforms(question)
66
+
67
+ # question input
68
+ question = clip.tokenize(question, truncate=True)
69
+ question = question.squeeze()
70
+
71
+ return img, question, image_id
72
+
73
+ def __len__(self):
74
+ return len(self.img_ids)
75
+
76
+
77
+ def _convert_image_to_rgb(image):
78
+ return image.convert("RGB")
79
+
80
+
81
+ def Sharpen(sharpness_factor=1.0):
82
+
83
+ def wrapper(x):
84
+
85
+ return TF.adjust_sharpness(x, sharpness_factor)
86
+
87
+ return wrapper
88
+
89
+
90
+ def Rotate(angle=0.0):
91
+
92
+ def wrapper(x):
93
+ return TF.rotate(x, angle)
94
+
95
+ return wrapper
96
+
97
+ def transform_crop(n_px):
98
+ return tf.Compose([
99
+ tf.Resize(n_px, interpolation=BICUBIC),
100
+ tf.CenterCrop(n_px),
101
+ _convert_image_to_rgb,
102
+ tf.ToTensor(),
103
+ tf.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
104
+ ])
105
+
106
+ def transform_crop_rotate(n_px, rotation_angle=0.0):
107
+ return tf.Compose([
108
+ Rotate(angle=rotation_angle),
109
+ tf.Resize(n_px, interpolation=BICUBIC),
110
+ tf.CenterCrop(n_px),
111
+ _convert_image_to_rgb,
112
+ tf.ToTensor(),
113
+ tf.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
114
+ ])
115
+
116
+
117
+ def transform_resize(n_px):
118
+ return tf.Compose([
119
+ tf.Resize((n_px, n_px), interpolation=BICUBIC),
120
+ _convert_image_to_rgb,
121
+ tf.ToTensor(),
122
+ tf.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
123
+ ])
124
+
125
+
126
+ def transform_resize_rotate(n_px, rotation_angle=0.0):
127
+ return tf.Compose([
128
+ Rotate(angle=rotation_angle),
129
+ tf.Resize((n_px, n_px), interpolation=BICUBIC),
130
+ _convert_image_to_rgb,
131
+ tf.ToTensor(),
132
+ tf.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
133
+ ])
134
+
135
+ def get_tta_preprocess(img_size):
136
+
137
+ img_preprocess = [
138
+ transform_crop(img_size),
139
+ transform_crop_rotate(img_size, rotation_angle=90.0),
140
+ transform_crop_rotate(img_size, rotation_angle=270.0),
141
+ transform_resize(img_size),
142
+ transform_resize_rotate(img_size, rotation_angle=90.0),
143
+ transform_resize_rotate(img_size, rotation_angle=270.0),
144
+ ]
145
+
146
+ return img_preprocess
147
+
148
+ def question_preprocess(question, debug=False):
149
+
150
+ question = question.replace("?", ".")
151
+
152
+ if question[-1] == " ":
153
+ question = question[:-1]
154
+
155
+
156
+ if question[-1] != ".":
157
+ question = question + "."
158
+
159
+ if debug:
160
+ print("Question:", question)
161
+
162
+ return question
163
+
164
+
165
+ def get_dataloader_extraction(config):
166
+
167
+
168
+ if config.use_question_preprocess:
169
+ print("Using custom preprocessing: Question")
170
+ question_transforms = question_preprocess
171
+ else:
172
+ question_transforms = None
173
+
174
+ if config.tta:
175
+ ("Using augmentation transforms:")
176
+ img_preprocess = get_tta_preprocess(config.img_size)
177
+ else:
178
+ ("Using original CLIP transforms:")
179
+ img_preprocess = transform_crop(config.img_size)
180
+
181
+
182
+
183
+ train_data = pd.read_csv(config.train_annotations_path)
184
+
185
+ train_dataset = ExtractFeaturesDataset(annotations = train_data,
186
+ img_path=config.img_path,
187
+ image_transforms=img_preprocess,
188
+ question_transforms=question_transforms,
189
+ tta=config.tta)
190
+
191
+
192
+
193
+ train_loader = DataLoader(dataset=train_dataset,
194
+ batch_size=config.batch_size,
195
+ shuffle=False,
196
+ num_workers=config.num_workers)
197
+
198
+
199
+
200
+ test_data = pd.read_csv(config.test_annotations_path)
201
+
202
+ test_dataset = ExtractFeaturesDataset(annotations = test_data,
203
+ img_path=config.img_path,
204
+ image_transforms=img_preprocess,
205
+ question_transforms=question_transforms,
206
+ tta=config.tta)
207
+
208
+
209
+ test_loader = ExtractFeaturesDataset(dataset=test_dataset,
210
+ batch_size=config.batch_size,
211
+ shuffle=False,
212
+ num_workers=config.num_workers)
213
+
214
+ return train_loader, test_loader
215
+
216
+
217
+ def get_dataloader_inference(config):
218
+
219
+ if config.use_question_preprocess:
220
+ print("Using custom preprocessing: Question")
221
+ question_transforms = question_preprocess
222
+ else:
223
+ question_transforms = None
224
+
225
+ if config.tta:
226
+ ("Using augmentation transforms:")
227
+ img_preprocess = transform_resize(config.img_size)
228
+ else:
229
+ ("Using original CLIP transforms:")
230
+ img_preprocess = transform_crop(config.img_size)
231
+
232
+
233
+
234
+ train_data = pd.read_csv(config.train_annotations_path)
235
+
236
+ train_dataset = ExtractFeaturesDataset(annotations = train_data,
237
+ img_path=config.img_path,
238
+ image_transforms=img_preprocess,
239
+ question_transforms=question_transforms,
240
+ tta=config.tta)
241
+
242
+
243
+
244
+ train_loader = DataLoader(dataset=train_dataset,
245
+ batch_size=config.batch_size,
246
+ shuffle=False,
247
+ num_workers=config.num_workers)
248
+
249
+
250
+
251
+ test_data = pd.read_csv(config.test_annotations_path)
252
+
253
+ test_dataset = ExtractFeaturesDataset(annotations = test_data,
254
+ img_path=config.img_path,
255
+ image_transforms=img_preprocess,
256
+ question_transforms=question_transforms,
257
+ tta=config.tta)
258
+
259
+
260
+ test_loader = ExtractFeaturesDataset(dataset=test_dataset,
261
+ batch_size=config.batch_size,
262
+ shuffle=False,
263
+ num_workers=config.num_workers)
264
+
265
+ return train_loader, test_loader
266
+
267
+
268
+
examples/VizWiz_train_00004056.jpg ADDED
examples/VizWiz_train_00017146.jpg ADDED
examples/VizWiz_val_00003077.jpg ADDED
model/__pycache__/vqa_model.cpython-39.pyc ADDED
Binary file (2.84 kB). View file
 
model/vqa_model.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class HeadVQA(torch.nn.Module):
4
+ def __init__(self, train_config):
5
+ super().__init__()
6
+
7
+ embedding_size = {'RN50': 1024,
8
+ 'RN101': 512,
9
+ 'RN50x4': 640,
10
+ 'RN50x16': 768,
11
+ 'RN50x64': 1024,
12
+ 'ViT-B/32': 512,
13
+ 'ViT-B/16': 512,
14
+ 'ViT-L/14': 768,
15
+ 'ViT-L/14@336px': 768}
16
+
17
+ n_aux_classes = len(set(train_config.aux_mapping.values()))
18
+
19
+ self.ln1 = torch.nn.LayerNorm(embedding_size[train_config.model]*2)
20
+ self.dp1 = torch.nn.Dropout(0.5)
21
+ self.fc1 = torch.nn.Linear(embedding_size[train_config.model] * 2, 512)
22
+
23
+ self.ln2 = torch.nn.LayerNorm(512)
24
+ self.dp2 = torch.nn.Dropout(0.5)
25
+ self.fc2 = torch.nn.Linear(512, train_config.n_classes)
26
+
27
+ self.fc_aux = torch.nn.Linear(512, n_aux_classes)
28
+ self.fc_gate = torch.nn.Linear(n_aux_classes, train_config.n_classes)
29
+ self.act_gate = torch.nn.Sigmoid()
30
+
31
+
32
+ def forward(self, img_features, question_features):
33
+ xc = torch.cat((img_features, question_features), dim=-1)
34
+
35
+ x = self.ln1(xc)
36
+ x = self.dp1(x)
37
+ x = self.fc1(x)
38
+
39
+ aux = self.fc_aux(x)
40
+
41
+ gate = self.fc_gate(aux)
42
+ gate = self.act_gate(gate)
43
+
44
+ x = self.ln2(x)
45
+ x = self.dp2(x)
46
+ vqa = self.fc2(x)
47
+
48
+ output = vqa * gate
49
+
50
+ return output, aux
51
+
52
+
53
+ class NetVQA(torch.nn.Module):
54
+ def __init__(self, train_config):
55
+ super().__init__()
56
+
57
+ self.heads = torch.nn.ModuleList()
58
+
59
+ if isinstance(train_config.folds, list):
60
+ self.num_heads = len(train_config.folds)
61
+ else:
62
+ self.num_heads = train_config.folds
63
+
64
+ for i in range(self.num_heads):
65
+ self.heads.append(HeadVQA(train_config))
66
+
67
+
68
+ def forward(self, img_features, question_features):
69
+
70
+ output = []
71
+ output_aux = []
72
+
73
+ for head in self.heads:
74
+
75
+ logits, logits_aux = head(img_features, question_features)
76
+
77
+ probs = logits.softmax(-1)
78
+ probs_aux = logits_aux.softmax(-1)
79
+
80
+ output.append(probs)
81
+ output_aux.append(probs_aux)
82
+
83
+ output = torch.stack(output, dim=-1).mean(-1)
84
+ output_aux = torch.stack(output_aux, dim=-1).mean(-1)
85
+
86
+ return output, output_aux
87
+
88
+ def merge_vqa(train_config):
89
+
90
+ # Initialize model
91
+ model = NetVQA(train_config)
92
+
93
+
94
+ for fold in train_config.folds:
95
+
96
+ print("load weights from fold {} into head {}".format(fold, fold))
97
+
98
+ checkpoint_path = "{}/{}/fold_{}".format(train_config.model_path, train_config.model, fold)
99
+
100
+ if train_config.crossvalidation:
101
+ # load best checkpoint
102
+ model_state_dict = torch.load('{}/weights_best.pth'.format(checkpoint_path))
103
+ else:
104
+ # load checkpoint on train end
105
+ model_state_dict = torch.load('{}/weights_end.pth'.format(checkpoint_path))
106
+
107
+ model.heads[fold].load_state_dict(model_state_dict, strict=True)
108
+
109
+ checkpoint_path = "{}/{}/weights_merged.pth".format(train_config.model_path, train_config.model)
110
+
111
+ print("Saving weights of merged model:", checkpoint_path)
112
+
113
+ torch.save(model.state_dict(), checkpoint_path)
114
+
115
+ return model
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+