ryefoxlime commited on
Commit
18a9dce
1 Parent(s): de8e1d2

Upload 26 files

Browse files

TAD Bot Face Detection algorithm

detectfaces.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from main import *
2
+ import cv2
3
+ import time
4
+
5
+ model_path = "raf-db-model_best.pth"
6
+
7
+ if torch.backends.mps.is_available():
8
+ device = "mps"
9
+ elif torch.cuda.is_available():
10
+ device = "cuda"
11
+ else:
12
+ device = "cpu"
13
+
14
+ model = pyramid_trans_expr2(img_size=224, num_classes=7)
15
+
16
+ model = torch.nn.DataParallel(model)
17
+ model = model.to(device)
18
+ currtime = time.strftime("%H:%M:%S")
19
+ print(currtime)
20
+
21
+
22
+ def main():
23
+ if model_path is not None:
24
+ if os.path.isfile(model_path):
25
+ print("=> loading checkpoint '{}'".format(model_path))
26
+ checkpoint = torch.load(model_path, map_location=device)
27
+ best_acc = checkpoint["best_acc"]
28
+ best_acc = best_acc.to()
29
+ print(f"best_acc:{best_acc}")
30
+ model.load_state_dict(checkpoint["state_dict"])
31
+ print(
32
+ "=> loaded checkpoint '{}' (epoch {})".format(
33
+ model_path, checkpoint["epoch"]
34
+ )
35
+ )
36
+ else:
37
+ print("=> no checkpoint found at '{}'".format(model_path))
38
+ imagecapture(model)
39
+ return
40
+
41
+
42
+ def imagecapture(model):
43
+ currtimeimg = time.strftime("%H:%M:%S")
44
+ cap = cv2.VideoCapture(0)
45
+ if not cap.isOpened():
46
+ print("Error: Could not open webcam.")
47
+ exit()
48
+
49
+ face_cascade = cv2.CascadeClassifier(
50
+ cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
51
+ )
52
+
53
+ start_time = None
54
+ capturing = False
55
+
56
+ while True:
57
+ from prediction import predict
58
+
59
+ ret, frame = cap.read()
60
+
61
+ if not ret:
62
+ print("Error: Could not read frame.")
63
+ break
64
+
65
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
66
+
67
+ faces = face_cascade.detectMultiScale(
68
+ gray, scaleFactor=1.3, minNeighbors=5, minSize=(30, 30)
69
+ )
70
+
71
+ # Display the frame
72
+ cv2.imshow("Webcam", frame)
73
+
74
+ # If faces are detected, start the timer
75
+ if len(faces) > 0:
76
+ print(f"[!]Face detected at {currtimeimg}")
77
+ face_region = frame[
78
+ faces[0][1] : faces[0][1] + faces[0][3],
79
+ faces[0][0] : faces[0][0] + faces[0][2],
80
+ ] # Crop the face region
81
+ face_pil_image = Image.fromarray(
82
+ cv2.cvtColor(face_region, cv2.COLOR_BGR2RGB)
83
+ ) # Convert to PIL image
84
+ print("[!]Start Expressions")
85
+ print(f"-->Prediction starting at {currtimeimg}")
86
+ predictions = predict(model, image_path=face_pil_image)
87
+ print(f"-->Done prediction at {currtimeimg}")
88
+
89
+ # Reset capturing
90
+ capturing = False
91
+
92
+ # Break the loop if the 'q' key is pressed
93
+ if cv2.waitKey(1) & 0xFF == ord("q"):
94
+ break
95
+
96
+ # Release the webcam and close the OpenCV window
97
+ cap.release()
98
+ cv2.destroyAllWindows()
99
+
100
+
101
+ if __name__ == "__main__":
102
+ main()
face_detection.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from deepface import DeepFace
2
+ import matplotlib.pyplot as plt
3
+ from PIL import Image
4
+ import numpy as np
5
+ import time
6
+
7
+
8
+ def face_detection(img_path):
9
+ currtime = time.strftime("%H:%M:%S")
10
+ face_objs = DeepFace.extract_faces(np.array(img_path), detector_backend="mtcnn", enforce_detection=False)
11
+
12
+ coordinates = face_objs[0]["facial_area"]
13
+ image = img_path
14
+ cropped_image = image.crop(
15
+ (
16
+ coordinates["x"],
17
+ coordinates["y"],
18
+ coordinates["x"] + coordinates["w"],
19
+ coordinates["y"] + coordinates["h"],
20
+ )
21
+ )
22
+ cropped_image.save(f"Images/test_{currtime}.jpg")
23
+ return cropped_image
main.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import warnings
3
+ from sklearn import metrics
4
+ from sklearn.metrics import confusion_matrix
5
+ from PIL import Image
6
+
7
+ warnings.filterwarnings("ignore")
8
+ import torch.utils.data as data
9
+ import os
10
+ import argparse
11
+ from sklearn.metrics import f1_score, confusion_matrix
12
+ from data_preprocessing.sam import SAM
13
+ import torch.nn.parallel
14
+ import torch.backends.cudnn as cudnn
15
+ import torch.optim
16
+ import torch.utils.data
17
+ import torch.utils.data.distributed
18
+ import matplotlib.pyplot as plt
19
+ import torchvision.datasets as datasets
20
+ import torchvision.transforms as transforms
21
+ import numpy as np
22
+ import datetime
23
+ from torchsampler import ImbalancedDatasetSampler
24
+ from models.PosterV2_7cls import *
25
+
26
+
27
+ warnings.filterwarnings("ignore", category=UserWarning)
28
+
29
+ now = datetime.datetime.now()
30
+ time_str = now.strftime("[%m-%d]-[%H-%M]-")
31
+ if torch.backends.mps.is_available():
32
+ device = "mps"
33
+ elif torch.cuda.is_available():
34
+ device = "cuda"
35
+ else:
36
+ device = "cpu"
37
+
38
+ print(f"Using device: {device}")
39
+
40
+ parser = argparse.ArgumentParser()
41
+ parser.add_argument("--data", type=str, default=r"raf-db/DATASET")
42
+ parser.add_argument(
43
+ "--data_type",
44
+ default="RAF-DB",
45
+ choices=["RAF-DB", "AffectNet-7", "CAER-S"],
46
+ type=str,
47
+ help="dataset option",
48
+ )
49
+ parser.add_argument(
50
+ "--checkpoint_path", type=str, default="./checkpoint/" + time_str + "model.pth"
51
+ )
52
+ parser.add_argument(
53
+ "--best_checkpoint_path",
54
+ type=str,
55
+ default="./checkpoint/" + time_str + "model_best.pth",
56
+ )
57
+ parser.add_argument(
58
+ "-j",
59
+ "--workers",
60
+ default=4,
61
+ type=int,
62
+ metavar="N",
63
+ help="number of data loading workers",
64
+ )
65
+ parser.add_argument(
66
+ "--epochs", default=200, type=int, metavar="N", help="number of total epochs to run"
67
+ )
68
+ parser.add_argument(
69
+ "--start-epoch",
70
+ default=0,
71
+ type=int,
72
+ metavar="N",
73
+ help="manual epoch number (useful on restarts)",
74
+ )
75
+ parser.add_argument("-b", "--batch-size", default=2, type=int, metavar="N")
76
+ parser.add_argument(
77
+ "--optimizer", type=str, default="adam", help="Optimizer, adam or sgd."
78
+ )
79
+
80
+ parser.add_argument(
81
+ "--lr", "--learning-rate", default=0.000035, type=float, metavar="LR", dest="lr"
82
+ )
83
+ parser.add_argument("--momentum", default=0.9, type=float, metavar="M")
84
+ parser.add_argument(
85
+ "--wd", "--weight-decay", default=1e-4, type=float, metavar="W", dest="weight_decay"
86
+ )
87
+ parser.add_argument(
88
+ "-p", "--print-freq", default=30, type=int, metavar="N", help="print frequency"
89
+ )
90
+ parser.add_argument(
91
+ "--resume", default=None, type=str, metavar="PATH", help="path to checkpoint"
92
+ )
93
+ parser.add_argument(
94
+ "-e", "--evaluate", default=None, type=str, help="evaluate model on test set"
95
+ )
96
+ parser.add_argument("--beta", type=float, default=0.6)
97
+ parser.add_argument("--gpu", type=str, default="0")
98
+
99
+ parser.add_argument(
100
+ "-i", "--image", type=str, help="upload a single image to test the prediction"
101
+ )
102
+ parser.add_argument("-t", "--test", type=str, help="test model on single image")
103
+ args = parser.parse_args()
104
+
105
+
106
+ def main():
107
+ # os.environ["CUDA_VISIBLE_DEVICES"] = device
108
+ best_acc = 0
109
+ # print("Training time: " + now.strftime("%m-%d %H:%M"))
110
+
111
+ # create model
112
+ model = pyramid_trans_expr2(img_size=224, num_classes=7)
113
+
114
+ model = torch.nn.DataParallel(model)
115
+ model = model.to(device)
116
+
117
+ criterion = torch.nn.CrossEntropyLoss()
118
+
119
+ if args.optimizer == "adamw":
120
+ base_optimizer = torch.optim.AdamW
121
+ elif args.optimizer == "adam":
122
+ base_optimizer = torch.optim.Adam
123
+ elif args.optimizer == "sgd":
124
+ base_optimizer = torch.optim.SGD
125
+ else:
126
+ raise ValueError("Optimizer not supported.")
127
+
128
+ optimizer = SAM(
129
+ model.parameters(),
130
+ base_optimizer,
131
+ lr=args.lr,
132
+ rho=0.05,
133
+ adaptive=False,
134
+ )
135
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)
136
+ recorder = RecorderMeter(args.epochs)
137
+ recorder1 = RecorderMeter1(args.epochs)
138
+
139
+ if args.resume:
140
+ if os.path.isfile(args.resume):
141
+ print("=> loading checkpoint '{}'".format(args.resume))
142
+ checkpoint = torch.load(args.resume)
143
+ args.start_epoch = checkpoint["epoch"]
144
+ best_acc = checkpoint["best_acc"]
145
+ recorder = checkpoint["recorder"]
146
+ recorder1 = checkpoint["recorder1"]
147
+ best_acc = best_acc.to()
148
+ model.load_state_dict(checkpoint["state_dict"])
149
+ optimizer.load_state_dict(checkpoint["optimizer"])
150
+ print(
151
+ "=> loaded checkpoint '{}' (epoch {})".format(
152
+ args.resume, checkpoint["epoch"]
153
+ )
154
+ )
155
+ else:
156
+ print("=> no checkpoint found at '{}'".format(args.resume))
157
+ cudnn.benchmark = True
158
+
159
+ # Data loading code
160
+ traindir = os.path.join(args.data, "train")
161
+
162
+ valdir = os.path.join(args.data, "test")
163
+
164
+ if args.evaluate is None:
165
+ if args.data_type == "RAF-DB":
166
+ train_dataset = datasets.ImageFolder(
167
+ traindir,
168
+ transforms.Compose(
169
+ [
170
+ transforms.Resize((224, 224)),
171
+ transforms.RandomHorizontalFlip(),
172
+ transforms.ToTensor(),
173
+ transforms.Normalize(
174
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
175
+ ),
176
+ transforms.RandomErasing(scale=(0.02, 0.1)),
177
+ ]
178
+ ),
179
+ )
180
+ else:
181
+ train_dataset = datasets.ImageFolder(
182
+ traindir,
183
+ transforms.Compose(
184
+ [
185
+ transforms.Resize((224, 224)),
186
+ transforms.RandomHorizontalFlip(),
187
+ transforms.ToTensor(),
188
+ transforms.Normalize(
189
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
190
+ ),
191
+ transforms.RandomErasing(p=1, scale=(0.05, 0.05)),
192
+ ]
193
+ ),
194
+ )
195
+
196
+ if args.data_type == "AffectNet-7":
197
+ train_loader = torch.utils.data.DataLoader(
198
+ train_dataset,
199
+ sampler=ImbalancedDatasetSampler(train_dataset),
200
+ batch_size=args.batch_size,
201
+ shuffle=False,
202
+ num_workers=args.workers,
203
+ pin_memory=True,
204
+ )
205
+
206
+ else:
207
+ train_loader = torch.utils.data.DataLoader(
208
+ train_dataset,
209
+ batch_size=args.batch_size,
210
+ shuffle=True,
211
+ num_workers=args.workers,
212
+ pin_memory=True,
213
+ )
214
+
215
+ test_dataset = datasets.ImageFolder(
216
+ valdir,
217
+ transforms.Compose(
218
+ [
219
+ transforms.Resize((224, 224)),
220
+ transforms.ToTensor(),
221
+ transforms.Normalize(
222
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
223
+ ),
224
+ ]
225
+ ),
226
+ )
227
+
228
+ val_loader = torch.utils.data.DataLoader(
229
+ test_dataset,
230
+ batch_size=args.batch_size,
231
+ shuffle=False,
232
+ num_workers=args.workers,
233
+ pin_memory=True,
234
+ )
235
+
236
+ if args.evaluate is not None:
237
+ from validation import validate
238
+
239
+ if os.path.isfile(args.evaluate):
240
+ print("=> loading checkpoint '{}'".format(args.evaluate))
241
+ checkpoint = torch.load(args.evaluate, map_location=device)
242
+ best_acc = checkpoint["best_acc"]
243
+ best_acc = best_acc.to()
244
+ print(f"best_acc:{best_acc}")
245
+ model.load_state_dict(checkpoint["state_dict"])
246
+ print(
247
+ "=> loaded checkpoint '{}' (epoch {})".format(
248
+ args.evaluate, checkpoint["epoch"]
249
+ )
250
+ )
251
+ else:
252
+ print("=> no checkpoint found at '{}'".format(args.evaluate))
253
+ validate(val_loader, model, criterion, args)
254
+ return
255
+
256
+ if args.test is not None:
257
+ from prediction import predict
258
+
259
+ if os.path.isfile(args.test):
260
+ print("=> loading checkpoint '{}'".format(args.test))
261
+ checkpoint = torch.load(args.test, map_location=device)
262
+ best_acc = checkpoint["best_acc"]
263
+ best_acc = best_acc.to()
264
+ print(f"best_acc:{best_acc}")
265
+ model.load_state_dict(checkpoint["state_dict"])
266
+ print(
267
+ "=> loaded checkpoint '{}' (epoch {})".format(
268
+ args.test, checkpoint["epoch"]
269
+ )
270
+ )
271
+ else:
272
+ print("=> no checkpoint found at '{}'".format(args.test))
273
+ predict(model, image_path=args.image)
274
+
275
+ return
276
+ matrix = None
277
+
278
+ for epoch in range(args.start_epoch, args.epochs):
279
+ current_learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
280
+ print("Current learning rate: ", current_learning_rate)
281
+ txt_name = "./log/" + time_str + "log.txt"
282
+ with open(txt_name, "a") as f:
283
+ f.write("Current learning rate: " + str(current_learning_rate) + "\n")
284
+
285
+ # train for one epoch
286
+ train_acc, train_los = train(
287
+ train_loader, model, criterion, optimizer, epoch, args
288
+ )
289
+
290
+ # evaluate on validation set
291
+ val_acc, val_los, output, target, D = validate(
292
+ val_loader, model, criterion, args
293
+ )
294
+
295
+ scheduler.step()
296
+
297
+ recorder.update(epoch, train_los, train_acc, val_los, val_acc)
298
+ recorder1.update(output, target)
299
+
300
+ curve_name = time_str + "cnn.png"
301
+ recorder.plot_curve(os.path.join("./log/", curve_name))
302
+
303
+ # remember best acc and save checkpoint
304
+ is_best = val_acc > best_acc
305
+ best_acc = max(val_acc, best_acc)
306
+
307
+ print("Current best accuracy: ", best_acc.item())
308
+
309
+ if is_best:
310
+ matrix = D
311
+
312
+ print("Current best matrix: ", matrix)
313
+
314
+ txt_name = "./log/" + time_str + "log.txt"
315
+ with open(txt_name, "a") as f:
316
+ f.write("Current best accuracy: " + str(best_acc.item()) + "\n")
317
+
318
+ save_checkpoint(
319
+ {
320
+ "epoch": epoch + 1,
321
+ "state_dict": model.state_dict(),
322
+ "best_acc": best_acc,
323
+ "optimizer": optimizer.state_dict(),
324
+ "recorder1": recorder1,
325
+ "recorder": recorder,
326
+ },
327
+ is_best,
328
+ args,
329
+ )
330
+
331
+
332
+ def train(train_loader, model, criterion, optimizer, epoch, args):
333
+ losses = AverageMeter("Loss", ":.4f")
334
+ top1 = AverageMeter("Accuracy", ":6.3f")
335
+ progress = ProgressMeter(
336
+ len(train_loader), [losses, top1], prefix="Epoch: [{}]".format(epoch)
337
+ )
338
+
339
+ # switch to train mode
340
+ model.train()
341
+
342
+ for i, (images, target) in enumerate(train_loader):
343
+ images = images.to(device)
344
+ target = target.to(device)
345
+
346
+ # compute output
347
+ output = model(images)
348
+ loss = criterion(output, target)
349
+
350
+ # measure accuracy and record loss
351
+ acc1, _ = accuracy(output, target, topk=(1, 5))
352
+ losses.update(loss.item(), images.size(0))
353
+ top1.update(acc1[0], images.size(0))
354
+
355
+ # compute gradient and do SGD step
356
+ optimizer.zero_grad()
357
+ loss.backward()
358
+ # optimizer.step()
359
+ optimizer.first_step(zero_grad=True)
360
+ images = images.to(device)
361
+ target = target.to(device)
362
+
363
+ # compute output
364
+ output = model(images)
365
+ loss = criterion(output, target)
366
+
367
+ # measure accuracy and record loss
368
+ acc1, _ = accuracy(output, target, topk=(1, 5))
369
+ losses.update(loss.item(), images.size(0))
370
+ top1.update(acc1[0], images.size(0))
371
+
372
+ # compute gradient and do SGD step
373
+ optimizer.zero_grad()
374
+ loss.backward()
375
+ optimizer.second_step(zero_grad=True)
376
+
377
+ # print loss and accuracy
378
+ if i % args.print_freq == 0:
379
+ progress.display(i)
380
+
381
+ return top1.avg, losses.avg
382
+
383
+
384
+ def save_checkpoint(state, is_best, args):
385
+ torch.save(state, args.checkpoint_path)
386
+ if is_best:
387
+ best_state = state.pop("optimizer")
388
+ torch.save(best_state, args.best_checkpoint_path)
389
+
390
+
391
+ class AverageMeter(object):
392
+ """Computes and stores the average and current value"""
393
+
394
+ def __init__(self, name, fmt=":f"):
395
+ self.name = name
396
+ self.fmt = fmt
397
+ self.reset()
398
+
399
+ def reset(self):
400
+ self.val = 0
401
+ self.avg = 0
402
+ self.sum = 0
403
+ self.count = 0
404
+
405
+ def update(self, val, n=1):
406
+ self.val = val
407
+ self.sum += val * n
408
+ self.count += n
409
+ self.avg = self.sum / self.count
410
+
411
+ def __str__(self):
412
+ fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
413
+ return fmtstr.format(**self.__dict__)
414
+
415
+
416
+ class ProgressMeter(object):
417
+ def __init__(self, num_batches, meters, prefix=""):
418
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
419
+ self.meters = meters
420
+ self.prefix = prefix
421
+
422
+ def display(self, batch):
423
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
424
+ entries += [str(meter) for meter in self.meters]
425
+ print_txt = "\t".join(entries)
426
+ print(print_txt)
427
+ txt_name = "./log/" + time_str + "log.txt"
428
+ with open(txt_name, "a") as f:
429
+ f.write(print_txt + "\n")
430
+
431
+ def _get_batch_fmtstr(self, num_batches):
432
+ num_digits = len(str(num_batches // 1))
433
+ fmt = "{:" + str(num_digits) + "d}"
434
+ return "[" + fmt + "/" + fmt.format(num_batches) + "]"
435
+
436
+
437
+ def accuracy(output, target, topk=(1,)):
438
+ """Computes the accuracy over the k top predictions for the specified values of k"""
439
+ with torch.no_grad():
440
+ maxk = max(topk)
441
+ batch_size = target.size(0)
442
+ _, pred = output.topk(maxk, 1, True, True)
443
+ pred = pred.t()
444
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
445
+ res = []
446
+ for k in topk:
447
+ correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
448
+ res.append(correct_k.mul_(100.0 / batch_size))
449
+ return res
450
+
451
+
452
+ labels = ["A", "B", "C", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O"]
453
+
454
+
455
+ class RecorderMeter1(object):
456
+ """Computes and stores the minimum loss value and its epoch index"""
457
+
458
+ def __init__(self, total_epoch):
459
+ self.reset(total_epoch)
460
+
461
+ def reset(self, total_epoch):
462
+ self.total_epoch = total_epoch
463
+ self.current_epoch = 0
464
+ self.epoch_losses = np.zeros(
465
+ (self.total_epoch, 2), dtype=np.float32
466
+ ) # [epoch, train/val]
467
+ self.epoch_accuracy = np.zeros(
468
+ (self.total_epoch, 2), dtype=np.float32
469
+ ) # [epoch, train/val]
470
+
471
+ def update(self, output, target):
472
+ self.y_pred = output
473
+ self.y_true = target
474
+
475
+ def plot_confusion_matrix(self, cm, title="Confusion Matrix", cmap=plt.cm.binary):
476
+ plt.imshow(cm, interpolation="nearest", cmap=cmap)
477
+ y_true = self.y_true
478
+ y_pred = self.y_pred
479
+
480
+ plt.title(title)
481
+ plt.colorbar()
482
+ xlocations = np.array(range(len(labels)))
483
+ plt.xticks(xlocations, labels, rotation=90)
484
+ plt.yticks(xlocations, labels)
485
+ plt.ylabel("True label")
486
+ plt.xlabel("Predicted label")
487
+
488
+ cm = confusion_matrix(y_true, y_pred)
489
+ np.set_printoptions(precision=2)
490
+ cm_normalized = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
491
+ plt.figure(figsize=(12, 8), dpi=120)
492
+
493
+ ind_array = np.arange(len(labels))
494
+ x, y = np.meshgrid(ind_array, ind_array)
495
+ for x_val, y_val in zip(x.flatten(), y.flatten()):
496
+ c = cm_normalized[y_val][x_val]
497
+ if c > 0.01:
498
+ plt.text(
499
+ x_val,
500
+ y_val,
501
+ "%0.2f" % (c,),
502
+ color="red",
503
+ fontsize=7,
504
+ va="center",
505
+ ha="center",
506
+ )
507
+ # offset the tick
508
+ tick_marks = np.arange(len(7))
509
+ plt.gca().set_xticks(tick_marks, minor=True)
510
+ plt.gca().set_yticks(tick_marks, minor=True)
511
+ plt.gca().xaxis.set_ticks_position("none")
512
+ plt.gca().yaxis.set_ticks_position("none")
513
+ plt.grid(True, which="minor", linestyle="-")
514
+ plt.gcf().subplots_adjust(bottom=0.15)
515
+
516
+ plot_confusion_matrix(cm_normalized, title="Normalized confusion matrix")
517
+ # show confusion matrix
518
+ plt.savefig("./log/confusion_matrix.png", format="png")
519
+ # fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
520
+ print("Saved figure")
521
+ plt.show()
522
+
523
+ def matrix(self):
524
+ target = self.y_true
525
+ output = self.y_pred
526
+ im_re_label = np.array(target)
527
+ im_pre_label = np.array(output)
528
+ y_ture = im_re_label.flatten()
529
+ # im_re_label.transpose()
530
+ y_pred = im_pre_label.flatten()
531
+ im_pre_label.transpose()
532
+
533
+
534
+ class RecorderMeter(object):
535
+ """Computes and stores the minimum loss value and its epoch index"""
536
+
537
+ def __init__(self, total_epoch):
538
+ self.reset(total_epoch)
539
+
540
+ def reset(self, total_epoch):
541
+ self.total_epoch = total_epoch
542
+ self.current_epoch = 0
543
+ self.epoch_losses = np.zeros(
544
+ (self.total_epoch, 2), dtype=np.float32
545
+ ) # [epoch, train/val]
546
+ self.epoch_accuracy = np.zeros(
547
+ (self.total_epoch, 2), dtype=np.float32
548
+ ) # [epoch, train/val]
549
+
550
+ def update(self, idx, train_loss, train_acc, val_loss, val_acc):
551
+ self.epoch_losses[idx, 0] = train_loss * 30
552
+ self.epoch_losses[idx, 1] = val_loss * 30
553
+ self.epoch_accuracy[idx, 0] = train_acc
554
+ self.epoch_accuracy[idx, 1] = val_acc
555
+ self.current_epoch = idx + 1
556
+
557
+ def plot_curve(self, save_path):
558
+ title = "the accuracy/loss curve of train/val"
559
+ dpi = 80
560
+ width, height = 1800, 800
561
+ legend_fontsize = 10
562
+ figsize = width / float(dpi), height / float(dpi)
563
+
564
+ fig = plt.figure(figsize=figsize)
565
+ x_axis = np.array([i for i in range(self.total_epoch)]) # epochs
566
+ y_axis = np.zeros(self.total_epoch)
567
+
568
+ plt.xlim(0, self.total_epoch)
569
+ plt.ylim(0, 100)
570
+ interval_y = 5
571
+ interval_x = 5
572
+ plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
573
+ plt.yticks(np.arange(0, 100 + interval_y, interval_y))
574
+ plt.grid()
575
+ plt.title(title, fontsize=20)
576
+ plt.xlabel("the training epoch", fontsize=16)
577
+ plt.ylabel("accuracy", fontsize=16)
578
+
579
+ y_axis[:] = self.epoch_accuracy[:, 0]
580
+ plt.plot(x_axis, y_axis, color="g", linestyle="-", label="train-accuracy", lw=2)
581
+ plt.legend(loc=4, fontsize=legend_fontsize)
582
+
583
+ y_axis[:] = self.epoch_accuracy[:, 1]
584
+ plt.plot(x_axis, y_axis, color="y", linestyle="-", label="valid-accuracy", lw=2)
585
+ plt.legend(loc=4, fontsize=legend_fontsize)
586
+
587
+ y_axis[:] = self.epoch_losses[:, 0]
588
+ plt.plot(x_axis, y_axis, color="g", linestyle=":", label="train-loss-x30", lw=2)
589
+ plt.legend(loc=4, fontsize=legend_fontsize)
590
+
591
+ y_axis[:] = self.epoch_losses[:, 1]
592
+ plt.plot(x_axis, y_axis, color="y", linestyle=":", label="valid-loss-x30", lw=2)
593
+ plt.legend(loc=4, fontsize=legend_fontsize)
594
+
595
+ if save_path is not None:
596
+ fig.savefig(save_path, dpi=dpi, bbox_inches="tight")
597
+ print("Saved figure")
598
+ plt.close(fig)
599
+
600
+
601
+ if __name__ == "__main__":
602
+ main()
models/.DS_Store ADDED
Binary file (8.2 kB). View file
 
models/PosterV2_7cls.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from .mobilefacenet import MobileFaceNet
6
+ from .ir50 import Backbone
7
+ from .vit_model import VisionTransformer, PatchEmbed
8
+ from timm.models.layers import trunc_normal_, DropPath
9
+ from thop import profile
10
+
11
+
12
+ def load_pretrained_weights(model, checkpoint):
13
+ import collections
14
+
15
+ if "state_dict" in checkpoint:
16
+ state_dict = checkpoint["state_dict"]
17
+ else:
18
+ state_dict = checkpoint
19
+ model_dict = model.state_dict()
20
+ new_state_dict = collections.OrderedDict()
21
+ matched_layers, discarded_layers = [], []
22
+ for k, v in state_dict.items():
23
+ # If the pretrained state_dict was saved as nn.DataParallel,
24
+ # keys would contain "module.", which should be ignored.
25
+ if k.startswith("module."):
26
+ k = k[7:]
27
+ if k in model_dict and model_dict[k].size() == v.size():
28
+ new_state_dict[k] = v
29
+ matched_layers.append(k)
30
+ else:
31
+ discarded_layers.append(k)
32
+ # new_state_dict.requires_grad = False
33
+ model_dict.update(new_state_dict)
34
+
35
+ model.load_state_dict(model_dict)
36
+ print("load_weight", len(matched_layers))
37
+ return model
38
+
39
+
40
+ def window_partition(x, window_size, h_w, w_w):
41
+ """
42
+ Args:
43
+ x: (B, H, W, C)
44
+ window_size: window size
45
+
46
+ Returns:
47
+ local window features (num_windows*B, window_size, window_size, C)
48
+ """
49
+ B, H, W, C = x.shape
50
+ x = x.view(B, h_w, window_size, w_w, window_size, C)
51
+ windows = (
52
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
53
+ )
54
+ return windows
55
+
56
+
57
+ class window(nn.Module):
58
+ def __init__(self, window_size, dim):
59
+ super(window, self).__init__()
60
+ self.window_size = window_size
61
+ self.norm = nn.LayerNorm(dim)
62
+
63
+ def forward(self, x):
64
+ x = x.permute(0, 2, 3, 1)
65
+ B, H, W, C = x.shape
66
+ x = self.norm(x)
67
+ shortcut = x
68
+ h_w = int(torch.div(H, self.window_size).item())
69
+ w_w = int(torch.div(W, self.window_size).item())
70
+ x_windows = window_partition(x, self.window_size, h_w, w_w)
71
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
72
+ return x_windows, shortcut
73
+
74
+
75
+ class WindowAttentionGlobal(nn.Module):
76
+ """
77
+ Global window attention based on: "Hatamizadeh et al.,
78
+ Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ dim,
84
+ num_heads,
85
+ window_size,
86
+ qkv_bias=True,
87
+ qk_scale=None,
88
+ attn_drop=0.0,
89
+ proj_drop=0.0,
90
+ ):
91
+ """
92
+ Args:
93
+ dim: feature size dimension.
94
+ num_heads: number of attention head.
95
+ window_size: window size.
96
+ qkv_bias: bool argument for query, key, value learnable bias.
97
+ qk_scale: bool argument to scaling query, key.
98
+ attn_drop: attention dropout rate.
99
+ proj_drop: output dropout rate.
100
+ """
101
+
102
+ super().__init__()
103
+ window_size = (window_size, window_size)
104
+ self.window_size = window_size
105
+ self.num_heads = num_heads
106
+ head_dim = torch.div(dim, num_heads)
107
+ self.scale = qk_scale or head_dim**-0.5
108
+ self.relative_position_bias_table = nn.Parameter(
109
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
110
+ )
111
+ coords_h = torch.arange(self.window_size[0])
112
+ coords_w = torch.arange(self.window_size[1])
113
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
114
+ coords_flatten = torch.flatten(coords, 1)
115
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
116
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
117
+ relative_coords[:, :, 0] += self.window_size[0] - 1
118
+ relative_coords[:, :, 1] += self.window_size[1] - 1
119
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
120
+ relative_position_index = relative_coords.sum(-1)
121
+ self.register_buffer("relative_position_index", relative_position_index)
122
+ self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
123
+ self.attn_drop = nn.Dropout(attn_drop)
124
+ self.proj = nn.Linear(dim, dim)
125
+ self.proj_drop = nn.Dropout(proj_drop)
126
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
127
+ self.softmax = nn.Softmax(dim=-1)
128
+
129
+ def forward(self, x, q_global):
130
+ # print(f'q_global.shape:{q_global.shape}')
131
+ # print(f'x.shape:{x.shape}')
132
+ B_, N, C = x.shape
133
+ B = q_global.shape[0]
134
+ head_dim = int(torch.div(C, self.num_heads).item())
135
+ B_dim = int(torch.div(B_, B).item())
136
+ kv = (
137
+ self.qkv(x)
138
+ .reshape(B_, N, 2, self.num_heads, head_dim)
139
+ .permute(2, 0, 3, 1, 4)
140
+ )
141
+ k, v = kv[0], kv[1]
142
+ q_global = q_global.repeat(1, B_dim, 1, 1, 1)
143
+ q = q_global.reshape(B_, self.num_heads, N, head_dim)
144
+ q = q * self.scale
145
+ attn = q @ k.transpose(-2, -1)
146
+ relative_position_bias = self.relative_position_bias_table[
147
+ self.relative_position_index.view(-1)
148
+ ].view(
149
+ self.window_size[0] * self.window_size[1],
150
+ self.window_size[0] * self.window_size[1],
151
+ -1,
152
+ )
153
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
154
+ attn = attn + relative_position_bias.unsqueeze(0)
155
+ attn = self.softmax(attn)
156
+ attn = self.attn_drop(attn)
157
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
158
+ x = self.proj(x)
159
+ x = self.proj_drop(x)
160
+ return x
161
+
162
+
163
+ def _to_channel_last(x):
164
+ """
165
+ Args:
166
+ x: (B, C, H, W)
167
+
168
+ Returns:
169
+ x: (B, H, W, C)
170
+ """
171
+ return x.permute(0, 2, 3, 1)
172
+
173
+
174
+ def _to_channel_first(x):
175
+ return x.permute(0, 3, 1, 2)
176
+
177
+
178
+ def _to_query(x, N, num_heads, dim_head):
179
+ B = x.shape[0]
180
+ x = x.reshape(B, 1, N, num_heads, dim_head).permute(0, 1, 3, 2, 4)
181
+ return x
182
+
183
+
184
+ class Mlp(nn.Module):
185
+ """
186
+ Multi-Layer Perceptron (MLP) block
187
+ """
188
+
189
+ def __init__(
190
+ self,
191
+ in_features,
192
+ hidden_features=None,
193
+ out_features=None,
194
+ act_layer=nn.GELU,
195
+ drop=0.0,
196
+ ):
197
+ """
198
+ Args:
199
+ in_features: input features dimension.
200
+ hidden_features: hidden features dimension.
201
+ out_features: output features dimension.
202
+ act_layer: activation function.
203
+ drop: dropout rate.
204
+ """
205
+
206
+ super().__init__()
207
+ out_features = out_features or in_features
208
+ hidden_features = hidden_features or in_features
209
+ self.fc1 = nn.Linear(in_features, hidden_features)
210
+ self.act = act_layer()
211
+ self.fc2 = nn.Linear(hidden_features, out_features)
212
+ self.drop = nn.Dropout(drop)
213
+
214
+ def forward(self, x):
215
+ x = self.fc1(x)
216
+ x = self.act(x)
217
+ x = self.drop(x)
218
+ x = self.fc2(x)
219
+ x = self.drop(x)
220
+ return x
221
+
222
+
223
+ def window_reverse(windows, window_size, H, W, h_w, w_w):
224
+ """
225
+ Args:
226
+ windows: local window features (num_windows*B, window_size, window_size, C)
227
+ window_size: Window size
228
+ H: Height of image
229
+ W: Width of image
230
+
231
+ Returns:
232
+ x: (B, H, W, C)
233
+ """
234
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
235
+ x = windows.view(B, h_w, w_w, window_size, window_size, -1)
236
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
237
+ return x
238
+
239
+
240
+ class feedforward(nn.Module):
241
+ def __init__(
242
+ self,
243
+ dim,
244
+ window_size,
245
+ mlp_ratio=4.0,
246
+ act_layer=nn.GELU,
247
+ drop=0.0,
248
+ drop_path=0.0,
249
+ layer_scale=None,
250
+ ):
251
+ super(feedforward, self).__init__()
252
+ if layer_scale is not None and type(layer_scale) in [int, float]:
253
+ self.layer_scale = True
254
+ self.gamma1 = nn.Parameter(
255
+ layer_scale * torch.ones(dim), requires_grad=True
256
+ )
257
+ self.gamma2 = nn.Parameter(
258
+ layer_scale * torch.ones(dim), requires_grad=True
259
+ )
260
+ else:
261
+ self.gamma1 = 1.0
262
+ self.gamma2 = 1.0
263
+ self.window_size = window_size
264
+ self.mlp = Mlp(
265
+ in_features=dim,
266
+ hidden_features=int(dim * mlp_ratio),
267
+ act_layer=act_layer,
268
+ drop=drop,
269
+ )
270
+ self.norm = nn.LayerNorm(dim)
271
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
272
+
273
+ def forward(self, attn_windows, shortcut):
274
+ B, H, W, C = shortcut.shape
275
+ h_w = int(torch.div(H, self.window_size).item())
276
+ w_w = int(torch.div(W, self.window_size).item())
277
+ x = window_reverse(attn_windows, self.window_size, H, W, h_w, w_w)
278
+ x = shortcut + self.drop_path(self.gamma1 * x)
279
+ x = x + self.drop_path(self.gamma2 * self.mlp(self.norm(x)))
280
+ return x
281
+
282
+
283
+ class pyramid_trans_expr2(nn.Module):
284
+ def __init__(
285
+ self,
286
+ img_size=224,
287
+ num_classes=7,
288
+ window_size=[28, 14, 7],
289
+ num_heads=[2, 4, 8],
290
+ dims=[64, 128, 256],
291
+ embed_dim=768,
292
+ ):
293
+ super().__init__()
294
+
295
+ self.img_size = img_size
296
+ self.num_heads = num_heads
297
+ self.dim_head = []
298
+ for num_head, dim in zip(num_heads, dims):
299
+ self.dim_head.append(int(torch.div(dim, num_head).item()))
300
+ self.num_classes = num_classes
301
+ self.window_size = window_size
302
+ self.N = [win * win for win in window_size]
303
+ self.face_landback = MobileFaceNet([112, 112], 136)
304
+
305
+ mobilefacenet_path = os.path.join(
306
+ os.getcwd(), "models/pretrain/mobilefacenet_model_best.pth.tar"
307
+ )
308
+ ir50_path = os.path.join(os.getcwd(), "models/pretrain/ir50.pth")
309
+
310
+ print(mobilefacenet_path)
311
+ face_landback_checkpoint = torch.load(
312
+ mobilefacenet_path,
313
+ map_location=lambda storage, loc: storage,
314
+ )
315
+ self.face_landback.load_state_dict(face_landback_checkpoint["state_dict"])
316
+
317
+ for param in self.face_landback.parameters():
318
+ param.requires_grad = False
319
+
320
+ self.VIT = VisionTransformer(depth=2, embed_dim=embed_dim)
321
+
322
+ self.ir_back = Backbone(50, 0.0, "ir")
323
+ ir_checkpoint = torch.load(
324
+ ir50_path,
325
+ map_location=lambda storage, loc: storage,
326
+ )
327
+
328
+ self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
329
+
330
+ self.attn1 = WindowAttentionGlobal(
331
+ dim=dims[0], num_heads=num_heads[0], window_size=window_size[0]
332
+ )
333
+ self.attn2 = WindowAttentionGlobal(
334
+ dim=dims[1], num_heads=num_heads[1], window_size=window_size[1]
335
+ )
336
+ self.attn3 = WindowAttentionGlobal(
337
+ dim=dims[2], num_heads=num_heads[2], window_size=window_size[2]
338
+ )
339
+ self.window1 = window(window_size=window_size[0], dim=dims[0])
340
+ self.window2 = window(window_size=window_size[1], dim=dims[1])
341
+ self.window3 = window(window_size=window_size[2], dim=dims[2])
342
+ self.conv1 = nn.Conv2d(
343
+ in_channels=dims[0],
344
+ out_channels=dims[0],
345
+ kernel_size=3,
346
+ stride=2,
347
+ padding=1,
348
+ )
349
+ self.conv2 = nn.Conv2d(
350
+ in_channels=dims[1],
351
+ out_channels=dims[1],
352
+ kernel_size=3,
353
+ stride=2,
354
+ padding=1,
355
+ )
356
+ self.conv3 = nn.Conv2d(
357
+ in_channels=dims[2],
358
+ out_channels=dims[2],
359
+ kernel_size=3,
360
+ stride=2,
361
+ padding=1,
362
+ )
363
+
364
+ dpr = [x.item() for x in torch.linspace(0, 0.5, 5)]
365
+ self.ffn1 = feedforward(
366
+ dim=dims[0], window_size=window_size[0], layer_scale=1e-5, drop_path=dpr[0]
367
+ )
368
+ self.ffn2 = feedforward(
369
+ dim=dims[1], window_size=window_size[1], layer_scale=1e-5, drop_path=dpr[1]
370
+ )
371
+ self.ffn3 = feedforward(
372
+ dim=dims[2], window_size=window_size[2], layer_scale=1e-5, drop_path=dpr[2]
373
+ )
374
+
375
+ self.last_face_conv = nn.Conv2d(
376
+ in_channels=512, out_channels=256, kernel_size=3, padding=1
377
+ )
378
+
379
+ self.embed_q = nn.Sequential(
380
+ nn.Conv2d(dims[0], 768, kernel_size=3, stride=2, padding=1),
381
+ nn.Conv2d(768, 768, kernel_size=3, stride=2, padding=1),
382
+ )
383
+ self.embed_k = nn.Sequential(
384
+ nn.Conv2d(dims[1], 768, kernel_size=3, stride=2, padding=1)
385
+ )
386
+ self.embed_v = PatchEmbed(img_size=14, patch_size=14, in_c=256, embed_dim=768)
387
+
388
+ def forward(self, x):
389
+ x_face = F.interpolate(x, size=112)
390
+ x_face1, x_face2, x_face3 = self.face_landback(x_face)
391
+ x_face3 = self.last_face_conv(x_face3)
392
+ x_face1, x_face2, x_face3 = (
393
+ _to_channel_last(x_face1),
394
+ _to_channel_last(x_face2),
395
+ _to_channel_last(x_face3),
396
+ )
397
+
398
+ q1, q2, q3 = (
399
+ _to_query(x_face1, self.N[0], self.num_heads[0], self.dim_head[0]),
400
+ _to_query(x_face2, self.N[1], self.num_heads[1], self.dim_head[1]),
401
+ _to_query(x_face3, self.N[2], self.num_heads[2], self.dim_head[2]),
402
+ )
403
+
404
+ x_ir1, x_ir2, x_ir3 = self.ir_back(x)
405
+
406
+ x_ir1, x_ir2, x_ir3 = self.conv1(x_ir1), self.conv2(x_ir2), self.conv3(x_ir3)
407
+ x_window1, shortcut1 = self.window1(x_ir1)
408
+ x_window2, shortcut2 = self.window2(x_ir2)
409
+ x_window3, shortcut3 = self.window3(x_ir3)
410
+
411
+ o1, o2, o3 = (
412
+ self.attn1(x_window1, q1),
413
+ self.attn2(x_window2, q2),
414
+ self.attn3(x_window3, q3),
415
+ )
416
+
417
+ o1, o2, o3 = (
418
+ self.ffn1(o1, shortcut1),
419
+ self.ffn2(o2, shortcut2),
420
+ self.ffn3(o3, shortcut3),
421
+ )
422
+
423
+ o1, o2, o3 = _to_channel_first(o1), _to_channel_first(o2), _to_channel_first(o3)
424
+
425
+ o1, o2, o3 = (
426
+ self.embed_q(o1).flatten(2).transpose(1, 2),
427
+ self.embed_k(o2).flatten(2).transpose(1, 2),
428
+ self.embed_v(o3),
429
+ )
430
+
431
+ o = torch.cat([o1, o2, o3], dim=1)
432
+
433
+ out = self.VIT(o)
434
+ return out
435
+
436
+
437
+ def compute_param_flop():
438
+ model = pyramid_trans_expr2()
439
+ img = torch.rand(size=(1, 3, 224, 224))
440
+ flops, params = profile(model, inputs=(img,))
441
+ print(f"flops:{flops/1000**3}G,params:{params/1000**2}M")
models/PosterV2_8cls.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from .mobilefacenet import MobileFaceNet
5
+ from .ir50 import Backbone
6
+ from .vit_model_8 import VisionTransformer, PatchEmbed
7
+ from timm.models.layers import trunc_normal_, DropPath
8
+ from thop import profile
9
+
10
+ def load_pretrained_weights(model, checkpoint):
11
+ import collections
12
+ if 'state_dict' in checkpoint:
13
+ state_dict = checkpoint['state_dict']
14
+ else:
15
+ state_dict = checkpoint
16
+ model_dict = model.state_dict()
17
+ new_state_dict = collections.OrderedDict()
18
+ matched_layers, discarded_layers = [], []
19
+ for k, v in state_dict.items():
20
+ # If the pretrained state_dict was saved as nn.DataParallel,
21
+ # keys would contain "module.", which should be ignored.
22
+ if k.startswith('module.'):
23
+ k = k[7:]
24
+ if k in model_dict and model_dict[k].size() == v.size():
25
+ new_state_dict[k] = v
26
+ matched_layers.append(k)
27
+ else:
28
+ discarded_layers.append(k)
29
+ # new_state_dict.requires_grad = False
30
+ model_dict.update(new_state_dict)
31
+
32
+ model.load_state_dict(model_dict)
33
+ print('load_weight', len(matched_layers))
34
+ return model
35
+
36
+ def window_partition(x, window_size, h_w, w_w):
37
+ """
38
+ Args:
39
+ x: (B, H, W, C)
40
+ window_size: window size
41
+
42
+ Returns:
43
+ local window features (num_windows*B, window_size, window_size, C)
44
+ """
45
+ B, H, W, C = x.shape
46
+ x = x.view(B, h_w, window_size, w_w, window_size, C)
47
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
48
+ return windows
49
+
50
+ class window(nn.Module):
51
+ def __init__(self, window_size, dim):
52
+ super(window, self).__init__()
53
+ self.window_size = window_size
54
+ self.norm = nn.LayerNorm(dim)
55
+ def forward(self, x):
56
+ x = x.permute(0, 2, 3, 1)
57
+ B, H, W, C = x.shape
58
+ x = self.norm(x)
59
+ shortcut = x
60
+ h_w = int(torch.div(H, self.window_size).item())
61
+ w_w = int(torch.div(W, self.window_size).item())
62
+ x_windows = window_partition(x, self.window_size, h_w, w_w)
63
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
64
+ return x_windows, shortcut
65
+
66
+ class WindowAttentionGlobal(nn.Module):
67
+ """
68
+ Global window attention based on: "Hatamizadeh et al.,
69
+ Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
70
+ """
71
+
72
+ def __init__(self,
73
+ dim,
74
+ num_heads,
75
+ window_size,
76
+ qkv_bias=True,
77
+ qk_scale=None,
78
+ attn_drop=0.,
79
+ proj_drop=0.,
80
+ ):
81
+ """
82
+ Args:
83
+ dim: feature size dimension.
84
+ num_heads: number of attention head.
85
+ window_size: window size.
86
+ qkv_bias: bool argument for query, key, value learnable bias.
87
+ qk_scale: bool argument to scaling query, key.
88
+ attn_drop: attention dropout rate.
89
+ proj_drop: output dropout rate.
90
+ """
91
+
92
+ super().__init__()
93
+ window_size = (window_size, window_size)
94
+ self.window_size = window_size
95
+ self.num_heads = num_heads
96
+ head_dim = torch.div(dim, num_heads)
97
+ self.scale = qk_scale or head_dim ** -0.5
98
+ self.relative_position_bias_table = nn.Parameter(
99
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
100
+ coords_h = torch.arange(self.window_size[0])
101
+ coords_w = torch.arange(self.window_size[1])
102
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
103
+ coords_flatten = torch.flatten(coords, 1)
104
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
105
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
106
+ relative_coords[:, :, 0] += self.window_size[0] - 1
107
+ relative_coords[:, :, 1] += self.window_size[1] - 1
108
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
109
+ relative_position_index = relative_coords.sum(-1)
110
+ self.register_buffer("relative_position_index", relative_position_index)
111
+ self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
112
+ self.attn_drop = nn.Dropout(attn_drop)
113
+ self.proj = nn.Linear(dim, dim)
114
+ self.proj_drop = nn.Dropout(proj_drop)
115
+ trunc_normal_(self.relative_position_bias_table, std=.02)
116
+ self.softmax = nn.Softmax(dim=-1)
117
+
118
+ def forward(self, x, q_global):
119
+ # print(f'q_global.shape:{q_global.shape}')
120
+ # print(f'x.shape:{x.shape}')
121
+ B_, N, C = x.shape
122
+ B = q_global.shape[0]
123
+ head_dim = int(torch.div(C, self.num_heads).item())
124
+ B_dim = int(torch.div(B_, B).item())
125
+ kv = self.qkv(x).reshape(B_, N, 2, self.num_heads, head_dim).permute(2, 0, 3, 1, 4)
126
+ k, v = kv[0], kv[1]
127
+ q_global = q_global.repeat(1, B_dim, 1, 1, 1)
128
+ q = q_global.reshape(B_, self.num_heads, N, head_dim)
129
+ q = q * self.scale
130
+ attn = (q @ k.transpose(-2, -1))
131
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
132
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
133
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
134
+ attn = attn + relative_position_bias.unsqueeze(0)
135
+ attn = self.softmax(attn)
136
+ attn = self.attn_drop(attn)
137
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
138
+ x = self.proj(x)
139
+ x = self.proj_drop(x)
140
+ return x
141
+
142
+ def _to_channel_last(x):
143
+ """
144
+ Args:
145
+ x: (B, C, H, W)
146
+
147
+ Returns:
148
+ x: (B, H, W, C)
149
+ """
150
+ return x.permute(0, 2, 3, 1)
151
+
152
+ def _to_channel_first(x):
153
+ return x.permute(0, 3, 1, 2)
154
+
155
+ def _to_query(x, N, num_heads, dim_head):
156
+ B = x.shape[0]
157
+ x = x.reshape(B, 1, N, num_heads, dim_head).permute(0, 1, 3, 2, 4)
158
+ return x
159
+
160
+ class Mlp(nn.Module):
161
+ """
162
+ Multi-Layer Perceptron (MLP) block
163
+ """
164
+
165
+ def __init__(self,
166
+ in_features,
167
+ hidden_features=None,
168
+ out_features=None,
169
+ act_layer=nn.GELU,
170
+ drop=0.):
171
+ """
172
+ Args:
173
+ in_features: input features dimension.
174
+ hidden_features: hidden features dimension.
175
+ out_features: output features dimension.
176
+ act_layer: activation function.
177
+ drop: dropout rate.
178
+ """
179
+
180
+ super().__init__()
181
+ out_features = out_features or in_features
182
+ hidden_features = hidden_features or in_features
183
+ self.fc1 = nn.Linear(in_features, hidden_features)
184
+ self.act = act_layer()
185
+ self.fc2 = nn.Linear(hidden_features, out_features)
186
+ self.drop = nn.Dropout(drop)
187
+
188
+ def forward(self, x):
189
+ x = self.fc1(x)
190
+ x = self.act(x)
191
+ x = self.drop(x)
192
+ x = self.fc2(x)
193
+ x = self.drop(x)
194
+ return x
195
+
196
+ def window_reverse(windows, window_size, H, W, h_w, w_w):
197
+ """
198
+ Args:
199
+ windows: local window features (num_windows*B, window_size, window_size, C)
200
+ window_size: Window size
201
+ H: Height of image
202
+ W: Width of image
203
+
204
+ Returns:
205
+ x: (B, H, W, C)
206
+ """
207
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
208
+ x = windows.view(B, h_w, w_w, window_size, window_size, -1)
209
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
210
+ return x
211
+
212
+ class feedforward(nn.Module):
213
+ def __init__(self, dim, window_size, mlp_ratio=4., act_layer=nn.GELU, drop=0., drop_path=0., layer_scale=None):
214
+ super(feedforward, self).__init__()
215
+ if layer_scale is not None and type(layer_scale) in [int, float]:
216
+ self.layer_scale = True
217
+ self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim), requires_grad=True)
218
+ self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim), requires_grad=True)
219
+ else:
220
+ self.gamma1 = 1.0
221
+ self.gamma2 = 1.0
222
+ self.window_size = window_size
223
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
224
+ self.norm = nn.LayerNorm(dim)
225
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
226
+ def forward(self, attn_windows, shortcut):
227
+ B, H, W, C = shortcut.shape
228
+ h_w = int(torch.div(H, self.window_size).item())
229
+ w_w = int(torch.div(W, self.window_size).item())
230
+ x = window_reverse(attn_windows, self.window_size, H, W, h_w, w_w)
231
+ x = shortcut + self.drop_path(self.gamma1 * x)
232
+ x = x + self.drop_path(self.gamma2 * self.mlp(self.norm(x)))
233
+ return x
234
+
235
+ class pyramid_trans_expr2(nn.Module):
236
+ def __init__(self, img_size=224, num_classes=8, window_size=[28,14,7], num_heads=[2, 4, 8], dims=[64, 128, 256], embed_dim=768):
237
+ super().__init__()
238
+
239
+ self.img_size = img_size
240
+ self.num_heads = num_heads
241
+ self.dim_head = []
242
+ for num_head, dim in zip(num_heads, dims):
243
+ self.dim_head.append(int(torch.div(dim, num_head).item()))
244
+ self.num_classes = num_classes
245
+ self.window_size = window_size
246
+ self.N = [win * win for win in window_size]
247
+ self.face_landback = MobileFaceNet([112, 112], 136)
248
+ face_landback_checkpoint = torch.load(r'./pretrain/mobilefacenet_model_best.pth.tar',
249
+ map_location=lambda storage, loc: storage)
250
+ self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
251
+
252
+ for param in self.face_landback.parameters():
253
+ param.requires_grad = False
254
+
255
+ self.VIT = VisionTransformer(depth=2, embed_dim=embed_dim, num_classes=num_classes)
256
+
257
+ self.ir_back = Backbone(50, 0.0, 'ir')
258
+ ir_checkpoint = torch.load(r'./pretrain/ir50.pth', map_location=lambda storage, loc: storage)
259
+
260
+ self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
261
+
262
+ self.attn1 = WindowAttentionGlobal(dim=dims[0], num_heads=num_heads[0], window_size=window_size[0])
263
+ self.attn2 = WindowAttentionGlobal(dim=dims[1], num_heads=num_heads[1], window_size=window_size[1])
264
+ self.attn3 = WindowAttentionGlobal(dim=dims[2], num_heads=num_heads[2], window_size=window_size[2])
265
+ self.window1 = window(window_size=window_size[0], dim=dims[0])
266
+ self.window2 = window(window_size=window_size[1], dim=dims[1])
267
+ self.window3 = window(window_size=window_size[2], dim=dims[2])
268
+ self.conv1 = nn.Conv2d(in_channels=dims[0], out_channels=dims[0], kernel_size=3, stride=2, padding=1)
269
+ self.conv2 = nn.Conv2d(in_channels=dims[1], out_channels=dims[1], kernel_size=3, stride=2, padding=1)
270
+ self.conv3 = nn.Conv2d(in_channels=dims[2], out_channels=dims[2], kernel_size=3, stride=2, padding=1)
271
+
272
+ dpr = [x.item() for x in torch.linspace(0, 0.5, 5)]
273
+ self.ffn1 = feedforward(dim=dims[0], window_size=window_size[0], layer_scale=1e-5, drop_path=dpr[0])
274
+ self.ffn2 = feedforward(dim=dims[1], window_size=window_size[1], layer_scale=1e-5, drop_path=dpr[1])
275
+ self.ffn3 = feedforward(dim=dims[2], window_size=window_size[2], layer_scale=1e-5, drop_path=dpr[2])
276
+
277
+ self.last_face_conv = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
278
+
279
+ self.embed_q = nn.Sequential(nn.Conv2d(dims[0], 768, kernel_size=3, stride=2, padding=1),
280
+ nn.Conv2d(768, 768, kernel_size=3, stride=2, padding=1))
281
+ self.embed_k = nn.Sequential(nn.Conv2d(dims[1], 768, kernel_size=3, stride=2, padding=1))
282
+ self.embed_v = PatchEmbed(img_size=14, patch_size=14, in_c=256, embed_dim=768)
283
+
284
+ def forward(self, x):
285
+ x_face = F.interpolate(x, size=112)
286
+ x_face1 , x_face2, x_face3 = self.face_landback(x_face)
287
+ x_face3 = self.last_face_conv(x_face3)
288
+ x_face1, x_face2, x_face3 = _to_channel_last(x_face1), _to_channel_last(x_face2), _to_channel_last(x_face3)
289
+
290
+ q1, q2, q3 = _to_query(x_face1, self.N[0], self.num_heads[0], self.dim_head[0]), \
291
+ _to_query(x_face2, self.N[1], self.num_heads[1], self.dim_head[1]), \
292
+ _to_query(x_face3, self.N[2], self.num_heads[2], self.dim_head[2])
293
+
294
+ x_ir1, x_ir2, x_ir3 = self.ir_back(x)
295
+ x_ir1, x_ir2, x_ir3 = self.conv1(x_ir1), self.conv2(x_ir2), self.conv3(x_ir3)
296
+ x_window1, shortcut1 = self.window1(x_ir1)
297
+ x_window2, shortcut2 = self.window2(x_ir2)
298
+ x_window3, shortcut3 = self.window3(x_ir3)
299
+
300
+ o1, o2, o3 = self.attn1(x_window1, q1), self.attn2(x_window2, q2), self.attn3(x_window3, q3)
301
+
302
+ o1, o2, o3 = self.ffn1(o1, shortcut1), self.ffn2(o2, shortcut2), self.ffn3(o3, shortcut3)
303
+
304
+ o1, o2, o3 = _to_channel_first(o1), _to_channel_first(o2), _to_channel_first(o3)
305
+
306
+ o1, o2, o3 = self.embed_q(o1).flatten(2).transpose(1, 2), self.embed_k(o2).flatten(2).transpose(1, 2), self.embed_v(o3)
307
+
308
+ o = torch.cat([o1, o2, o3], dim=1)
309
+
310
+ out = self.VIT(o)
311
+ return out
312
+
313
+ def compute_param_flop():
314
+ model = pyramid_trans_expr2()
315
+ img = torch.rand(size=(1,3,224,224))
316
+ flops, params = profile(model, inputs=(img,))
317
+ print(f'flops:{flops/1000**3}G,params:{params/1000**2}M')
models/__pycache__/PosterV2_7cls.cpython-310.pyc ADDED
Binary file (12.1 kB). View file
 
models/__pycache__/PosterV2_7cls.cpython-311.pyc ADDED
Binary file (24.9 kB). View file
 
models/__pycache__/ir50.cpython-310.pyc ADDED
Binary file (6.01 kB). View file
 
models/__pycache__/ir50.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
models/__pycache__/mobilefacenet.cpython-310.pyc ADDED
Binary file (6.5 kB). View file
 
models/__pycache__/mobilefacenet.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
 
models/__pycache__/vit_model.cpython-310.pyc ADDED
Binary file (19.6 kB). View file
 
models/__pycache__/vit_model.cpython-311.pyc ADDED
Binary file (34.9 kB). View file
 
models/ir50.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout2d, Dropout, AvgPool2d, \
2
+ MaxPool2d, AdaptiveAvgPool2d, Sequential, Module, Parameter
3
+ import torch.nn.functional as F
4
+ import torch
5
+ from collections import namedtuple
6
+ import math
7
+ import pdb
8
+
9
+
10
+ ################################## Original Arcface Model #############################################################
11
+
12
+ class Flatten(Module):
13
+ def forward(self, input):
14
+ return input.view(input.size(0), -1)
15
+
16
+
17
+ def l2_norm(input, axis=1):
18
+ norm = torch.norm(input, 2, axis, True)
19
+ output = torch.div(input, norm)
20
+ return output
21
+
22
+
23
+ class SEModule(Module):
24
+ def __init__(self, channels, reduction):
25
+ super(SEModule, self).__init__()
26
+ self.avg_pool = AdaptiveAvgPool2d(1)
27
+ self.fc1 = Conv2d(
28
+ channels, channels // reduction, kernel_size=1, padding=0, bias=False)
29
+ self.relu = ReLU(inplace=True)
30
+ self.fc2 = Conv2d(
31
+ channels // reduction, channels, kernel_size=1, padding=0, bias=False)
32
+ self.sigmoid = Sigmoid()
33
+
34
+ def forward(self, x):
35
+ module_input = x
36
+ x = self.avg_pool(x)
37
+ x = self.fc1(x)
38
+ x = self.relu(x)
39
+ x = self.fc2(x)
40
+ x = self.sigmoid(x)
41
+ return module_input * x
42
+
43
+
44
+ # i = 0
45
+
46
+ class bottleneck_IR(Module):
47
+ def __init__(self, in_channel, depth, stride):
48
+ super(bottleneck_IR, self).__init__()
49
+ if in_channel == depth:
50
+ self.shortcut_layer = MaxPool2d(1, stride)
51
+ else:
52
+ self.shortcut_layer = Sequential(
53
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth))
54
+ self.res_layer = Sequential(
55
+ BatchNorm2d(in_channel),
56
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
57
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth))
58
+ i = 0
59
+
60
+ def forward(self, x):
61
+ shortcut = self.shortcut_layer(x)
62
+ # print(shortcut.shape)
63
+ # print('---s---')
64
+ res = self.res_layer(x)
65
+ # print(res.shape)
66
+ # print('---r---')
67
+ # i = i + 50
68
+ # print(i)
69
+ # print('50')
70
+ return res + shortcut
71
+
72
+
73
+ class bottleneck_IR_SE(Module):
74
+ def __init__(self, in_channel, depth, stride):
75
+ super(bottleneck_IR_SE, self).__init__()
76
+ if in_channel == depth:
77
+ self.shortcut_layer = MaxPool2d(1, stride)
78
+ else:
79
+ self.shortcut_layer = Sequential(
80
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
81
+ BatchNorm2d(depth))
82
+ self.res_layer = Sequential(
83
+ BatchNorm2d(in_channel),
84
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
85
+ PReLU(depth),
86
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
87
+ BatchNorm2d(depth),
88
+ SEModule(depth, 16)
89
+ )
90
+
91
+ def forward(self, x):
92
+ shortcut = self.shortcut_layer(x)
93
+ res = self.res_layer(x)
94
+ return res + shortcut
95
+
96
+
97
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
98
+ '''A named tuple describing a ResNet block.'''
99
+ # print('50')
100
+
101
+
102
+ def get_block(in_channel, depth, num_units, stride=2):
103
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
104
+
105
+
106
+ def get_blocks(num_layers):
107
+ if num_layers == 50:
108
+ blocks1 = [
109
+ get_block(in_channel=64, depth=64, num_units=3),
110
+ # get_block(in_channel=64, depth=128, num_units=4),
111
+ # get_block(in_channel=128, depth=256, num_units=14),
112
+ # get_block(in_channel=256, depth=512, num_units=3)
113
+ ]
114
+ blocks2 = [
115
+ # get_block(in_channel=64, depth=64, num_units=3),
116
+ get_block(in_channel=64, depth=128, num_units=4),
117
+ # get_block(in_channel=128, depth=256, num_units=14),
118
+ # get_block(in_channel=256, depth=512, num_units=3)
119
+ ]
120
+ blocks3 = [
121
+ # get_block(in_channel=64, depth=64, num_units=3),
122
+ # get_block(in_channel=64, depth=128, num_units=4),
123
+ get_block(in_channel=128, depth=256, num_units=14),
124
+ # get_block(in_channel=256, depth=512, num_units=3)
125
+ ]
126
+
127
+ elif num_layers == 100:
128
+ blocks = [
129
+ get_block(in_channel=64, depth=64, num_units=3),
130
+ get_block(in_channel=64, depth=128, num_units=13),
131
+ get_block(in_channel=128, depth=256, num_units=30),
132
+ get_block(in_channel=256, depth=512, num_units=3)
133
+ ]
134
+ elif num_layers == 152:
135
+ blocks = [
136
+ get_block(in_channel=64, depth=64, num_units=3),
137
+ get_block(in_channel=64, depth=128, num_units=8),
138
+ get_block(in_channel=128, depth=256, num_units=36),
139
+ get_block(in_channel=256, depth=512, num_units=3)
140
+ ]
141
+ return blocks1, blocks2, blocks3
142
+
143
+
144
+ class Backbone(Module):
145
+ def __init__(self, num_layers, drop_ratio, mode='ir'):
146
+ super(Backbone, self).__init__()
147
+ # assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
148
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
149
+ blocks1, blocks2, blocks3 = get_blocks(num_layers)
150
+ # blocks2 = get_blocks(num_layers)
151
+ if mode == 'ir':
152
+ unit_module = bottleneck_IR
153
+ elif mode == 'ir_se':
154
+ unit_module = bottleneck_IR_SE
155
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
156
+ BatchNorm2d(64),
157
+ PReLU(64))
158
+ self.output_layer = Sequential(BatchNorm2d(512),
159
+ Dropout(drop_ratio),
160
+ Flatten(),
161
+ Linear(512 * 7 * 7, 512),
162
+ BatchNorm1d(512))
163
+ modules1 = []
164
+ for block in blocks1:
165
+ for bottleneck in block:
166
+ modules1.append(
167
+ unit_module(bottleneck.in_channel,
168
+ bottleneck.depth,
169
+ bottleneck.stride))
170
+
171
+ modules2 = []
172
+ for block in blocks2:
173
+ for bottleneck in block:
174
+ modules2.append(
175
+ unit_module(bottleneck.in_channel,
176
+ bottleneck.depth,
177
+ bottleneck.stride))
178
+
179
+ modules3 = []
180
+ for block in blocks3:
181
+ for bottleneck in block:
182
+ modules3.append(
183
+ unit_module(bottleneck.in_channel,
184
+ bottleneck.depth,
185
+ bottleneck.stride))
186
+ # modules4 = []
187
+ # for block in blocks4:
188
+ # for bottleneck in block:
189
+ # modules4.append(
190
+ # unit_module(bottleneck.in_channel,
191
+ # bottleneck.depth,
192
+ # bottleneck.stride))
193
+ self.body1 = Sequential(*modules1)
194
+ self.body2 = Sequential(*modules2)
195
+ self.body3 = Sequential(*modules3)
196
+ # self.body4 = Sequential(*modules4)
197
+
198
+ def forward(self, x):
199
+ x = F.interpolate(x, size=112)
200
+ x = self.input_layer(x)
201
+ x1 = self.body1(x)
202
+ x2 = self.body2(x1)
203
+ x3 = self.body3(x2)
204
+
205
+ # x = self.output_layer(x)
206
+ # return l2_norm(x)
207
+
208
+ return x1, x2, x3
209
+
210
+ def load_pretrained_weights(model, checkpoint):
211
+ import collections
212
+ if 'state_dict' in checkpoint:
213
+ state_dict = checkpoint['state_dict']
214
+ else:
215
+ state_dict = checkpoint
216
+ model_dict = model.state_dict()
217
+ new_state_dict = collections.OrderedDict()
218
+ matched_layers, discarded_layers = [], []
219
+ for i, (k, v) in enumerate(state_dict.items()):
220
+ # print(i)
221
+
222
+ # If the pretrained state_dict was saved as nn.DataParallel,
223
+ # keys would contain "module.", which should be ignored.
224
+ if k.startswith('module.'):
225
+ k = k[7:]
226
+ if k in model_dict and model_dict[k].size() == v.size():
227
+
228
+ new_state_dict[k] = v
229
+ matched_layers.append(k)
230
+ else:
231
+ # print(k)
232
+ discarded_layers.append(k)
233
+ # new_state_dict.requires_grad = False
234
+ model_dict.update(new_state_dict)
235
+ model.load_state_dict(model_dict)
236
+ print('load_weight', len(matched_layers))
237
+ return model
238
+
239
+ # model = Backbone(50, 0.0, 'ir')
240
+ # ir_checkpoint = torch.load(r'C:\Users\86187\Desktop\project\mixfacial\models\pretrain\new_ir50.pth')
241
+ # print('hello')
242
+ # i1, i2, i3 = 0, 0, 0
243
+ # ir_checkpoint = torch.load(r'C:\Users\86187\Desktop\project\mixfacial\models\pretrain\ir50.pth', map_location=lambda storage, loc: storage)
244
+ # for (k1, v1), (k2, v2) in zip(model.state_dict().items(), ir_checkpoint.items()):
245
+ # print(f'k1:{k1}, k2:{k2}')
246
+ # model.state_dict()[k1] = v2
247
+
248
+ # torch.save(model.state_dict(), r'C:\Users\86187\Desktop\project\mixfacial\models\pretrain\new_ir50.pth')
249
+ # print(k)
250
+ # if k.startswith('body1'):
251
+ # i1+=1
252
+ # if k.startswith('body2'):
253
+ # i2+=1
254
+ # if k.startswith('body3'):
255
+ # i3+=1
256
+ # print(f'i1:{i1}, i2:{i2}, i3:{i3}')
257
+
258
+ # print('-'*100)
259
+ # ir_checkpoint = torch.load(r'C:\Users\86187\Desktop\project\mixfacial\models\pretrain\ir50.pth', map_location=lambda storage, loc: storage)
260
+ # le = 0
261
+ # for k, v in ir_checkpoint.items():
262
+ # # print(k)
263
+ # if k.startswith('body'):
264
+ # if le < i1:
265
+ # le += 1
266
+ # key = k.split('.')[0] + str(1) + k.split('.')[1:]
267
+ # print(key)
268
+ # # ir_checkpoint = ir_checkpoint["model"]
269
+ # model = load_pretrained_weights(model, ir_checkpoint)
270
+ # img = torch.rand(size=(2,3,224,224))
271
+ # out1, out2, out3 = model(img)
272
+ # print(out1.shape, out2.shape, out3.shape)
models/matrix.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+
6
+ plt.rcParams['font.sans-serif'] = ['SimHei']
7
+ plt.rcParams['axes.unicode_minus'] = False
8
+
9
+
10
+ # -*- coding:utf-8 -*-
11
+
12
+ def plot_confusion_matrix(cm, classes,
13
+ normalize=False,
14
+ title='Confusion matrix',
15
+ cmap=plt.cm.Blues):
16
+ """
17
+ This function prints and plots the confusion matrix.
18
+ Normalization can be applied by setting `normalize=True`.
19
+ """
20
+ if normalize:
21
+ cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
22
+ print("Normalized confusion matrix")
23
+ else:
24
+ print('Confusion matrix, without normalization')
25
+
26
+ print(cm)
27
+
28
+ plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
29
+ plt.title(title)
30
+ plt.colorbar()
31
+ tick_marks = np.arange(len(classes))
32
+ plt.xticks(tick_marks, classes, fontsize=16)
33
+ plt.yticks(tick_marks, classes, fontsize=16)
34
+
35
+ fmt = '.2f' if normalize else 'd'
36
+ thresh = cm.max() / 2.
37
+ for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
38
+ plt.text(j, i, format(cm[i, j], fmt),
39
+ horizontalalignment="center",
40
+ color="white" if cm[i, j] > thresh else "black")
41
+
42
+ plt.tight_layout()
43
+ plt.ylabel('True Label',fontsize=12)
44
+ plt.xlabel('Predicted Label',fontsize=12)
45
+ plt.show()
46
+
47
+
48
+
49
+ cnf_matrix = np.array([[ 299 , 6 , 5 , 3 , 1 , 4, 11],
50
+ [ 9, 51 , 0, 2 , 8, 2 , 2],
51
+ [ 2 , 1 ,120 , 6 ,13 , 9 , 9],
52
+ [ 5 , 1 , 7 ,1148 , 2 , 4 , 18],
53
+ [ 0 , 0 , 9 , 4 ,442 , 1 , 22],
54
+ [ 2 ,0 , 7 , 3 , 0 ,145 , 5],
55
+ [ 10 ,0, 6 ,11, 29 , 0, 624]])
56
+
57
+ class_names = ["SU", 'FE', 'AN', 'HA', 'SA', 'DI', 'NE']
58
+
59
+
60
+ plt.figure(dpi=200)
61
+ plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
62
+ title=None)
models/mobilefacenet.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout2d, Dropout, AvgPool2d, \
2
+ MaxPool2d, AdaptiveAvgPool2d, Sequential, Module, Parameter
3
+ import torch.nn.functional as F
4
+ import torch
5
+ import torch.nn as nn
6
+ from collections import namedtuple
7
+ import math
8
+ import pdb
9
+
10
+
11
+ ################################## Original Arcface Model #############################################################
12
+ ######## ccc#######################
13
+ class Flatten(Module):
14
+ def forward(self, input):
15
+ return input.view(input.size(0), -1)
16
+
17
+
18
+ ################################## MobileFaceNet #############################################################
19
+
20
+ class Conv_block(Module):
21
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
22
+ super(Conv_block, self).__init__()
23
+ self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding,
24
+ bias=False)
25
+ self.bn = BatchNorm2d(out_c)
26
+ self.prelu = PReLU(out_c)
27
+
28
+ def forward(self, x):
29
+ x = self.conv(x)
30
+ x = self.bn(x)
31
+ x = self.prelu(x)
32
+ return x
33
+
34
+
35
+ class Linear_block(Module):
36
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
37
+ super(Linear_block, self).__init__()
38
+ self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding,
39
+ bias=False)
40
+ self.bn = BatchNorm2d(out_c)
41
+
42
+ def forward(self, x):
43
+ x = self.conv(x)
44
+ x = self.bn(x)
45
+ return x
46
+
47
+
48
+ class Depth_Wise(Module):
49
+ def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
50
+ super(Depth_Wise, self).__init__()
51
+ self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
52
+ self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
53
+ self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
54
+ self.residual = residual
55
+
56
+ def forward(self, x):
57
+ if self.residual:
58
+ short_cut = x
59
+ x = self.conv(x)
60
+ x = self.conv_dw(x)
61
+ x = self.project(x)
62
+ if self.residual:
63
+ output = short_cut + x
64
+ else:
65
+ output = x
66
+ return output
67
+
68
+
69
+ class Residual(Module):
70
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
71
+ super(Residual, self).__init__()
72
+ modules = []
73
+ for _ in range(num_block):
74
+ modules.append(
75
+ Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups))
76
+ self.model = Sequential(*modules)
77
+
78
+ def forward(self, x):
79
+ return self.model(x)
80
+
81
+
82
+ class GNAP(Module):
83
+ def __init__(self, embedding_size):
84
+ super(GNAP, self).__init__()
85
+ assert embedding_size == 512
86
+ self.bn1 = BatchNorm2d(512, affine=False)
87
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
88
+
89
+ self.bn2 = BatchNorm1d(512, affine=False)
90
+
91
+ def forward(self, x):
92
+ x = self.bn1(x)
93
+ x_norm = torch.norm(x, 2, 1, True)
94
+ x_norm_mean = torch.mean(x_norm)
95
+ weight = x_norm_mean / x_norm
96
+ x = x * weight
97
+ x = self.pool(x)
98
+ x = x.view(x.shape[0], -1)
99
+ feature = self.bn2(x)
100
+ return feature
101
+
102
+
103
+ class GDC(Module):
104
+ def __init__(self, embedding_size):
105
+ super(GDC, self).__init__()
106
+ self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0))
107
+ self.conv_6_flatten = Flatten()
108
+ self.linear = Linear(512, embedding_size, bias=False)
109
+ # self.bn = BatchNorm1d(embedding_size, affine=False)
110
+ self.bn = BatchNorm1d(embedding_size)
111
+
112
+ def forward(self, x):
113
+ x = self.conv_6_dw(x) #### [B, 512, 1, 1]
114
+ x = self.conv_6_flatten(x) #### [B, 512]
115
+ x = self.linear(x) #### [B, 136]
116
+ x = self.bn(x)
117
+ return x
118
+
119
+
120
+ class MobileFaceNet(Module):
121
+ def __init__(self, input_size, embedding_size=512, output_name="GDC"):
122
+ super(MobileFaceNet, self).__init__()
123
+ assert output_name in ["GNAP", 'GDC']
124
+ assert input_size[0] in [112]
125
+ self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
126
+ self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
127
+ self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
128
+ self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
129
+ self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
130
+ self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
131
+ self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
132
+ self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
133
+ self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
134
+ if output_name == "GNAP":
135
+ self.output_layer = GNAP(512)
136
+ else:
137
+ self.output_layer = GDC(embedding_size)
138
+
139
+ self._initialize_weights()
140
+
141
+ def _initialize_weights(self):
142
+ for m in self.modules():
143
+ if isinstance(m, nn.Conv2d):
144
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
145
+ if m.bias is not None:
146
+ m.bias.data.zero_()
147
+ elif isinstance(m, nn.BatchNorm2d):
148
+ m.weight.data.fill_(1)
149
+ m.bias.data.zero_()
150
+ elif isinstance(m, nn.Linear):
151
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
152
+ if m.bias is not None:
153
+ m.bias.data.zero_()
154
+
155
+ def forward(self, x):
156
+ out = self.conv1(x)
157
+ # print(out.shape)
158
+ out = self.conv2_dw(out)
159
+ # print(out.shape)
160
+ out = self.conv_23(out)
161
+ # print(out.shape)
162
+ out3 = self.conv_3(out)
163
+ # print(out.shape)
164
+ out = self.conv_34(out3)
165
+ # print(out.shape)
166
+ out4 = self.conv_4(out) # [128, 14, 14]
167
+ # print(out.shape)
168
+ out = self.conv_45(out4) # [128, 7, 7]
169
+ # print(out.shape)
170
+ out = self.conv_5(out) # [128, 7, 7]
171
+ # print(out.shape)
172
+ conv_features = self.conv_6_sep(out) ##### [B, 512, 7, 7]
173
+ out = self.output_layer(conv_features) ##### [B, 136]
174
+ return out3, out4, conv_features
175
+
176
+
177
+ # model = MobileFaceNet([112, 112],136)
178
+ # input = torch.ones(8,3,112,112).cuda()
179
+ # model = model.cuda()
180
+ # x = model(input)
181
+ # import numpy as np
182
+ # parameters = model.parameters()
183
+ # parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
184
+ # print('Total Parameters: %.3fM' % parameters)
185
+ #
186
+ #
187
+ # from ptflops import get_model_complexity_info
188
+ # macs, params = get_model_complexity_info(model, (3, 112, 112), as_strings=True,
189
+ # print_per_layer_stat=True, verbose=True)
190
+ # print('{:<30} {:<8}'.format('Computational complexity: ', macs))
191
+ # print('{:<30} {:<8}'.format('Number of parameters: ', params))
192
+ #
193
+ # print(x.shape)
models/pretrain/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/pretrain/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
models/pretrain/ir50.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62fcfa833776648f818b15fac4f5b760d76847316097e8e046f77ac445defb75
3
+ size 122022895
models/pretrain/mobilefacenet_model_best.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b994af026bfddbafc507a6f1c8737a9896bab20ed2b0cfb6ae90b81736970313
3
+ size 12281146
models/vit_model.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ original code from rwightman:
3
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
4
+ """
5
+ from functools import partial
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.hub
14
+ from functools import partial
15
+ # import mat
16
+ # from vision_transformer.ir50 import Backbone
17
+
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import torch.hub
23
+ from functools import partial
24
+ import math
25
+
26
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
27
+ from timm.models.registry import register_model
28
+ from timm.models.vision_transformer import _cfg, Mlp, Block
29
+ # from .ir50 import Backbone
30
+
31
+
32
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
33
+ """3x3 convolution with padding"""
34
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
35
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
36
+
37
+
38
+ def conv1x1(in_planes, out_planes, stride=1):
39
+ """1x1 convolution"""
40
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
41
+
42
+
43
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
44
+ """
45
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
46
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
47
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
48
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
49
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
50
+ 'survival rate' as the argument.
51
+ """
52
+ if drop_prob == 0. or not training:
53
+ return x
54
+ keep_prob = 1 - drop_prob
55
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
56
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
57
+ random_tensor.floor_() # binarize
58
+ output = x.div(keep_prob) * random_tensor
59
+ return output
60
+
61
+
62
+ class BasicBlock(nn.Module):
63
+ __constants__ = ['downsample']
64
+
65
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
66
+ super(BasicBlock, self).__init__()
67
+ norm_layer = nn.BatchNorm2d
68
+ self.conv1 = conv3x3(inplanes, planes, stride)
69
+ self.bn1 = norm_layer(planes)
70
+ self.relu = nn.ReLU(inplace=True)
71
+ self.conv2 = conv3x3(planes, planes)
72
+ self.bn2 = norm_layer(planes)
73
+ self.downsample = downsample
74
+ self.stride = stride
75
+
76
+ def forward(self, x):
77
+ identity = x
78
+
79
+ out = self.conv1(x)
80
+ out = self.bn1(out)
81
+ out = self.relu(out)
82
+ out = self.conv2(out)
83
+ out = self.bn2(out)
84
+
85
+ if self.downsample is not None:
86
+ identity = self.downsample(x)
87
+
88
+ out += identity
89
+ out = self.relu(out)
90
+
91
+ return out
92
+
93
+
94
+ class DropPath(nn.Module):
95
+ """
96
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
97
+ """
98
+
99
+ def __init__(self, drop_prob=None):
100
+ super(DropPath, self).__init__()
101
+ self.drop_prob = drop_prob
102
+
103
+ def forward(self, x):
104
+ return drop_path(x, self.drop_prob, self.training)
105
+
106
+
107
+ class PatchEmbed(nn.Module):
108
+ """
109
+ 2D Image to Patch Embedding
110
+ """
111
+
112
+ def __init__(self, img_size=14, patch_size=16, in_c=256, embed_dim=768, norm_layer=None):
113
+ super().__init__()
114
+ img_size = (img_size, img_size)
115
+ patch_size = (patch_size, patch_size)
116
+ self.img_size = img_size
117
+ self.patch_size = patch_size
118
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
119
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
120
+
121
+ self.proj = nn.Conv2d(256, 768, kernel_size=1)
122
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
123
+
124
+ def forward(self, x):
125
+ B, C, H, W = x.shape
126
+ # assert H == self.img_size[0] and W == self.img_size[1], \
127
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
128
+ # print(x.shape)
129
+
130
+ # flatten: [B, C, H, W] -> [B, C, HW]
131
+ # transpose: [B, C, HW] -> [B, HW, C]
132
+ x = self.proj(x).flatten(2).transpose(1, 2)
133
+ x = self.norm(x)
134
+ return x
135
+
136
+
137
+ class Attention(nn.Module):
138
+ def __init__(self,
139
+ dim, in_chans, # 输入token的dim
140
+ num_heads=8,
141
+ qkv_bias=False,
142
+ qk_scale=None,
143
+ attn_drop_ratio=0.,
144
+ proj_drop_ratio=0.):
145
+ super(Attention, self).__init__()
146
+ self.num_heads = 8
147
+ self.img_chanel = in_chans + 1
148
+ head_dim = dim // num_heads
149
+ self.scale = head_dim ** -0.5
150
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
151
+ self.attn_drop = nn.Dropout(attn_drop_ratio)
152
+ self.proj = nn.Linear(dim, dim)
153
+ self.proj_drop = nn.Dropout(proj_drop_ratio)
154
+
155
+ def forward(self, x):
156
+ x_img = x[:, :self.img_chanel, :]
157
+ # [batch_size, num_patches + 1, total_embed_dim]
158
+ B, N, C = x_img.shape
159
+ # print(C)
160
+ qkv = self.qkv(x_img).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
161
+ q, k, v = qkv[0], qkv[1], qkv[2]
162
+ # k, v = kv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
163
+ # q = x_img.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
164
+ attn = (q @ k.transpose(-2, -1)) * self.scale
165
+ attn = attn.softmax(dim=-1)
166
+ attn = self.attn_drop(attn)
167
+
168
+ x_img = (attn @ v).transpose(1, 2).reshape(B, N, C)
169
+ x_img = self.proj(x_img)
170
+ x_img = self.proj_drop(x_img)
171
+ #
172
+ #
173
+ # # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
174
+ # # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
175
+ # # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
176
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
177
+ # # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
178
+ # q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
179
+ #
180
+ # # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
181
+ # # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
182
+ # attn = (q @ k.transpose(-2, -1)) * self.scale
183
+ # attn = attn.softmax(dim=-1)
184
+ # attn = self.attn_drop(attn)
185
+ #
186
+ # # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
187
+ # # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
188
+ # # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
189
+ # x = (attn @ v).transpose(1, 2).reshape(B, N, C)
190
+ # x = self.proj(x)
191
+ # x = self.proj_drop(x)
192
+ return x_img
193
+
194
+
195
+ class AttentionBlock(nn.Module):
196
+ __constants__ = ['downsample']
197
+
198
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
199
+ super(AttentionBlock, self).__init__()
200
+ norm_layer = nn.BatchNorm2d
201
+ self.conv1 = conv3x3(inplanes, planes, stride)
202
+ self.bn1 = norm_layer(planes)
203
+ self.relu = nn.ReLU(inplace=True)
204
+ self.conv2 = conv3x3(planes, planes)
205
+ self.bn2 = norm_layer(planes)
206
+ self.downsample = downsample
207
+ self.stride = stride
208
+ # self.cbam = CBAM(planes, 16)
209
+ self.inplanes = inplanes
210
+ self.eca_block = eca_block()
211
+
212
+ def forward(self, x):
213
+ identity = x
214
+
215
+ out = self.conv1(x)
216
+ out = self.bn1(out)
217
+ out = self.relu(out)
218
+
219
+ out = self.conv2(out)
220
+ out = self.bn2(out)
221
+ inplanes = self.inplanes
222
+ out = self.eca_block(out)
223
+ if self.downsample is not None:
224
+ identity = self.downsample(x)
225
+
226
+ out += identity
227
+ out = self.relu(out)
228
+
229
+ return out
230
+
231
+
232
+ class Mlp(nn.Module):
233
+ """
234
+ MLP as used in Vision Transformer, MLP-Mixer and related networks
235
+ """
236
+
237
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
238
+ super().__init__()
239
+ out_features = out_features or in_features
240
+ hidden_features = hidden_features or in_features
241
+ self.fc1 = nn.Linear(in_features, hidden_features)
242
+ self.act = act_layer()
243
+ self.fc2 = nn.Linear(hidden_features, out_features)
244
+ self.drop = nn.Dropout(drop)
245
+
246
+ def forward(self, x):
247
+ x = self.fc1(x)
248
+ x = self.act(x)
249
+ x = self.drop(x)
250
+ x = self.fc2(x)
251
+ x = self.drop(x)
252
+ return x
253
+
254
+
255
+ class Block(nn.Module):
256
+ def __init__(self,
257
+ dim, in_chans,
258
+ num_heads,
259
+ mlp_ratio=4.,
260
+ qkv_bias=False,
261
+ qk_scale=None,
262
+ drop_ratio=0.,
263
+ attn_drop_ratio=0.,
264
+ drop_path_ratio=0.,
265
+ act_layer=nn.GELU,
266
+ norm_layer=nn.LayerNorm):
267
+ super(Block, self).__init__()
268
+ self.norm1 = norm_layer(dim)
269
+ self.img_chanel = in_chans + 1
270
+
271
+ self.conv = nn.Conv1d(self.img_chanel, self.img_chanel, 1)
272
+ self.attn = Attention(dim, in_chans=in_chans, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
273
+ attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
274
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
275
+ self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
276
+ self.norm2 = norm_layer(dim)
277
+ mlp_hidden_dim = int(dim * mlp_ratio)
278
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
279
+
280
+ def forward(self, x):
281
+ # x = x + self.drop_path(self.attn(self.norm1(x)))
282
+ # x = x + self.drop_path(self.mlp(self.norm2(x)))
283
+
284
+ x_img = x
285
+ # [:, :self.img_chanel, :]
286
+ # x_lm = x[:, self.img_chanel:, :]
287
+ x_img = x_img + self.drop_path(self.attn(self.norm1(x)))
288
+ x = x_img + self.drop_path(self.mlp(self.norm2(x_img)))
289
+ #
290
+ # x_lm = x_lm + self.drop_path(self.attn_lm(self.norm3(x)))
291
+ # x_lm = x_lm + self.drop_path(self.mlp2(self.norm4(x_lm)))
292
+ # x = torch.cat((x_img, x_lm), dim=1)
293
+ # x = self.conv(x)
294
+
295
+ return x
296
+
297
+
298
+ class ClassificationHead(nn.Module):
299
+ def __init__(self, input_dim: int, target_dim: int):
300
+ super().__init__()
301
+ self.linear = torch.nn.Linear(input_dim, target_dim)
302
+
303
+ def forward(self, x):
304
+ x = x.view(x.size(0), -1)
305
+ y_hat = self.linear(x)
306
+ return y_hat
307
+
308
+
309
+ def load_pretrained_weights(model, checkpoint):
310
+ import collections
311
+ if 'state_dict' in checkpoint:
312
+ state_dict = checkpoint['state_dict']
313
+ else:
314
+ state_dict = checkpoint
315
+ model_dict = model.state_dict()
316
+ new_state_dict = collections.OrderedDict()
317
+ matched_layers, discarded_layers = [], []
318
+ for k, v in state_dict.items():
319
+ # If the pretrained state_dict was saved as nn.DataParallel,
320
+ # keys would contain "module.", which should be ignored.
321
+ if k.startswith('module.'):
322
+ k = k[7:]
323
+ if k in model_dict and model_dict[k].size() == v.size():
324
+ new_state_dict[k] = v
325
+ matched_layers.append(k)
326
+ else:
327
+ discarded_layers.append(k)
328
+ # new_state_dict.requires_grad = False
329
+ model_dict.update(new_state_dict)
330
+
331
+ model.load_state_dict(model_dict)
332
+ print('load_weight', len(matched_layers))
333
+ return model
334
+
335
+ class eca_block(nn.Module):
336
+ def __init__(self, channel=128, b=1, gamma=2):
337
+ super(eca_block, self).__init__()
338
+ kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
339
+ kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
340
+
341
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
342
+ self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
343
+ self.sigmoid = nn.Sigmoid()
344
+
345
+ def forward(self, x):
346
+ y = self.avg_pool(x)
347
+ y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
348
+ y = self.sigmoid(y)
349
+ return x * y.expand_as(x)
350
+ #
351
+ #
352
+ # class IR20(nn.Module):
353
+ # def __init__(self, img_size_=112, num_classes=7, layers=[2, 2, 2, 2]):
354
+ # super().__init__()
355
+ # norm_layer = nn.BatchNorm2d
356
+ # self.img_size = img_size_
357
+ # self._norm_layer = norm_layer
358
+ # self.num_classes = num_classes
359
+ # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
360
+ # self.bn1 = norm_layer(64)
361
+ # self.relu = nn.ReLU(inplace=True)
362
+ # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
363
+ # # self.face_landback = MobileFaceNet([112, 112],136)
364
+ # # face_landback_checkpoint = torch.load('./models/pretrain/mobilefacenet_model_best.pth.tar', map_location=lambda storage, loc: storage)
365
+ # # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
366
+ # self.layer1 = self._make_layer(BasicBlock, 64, 64, layers[0])
367
+ # self.layer2 = self._make_layer(BasicBlock, 64, 128, layers[1], stride=2)
368
+ # self.layer3 = self._make_layer(AttentionBlock, 128, 256, layers[2], stride=2)
369
+ # self.layer4 = self._make_layer(AttentionBlock, 256, 256, layers[3], stride=1)
370
+ # self.ir_back = Backbone(50, 51, 52, 0.0, 'ir')
371
+ # self.ir_layer = nn.Linear(1024, 512)
372
+ # # ir_checkpoint = torch.load(r'F:\0815crossvit\vision_transformer\models\pretrain\Pretrained_on_MSCeleb.pth.tar',
373
+ # # map_location=lambda storage, loc: storage)
374
+ # # ir_checkpoint = ir_checkpoint['state_dict']
375
+ # # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
376
+ # # checkpoint = torch.load('./checkpoint/Pretrained_on_MSCeleb.pth.tar')
377
+ # # pre_trained_dict = checkpoint['state_dict']
378
+ # # IR20.load_state_dict(ir_checkpoint, strict=False)
379
+ # # self.IR = load_pretrained_weights(IR, ir_checkpoint)
380
+ #
381
+ # def _make_layer(self, block, inplanes, planes, blocks, stride=1):
382
+ # norm_layer = self._norm_layer
383
+ # downsample = None
384
+ # if stride != 1 or inplanes != planes:
385
+ # downsample = nn.Sequential(conv1x1(inplanes, planes, stride), norm_layer(planes))
386
+ # layers = []
387
+ # layers.append(block(inplanes, planes, stride, downsample))
388
+ # inplanes = planes
389
+ # for _ in range(1, blocks):
390
+ # layers.append(block(inplanes, planes))
391
+ # return nn.Sequential(*layers)
392
+ #
393
+ # def forward(self, x):
394
+ # x_ir = self.ir_back(x)
395
+ # # x_ir = self.ir_layer(x_ir)
396
+ # # print(x_ir.shape)
397
+ # # x = F.interpolate(x, size=112)
398
+ # # x = self.conv1(x)
399
+ # # x = self.bn1(x)
400
+ # # x = self.relu(x)
401
+ # # x = self.maxpool(x)
402
+ # #
403
+ # # x = self.layer1(x)
404
+ # # x = self.layer2(x)
405
+ # # x = self.layer3(x)
406
+ # # x = self.layer4(x)
407
+ # # print(x.shape)
408
+ # # print(x)
409
+ # out = x_ir
410
+ #
411
+ # return out
412
+ #
413
+ #
414
+ # class IR(nn.Module):
415
+ # def __init__(self, img_size_=112, num_classes=7):
416
+ # super().__init__()
417
+ # depth = 8
418
+ # # if type == "small":
419
+ # # depth = 4
420
+ # # if type == "base":
421
+ # # depth = 6
422
+ # # if type == "large":
423
+ # # depth = 8
424
+ #
425
+ # self.img_size = img_size_
426
+ # self.num_classes = num_classes
427
+ # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
428
+ # # self.bn1 = norm_layer(64)
429
+ # self.relu = nn.ReLU(inplace=True)
430
+ # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
431
+ # # self.face_landback = MobileFaceNet([112, 112],136)
432
+ # # face_landback_checkpoint = torch.load('./models/pretrain/mobilefacenet_model_best.pth.tar', map_location=lambda storage, loc: storage)
433
+ # # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
434
+ #
435
+ # # for param in self.face_landback.parameters():
436
+ # # param.requires_grad = False
437
+ #
438
+ # ###########################################################################333
439
+ #
440
+ # self.ir_back = IR20()
441
+ #
442
+ # # ir_checkpoint = torch.load(r'F:\0815crossvit\vision_transformer\models\pretrain\ir50.pth',
443
+ # # map_location=lambda storage, loc: storage)
444
+ # # # ir_checkpoint = ir_checkpoint["model"]
445
+ # # self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
446
+ # # load_state_dict(checkpoint_model, strict=False)
447
+ # # self.ir_layer = nn.Linear(1024,512)
448
+ #
449
+ # #############################################################3
450
+ # #
451
+ # # self.pyramid_fuse = HyVisionTransformer(in_chans=49, q_chanel = 49, embed_dim=512,
452
+ # # depth=depth, num_heads=8, mlp_ratio=2.,
453
+ # # drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1)
454
+ #
455
+ # # self.se_block = SE_block(input_dim=512)
456
+ # self.head = ClassificationHead(input_dim=768, target_dim=self.num_classes)
457
+ #
458
+ # def forward(self, x):
459
+ # B_ = x.shape[0]
460
+ # # x_face = F.interpolate(x, size=112)
461
+ # # _, x_face = self.face_landback(x_face)
462
+ # # x_face = x_face.view(B_, -1, 49).transpose(1,2)
463
+ # ############### landmark x_face ([B, 49, 512])
464
+ # x_ir = self.ir_back(x)
465
+ # # print(x_ir.shape)
466
+ # # x_ir = self.ir_layer(x_ir)
467
+ # # print(x_ir.shape)
468
+ # ############### image x_ir ([B, 49, 512])
469
+ #
470
+ # # y_hat = self.pyramid_fuse(x_ir, x_face)
471
+ # # y_hat = self.se_block(y_hat)
472
+ # # y_feat = y_hat
473
+ #
474
+ # # out = self.head(x_ir)
475
+ #
476
+ # out = x_ir
477
+ # return out
478
+
479
+
480
+ class eca_block(nn.Module):
481
+ def __init__(self, channel=196, b=1, gamma=2):
482
+ super(eca_block, self).__init__()
483
+ kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
484
+ kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
485
+
486
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
487
+ self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
488
+ self.sigmoid = nn.Sigmoid()
489
+
490
+ def forward(self, x):
491
+ y = self.avg_pool(x)
492
+ y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
493
+ y = self.sigmoid(y)
494
+ return x * y.expand_as(x)
495
+
496
+ class SE_block(nn.Module):
497
+ def __init__(self, input_dim: int):
498
+ super().__init__()
499
+ self.linear1 = torch.nn.Linear(input_dim, input_dim)
500
+ self.relu = nn.ReLU()
501
+ self.linear2 = torch.nn.Linear(input_dim, input_dim)
502
+ self.sigmod = nn.Sigmoid()
503
+
504
+ def forward(self, x):
505
+ x1 = self.linear1(x)
506
+ x1 = self.relu(x1)
507
+ x1 = self.linear2(x1)
508
+ x1 = self.sigmod(x1)
509
+ x = x * x1
510
+ return x
511
+
512
+
513
+ class VisionTransformer(nn.Module):
514
+ def __init__(self, img_size=14, patch_size=14, in_c=147, num_classes=7,
515
+ embed_dim=768, depth=6, num_heads=8, mlp_ratio=4.0, qkv_bias=True,
516
+ qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
517
+ attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
518
+ act_layer=None):
519
+ """
520
+ Args:
521
+ img_size (int, tuple): input image size
522
+ patch_size (int, tuple): patch size
523
+ in_c (int): number of input channels
524
+ num_classes (int): number of classes for classification head
525
+ embed_dim (int): embedding dimension
526
+ depth (int): depth of transformer
527
+ num_heads (int): number of attention heads
528
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
529
+ qkv_bias (bool): enable bias for qkv if True
530
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
531
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
532
+ distilled (bool): model includes a distillation token and head as in DeiT models
533
+ drop_ratio (float): dropout rate
534
+ attn_drop_ratio (float): attention dropout rate
535
+ drop_path_ratio (float): stochastic depth rate
536
+ embed_layer (nn.Module): patch embedding layer
537
+ norm_layer: (nn.Module): normalization layer
538
+ """
539
+ super(VisionTransformer, self).__init__()
540
+ self.num_classes = num_classes
541
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
542
+ self.num_tokens = 2 if distilled else 1
543
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
544
+ act_layer = act_layer or nn.GELU
545
+
546
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
547
+ self.pos_embed = nn.Parameter(torch.zeros(1, in_c + 1, embed_dim))
548
+ self.pos_drop = nn.Dropout(p=drop_ratio)
549
+
550
+ self.se_block = SE_block(input_dim=embed_dim)
551
+
552
+
553
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=256, embed_dim=768)
554
+ num_patches = self.patch_embed.num_patches
555
+ self.head = ClassificationHead(input_dim=embed_dim, target_dim=self.num_classes)
556
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
557
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
558
+ # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
559
+ self.pos_drop = nn.Dropout(p=drop_ratio)
560
+ # self.IR = IR()
561
+ self.eca_block = eca_block()
562
+
563
+
564
+ # self.ir_back = Backbone(50, 0.0, 'ir')
565
+ # ir_checkpoint = torch.load('./models/pretrain/ir50.pth', map_location=lambda storage, loc: storage)
566
+ # # ir_checkpoint = ir_checkpoint["model"]
567
+ # self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
568
+
569
+ self.CON1 = nn.Conv2d(256, 768, kernel_size=1, stride=1, bias=False)
570
+ self.IRLinear1 = nn.Linear(1024, 768)
571
+ self.IRLinear2 = nn.Linear(768, 512)
572
+ self.eca_block = eca_block()
573
+ dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
574
+ self.blocks = nn.Sequential(*[
575
+ Block(dim=embed_dim, in_chans=in_c, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
576
+ qk_scale=qk_scale,
577
+ drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
578
+ norm_layer=norm_layer, act_layer=act_layer)
579
+ for i in range(depth)
580
+ ])
581
+ self.norm = norm_layer(embed_dim)
582
+
583
+ # Representation layer
584
+ if representation_size and not distilled:
585
+ self.has_logits = True
586
+ self.num_features = representation_size
587
+ self.pre_logits = nn.Sequential(OrderedDict([
588
+ ("fc", nn.Linear(embed_dim, representation_size)),
589
+ ("act", nn.Tanh())
590
+ ]))
591
+ else:
592
+ self.has_logits = False
593
+ self.pre_logits = nn.Identity()
594
+
595
+ # Classifier head(s)
596
+ # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
597
+ self.head_dist = None
598
+ if distilled:
599
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
600
+
601
+ # Weight init
602
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
603
+ if self.dist_token is not None:
604
+ nn.init.trunc_normal_(self.dist_token, std=0.02)
605
+
606
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
607
+ self.apply(_init_vit_weights)
608
+
609
+ def forward_features(self, x):
610
+ # [B, C, H, W] -> [B, num_patches, embed_dim]
611
+ # x = self.patch_embed(x) # [B, 196, 768]
612
+ # [1, 1, 768] -> [B, 1, 768]
613
+ # print(x.shape)
614
+
615
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1)
616
+ if self.dist_token is None:
617
+ x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
618
+ else:
619
+ x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
620
+ # print(x.shape)
621
+ x = self.pos_drop(x + self.pos_embed)
622
+ x = self.blocks(x)
623
+ x = self.norm(x)
624
+ if self.dist_token is None:
625
+ return self.pre_logits(x[:, 0])
626
+ else:
627
+ return x[:, 0], x[:, 1]
628
+
629
+ def forward(self, x):
630
+
631
+ # B = x.shape[0]
632
+ # print(x)
633
+ # x = self.eca_block(x)
634
+ # x = self.IR(x)
635
+ # x = eca_block(x)
636
+ # x = self.ir_back(x)
637
+ # print(x.shape)
638
+ # x = self.CON1(x)
639
+ # x = x.view(-1, 196, 768)
640
+ #
641
+ # # print(x.shape)
642
+ # # x = self.IRLinear1(x)
643
+ # # print(x)
644
+ # x_cls = torch.mean(x, 1).view(B, 1, -1)
645
+ # x = torch.cat((x_cls, x), dim=1)
646
+ # # print(x.shape)
647
+ # x = self.pos_drop(x + self.pos_embed)
648
+ # # print(x.shape)
649
+ # x = self.blocks(x)
650
+ # # print(x)
651
+ # x = self.norm(x)
652
+ # # print(x)
653
+ # # x1 = self.IRLinear2(x)
654
+ # x1 = x[:, 0, :]
655
+
656
+ # print(x1)
657
+ # print(x1.shape)
658
+
659
+ x = self.forward_features(x)
660
+ # # print(x.shape)
661
+ # if self.head_dist is not None:
662
+ # x, x_dist = self.head(x[0]), self.head_dist(x[1])
663
+ # if self.training and not torch.jit.is_scripting():
664
+ # # during inference, return the average of both classifier predictions
665
+ # return x, x_dist
666
+ # else:
667
+ # return (x + x_dist) / 2
668
+ # else:
669
+ # print(x.shape)
670
+ x = self.se_block(x)
671
+
672
+ x1 = self.head(x)
673
+
674
+ return x1
675
+
676
+
677
+ def _init_vit_weights(m):
678
+ """
679
+ ViT weight initialization
680
+ :param m: module
681
+ """
682
+ if isinstance(m, nn.Linear):
683
+ nn.init.trunc_normal_(m.weight, std=.01)
684
+ if m.bias is not None:
685
+ nn.init.zeros_(m.bias)
686
+ elif isinstance(m, nn.Conv2d):
687
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
688
+ if m.bias is not None:
689
+ nn.init.zeros_(m.bias)
690
+ elif isinstance(m, nn.LayerNorm):
691
+ nn.init.zeros_(m.bias)
692
+ nn.init.ones_(m.weight)
693
+
694
+
695
+ def vit_base_patch16_224(num_classes: int = 7):
696
+ """
697
+ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
698
+ ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
699
+ weights ported from official Google JAX impl:
700
+ 链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f
701
+ """
702
+ model = VisionTransformer(img_size=224,
703
+ patch_size=16,
704
+ embed_dim=768,
705
+ depth=12,
706
+ num_heads=12,
707
+ representation_size=None,
708
+ num_classes=num_classes)
709
+
710
+ return model
711
+
712
+
713
+ def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
714
+ """
715
+ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
716
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
717
+ weights ported from official Google JAX impl:
718
+ https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
719
+ """
720
+ model = VisionTransformer(img_size=224,
721
+ patch_size=16,
722
+ embed_dim=768,
723
+ depth=12,
724
+ num_heads=12,
725
+ representation_size=768 if has_logits else None,
726
+ num_classes=num_classes)
727
+ return model
728
+
729
+
730
+ def vit_base_patch32_224(num_classes: int = 1000):
731
+ """
732
+ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
733
+ ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
734
+ weights ported from official Google JAX impl:
735
+ 链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl
736
+ """
737
+ model = VisionTransformer(img_size=224,
738
+ patch_size=32,
739
+ embed_dim=768,
740
+ depth=12,
741
+ num_heads=12,
742
+ representation_size=None,
743
+ num_classes=num_classes)
744
+ return model
745
+
746
+
747
+ def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
748
+ """
749
+ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
750
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
751
+ weights ported from official Google JAX impl:
752
+ https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
753
+ """
754
+ model = VisionTransformer(img_size=224,
755
+ patch_size=32,
756
+ embed_dim=768,
757
+ depth=12,
758
+ num_heads=12,
759
+ representation_size=768 if has_logits else None,
760
+ num_classes=num_classes)
761
+ return model
762
+
763
+
764
+ def vit_large_patch16_224(num_classes: int = 1000):
765
+ """
766
+ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
767
+ ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
768
+ weights ported from official Google JAX impl:
769
+ 链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8
770
+ """
771
+ model = VisionTransformer(img_size=224,
772
+ patch_size=16,
773
+ embed_dim=1024,
774
+ depth=24,
775
+ num_heads=16,
776
+ representation_size=None,
777
+ num_classes=num_classes)
778
+ return model
779
+
780
+
781
+ def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
782
+ """
783
+ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
784
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
785
+ weights ported from official Google JAX impl:
786
+ https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
787
+ """
788
+ model = VisionTransformer(img_size=224,
789
+ patch_size=16,
790
+ embed_dim=1024,
791
+ depth=24,
792
+ num_heads=16,
793
+ representation_size=1024 if has_logits else None,
794
+ num_classes=num_classes)
795
+ return model
796
+
797
+
798
+ def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
799
+ """
800
+ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
801
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
802
+ weights ported from official Google JAX impl:
803
+ https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
804
+ """
805
+ model = VisionTransformer(img_size=224,
806
+ patch_size=32,
807
+ embed_dim=1024,
808
+ depth=24,
809
+ num_heads=16,
810
+ representation_size=1024 if has_logits else None,
811
+ num_classes=num_classes)
812
+ return model
813
+
814
+
815
+ def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):
816
+ """
817
+ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
818
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
819
+ NOTE: converted weights not currently available, too large for github release hosting.
820
+ """
821
+ model = VisionTransformer(img_size=224,
822
+ patch_size=14,
823
+ embed_dim=1280,
824
+ depth=32,
825
+ num_heads=16,
826
+ representation_size=1280 if has_logits else None,
827
+ num_classes=num_classes)
828
+ return model
models/vit_model_8.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ original code from rwightman:
3
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
4
+ """
5
+ from functools import partial
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.hub
14
+ from functools import partial
15
+ # import mat
16
+ # from vision_transformer.ir50 import Backbone
17
+
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import torch.hub
23
+ from functools import partial
24
+ import math
25
+
26
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
27
+ from timm.models.registry import register_model
28
+ from timm.models.vision_transformer import _cfg, Mlp, Block
29
+ from .ir50 import Backbone
30
+
31
+
32
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
33
+ """3x3 convolution with padding"""
34
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
35
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
36
+
37
+
38
+ def conv1x1(in_planes, out_planes, stride=1):
39
+ """1x1 convolution"""
40
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
41
+
42
+
43
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
44
+ """
45
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
46
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
47
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
48
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
49
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
50
+ 'survival rate' as the argument.
51
+ """
52
+ if drop_prob == 0. or not training:
53
+ return x
54
+ keep_prob = 1 - drop_prob
55
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
56
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
57
+ random_tensor.floor_() # binarize
58
+ output = x.div(keep_prob) * random_tensor
59
+ return output
60
+
61
+
62
+ class BasicBlock(nn.Module):
63
+ __constants__ = ['downsample']
64
+
65
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
66
+ super(BasicBlock, self).__init__()
67
+ norm_layer = nn.BatchNorm2d
68
+ self.conv1 = conv3x3(inplanes, planes, stride)
69
+ self.bn1 = norm_layer(planes)
70
+ self.relu = nn.ReLU(inplace=True)
71
+ self.conv2 = conv3x3(planes, planes)
72
+ self.bn2 = norm_layer(planes)
73
+ self.downsample = downsample
74
+ self.stride = stride
75
+
76
+ def forward(self, x):
77
+ identity = x
78
+
79
+ out = self.conv1(x)
80
+ out = self.bn1(out)
81
+ out = self.relu(out)
82
+ out = self.conv2(out)
83
+ out = self.bn2(out)
84
+
85
+ if self.downsample is not None:
86
+ identity = self.downsample(x)
87
+
88
+ out += identity
89
+ out = self.relu(out)
90
+
91
+ return out
92
+
93
+
94
+ class DropPath(nn.Module):
95
+ """
96
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
97
+ """
98
+
99
+ def __init__(self, drop_prob=None):
100
+ super(DropPath, self).__init__()
101
+ self.drop_prob = drop_prob
102
+
103
+ def forward(self, x):
104
+ return drop_path(x, self.drop_prob, self.training)
105
+
106
+
107
+ class PatchEmbed(nn.Module):
108
+ """
109
+ 2D Image to Patch Embedding
110
+ """
111
+
112
+ def __init__(self, img_size=14, patch_size=16, in_c=256, embed_dim=768, norm_layer=None):
113
+ super().__init__()
114
+ img_size = (img_size, img_size)
115
+ patch_size = (patch_size, patch_size)
116
+ self.img_size = img_size
117
+ self.patch_size = patch_size
118
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
119
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
120
+
121
+ self.proj = nn.Conv2d(256, 768, kernel_size=1)
122
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
123
+
124
+ def forward(self, x):
125
+ B, C, H, W = x.shape
126
+ # assert H == self.img_size[0] and W == self.img_size[1], \
127
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
128
+ # print(x.shape)
129
+
130
+ # flatten: [B, C, H, W] -> [B, C, HW]
131
+ # transpose: [B, C, HW] -> [B, HW, C]
132
+ x = self.proj(x).flatten(2).transpose(1, 2)
133
+ x = self.norm(x)
134
+ return x
135
+
136
+
137
+ class Attention(nn.Module):
138
+ def __init__(self,
139
+ dim, in_chans, # 输入token的dim
140
+ num_heads=8,
141
+ qkv_bias=False,
142
+ qk_scale=None,
143
+ attn_drop_ratio=0.,
144
+ proj_drop_ratio=0.):
145
+ super(Attention, self).__init__()
146
+ self.num_heads = 8
147
+ self.img_chanel = in_chans + 1
148
+ head_dim = dim // num_heads
149
+ self.scale = head_dim ** -0.5
150
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
151
+ self.attn_drop = nn.Dropout(attn_drop_ratio)
152
+ self.proj = nn.Linear(dim, dim)
153
+ self.proj_drop = nn.Dropout(proj_drop_ratio)
154
+
155
+ def forward(self, x):
156
+ x_img = x[:, :self.img_chanel, :]
157
+ # [batch_size, num_patches + 1, total_embed_dim]
158
+ B, N, C = x_img.shape
159
+ # print(C)
160
+ qkv = self.qkv(x_img).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
161
+ q, k, v = qkv[0], qkv[1], qkv[2]
162
+ # k, v = kv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
163
+ # q = x_img.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
164
+ attn = (q @ k.transpose(-2, -1)) * self.scale
165
+ attn = attn.softmax(dim=-1)
166
+ attn = self.attn_drop(attn)
167
+
168
+ x_img = (attn @ v).transpose(1, 2).reshape(B, N, C)
169
+ x_img = self.proj(x_img)
170
+ x_img = self.proj_drop(x_img)
171
+ #
172
+ #
173
+ # # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
174
+ # # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
175
+ # # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
176
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
177
+ # # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
178
+ # q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
179
+ #
180
+ # # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
181
+ # # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
182
+ # attn = (q @ k.transpose(-2, -1)) * self.scale
183
+ # attn = attn.softmax(dim=-1)
184
+ # attn = self.attn_drop(attn)
185
+ #
186
+ # # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
187
+ # # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
188
+ # # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
189
+ # x = (attn @ v).transpose(1, 2).reshape(B, N, C)
190
+ # x = self.proj(x)
191
+ # x = self.proj_drop(x)
192
+ return x_img
193
+
194
+
195
+ class AttentionBlock(nn.Module):
196
+ __constants__ = ['downsample']
197
+
198
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
199
+ super(AttentionBlock, self).__init__()
200
+ norm_layer = nn.BatchNorm2d
201
+ self.conv1 = conv3x3(inplanes, planes, stride)
202
+ self.bn1 = norm_layer(planes)
203
+ self.relu = nn.ReLU(inplace=True)
204
+ self.conv2 = conv3x3(planes, planes)
205
+ self.bn2 = norm_layer(planes)
206
+ self.downsample = downsample
207
+ self.stride = stride
208
+ # self.cbam = CBAM(planes, 16)
209
+ self.inplanes = inplanes
210
+ self.eca_block = eca_block()
211
+
212
+ def forward(self, x):
213
+ identity = x
214
+
215
+ out = self.conv1(x)
216
+ out = self.bn1(out)
217
+ out = self.relu(out)
218
+
219
+ out = self.conv2(out)
220
+ out = self.bn2(out)
221
+ inplanes = self.inplanes
222
+ out = self.eca_block(out)
223
+ if self.downsample is not None:
224
+ identity = self.downsample(x)
225
+
226
+ out += identity
227
+ out = self.relu(out)
228
+
229
+ return out
230
+
231
+
232
+ class Mlp(nn.Module):
233
+ """
234
+ MLP as used in Vision Transformer, MLP-Mixer and related networks
235
+ """
236
+
237
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
238
+ super().__init__()
239
+ out_features = out_features or in_features
240
+ hidden_features = hidden_features or in_features
241
+ self.fc1 = nn.Linear(in_features, hidden_features)
242
+ self.act = act_layer()
243
+ self.fc2 = nn.Linear(hidden_features, out_features)
244
+ self.drop = nn.Dropout(drop)
245
+
246
+ def forward(self, x):
247
+ x = self.fc1(x)
248
+ x = self.act(x)
249
+ x = self.drop(x)
250
+ x = self.fc2(x)
251
+ x = self.drop(x)
252
+ return x
253
+
254
+
255
+ class Block(nn.Module):
256
+ def __init__(self,
257
+ dim, in_chans,
258
+ num_heads,
259
+ mlp_ratio=4.,
260
+ qkv_bias=False,
261
+ qk_scale=None,
262
+ drop_ratio=0.,
263
+ attn_drop_ratio=0.,
264
+ drop_path_ratio=0.,
265
+ act_layer=nn.GELU,
266
+ norm_layer=nn.LayerNorm):
267
+ super(Block, self).__init__()
268
+ self.norm1 = norm_layer(dim)
269
+ self.img_chanel = in_chans + 1
270
+
271
+ self.conv = nn.Conv1d(self.img_chanel, self.img_chanel, 1)
272
+ self.attn = Attention(dim, in_chans=in_chans, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
273
+ attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
274
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
275
+ self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
276
+ self.norm2 = norm_layer(dim)
277
+ mlp_hidden_dim = int(dim * mlp_ratio)
278
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
279
+
280
+ def forward(self, x):
281
+ # x = x + self.drop_path(self.attn(self.norm1(x)))
282
+ # x = x + self.drop_path(self.mlp(self.norm2(x)))
283
+
284
+ x_img = x
285
+ # [:, :self.img_chanel, :]
286
+ # x_lm = x[:, self.img_chanel:, :]
287
+ x_img = x_img + self.drop_path(self.attn(self.norm1(x)))
288
+ x = x_img + self.drop_path(self.mlp(self.norm2(x_img)))
289
+ #
290
+ # x_lm = x_lm + self.drop_path(self.attn_lm(self.norm3(x)))
291
+ # x_lm = x_lm + self.drop_path(self.mlp2(self.norm4(x_lm)))
292
+ # x = torch.cat((x_img, x_lm), dim=1)
293
+ # x = self.conv(x)
294
+
295
+ return x
296
+
297
+
298
+ class ClassificationHead(nn.Module):
299
+ def __init__(self, input_dim: int, target_dim: int):
300
+ super().__init__()
301
+ self.linear = torch.nn.Linear(input_dim, target_dim)
302
+
303
+ def forward(self, x):
304
+ x = x.view(x.size(0), -1)
305
+ y_hat = self.linear(x)
306
+ return y_hat
307
+
308
+
309
+ def load_pretrained_weights(model, checkpoint):
310
+ import collections
311
+ if 'state_dict' in checkpoint:
312
+ state_dict = checkpoint['state_dict']
313
+ else:
314
+ state_dict = checkpoint
315
+ model_dict = model.state_dict()
316
+ new_state_dict = collections.OrderedDict()
317
+ matched_layers, discarded_layers = [], []
318
+ for k, v in state_dict.items():
319
+ # If the pretrained state_dict was saved as nn.DataParallel,
320
+ # keys would contain "module.", which should be ignored.
321
+ if k.startswith('module.'):
322
+ k = k[7:]
323
+ if k in model_dict and model_dict[k].size() == v.size():
324
+ new_state_dict[k] = v
325
+ matched_layers.append(k)
326
+ else:
327
+ discarded_layers.append(k)
328
+ # new_state_dict.requires_grad = False
329
+ model_dict.update(new_state_dict)
330
+
331
+ model.load_state_dict(model_dict)
332
+ print('load_weight', len(matched_layers))
333
+ return model
334
+
335
+ class eca_block(nn.Module):
336
+ def __init__(self, channel=128, b=1, gamma=2):
337
+ super(eca_block, self).__init__()
338
+ kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
339
+ kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
340
+
341
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
342
+ self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
343
+ self.sigmoid = nn.Sigmoid()
344
+
345
+ def forward(self, x):
346
+ y = self.avg_pool(x)
347
+ y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
348
+ y = self.sigmoid(y)
349
+ return x * y.expand_as(x)
350
+ #
351
+ #
352
+ # class IR20(nn.Module):
353
+ # def __init__(self, img_size_=112, num_classes=7, layers=[2, 2, 2, 2]):
354
+ # super().__init__()
355
+ # norm_layer = nn.BatchNorm2d
356
+ # self.img_size = img_size_
357
+ # self._norm_layer = norm_layer
358
+ # self.num_classes = num_classes
359
+ # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
360
+ # self.bn1 = norm_layer(64)
361
+ # self.relu = nn.ReLU(inplace=True)
362
+ # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
363
+ # # self.face_landback = MobileFaceNet([112, 112],136)
364
+ # # face_landback_checkpoint = torch.load('./models/pretrain/mobilefacenet_model_best.pth.tar', map_location=lambda storage, loc: storage)
365
+ # # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
366
+ # self.layer1 = self._make_layer(BasicBlock, 64, 64, layers[0])
367
+ # self.layer2 = self._make_layer(BasicBlock, 64, 128, layers[1], stride=2)
368
+ # self.layer3 = self._make_layer(AttentionBlock, 128, 256, layers[2], stride=2)
369
+ # self.layer4 = self._make_layer(AttentionBlock, 256, 256, layers[3], stride=1)
370
+ # self.ir_back = Backbone(50, 51, 52, 0.0, 'ir')
371
+ # self.ir_layer = nn.Linear(1024, 512)
372
+ # # ir_checkpoint = torch.load(r'F:\0815crossvit\vision_transformer\models\pretrain\Pretrained_on_MSCeleb.pth.tar',
373
+ # # map_location=lambda storage, loc: storage)
374
+ # # ir_checkpoint = ir_checkpoint['state_dict']
375
+ # # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
376
+ # # checkpoint = torch.load('./checkpoint/Pretrained_on_MSCeleb.pth.tar')
377
+ # # pre_trained_dict = checkpoint['state_dict']
378
+ # # IR20.load_state_dict(ir_checkpoint, strict=False)
379
+ # # self.IR = load_pretrained_weights(IR, ir_checkpoint)
380
+ #
381
+ # def _make_layer(self, block, inplanes, planes, blocks, stride=1):
382
+ # norm_layer = self._norm_layer
383
+ # downsample = None
384
+ # if stride != 1 or inplanes != planes:
385
+ # downsample = nn.Sequential(conv1x1(inplanes, planes, stride), norm_layer(planes))
386
+ # layers = []
387
+ # layers.append(block(inplanes, planes, stride, downsample))
388
+ # inplanes = planes
389
+ # for _ in range(1, blocks):
390
+ # layers.append(block(inplanes, planes))
391
+ # return nn.Sequential(*layers)
392
+ #
393
+ # def forward(self, x):
394
+ # x_ir = self.ir_back(x)
395
+ # # x_ir = self.ir_layer(x_ir)
396
+ # # print(x_ir.shape)
397
+ # # x = F.interpolate(x, size=112)
398
+ # # x = self.conv1(x)
399
+ # # x = self.bn1(x)
400
+ # # x = self.relu(x)
401
+ # # x = self.maxpool(x)
402
+ # #
403
+ # # x = self.layer1(x)
404
+ # # x = self.layer2(x)
405
+ # # x = self.layer3(x)
406
+ # # x = self.layer4(x)
407
+ # # print(x.shape)
408
+ # # print(x)
409
+ # out = x_ir
410
+ #
411
+ # return out
412
+ #
413
+ #
414
+ # class IR(nn.Module):
415
+ # def __init__(self, img_size_=112, num_classes=7):
416
+ # super().__init__()
417
+ # depth = 8
418
+ # # if type == "small":
419
+ # # depth = 4
420
+ # # if type == "base":
421
+ # # depth = 6
422
+ # # if type == "large":
423
+ # # depth = 8
424
+ #
425
+ # self.img_size = img_size_
426
+ # self.num_classes = num_classes
427
+ # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
428
+ # # self.bn1 = norm_layer(64)
429
+ # self.relu = nn.ReLU(inplace=True)
430
+ # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
431
+ # # self.face_landback = MobileFaceNet([112, 112],136)
432
+ # # face_landback_checkpoint = torch.load('./models/pretrain/mobilefacenet_model_best.pth.tar', map_location=lambda storage, loc: storage)
433
+ # # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
434
+ #
435
+ # # for param in self.face_landback.parameters():
436
+ # # param.requires_grad = False
437
+ #
438
+ # ###########################################################################333
439
+ #
440
+ # self.ir_back = IR20()
441
+ #
442
+ # # ir_checkpoint = torch.load(r'F:\0815crossvit\vision_transformer\models\pretrain\ir50.pth',
443
+ # # map_location=lambda storage, loc: storage)
444
+ # # # ir_checkpoint = ir_checkpoint["model"]
445
+ # # self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
446
+ # # load_state_dict(checkpoint_model, strict=False)
447
+ # # self.ir_layer = nn.Linear(1024,512)
448
+ #
449
+ # #############################################################3
450
+ # #
451
+ # # self.pyramid_fuse = HyVisionTransformer(in_chans=49, q_chanel = 49, embed_dim=512,
452
+ # # depth=depth, num_heads=8, mlp_ratio=2.,
453
+ # # drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1)
454
+ #
455
+ # # self.se_block = SE_block(input_dim=512)
456
+ # self.head = ClassificationHead(input_dim=768, target_dim=self.num_classes)
457
+ #
458
+ # def forward(self, x):
459
+ # B_ = x.shape[0]
460
+ # # x_face = F.interpolate(x, size=112)
461
+ # # _, x_face = self.face_landback(x_face)
462
+ # # x_face = x_face.view(B_, -1, 49).transpose(1,2)
463
+ # ############### landmark x_face ([B, 49, 512])
464
+ # x_ir = self.ir_back(x)
465
+ # # print(x_ir.shape)
466
+ # # x_ir = self.ir_layer(x_ir)
467
+ # # print(x_ir.shape)
468
+ # ############### image x_ir ([B, 49, 512])
469
+ #
470
+ # # y_hat = self.pyramid_fuse(x_ir, x_face)
471
+ # # y_hat = self.se_block(y_hat)
472
+ # # y_feat = y_hat
473
+ #
474
+ # # out = self.head(x_ir)
475
+ #
476
+ # out = x_ir
477
+ # return out
478
+
479
+
480
+ class eca_block(nn.Module):
481
+ def __init__(self, channel=196, b=1, gamma=2):
482
+ super(eca_block, self).__init__()
483
+ kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
484
+ kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
485
+
486
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
487
+ self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
488
+ self.sigmoid = nn.Sigmoid()
489
+
490
+ def forward(self, x):
491
+ y = self.avg_pool(x)
492
+ y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
493
+ y = self.sigmoid(y)
494
+ return x * y.expand_as(x)
495
+
496
+ class SE_block(nn.Module):
497
+ def __init__(self, input_dim: int):
498
+ super().__init__()
499
+ self.linear1 = torch.nn.Linear(input_dim, input_dim)
500
+ self.relu = nn.ReLU()
501
+ self.linear2 = torch.nn.Linear(input_dim, input_dim)
502
+ self.sigmod = nn.Sigmoid()
503
+
504
+ def forward(self, x):
505
+ x1 = self.linear1(x)
506
+ x1 = self.relu(x1)
507
+ x1 = self.linear2(x1)
508
+ x1 = self.sigmod(x1)
509
+ x = x * x1
510
+ return x
511
+
512
+
513
+ class VisionTransformer(nn.Module):
514
+ def __init__(self, img_size=14, patch_size=14, in_c=147, num_classes=8,
515
+ embed_dim=768, depth=6, num_heads=8, mlp_ratio=4.0, qkv_bias=True,
516
+ qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
517
+ attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
518
+ act_layer=None):
519
+ """
520
+ Args:
521
+ img_size (int, tuple): input image size
522
+ patch_size (int, tuple): patch size
523
+ in_c (int): number of input channels
524
+ num_classes (int): number of classes for classification head
525
+ embed_dim (int): embedding dimension
526
+ depth (int): depth of transformer
527
+ num_heads (int): number of attention heads
528
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
529
+ qkv_bias (bool): enable bias for qkv if True
530
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
531
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
532
+ distilled (bool): model includes a distillation token and head as in DeiT models
533
+ drop_ratio (float): dropout rate
534
+ attn_drop_ratio (float): attention dropout rate
535
+ drop_path_ratio (float): stochastic depth rate
536
+ embed_layer (nn.Module): patch embedding layer
537
+ norm_layer: (nn.Module): normalization layer
538
+ """
539
+ super(VisionTransformer, self).__init__()
540
+ self.num_classes = num_classes
541
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
542
+ self.num_tokens = 2 if distilled else 1
543
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
544
+ act_layer = act_layer or nn.GELU
545
+
546
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
547
+ self.pos_embed = nn.Parameter(torch.zeros(1, in_c + 1, embed_dim))
548
+ self.pos_drop = nn.Dropout(p=drop_ratio)
549
+
550
+ self.se_block = SE_block(input_dim=embed_dim)
551
+
552
+
553
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=256, embed_dim=768)
554
+ num_patches = self.patch_embed.num_patches
555
+ self.head = ClassificationHead(input_dim=embed_dim, target_dim=self.num_classes)
556
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
557
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
558
+ # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
559
+ self.pos_drop = nn.Dropout(p=drop_ratio)
560
+ # self.IR = IR()
561
+ self.eca_block = eca_block()
562
+
563
+
564
+ # self.ir_back = Backbone(50, 0.0, 'ir')
565
+ # ir_checkpoint = torch.load('./models/pretrain/ir50.pth', map_location=lambda storage, loc: storage)
566
+ # # ir_checkpoint = ir_checkpoint["model"]
567
+ # self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
568
+
569
+ self.CON1 = nn.Conv2d(256, 768, kernel_size=1, stride=1, bias=False)
570
+ self.IRLinear1 = nn.Linear(1024, 768)
571
+ self.IRLinear2 = nn.Linear(768, 512)
572
+ self.eca_block = eca_block()
573
+ dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
574
+ self.blocks = nn.Sequential(*[
575
+ Block(dim=embed_dim, in_chans=in_c, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
576
+ qk_scale=qk_scale,
577
+ drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
578
+ norm_layer=norm_layer, act_layer=act_layer)
579
+ for i in range(depth)
580
+ ])
581
+ self.norm = norm_layer(embed_dim)
582
+
583
+ # Representation layer
584
+ if representation_size and not distilled:
585
+ self.has_logits = True
586
+ self.num_features = representation_size
587
+ self.pre_logits = nn.Sequential(OrderedDict([
588
+ ("fc", nn.Linear(embed_dim, representation_size)),
589
+ ("act", nn.Tanh())
590
+ ]))
591
+ else:
592
+ self.has_logits = False
593
+ self.pre_logits = nn.Identity()
594
+
595
+ # Classifier head(s)
596
+ # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
597
+ self.head_dist = None
598
+ if distilled:
599
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
600
+
601
+ # Weight init
602
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
603
+ if self.dist_token is not None:
604
+ nn.init.trunc_normal_(self.dist_token, std=0.02)
605
+
606
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
607
+ self.apply(_init_vit_weights)
608
+
609
+ def forward_features(self, x):
610
+ # [B, C, H, W] -> [B, num_patches, embed_dim]
611
+ # x = self.patch_embed(x) # [B, 196, 768]
612
+ # [1, 1, 768] -> [B, 1, 768]
613
+ # print(x.shape)
614
+
615
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1)
616
+ if self.dist_token is None:
617
+ x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
618
+ else:
619
+ x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
620
+ # print(x.shape)
621
+ x = self.pos_drop(x + self.pos_embed)
622
+ x = self.blocks(x)
623
+ x = self.norm(x)
624
+ if self.dist_token is None:
625
+ return self.pre_logits(x[:, 0])
626
+ else:
627
+ return x[:, 0], x[:, 1]
628
+
629
+ def forward(self, x):
630
+
631
+ # B = x.shape[0]
632
+ # print(x)
633
+ # x = self.eca_block(x)
634
+ # x = self.IR(x)
635
+ # x = eca_block(x)
636
+ # x = self.ir_back(x)
637
+ # print(x.shape)
638
+ # x = self.CON1(x)
639
+ # x = x.view(-1, 196, 768)
640
+ #
641
+ # # print(x.shape)
642
+ # # x = self.IRLinear1(x)
643
+ # # print(x)
644
+ # x_cls = torch.mean(x, 1).view(B, 1, -1)
645
+ # x = torch.cat((x_cls, x), dim=1)
646
+ # # print(x.shape)
647
+ # x = self.pos_drop(x + self.pos_embed)
648
+ # # print(x.shape)
649
+ # x = self.blocks(x)
650
+ # # print(x)
651
+ # x = self.norm(x)
652
+ # # print(x)
653
+ # # x1 = self.IRLinear2(x)
654
+ # x1 = x[:, 0, :]
655
+
656
+ # print(x1)
657
+ # print(x1.shape)
658
+
659
+ x = self.forward_features(x)
660
+ # # print(x.shape)
661
+ # if self.head_dist is not None:
662
+ # x, x_dist = self.head(x[0]), self.head_dist(x[1])
663
+ # if self.training and not torch.jit.is_scripting():
664
+ # # during inference, return the average of both classifier predictions
665
+ # return x, x_dist
666
+ # else:
667
+ # return (x + x_dist) / 2
668
+ # else:
669
+ # print(x.shape)
670
+ x = self.se_block(x)
671
+
672
+ x1 = self.head(x)
673
+
674
+ return x1
675
+
676
+
677
+ def _init_vit_weights(m):
678
+ """
679
+ ViT weight initialization
680
+ :param m: module
681
+ """
682
+ if isinstance(m, nn.Linear):
683
+ nn.init.trunc_normal_(m.weight, std=.01)
684
+ if m.bias is not None:
685
+ nn.init.zeros_(m.bias)
686
+ elif isinstance(m, nn.Conv2d):
687
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
688
+ if m.bias is not None:
689
+ nn.init.zeros_(m.bias)
690
+ elif isinstance(m, nn.LayerNorm):
691
+ nn.init.zeros_(m.bias)
692
+ nn.init.ones_(m.weight)
693
+
694
+
695
+ def vit_base_patch16_224(num_classes: int = 7):
696
+ """
697
+ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
698
+ ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
699
+ weights ported from official Google JAX impl:
700
+ 链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f
701
+ """
702
+ model = VisionTransformer(img_size=224,
703
+ patch_size=16,
704
+ embed_dim=768,
705
+ depth=12,
706
+ num_heads=12,
707
+ representation_size=None,
708
+ num_classes=num_classes)
709
+
710
+ return model
711
+
712
+
713
+ def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
714
+ """
715
+ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
716
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
717
+ weights ported from official Google JAX impl:
718
+ https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
719
+ """
720
+ model = VisionTransformer(img_size=224,
721
+ patch_size=16,
722
+ embed_dim=768,
723
+ depth=12,
724
+ num_heads=12,
725
+ representation_size=768 if has_logits else None,
726
+ num_classes=num_classes)
727
+ return model
728
+
729
+
730
+ def vit_base_patch32_224(num_classes: int = 1000):
731
+ """
732
+ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
733
+ ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
734
+ weights ported from official Google JAX impl:
735
+ 链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl
736
+ """
737
+ model = VisionTransformer(img_size=224,
738
+ patch_size=32,
739
+ embed_dim=768,
740
+ depth=12,
741
+ num_heads=12,
742
+ representation_size=None,
743
+ num_classes=num_classes)
744
+ return model
745
+
746
+
747
+ def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
748
+ """
749
+ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
750
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
751
+ weights ported from official Google JAX impl:
752
+ https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
753
+ """
754
+ model = VisionTransformer(img_size=224,
755
+ patch_size=32,
756
+ embed_dim=768,
757
+ depth=12,
758
+ num_heads=12,
759
+ representation_size=768 if has_logits else None,
760
+ num_classes=num_classes)
761
+ return model
762
+
763
+
764
+ def vit_large_patch16_224(num_classes: int = 1000):
765
+ """
766
+ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
767
+ ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
768
+ weights ported from official Google JAX impl:
769
+ 链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8
770
+ """
771
+ model = VisionTransformer(img_size=224,
772
+ patch_size=16,
773
+ embed_dim=1024,
774
+ depth=24,
775
+ num_heads=16,
776
+ representation_size=None,
777
+ num_classes=num_classes)
778
+ return model
779
+
780
+
781
+ def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
782
+ """
783
+ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
784
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
785
+ weights ported from official Google JAX impl:
786
+ https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
787
+ """
788
+ model = VisionTransformer(img_size=224,
789
+ patch_size=16,
790
+ embed_dim=1024,
791
+ depth=24,
792
+ num_heads=16,
793
+ representation_size=1024 if has_logits else None,
794
+ num_classes=num_classes)
795
+ return model
796
+
797
+
798
+ def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
799
+ """
800
+ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
801
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
802
+ weights ported from official Google JAX impl:
803
+ https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
804
+ """
805
+ model = VisionTransformer(img_size=224,
806
+ patch_size=32,
807
+ embed_dim=1024,
808
+ depth=24,
809
+ num_heads=16,
810
+ representation_size=1024 if has_logits else None,
811
+ num_classes=num_classes)
812
+ return model
813
+
814
+
815
+ def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):
816
+ """
817
+ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
818
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
819
+ NOTE: converted weights not currently available, too large for github release hosting.
820
+ """
821
+ model = VisionTransformer(img_size=224,
822
+ patch_size=14,
823
+ embed_dim=1280,
824
+ depth=32,
825
+ num_heads=16,
826
+ representation_size=1280 if has_logits else None,
827
+ num_classes=num_classes)
828
+ return model
prediction.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from main import *
2
+ from deepface import DeepFace
3
+
4
+ # Checking for all types of devices available
5
+ if torch.backends.mps.is_available():
6
+ device = "mps"
7
+ elif torch.cuda.is_available():
8
+ device = "cuda"
9
+ else:
10
+ device = "cpu"
11
+
12
+ print(f"Using device: {device}")
13
+ # Predicting the model
14
+ # def prediction(model, image_path):
15
+ model = pyramid_trans_expr2(img_size=224, num_classes=7)
16
+
17
+ model = torch.nn.DataParallel(model)
18
+ model = model.to(device)
19
+
20
+ model_path = "raf-db-model_best.pth"
21
+ image_arr = []
22
+ for foldername, subfolders, filenames in os.walk(
23
+ "/Users/futuregadgetlab/Downloads/Testing/"
24
+ ):
25
+ for filename in filenames:
26
+ # Construct the full path to the file
27
+ file_path = os.path.join(foldername, filename)
28
+ image_arr.append(f"{file_path}")
29
+
30
+
31
+ def main():
32
+ if model_path is not None:
33
+ if os.path.isfile(model_path):
34
+ print("=> loading checkpoint '{}'".format(model_path))
35
+ checkpoint = torch.load(model_path, map_location=device)
36
+ best_acc = checkpoint["best_acc"]
37
+ best_acc = best_acc.to()
38
+ print(f"best_acc:{best_acc}")
39
+ model.load_state_dict(checkpoint["state_dict"])
40
+ print(
41
+ "=> loaded checkpoint '{}' (epoch {})".format(
42
+ model_path, checkpoint["epoch"]
43
+ )
44
+ )
45
+ else:
46
+ print("=> no checkpoint found at '{}'".format(model_path))
47
+ predict(model, image_path=image_arr)
48
+ return
49
+
50
+
51
+ def predict(model, image_path):
52
+ from face_detection import face_detection
53
+
54
+ with torch.no_grad():
55
+ transform = transforms.Compose(
56
+ [
57
+ transforms.Resize((224, 224)),
58
+ transforms.RandomHorizontalFlip(),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize(
61
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
62
+ ),
63
+ transforms.RandomErasing(p=1, scale=(0.05, 0.05)),
64
+ ]
65
+ )
66
+ face = face_detection(image_path)
67
+ image_tensor = transform(face).unsqueeze(0)
68
+ image_tensor = image_tensor.to(device)
69
+
70
+ model.eval()
71
+ img_pred = model(image_tensor)
72
+ topk = (3,)
73
+ with torch.no_grad():
74
+ maxk = max(topk)
75
+ # batch_size = target.size(0)
76
+ _, pred = img_pred.topk(maxk, 1, True, True)
77
+ pred = pred.t()
78
+
79
+ img_pred = pred
80
+ img_pred = img_pred.squeeze().cpu().numpy()
81
+ im_pre_label = np.array(img_pred)
82
+ y_pred = im_pre_label.flatten()
83
+ emotions = {
84
+ 0: "Surprise",
85
+ 1: "Fear",
86
+ 2: "Disgust",
87
+ 3: "Happy",
88
+ 4: "Sad",
89
+ 5: "Angry",
90
+ 6: "Neutral",
91
+ }
92
+ labels = []
93
+ for i in y_pred:
94
+ labels.append(emotions.get(i))
95
+
96
+ print(
97
+ f"-->Image Path {image_path} [!] The predicted labels are {y_pred} and the label is {labels}"
98
+ )
99
+ return
100
+
101
+
102
+ if __name__ == "__main__":
103
+ main()
raf-db-model_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9bf1d0d88238966ce0d1a289a2bb5f927ec2fe635ef1ec4396c323028924701
3
+ size 238971279
requirements.txt ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ appdirs==1.4.4
2
+ asgiref==3.7.2
3
+ attr==0.3.1
4
+ azure-core==1.29.5
5
+ azure-storage-blob==12.18.3
6
+ bleach==5.0.1
7
+ boto==2.49.0
8
+ boto3==1.16.63
9
+ botocore==1.19.63
10
+ boxing==0.1.4
11
+ Brotli @ file:///Users/runner/miniforge3/conda-bld/brotli-split_1695989934239/work
12
+ certifi==2023.7.22
13
+ cffi==1.16.0
14
+ charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1698833585322/work
15
+ click==8.1.7
16
+ colorama==0.4.6
17
+ contourpy @ file:///Users/runner/miniforge3/conda-bld/contourpy_1699041448398/work
18
+ coreapi==2.3.3
19
+ coreschema==0.0.4
20
+ cryptography==41.0.5
21
+ cycler @ file:///home/conda/feedstock_root/build_artifacts/cycler_1696677705766/work
22
+ defusedxml==0.7.1
23
+ Django==3.2.20
24
+ django-annoying==0.10.6
25
+ django-cors-headers==3.6.0
26
+ django-debug-toolbar==3.2.1
27
+ django-environ==0.10.0
28
+ django-extensions==3.1.0
29
+ django-filter==2.4.0
30
+ django-model-utils==4.1.1
31
+ django-ranged-fileresponse==0.1.2
32
+ django-rest-swagger==2.2.0
33
+ django-rq==2.5.1
34
+ django-storages==1.12.3
35
+ django-user-agents==0.4.0
36
+ djangorestframework==3.13.1
37
+ drf-dynamic-fields==0.3.0
38
+ drf-flex-fields==0.9.5
39
+ drf-generators==0.3.0
40
+ drf-yasg==1.20.0
41
+ expiringdict==1.2.2
42
+ filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1698714947081/work
43
+ fonttools @ file:///Users/runner/miniforge3/conda-bld/fonttools_1699023568720/work
44
+ fsspec==2023.10.0
45
+ gmpy2 @ file:///Users/runner/miniforge3/conda-bld/gmpy2_1666808749046/work
46
+ google-api-core==2.11.0
47
+ google-cloud-appengine-logging==1.1.0
48
+ google-cloud-audit-log==0.2.0
49
+ google-cloud-core==2.3.2
50
+ google-cloud-logging==2.7.1
51
+ google-cloud-storage==2.5.0
52
+ google-crc32c==1.5.0
53
+ google-resumable-media==2.3.3
54
+ googleapis-common-protos==1.56.4
55
+ grpc-google-iam-v1==0.12.4
56
+ grpcio-status==1.59.2
57
+ htmlmin==0.1.12
58
+ huggingface-hub==0.18.0
59
+ idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work
60
+ ijson==3.2.3
61
+ inflection==0.5.1
62
+ isodate==0.6.1
63
+ itypes==1.2.0
64
+ Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1654302431367/work
65
+ jmespath==0.10.0
66
+ joblib==1.3.2
67
+ jsonschema==3.2.0
68
+ kiwisolver @ file:///Users/runner/miniforge3/conda-bld/kiwisolver_1695380058985/work
69
+ label-studio==1.8.2.post1
70
+ label-studio-converter==0.0.54rc0
71
+ label-studio-tools==0.0.3
72
+ launchdarkly-server-sdk==7.5.0
73
+ lockfile==0.12.2
74
+ lxml==4.9.3
75
+ MarkupSafe @ file:///Users/runner/miniforge3/conda-bld/markupsafe_1695367660391/work
76
+ matplotlib @ file:///Users/runner/miniforge3/conda-bld/matplotlib-suite_1698868590489/work
77
+ mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1678228039184/work
78
+ munkres==1.1.4
79
+ networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1698504735452/work
80
+ nltk==3.6.7
81
+ numpy @ file:///Users/runner/miniforge3/conda-bld/numpy_1694920094885/work/dist/numpy-1.26.0-cp311-cp311-macosx_11_0_arm64.whl#sha256=6909902123b8421906e90ad77fb0041d9eb2d95bbdc29f3d09c7d244b0e0e5a5
82
+ openapi-codec==1.3.2
83
+ ordered-set==4.0.2
84
+ packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1696202382185/work
85
+ pandas==2.1.2
86
+ Pillow @ file:///Users/runner/miniforge3/conda-bld/pillow_1697423665652/work
87
+ proto-plus==1.22.3
88
+ psycopg2-binary==2.9.6
89
+ pycparser==2.21
90
+ pyparsing @ file:///home/conda/feedstock_root/build_artifacts/pyparsing_1690737849915/work
91
+ pyRFC3339==1.1
92
+ pyrsistent==0.20.0
93
+ PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
94
+ python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work
95
+ python-json-logger==2.0.4
96
+ pytz==2023.3.post1
97
+ PyYAML @ file:///Users/runner/miniforge3/conda-bld/pyyaml_1695373486380/work
98
+ redis==3.5.3
99
+ requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1684774241324/work
100
+ rq==1.10.1
101
+ ruamel.yaml==0.18.5
102
+ ruamel.yaml.clib==0.2.8
103
+ rules==2.2
104
+ s3transfer==0.3.7
105
+ safetensors==0.4.0
106
+ scikit-learn==1.3.2
107
+ scipy==1.11.3
108
+ semver==2.13.0
109
+ sentry-sdk==1.34.0
110
+ simplejson==3.19.2
111
+ six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
112
+ sqlparse==0.4.4
113
+ sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1684180540116/work
114
+ thop==0.1.1.post2209072238
115
+ threadpoolctl==3.2.0
116
+ timm==0.9.10
117
+ torch==2.1.0
118
+ torchaudio==2.1.0
119
+ torchsampler==0.1.2
120
+ torchvision==0.16.0
121
+ tornado @ file:///Users/runner/miniforge3/conda-bld/tornado_1695373481350/work
122
+ tqdm==4.66.1
123
+ typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1695040754690/work
124
+ tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1680081134351/work
125
+ ua-parser==0.18.0
126
+ ujson==5.8.0
127
+ uritemplate==4.1.1
128
+ urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1697720414277/work
129
+ user-agents==2.2.0
130
+ webencodings==0.5.1
131
+ xmljson==0.2.0