hasibzunair commited on
Commit
46fdf2a
1 Parent(s): 93775f8

inital files

Browse files
000001.jpg ADDED
000006.jpg ADDED
000009.jpg ADDED
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import argparse
5
+ import time
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+
10
+ from tqdm import tqdm
11
+ from PIL import Image
12
+ from torch.utils.data import DataLoader
13
+ from PIL import Image
14
+ from torchvision import transforms
15
+
16
+ from pipeline.resnet_csra import ResNet_CSRA
17
+ from pipeline.vit_csra import VIT_B16_224_CSRA, VIT_L16_224_CSRA, VIT_CSRA
18
+ from pipeline.dataset import DataSet
19
+ from torchvision.transforms import transforms
20
+ from utils.evaluation.eval import voc_classes, wider_classes, coco_classes, class_dict
21
+
22
+ torch.manual_seed(0)
23
+
24
+ if torch.cuda.is_available():
25
+ torch.backends.cudnn.deterministic = True
26
+
27
+ # Device
28
+ DEVICE = "cpu"
29
+ print(DEVICE)
30
+
31
+ # Make directories
32
+ os.system("mkdir ./models")
33
+
34
+ # Get model weights
35
+ if not os.path.exists("./models/msl_c_voc.pth"):
36
+ os.system(
37
+ "wget -O ./models/msl_c_voc.pth https://github.com/hasibzunair/msl-recognition/releases/download/v1.0-models/msl_c_voc.pth"
38
+ )
39
+
40
+ # Load model
41
+ model = ResNet_CSRA(num_heads=1, lam=0.1, num_classes=20)
42
+ normalize = transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
43
+ model.to(DEVICE)
44
+ print("Loading weights from {}".format("./models/msl_c_voc.pth"))
45
+ model.load_state_dict(torch.load("./models/msl_c_voc.pth"))
46
+
47
+ # Inference!
48
+ def inference(img_path):
49
+ # read image
50
+ image = Image.open(img_path).convert("RGB")
51
+
52
+ # image pre-process
53
+ transforms_image = transforms.Compose([
54
+ transforms.Resize((448, 448)),
55
+ transforms.ToTensor(),
56
+ normalize
57
+ ])
58
+
59
+ image = transforms_image(image)
60
+ image = image.unsqueeze(0)
61
+
62
+ # Predict
63
+ result = []
64
+ model.eval()
65
+ with torch.no_grad():
66
+ image = image.to(DEVICE)
67
+ logit = model(image).squeeze(0)
68
+ logit = nn.Sigmoid()(logit)
69
+
70
+ pos = torch.where(logit > 0.5)[0].cpu().numpy()
71
+ for k in pos:
72
+ result.append(str(class_dict["voc07"][k]))
73
+ return result
74
+
75
+
76
+ # Define ins outs placeholders
77
+ inputs = gr.inputs.Image(type="filepath", label="Input Image")
78
+
79
+ # Define style
80
+ title = "Learning to Recognize Occluded and Small Objects with Partial Inputs"
81
+ description = "TBA."
82
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1512.03385' target='_blank'>Learning to Recognize Occluded and Small Objects with Partial Inputs</a> | <a href='https://github.com/hasibzunair/msl-recognition' target='_blank'>Github Repo</a></p>"
83
+
84
+ voc_classes = ("aeroplane", "bicycle", "bird", "boat", "bottle",
85
+ "bus", "car", "cat", "chair", "cow", "diningtable",
86
+ "dog", "horse", "motorbike", "person", "pottedplant",
87
+ "sheep", "sofa", "train", "tvmonitor")
88
+
89
+ # Run inference
90
+ gr.Interface(inference,
91
+ inputs,
92
+ outputs="text",
93
+ examples=["demo_images/000001.jpg", "demo_images/000006.jpg", "demo_images/000009.jpg"],
94
+ title=title,
95
+ description=description,
96
+ article=article,
97
+ analytics_enabled=False).launch()
pipeline/csra.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+ class CSRA(nn.Module): # one basic block
7
+ def __init__(self, input_dim, num_classes, T, lam):
8
+ super(CSRA, self).__init__()
9
+ self.T = T # temperature
10
+ self.lam = lam # Lambda
11
+ self.head = nn.Conv2d(input_dim, num_classes, 1, bias=False)
12
+ self.softmax = nn.Softmax(dim=2)
13
+
14
+ def forward(self, x):
15
+ # x (B d H W)
16
+ # normalize classifier
17
+ # score (B C HxW)
18
+ score = self.head(x) / torch.norm(self.head.weight, dim=1, keepdim=True).transpose(0,1)
19
+ score = score.flatten(2)
20
+ base_logit = torch.mean(score, dim=2)
21
+
22
+ if self.T == 99: # max-pooling
23
+ att_logit = torch.max(score, dim=2)[0]
24
+ else:
25
+ score_soft = self.softmax(score * self.T)
26
+ # https://github.com/Kevinz-code/CSRA/issues/5
27
+ att_logit = torch.sum(score * score_soft, dim=2)
28
+
29
+ return base_logit + self.lam * att_logit
30
+
31
+
32
+
33
+
34
+ class MHA(nn.Module): # multi-head attention
35
+ temp_settings = { # softmax temperature settings
36
+ 1: [1],
37
+ 2: [1, 99],
38
+ 4: [1, 2, 4, 99],
39
+ 6: [1, 2, 3, 4, 5, 99],
40
+ 8: [1, 2, 3, 4, 5, 6, 7, 99]
41
+ }
42
+
43
+ def __init__(self, num_heads, lam, input_dim, num_classes):
44
+ super(MHA, self).__init__()
45
+ self.temp_list = self.temp_settings[num_heads]
46
+ self.multi_head = nn.ModuleList([
47
+ CSRA(input_dim, num_classes, self.temp_list[i], lam)
48
+ for i in range(num_heads)
49
+ ])
50
+
51
+ def forward(self, x):
52
+ logit = 0.
53
+ for head in self.multi_head:
54
+ logit += head(x)
55
+ return logit
pipeline/dataset.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import glob
3
+ import random
4
+
5
+ from torch.utils.data import Dataset
6
+ from PIL import Image
7
+ from torchvision.transforms import transforms
8
+ import torch
9
+ import numpy as np
10
+
11
+ try:
12
+ from torchvision.transforms import InterpolationMode
13
+
14
+ BICUBIC = InterpolationMode.BICUBIC
15
+ except ImportError:
16
+ BICUBIC = Image.BICUBIC
17
+
18
+
19
+ # modify for transformation for vit
20
+ # modfify wider crop-person images
21
+
22
+
23
+ ###### Base data loader ######
24
+ class DataSet(Dataset):
25
+ def __init__(
26
+ self,
27
+ ann_files,
28
+ augs,
29
+ img_size,
30
+ dataset,
31
+ ):
32
+ self.dataset = dataset
33
+ self.ann_files = ann_files
34
+ self.augment = self.augs_function(augs, img_size)
35
+ self.transform = transforms.Compose(
36
+ [transforms.ToTensor(), transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])]
37
+ # In this paper, we normalize the image data to [0, 1]
38
+ # You can also use the so called 'ImageNet' Normalization method
39
+ )
40
+ self.anns = []
41
+ self.load_anns()
42
+ print(self.augment)
43
+
44
+ # in wider dataset we use vit models
45
+ # so transformation has been changed
46
+ if self.dataset == "wider":
47
+ self.transform = transforms.Compose(
48
+ [
49
+ transforms.ToTensor(),
50
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
51
+ ]
52
+ )
53
+
54
+ def augs_function(self, augs, img_size):
55
+ t = []
56
+ if "randomflip" in augs:
57
+ t.append(transforms.RandomHorizontalFlip())
58
+ if "ColorJitter" in augs:
59
+ t.append(
60
+ transforms.ColorJitter(
61
+ brightness=0.5, contrast=0.5, saturation=0.5, hue=0
62
+ )
63
+ )
64
+ if "resizedcrop" in augs:
65
+ t.append(transforms.RandomResizedCrop(img_size, scale=(0.7, 1.0)))
66
+ if "RandAugment" in augs:
67
+ t.append(RandAugment())
68
+
69
+ t.append(transforms.Resize((img_size, img_size)))
70
+
71
+ return transforms.Compose(t)
72
+
73
+ def load_anns(self):
74
+ self.anns = []
75
+ for ann_file in self.ann_files:
76
+ json_data = json.load(open(ann_file, "r"))
77
+ self.anns += json_data
78
+
79
+ def __len__(self):
80
+ return len(self.anns)
81
+
82
+ def __getitem__(self, idx):
83
+ idx = idx % len(self)
84
+ ann = self.anns[idx]
85
+ img = Image.open(ann["img_path"]).convert("RGB")
86
+
87
+ if self.dataset == "wider":
88
+ x, y, w, h = ann["bbox"]
89
+ img_area = img.crop([x, y, x + w, y + h])
90
+ img_area = self.augment(img_area)
91
+ img_area = self.transform(img_area)
92
+ message = {
93
+ "img_path": ann["img_path"],
94
+ "target": torch.Tensor(ann["target"]),
95
+ "img": img_area,
96
+ }
97
+ else: # voc and coco
98
+ img = self.augment(img)
99
+ img = self.transform(img)
100
+ message = {
101
+ "img_path": ann["img_path"],
102
+ "target": torch.Tensor(ann["target"]),
103
+ "img": img,
104
+ }
105
+
106
+ return message
107
+ # finally, if we use dataloader to get the data, we will get
108
+ # {
109
+ # "img_path": list, # length = batch_size
110
+ # "target": Tensor, # shape: batch_size * num_classes
111
+ # "img": Tensor, # shape: batch_size * 3 * 224 * 224
112
+ # }
113
+
114
+
115
+ def preprocess_scribble(img, img_size):
116
+ transform = transforms.Compose(
117
+ [
118
+ transforms.Resize(img_size, BICUBIC),
119
+ transforms.CenterCrop(img_size),
120
+ #_convert_image_to_rgb,
121
+ transforms.ToTensor(),
122
+ ]
123
+ )
124
+ return transform(img)
125
+
126
+
127
+ class DataSetMaskSup(Dataset):
128
+ """
129
+ Data loader with scribbles.
130
+ """
131
+ def __init__(
132
+ self,
133
+ ann_files,
134
+ augs,
135
+ img_size,
136
+ dataset,
137
+ ):
138
+ self.dataset = dataset
139
+ self.ann_files = ann_files
140
+ self.img_size = img_size
141
+ self.augment = self.augs_function(augs, img_size)
142
+ self.transform = transforms.Compose(
143
+ [transforms.ToTensor(), transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])]
144
+ # In this paper, we normalize the image data to [0, 1]
145
+ # You can also use the so called 'ImageNet' Normalization method
146
+ )
147
+ self.anns = []
148
+ self.load_anns()
149
+ print(self.augment)
150
+
151
+ # scribbles
152
+ self._scribbles_folder = "./datasets/SCRIBBLES"
153
+
154
+ # Type of masks to use, this is hardcoded since we find that high masks
155
+ # work better in MSL. See paper for details.
156
+
157
+ # for low masks
158
+ # self._scribbles = sorted(glob.glob(self._scribbles_folder + "/*.png"))[
159
+ # :1000
160
+ # ]
161
+
162
+ # for high masks
163
+ self._scribbles = sorted(glob.glob(self._scribbles_folder + "/*.png"))[::-1][
164
+ :1000
165
+ ]
166
+
167
+ # in wider dataset we use vit models
168
+ # so transformation has been changed
169
+ if self.dataset == "wider":
170
+ self.transform = transforms.Compose(
171
+ [
172
+ transforms.ToTensor(),
173
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
174
+ ]
175
+ )
176
+
177
+ def augs_function(self, augs, img_size):
178
+ t = []
179
+ if "randomflip" in augs:
180
+ t.append(transforms.RandomHorizontalFlip())
181
+ if "ColorJitter" in augs:
182
+ t.append(
183
+ transforms.ColorJitter(
184
+ brightness=0.5, contrast=0.5, saturation=0.5, hue=0
185
+ )
186
+ )
187
+ if "resizedcrop" in augs:
188
+ t.append(transforms.RandomResizedCrop(img_size, scale=(0.7, 1.0)))
189
+ if "RandAugment" in augs:
190
+ t.append(RandAugment())
191
+
192
+ t.append(transforms.Resize((img_size, img_size)))
193
+
194
+ return transforms.Compose(t)
195
+
196
+ def load_anns(self):
197
+ self.anns = []
198
+ for ann_file in self.ann_files:
199
+ json_data = json.load(open(ann_file, "r"))
200
+ self.anns += json_data
201
+
202
+ def __len__(self):
203
+ return len(self.anns)
204
+
205
+ def __getitem__(self, idx):
206
+ idx = idx % len(self)
207
+ ann = self.anns[idx]
208
+ img = Image.open(ann["img_path"]).convert("RGB")
209
+
210
+ # get scribble
211
+ scribble_path = self._scribbles[
212
+ random.randint(0, 950)
213
+ ]
214
+ scribble = Image.open(scribble_path).convert('P')
215
+ scribble = preprocess_scribble(scribble, self.img_size)
216
+
217
+ scribble_t = (scribble > 0).float() # threshold to [0,1]
218
+ inv_scribble = (torch.max(scribble_t) - scribble_t) # inverted scribble
219
+
220
+ if self.dataset == "wider":
221
+ x, y, w, h = ann["bbox"]
222
+ img_area = img.crop([x, y, x + w, y + h])
223
+ img_area = self.augment(img_area)
224
+ img_area = self.transform(img_area)
225
+
226
+ # masked image
227
+ masked_image = img_area * inv_scribble
228
+ message = {
229
+ "img_path": ann["img_path"],
230
+ "target": torch.Tensor(ann["target"]),
231
+ "img": img_area,
232
+ "masked_img": masked_image,
233
+ #"scribble": inv_scribble,
234
+ }
235
+ else: # voc and coco
236
+ img = self.augment(img)
237
+ img = self.transform(img)
238
+ # masked image
239
+ masked_image = img * inv_scribble
240
+ message = {
241
+ "img_path": ann["img_path"],
242
+ "target": torch.Tensor(ann["target"]),
243
+ "img": img,
244
+ "masked_img": masked_image,
245
+ #"scribble": inv_scribble,
246
+ }
247
+
248
+ return message
249
+ # finally, if we use dataloader to get the data, we will get
250
+ # {
251
+ # "img_path": list, # length = batch_size
252
+ # "target": Tensor, # shape: batch_size * num_classes
253
+ # "img": Tensor, # shape: batch_size * 3 * 224 * 224
254
+ # "masked_img": Tensor, # shape: batch_size * 3 * 224 * 224
255
+ # }
pipeline/losses.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ """ASL taken from https://github.com/Alibaba-MIIL/ASL"""
5
+
6
+ # Usage
7
+ # global criterion_asl
8
+ # criterion_asl = AsymmetricLoss(gamma_neg=4, gamma_pos=0, clip=0.05, disable_torch_grad_focal_loss=True)
9
+ # loss3 = criterion_asl(pred1, pred2)
10
+
11
+ class AsymmetricLoss(nn.Module):
12
+ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
13
+ super(AsymmetricLoss, self).__init__()
14
+
15
+ self.gamma_neg = gamma_neg
16
+ self.gamma_pos = gamma_pos
17
+ self.clip = clip
18
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
19
+ self.eps = eps
20
+
21
+ def forward(self, x, y):
22
+ """"
23
+ Parameters
24
+ ----------
25
+ x: input logits
26
+ y: targets (multi-label binarized vector)
27
+ """
28
+
29
+ # Calculating Probabilities
30
+ x_sigmoid = torch.sigmoid(x)
31
+ xs_pos = x_sigmoid
32
+ xs_neg = 1 - x_sigmoid
33
+
34
+ # Asymmetric Clipping
35
+ if self.clip is not None and self.clip > 0:
36
+ xs_neg = (xs_neg + self.clip).clamp(max=1)
37
+
38
+ # Basic CE calculation
39
+ los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
40
+ los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
41
+ loss = los_pos + los_neg
42
+
43
+ # Asymmetric Focusing
44
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
45
+ if self.disable_torch_grad_focal_loss:
46
+ torch.set_grad_enabled(False)
47
+ pt0 = xs_pos * y
48
+ pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
49
+ pt = pt0 + pt1
50
+ one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
51
+ one_sided_w = torch.pow(1 - pt, one_sided_gamma)
52
+ if self.disable_torch_grad_focal_loss:
53
+ torch.set_grad_enabled(True)
54
+ loss *= one_sided_w
55
+
56
+ return -loss.sum()
pipeline/models/tresnet/layers/anti_aliasing.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.parallel
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class AntiAliasDownsampleLayer(nn.Module):
9
+ def __init__(self, remove_model_jit: bool = False, filt_size: int = 3, stride: int = 2,
10
+ channels: int = 0):
11
+ super(AntiAliasDownsampleLayer, self).__init__()
12
+ if not remove_model_jit:
13
+ self.op = DownsampleJIT(filt_size, stride, channels)
14
+ else:
15
+ self.op = Downsample(filt_size, stride, channels)
16
+
17
+ def forward(self, x):
18
+ return self.op(x)
19
+
20
+
21
+ @torch.jit.script
22
+ class DownsampleJIT(object):
23
+ def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0):
24
+ self.stride = stride
25
+ self.filt_size = filt_size
26
+ self.channels = channels
27
+
28
+ assert self.filt_size == 3
29
+ assert stride == 2
30
+ a = torch.tensor([1., 2., 1.])
31
+
32
+ filt = (a[:, None] * a[None, :]).clone().detach()
33
+ filt = filt / torch.sum(filt)
34
+ self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)).cuda().half()
35
+
36
+ def __call__(self, input: torch.Tensor):
37
+ if input.dtype != self.filt.dtype:
38
+ self.filt = self.filt.float()
39
+ input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
40
+ return F.conv2d(input_pad, self.filt, stride=2, padding=0, groups=input.shape[1])
41
+
42
+
43
+ class Downsample(nn.Module):
44
+ def __init__(self, filt_size=3, stride=2, channels=None):
45
+ super(Downsample, self).__init__()
46
+ self.filt_size = filt_size
47
+ self.stride = stride
48
+ self.channels = channels
49
+
50
+
51
+ assert self.filt_size == 3
52
+ a = torch.tensor([1., 2., 1.])
53
+
54
+ filt = (a[:, None] * a[None, :]).clone().detach()
55
+ filt = filt / torch.sum(filt)
56
+ self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
57
+
58
+ def forward(self, input):
59
+ input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
60
+ return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1])
pipeline/models/tresnet/layers/avg_pool.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+
7
+ class FastAvgPool2d(nn.Module):
8
+ def __init__(self, flatten=False):
9
+ super(FastAvgPool2d, self).__init__()
10
+ self.flatten = flatten
11
+
12
+ def forward(self, x):
13
+ if self.flatten:
14
+ in_size = x.size()
15
+ return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
16
+ else:
17
+ return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
18
+
19
+
pipeline/models/tresnet/layers/general_layers.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from pipeline.models.tresnet.layers.avg_pool import FastAvgPool2d
6
+
7
+
8
+ class Flatten(nn.Module):
9
+ def forward(self, x):
10
+ return x.view(x.size(0), -1)
11
+
12
+
13
+ class DepthToSpace(nn.Module):
14
+
15
+ def __init__(self, block_size):
16
+ super().__init__()
17
+ self.bs = block_size
18
+
19
+ def forward(self, x):
20
+ N, C, H, W = x.size()
21
+ x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
22
+ x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
23
+ x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
24
+ return x
25
+
26
+
27
+ class SpaceToDepthModule(nn.Module):
28
+ def __init__(self, remove_model_jit=False):
29
+ super().__init__()
30
+ if not remove_model_jit:
31
+ self.op = SpaceToDepthJit()
32
+ else:
33
+ self.op = SpaceToDepth()
34
+
35
+ def forward(self, x):
36
+ return self.op(x)
37
+
38
+
39
+ class SpaceToDepth(nn.Module):
40
+ def __init__(self, block_size=4):
41
+ super().__init__()
42
+ assert block_size == 4
43
+ self.bs = block_size
44
+
45
+ def forward(self, x):
46
+ N, C, H, W = x.size()
47
+ x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
48
+ x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
49
+ x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
50
+ return x
51
+
52
+
53
+ @torch.jit.script
54
+ class SpaceToDepthJit(object):
55
+ def __call__(self, x: torch.Tensor):
56
+ # assuming hard-coded that block_size==4 for acceleration
57
+ N, C, H, W = x.size()
58
+ x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
59
+ x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
60
+ x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
61
+ return x
62
+
63
+
64
+ class hard_sigmoid(nn.Module):
65
+ def __init__(self, inplace=True):
66
+ super(hard_sigmoid, self).__init__()
67
+ self.inplace = inplace
68
+
69
+ def forward(self, x):
70
+ if self.inplace:
71
+ return x.add_(3.).clamp_(0., 6.).div_(6.)
72
+ else:
73
+ return F.relu6(x + 3.) / 6.
74
+
75
+
76
+ class SEModule(nn.Module):
77
+
78
+ def __init__(self, channels, reduction_channels, inplace=True):
79
+ super(SEModule, self).__init__()
80
+ self.avg_pool = FastAvgPool2d()
81
+ self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, padding=0, bias=True)
82
+ self.relu = nn.ReLU(inplace=inplace)
83
+ self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, padding=0, bias=True)
84
+ # self.activation = hard_sigmoid(inplace=inplace)
85
+ self.activation = nn.Sigmoid()
86
+
87
+ def forward(self, x):
88
+ x_se = self.avg_pool(x)
89
+ x_se2 = self.fc1(x_se)
90
+ x_se2 = self.relu(x_se2)
91
+ x_se = self.fc2(x_se2)
92
+ x_se = self.activation(x_se)
93
+ return x * x_se
pipeline/models/tresnet/tresnet.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import Module as Module
4
+ from collections import OrderedDict
5
+ from pipeline.models.tresnet.layers.anti_aliasing import AntiAliasDownsampleLayer
6
+ from .layers.avg_pool import FastAvgPool2d
7
+ from .layers.general_layers import SEModule, SpaceToDepthModule
8
+ from inplace_abn import InPlaceABN, ABN
9
+ import torch.nn.functional as F
10
+
11
+ def InplacABN_to_ABN(module: nn.Module) -> nn.Module:
12
+ # convert all InplaceABN layer to bit-accurate ABN layers.
13
+ if isinstance(module, InPlaceABN):
14
+ module_new = ABN(module.num_features, activation=module.activation,
15
+ activation_param=module.activation_param)
16
+ for key in module.state_dict():
17
+ module_new.state_dict()[key].copy_(module.state_dict()[key])
18
+ module_new.training = module.training
19
+ module_new.weight.data = module_new.weight.abs() + module_new.eps
20
+ return module_new
21
+ for name, child in reversed(module._modules.items()):
22
+ new_child = InplacABN_to_ABN(child)
23
+ if new_child != child:
24
+ module._modules[name] = new_child
25
+ return module
26
+
27
+ class bottleneck_head(nn.Module):
28
+ def __init__(self, num_features, num_classes, bottleneck_features=200):
29
+ super(bottleneck_head, self).__init__()
30
+ self.embedding_generator = nn.ModuleList()
31
+ self.embedding_generator.append(nn.Linear(num_features, bottleneck_features))
32
+ self.embedding_generator = nn.Sequential(*self.embedding_generator)
33
+ self.FC = nn.Linear(bottleneck_features, num_classes)
34
+
35
+ def forward(self, x):
36
+ self.embedding = self.embedding_generator(x)
37
+ logits = self.FC(self.embedding)
38
+ return logits
39
+
40
+
41
+ def conv2d(ni, nf, stride):
42
+ return nn.Sequential(
43
+ nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False),
44
+ nn.BatchNorm2d(nf),
45
+ nn.ReLU(inplace=True)
46
+ )
47
+
48
+
49
+ def conv2d_ABN(ni, nf, stride, activation="leaky_relu", kernel_size=3, activation_param=1e-2, groups=1):
50
+ return nn.Sequential(
51
+ nn.Conv2d(ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups,
52
+ bias=False),
53
+ InPlaceABN(num_features=nf, activation=activation, activation_param=activation_param)
54
+ )
55
+
56
+
57
+ class BasicBlock(Module):
58
+ expansion = 1
59
+
60
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None):
61
+ super(BasicBlock, self).__init__()
62
+ if stride == 1:
63
+ self.conv1 = conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3)
64
+ else:
65
+ if anti_alias_layer is None:
66
+ self.conv1 = conv2d_ABN(inplanes, planes, stride=2, activation_param=1e-3)
67
+ else:
68
+ self.conv1 = nn.Sequential(conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3),
69
+ anti_alias_layer(channels=planes, filt_size=3, stride=2))
70
+
71
+ self.conv2 = conv2d_ABN(planes, planes, stride=1, activation="identity")
72
+ self.relu = nn.ReLU(inplace=True)
73
+ self.downsample = downsample
74
+ self.stride = stride
75
+ reduce_layer_planes = max(planes * self.expansion // 4, 64)
76
+ self.se = SEModule(planes * self.expansion, reduce_layer_planes) if use_se else None
77
+
78
+ def forward(self, x):
79
+ if self.downsample is not None:
80
+ residual = self.downsample(x)
81
+ else:
82
+ residual = x
83
+
84
+ out = self.conv1(x)
85
+ out = self.conv2(out)
86
+
87
+ if self.se is not None: out = self.se(out)
88
+
89
+ out += residual
90
+
91
+ out = self.relu(out)
92
+
93
+ return out
94
+
95
+
96
+ class Bottleneck(Module):
97
+ expansion = 4
98
+
99
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None):
100
+ super(Bottleneck, self).__init__()
101
+ self.conv1 = conv2d_ABN(inplanes, planes, kernel_size=1, stride=1, activation="leaky_relu",
102
+ activation_param=1e-3)
103
+ if stride == 1:
104
+ self.conv2 = conv2d_ABN(planes, planes, kernel_size=3, stride=1, activation="leaky_relu",
105
+ activation_param=1e-3)
106
+ else:
107
+ if anti_alias_layer is None:
108
+ self.conv2 = conv2d_ABN(planes, planes, kernel_size=3, stride=2, activation="leaky_relu",
109
+ activation_param=1e-3)
110
+ else:
111
+ self.conv2 = nn.Sequential(conv2d_ABN(planes, planes, kernel_size=3, stride=1,
112
+ activation="leaky_relu", activation_param=1e-3),
113
+ anti_alias_layer(channels=planes, filt_size=3, stride=2))
114
+
115
+ self.conv3 = conv2d_ABN(planes, planes * self.expansion, kernel_size=1, stride=1,
116
+ activation="identity")
117
+
118
+ self.relu = nn.ReLU(inplace=True)
119
+ self.downsample = downsample
120
+ self.stride = stride
121
+
122
+ reduce_layer_planes = max(planes * self.expansion // 8, 64)
123
+ self.se = SEModule(planes, reduce_layer_planes) if use_se else None
124
+
125
+ def forward(self, x):
126
+ if self.downsample is not None:
127
+ residual = self.downsample(x)
128
+ else:
129
+ residual = x
130
+
131
+ out = self.conv1(x)
132
+ out = self.conv2(out)
133
+ if self.se is not None: out = self.se(out)
134
+
135
+ out = self.conv3(out)
136
+ out = out + residual # no inplace
137
+ out = self.relu(out)
138
+
139
+ return out
140
+
141
+
142
+ class TResNet(Module):
143
+
144
+ def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0,
145
+ do_bottleneck_head=False,bottleneck_features=512):
146
+ super(TResNet, self).__init__()
147
+
148
+ # Loss function
149
+ self.loss_func = F.binary_cross_entropy_with_logits
150
+
151
+ # JIT layers
152
+ space_to_depth = SpaceToDepthModule()
153
+ anti_alias_layer = AntiAliasDownsampleLayer
154
+ global_pool_layer = FastAvgPool2d(flatten=True)
155
+
156
+ # TResnet stages
157
+ self.inplanes = int(64 * width_factor)
158
+ self.planes = int(64 * width_factor)
159
+ conv1 = conv2d_ABN(in_chans * 16, self.planes, stride=1, kernel_size=3)
160
+ layer1 = self._make_layer(BasicBlock, self.planes, layers[0], stride=1, use_se=True,
161
+ anti_alias_layer=anti_alias_layer) # 56x56
162
+ layer2 = self._make_layer(BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True,
163
+ anti_alias_layer=anti_alias_layer) # 28x28
164
+ layer3 = self._make_layer(Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True,
165
+ anti_alias_layer=anti_alias_layer) # 14x14
166
+ layer4 = self._make_layer(Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False,
167
+ anti_alias_layer=anti_alias_layer) # 7x7
168
+
169
+ # body
170
+ self.body = nn.Sequential(OrderedDict([
171
+ ('SpaceToDepth', space_to_depth),
172
+ ('conv1', conv1),
173
+ ('layer1', layer1),
174
+ ('layer2', layer2),
175
+ ('layer3', layer3),
176
+ ('layer4', layer4)]))
177
+
178
+ # head
179
+ self.embeddings = []
180
+ self.global_pool = nn.Sequential(OrderedDict([('global_pool_layer', global_pool_layer)]))
181
+ self.num_features = (self.planes * 8) * Bottleneck.expansion
182
+ if do_bottleneck_head:
183
+ fc = bottleneck_head(self.num_features, num_classes,
184
+ bottleneck_features=bottleneck_features)
185
+ else:
186
+ fc = nn.Linear(self.num_features , num_classes)
187
+
188
+ self.head = nn.Sequential(OrderedDict([('fc', fc)]))
189
+
190
+ # model initilization
191
+ for m in self.modules():
192
+ if isinstance(m, nn.Conv2d):
193
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
194
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InPlaceABN):
195
+ nn.init.constant_(m.weight, 1)
196
+ nn.init.constant_(m.bias, 0)
197
+
198
+ # residual connections special initialization
199
+ for m in self.modules():
200
+ if isinstance(m, BasicBlock):
201
+ m.conv2[1].weight = nn.Parameter(torch.zeros_like(m.conv2[1].weight)) # BN to zero
202
+ if isinstance(m, Bottleneck):
203
+ m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero
204
+ if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01)
205
+
206
+ def _make_layer(self, block, planes, blocks, stride=1, use_se=True, anti_alias_layer=None):
207
+ downsample = None
208
+ if stride != 1 or self.inplanes != planes * block.expansion:
209
+ layers = []
210
+ if stride == 2:
211
+ # avg pooling before 1x1 conv
212
+ layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
213
+ layers += [conv2d_ABN(self.inplanes, planes * block.expansion, kernel_size=1, stride=1,
214
+ activation="identity")]
215
+ downsample = nn.Sequential(*layers)
216
+
217
+ layers = []
218
+ layers.append(block(self.inplanes, planes, stride, downsample, use_se=use_se,
219
+ anti_alias_layer=anti_alias_layer))
220
+ self.inplanes = planes * block.expansion
221
+ for i in range(1, blocks): layers.append(
222
+ block(self.inplanes, planes, use_se=use_se, anti_alias_layer=anti_alias_layer))
223
+ return nn.Sequential(*layers)
224
+
225
+ def forward_train(self, x, target):
226
+ x = self.body(x)
227
+ self.embeddings = self.global_pool(x)
228
+ logits = self.head(self.embeddings)
229
+ loss = self.loss_func(logits, target, reduction="mean")
230
+ return logits, loss
231
+
232
+ def forward_test(self, x):
233
+ x = self.body(x)
234
+ self.embeddings = self.global_pool(x)
235
+ logits = self.head(self.embeddings)
236
+ return logits
237
+
238
+ def forward(self, x, target=None):
239
+ if target is not None:
240
+ return self.forward_train(x, target)
241
+ else:
242
+ return self.forward_test(x)
243
+
244
+
245
+ def TResnetM(num_classes):
246
+ """Constructs a medium TResnet model.
247
+ """
248
+ in_chans = 3
249
+ model = TResNet(layers=[3, 4, 11, 3], num_classes=num_classes, in_chans=in_chans)
250
+ return model
251
+
252
+
253
+ def TResnetL(num_classes):
254
+ """Constructs a large TResnet model.
255
+ """
256
+ in_chans = 3
257
+ do_bottleneck_head = False
258
+ model = TResNet(layers=[4, 5, 18, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.2,
259
+ do_bottleneck_head=do_bottleneck_head)
260
+ return model
261
+
262
+
263
+ def TResnetXL(num_classes):
264
+ """Constructs a xlarge TResnet model.
265
+ """
266
+ in_chans = 3
267
+ model = TResNet(layers=[4, 5, 24, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.3)
268
+ return model
pipeline/models/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .factory import create_model
2
+ __all__ = ['create_model']
pipeline/models/utils/factory.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ logger = logging.getLogger(__name__)
4
+
5
+ from ..tresnet import TResnetM, TResnetL, TResnetXL
6
+
7
+
8
+ def create_model(args):
9
+ """Create a model
10
+ """
11
+ model_params = {'args': args, 'num_classes': args.num_classes}
12
+ args = model_params['args']
13
+ args.model_name = args.model_name.lower()
14
+
15
+ if args.model_name=='tresnet_m':
16
+ model = TResnetM(model_params)
17
+ elif args.model_name=='tresnet_l':
18
+ model = TResnetL(model_params)
19
+ elif args.model_name=='tresnet_xl':
20
+ model = TResnetXL(model_params)
21
+ else:
22
+ print("model: {} not found !!".format(args.model_name))
23
+ exit(-1)
24
+
25
+ return model
pipeline/resnet_csra.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.models import ResNet
2
+ from torchvision.models.resnet import Bottleneck, BasicBlock
3
+ from .csra import CSRA, MHA
4
+ import torch.utils.model_zoo as model_zoo
5
+ import logging
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ model_urls = {
12
+ "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
13
+ "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
14
+ "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
15
+ "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
16
+ "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
17
+ }
18
+
19
+
20
+ class ResNet_CSRA(ResNet):
21
+ arch_settings = {
22
+ 18: (BasicBlock, (2, 2, 2, 2)),
23
+ 34: (BasicBlock, (3, 4, 6, 3)),
24
+ 50: (Bottleneck, (3, 4, 6, 3)),
25
+ 101: (Bottleneck, (3, 4, 23, 3)),
26
+ 152: (Bottleneck, (3, 8, 36, 3)),
27
+ }
28
+
29
+ def __init__(
30
+ self, num_heads, lam, num_classes, depth=101, input_dim=2048, cutmix=None
31
+ ):
32
+ self.block, self.layers = self.arch_settings[depth]
33
+ self.depth = depth
34
+ super(ResNet_CSRA, self).__init__(self.block, self.layers)
35
+ self.init_weights(pretrained=True, cutmix=cutmix)
36
+
37
+ self.classifier = MHA(num_heads, lam, input_dim, num_classes)
38
+ self.loss_func = F.binary_cross_entropy_with_logits
39
+ # todo
40
+ # criterion = nn.BCEWithLogitsLoss() # loss combines a Sigmoid layer and the BCELoss in one single class
41
+
42
+ def backbone(self, x):
43
+ x = self.conv1(x)
44
+ x = self.bn1(x)
45
+ x = self.relu(x)
46
+ x = self.maxpool(x)
47
+
48
+ x = self.layer1(x)
49
+ x = self.layer2(x)
50
+ x = self.layer3(x)
51
+ x = self.layer4(x)
52
+
53
+ return x
54
+
55
+ def forward_train(self, x, target):
56
+ x = self.backbone(x)
57
+ logit = self.classifier(x)
58
+ loss = self.loss_func(logit, target, reduction="mean")
59
+ return logit, loss
60
+
61
+ def forward_test(self, x):
62
+ x = self.backbone(x)
63
+ x = self.classifier(x)
64
+ return x
65
+
66
+ def forward(self, x, target=None):
67
+ if target is not None:
68
+ return self.forward_train(x, target)
69
+ else:
70
+ return self.forward_test(x)
71
+
72
+ def init_weights(self, pretrained=True, cutmix=None):
73
+ if cutmix is not None:
74
+ print("backbone params inited by CutMix pretrained model")
75
+ state_dict = torch.load(cutmix)
76
+ elif pretrained:
77
+ print("backbone params inited by Pytorch official model")
78
+ model_url = model_urls["resnet{}".format(self.depth)]
79
+ state_dict = model_zoo.load_url(model_url)
80
+
81
+ model_dict = self.state_dict()
82
+ try:
83
+ pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
84
+ self.load_state_dict(pretrained_dict)
85
+ except:
86
+ logger = logging.getLogger()
87
+ logger.info(
88
+ "the keys in pretrained model is not equal to the keys in the ResNet you choose, trying to fix..."
89
+ )
90
+ state_dict = self._keysFix(model_dict, state_dict)
91
+ self.load_state_dict(state_dict)
92
+
93
+ # remove the original 1000-class fc
94
+ self.fc = nn.Sequential()
pipeline/timm_utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .tuple import to_ntuple, to_2tuple, to_3tuple, to_4tuple
2
+ from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
3
+ from .weight_init import trunc_normal_
4
+
pipeline/timm_utils/drop.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ DropBlock, DropPath
2
+
3
+ PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
4
+
5
+ Papers:
6
+ DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
7
+
8
+ Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
9
+
10
+ Code:
11
+ DropBlock impl inspired by two Tensorflow impl that I liked:
12
+ - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
13
+ - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
14
+
15
+ Hacked together by / Copyright 2020 Ross Wightman
16
+ """
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+
22
+ def drop_block_2d(
23
+ x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
24
+ with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
25
+ """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
26
+
27
+ DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
28
+ runs with success, but needs further validation and possibly optimization for lower runtime impact.
29
+ """
30
+ B, C, H, W = x.shape
31
+ total_size = W * H
32
+ clipped_block_size = min(block_size, min(W, H))
33
+ # seed_drop_rate, the gamma parameter
34
+ gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
35
+ (W - block_size + 1) * (H - block_size + 1))
36
+
37
+ # Forces the block to be inside the feature map.
38
+ w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
39
+ valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
40
+ ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
41
+ valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
42
+
43
+ if batchwise:
44
+ # one mask for whole batch, quite a bit faster
45
+ uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
46
+ else:
47
+ uniform_noise = torch.rand_like(x)
48
+ block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
49
+ block_mask = -F.max_pool2d(
50
+ -block_mask,
51
+ kernel_size=clipped_block_size, # block_size,
52
+ stride=1,
53
+ padding=clipped_block_size // 2)
54
+
55
+ if with_noise:
56
+ normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
57
+ if inplace:
58
+ x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
59
+ else:
60
+ x = x * block_mask + normal_noise * (1 - block_mask)
61
+ else:
62
+ normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
63
+ if inplace:
64
+ x.mul_(block_mask * normalize_scale)
65
+ else:
66
+ x = x * block_mask * normalize_scale
67
+ return x
68
+
69
+
70
+ def drop_block_fast_2d(
71
+ x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
72
+ gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
73
+ """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
74
+
75
+ DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
76
+ block mask at edges.
77
+ """
78
+ B, C, H, W = x.shape
79
+ total_size = W * H
80
+ clipped_block_size = min(block_size, min(W, H))
81
+ gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
82
+ (W - block_size + 1) * (H - block_size + 1))
83
+
84
+ if batchwise:
85
+ # one mask for whole batch, quite a bit faster
86
+ block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma
87
+ else:
88
+ # mask per batch element
89
+ block_mask = torch.rand_like(x) < gamma
90
+ block_mask = F.max_pool2d(
91
+ block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
92
+
93
+ if with_noise:
94
+ normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
95
+ if inplace:
96
+ x.mul_(1. - block_mask).add_(normal_noise * block_mask)
97
+ else:
98
+ x = x * (1. - block_mask) + normal_noise * block_mask
99
+ else:
100
+ block_mask = 1 - block_mask
101
+ normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype)
102
+ if inplace:
103
+ x.mul_(block_mask * normalize_scale)
104
+ else:
105
+ x = x * block_mask * normalize_scale
106
+ return x
107
+
108
+
109
+ class DropBlock2d(nn.Module):
110
+ """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
111
+ """
112
+ def __init__(self,
113
+ drop_prob=0.1,
114
+ block_size=7,
115
+ gamma_scale=1.0,
116
+ with_noise=False,
117
+ inplace=False,
118
+ batchwise=False,
119
+ fast=True):
120
+ super(DropBlock2d, self).__init__()
121
+ self.drop_prob = drop_prob
122
+ self.gamma_scale = gamma_scale
123
+ self.block_size = block_size
124
+ self.with_noise = with_noise
125
+ self.inplace = inplace
126
+ self.batchwise = batchwise
127
+ self.fast = fast # FIXME finish comparisons of fast vs not
128
+
129
+ def forward(self, x):
130
+ if not self.training or not self.drop_prob:
131
+ return x
132
+ if self.fast:
133
+ return drop_block_fast_2d(
134
+ x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
135
+ else:
136
+ return drop_block_2d(
137
+ x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
138
+
139
+
140
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
141
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
142
+
143
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
144
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
145
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
146
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
147
+ 'survival rate' as the argument.
148
+
149
+ """
150
+ if drop_prob == 0. or not training:
151
+ return x
152
+ keep_prob = 1 - drop_prob
153
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
154
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
155
+ random_tensor.floor_() # binarize
156
+ output = x.div(keep_prob) * random_tensor
157
+ return output
158
+
159
+
160
+ class DropPath(nn.Module):
161
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
162
+ """
163
+ def __init__(self, drop_prob=None):
164
+ super(DropPath, self).__init__()
165
+ self.drop_prob = drop_prob
166
+
167
+ def forward(self, x):
168
+ return drop_path(x, self.drop_prob, self.training)
pipeline/timm_utils/tuple.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Layer/Module Helpers
2
+
3
+ Hacked together by / Copyright 2020 Ross Wightman
4
+ """
5
+ from itertools import repeat
6
+ from torch._six import container_abcs
7
+
8
+
9
+ # From PyTorch internals
10
+ def _ntuple(n):
11
+ def parse(x):
12
+ if isinstance(x, container_abcs.Iterable):
13
+ return x
14
+ return tuple(repeat(x, n))
15
+ return parse
16
+
17
+
18
+ to_1tuple = _ntuple(1)
19
+ to_2tuple = _ntuple(2)
20
+ to_3tuple = _ntuple(3)
21
+ to_4tuple = _ntuple(4)
22
+ to_ntuple = _ntuple
23
+
24
+
25
+
26
+
27
+
pipeline/timm_utils/weight_init.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import warnings
4
+
5
+
6
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
7
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
8
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
9
+ def norm_cdf(x):
10
+ # Computes standard normal cumulative distribution function
11
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
12
+
13
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
14
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
15
+ "The distribution of values may be incorrect.",
16
+ stacklevel=2)
17
+
18
+ with torch.no_grad():
19
+ # Values are generated by using a truncated uniform distribution and
20
+ # then using the inverse CDF for the normal distribution.
21
+ # Get upper and lower cdf values
22
+ l = norm_cdf((a - mean) / std)
23
+ u = norm_cdf((b - mean) / std)
24
+
25
+ # Uniformly fill tensor with values from [l, u], then translate to
26
+ # [2l-1, 2u-1].
27
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
28
+
29
+ # Use inverse cdf transform for normal distribution to get truncated
30
+ # standard normal
31
+ tensor.erfinv_()
32
+
33
+ # Transform to proper mean, std
34
+ tensor.mul_(std * math.sqrt(2.))
35
+ tensor.add_(mean)
36
+
37
+ # Clamp to ensure it's in the proper range
38
+ tensor.clamp_(min=a, max=b)
39
+ return tensor
40
+
41
+
42
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
43
+ # type: (Tensor, float, float, float, float) -> Tensor
44
+ r"""Fills the input Tensor with values drawn from a truncated
45
+ normal distribution. The values are effectively drawn from the
46
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
47
+ with values outside :math:`[a, b]` redrawn until they are within
48
+ the bounds. The method used for generating the random values works
49
+ best when :math:`a \leq \text{mean} \leq b`.
50
+ Args:
51
+ tensor: an n-dimensional `torch.Tensor`
52
+ mean: the mean of the normal distribution
53
+ std: the standard deviation of the normal distribution
54
+ a: the minimum cutoff value
55
+ b: the maximum cutoff value
56
+ Examples:
57
+ >>> w = torch.empty(3, 5)
58
+ >>> nn.init.trunc_normal_(w)
59
+ """
60
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
pipeline/vit_csra.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+
3
+ A PyTorch implement of Vision Transformers as described in
4
+ 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
5
+
6
+ The official jax code is released and available at https://github.com/google-research/vision_transformer
7
+
8
+ Status/TODO:
9
+ * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
10
+ * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
11
+ * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
12
+ * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
13
+
14
+ Acknowledgments:
15
+ * The paper authors for releasing code and weights, thanks!
16
+ * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
17
+ for some einops/einsum fun
18
+ * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
19
+ * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
20
+
21
+ Hacked together by / Copyright 2020 Ross Wightman
22
+ """
23
+ import math
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ import torch.utils.model_zoo as model_zoo
28
+ from functools import partial
29
+ from .timm_utils import DropPath, to_2tuple, trunc_normal_
30
+ from .csra import MHA, CSRA
31
+
32
+
33
+ default_cfgs = {
34
+ 'vit_base_patch16_224': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
35
+ 'vit_large_patch16_224':'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth'
36
+ }
37
+
38
+
39
+
40
+ class Mlp(nn.Module):
41
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
42
+ super().__init__()
43
+ out_features = out_features or in_features
44
+ hidden_features = hidden_features or in_features
45
+ self.fc1 = nn.Linear(in_features, hidden_features)
46
+ self.act = act_layer()
47
+ self.fc2 = nn.Linear(hidden_features, out_features)
48
+ self.drop = nn.Dropout(drop)
49
+
50
+ def forward(self, x):
51
+ x = self.fc1(x)
52
+ x = self.act(x)
53
+ x = self.drop(x)
54
+ x = self.fc2(x)
55
+ x = self.drop(x)
56
+ return x
57
+
58
+
59
+ class Attention(nn.Module):
60
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
61
+ super().__init__()
62
+ self.num_heads = num_heads
63
+ head_dim = dim // num_heads # 64
64
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
65
+ self.scale = qk_scale or head_dim ** -0.5
66
+
67
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
68
+ self.attn_drop = nn.Dropout(attn_drop)
69
+ self.proj = nn.Linear(dim, dim)
70
+ self.proj_drop = nn.Dropout(proj_drop)
71
+
72
+ def forward(self, x):
73
+ B, N, C = x.shape
74
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
75
+ # qkv (3, B, 12, N, C/12)
76
+ # q (B, 12, N, C/12)
77
+ # k (B, 12, N, C/12)
78
+ # v (B, 12, N, C/12)
79
+ # attn (B, 12, N, N)
80
+ # x (B, 12, N, C/12)
81
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
82
+
83
+ attn = (q @ k.transpose(-2, -1)) * self.scale
84
+ attn = attn.softmax(dim=-1)
85
+ attn = self.attn_drop(attn)
86
+
87
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
88
+
89
+ x = self.proj(x)
90
+ x = self.proj_drop(x)
91
+
92
+ return x
93
+
94
+
95
+ class Block(nn.Module):
96
+
97
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
98
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
99
+ super().__init__()
100
+ self.norm1 = norm_layer(dim)
101
+ self.attn = Attention(
102
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
103
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
104
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
105
+ self.norm2 = norm_layer(dim)
106
+ mlp_hidden_dim = int(dim * mlp_ratio)
107
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
108
+
109
+ def forward(self, x):
110
+ x = x + self.drop_path(self.attn(self.norm1(x)))
111
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
112
+ return x
113
+
114
+
115
+ class PatchEmbed(nn.Module):
116
+ """ Image to Patch Embedding
117
+ """
118
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
119
+ super().__init__()
120
+ img_size = to_2tuple(img_size)
121
+ patch_size = to_2tuple(patch_size)
122
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
123
+ self.img_size = img_size
124
+ self.patch_size = patch_size
125
+ self.num_patches = num_patches
126
+
127
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
128
+
129
+ def forward(self, x):
130
+ B, C, H, W = x.shape
131
+ # FIXME look at relaxing size constraints
132
+ assert H == self.img_size[0] and W == self.img_size[1], \
133
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
134
+ x = self.proj(x).flatten(2).transpose(1, 2)
135
+ return x
136
+
137
+
138
+ class HybridEmbed(nn.Module):
139
+ """ CNN Feature Map Embedding
140
+ Extract feature map from CNN, flatten, project to embedding dim.
141
+ """
142
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
143
+ super().__init__()
144
+ assert isinstance(backbone, nn.Module)
145
+ img_size = to_2tuple(img_size)
146
+ self.img_size = img_size
147
+ self.backbone = backbone
148
+ if feature_size is None:
149
+ with torch.no_grad():
150
+ # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
151
+ # map for all networks, the feature metadata has reliable channel and stride info, but using
152
+ # stride to calc feature dim requires info about padding of each stage that isn't captured.
153
+ training = backbone.training
154
+ if training:
155
+ backbone.eval()
156
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
157
+ feature_size = o.shape[-2:]
158
+ feature_dim = o.shape[1]
159
+ backbone.train(training)
160
+ else:
161
+ feature_size = to_2tuple(feature_size)
162
+ feature_dim = self.backbone.feature_info.channels()[-1]
163
+ self.num_patches = feature_size[0] * feature_size[1]
164
+ self.proj = nn.Linear(feature_dim, embed_dim)
165
+
166
+ def forward(self, x):
167
+ x = self.backbone(x)[-1]
168
+ x = x.flatten(2).transpose(1, 2)
169
+ x = self.proj(x)
170
+ return x
171
+
172
+
173
+ class VIT_CSRA(nn.Module):
174
+ """ Vision Transformer with support for patch or hybrid CNN input stage
175
+ """
176
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
177
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
178
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, cls_num_heads=1, cls_num_cls=80, lam=0.3):
179
+ super().__init__()
180
+ self.add_w = 0.
181
+ self.normalize = False
182
+ self.num_classes = num_classes
183
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
184
+
185
+ if hybrid_backbone is not None:
186
+ self.patch_embed = HybridEmbed(
187
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
188
+ else:
189
+ self.patch_embed = PatchEmbed(
190
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
191
+ num_patches = self.patch_embed.num_patches
192
+ self.HW = int(math.sqrt(num_patches))
193
+
194
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
195
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
196
+ self.pos_drop = nn.Dropout(p=drop_rate)
197
+
198
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
199
+ self.blocks = nn.ModuleList([
200
+ Block(
201
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
202
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
203
+ for i in range(depth)])
204
+ self.norm = norm_layer(embed_dim)
205
+
206
+ # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
207
+ #self.repr = nn.Linear(embed_dim, representation_size)
208
+ #self.repr_act = nn.Tanh()
209
+
210
+ trunc_normal_(self.pos_embed, std=.02)
211
+ trunc_normal_(self.cls_token, std=.02)
212
+ self.apply(self._init_weights)
213
+
214
+ # We add our MHA (CSRA) beside the orginal VIT structure below
215
+ self.head = nn.Sequential() # delete original classifier
216
+ self.classifier = MHA(input_dim=embed_dim, num_heads=cls_num_heads, num_classes=cls_num_cls, lam=lam)
217
+
218
+ self.loss_func = F.binary_cross_entropy_with_logits
219
+
220
+ def _init_weights(self, m):
221
+ if isinstance(m, nn.Linear):
222
+ trunc_normal_(m.weight, std=.02)
223
+ if isinstance(m, nn.Linear) and m.bias is not None:
224
+ nn.init.constant_(m.bias, 0)
225
+ elif isinstance(m, nn.LayerNorm):
226
+ nn.init.constant_(m.bias, 0)
227
+ nn.init.constant_(m.weight, 1.0)
228
+
229
+ def backbone(self, x):
230
+ B = x.shape[0]
231
+ x = self.patch_embed(x)
232
+
233
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
234
+ x = torch.cat((cls_tokens, x), dim=1)
235
+ x = x + self.pos_embed
236
+ x = self.pos_drop(x)
237
+
238
+ for blk in self.blocks:
239
+ x = blk(x)
240
+ x = self.norm(x)
241
+
242
+ # (B, 1+HW, C)
243
+ # we use all the feature to form the tensor like B C H W
244
+ x = x[:, 1:]
245
+ b, hw, c = x.shape
246
+ x = x.transpose(1, 2)
247
+ x = x.reshape(b, c, self.HW, self.HW)
248
+
249
+ return x
250
+
251
+ def forward_train(self, x, target):
252
+ x = self.backbone(x)
253
+ logit = self.classifier(x)
254
+ loss = self.loss_func(logit, target, reduction="mean")
255
+ return logit, loss
256
+
257
+ def forward_test(self, x):
258
+ x = self.backbone(x)
259
+ x = self.classifier(x)
260
+ return x
261
+
262
+ def forward(self, x, target=None):
263
+ if target is not None:
264
+ return self.forward_train(x, target)
265
+ else:
266
+ return self.forward_test(x)
267
+
268
+
269
+
270
+
271
+ def _conv_filter(state_dict, patch_size=16):
272
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
273
+ out_dict = {}
274
+ for k, v in state_dict.items():
275
+ if 'patch_embed.proj.weight' in k:
276
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
277
+ out_dict[k] = v
278
+ return out_dict
279
+
280
+
281
+ def VIT_B16_224_CSRA(pretrained=True, cls_num_heads=1, cls_num_cls=80, lam=0.3):
282
+ model = VIT_CSRA(
283
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
284
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), cls_num_heads=cls_num_heads, cls_num_cls=cls_num_cls, lam=lam)
285
+
286
+ model_url = default_cfgs['vit_base_patch16_224']
287
+ if pretrained:
288
+ state_dict = model_zoo.load_url(model_url)
289
+ model.load_state_dict(state_dict, strict=False)
290
+ return model
291
+
292
+
293
+ def VIT_L16_224_CSRA(pretrained=True, cls_num_heads=1, cls_num_cls=80, lam=0.3):
294
+ model = VIT_CSRA(
295
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
296
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), cls_num_heads=cls_num_heads, cls_num_cls=cls_num_cls, lam=lam)
297
+
298
+ model_url = default_cfgs['vit_large_patch16_224']
299
+ if pretrained:
300
+ state_dict = model_zoo.load_url(model_url)
301
+ model.load_state_dict(state_dict, strict=False)
302
+ # load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
303
+ return model
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ Pillow
utils/demo_images/000001.jpg ADDED
utils/demo_images/000002.jpg ADDED
utils/demo_images/000004.jpg ADDED
utils/demo_images/000006.jpg ADDED
utils/demo_images/000007.jpg ADDED
utils/demo_images/000009.jpg ADDED
utils/evaluation/cal_PR.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+
4
+
5
+
6
+ def json_metric(score_json, target_json, num_classes, types):
7
+ assert len(score_json) == len(target_json)
8
+ scores = np.zeros((len(score_json), num_classes))
9
+ targets = np.zeros((len(target_json), num_classes))
10
+ for index in range(len(score_json)):
11
+ scores[index] = score_json[index]["scores"]
12
+ targets[index] = target_json[index]["target"]
13
+
14
+
15
+ return metric(scores, targets, types)
16
+
17
+ def json_metric_top3(score_json, target_json, num_classes, types):
18
+ assert len(score_json) == len(target_json)
19
+ scores = np.zeros((len(score_json), num_classes))
20
+ targets = np.zeros((len(target_json), num_classes))
21
+ for index in range(len(score_json)):
22
+ tmp = np.array(score_json[index]['scores'])
23
+ idx = np.argsort(-tmp)
24
+ idx_after_3 = idx[3:]
25
+ tmp[idx_after_3] = 0.
26
+
27
+ scores[index] = tmp
28
+ # scores[index] = score_json[index]["scores"]
29
+ targets[index] = target_json[index]["target"]
30
+
31
+ return metric(scores, targets, types)
32
+
33
+
34
+ def metric(scores, targets, types):
35
+ """
36
+ :param scores: the output the model predict
37
+ :param targets: the gt label
38
+ :return: OP, OR, OF1, CP, CR, CF1
39
+ calculate the Precision of every class by: TP/TP+FP i.e. TP/total predict
40
+ calculate the Recall by: TP/total GT
41
+ """
42
+ num, num_class = scores.shape
43
+ gt_num = np.zeros(num_class)
44
+ tp_num = np.zeros(num_class)
45
+ predict_num = np.zeros(num_class)
46
+
47
+
48
+ for index in range(num_class):
49
+ score = scores[:, index]
50
+ target = targets[:, index]
51
+ if types == 'wider':
52
+ tmp = np.where(target == 99)[0]
53
+ # score[tmp] = 0
54
+ target[tmp] = 0
55
+
56
+ if types == 'voc07':
57
+ tmp = np.where(target != 0)[0]
58
+ score = score[tmp]
59
+ target = target[tmp]
60
+ neg_id = np.where(target == -1)[0]
61
+ target[neg_id] = 0
62
+
63
+
64
+ gt_num[index] = np.sum(target == 1)
65
+ predict_num[index] = np.sum(score >= 0.5)
66
+ tp_num[index] = np.sum(target * (score >= 0.5))
67
+
68
+ predict_num[predict_num == 0] = 1 # avoid dividing 0
69
+ OP = np.sum(tp_num) / np.sum(predict_num)
70
+ OR = np.sum(tp_num) / np.sum(gt_num)
71
+ OF1 = (2 * OP * OR) / (OP + OR)
72
+
73
+ #print(tp_num / predict_num)
74
+ #print(tp_num / gt_num)
75
+ CP = np.sum(tp_num / predict_num) / num_class
76
+ CR = np.sum(tp_num / gt_num) / num_class
77
+ CF1 = (2 * CP * CR) / (CP + CR)
78
+
79
+ return OP, OR, OF1, CP, CR, CF1
utils/evaluation/cal_mAP.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import json
5
+
6
+
7
+ def json_map(cls_id, pred_json, ann_json, types):
8
+ assert len(ann_json) == len(pred_json)
9
+ num = len(ann_json)
10
+ predict = np.zeros((num), dtype=np.float64)
11
+ target = np.zeros((num), dtype=np.float64)
12
+
13
+ for i in range(num):
14
+ predict[i] = pred_json[i]["scores"][cls_id]
15
+ target[i] = ann_json[i]["target"][cls_id]
16
+
17
+ if types == 'wider':
18
+ tmp = np.where(target != 99)[0]
19
+ predict = predict[tmp]
20
+ target = target[tmp]
21
+ num = len(tmp)
22
+
23
+ if types == 'voc07':
24
+ tmp = np.where(target != 0)[0]
25
+ predict = predict[tmp]
26
+ target = target[tmp]
27
+ neg_id = np.where(target == -1)[0]
28
+ target[neg_id] = 0
29
+ num = len(tmp)
30
+
31
+
32
+ tmp = np.argsort(-predict)
33
+ target = target[tmp]
34
+ predict = predict[tmp]
35
+
36
+
37
+ pre, obj = 0, 0
38
+ for i in range(num):
39
+ if target[i] == 1:
40
+ obj += 1.0
41
+ pre += obj / (i+1)
42
+ pre /= obj
43
+ return pre
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+
53
+
54
+
55
+
56
+
utils/evaluation/eval.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import numpy as np
4
+ import json
5
+ from tqdm import tqdm
6
+ from .cal_mAP import json_map
7
+ from .cal_PR import json_metric, metric, json_metric_top3
8
+
9
+
10
+ voc_classes = ("aeroplane", "bicycle", "bird", "boat", "bottle",
11
+ "bus", "car", "cat", "chair", "cow", "diningtable",
12
+ "dog", "horse", "motorbike", "person", "pottedplant",
13
+ "sheep", "sofa", "train", "tvmonitor")
14
+ coco_classes = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
15
+ 'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant',
16
+ 'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog',
17
+ 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
18
+ 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
19
+ 'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat',
20
+ 'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket',
21
+ 'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
22
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
23
+ 'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
24
+ 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop',
25
+ 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
26
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
27
+ 'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush')
28
+ wider_classes = (
29
+ "Male","longHair","sunglass","Hat","Tshiirt","longSleeve","formal",
30
+ "shorts","jeans","longPants","skirt","faceMask", "logo","stripe")
31
+
32
+ class_dict = {
33
+ "voc07": voc_classes,
34
+ "coco": coco_classes,
35
+ "wider": wider_classes,
36
+ }
37
+
38
+
39
+
40
+ def evaluation(result, types, ann_path):
41
+ print("Evaluation")
42
+ classes = class_dict[types]
43
+ aps = np.zeros(len(classes), dtype=np.float64)
44
+
45
+ ann_json = json.load(open(ann_path, "r"))
46
+ pred_json = result
47
+
48
+ for i, _ in enumerate(tqdm(classes)):
49
+ ap = json_map(i, pred_json, ann_json, types)
50
+ aps[i] = ap
51
+ OP, OR, OF1, CP, CR, CF1 = json_metric(pred_json, ann_json, len(classes), types)
52
+ print("mAP: {:4f}".format(np.mean(aps)))
53
+ print("CP: {:4f}, CR: {:4f}, CF1 :{:4F}".format(CP, CR, CF1))
54
+ print("OP: {:4f}, OR: {:4f}, OF1 {:4F}".format(OP, OR, OF1))
55
+
56
+ # I added it here
57
+ class WarmUpLR(torch.optim.lr_scheduler._LRScheduler):
58
+ def __init__(self, optimizer, total_iters, last_epoch=-1):
59
+ self.total_iters = total_iters
60
+ super().__init__(optimizer, last_epoch=last_epoch)
61
+
62
+ def get_lr(self):
63
+ return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]
64
+
utils/evaluation/warmUpLR.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class WarmUpLR(torch.optim.lr_scheduler._LRScheduler):
5
+ def __init__(self, optimizer, total_iters, last_epoch=-1):
6
+ self.total_iters = total_iters
7
+ super().__init__(optimizer, last_epoch=last_epoch)
8
+
9
+ def get_lr(self):
10
+ return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]
11
+
utils/prepare/prepare_coco.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import numpy as np
5
+ from pycocotools.coco import COCO
6
+
7
+
8
+
9
+ def make_data(data_path=None, tag="train"):
10
+ annFile = os.path.join(data_path, "annotations/instances_{}2014.json".format(tag))
11
+ coco = COCO(annFile)
12
+
13
+ img_id = coco.getImgIds()
14
+ cat_id = coco.getCatIds()
15
+ img_id = list(sorted(img_id))
16
+ cat_trans = {}
17
+ for i in range(len(cat_id)):
18
+ cat_trans[cat_id[i]] = i
19
+
20
+ message = []
21
+
22
+
23
+ for i in img_id:
24
+ data = {}
25
+ target = [0] * 80
26
+ path = ""
27
+ img_info = coco.loadImgs(i)[0]
28
+ ann_ids = coco.getAnnIds(imgIds = i)
29
+ anns = coco.loadAnns(ann_ids)
30
+ if len(anns) == 0:
31
+ continue
32
+ else:
33
+ for i in range(len(anns)):
34
+ cls = anns[i]['category_id']
35
+ cls = cat_trans[cls]
36
+ target[cls] = 1
37
+ path = img_info['file_name']
38
+ data['target'] = target
39
+ data['img_path'] = os.path.join(os.path.join(data_path, "images/{}2014/".format(tag)), path)
40
+ message.append(data)
41
+
42
+ with open('data/coco/{}_coco2014.json'.format(tag), 'w') as f:
43
+ json.dump(message, f)
44
+
45
+
46
+
47
+ # The final json file include: train_coco2014.json & val_coco2014.json
48
+ # which is the following format:
49
+ # [item1, item2, item3, ......,]
50
+ # item1 = {
51
+ # "target":
52
+ # "img_path":
53
+ # }
54
+ if __name__ == "__main__":
55
+ parser = argparse.ArgumentParser()
56
+ # Usage: --data_path /your/dataset/path/COCO2014
57
+ parser.add_argument("--data_path", default="Dataset/COCO2014/", type=str, help="The absolute path of COCO2014")
58
+ args = parser.parse_args()
59
+
60
+ if not os.path.exists("data/coco"):
61
+ os.makedirs("data/coco")
62
+
63
+ make_data(data_path=args.data_path, tag="train")
64
+ make_data(data_path=args.data_path, tag="val")
65
+
66
+ print("COCO data ready!")
67
+ print("data/coco/train_coco2014.json, data/coco/val_coco2014.json")
utils/prepare/prepare_voc.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import numpy as np
5
+ import xml.dom.minidom as XML
6
+
7
+
8
+
9
+ voc_cls_id = {"aeroplane":0, "bicycle":1, "bird":2, "boat":3, "bottle":4,
10
+ "bus":5, "car":6, "cat":7, "chair":8, "cow":9,
11
+ "diningtable":10, "dog":11, "horse":12, "motorbike":13, "person":14,
12
+ "pottedplant":15, "sheep":16, "sofa":17, "train":18, "tvmonitor":19}
13
+
14
+
15
+ def get_label(data_path):
16
+ print("generating labels for VOC07 dataset")
17
+ xml_paths = os.path.join(data_path, "VOC2007/Annotations/")
18
+ save_dir = "data/voc07/labels"
19
+
20
+ if not os.path.exists(save_dir):
21
+ os.makedirs(save_dir)
22
+
23
+ for i in os.listdir(xml_paths):
24
+ if not i.endswith(".xml"):
25
+ continue
26
+ s_name = i.split('.')[0] + ".txt"
27
+ s_dir = os.path.join(save_dir, s_name)
28
+ xml_path = os.path.join(xml_paths, i)
29
+ DomTree = XML.parse(xml_path)
30
+ Root = DomTree.documentElement
31
+
32
+ obj_all = Root.getElementsByTagName("object")
33
+ leng = len(obj_all)
34
+ cls = []
35
+ difi_tag = []
36
+ for obj in obj_all:
37
+ # get the classes
38
+ obj_name = obj.getElementsByTagName('name')[0]
39
+ one_class = obj_name.childNodes[0].data
40
+ cls.append(voc_cls_id[one_class])
41
+
42
+ difficult = obj.getElementsByTagName('difficult')[0]
43
+ difi_tag.append(difficult.childNodes[0].data)
44
+
45
+ for i, c in enumerate(cls):
46
+ with open(s_dir, "a") as f:
47
+ f.writelines("%s,%s\n" % (c, difi_tag[i]))
48
+
49
+
50
+ def transdifi(data_path):
51
+ print("generating final json file for VOC07 dataset")
52
+ label_dir = "data/voc07/labels/"
53
+ img_dir = os.path.join(data_path, "VOC2007/JPEGImages/")
54
+
55
+ # get trainval test id
56
+ id_dirs = os.path.join(data_path, "VOC2007/ImageSets/Main/")
57
+ f_train = open(os.path.join(id_dirs, "train.txt"), "r").readlines()
58
+ f_val = open(os.path.join(id_dirs, "val.txt"), "r").readlines()
59
+ f_trainval = f_train + f_val
60
+ f_test = open(os.path.join(id_dirs, "test.txt"), "r")
61
+
62
+ trainval_id = np.sort([int(line.strip()) for line in f_trainval]).tolist()
63
+ test_id = [int(line.strip()) for line in f_test]
64
+ trainval_data = []
65
+ test_data = []
66
+
67
+ # ternary label
68
+ # -1 means negative
69
+ # 0 means difficult
70
+ # +1 means positive
71
+
72
+ # binary label
73
+ # 0 means negative
74
+ # +1 means positive
75
+
76
+ # we use binary labels in our implementation
77
+
78
+ for item in sorted(os.listdir(label_dir)):
79
+ with open(os.path.join(label_dir, item), "r") as f:
80
+
81
+ target = np.array([-1] * 20)
82
+ classes = []
83
+ diffi_tag = []
84
+
85
+ for line in f.readlines():
86
+ cls, tag = map(int, line.strip().split(','))
87
+ classes.append(cls)
88
+ diffi_tag.append(tag)
89
+
90
+ classes = np.array(classes)
91
+ diffi_tag = np.array(diffi_tag)
92
+ for i in range(20):
93
+ if i in classes:
94
+ i_index = np.where(classes == i)[0]
95
+ if len(i_index) == 1:
96
+ target[i] = 1 - diffi_tag[i_index]
97
+ else:
98
+ if len(i_index) == sum(diffi_tag[i_index]):
99
+ target[i] = 0
100
+ else:
101
+ target[i] = 1
102
+ else:
103
+ continue
104
+ img_path = os.path.join(img_dir, item.split('.')[0]+".jpg")
105
+
106
+ if int(item.split('.')[0]) in trainval_id:
107
+ target[target == -1] = 0 # from ternary to binary by treating difficult as negatives
108
+ data = {"target": target.tolist(), "img_path": img_path}
109
+ trainval_data.append(data)
110
+ if int(item.split('.')[0]) in test_id:
111
+ data = {"target": target.tolist(), "img_path": img_path}
112
+ test_data.append(data)
113
+
114
+ json.dump(trainval_data, open("data/voc07/trainval_voc07.json", "w"))
115
+ json.dump(test_data, open("data/voc07/test_voc07.json", "w"))
116
+ print("VOC07 data preparing finished!")
117
+ print("data/voc07/trainval_voc07.json data/voc07/test_voc07.json")
118
+
119
+ # remove label cash
120
+ for item in os.listdir(label_dir):
121
+ os.remove(os.path.join(label_dir, item))
122
+ os.rmdir(label_dir)
123
+
124
+
125
+ # We treat difficult classes in trainval_data as negtive while ignore them in test_data
126
+ # The ignoring operation can be automatically done during evaluation (testing).
127
+ # The final json file include: trainval_voc07.json & test_voc07.json
128
+ # which is the following format:
129
+ # [item1, item2, item3, ......,]
130
+ # item1 = {
131
+ # "target":
132
+ # "img_path":
133
+ # }
134
+
135
+ if __name__ == "__main__":
136
+ parser = argparse.ArgumentParser()
137
+ # Usage: --data_path /your/dataset/path/VOCdevkit
138
+ parser.add_argument("--data_path", default="Dataset/VOCdevkit/", type=str, help="The absolute path of VOCdevkit")
139
+ args = parser.parse_args()
140
+
141
+ if not os.path.exists("data/voc07"):
142
+ os.makedirs("data/voc07")
143
+
144
+ if 'VOCdevkit' not in args.data_path:
145
+ print("WARNING: please include \'VOCdevkit\' str in your args.data_path")
146
+ # exit()
147
+
148
+ get_label(args.data_path)
149
+ transdifi(args.data_path)
utils/prepare/prepare_wider.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import argparse
5
+
6
+
7
+ def make_wider(tag, value, data_path):
8
+ img_path = os.path.join(data_path, "Image")
9
+ ann_path = os.path.join(data_path, "Annotations")
10
+ ann_file = os.path.join(ann_path, "wider_attribute_{}.json".format(tag))
11
+
12
+ data = json.load(open(ann_file, "r"))
13
+
14
+ final = []
15
+ image_list = data['images']
16
+ for image in image_list:
17
+ for person in image["targets"]: # iterate over each person
18
+ tmp = {}
19
+ tmp['img_path'] = os.path.join(img_path, image['file_name'])
20
+ tmp['bbox'] = person['bbox']
21
+ attr = person["attribute"]
22
+ for i, item in enumerate(attr):
23
+ if item == -1:
24
+ attr[i] = 0
25
+ if item == 0:
26
+ attr[i] = value # pad un-specified samples
27
+ if item == 1:
28
+ attr[i] = 1
29
+ tmp["target"] = attr
30
+ final.append(tmp)
31
+
32
+ json.dump(final, open("data/wider/{}_wider.json".format(tag), "w"))
33
+ print("data/wider/{}_wider.json".format(tag))
34
+
35
+
36
+
37
+ # which is the following format:
38
+ # [item1, item2, item3, ......,]
39
+ # item1 = {
40
+ # "target":
41
+ # "img_path":
42
+ # }
43
+
44
+
45
+ if __name__ == "__main__":
46
+ parser = argparse.ArgumentParser()
47
+ parser.add_argument("--data_path", default="Dataset/WIDER_ATTRIBUTE", type=str)
48
+ args = parser.parse_args()
49
+
50
+ if not os.path.exists("data/wider"):
51
+ os.makedirs("data/wider")
52
+
53
+ # 0 (zero) means negative, we treat un-specified attribute as negative in the trainval set
54
+ make_wider(tag='trainval', value=0, data_path=args.data_path)
55
+
56
+ # 99 means we ignore un-specified attribute in the test set, following previous work
57
+ # the number 99 can be properly identified when evaluating mAP
58
+ make_wider(tag='test', value=99, data_path=args.data_path)
utils/visualize.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import json
3
+ import torch
4
+ from torchvision import transforms
5
+ import cv2
6
+ import numpy as np
7
+ import os
8
+ import torch.nn as nn
9
+
10
+ def show_cam_on_img(img, mask, img_path_save):
11
+ heat_map = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
12
+ heat_map = np.float32(heat_map) / 255
13
+
14
+ cam = heat_map + np.float32(img)
15
+ cam = cam / np.max(cam)
16
+ cv2.imwrite(img_path_save, np.uint8(255 * cam))
17
+
18
+
19
+ img_path_read = ""
20
+ img_path_save = ""
21
+
22
+
23
+
24
+
25
+ def main():
26
+ img = cv2.imread(img_path_read, flags=1)
27
+
28
+ img = np.float32(cv2.resize(img, (224, 224))) / 255
29
+
30
+ # cam_all is the score tensor of shape (B, C, H, W), similar to y_raw in out Figure 1
31
+ # cls_idx specifying the i-th class out of C class
32
+ # visualize the 0's class heatmap
33
+ cls_idx = 0
34
+ cam = cam_all[cls_idx]
35
+
36
+
37
+ # cam = nn.ReLU()(cam)
38
+ cam = cam / torch.max(cam)
39
+
40
+ cam = cv2.resize(np.array(cam), (224, 224))
41
+ show_cam_on_img(img, cam, img_path_save)
42
+