nightfury commited on
Commit
95d308c
1 Parent(s): ed0f3f1

upload chkpt

Browse files

chkpt, src, doc

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. checkpoints/BFM_Fitting/01_MorphableModel.mat +1 -0
  2. checkpoints/BFM_Fitting/BFM09_model_info.mat +1 -0
  3. checkpoints/BFM_Fitting/BFM_exp_idx.mat +1 -0
  4. checkpoints/BFM_Fitting/BFM_front_idx.mat +1 -0
  5. checkpoints/BFM_Fitting/Exp_Pca.bin +1 -0
  6. checkpoints/BFM_Fitting/facemodel_info.mat +1 -0
  7. checkpoints/BFM_Fitting/select_vertex_id.mat +1 -0
  8. checkpoints/BFM_Fitting/similarity_Lm3D_all.mat +1 -0
  9. checkpoints/BFM_Fitting/std_exp.txt +1 -0
  10. checkpoints/auido2exp_00300-model.pth +1 -0
  11. checkpoints/auido2pose_00140-model.pth +1 -0
  12. checkpoints/epoch_20.pth +1 -0
  13. checkpoints/facevid2vid_00189-model.pth.tar +1 -0
  14. checkpoints/hub/checkpoints/2DFAN4-cd938726ad.zip +1 -0
  15. checkpoints/hub/checkpoints/s3fd-619a316812.pth +1 -0
  16. checkpoints/mapping_00229-model.pth.tar +1 -0
  17. checkpoints/shape_predictor_68_face_landmarks.dat +1 -0
  18. checkpoints/wav2lip.pth +1 -0
  19. docs/logo.jpg +0 -0
  20. src/audio2exp_models/audio2exp.py +41 -0
  21. src/audio2exp_models/networks.py +74 -0
  22. src/audio2pose_models/audio2pose.py +94 -0
  23. src/audio2pose_models/audio_encoder.py +64 -0
  24. src/audio2pose_models/cvae.py +149 -0
  25. src/audio2pose_models/discriminator.py +76 -0
  26. src/audio2pose_models/networks.py +140 -0
  27. src/audio2pose_models/res_unet.py +65 -0
  28. src/config/auido2exp.yaml +58 -0
  29. src/config/auido2pose.yaml +49 -0
  30. src/config/facerender.yaml +45 -0
  31. src/config/facerender_pirender.yaml +83 -0
  32. src/config/facerender_still.yaml +45 -0
  33. src/config/similarity_Lm3D_all.mat +0 -0
  34. src/face3d/data/__init__.py +116 -0
  35. src/face3d/data/base_dataset.py +125 -0
  36. src/face3d/data/flist_dataset.py +125 -0
  37. src/face3d/data/image_folder.py +66 -0
  38. src/face3d/data/template_dataset.py +75 -0
  39. src/face3d/extract_kp_videos.py +108 -0
  40. src/face3d/extract_kp_videos_safe.py +151 -0
  41. src/face3d/models/__init__.py +67 -0
  42. src/face3d/models/arcface_torch/README.md +164 -0
  43. src/face3d/models/arcface_torch/backbones/__init__.py +25 -0
  44. src/face3d/models/arcface_torch/backbones/iresnet.py +187 -0
  45. src/face3d/models/arcface_torch/backbones/iresnet2060.py +176 -0
  46. src/face3d/models/arcface_torch/backbones/mobilefacenet.py +130 -0
  47. src/face3d/models/arcface_torch/configs/3millions.py +23 -0
  48. src/face3d/models/arcface_torch/configs/3millions_pfc.py +23 -0
  49. src/face3d/models/arcface_torch/configs/__init__.py +0 -0
  50. src/face3d/models/arcface_torch/configs/base.py +56 -0
checkpoints/BFM_Fitting/01_MorphableModel.mat ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/37b1f0742db356a3b1568a8365a06f5b0fe0ab687ac1c3068c803666cbd4d8e2
checkpoints/BFM_Fitting/BFM09_model_info.mat ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/db8d00544f0b0182f1b8430a3bb87662b3ff674eb33c84e6f52dbe2971adb81b
checkpoints/BFM_Fitting/BFM_exp_idx.mat ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/1146e4e9c3bef303a497383aa7974c014fe945c7
checkpoints/BFM_Fitting/BFM_front_idx.mat ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/b9d7b0953dd1dc5b1e28144610485409ac321f9b
checkpoints/BFM_Fitting/Exp_Pca.bin ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/e7f31380e6cbdaf2aeec698db220bac4f221946e4d551d88c092d47ec49b1726
checkpoints/BFM_Fitting/facemodel_info.mat ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/3e516ec7297fa3248098f49ecea10579f4831c0a
checkpoints/BFM_Fitting/select_vertex_id.mat ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/5b8b220093d93b133acc94ffed159f31a74854cd
checkpoints/BFM_Fitting/similarity_Lm3D_all.mat ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/a0e23588302bc71fc899eef53ff06df5f4df4c1d
checkpoints/BFM_Fitting/std_exp.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/767b8de4ea1ca78b6f22b98ff2dee4fa345500bb
checkpoints/auido2exp_00300-model.pth ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/b7608f0e6b477e50e03ca569ac5b04a841b9217f89d502862fc78fda4e46dec4
checkpoints/auido2pose_00140-model.pth ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/4fba6701852dc57efbed25b1e4276e4ff752941860d69fc4429f08a02326ebce
checkpoints/epoch_20.pth ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/6d17a6b23457b521801baae583cb6a58f7238fe6721fc3d65d76407460e9149b
checkpoints/facevid2vid_00189-model.pth.tar ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/fbad01d46f0510276dc4521322dde6824a873a4222cd0740c85762e7067ea71d
checkpoints/hub/checkpoints/2DFAN4-cd938726ad.zip ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/cd938726adb1f15f361263cce2db9cb820c42585fa8796ec72ce19107f369a46
checkpoints/hub/checkpoints/s3fd-619a316812.pth ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/619a31681264d3f7f7fc7a16a42cbbe8b23f31a256f75a366e5a1bcd59b33543
checkpoints/mapping_00229-model.pth.tar ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker-V002rc/blobs/62a1e06006cc963220f6477438518ed86e9788226c62ae382ddc42fbcefb83f1
checkpoints/shape_predictor_68_face_landmarks.dat ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f
checkpoints/wav2lip.pth ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/b78b681b68ad9fe6c6fb1debc6ff43ad05834a8af8a62ffc4167b7b34ef63c37
docs/logo.jpg ADDED
src/audio2exp_models/audio2exp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class Audio2Exp(nn.Module):
7
+ def __init__(self, netG, cfg, device, prepare_training_loss=False):
8
+ super(Audio2Exp, self).__init__()
9
+ self.cfg = cfg
10
+ self.device = device
11
+ self.netG = netG.to(device)
12
+
13
+ def test(self, batch):
14
+
15
+ mel_input = batch['indiv_mels'] # bs T 1 80 16
16
+ bs = mel_input.shape[0]
17
+ T = mel_input.shape[1]
18
+
19
+ exp_coeff_pred = []
20
+
21
+ for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
22
+
23
+ current_mel_input = mel_input[:,i:i+10]
24
+
25
+ #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
26
+ ref = batch['ref'][:, :, :64][:, i:i+10]
27
+ ratio = batch['ratio_gt'][:, i:i+10] #bs T
28
+
29
+ audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
30
+
31
+ curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
32
+
33
+ exp_coeff_pred += [curr_exp_coeff_pred]
34
+
35
+ # BS x T x 64
36
+ results_dict = {
37
+ 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
38
+ }
39
+ return results_dict
40
+
41
+
src/audio2exp_models/networks.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+ self.use_act = use_act
15
+
16
+ def forward(self, x):
17
+ out = self.conv_block(x)
18
+ if self.residual:
19
+ out += x
20
+
21
+ if self.use_act:
22
+ return self.act(out)
23
+ else:
24
+ return out
25
+
26
+ class SimpleWrapperV2(nn.Module):
27
+ def __init__(self) -> None:
28
+ super().__init__()
29
+ self.audio_encoder = nn.Sequential(
30
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
31
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
35
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
39
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
40
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
41
+
42
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
43
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
44
+
45
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
46
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
47
+ )
48
+
49
+ #### load the pre-trained audio_encoder
50
+ #self.audio_encoder = self.audio_encoder.to(device)
51
+ '''
52
+ wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
53
+ state_dict = self.audio_encoder.state_dict()
54
+
55
+ for k,v in wav2lip_state_dict.items():
56
+ if 'audio_encoder' in k:
57
+ print('init:', k)
58
+ state_dict[k.replace('module.audio_encoder.', '')] = v
59
+ self.audio_encoder.load_state_dict(state_dict)
60
+ '''
61
+
62
+ self.mapping1 = nn.Linear(512+64+1, 64)
63
+ #self.mapping2 = nn.Linear(30, 64)
64
+ #nn.init.constant_(self.mapping1.weight, 0.)
65
+ nn.init.constant_(self.mapping1.bias, 0.)
66
+
67
+ def forward(self, x, ref, ratio):
68
+ x = self.audio_encoder(x).view(x.size(0), -1)
69
+ ref_reshape = ref.reshape(x.size(0), -1)
70
+ ratio = ratio.reshape(x.size(0), -1)
71
+
72
+ y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
73
+ out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
74
+ return out
src/audio2pose_models/audio2pose.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from src.audio2pose_models.cvae import CVAE
4
+ from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
5
+ from src.audio2pose_models.audio_encoder import AudioEncoder
6
+
7
+ class Audio2Pose(nn.Module):
8
+ def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
9
+ super().__init__()
10
+ self.cfg = cfg
11
+ self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
12
+ self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
13
+ self.device = device
14
+
15
+ self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
16
+ self.audio_encoder.eval()
17
+ for param in self.audio_encoder.parameters():
18
+ param.requires_grad = False
19
+
20
+ self.netG = CVAE(cfg)
21
+ self.netD_motion = PoseSequenceDiscriminator(cfg)
22
+
23
+
24
+ def forward(self, x):
25
+
26
+ batch = {}
27
+ coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
28
+ batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
29
+ batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
30
+ batch['class'] = x['class'].squeeze(0).cuda() # bs
31
+ indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
32
+
33
+ # forward
34
+ audio_emb_list = []
35
+ audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
36
+ batch['audio_emb'] = audio_emb
37
+ batch = self.netG(batch)
38
+
39
+ pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
40
+ pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
41
+ pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
42
+
43
+ batch['pose_pred'] = pose_pred
44
+ batch['pose_gt'] = pose_gt
45
+
46
+ return batch
47
+
48
+ def test(self, x):
49
+
50
+ batch = {}
51
+ ref = x['ref'] #bs 1 70
52
+ batch['ref'] = x['ref'][:,0,-6:]
53
+ batch['class'] = x['class']
54
+ bs = ref.shape[0]
55
+
56
+ indiv_mels= x['indiv_mels'] # bs T 1 80 16
57
+ indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
58
+ num_frames = x['num_frames']
59
+ num_frames = int(num_frames) - 1
60
+
61
+ #
62
+ div = num_frames//self.seq_len
63
+ re = num_frames%self.seq_len
64
+ audio_emb_list = []
65
+ pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
66
+ device=batch['ref'].device)]
67
+
68
+ for i in range(div):
69
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
70
+ batch['z'] = z
71
+ audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
72
+ batch['audio_emb'] = audio_emb
73
+ batch = self.netG.test(batch)
74
+ pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
75
+
76
+ if re != 0:
77
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
78
+ batch['z'] = z
79
+ audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
80
+ if audio_emb.shape[1] != self.seq_len:
81
+ pad_dim = self.seq_len-audio_emb.shape[1]
82
+ pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
83
+ audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
84
+ batch['audio_emb'] = audio_emb
85
+ batch = self.netG.test(batch)
86
+ pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
87
+
88
+ pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
89
+ batch['pose_motion_pred'] = pose_motion_pred
90
+
91
+ pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
92
+
93
+ batch['pose_pred'] = pose_pred
94
+ return batch
src/audio2pose_models/audio_encoder.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+
15
+ def forward(self, x):
16
+ out = self.conv_block(x)
17
+ if self.residual:
18
+ out += x
19
+ return self.act(out)
20
+
21
+ class AudioEncoder(nn.Module):
22
+ def __init__(self, wav2lip_checkpoint, device):
23
+ super(AudioEncoder, self).__init__()
24
+
25
+ self.audio_encoder = nn.Sequential(
26
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
27
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
28
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
29
+
30
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
31
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
35
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
39
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
40
+
41
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
42
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43
+
44
+ #### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
45
+ # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
46
+ # state_dict = self.audio_encoder.state_dict()
47
+
48
+ # for k,v in wav2lip_state_dict.items():
49
+ # if 'audio_encoder' in k:
50
+ # state_dict[k.replace('module.audio_encoder.', '')] = v
51
+ # self.audio_encoder.load_state_dict(state_dict)
52
+
53
+
54
+ def forward(self, audio_sequences):
55
+ # audio_sequences = (B, T, 1, 80, 16)
56
+ B = audio_sequences.size(0)
57
+
58
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
59
+
60
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
61
+ dim = audio_embedding.shape[1]
62
+ audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
63
+
64
+ return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
src/audio2pose_models/cvae.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from src.audio2pose_models.res_unet import ResUnet
5
+
6
+ def class2onehot(idx, class_num):
7
+
8
+ assert torch.max(idx).item() < class_num
9
+ onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
10
+ onehot.scatter_(1, idx, 1)
11
+ return onehot
12
+
13
+ class CVAE(nn.Module):
14
+ def __init__(self, cfg):
15
+ super().__init__()
16
+ encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
17
+ decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
18
+ latent_size = cfg.MODEL.CVAE.LATENT_SIZE
19
+ num_classes = cfg.DATASET.NUM_CLASSES
20
+ audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
21
+ audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
22
+ seq_len = cfg.MODEL.CVAE.SEQ_LEN
23
+
24
+ self.latent_size = latent_size
25
+
26
+ self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
27
+ audio_emb_in_size, audio_emb_out_size, seq_len)
28
+ self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
29
+ audio_emb_in_size, audio_emb_out_size, seq_len)
30
+ def reparameterize(self, mu, logvar):
31
+ std = torch.exp(0.5 * logvar)
32
+ eps = torch.randn_like(std)
33
+ return mu + eps * std
34
+
35
+ def forward(self, batch):
36
+ batch = self.encoder(batch)
37
+ mu = batch['mu']
38
+ logvar = batch['logvar']
39
+ z = self.reparameterize(mu, logvar)
40
+ batch['z'] = z
41
+ return self.decoder(batch)
42
+
43
+ def test(self, batch):
44
+ '''
45
+ class_id = batch['class']
46
+ z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
47
+ batch['z'] = z
48
+ '''
49
+ return self.decoder(batch)
50
+
51
+ class ENCODER(nn.Module):
52
+ def __init__(self, layer_sizes, latent_size, num_classes,
53
+ audio_emb_in_size, audio_emb_out_size, seq_len):
54
+ super().__init__()
55
+
56
+ self.resunet = ResUnet()
57
+ self.num_classes = num_classes
58
+ self.seq_len = seq_len
59
+
60
+ self.MLP = nn.Sequential()
61
+ layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
62
+ for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
63
+ self.MLP.add_module(
64
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
65
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
66
+
67
+ self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
68
+ self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
69
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
70
+
71
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
72
+
73
+ def forward(self, batch):
74
+ class_id = batch['class']
75
+ pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
76
+ ref = batch['ref'] #bs 6
77
+ bs = pose_motion_gt.shape[0]
78
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
79
+
80
+ #pose encode
81
+ pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
82
+ pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
83
+
84
+ #audio mapping
85
+ print(audio_in.shape)
86
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
87
+ audio_out = audio_out.reshape(bs, -1)
88
+
89
+ class_bias = self.classbias[class_id] #bs latent_size
90
+ x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
91
+ x_out = self.MLP(x_in)
92
+
93
+ mu = self.linear_means(x_out)
94
+ logvar = self.linear_means(x_out) #bs latent_size
95
+
96
+ batch.update({'mu':mu, 'logvar':logvar})
97
+ return batch
98
+
99
+ class DECODER(nn.Module):
100
+ def __init__(self, layer_sizes, latent_size, num_classes,
101
+ audio_emb_in_size, audio_emb_out_size, seq_len):
102
+ super().__init__()
103
+
104
+ self.resunet = ResUnet()
105
+ self.num_classes = num_classes
106
+ self.seq_len = seq_len
107
+
108
+ self.MLP = nn.Sequential()
109
+ input_size = latent_size + seq_len*audio_emb_out_size + 6
110
+ for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
111
+ self.MLP.add_module(
112
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
113
+ if i+1 < len(layer_sizes):
114
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
115
+ else:
116
+ self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
117
+
118
+ self.pose_linear = nn.Linear(6, 6)
119
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
120
+
121
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
122
+
123
+ def forward(self, batch):
124
+
125
+ z = batch['z'] #bs latent_size
126
+ bs = z.shape[0]
127
+ class_id = batch['class']
128
+ ref = batch['ref'] #bs 6
129
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
130
+ #print('audio_in: ', audio_in[:, :, :10])
131
+
132
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
133
+ #print('audio_out: ', audio_out[:, :, :10])
134
+ audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
135
+ class_bias = self.classbias[class_id] #bs latent_size
136
+
137
+ z = z + class_bias
138
+ x_in = torch.cat([ref, z, audio_out], dim=-1)
139
+ x_out = self.MLP(x_in) # bs layer_sizes[-1]
140
+ x_out = x_out.reshape((bs, self.seq_len, -1))
141
+
142
+ #print('x_out: ', x_out)
143
+
144
+ pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
145
+
146
+ pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
147
+
148
+ batch.update({'pose_motion_pred':pose_motion_pred})
149
+ return batch
src/audio2pose_models/discriminator.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class ConvNormRelu(nn.Module):
6
+ def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
7
+ kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
8
+ super().__init__()
9
+ if kernel_size is None:
10
+ if downsample:
11
+ kernel_size, stride, padding = 4, 2, 1
12
+ else:
13
+ kernel_size, stride, padding = 3, 1, 1
14
+
15
+ if conv_type == '2d':
16
+ self.conv = nn.Conv2d(
17
+ in_channels,
18
+ out_channels,
19
+ kernel_size,
20
+ stride,
21
+ padding,
22
+ bias=False,
23
+ )
24
+ if norm == 'BN':
25
+ self.norm = nn.BatchNorm2d(out_channels)
26
+ elif norm == 'IN':
27
+ self.norm = nn.InstanceNorm2d(out_channels)
28
+ else:
29
+ raise NotImplementedError
30
+ elif conv_type == '1d':
31
+ self.conv = nn.Conv1d(
32
+ in_channels,
33
+ out_channels,
34
+ kernel_size,
35
+ stride,
36
+ padding,
37
+ bias=False,
38
+ )
39
+ if norm == 'BN':
40
+ self.norm = nn.BatchNorm1d(out_channels)
41
+ elif norm == 'IN':
42
+ self.norm = nn.InstanceNorm1d(out_channels)
43
+ else:
44
+ raise NotImplementedError
45
+ nn.init.kaiming_normal_(self.conv.weight)
46
+
47
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
48
+
49
+ def forward(self, x):
50
+ x = self.conv(x)
51
+ if isinstance(self.norm, nn.InstanceNorm1d):
52
+ x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
53
+ else:
54
+ x = self.norm(x)
55
+ x = self.act(x)
56
+ return x
57
+
58
+
59
+ class PoseSequenceDiscriminator(nn.Module):
60
+ def __init__(self, cfg):
61
+ super().__init__()
62
+ self.cfg = cfg
63
+ leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
64
+
65
+ self.seq = nn.Sequential(
66
+ ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
67
+ ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
68
+ ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
69
+ nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
70
+ )
71
+
72
+ def forward(self, x):
73
+ x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
74
+ x = self.seq(x)
75
+ x = x.squeeze(1)
76
+ return x
src/audio2pose_models/networks.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+
5
+ class ResidualConv(nn.Module):
6
+ def __init__(self, input_dim, output_dim, stride, padding):
7
+ super(ResidualConv, self).__init__()
8
+
9
+ self.conv_block = nn.Sequential(
10
+ nn.BatchNorm2d(input_dim),
11
+ nn.ReLU(),
12
+ nn.Conv2d(
13
+ input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
14
+ ),
15
+ nn.BatchNorm2d(output_dim),
16
+ nn.ReLU(),
17
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
18
+ )
19
+ self.conv_skip = nn.Sequential(
20
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
21
+ nn.BatchNorm2d(output_dim),
22
+ )
23
+
24
+ def forward(self, x):
25
+
26
+ return self.conv_block(x) + self.conv_skip(x)
27
+
28
+
29
+ class Upsample(nn.Module):
30
+ def __init__(self, input_dim, output_dim, kernel, stride):
31
+ super(Upsample, self).__init__()
32
+
33
+ self.upsample = nn.ConvTranspose2d(
34
+ input_dim, output_dim, kernel_size=kernel, stride=stride
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.upsample(x)
39
+
40
+
41
+ class Squeeze_Excite_Block(nn.Module):
42
+ def __init__(self, channel, reduction=16):
43
+ super(Squeeze_Excite_Block, self).__init__()
44
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
45
+ self.fc = nn.Sequential(
46
+ nn.Linear(channel, channel // reduction, bias=False),
47
+ nn.ReLU(inplace=True),
48
+ nn.Linear(channel // reduction, channel, bias=False),
49
+ nn.Sigmoid(),
50
+ )
51
+
52
+ def forward(self, x):
53
+ b, c, _, _ = x.size()
54
+ y = self.avg_pool(x).view(b, c)
55
+ y = self.fc(y).view(b, c, 1, 1)
56
+ return x * y.expand_as(x)
57
+
58
+
59
+ class ASPP(nn.Module):
60
+ def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
61
+ super(ASPP, self).__init__()
62
+
63
+ self.aspp_block1 = nn.Sequential(
64
+ nn.Conv2d(
65
+ in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
66
+ ),
67
+ nn.ReLU(inplace=True),
68
+ nn.BatchNorm2d(out_dims),
69
+ )
70
+ self.aspp_block2 = nn.Sequential(
71
+ nn.Conv2d(
72
+ in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
73
+ ),
74
+ nn.ReLU(inplace=True),
75
+ nn.BatchNorm2d(out_dims),
76
+ )
77
+ self.aspp_block3 = nn.Sequential(
78
+ nn.Conv2d(
79
+ in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
80
+ ),
81
+ nn.ReLU(inplace=True),
82
+ nn.BatchNorm2d(out_dims),
83
+ )
84
+
85
+ self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
86
+ self._init_weights()
87
+
88
+ def forward(self, x):
89
+ x1 = self.aspp_block1(x)
90
+ x2 = self.aspp_block2(x)
91
+ x3 = self.aspp_block3(x)
92
+ out = torch.cat([x1, x2, x3], dim=1)
93
+ return self.output(out)
94
+
95
+ def _init_weights(self):
96
+ for m in self.modules():
97
+ if isinstance(m, nn.Conv2d):
98
+ nn.init.kaiming_normal_(m.weight)
99
+ elif isinstance(m, nn.BatchNorm2d):
100
+ m.weight.data.fill_(1)
101
+ m.bias.data.zero_()
102
+
103
+
104
+ class Upsample_(nn.Module):
105
+ def __init__(self, scale=2):
106
+ super(Upsample_, self).__init__()
107
+
108
+ self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
109
+
110
+ def forward(self, x):
111
+ return self.upsample(x)
112
+
113
+
114
+ class AttentionBlock(nn.Module):
115
+ def __init__(self, input_encoder, input_decoder, output_dim):
116
+ super(AttentionBlock, self).__init__()
117
+
118
+ self.conv_encoder = nn.Sequential(
119
+ nn.BatchNorm2d(input_encoder),
120
+ nn.ReLU(),
121
+ nn.Conv2d(input_encoder, output_dim, 3, padding=1),
122
+ nn.MaxPool2d(2, 2),
123
+ )
124
+
125
+ self.conv_decoder = nn.Sequential(
126
+ nn.BatchNorm2d(input_decoder),
127
+ nn.ReLU(),
128
+ nn.Conv2d(input_decoder, output_dim, 3, padding=1),
129
+ )
130
+
131
+ self.conv_attn = nn.Sequential(
132
+ nn.BatchNorm2d(output_dim),
133
+ nn.ReLU(),
134
+ nn.Conv2d(output_dim, 1, 1),
135
+ )
136
+
137
+ def forward(self, x1, x2):
138
+ out = self.conv_encoder(x1) + self.conv_decoder(x2)
139
+ out = self.conv_attn(out)
140
+ return out * x2
src/audio2pose_models/res_unet.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from src.audio2pose_models.networks import ResidualConv, Upsample
4
+
5
+
6
+ class ResUnet(nn.Module):
7
+ def __init__(self, channel=1, filters=[32, 64, 128, 256]):
8
+ super(ResUnet, self).__init__()
9
+
10
+ self.input_layer = nn.Sequential(
11
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
12
+ nn.BatchNorm2d(filters[0]),
13
+ nn.ReLU(),
14
+ nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
15
+ )
16
+ self.input_skip = nn.Sequential(
17
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
18
+ )
19
+
20
+ self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
21
+ self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
22
+
23
+ self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
24
+
25
+ self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
26
+ self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
27
+
28
+ self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
29
+ self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
30
+
31
+ self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
32
+ self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
33
+
34
+ self.output_layer = nn.Sequential(
35
+ nn.Conv2d(filters[0], 1, 1, 1),
36
+ nn.Sigmoid(),
37
+ )
38
+
39
+ def forward(self, x):
40
+ # Encode
41
+ x1 = self.input_layer(x) + self.input_skip(x)
42
+ x2 = self.residual_conv_1(x1)
43
+ x3 = self.residual_conv_2(x2)
44
+ # Bridge
45
+ x4 = self.bridge(x3)
46
+
47
+ # Decode
48
+ x4 = self.upsample_1(x4)
49
+ x5 = torch.cat([x4, x3], dim=1)
50
+
51
+ x6 = self.up_residual_conv1(x5)
52
+
53
+ x6 = self.upsample_2(x6)
54
+ x7 = torch.cat([x6, x2], dim=1)
55
+
56
+ x8 = self.up_residual_conv2(x7)
57
+
58
+ x8 = self.upsample_3(x8)
59
+ x9 = torch.cat([x8, x1], dim=1)
60
+
61
+ x10 = self.up_residual_conv3(x9)
62
+
63
+ output = self.output_layer(x10)
64
+
65
+ return output
src/config/auido2exp.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
3
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
4
+ TRAIN_BATCH_SIZE: 32
5
+ EVAL_BATCH_SIZE: 32
6
+ EXP: True
7
+ EXP_DIM: 64
8
+ FRAME_LEN: 32
9
+ COEFF_LEN: 73
10
+ NUM_CLASSES: 46
11
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
13
+ LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
14
+ DEBUG: True
15
+ NUM_REPEATS: 2
16
+ T: 40
17
+
18
+
19
+ MODEL:
20
+ FRAMEWORK: V2
21
+ AUDIOENCODER:
22
+ LEAKY_RELU: True
23
+ NORM: 'IN'
24
+ DISCRIMINATOR:
25
+ LEAKY_RELU: False
26
+ INPUT_CHANNELS: 6
27
+ CVAE:
28
+ AUDIO_EMB_IN_SIZE: 512
29
+ AUDIO_EMB_OUT_SIZE: 128
30
+ SEQ_LEN: 32
31
+ LATENT_SIZE: 256
32
+ ENCODER_LAYER_SIZES: [192, 1024]
33
+ DECODER_LAYER_SIZES: [1024, 192]
34
+
35
+
36
+ TRAIN:
37
+ MAX_EPOCH: 300
38
+ GENERATOR:
39
+ LR: 2.0e-5
40
+ DISCRIMINATOR:
41
+ LR: 1.0e-5
42
+ LOSS:
43
+ W_FEAT: 0
44
+ W_COEFF_EXP: 2
45
+ W_LM: 1.0e-2
46
+ W_LM_MOUTH: 0
47
+ W_REG: 0
48
+ W_SYNC: 0
49
+ W_COLOR: 0
50
+ W_EXPRESSION: 0
51
+ W_LIPREADING: 0.01
52
+ W_LIPREADING_VV: 0
53
+ W_EYE_BLINK: 4
54
+
55
+ TAG:
56
+ NAME: small_dataset
57
+
58
+
src/config/auido2pose.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
3
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
4
+ TRAIN_BATCH_SIZE: 64
5
+ EVAL_BATCH_SIZE: 1
6
+ EXP: True
7
+ EXP_DIM: 64
8
+ FRAME_LEN: 32
9
+ COEFF_LEN: 73
10
+ NUM_CLASSES: 46
11
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
13
+ DEBUG: True
14
+
15
+
16
+ MODEL:
17
+ AUDIOENCODER:
18
+ LEAKY_RELU: True
19
+ NORM: 'IN'
20
+ DISCRIMINATOR:
21
+ LEAKY_RELU: False
22
+ INPUT_CHANNELS: 6
23
+ CVAE:
24
+ AUDIO_EMB_IN_SIZE: 512
25
+ AUDIO_EMB_OUT_SIZE: 6
26
+ SEQ_LEN: 32
27
+ LATENT_SIZE: 64
28
+ ENCODER_LAYER_SIZES: [192, 128]
29
+ DECODER_LAYER_SIZES: [128, 192]
30
+
31
+
32
+ TRAIN:
33
+ MAX_EPOCH: 150
34
+ GENERATOR:
35
+ LR: 1.0e-4
36
+ DISCRIMINATOR:
37
+ LR: 1.0e-4
38
+ LOSS:
39
+ LAMBDA_REG: 1
40
+ LAMBDA_LANDMARKS: 0
41
+ LAMBDA_VERTICES: 0
42
+ LAMBDA_GAN_MOTION: 0.7
43
+ LAMBDA_GAN_COEFF: 0
44
+ LAMBDA_KL: 1
45
+
46
+ TAG:
47
+ NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
48
+
49
+
src/config/facerender.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_params:
2
+ common_params:
3
+ num_kp: 15
4
+ image_channel: 3
5
+ feature_channel: 32
6
+ estimate_jacobian: False # True
7
+ kp_detector_params:
8
+ temperature: 0.1
9
+ block_expansion: 32
10
+ max_features: 1024
11
+ scale_factor: 0.25 # 0.25
12
+ num_blocks: 5
13
+ reshape_channel: 16384 # 16384 = 1024 * 16
14
+ reshape_depth: 16
15
+ he_estimator_params:
16
+ block_expansion: 64
17
+ max_features: 2048
18
+ num_bins: 66
19
+ generator_params:
20
+ block_expansion: 64
21
+ max_features: 512
22
+ num_down_blocks: 2
23
+ reshape_channel: 32
24
+ reshape_depth: 16 # 512 = 32 * 16
25
+ num_resblocks: 6
26
+ estimate_occlusion_map: True
27
+ dense_motion_params:
28
+ block_expansion: 32
29
+ max_features: 1024
30
+ num_blocks: 5
31
+ reshape_depth: 16
32
+ compress: 4
33
+ discriminator_params:
34
+ scales: [1]
35
+ block_expansion: 32
36
+ max_features: 512
37
+ num_blocks: 4
38
+ sn: True
39
+ mapping_params:
40
+ coeff_nc: 70
41
+ descriptor_nc: 1024
42
+ layer: 3
43
+ num_kp: 15
44
+ num_bins: 66
45
+
src/config/facerender_pirender.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How often do you want to log the training stats.
2
+ # network_list:
3
+ # gen: gen_optimizer
4
+ # dis: dis_optimizer
5
+
6
+ distributed: False
7
+ image_to_tensorboard: True
8
+ snapshot_save_iter: 40000
9
+ snapshot_save_epoch: 20
10
+ snapshot_save_start_iter: 20000
11
+ snapshot_save_start_epoch: 10
12
+ image_save_iter: 1000
13
+ max_epoch: 200
14
+ logging_iter: 100
15
+ results_dir: ./eval_results
16
+
17
+ gen_optimizer:
18
+ type: adam
19
+ lr: 0.0001
20
+ adam_beta1: 0.5
21
+ adam_beta2: 0.999
22
+ lr_policy:
23
+ iteration_mode: True
24
+ type: step
25
+ step_size: 300000
26
+ gamma: 0.2
27
+
28
+ trainer:
29
+ type: trainers.face_trainer::FaceTrainer
30
+ pretrain_warp_iteration: 200000
31
+ loss_weight:
32
+ weight_perceptual_warp: 2.5
33
+ weight_perceptual_final: 4
34
+ vgg_param_warp:
35
+ network: vgg19
36
+ layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
37
+ use_style_loss: False
38
+ num_scales: 4
39
+ vgg_param_final:
40
+ network: vgg19
41
+ layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
42
+ use_style_loss: True
43
+ num_scales: 4
44
+ style_to_perceptual: 250
45
+ init:
46
+ type: 'normal'
47
+ gain: 0.02
48
+ gen:
49
+ type: generators.face_model::FaceGenerator
50
+ param:
51
+ mapping_net:
52
+ coeff_nc: 73
53
+ descriptor_nc: 256
54
+ layer: 3
55
+ warpping_net:
56
+ encoder_layer: 5
57
+ decoder_layer: 3
58
+ base_nc: 32
59
+ editing_net:
60
+ layer: 3
61
+ num_res_blocks: 2
62
+ base_nc: 64
63
+ common:
64
+ image_nc: 3
65
+ descriptor_nc: 256
66
+ max_nc: 256
67
+ use_spect: False
68
+
69
+
70
+ # Data options.
71
+ data:
72
+ type: data.vox_dataset::VoxDataset
73
+ path: ./dataset/vox_lmdb
74
+ resolution: 256
75
+ semantic_radius: 13
76
+ train:
77
+ batch_size: 5
78
+ distributed: True
79
+ val:
80
+ batch_size: 8
81
+ distributed: True
82
+
83
+
src/config/facerender_still.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_params:
2
+ common_params:
3
+ num_kp: 15
4
+ image_channel: 3
5
+ feature_channel: 32
6
+ estimate_jacobian: False # True
7
+ kp_detector_params:
8
+ temperature: 0.1
9
+ block_expansion: 32
10
+ max_features: 1024
11
+ scale_factor: 0.25 # 0.25
12
+ num_blocks: 5
13
+ reshape_channel: 16384 # 16384 = 1024 * 16
14
+ reshape_depth: 16
15
+ he_estimator_params:
16
+ block_expansion: 64
17
+ max_features: 2048
18
+ num_bins: 66
19
+ generator_params:
20
+ block_expansion: 64
21
+ max_features: 512
22
+ num_down_blocks: 2
23
+ reshape_channel: 32
24
+ reshape_depth: 16 # 512 = 32 * 16
25
+ num_resblocks: 6
26
+ estimate_occlusion_map: True
27
+ dense_motion_params:
28
+ block_expansion: 32
29
+ max_features: 1024
30
+ num_blocks: 5
31
+ reshape_depth: 16
32
+ compress: 4
33
+ discriminator_params:
34
+ scales: [1]
35
+ block_expansion: 32
36
+ max_features: 512
37
+ num_blocks: 4
38
+ sn: True
39
+ mapping_params:
40
+ coeff_nc: 73
41
+ descriptor_nc: 1024
42
+ layer: 3
43
+ num_kp: 15
44
+ num_bins: 66
45
+
src/config/similarity_Lm3D_all.mat ADDED
Binary file (994 Bytes). View file
 
src/face3d/data/__init__.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package includes all the modules related to data loading and preprocessing
2
+
3
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
+ You need to implement four functions:
5
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
+ -- <__len__>: return the size of dataset.
7
+ -- <__getitem__>: get a data point from data loader.
8
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
+
10
+ Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
+ See our template dataset class 'template_dataset.py' for more details.
12
+ """
13
+ import numpy as np
14
+ import importlib
15
+ import torch.utils.data
16
+ from face3d.data.base_dataset import BaseDataset
17
+
18
+
19
+ def find_dataset_using_name(dataset_name):
20
+ """Import the module "data/[dataset_name]_dataset.py".
21
+
22
+ In the file, the class called DatasetNameDataset() will
23
+ be instantiated. It has to be a subclass of BaseDataset,
24
+ and it is case-insensitive.
25
+ """
26
+ dataset_filename = "data." + dataset_name + "_dataset"
27
+ datasetlib = importlib.import_module(dataset_filename)
28
+
29
+ dataset = None
30
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
31
+ for name, cls in datasetlib.__dict__.items():
32
+ if name.lower() == target_dataset_name.lower() \
33
+ and issubclass(cls, BaseDataset):
34
+ dataset = cls
35
+
36
+ if dataset is None:
37
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
38
+
39
+ return dataset
40
+
41
+
42
+ def get_option_setter(dataset_name):
43
+ """Return the static method <modify_commandline_options> of the dataset class."""
44
+ dataset_class = find_dataset_using_name(dataset_name)
45
+ return dataset_class.modify_commandline_options
46
+
47
+
48
+ def create_dataset(opt, rank=0):
49
+ """Create a dataset given the option.
50
+
51
+ This function wraps the class CustomDatasetDataLoader.
52
+ This is the main interface between this package and 'train.py'/'test.py'
53
+
54
+ Example:
55
+ >>> from data import create_dataset
56
+ >>> dataset = create_dataset(opt)
57
+ """
58
+ data_loader = CustomDatasetDataLoader(opt, rank=rank)
59
+ dataset = data_loader.load_data()
60
+ return dataset
61
+
62
+ class CustomDatasetDataLoader():
63
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
+
65
+ def __init__(self, opt, rank=0):
66
+ """Initialize this class
67
+
68
+ Step 1: create a dataset instance given the name [dataset_mode]
69
+ Step 2: create a multi-threaded data loader.
70
+ """
71
+ self.opt = opt
72
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
73
+ self.dataset = dataset_class(opt)
74
+ self.sampler = None
75
+ print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
76
+ if opt.use_ddp and opt.isTrain:
77
+ world_size = opt.world_size
78
+ self.sampler = torch.utils.data.distributed.DistributedSampler(
79
+ self.dataset,
80
+ num_replicas=world_size,
81
+ rank=rank,
82
+ shuffle=not opt.serial_batches
83
+ )
84
+ self.dataloader = torch.utils.data.DataLoader(
85
+ self.dataset,
86
+ sampler=self.sampler,
87
+ num_workers=int(opt.num_threads / world_size),
88
+ batch_size=int(opt.batch_size / world_size),
89
+ drop_last=True)
90
+ else:
91
+ self.dataloader = torch.utils.data.DataLoader(
92
+ self.dataset,
93
+ batch_size=opt.batch_size,
94
+ shuffle=(not opt.serial_batches) and opt.isTrain,
95
+ num_workers=int(opt.num_threads),
96
+ drop_last=True
97
+ )
98
+
99
+ def set_epoch(self, epoch):
100
+ self.dataset.current_epoch = epoch
101
+ if self.sampler is not None:
102
+ self.sampler.set_epoch(epoch)
103
+
104
+ def load_data(self):
105
+ return self
106
+
107
+ def __len__(self):
108
+ """Return the number of data in the dataset"""
109
+ return min(len(self.dataset), self.opt.max_dataset_size)
110
+
111
+ def __iter__(self):
112
+ """Return a batch of data"""
113
+ for i, data in enumerate(self.dataloader):
114
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
115
+ break
116
+ yield data
src/face3d/data/base_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
+
3
+ It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
+ """
5
+ import random
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+ from abc import ABC, abstractmethod
11
+
12
+
13
+ class BaseDataset(data.Dataset, ABC):
14
+ """This class is an abstract base class (ABC) for datasets.
15
+
16
+ To create a subclass, you need to implement the following four functions:
17
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
+ -- <__len__>: return the size of dataset.
19
+ -- <__getitem__>: get a data point.
20
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ """Initialize the class; save the options in the class
25
+
26
+ Parameters:
27
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
+ """
29
+ self.opt = opt
30
+ # self.root = opt.dataroot
31
+ self.current_epoch = 0
32
+
33
+ @staticmethod
34
+ def modify_commandline_options(parser, is_train):
35
+ """Add new dataset-specific options, and rewrite default values for existing options.
36
+
37
+ Parameters:
38
+ parser -- original option parser
39
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40
+
41
+ Returns:
42
+ the modified parser.
43
+ """
44
+ return parser
45
+
46
+ @abstractmethod
47
+ def __len__(self):
48
+ """Return the total number of images in the dataset."""
49
+ return 0
50
+
51
+ @abstractmethod
52
+ def __getitem__(self, index):
53
+ """Return a data point and its metadata information.
54
+
55
+ Parameters:
56
+ index - - a random integer for data indexing
57
+
58
+ Returns:
59
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60
+ """
61
+ pass
62
+
63
+
64
+ def get_transform(grayscale=False):
65
+ transform_list = []
66
+ if grayscale:
67
+ transform_list.append(transforms.Grayscale(1))
68
+ transform_list += [transforms.ToTensor()]
69
+ return transforms.Compose(transform_list)
70
+
71
+ def get_affine_mat(opt, size):
72
+ shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
73
+ w, h = size
74
+
75
+ if 'shift' in opt.preprocess:
76
+ shift_pixs = int(opt.shift_pixs)
77
+ shift_x = random.randint(-shift_pixs, shift_pixs)
78
+ shift_y = random.randint(-shift_pixs, shift_pixs)
79
+ if 'scale' in opt.preprocess:
80
+ scale = 1 + opt.scale_delta * (2 * random.random() - 1)
81
+ if 'rot' in opt.preprocess:
82
+ rot_angle = opt.rot_angle * (2 * random.random() - 1)
83
+ rot_rad = -rot_angle * np.pi/180
84
+ if 'flip' in opt.preprocess:
85
+ flip = random.random() > 0.5
86
+
87
+ shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
88
+ flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
89
+ shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
90
+ rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
91
+ scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
92
+ shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
93
+
94
+ affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
95
+ affine_inv = np.linalg.inv(affine)
96
+ return affine, affine_inv, flip
97
+
98
+ def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
99
+ return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
100
+
101
+ def apply_lm_affine(landmark, affine, flip, size):
102
+ _, h = size
103
+ lm = landmark.copy()
104
+ lm[:, 1] = h - 1 - lm[:, 1]
105
+ lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
106
+ lm = lm @ np.transpose(affine)
107
+ lm[:, :2] = lm[:, :2] / lm[:, 2:]
108
+ lm = lm[:, :2]
109
+ lm[:, 1] = h - 1 - lm[:, 1]
110
+ if flip:
111
+ lm_ = lm.copy()
112
+ lm_[:17] = lm[16::-1]
113
+ lm_[17:22] = lm[26:21:-1]
114
+ lm_[22:27] = lm[21:16:-1]
115
+ lm_[31:36] = lm[35:30:-1]
116
+ lm_[36:40] = lm[45:41:-1]
117
+ lm_[40:42] = lm[47:45:-1]
118
+ lm_[42:46] = lm[39:35:-1]
119
+ lm_[46:48] = lm[41:39:-1]
120
+ lm_[48:55] = lm[54:47:-1]
121
+ lm_[55:60] = lm[59:54:-1]
122
+ lm_[60:65] = lm[64:59:-1]
123
+ lm_[65:68] = lm[67:64:-1]
124
+ lm = lm_
125
+ return lm
src/face3d/data/flist_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script defines the custom dataset for Deep3DFaceRecon_pytorch
2
+ """
3
+
4
+ import os.path
5
+ from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
6
+ from data.image_folder import make_dataset
7
+ from PIL import Image
8
+ import random
9
+ import util.util as util
10
+ import numpy as np
11
+ import json
12
+ import torch
13
+ from scipy.io import loadmat, savemat
14
+ import pickle
15
+ from util.preprocess import align_img, estimate_norm
16
+ from util.load_mats import load_lm3d
17
+
18
+
19
+ def default_flist_reader(flist):
20
+ """
21
+ flist format: impath label\nimpath label\n ...(same to caffe's filelist)
22
+ """
23
+ imlist = []
24
+ with open(flist, 'r') as rf:
25
+ for line in rf.readlines():
26
+ impath = line.strip()
27
+ imlist.append(impath)
28
+
29
+ return imlist
30
+
31
+ def jason_flist_reader(flist):
32
+ with open(flist, 'r') as fp:
33
+ info = json.load(fp)
34
+ return info
35
+
36
+ def parse_label(label):
37
+ return torch.tensor(np.array(label).astype(np.float32))
38
+
39
+
40
+ class FlistDataset(BaseDataset):
41
+ """
42
+ It requires one directories to host training images '/path/to/data/train'
43
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
44
+ """
45
+
46
+ def __init__(self, opt):
47
+ """Initialize this dataset class.
48
+
49
+ Parameters:
50
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
51
+ """
52
+ BaseDataset.__init__(self, opt)
53
+
54
+ self.lm3d_std = load_lm3d(opt.bfm_folder)
55
+
56
+ msk_names = default_flist_reader(opt.flist)
57
+ self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
58
+
59
+ self.size = len(self.msk_paths)
60
+ self.opt = opt
61
+
62
+ self.name = 'train' if opt.isTrain else 'val'
63
+ if '_' in opt.flist:
64
+ self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
65
+
66
+
67
+ def __getitem__(self, index):
68
+ """Return a data point and its metadata information.
69
+
70
+ Parameters:
71
+ index (int) -- a random integer for data indexing
72
+
73
+ Returns a dictionary that contains A, B, A_paths and B_paths
74
+ img (tensor) -- an image in the input domain
75
+ msk (tensor) -- its corresponding attention mask
76
+ lm (tensor) -- its corresponding 3d landmarks
77
+ im_paths (str) -- image paths
78
+ aug_flag (bool) -- a flag used to tell whether its raw or augmented
79
+ """
80
+ msk_path = self.msk_paths[index % self.size] # make sure index is within then range
81
+ img_path = msk_path.replace('mask/', '')
82
+ lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
83
+
84
+ raw_img = Image.open(img_path).convert('RGB')
85
+ raw_msk = Image.open(msk_path).convert('RGB')
86
+ raw_lm = np.loadtxt(lm_path).astype(np.float32)
87
+
88
+ _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
89
+
90
+ aug_flag = self.opt.use_aug and self.opt.isTrain
91
+ if aug_flag:
92
+ img, lm, msk = self._augmentation(img, lm, self.opt, msk)
93
+
94
+ _, H = img.size
95
+ M = estimate_norm(lm, H)
96
+ transform = get_transform()
97
+ img_tensor = transform(img)
98
+ msk_tensor = transform(msk)[:1, ...]
99
+ lm_tensor = parse_label(lm)
100
+ M_tensor = parse_label(M)
101
+
102
+
103
+ return {'imgs': img_tensor,
104
+ 'lms': lm_tensor,
105
+ 'msks': msk_tensor,
106
+ 'M': M_tensor,
107
+ 'im_paths': img_path,
108
+ 'aug_flag': aug_flag,
109
+ 'dataset': self.name}
110
+
111
+ def _augmentation(self, img, lm, opt, msk=None):
112
+ affine, affine_inv, flip = get_affine_mat(opt, img.size)
113
+ img = apply_img_affine(img, affine_inv)
114
+ lm = apply_lm_affine(lm, affine, flip, img.size)
115
+ if msk is not None:
116
+ msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
117
+ return img, lm, msk
118
+
119
+
120
+
121
+
122
+ def __len__(self):
123
+ """Return the total number of images in the dataset.
124
+ """
125
+ return self.size
src/face3d/data/image_folder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A modified image folder class
2
+
3
+ We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
+ so that this class can load images from both current directory and its subdirectories.
5
+ """
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+
9
+ from PIL import Image
10
+ import os
11
+ import os.path
12
+
13
+ IMG_EXTENSIONS = [
14
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
15
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16
+ '.tif', '.TIF', '.tiff', '.TIFF',
17
+ ]
18
+
19
+
20
+ def is_image_file(filename):
21
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22
+
23
+
24
+ def make_dataset(dir, max_dataset_size=float("inf")):
25
+ images = []
26
+ assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
27
+
28
+ for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
29
+ for fname in fnames:
30
+ if is_image_file(fname):
31
+ path = os.path.join(root, fname)
32
+ images.append(path)
33
+ return images[:min(max_dataset_size, len(images))]
34
+
35
+
36
+ def default_loader(path):
37
+ return Image.open(path).convert('RGB')
38
+
39
+
40
+ class ImageFolder(data.Dataset):
41
+
42
+ def __init__(self, root, transform=None, return_paths=False,
43
+ loader=default_loader):
44
+ imgs = make_dataset(root)
45
+ if len(imgs) == 0:
46
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
47
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48
+
49
+ self.root = root
50
+ self.imgs = imgs
51
+ self.transform = transform
52
+ self.return_paths = return_paths
53
+ self.loader = loader
54
+
55
+ def __getitem__(self, index):
56
+ path = self.imgs[index]
57
+ img = self.loader(path)
58
+ if self.transform is not None:
59
+ img = self.transform(img)
60
+ if self.return_paths:
61
+ return img, path
62
+ else:
63
+ return img
64
+
65
+ def __len__(self):
66
+ return len(self.imgs)
src/face3d/data/template_dataset.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset class template
2
+
3
+ This module provides a template for users to implement custom datasets.
4
+ You can specify '--dataset_mode template' to use this dataset.
5
+ The class name should be consistent with both the filename and its dataset_mode option.
6
+ The filename should be <dataset_mode>_dataset.py
7
+ The class name should be <Dataset_mode>Dataset.py
8
+ You need to implement the following functions:
9
+ -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
10
+ -- <__init__>: Initialize this dataset class.
11
+ -- <__getitem__>: Return a data point and its metadata information.
12
+ -- <__len__>: Return the number of images.
13
+ """
14
+ from data.base_dataset import BaseDataset, get_transform
15
+ # from data.image_folder import make_dataset
16
+ # from PIL import Image
17
+
18
+
19
+ class TemplateDataset(BaseDataset):
20
+ """A template dataset class for you to implement custom datasets."""
21
+ @staticmethod
22
+ def modify_commandline_options(parser, is_train):
23
+ """Add new dataset-specific options, and rewrite default values for existing options.
24
+
25
+ Parameters:
26
+ parser -- original option parser
27
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
28
+
29
+ Returns:
30
+ the modified parser.
31
+ """
32
+ parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
33
+ parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
34
+ return parser
35
+
36
+ def __init__(self, opt):
37
+ """Initialize this dataset class.
38
+
39
+ Parameters:
40
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
41
+
42
+ A few things can be done here.
43
+ - save the options (have been done in BaseDataset)
44
+ - get image paths and meta information of the dataset.
45
+ - define the image transformation.
46
+ """
47
+ # save the option and dataset root
48
+ BaseDataset.__init__(self, opt)
49
+ # get the image paths of your dataset;
50
+ self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
51
+ # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
52
+ self.transform = get_transform(opt)
53
+
54
+ def __getitem__(self, index):
55
+ """Return a data point and its metadata information.
56
+
57
+ Parameters:
58
+ index -- a random integer for data indexing
59
+
60
+ Returns:
61
+ a dictionary of data with their names. It usually contains the data itself and its metadata information.
62
+
63
+ Step 1: get a random image path: e.g., path = self.image_paths[index]
64
+ Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
65
+ Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
66
+ Step 4: return a data point as a dictionary.
67
+ """
68
+ path = 'temp' # needs to be a string
69
+ data_A = None # needs to be a tensor
70
+ data_B = None # needs to be a tensor
71
+ return {'data_A': data_A, 'data_B': data_B, 'path': path}
72
+
73
+ def __len__(self):
74
+ """Return the total number of images."""
75
+ return len(self.image_paths)
src/face3d/extract_kp_videos.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import argparse
6
+ import face_alignment
7
+ import numpy as np
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ from itertools import cycle
11
+
12
+ from torch.multiprocessing import Pool, Process, set_start_method
13
+
14
+ class KeypointExtractor():
15
+ def __init__(self, device):
16
+ self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
17
+ device=device)
18
+
19
+ def extract_keypoint(self, images, name=None, info=True):
20
+ if isinstance(images, list):
21
+ keypoints = []
22
+ if info:
23
+ i_range = tqdm(images,desc='landmark Det:')
24
+ else:
25
+ i_range = images
26
+
27
+ for image in i_range:
28
+ current_kp = self.extract_keypoint(image)
29
+ if np.mean(current_kp) == -1 and keypoints:
30
+ keypoints.append(keypoints[-1])
31
+ else:
32
+ keypoints.append(current_kp[None])
33
+
34
+ keypoints = np.concatenate(keypoints, 0)
35
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
36
+ return keypoints
37
+ else:
38
+ while True:
39
+ try:
40
+ keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
41
+ break
42
+ except RuntimeError as e:
43
+ if str(e).startswith('CUDA'):
44
+ print("Warning: out of memory, sleep for 1s")
45
+ time.sleep(1)
46
+ else:
47
+ print(e)
48
+ break
49
+ except TypeError:
50
+ print('No face detected in this image')
51
+ shape = [68, 2]
52
+ keypoints = -1. * np.ones(shape)
53
+ break
54
+ if name is not None:
55
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
56
+ return keypoints
57
+
58
+ def read_video(filename):
59
+ frames = []
60
+ cap = cv2.VideoCapture(filename)
61
+ while cap.isOpened():
62
+ ret, frame = cap.read()
63
+ if ret:
64
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65
+ frame = Image.fromarray(frame)
66
+ frames.append(frame)
67
+ else:
68
+ break
69
+ cap.release()
70
+ return frames
71
+
72
+ def run(data):
73
+ filename, opt, device = data
74
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
75
+ kp_extractor = KeypointExtractor()
76
+ images = read_video(filename)
77
+ name = filename.split('/')[-2:]
78
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
79
+ kp_extractor.extract_keypoint(
80
+ images,
81
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
82
+ )
83
+
84
+ if __name__ == '__main__':
85
+ set_start_method('spawn')
86
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
87
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
88
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
89
+ parser.add_argument('--device_ids', type=str, default='0,1')
90
+ parser.add_argument('--workers', type=int, default=4)
91
+
92
+ opt = parser.parse_args()
93
+ filenames = list()
94
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
95
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
96
+ extensions = VIDEO_EXTENSIONS
97
+
98
+ for ext in extensions:
99
+ os.listdir(f'{opt.input_dir}')
100
+ print(f'{opt.input_dir}/*.{ext}')
101
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
102
+ print('Total number of videos:', len(filenames))
103
+ pool = Pool(opt.workers)
104
+ args_list = cycle([opt])
105
+ device_ids = opt.device_ids.split(",")
106
+ device_ids = cycle(device_ids)
107
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
108
+ None
src/face3d/extract_kp_videos_safe.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import argparse
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch
9
+ from tqdm import tqdm
10
+ from itertools import cycle
11
+ from torch.multiprocessing import Pool, Process, set_start_method
12
+
13
+ from facexlib.alignment import landmark_98_to_68
14
+ from facexlib.detection import init_detection_model
15
+
16
+ from facexlib.utils import load_file_from_url
17
+ from facexlib.alignment.awing_arch import FAN
18
+
19
+ def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
20
+ if model_name == 'awing_fan':
21
+ model = FAN(num_modules=4, num_landmarks=98, device=device)
22
+ model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
23
+ else:
24
+ raise NotImplementedError(f'{model_name} is not implemented.')
25
+
26
+ model_path = load_file_from_url(
27
+ url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
28
+ model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)
29
+ model.eval()
30
+ model = model.to(device)
31
+ return model
32
+
33
+
34
+ class KeypointExtractor():
35
+ def __init__(self, device='cuda'):
36
+
37
+ ### gfpgan/weights
38
+ try:
39
+ import webui # in webui
40
+ root_path = 'extensions/SadTalker/gfpgan/weights'
41
+
42
+ except:
43
+ root_path = 'gfpgan/weights'
44
+
45
+ self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
46
+ self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
47
+
48
+ def extract_keypoint(self, images, name=None, info=True):
49
+ if isinstance(images, list):
50
+ keypoints = []
51
+ if info:
52
+ i_range = tqdm(images,desc='landmark Det:')
53
+ else:
54
+ i_range = images
55
+
56
+ for image in i_range:
57
+ current_kp = self.extract_keypoint(image)
58
+ # current_kp = self.detector.get_landmarks(np.array(image))
59
+ if np.mean(current_kp) == -1 and keypoints:
60
+ keypoints.append(keypoints[-1])
61
+ else:
62
+ keypoints.append(current_kp[None])
63
+
64
+ keypoints = np.concatenate(keypoints, 0)
65
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
66
+ return keypoints
67
+ else:
68
+ while True:
69
+ try:
70
+ with torch.no_grad():
71
+ # face detection -> face alignment.
72
+ img = np.array(images)
73
+ bboxes = self.det_net.detect_faces(images, 0.97)
74
+
75
+ bboxes = bboxes[0]
76
+ img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
77
+
78
+ keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
79
+
80
+ #### keypoints to the original location
81
+ keypoints[:,0] += int(bboxes[0])
82
+ keypoints[:,1] += int(bboxes[1])
83
+
84
+ break
85
+ except RuntimeError as e:
86
+ if str(e).startswith('CUDA'):
87
+ print("Warning: out of memory, sleep for 1s")
88
+ time.sleep(1)
89
+ else:
90
+ print(e)
91
+ break
92
+ except TypeError:
93
+ print('No face detected in this image')
94
+ shape = [68, 2]
95
+ keypoints = -1. * np.ones(shape)
96
+ break
97
+ if name is not None:
98
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
99
+ return keypoints
100
+
101
+ def read_video(filename):
102
+ frames = []
103
+ cap = cv2.VideoCapture(filename)
104
+ while cap.isOpened():
105
+ ret, frame = cap.read()
106
+ if ret:
107
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
108
+ frame = Image.fromarray(frame)
109
+ frames.append(frame)
110
+ else:
111
+ break
112
+ cap.release()
113
+ return frames
114
+
115
+ def run(data):
116
+ filename, opt, device = data
117
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
118
+ kp_extractor = KeypointExtractor()
119
+ images = read_video(filename)
120
+ name = filename.split('/')[-2:]
121
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
122
+ kp_extractor.extract_keypoint(
123
+ images,
124
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
125
+ )
126
+
127
+ if __name__ == '__main__':
128
+ set_start_method('spawn')
129
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
130
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
131
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
132
+ parser.add_argument('--device_ids', type=str, default='0,1')
133
+ parser.add_argument('--workers', type=int, default=4)
134
+
135
+ opt = parser.parse_args()
136
+ filenames = list()
137
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
138
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
139
+ extensions = VIDEO_EXTENSIONS
140
+
141
+ for ext in extensions:
142
+ os.listdir(f'{opt.input_dir}')
143
+ print(f'{opt.input_dir}/*.{ext}')
144
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
145
+ print('Total number of videos:', len(filenames))
146
+ pool = Pool(opt.workers)
147
+ args_list = cycle([opt])
148
+ device_ids = opt.device_ids.split(",")
149
+ device_ids = cycle(device_ids)
150
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
151
+ None
src/face3d/models/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ See our template model class 'template_model.py' for more details.
19
+ """
20
+
21
+ import importlib
22
+ from src.face3d.models.base_model import BaseModel
23
+
24
+
25
+ def find_model_using_name(model_name):
26
+ """Import the module "models/[model_name]_model.py".
27
+
28
+ In the file, the class called DatasetNameModel() will
29
+ be instantiated. It has to be a subclass of BaseModel,
30
+ and it is case-insensitive.
31
+ """
32
+ model_filename = "face3d.models." + model_name + "_model"
33
+ modellib = importlib.import_module(model_filename)
34
+ model = None
35
+ target_model_name = model_name.replace('_', '') + 'model'
36
+ for name, cls in modellib.__dict__.items():
37
+ if name.lower() == target_model_name.lower() \
38
+ and issubclass(cls, BaseModel):
39
+ model = cls
40
+
41
+ if model is None:
42
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
+ exit(0)
44
+
45
+ return model
46
+
47
+
48
+ def get_option_setter(model_name):
49
+ """Return the static method <modify_commandline_options> of the model class."""
50
+ model_class = find_model_using_name(model_name)
51
+ return model_class.modify_commandline_options
52
+
53
+
54
+ def create_model(opt):
55
+ """Create a model given the option.
56
+
57
+ This function warps the class CustomDatasetDataLoader.
58
+ This is the main interface between this package and 'train.py'/'test.py'
59
+
60
+ Example:
61
+ >>> from models import create_model
62
+ >>> model = create_model(opt)
63
+ """
64
+ model = find_model_using_name(opt.model)
65
+ instance = model(opt)
66
+ print("model [%s] was created" % type(instance).__name__)
67
+ return instance
src/face3d/models/arcface_torch/README.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Distributed Arcface Training in Pytorch
2
+
3
+ This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
4
+ identity on a single server.
5
+
6
+ ## Requirements
7
+
8
+ - Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
9
+ - `pip install -r requirements.txt`.
10
+ - Download the dataset
11
+ from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
12
+ .
13
+
14
+ ## How to Training
15
+
16
+ To train a model, run `train.py` with the path to the configs:
17
+
18
+ ### 1. Single node, 8 GPUs:
19
+
20
+ ```shell
21
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
22
+ ```
23
+
24
+ ### 2. Multiple nodes, each node 8 GPUs:
25
+
26
+ Node 0:
27
+
28
+ ```shell
29
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
30
+ ```
31
+
32
+ Node 1:
33
+
34
+ ```shell
35
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
36
+ ```
37
+
38
+ ### 3.Training resnet2060 with 8 GPUs:
39
+
40
+ ```shell
41
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
42
+ ```
43
+
44
+ ## Model Zoo
45
+
46
+ - The models are available for non-commercial research purposes only.
47
+ - All models can be found in here.
48
+ - [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
49
+ - [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
50
+
51
+ ### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
52
+
53
+ ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
54
+ recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
55
+ As the result, we can evaluate the FAIR performance for different algorithms.
56
+
57
+ For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
58
+ globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
59
+
60
+ For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
61
+ Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
62
+ There are totally 13,928 positive pairs and 96,983,824 negative pairs.
63
+
64
+ | Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
65
+ | :---: | :--- | :--- | :--- |:--- |:--- |
66
+ | MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
67
+ | Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
68
+ | MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
69
+ | Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
70
+ | MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
71
+ | Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
72
+ | MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
73
+ | Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
74
+ | MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
75
+ | Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
76
+
77
+ ### Performance on IJB-C and Verification Datasets
78
+
79
+ | Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
80
+ | :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
81
+ | MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
82
+ | MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
83
+ | MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
84
+ | MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
85
+ | MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
86
+ | Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
87
+ | Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
88
+ | Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
89
+ | Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
90
+
91
+ [comment]: <> (More details see [model.md]&#40;docs/modelzoo.md&#41; in docs.)
92
+
93
+
94
+ ## [Speed Benchmark](docs/speed_benchmark.md)
95
+
96
+ **Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
97
+ classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
98
+ accuracy with several times faster training performance and smaller GPU memory.
99
+ Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
100
+ sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
101
+ sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
102
+ we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
103
+ training and mixed precision training.
104
+
105
+ ![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
106
+
107
+ More details see
108
+ [speed_benchmark.md](docs/speed_benchmark.md) in docs.
109
+
110
+ ### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
111
+
112
+ `-` means training failed because of gpu memory limitations.
113
+
114
+ | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
115
+ | :--- | :--- | :--- | :--- |
116
+ |125000 | 4681 | 4824 | 5004 |
117
+ |1400000 | **1672** | 3043 | 4738 |
118
+ |5500000 | **-** | **1389** | 3975 |
119
+ |8000000 | **-** | **-** | 3565 |
120
+ |16000000 | **-** | **-** | 2679 |
121
+ |29000000 | **-** | **-** | **1855** |
122
+
123
+ ### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
124
+
125
+ | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
126
+ | :--- | :--- | :--- | :--- |
127
+ |125000 | 7358 | 5306 | 4868 |
128
+ |1400000 | 32252 | 11178 | 6056 |
129
+ |5500000 | **-** | 32188 | 9854 |
130
+ |8000000 | **-** | **-** | 12310 |
131
+ |16000000 | **-** | **-** | 19950 |
132
+ |29000000 | **-** | **-** | 32324 |
133
+
134
+ ## Evaluation ICCV2021-MFR and IJB-C
135
+
136
+ More details see [eval.md](docs/eval.md) in docs.
137
+
138
+ ## Test
139
+
140
+ We tested many versions of PyTorch. Please create an issue if you are having trouble.
141
+
142
+ - [x] torch 1.6.0
143
+ - [x] torch 1.7.1
144
+ - [x] torch 1.8.0
145
+ - [x] torch 1.9.0
146
+
147
+ ## Citation
148
+
149
+ ```
150
+ @inproceedings{deng2019arcface,
151
+ title={Arcface: Additive angular margin loss for deep face recognition},
152
+ author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
153
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
154
+ pages={4690--4699},
155
+ year={2019}
156
+ }
157
+ @inproceedings{an2020partical_fc,
158
+ title={Partial FC: Training 10 Million Identities on a Single Machine},
159
+ author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
160
+ Zhang, Debing and Fu Ying},
161
+ booktitle={Arxiv 2010.05222},
162
+ year={2020}
163
+ }
164
+ ```
src/face3d/models/arcface_torch/backbones/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
2
+ from .mobilefacenet import get_mbf
3
+
4
+
5
+ def get_model(name, **kwargs):
6
+ # resnet
7
+ if name == "r18":
8
+ return iresnet18(False, **kwargs)
9
+ elif name == "r34":
10
+ return iresnet34(False, **kwargs)
11
+ elif name == "r50":
12
+ return iresnet50(False, **kwargs)
13
+ elif name == "r100":
14
+ return iresnet100(False, **kwargs)
15
+ elif name == "r200":
16
+ return iresnet200(False, **kwargs)
17
+ elif name == "r2060":
18
+ from .iresnet2060 import iresnet2060
19
+ return iresnet2060(False, **kwargs)
20
+ elif name == "mbf":
21
+ fp16 = kwargs.get("fp16", False)
22
+ num_features = kwargs.get("num_features", 512)
23
+ return get_mbf(fp16=fp16, num_features=num_features)
24
+ else:
25
+ raise ValueError()
src/face3d/models/arcface_torch/backbones/iresnet.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
8
+ """3x3 convolution with padding"""
9
+ return nn.Conv2d(in_planes,
10
+ out_planes,
11
+ kernel_size=3,
12
+ stride=stride,
13
+ padding=dilation,
14
+ groups=groups,
15
+ bias=False,
16
+ dilation=dilation)
17
+
18
+
19
+ def conv1x1(in_planes, out_planes, stride=1):
20
+ """1x1 convolution"""
21
+ return nn.Conv2d(in_planes,
22
+ out_planes,
23
+ kernel_size=1,
24
+ stride=stride,
25
+ bias=False)
26
+
27
+
28
+ class IBasicBlock(nn.Module):
29
+ expansion = 1
30
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
31
+ groups=1, base_width=64, dilation=1):
32
+ super(IBasicBlock, self).__init__()
33
+ if groups != 1 or base_width != 64:
34
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
35
+ if dilation > 1:
36
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
37
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
38
+ self.conv1 = conv3x3(inplanes, planes)
39
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
40
+ self.prelu = nn.PReLU(planes)
41
+ self.conv2 = conv3x3(planes, planes, stride)
42
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
43
+ self.downsample = downsample
44
+ self.stride = stride
45
+
46
+ def forward(self, x):
47
+ identity = x
48
+ out = self.bn1(x)
49
+ out = self.conv1(out)
50
+ out = self.bn2(out)
51
+ out = self.prelu(out)
52
+ out = self.conv2(out)
53
+ out = self.bn3(out)
54
+ if self.downsample is not None:
55
+ identity = self.downsample(x)
56
+ out += identity
57
+ return out
58
+
59
+
60
+ class IResNet(nn.Module):
61
+ fc_scale = 7 * 7
62
+ def __init__(self,
63
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
64
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
65
+ super(IResNet, self).__init__()
66
+ self.fp16 = fp16
67
+ self.inplanes = 64
68
+ self.dilation = 1
69
+ if replace_stride_with_dilation is None:
70
+ replace_stride_with_dilation = [False, False, False]
71
+ if len(replace_stride_with_dilation) != 3:
72
+ raise ValueError("replace_stride_with_dilation should be None "
73
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
74
+ self.groups = groups
75
+ self.base_width = width_per_group
76
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
77
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
78
+ self.prelu = nn.PReLU(self.inplanes)
79
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
80
+ self.layer2 = self._make_layer(block,
81
+ 128,
82
+ layers[1],
83
+ stride=2,
84
+ dilate=replace_stride_with_dilation[0])
85
+ self.layer3 = self._make_layer(block,
86
+ 256,
87
+ layers[2],
88
+ stride=2,
89
+ dilate=replace_stride_with_dilation[1])
90
+ self.layer4 = self._make_layer(block,
91
+ 512,
92
+ layers[3],
93
+ stride=2,
94
+ dilate=replace_stride_with_dilation[2])
95
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
96
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
97
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
98
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
99
+ nn.init.constant_(self.features.weight, 1.0)
100
+ self.features.weight.requires_grad = False
101
+
102
+ for m in self.modules():
103
+ if isinstance(m, nn.Conv2d):
104
+ nn.init.normal_(m.weight, 0, 0.1)
105
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
106
+ nn.init.constant_(m.weight, 1)
107
+ nn.init.constant_(m.bias, 0)
108
+
109
+ if zero_init_residual:
110
+ for m in self.modules():
111
+ if isinstance(m, IBasicBlock):
112
+ nn.init.constant_(m.bn2.weight, 0)
113
+
114
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
115
+ downsample = None
116
+ previous_dilation = self.dilation
117
+ if dilate:
118
+ self.dilation *= stride
119
+ stride = 1
120
+ if stride != 1 or self.inplanes != planes * block.expansion:
121
+ downsample = nn.Sequential(
122
+ conv1x1(self.inplanes, planes * block.expansion, stride),
123
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
124
+ )
125
+ layers = []
126
+ layers.append(
127
+ block(self.inplanes, planes, stride, downsample, self.groups,
128
+ self.base_width, previous_dilation))
129
+ self.inplanes = planes * block.expansion
130
+ for _ in range(1, blocks):
131
+ layers.append(
132
+ block(self.inplanes,
133
+ planes,
134
+ groups=self.groups,
135
+ base_width=self.base_width,
136
+ dilation=self.dilation))
137
+
138
+ return nn.Sequential(*layers)
139
+
140
+ def forward(self, x):
141
+ with torch.cuda.amp.autocast(self.fp16):
142
+ x = self.conv1(x)
143
+ x = self.bn1(x)
144
+ x = self.prelu(x)
145
+ x = self.layer1(x)
146
+ x = self.layer2(x)
147
+ x = self.layer3(x)
148
+ x = self.layer4(x)
149
+ x = self.bn2(x)
150
+ x = torch.flatten(x, 1)
151
+ x = self.dropout(x)
152
+ x = self.fc(x.float() if self.fp16 else x)
153
+ x = self.features(x)
154
+ return x
155
+
156
+
157
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
158
+ model = IResNet(block, layers, **kwargs)
159
+ if pretrained:
160
+ raise ValueError()
161
+ return model
162
+
163
+
164
+ def iresnet18(pretrained=False, progress=True, **kwargs):
165
+ return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
166
+ progress, **kwargs)
167
+
168
+
169
+ def iresnet34(pretrained=False, progress=True, **kwargs):
170
+ return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
171
+ progress, **kwargs)
172
+
173
+
174
+ def iresnet50(pretrained=False, progress=True, **kwargs):
175
+ return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
176
+ progress, **kwargs)
177
+
178
+
179
+ def iresnet100(pretrained=False, progress=True, **kwargs):
180
+ return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
181
+ progress, **kwargs)
182
+
183
+
184
+ def iresnet200(pretrained=False, progress=True, **kwargs):
185
+ return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
186
+ progress, **kwargs)
187
+
src/face3d/models/arcface_torch/backbones/iresnet2060.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ assert torch.__version__ >= "1.8.1"
5
+ from torch.utils.checkpoint import checkpoint_sequential
6
+
7
+ __all__ = ['iresnet2060']
8
+
9
+
10
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
11
+ """3x3 convolution with padding"""
12
+ return nn.Conv2d(in_planes,
13
+ out_planes,
14
+ kernel_size=3,
15
+ stride=stride,
16
+ padding=dilation,
17
+ groups=groups,
18
+ bias=False,
19
+ dilation=dilation)
20
+
21
+
22
+ def conv1x1(in_planes, out_planes, stride=1):
23
+ """1x1 convolution"""
24
+ return nn.Conv2d(in_planes,
25
+ out_planes,
26
+ kernel_size=1,
27
+ stride=stride,
28
+ bias=False)
29
+
30
+
31
+ class IBasicBlock(nn.Module):
32
+ expansion = 1
33
+
34
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
35
+ groups=1, base_width=64, dilation=1):
36
+ super(IBasicBlock, self).__init__()
37
+ if groups != 1 or base_width != 64:
38
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
39
+ if dilation > 1:
40
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
41
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
42
+ self.conv1 = conv3x3(inplanes, planes)
43
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
44
+ self.prelu = nn.PReLU(planes)
45
+ self.conv2 = conv3x3(planes, planes, stride)
46
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
47
+ self.downsample = downsample
48
+ self.stride = stride
49
+
50
+ def forward(self, x):
51
+ identity = x
52
+ out = self.bn1(x)
53
+ out = self.conv1(out)
54
+ out = self.bn2(out)
55
+ out = self.prelu(out)
56
+ out = self.conv2(out)
57
+ out = self.bn3(out)
58
+ if self.downsample is not None:
59
+ identity = self.downsample(x)
60
+ out += identity
61
+ return out
62
+
63
+
64
+ class IResNet(nn.Module):
65
+ fc_scale = 7 * 7
66
+
67
+ def __init__(self,
68
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
69
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
70
+ super(IResNet, self).__init__()
71
+ self.fp16 = fp16
72
+ self.inplanes = 64
73
+ self.dilation = 1
74
+ if replace_stride_with_dilation is None:
75
+ replace_stride_with_dilation = [False, False, False]
76
+ if len(replace_stride_with_dilation) != 3:
77
+ raise ValueError("replace_stride_with_dilation should be None "
78
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
79
+ self.groups = groups
80
+ self.base_width = width_per_group
81
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
82
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
83
+ self.prelu = nn.PReLU(self.inplanes)
84
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
85
+ self.layer2 = self._make_layer(block,
86
+ 128,
87
+ layers[1],
88
+ stride=2,
89
+ dilate=replace_stride_with_dilation[0])
90
+ self.layer3 = self._make_layer(block,
91
+ 256,
92
+ layers[2],
93
+ stride=2,
94
+ dilate=replace_stride_with_dilation[1])
95
+ self.layer4 = self._make_layer(block,
96
+ 512,
97
+ layers[3],
98
+ stride=2,
99
+ dilate=replace_stride_with_dilation[2])
100
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
101
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
102
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
103
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
104
+ nn.init.constant_(self.features.weight, 1.0)
105
+ self.features.weight.requires_grad = False
106
+
107
+ for m in self.modules():
108
+ if isinstance(m, nn.Conv2d):
109
+ nn.init.normal_(m.weight, 0, 0.1)
110
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
111
+ nn.init.constant_(m.weight, 1)
112
+ nn.init.constant_(m.bias, 0)
113
+
114
+ if zero_init_residual:
115
+ for m in self.modules():
116
+ if isinstance(m, IBasicBlock):
117
+ nn.init.constant_(m.bn2.weight, 0)
118
+
119
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
120
+ downsample = None
121
+ previous_dilation = self.dilation
122
+ if dilate:
123
+ self.dilation *= stride
124
+ stride = 1
125
+ if stride != 1 or self.inplanes != planes * block.expansion:
126
+ downsample = nn.Sequential(
127
+ conv1x1(self.inplanes, planes * block.expansion, stride),
128
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
129
+ )
130
+ layers = []
131
+ layers.append(
132
+ block(self.inplanes, planes, stride, downsample, self.groups,
133
+ self.base_width, previous_dilation))
134
+ self.inplanes = planes * block.expansion
135
+ for _ in range(1, blocks):
136
+ layers.append(
137
+ block(self.inplanes,
138
+ planes,
139
+ groups=self.groups,
140
+ base_width=self.base_width,
141
+ dilation=self.dilation))
142
+
143
+ return nn.Sequential(*layers)
144
+
145
+ def checkpoint(self, func, num_seg, x):
146
+ if self.training:
147
+ return checkpoint_sequential(func, num_seg, x)
148
+ else:
149
+ return func(x)
150
+
151
+ def forward(self, x):
152
+ with torch.cuda.amp.autocast(self.fp16):
153
+ x = self.conv1(x)
154
+ x = self.bn1(x)
155
+ x = self.prelu(x)
156
+ x = self.layer1(x)
157
+ x = self.checkpoint(self.layer2, 20, x)
158
+ x = self.checkpoint(self.layer3, 100, x)
159
+ x = self.layer4(x)
160
+ x = self.bn2(x)
161
+ x = torch.flatten(x, 1)
162
+ x = self.dropout(x)
163
+ x = self.fc(x.float() if self.fp16 else x)
164
+ x = self.features(x)
165
+ return x
166
+
167
+
168
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
169
+ model = IResNet(block, layers, **kwargs)
170
+ if pretrained:
171
+ raise ValueError()
172
+ return model
173
+
174
+
175
+ def iresnet2060(pretrained=False, progress=True, **kwargs):
176
+ return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
src/face3d/models/arcface_torch/backbones/mobilefacenet.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
3
+ Original author cavalleria
4
+ '''
5
+
6
+ import torch.nn as nn
7
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
8
+ import torch
9
+
10
+
11
+ class Flatten(Module):
12
+ def forward(self, x):
13
+ return x.view(x.size(0), -1)
14
+
15
+
16
+ class ConvBlock(Module):
17
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
18
+ super(ConvBlock, self).__init__()
19
+ self.layers = nn.Sequential(
20
+ Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
21
+ BatchNorm2d(num_features=out_c),
22
+ PReLU(num_parameters=out_c)
23
+ )
24
+
25
+ def forward(self, x):
26
+ return self.layers(x)
27
+
28
+
29
+ class LinearBlock(Module):
30
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
31
+ super(LinearBlock, self).__init__()
32
+ self.layers = nn.Sequential(
33
+ Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
34
+ BatchNorm2d(num_features=out_c)
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.layers(x)
39
+
40
+
41
+ class DepthWise(Module):
42
+ def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
43
+ super(DepthWise, self).__init__()
44
+ self.residual = residual
45
+ self.layers = nn.Sequential(
46
+ ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
47
+ ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
48
+ LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
49
+ )
50
+
51
+ def forward(self, x):
52
+ short_cut = None
53
+ if self.residual:
54
+ short_cut = x
55
+ x = self.layers(x)
56
+ if self.residual:
57
+ output = short_cut + x
58
+ else:
59
+ output = x
60
+ return output
61
+
62
+
63
+ class Residual(Module):
64
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
65
+ super(Residual, self).__init__()
66
+ modules = []
67
+ for _ in range(num_block):
68
+ modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
69
+ self.layers = Sequential(*modules)
70
+
71
+ def forward(self, x):
72
+ return self.layers(x)
73
+
74
+
75
+ class GDC(Module):
76
+ def __init__(self, embedding_size):
77
+ super(GDC, self).__init__()
78
+ self.layers = nn.Sequential(
79
+ LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
80
+ Flatten(),
81
+ Linear(512, embedding_size, bias=False),
82
+ BatchNorm1d(embedding_size))
83
+
84
+ def forward(self, x):
85
+ return self.layers(x)
86
+
87
+
88
+ class MobileFaceNet(Module):
89
+ def __init__(self, fp16=False, num_features=512):
90
+ super(MobileFaceNet, self).__init__()
91
+ scale = 2
92
+ self.fp16 = fp16
93
+ self.layers = nn.Sequential(
94
+ ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
95
+ ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
96
+ DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
97
+ Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
98
+ DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
99
+ Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
100
+ DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
101
+ Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
102
+ )
103
+ self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
104
+ self.features = GDC(num_features)
105
+ self._initialize_weights()
106
+
107
+ def _initialize_weights(self):
108
+ for m in self.modules():
109
+ if isinstance(m, nn.Conv2d):
110
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
111
+ if m.bias is not None:
112
+ m.bias.data.zero_()
113
+ elif isinstance(m, nn.BatchNorm2d):
114
+ m.weight.data.fill_(1)
115
+ m.bias.data.zero_()
116
+ elif isinstance(m, nn.Linear):
117
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
118
+ if m.bias is not None:
119
+ m.bias.data.zero_()
120
+
121
+ def forward(self, x):
122
+ with torch.cuda.amp.autocast(self.fp16):
123
+ x = self.layers(x)
124
+ x = self.conv_sep(x.float() if self.fp16 else x)
125
+ x = self.features(x)
126
+ return x
127
+
128
+
129
+ def get_mbf(fp16, num_features):
130
+ return MobileFaceNet(fp16, num_features)
src/face3d/models/arcface_torch/configs/3millions.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # configs for test speed
4
+
5
+ config = edict()
6
+ config.loss = "arcface"
7
+ config.network = "r50"
8
+ config.resume = False
9
+ config.output = None
10
+ config.embedding_size = 512
11
+ config.sample_rate = 1.0
12
+ config.fp16 = True
13
+ config.momentum = 0.9
14
+ config.weight_decay = 5e-4
15
+ config.batch_size = 128
16
+ config.lr = 0.1 # batch size is 512
17
+
18
+ config.rec = "synthetic"
19
+ config.num_classes = 300 * 10000
20
+ config.num_epoch = 30
21
+ config.warmup_epoch = -1
22
+ config.decay_epoch = [10, 16, 22]
23
+ config.val_targets = []
src/face3d/models/arcface_torch/configs/3millions_pfc.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # configs for test speed
4
+
5
+ config = edict()
6
+ config.loss = "arcface"
7
+ config.network = "r50"
8
+ config.resume = False
9
+ config.output = None
10
+ config.embedding_size = 512
11
+ config.sample_rate = 0.1
12
+ config.fp16 = True
13
+ config.momentum = 0.9
14
+ config.weight_decay = 5e-4
15
+ config.batch_size = 128
16
+ config.lr = 0.1 # batch size is 512
17
+
18
+ config.rec = "synthetic"
19
+ config.num_classes = 300 * 10000
20
+ config.num_epoch = 30
21
+ config.warmup_epoch = -1
22
+ config.decay_epoch = [10, 16, 22]
23
+ config.val_targets = []
src/face3d/models/arcface_torch/configs/__init__.py ADDED
File without changes
src/face3d/models/arcface_torch/configs/base.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "arcface"
9
+ config.network = "r50"
10
+ config.resume = False
11
+ config.output = "ms1mv3_arcface_r50"
12
+
13
+ config.dataset = "ms1m-retinaface-t1"
14
+ config.embedding_size = 512
15
+ config.sample_rate = 1
16
+ config.fp16 = False
17
+ config.momentum = 0.9
18
+ config.weight_decay = 5e-4
19
+ config.batch_size = 128
20
+ config.lr = 0.1 # batch size is 512
21
+
22
+ if config.dataset == "emore":
23
+ config.rec = "/train_tmp/faces_emore"
24
+ config.num_classes = 85742
25
+ config.num_image = 5822653
26
+ config.num_epoch = 16
27
+ config.warmup_epoch = -1
28
+ config.decay_epoch = [8, 14, ]
29
+ config.val_targets = ["lfw", ]
30
+
31
+ elif config.dataset == "ms1m-retinaface-t1":
32
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
33
+ config.num_classes = 93431
34
+ config.num_image = 5179510
35
+ config.num_epoch = 25
36
+ config.warmup_epoch = -1
37
+ config.decay_epoch = [11, 17, 22]
38
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
39
+
40
+ elif config.dataset == "glint360k":
41
+ config.rec = "/train_tmp/glint360k"
42
+ config.num_classes = 360232
43
+ config.num_image = 17091657
44
+ config.num_epoch = 20
45
+ config.warmup_epoch = -1
46
+ config.decay_epoch = [8, 12, 15, 18]
47
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
48
+
49
+ elif config.dataset == "webface":
50
+ config.rec = "/train_tmp/faces_webface_112x112"
51
+ config.num_classes = 10572
52
+ config.num_image = "forget"
53
+ config.num_epoch = 34
54
+ config.warmup_epoch = -1
55
+ config.decay_epoch = [20, 28, 32]
56
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]