Spanicin commited on
Commit
ebebd31
1 Parent(s): f40faa5

Delete src

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc +0 -0
  2. src/audio2exp_models/__pycache__/networks.cpython-38.pyc +0 -0
  3. src/audio2exp_models/audio2exp.py +0 -41
  4. src/audio2exp_models/networks.py +0 -74
  5. src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc +0 -0
  6. src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc +0 -0
  7. src/audio2pose_models/__pycache__/cvae.cpython-38.pyc +0 -0
  8. src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc +0 -0
  9. src/audio2pose_models/__pycache__/networks.cpython-38.pyc +0 -0
  10. src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc +0 -0
  11. src/audio2pose_models/audio2pose.py +0 -94
  12. src/audio2pose_models/audio_encoder.py +0 -64
  13. src/audio2pose_models/cvae.py +0 -149
  14. src/audio2pose_models/discriminator.py +0 -76
  15. src/audio2pose_models/networks.py +0 -140
  16. src/audio2pose_models/res_unet.py +0 -65
  17. src/config/auido2exp.yaml +0 -58
  18. src/config/auido2pose.yaml +0 -49
  19. src/config/facerender.yaml +0 -45
  20. src/config/facerender_still.yaml +0 -45
  21. src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc +0 -0
  22. src/face3d/__pycache__/extract_kp_videos.cpython-39.pyc +0 -0
  23. src/face3d/data/__init__.py +0 -116
  24. src/face3d/data/base_dataset.py +0 -125
  25. src/face3d/data/flist_dataset.py +0 -125
  26. src/face3d/data/image_folder.py +0 -66
  27. src/face3d/data/template_dataset.py +0 -75
  28. src/face3d/extract_kp_videos.py +0 -108
  29. src/face3d/extract_kp_videos_safe.py +0 -138
  30. src/face3d/models/__init__.py +0 -67
  31. src/face3d/models/__pycache__/__init__.cpython-38.pyc +0 -0
  32. src/face3d/models/__pycache__/__init__.cpython-39.pyc +0 -0
  33. src/face3d/models/__pycache__/base_model.cpython-38.pyc +0 -0
  34. src/face3d/models/__pycache__/base_model.cpython-39.pyc +0 -0
  35. src/face3d/models/__pycache__/networks.cpython-38.pyc +0 -0
  36. src/face3d/models/__pycache__/networks.cpython-39.pyc +0 -0
  37. src/face3d/models/arcface_torch/README.md +0 -164
  38. src/face3d/models/arcface_torch/backbones/__init__.py +0 -25
  39. src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc +0 -0
  40. src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-39.pyc +0 -0
  41. src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc +0 -0
  42. src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-39.pyc +0 -0
  43. src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc +0 -0
  44. src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-39.pyc +0 -0
  45. src/face3d/models/arcface_torch/backbones/iresnet.py +0 -187
  46. src/face3d/models/arcface_torch/backbones/iresnet2060.py +0 -176
  47. src/face3d/models/arcface_torch/backbones/mobilefacenet.py +0 -130
  48. src/face3d/models/arcface_torch/configs/3millions.py +0 -23
  49. src/face3d/models/arcface_torch/configs/3millions_pfc.py +0 -23
  50. src/face3d/models/arcface_torch/configs/__init__.py +0 -0
src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc DELETED
Binary file (1.25 kB)
 
src/audio2exp_models/__pycache__/networks.cpython-38.pyc DELETED
Binary file (2.1 kB)
 
src/audio2exp_models/audio2exp.py DELETED
@@ -1,41 +0,0 @@
1
- from tqdm import tqdm
2
- import torch
3
- from torch import nn
4
-
5
-
6
- class Audio2Exp(nn.Module):
7
- def __init__(self, netG, cfg, device, prepare_training_loss=False):
8
- super(Audio2Exp, self).__init__()
9
- self.cfg = cfg
10
- self.device = device
11
- self.netG = netG.to(device)
12
-
13
- def test(self, batch):
14
-
15
- mel_input = batch['indiv_mels'] # bs T 1 80 16
16
- bs = mel_input.shape[0]
17
- T = mel_input.shape[1]
18
-
19
- exp_coeff_pred = []
20
-
21
- for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
22
-
23
- current_mel_input = mel_input[:,i:i+10]
24
-
25
- #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
26
- ref = batch['ref'][:, :, :64][:, i:i+10]
27
- ratio = batch['ratio_gt'][:, i:i+10] #bs T
28
-
29
- audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
30
-
31
- curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
32
-
33
- exp_coeff_pred += [curr_exp_coeff_pred]
34
-
35
- # BS x T x 64
36
- results_dict = {
37
- 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
38
- }
39
- return results_dict
40
-
41
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audio2exp_models/networks.py DELETED
@@ -1,74 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
-
5
- class Conv2d(nn.Module):
6
- def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
7
- super().__init__(*args, **kwargs)
8
- self.conv_block = nn.Sequential(
9
- nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
- nn.BatchNorm2d(cout)
11
- )
12
- self.act = nn.ReLU()
13
- self.residual = residual
14
- self.use_act = use_act
15
-
16
- def forward(self, x):
17
- out = self.conv_block(x)
18
- if self.residual:
19
- out += x
20
-
21
- if self.use_act:
22
- return self.act(out)
23
- else:
24
- return out
25
-
26
- class SimpleWrapperV2(nn.Module):
27
- def __init__(self) -> None:
28
- super().__init__()
29
- self.audio_encoder = nn.Sequential(
30
- Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
31
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
32
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
33
-
34
- Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
35
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
36
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
37
-
38
- Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
39
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
40
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
41
-
42
- Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
43
- Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
44
-
45
- Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
46
- Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
47
- )
48
-
49
- #### load the pre-trained audio_encoder
50
- #self.audio_encoder = self.audio_encoder.to(device)
51
- '''
52
- wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
53
- state_dict = self.audio_encoder.state_dict()
54
-
55
- for k,v in wav2lip_state_dict.items():
56
- if 'audio_encoder' in k:
57
- print('init:', k)
58
- state_dict[k.replace('module.audio_encoder.', '')] = v
59
- self.audio_encoder.load_state_dict(state_dict)
60
- '''
61
-
62
- self.mapping1 = nn.Linear(512+64+1, 64)
63
- #self.mapping2 = nn.Linear(30, 64)
64
- #nn.init.constant_(self.mapping1.weight, 0.)
65
- nn.init.constant_(self.mapping1.bias, 0.)
66
-
67
- def forward(self, x, ref, ratio):
68
- x = self.audio_encoder(x).view(x.size(0), -1)
69
- ref_reshape = ref.reshape(x.size(0), -1)
70
- ratio = ratio.reshape(x.size(0), -1)
71
-
72
- y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
73
- out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
74
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc DELETED
Binary file (2.82 kB)
 
src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc DELETED
Binary file (2.38 kB)
 
src/audio2pose_models/__pycache__/cvae.cpython-38.pyc DELETED
Binary file (4.65 kB)
 
src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc DELETED
Binary file (2.42 kB)
 
src/audio2pose_models/__pycache__/networks.cpython-38.pyc DELETED
Binary file (4.7 kB)
 
src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc DELETED
Binary file (1.88 kB)
 
src/audio2pose_models/audio2pose.py DELETED
@@ -1,94 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from src.audio2pose_models.cvae import CVAE
4
- from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
5
- from src.audio2pose_models.audio_encoder import AudioEncoder
6
-
7
- class Audio2Pose(nn.Module):
8
- def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
9
- super().__init__()
10
- self.cfg = cfg
11
- self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
12
- self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
13
- self.device = device
14
-
15
- self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
16
- self.audio_encoder.eval()
17
- for param in self.audio_encoder.parameters():
18
- param.requires_grad = False
19
-
20
- self.netG = CVAE(cfg)
21
- self.netD_motion = PoseSequenceDiscriminator(cfg)
22
-
23
-
24
- def forward(self, x):
25
-
26
- batch = {}
27
- coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
28
- batch['pose_motion_gt'] = coeff_gt[:, 1:, -9:-3] - coeff_gt[:, :1, -9:-3] #bs frame_len 6
29
- batch['ref'] = coeff_gt[:, 0, -9:-3] #bs 6
30
- batch['class'] = x['class'].squeeze(0).cuda() # bs
31
- indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
32
-
33
- # forward
34
- audio_emb_list = []
35
- audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
36
- batch['audio_emb'] = audio_emb
37
- batch = self.netG(batch)
38
-
39
- pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
40
- pose_gt = coeff_gt[:, 1:, -9:-3].clone() # bs frame_len 6
41
- pose_pred = coeff_gt[:, :1, -9:-3] + pose_motion_pred # bs frame_len 6
42
-
43
- batch['pose_pred'] = pose_pred
44
- batch['pose_gt'] = pose_gt
45
-
46
- return batch
47
-
48
- def test(self, x):
49
-
50
- batch = {}
51
- ref = x['ref'] #bs 1 70
52
- batch['ref'] = x['ref'][:,0,-6:]
53
- batch['class'] = x['class']
54
- bs = ref.shape[0]
55
-
56
- indiv_mels= x['indiv_mels'] # bs T 1 80 16
57
- indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
58
- num_frames = x['num_frames']
59
- num_frames = int(num_frames) - 1
60
-
61
- #
62
- div = num_frames//self.seq_len
63
- re = num_frames%self.seq_len
64
- audio_emb_list = []
65
- pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
66
- device=batch['ref'].device)]
67
-
68
- for i in range(div):
69
- z = torch.randn(bs, self.latent_dim).to(ref.device)
70
- batch['z'] = z
71
- audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
72
- batch['audio_emb'] = audio_emb
73
- batch = self.netG.test(batch)
74
- pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
75
-
76
- if re != 0:
77
- z = torch.randn(bs, self.latent_dim).to(ref.device)
78
- batch['z'] = z
79
- audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
80
- if audio_emb.shape[1] != self.seq_len:
81
- pad_dim = self.seq_len-audio_emb.shape[1]
82
- pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
83
- audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
84
- batch['audio_emb'] = audio_emb
85
- batch = self.netG.test(batch)
86
- pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
87
-
88
- pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
89
- batch['pose_motion_pred'] = pose_motion_pred
90
-
91
- pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
92
-
93
- batch['pose_pred'] = pose_pred
94
- return batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audio2pose_models/audio_encoder.py DELETED
@@ -1,64 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
-
5
- class Conv2d(nn.Module):
6
- def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
- super().__init__(*args, **kwargs)
8
- self.conv_block = nn.Sequential(
9
- nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
- nn.BatchNorm2d(cout)
11
- )
12
- self.act = nn.ReLU()
13
- self.residual = residual
14
-
15
- def forward(self, x):
16
- out = self.conv_block(x)
17
- if self.residual:
18
- out += x
19
- return self.act(out)
20
-
21
- class AudioEncoder(nn.Module):
22
- def __init__(self, wav2lip_checkpoint, device):
23
- super(AudioEncoder, self).__init__()
24
-
25
- self.audio_encoder = nn.Sequential(
26
- Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
27
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
28
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
29
-
30
- Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
31
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
32
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
33
-
34
- Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
35
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
36
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
37
-
38
- Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
39
- Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
40
-
41
- Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
42
- Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43
-
44
- #### load the pre-trained audio_encoder
45
- wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
46
- state_dict = self.audio_encoder.state_dict()
47
-
48
- for k,v in wav2lip_state_dict.items():
49
- if 'audio_encoder' in k:
50
- state_dict[k.replace('module.audio_encoder.', '')] = v
51
- self.audio_encoder.load_state_dict(state_dict)
52
-
53
-
54
- def forward(self, audio_sequences):
55
- # audio_sequences = (B, T, 1, 80, 16)
56
- B = audio_sequences.size(0)
57
-
58
- audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
59
-
60
- audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
61
- dim = audio_embedding.shape[1]
62
- audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
63
-
64
- return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audio2pose_models/cvae.py DELETED
@@ -1,149 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
- from src.audio2pose_models.res_unet import ResUnet
5
-
6
- def class2onehot(idx, class_num):
7
-
8
- assert torch.max(idx).item() < class_num
9
- onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
10
- onehot.scatter_(1, idx, 1)
11
- return onehot
12
-
13
- class CVAE(nn.Module):
14
- def __init__(self, cfg):
15
- super().__init__()
16
- encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
17
- decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
18
- latent_size = cfg.MODEL.CVAE.LATENT_SIZE
19
- num_classes = cfg.DATASET.NUM_CLASSES
20
- audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
21
- audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
22
- seq_len = cfg.MODEL.CVAE.SEQ_LEN
23
-
24
- self.latent_size = latent_size
25
-
26
- self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
27
- audio_emb_in_size, audio_emb_out_size, seq_len)
28
- self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
29
- audio_emb_in_size, audio_emb_out_size, seq_len)
30
- def reparameterize(self, mu, logvar):
31
- std = torch.exp(0.5 * logvar)
32
- eps = torch.randn_like(std)
33
- return mu + eps * std
34
-
35
- def forward(self, batch):
36
- batch = self.encoder(batch)
37
- mu = batch['mu']
38
- logvar = batch['logvar']
39
- z = self.reparameterize(mu, logvar)
40
- batch['z'] = z
41
- return self.decoder(batch)
42
-
43
- def test(self, batch):
44
- '''
45
- class_id = batch['class']
46
- z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
47
- batch['z'] = z
48
- '''
49
- return self.decoder(batch)
50
-
51
- class ENCODER(nn.Module):
52
- def __init__(self, layer_sizes, latent_size, num_classes,
53
- audio_emb_in_size, audio_emb_out_size, seq_len):
54
- super().__init__()
55
-
56
- self.resunet = ResUnet()
57
- self.num_classes = num_classes
58
- self.seq_len = seq_len
59
-
60
- self.MLP = nn.Sequential()
61
- layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
62
- for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
63
- self.MLP.add_module(
64
- name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
65
- self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
66
-
67
- self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
68
- self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
69
- self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
70
-
71
- self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
72
-
73
- def forward(self, batch):
74
- class_id = batch['class']
75
- pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
76
- ref = batch['ref'] #bs 6
77
- bs = pose_motion_gt.shape[0]
78
- audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
79
-
80
- #pose encode
81
- pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
82
- pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
83
-
84
- #audio mapping
85
- print(audio_in.shape)
86
- audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
87
- audio_out = audio_out.reshape(bs, -1)
88
-
89
- class_bias = self.classbias[class_id] #bs latent_size
90
- x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
91
- x_out = self.MLP(x_in)
92
-
93
- mu = self.linear_means(x_out)
94
- logvar = self.linear_means(x_out) #bs latent_size
95
-
96
- batch.update({'mu':mu, 'logvar':logvar})
97
- return batch
98
-
99
- class DECODER(nn.Module):
100
- def __init__(self, layer_sizes, latent_size, num_classes,
101
- audio_emb_in_size, audio_emb_out_size, seq_len):
102
- super().__init__()
103
-
104
- self.resunet = ResUnet()
105
- self.num_classes = num_classes
106
- self.seq_len = seq_len
107
-
108
- self.MLP = nn.Sequential()
109
- input_size = latent_size + seq_len*audio_emb_out_size + 6
110
- for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
111
- self.MLP.add_module(
112
- name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
113
- if i+1 < len(layer_sizes):
114
- self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
115
- else:
116
- self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
117
-
118
- self.pose_linear = nn.Linear(6, 6)
119
- self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
120
-
121
- self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
122
-
123
- def forward(self, batch):
124
-
125
- z = batch['z'] #bs latent_size
126
- bs = z.shape[0]
127
- class_id = batch['class']
128
- ref = batch['ref'] #bs 6
129
- audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
130
- #print('audio_in: ', audio_in[:, :, :10])
131
-
132
- audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
133
- #print('audio_out: ', audio_out[:, :, :10])
134
- audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
135
- class_bias = self.classbias[class_id] #bs latent_size
136
-
137
- z = z + class_bias
138
- x_in = torch.cat([ref, z, audio_out], dim=-1)
139
- x_out = self.MLP(x_in) # bs layer_sizes[-1]
140
- x_out = x_out.reshape((bs, self.seq_len, -1))
141
-
142
- #print('x_out: ', x_out)
143
-
144
- pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
145
-
146
- pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
147
-
148
- batch.update({'pose_motion_pred':pose_motion_pred})
149
- return batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audio2pose_models/discriminator.py DELETED
@@ -1,76 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
-
5
- class ConvNormRelu(nn.Module):
6
- def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
7
- kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
8
- super().__init__()
9
- if kernel_size is None:
10
- if downsample:
11
- kernel_size, stride, padding = 4, 2, 1
12
- else:
13
- kernel_size, stride, padding = 3, 1, 1
14
-
15
- if conv_type == '2d':
16
- self.conv = nn.Conv2d(
17
- in_channels,
18
- out_channels,
19
- kernel_size,
20
- stride,
21
- padding,
22
- bias=False,
23
- )
24
- if norm == 'BN':
25
- self.norm = nn.BatchNorm2d(out_channels)
26
- elif norm == 'IN':
27
- self.norm = nn.InstanceNorm2d(out_channels)
28
- else:
29
- raise NotImplementedError
30
- elif conv_type == '1d':
31
- self.conv = nn.Conv1d(
32
- in_channels,
33
- out_channels,
34
- kernel_size,
35
- stride,
36
- padding,
37
- bias=False,
38
- )
39
- if norm == 'BN':
40
- self.norm = nn.BatchNorm1d(out_channels)
41
- elif norm == 'IN':
42
- self.norm = nn.InstanceNorm1d(out_channels)
43
- else:
44
- raise NotImplementedError
45
- nn.init.kaiming_normal_(self.conv.weight)
46
-
47
- self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
48
-
49
- def forward(self, x):
50
- x = self.conv(x)
51
- if isinstance(self.norm, nn.InstanceNorm1d):
52
- x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
53
- else:
54
- x = self.norm(x)
55
- x = self.act(x)
56
- return x
57
-
58
-
59
- class PoseSequenceDiscriminator(nn.Module):
60
- def __init__(self, cfg):
61
- super().__init__()
62
- self.cfg = cfg
63
- leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
64
-
65
- self.seq = nn.Sequential(
66
- ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
67
- ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
68
- ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
69
- nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
70
- )
71
-
72
- def forward(self, x):
73
- x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
74
- x = self.seq(x)
75
- x = x.squeeze(1)
76
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audio2pose_models/networks.py DELETED
@@ -1,140 +0,0 @@
1
- import torch.nn as nn
2
- import torch
3
-
4
-
5
- class ResidualConv(nn.Module):
6
- def __init__(self, input_dim, output_dim, stride, padding):
7
- super(ResidualConv, self).__init__()
8
-
9
- self.conv_block = nn.Sequential(
10
- nn.BatchNorm2d(input_dim),
11
- nn.ReLU(),
12
- nn.Conv2d(
13
- input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
14
- ),
15
- nn.BatchNorm2d(output_dim),
16
- nn.ReLU(),
17
- nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
18
- )
19
- self.conv_skip = nn.Sequential(
20
- nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
21
- nn.BatchNorm2d(output_dim),
22
- )
23
-
24
- def forward(self, x):
25
-
26
- return self.conv_block(x) + self.conv_skip(x)
27
-
28
-
29
- class Upsample(nn.Module):
30
- def __init__(self, input_dim, output_dim, kernel, stride):
31
- super(Upsample, self).__init__()
32
-
33
- self.upsample = nn.ConvTranspose2d(
34
- input_dim, output_dim, kernel_size=kernel, stride=stride
35
- )
36
-
37
- def forward(self, x):
38
- return self.upsample(x)
39
-
40
-
41
- class Squeeze_Excite_Block(nn.Module):
42
- def __init__(self, channel, reduction=16):
43
- super(Squeeze_Excite_Block, self).__init__()
44
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
45
- self.fc = nn.Sequential(
46
- nn.Linear(channel, channel // reduction, bias=False),
47
- nn.ReLU(inplace=True),
48
- nn.Linear(channel // reduction, channel, bias=False),
49
- nn.Sigmoid(),
50
- )
51
-
52
- def forward(self, x):
53
- b, c, _, _ = x.size()
54
- y = self.avg_pool(x).view(b, c)
55
- y = self.fc(y).view(b, c, 1, 1)
56
- return x * y.expand_as(x)
57
-
58
-
59
- class ASPP(nn.Module):
60
- def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
61
- super(ASPP, self).__init__()
62
-
63
- self.aspp_block1 = nn.Sequential(
64
- nn.Conv2d(
65
- in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
66
- ),
67
- nn.ReLU(inplace=True),
68
- nn.BatchNorm2d(out_dims),
69
- )
70
- self.aspp_block2 = nn.Sequential(
71
- nn.Conv2d(
72
- in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
73
- ),
74
- nn.ReLU(inplace=True),
75
- nn.BatchNorm2d(out_dims),
76
- )
77
- self.aspp_block3 = nn.Sequential(
78
- nn.Conv2d(
79
- in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
80
- ),
81
- nn.ReLU(inplace=True),
82
- nn.BatchNorm2d(out_dims),
83
- )
84
-
85
- self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
86
- self._init_weights()
87
-
88
- def forward(self, x):
89
- x1 = self.aspp_block1(x)
90
- x2 = self.aspp_block2(x)
91
- x3 = self.aspp_block3(x)
92
- out = torch.cat([x1, x2, x3], dim=1)
93
- return self.output(out)
94
-
95
- def _init_weights(self):
96
- for m in self.modules():
97
- if isinstance(m, nn.Conv2d):
98
- nn.init.kaiming_normal_(m.weight)
99
- elif isinstance(m, nn.BatchNorm2d):
100
- m.weight.data.fill_(1)
101
- m.bias.data.zero_()
102
-
103
-
104
- class Upsample_(nn.Module):
105
- def __init__(self, scale=2):
106
- super(Upsample_, self).__init__()
107
-
108
- self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
109
-
110
- def forward(self, x):
111
- return self.upsample(x)
112
-
113
-
114
- class AttentionBlock(nn.Module):
115
- def __init__(self, input_encoder, input_decoder, output_dim):
116
- super(AttentionBlock, self).__init__()
117
-
118
- self.conv_encoder = nn.Sequential(
119
- nn.BatchNorm2d(input_encoder),
120
- nn.ReLU(),
121
- nn.Conv2d(input_encoder, output_dim, 3, padding=1),
122
- nn.MaxPool2d(2, 2),
123
- )
124
-
125
- self.conv_decoder = nn.Sequential(
126
- nn.BatchNorm2d(input_decoder),
127
- nn.ReLU(),
128
- nn.Conv2d(input_decoder, output_dim, 3, padding=1),
129
- )
130
-
131
- self.conv_attn = nn.Sequential(
132
- nn.BatchNorm2d(output_dim),
133
- nn.ReLU(),
134
- nn.Conv2d(output_dim, 1, 1),
135
- )
136
-
137
- def forward(self, x1, x2):
138
- out = self.conv_encoder(x1) + self.conv_decoder(x2)
139
- out = self.conv_attn(out)
140
- return out * x2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audio2pose_models/res_unet.py DELETED
@@ -1,65 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from src.audio2pose_models.networks import ResidualConv, Upsample
4
-
5
-
6
- class ResUnet(nn.Module):
7
- def __init__(self, channel=1, filters=[32, 64, 128, 256]):
8
- super(ResUnet, self).__init__()
9
-
10
- self.input_layer = nn.Sequential(
11
- nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
12
- nn.BatchNorm2d(filters[0]),
13
- nn.ReLU(),
14
- nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
15
- )
16
- self.input_skip = nn.Sequential(
17
- nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
18
- )
19
-
20
- self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
21
- self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
22
-
23
- self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
24
-
25
- self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
26
- self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
27
-
28
- self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
29
- self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
30
-
31
- self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
32
- self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
33
-
34
- self.output_layer = nn.Sequential(
35
- nn.Conv2d(filters[0], 1, 1, 1),
36
- nn.Sigmoid(),
37
- )
38
-
39
- def forward(self, x):
40
- # Encode
41
- x1 = self.input_layer(x) + self.input_skip(x)
42
- x2 = self.residual_conv_1(x1)
43
- x3 = self.residual_conv_2(x2)
44
- # Bridge
45
- x4 = self.bridge(x3)
46
-
47
- # Decode
48
- x4 = self.upsample_1(x4)
49
- x5 = torch.cat([x4, x3], dim=1)
50
-
51
- x6 = self.up_residual_conv1(x5)
52
-
53
- x6 = self.upsample_2(x6)
54
- x7 = torch.cat([x6, x2], dim=1)
55
-
56
- x8 = self.up_residual_conv2(x7)
57
-
58
- x8 = self.upsample_3(x8)
59
- x9 = torch.cat([x8, x1], dim=1)
60
-
61
- x10 = self.up_residual_conv3(x9)
62
-
63
- output = self.output_layer(x10)
64
-
65
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/auido2exp.yaml DELETED
@@ -1,58 +0,0 @@
1
- DATASET:
2
- TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
3
- EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
4
- TRAIN_BATCH_SIZE: 32
5
- EVAL_BATCH_SIZE: 32
6
- EXP: True
7
- EXP_DIM: 64
8
- FRAME_LEN: 32
9
- COEFF_LEN: 73
10
- NUM_CLASSES: 46
11
- AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
- COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
13
- LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
14
- DEBUG: True
15
- NUM_REPEATS: 2
16
- T: 40
17
-
18
-
19
- MODEL:
20
- FRAMEWORK: V2
21
- AUDIOENCODER:
22
- LEAKY_RELU: True
23
- NORM: 'IN'
24
- DISCRIMINATOR:
25
- LEAKY_RELU: False
26
- INPUT_CHANNELS: 6
27
- CVAE:
28
- AUDIO_EMB_IN_SIZE: 512
29
- AUDIO_EMB_OUT_SIZE: 128
30
- SEQ_LEN: 32
31
- LATENT_SIZE: 256
32
- ENCODER_LAYER_SIZES: [192, 1024]
33
- DECODER_LAYER_SIZES: [1024, 192]
34
-
35
-
36
- TRAIN:
37
- MAX_EPOCH: 300
38
- GENERATOR:
39
- LR: 2.0e-5
40
- DISCRIMINATOR:
41
- LR: 1.0e-5
42
- LOSS:
43
- W_FEAT: 0
44
- W_COEFF_EXP: 2
45
- W_LM: 1.0e-2
46
- W_LM_MOUTH: 0
47
- W_REG: 0
48
- W_SYNC: 0
49
- W_COLOR: 0
50
- W_EXPRESSION: 0
51
- W_LIPREADING: 0.01
52
- W_LIPREADING_VV: 0
53
- W_EYE_BLINK: 4
54
-
55
- TAG:
56
- NAME: small_dataset
57
-
58
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/auido2pose.yaml DELETED
@@ -1,49 +0,0 @@
1
- DATASET:
2
- TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
3
- EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
4
- TRAIN_BATCH_SIZE: 64
5
- EVAL_BATCH_SIZE: 1
6
- EXP: True
7
- EXP_DIM: 64
8
- FRAME_LEN: 32
9
- COEFF_LEN: 73
10
- NUM_CLASSES: 46
11
- AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
- COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
13
- DEBUG: True
14
-
15
-
16
- MODEL:
17
- AUDIOENCODER:
18
- LEAKY_RELU: True
19
- NORM: 'IN'
20
- DISCRIMINATOR:
21
- LEAKY_RELU: False
22
- INPUT_CHANNELS: 6
23
- CVAE:
24
- AUDIO_EMB_IN_SIZE: 512
25
- AUDIO_EMB_OUT_SIZE: 6
26
- SEQ_LEN: 32
27
- LATENT_SIZE: 64
28
- ENCODER_LAYER_SIZES: [192, 128]
29
- DECODER_LAYER_SIZES: [128, 192]
30
-
31
-
32
- TRAIN:
33
- MAX_EPOCH: 150
34
- GENERATOR:
35
- LR: 1.0e-4
36
- DISCRIMINATOR:
37
- LR: 1.0e-4
38
- LOSS:
39
- LAMBDA_REG: 1
40
- LAMBDA_LANDMARKS: 0
41
- LAMBDA_VERTICES: 0
42
- LAMBDA_GAN_MOTION: 0.7
43
- LAMBDA_GAN_COEFF: 0
44
- LAMBDA_KL: 1
45
-
46
- TAG:
47
- NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
48
-
49
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/facerender.yaml DELETED
@@ -1,45 +0,0 @@
1
- model_params:
2
- common_params:
3
- num_kp: 15
4
- image_channel: 3
5
- feature_channel: 32
6
- estimate_jacobian: False # True
7
- kp_detector_params:
8
- temperature: 0.1
9
- block_expansion: 32
10
- max_features: 1024
11
- scale_factor: 0.25 # 0.25
12
- num_blocks: 5
13
- reshape_channel: 16384 # 16384 = 1024 * 16
14
- reshape_depth: 16
15
- he_estimator_params:
16
- block_expansion: 64
17
- max_features: 2048
18
- num_bins: 66
19
- generator_params:
20
- block_expansion: 64
21
- max_features: 512
22
- num_down_blocks: 2
23
- reshape_channel: 32
24
- reshape_depth: 16 # 512 = 32 * 16
25
- num_resblocks: 6
26
- estimate_occlusion_map: True
27
- dense_motion_params:
28
- block_expansion: 32
29
- max_features: 1024
30
- num_blocks: 5
31
- reshape_depth: 16
32
- compress: 4
33
- discriminator_params:
34
- scales: [1]
35
- block_expansion: 32
36
- max_features: 512
37
- num_blocks: 4
38
- sn: True
39
- mapping_params:
40
- coeff_nc: 70
41
- descriptor_nc: 1024
42
- layer: 3
43
- num_kp: 15
44
- num_bins: 66
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/facerender_still.yaml DELETED
@@ -1,45 +0,0 @@
1
- model_params:
2
- common_params:
3
- num_kp: 15
4
- image_channel: 3
5
- feature_channel: 32
6
- estimate_jacobian: False # True
7
- kp_detector_params:
8
- temperature: 0.1
9
- block_expansion: 32
10
- max_features: 1024
11
- scale_factor: 0.25 # 0.25
12
- num_blocks: 5
13
- reshape_channel: 16384 # 16384 = 1024 * 16
14
- reshape_depth: 16
15
- he_estimator_params:
16
- block_expansion: 64
17
- max_features: 2048
18
- num_bins: 66
19
- generator_params:
20
- block_expansion: 64
21
- max_features: 512
22
- num_down_blocks: 2
23
- reshape_channel: 32
24
- reshape_depth: 16 # 512 = 32 * 16
25
- num_resblocks: 6
26
- estimate_occlusion_map: True
27
- dense_motion_params:
28
- block_expansion: 32
29
- max_features: 1024
30
- num_blocks: 5
31
- reshape_depth: 16
32
- compress: 4
33
- discriminator_params:
34
- scales: [1]
35
- block_expansion: 32
36
- max_features: 512
37
- num_blocks: 4
38
- sn: True
39
- mapping_params:
40
- coeff_nc: 73
41
- descriptor_nc: 1024
42
- layer: 3
43
- num_kp: 15
44
- num_bins: 66
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc DELETED
Binary file (3.55 kB)
 
src/face3d/__pycache__/extract_kp_videos.cpython-39.pyc DELETED
Binary file (3.55 kB)
 
src/face3d/data/__init__.py DELETED
@@ -1,116 +0,0 @@
1
- """This package includes all the modules related to data loading and preprocessing
2
-
3
- To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
- You need to implement four functions:
5
- -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
- -- <__len__>: return the size of dataset.
7
- -- <__getitem__>: get a data point from data loader.
8
- -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
-
10
- Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
- See our template dataset class 'template_dataset.py' for more details.
12
- """
13
- import numpy as np
14
- import importlib
15
- import torch.utils.data
16
- from face3d.data.base_dataset import BaseDataset
17
-
18
-
19
- def find_dataset_using_name(dataset_name):
20
- """Import the module "data/[dataset_name]_dataset.py".
21
-
22
- In the file, the class called DatasetNameDataset() will
23
- be instantiated. It has to be a subclass of BaseDataset,
24
- and it is case-insensitive.
25
- """
26
- dataset_filename = "data." + dataset_name + "_dataset"
27
- datasetlib = importlib.import_module(dataset_filename)
28
-
29
- dataset = None
30
- target_dataset_name = dataset_name.replace('_', '') + 'dataset'
31
- for name, cls in datasetlib.__dict__.items():
32
- if name.lower() == target_dataset_name.lower() \
33
- and issubclass(cls, BaseDataset):
34
- dataset = cls
35
-
36
- if dataset is None:
37
- raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
38
-
39
- return dataset
40
-
41
-
42
- def get_option_setter(dataset_name):
43
- """Return the static method <modify_commandline_options> of the dataset class."""
44
- dataset_class = find_dataset_using_name(dataset_name)
45
- return dataset_class.modify_commandline_options
46
-
47
-
48
- def create_dataset(opt, rank=0):
49
- """Create a dataset given the option.
50
-
51
- This function wraps the class CustomDatasetDataLoader.
52
- This is the main interface between this package and 'train.py'/'test.py'
53
-
54
- Example:
55
- >>> from data import create_dataset
56
- >>> dataset = create_dataset(opt)
57
- """
58
- data_loader = CustomDatasetDataLoader(opt, rank=rank)
59
- dataset = data_loader.load_data()
60
- return dataset
61
-
62
- class CustomDatasetDataLoader():
63
- """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
-
65
- def __init__(self, opt, rank=0):
66
- """Initialize this class
67
-
68
- Step 1: create a dataset instance given the name [dataset_mode]
69
- Step 2: create a multi-threaded data loader.
70
- """
71
- self.opt = opt
72
- dataset_class = find_dataset_using_name(opt.dataset_mode)
73
- self.dataset = dataset_class(opt)
74
- self.sampler = None
75
- print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
76
- if opt.use_ddp and opt.isTrain:
77
- world_size = opt.world_size
78
- self.sampler = torch.utils.data.distributed.DistributedSampler(
79
- self.dataset,
80
- num_replicas=world_size,
81
- rank=rank,
82
- shuffle=not opt.serial_batches
83
- )
84
- self.dataloader = torch.utils.data.DataLoader(
85
- self.dataset,
86
- sampler=self.sampler,
87
- num_workers=int(opt.num_threads / world_size),
88
- batch_size=int(opt.batch_size / world_size),
89
- drop_last=True)
90
- else:
91
- self.dataloader = torch.utils.data.DataLoader(
92
- self.dataset,
93
- batch_size=opt.batch_size,
94
- shuffle=(not opt.serial_batches) and opt.isTrain,
95
- num_workers=int(opt.num_threads),
96
- drop_last=True
97
- )
98
-
99
- def set_epoch(self, epoch):
100
- self.dataset.current_epoch = epoch
101
- if self.sampler is not None:
102
- self.sampler.set_epoch(epoch)
103
-
104
- def load_data(self):
105
- return self
106
-
107
- def __len__(self):
108
- """Return the number of data in the dataset"""
109
- return min(len(self.dataset), self.opt.max_dataset_size)
110
-
111
- def __iter__(self):
112
- """Return a batch of data"""
113
- for i, data in enumerate(self.dataloader):
114
- if i * self.opt.batch_size >= self.opt.max_dataset_size:
115
- break
116
- yield data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/data/base_dataset.py DELETED
@@ -1,125 +0,0 @@
1
- """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
-
3
- It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
- """
5
- import random
6
- import numpy as np
7
- import torch.utils.data as data
8
- from PIL import Image
9
- import torchvision.transforms as transforms
10
- from abc import ABC, abstractmethod
11
-
12
-
13
- class BaseDataset(data.Dataset, ABC):
14
- """This class is an abstract base class (ABC) for datasets.
15
-
16
- To create a subclass, you need to implement the following four functions:
17
- -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
- -- <__len__>: return the size of dataset.
19
- -- <__getitem__>: get a data point.
20
- -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
- """
22
-
23
- def __init__(self, opt):
24
- """Initialize the class; save the options in the class
25
-
26
- Parameters:
27
- opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
- """
29
- self.opt = opt
30
- # self.root = opt.dataroot
31
- self.current_epoch = 0
32
-
33
- @staticmethod
34
- def modify_commandline_options(parser, is_train):
35
- """Add new dataset-specific options, and rewrite default values for existing options.
36
-
37
- Parameters:
38
- parser -- original option parser
39
- is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40
-
41
- Returns:
42
- the modified parser.
43
- """
44
- return parser
45
-
46
- @abstractmethod
47
- def __len__(self):
48
- """Return the total number of images in the dataset."""
49
- return 0
50
-
51
- @abstractmethod
52
- def __getitem__(self, index):
53
- """Return a data point and its metadata information.
54
-
55
- Parameters:
56
- index - - a random integer for data indexing
57
-
58
- Returns:
59
- a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60
- """
61
- pass
62
-
63
-
64
- def get_transform(grayscale=False):
65
- transform_list = []
66
- if grayscale:
67
- transform_list.append(transforms.Grayscale(1))
68
- transform_list += [transforms.ToTensor()]
69
- return transforms.Compose(transform_list)
70
-
71
- def get_affine_mat(opt, size):
72
- shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
73
- w, h = size
74
-
75
- if 'shift' in opt.preprocess:
76
- shift_pixs = int(opt.shift_pixs)
77
- shift_x = random.randint(-shift_pixs, shift_pixs)
78
- shift_y = random.randint(-shift_pixs, shift_pixs)
79
- if 'scale' in opt.preprocess:
80
- scale = 1 + opt.scale_delta * (2 * random.random() - 1)
81
- if 'rot' in opt.preprocess:
82
- rot_angle = opt.rot_angle * (2 * random.random() - 1)
83
- rot_rad = -rot_angle * np.pi/180
84
- if 'flip' in opt.preprocess:
85
- flip = random.random() > 0.5
86
-
87
- shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
88
- flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
89
- shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
90
- rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
91
- scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
92
- shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
93
-
94
- affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
95
- affine_inv = np.linalg.inv(affine)
96
- return affine, affine_inv, flip
97
-
98
- def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
99
- return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
100
-
101
- def apply_lm_affine(landmark, affine, flip, size):
102
- _, h = size
103
- lm = landmark.copy()
104
- lm[:, 1] = h - 1 - lm[:, 1]
105
- lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
106
- lm = lm @ np.transpose(affine)
107
- lm[:, :2] = lm[:, :2] / lm[:, 2:]
108
- lm = lm[:, :2]
109
- lm[:, 1] = h - 1 - lm[:, 1]
110
- if flip:
111
- lm_ = lm.copy()
112
- lm_[:17] = lm[16::-1]
113
- lm_[17:22] = lm[26:21:-1]
114
- lm_[22:27] = lm[21:16:-1]
115
- lm_[31:36] = lm[35:30:-1]
116
- lm_[36:40] = lm[45:41:-1]
117
- lm_[40:42] = lm[47:45:-1]
118
- lm_[42:46] = lm[39:35:-1]
119
- lm_[46:48] = lm[41:39:-1]
120
- lm_[48:55] = lm[54:47:-1]
121
- lm_[55:60] = lm[59:54:-1]
122
- lm_[60:65] = lm[64:59:-1]
123
- lm_[65:68] = lm[67:64:-1]
124
- lm = lm_
125
- return lm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/data/flist_dataset.py DELETED
@@ -1,125 +0,0 @@
1
- """This script defines the custom dataset for Deep3DFaceRecon_pytorch
2
- """
3
-
4
- import os.path
5
- from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
6
- from data.image_folder import make_dataset
7
- from PIL import Image
8
- import random
9
- import util.util as util
10
- import numpy as np
11
- import json
12
- import torch
13
- from scipy.io import loadmat, savemat
14
- import pickle
15
- from util.preprocess import align_img, estimate_norm
16
- from util.load_mats import load_lm3d
17
-
18
-
19
- def default_flist_reader(flist):
20
- """
21
- flist format: impath label\nimpath label\n ...(same to caffe's filelist)
22
- """
23
- imlist = []
24
- with open(flist, 'r') as rf:
25
- for line in rf.readlines():
26
- impath = line.strip()
27
- imlist.append(impath)
28
-
29
- return imlist
30
-
31
- def jason_flist_reader(flist):
32
- with open(flist, 'r') as fp:
33
- info = json.load(fp)
34
- return info
35
-
36
- def parse_label(label):
37
- return torch.tensor(np.array(label).astype(np.float32))
38
-
39
-
40
- class FlistDataset(BaseDataset):
41
- """
42
- It requires one directories to host training images '/path/to/data/train'
43
- You can train the model with the dataset flag '--dataroot /path/to/data'.
44
- """
45
-
46
- def __init__(self, opt):
47
- """Initialize this dataset class.
48
-
49
- Parameters:
50
- opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
51
- """
52
- BaseDataset.__init__(self, opt)
53
-
54
- self.lm3d_std = load_lm3d(opt.bfm_folder)
55
-
56
- msk_names = default_flist_reader(opt.flist)
57
- self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
58
-
59
- self.size = len(self.msk_paths)
60
- self.opt = opt
61
-
62
- self.name = 'train' if opt.isTrain else 'val'
63
- if '_' in opt.flist:
64
- self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
65
-
66
-
67
- def __getitem__(self, index):
68
- """Return a data point and its metadata information.
69
-
70
- Parameters:
71
- index (int) -- a random integer for data indexing
72
-
73
- Returns a dictionary that contains A, B, A_paths and B_paths
74
- img (tensor) -- an image in the input domain
75
- msk (tensor) -- its corresponding attention mask
76
- lm (tensor) -- its corresponding 3d landmarks
77
- im_paths (str) -- image paths
78
- aug_flag (bool) -- a flag used to tell whether its raw or augmented
79
- """
80
- msk_path = self.msk_paths[index % self.size] # make sure index is within then range
81
- img_path = msk_path.replace('mask/', '')
82
- lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
83
-
84
- raw_img = Image.open(img_path).convert('RGB')
85
- raw_msk = Image.open(msk_path).convert('RGB')
86
- raw_lm = np.loadtxt(lm_path).astype(np.float32)
87
-
88
- _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
89
-
90
- aug_flag = self.opt.use_aug and self.opt.isTrain
91
- if aug_flag:
92
- img, lm, msk = self._augmentation(img, lm, self.opt, msk)
93
-
94
- _, H = img.size
95
- M = estimate_norm(lm, H)
96
- transform = get_transform()
97
- img_tensor = transform(img)
98
- msk_tensor = transform(msk)[:1, ...]
99
- lm_tensor = parse_label(lm)
100
- M_tensor = parse_label(M)
101
-
102
-
103
- return {'imgs': img_tensor,
104
- 'lms': lm_tensor,
105
- 'msks': msk_tensor,
106
- 'M': M_tensor,
107
- 'im_paths': img_path,
108
- 'aug_flag': aug_flag,
109
- 'dataset': self.name}
110
-
111
- def _augmentation(self, img, lm, opt, msk=None):
112
- affine, affine_inv, flip = get_affine_mat(opt, img.size)
113
- img = apply_img_affine(img, affine_inv)
114
- lm = apply_lm_affine(lm, affine, flip, img.size)
115
- if msk is not None:
116
- msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
117
- return img, lm, msk
118
-
119
-
120
-
121
-
122
- def __len__(self):
123
- """Return the total number of images in the dataset.
124
- """
125
- return self.size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/data/image_folder.py DELETED
@@ -1,66 +0,0 @@
1
- """A modified image folder class
2
-
3
- We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
- so that this class can load images from both current directory and its subdirectories.
5
- """
6
- import numpy as np
7
- import torch.utils.data as data
8
-
9
- from PIL import Image
10
- import os
11
- import os.path
12
-
13
- IMG_EXTENSIONS = [
14
- '.jpg', '.JPG', '.jpeg', '.JPEG',
15
- '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16
- '.tif', '.TIF', '.tiff', '.TIFF',
17
- ]
18
-
19
-
20
- def is_image_file(filename):
21
- return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22
-
23
-
24
- def make_dataset(dir, max_dataset_size=float("inf")):
25
- images = []
26
- assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
27
-
28
- for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
29
- for fname in fnames:
30
- if is_image_file(fname):
31
- path = os.path.join(root, fname)
32
- images.append(path)
33
- return images[:min(max_dataset_size, len(images))]
34
-
35
-
36
- def default_loader(path):
37
- return Image.open(path).convert('RGB')
38
-
39
-
40
- class ImageFolder(data.Dataset):
41
-
42
- def __init__(self, root, transform=None, return_paths=False,
43
- loader=default_loader):
44
- imgs = make_dataset(root)
45
- if len(imgs) == 0:
46
- raise(RuntimeError("Found 0 images in: " + root + "\n"
47
- "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48
-
49
- self.root = root
50
- self.imgs = imgs
51
- self.transform = transform
52
- self.return_paths = return_paths
53
- self.loader = loader
54
-
55
- def __getitem__(self, index):
56
- path = self.imgs[index]
57
- img = self.loader(path)
58
- if self.transform is not None:
59
- img = self.transform(img)
60
- if self.return_paths:
61
- return img, path
62
- else:
63
- return img
64
-
65
- def __len__(self):
66
- return len(self.imgs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/data/template_dataset.py DELETED
@@ -1,75 +0,0 @@
1
- """Dataset class template
2
-
3
- This module provides a template for users to implement custom datasets.
4
- You can specify '--dataset_mode template' to use this dataset.
5
- The class name should be consistent with both the filename and its dataset_mode option.
6
- The filename should be <dataset_mode>_dataset.py
7
- The class name should be <Dataset_mode>Dataset.py
8
- You need to implement the following functions:
9
- -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
10
- -- <__init__>: Initialize this dataset class.
11
- -- <__getitem__>: Return a data point and its metadata information.
12
- -- <__len__>: Return the number of images.
13
- """
14
- from data.base_dataset import BaseDataset, get_transform
15
- # from data.image_folder import make_dataset
16
- # from PIL import Image
17
-
18
-
19
- class TemplateDataset(BaseDataset):
20
- """A template dataset class for you to implement custom datasets."""
21
- @staticmethod
22
- def modify_commandline_options(parser, is_train):
23
- """Add new dataset-specific options, and rewrite default values for existing options.
24
-
25
- Parameters:
26
- parser -- original option parser
27
- is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
28
-
29
- Returns:
30
- the modified parser.
31
- """
32
- parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
33
- parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
34
- return parser
35
-
36
- def __init__(self, opt):
37
- """Initialize this dataset class.
38
-
39
- Parameters:
40
- opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
41
-
42
- A few things can be done here.
43
- - save the options (have been done in BaseDataset)
44
- - get image paths and meta information of the dataset.
45
- - define the image transformation.
46
- """
47
- # save the option and dataset root
48
- BaseDataset.__init__(self, opt)
49
- # get the image paths of your dataset;
50
- self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
51
- # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
52
- self.transform = get_transform(opt)
53
-
54
- def __getitem__(self, index):
55
- """Return a data point and its metadata information.
56
-
57
- Parameters:
58
- index -- a random integer for data indexing
59
-
60
- Returns:
61
- a dictionary of data with their names. It usually contains the data itself and its metadata information.
62
-
63
- Step 1: get a random image path: e.g., path = self.image_paths[index]
64
- Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
65
- Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
66
- Step 4: return a data point as a dictionary.
67
- """
68
- path = 'temp' # needs to be a string
69
- data_A = None # needs to be a tensor
70
- data_B = None # needs to be a tensor
71
- return {'data_A': data_A, 'data_B': data_B, 'path': path}
72
-
73
- def __len__(self):
74
- """Return the total number of images."""
75
- return len(self.image_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/extract_kp_videos.py DELETED
@@ -1,108 +0,0 @@
1
- import os
2
- import cv2
3
- import time
4
- import glob
5
- import argparse
6
- import face_alignment
7
- import numpy as np
8
- from PIL import Image
9
- from tqdm import tqdm
10
- from itertools import cycle
11
-
12
- from torch.multiprocessing import Pool, Process, set_start_method
13
-
14
- class KeypointExtractor():
15
- def __init__(self, device):
16
- self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D,
17
- device=device)
18
-
19
- def extract_keypoint(self, images, name=None, info=True):
20
- if isinstance(images, list):
21
- keypoints = []
22
- if info:
23
- i_range = tqdm(images,desc='landmark Det:')
24
- else:
25
- i_range = images
26
-
27
- for image in i_range:
28
- current_kp = self.extract_keypoint(image)
29
- if np.mean(current_kp) == -1 and keypoints:
30
- keypoints.append(keypoints[-1])
31
- else:
32
- keypoints.append(current_kp[None])
33
-
34
- keypoints = np.concatenate(keypoints, 0)
35
- np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
36
- return keypoints
37
- else:
38
- while True:
39
- try:
40
- keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
41
- break
42
- except RuntimeError as e:
43
- if str(e).startswith('CUDA'):
44
- print("Warning: out of memory, sleep for 1s")
45
- time.sleep(1)
46
- else:
47
- print(e)
48
- break
49
- except TypeError:
50
- print('No face detected in this image')
51
- shape = [68, 2]
52
- keypoints = -1. * np.ones(shape)
53
- break
54
- if name is not None:
55
- np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
56
- return keypoints
57
-
58
- def read_video(filename):
59
- frames = []
60
- cap = cv2.VideoCapture(filename)
61
- while cap.isOpened():
62
- ret, frame = cap.read()
63
- if ret:
64
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65
- frame = Image.fromarray(frame)
66
- frames.append(frame)
67
- else:
68
- break
69
- cap.release()
70
- return frames
71
-
72
- def run(data):
73
- filename, opt, device = data
74
- os.environ['CUDA_VISIBLE_DEVICES'] = device
75
- kp_extractor = KeypointExtractor()
76
- images = read_video(filename)
77
- name = filename.split('/')[-2:]
78
- os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
79
- kp_extractor.extract_keypoint(
80
- images,
81
- name=os.path.join(opt.output_dir, name[-2], name[-1])
82
- )
83
-
84
- if __name__ == '__main__':
85
- set_start_method('spawn')
86
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
87
- parser.add_argument('--input_dir', type=str, help='the folder of the input files')
88
- parser.add_argument('--output_dir', type=str, help='the folder of the output files')
89
- parser.add_argument('--device_ids', type=str, default='0,1')
90
- parser.add_argument('--workers', type=int, default=4)
91
-
92
- opt = parser.parse_args()
93
- filenames = list()
94
- VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
95
- VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
96
- extensions = VIDEO_EXTENSIONS
97
-
98
- for ext in extensions:
99
- os.listdir(f'{opt.input_dir}')
100
- print(f'{opt.input_dir}/*.{ext}')
101
- filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
102
- print('Total number of videos:', len(filenames))
103
- pool = Pool(opt.workers)
104
- args_list = cycle([opt])
105
- device_ids = opt.device_ids.split(",")
106
- device_ids = cycle(device_ids)
107
- for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
108
- None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/extract_kp_videos_safe.py DELETED
@@ -1,138 +0,0 @@
1
- import os
2
- import cv2
3
- import time
4
- import glob
5
- import argparse
6
- import numpy as np
7
- from PIL import Image
8
- import torch
9
- from tqdm import tqdm
10
- from itertools import cycle
11
- from facexlib.alignment import init_alignment_model, landmark_98_to_68
12
- from facexlib.detection import init_detection_model
13
- from torch.multiprocessing import Pool, Process, set_start_method
14
-
15
-
16
- class KeypointExtractor():
17
- def __init__(self, device='cuda'):
18
-
19
- ### gfpgan/weights
20
- try:
21
- import webui # in webui
22
- root_path = 'extensions/SadTalker/gfpgan/weights'
23
-
24
- except:
25
- root_path = 'gfpgan/weights'
26
-
27
- self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
28
- self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
29
-
30
- def extract_keypoint(self, images, name=None, info=True):
31
- if isinstance(images, list):
32
- keypoints = []
33
- if info:
34
- i_range = tqdm(images,desc='landmark Det:')
35
- else:
36
- i_range = images
37
-
38
- for image in i_range:
39
- current_kp = self.extract_keypoint(image)
40
- # current_kp = self.detector.get_landmarks(np.array(image))
41
- if np.mean(current_kp) == -1 and keypoints:
42
- keypoints.append(keypoints[-1])
43
- else:
44
- keypoints.append(current_kp[None])
45
-
46
- keypoints = np.concatenate(keypoints, 0)
47
- np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
48
- return keypoints
49
- else:
50
- while True:
51
- try:
52
- with torch.no_grad():
53
- # face detection -> face alignment.
54
- img = np.array(images)
55
- bboxes = self.det_net.detect_faces(images, 0.97)
56
-
57
- bboxes = bboxes[0]
58
-
59
- # bboxes[0] -= 100
60
- # bboxes[1] -= 100
61
- # bboxes[2] += 100
62
- # bboxes[3] += 100
63
- img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
64
-
65
- keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
66
-
67
- #### keypoints to the original location
68
- keypoints[:,0] += int(bboxes[0])
69
- keypoints[:,1] += int(bboxes[1])
70
-
71
- break
72
- except RuntimeError as e:
73
- if str(e).startswith('CUDA'):
74
- print("Warning: out of memory, sleep for 1s")
75
- time.sleep(1)
76
- else:
77
- print(e)
78
- break
79
- except TypeError:
80
- print('No face detected in this image')
81
- shape = [68, 2]
82
- keypoints = -1. * np.ones(shape)
83
- break
84
- if name is not None:
85
- np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
86
- return keypoints
87
-
88
- def read_video(filename):
89
- frames = []
90
- cap = cv2.VideoCapture(filename)
91
- while cap.isOpened():
92
- ret, frame = cap.read()
93
- if ret:
94
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
95
- frame = Image.fromarray(frame)
96
- frames.append(frame)
97
- else:
98
- break
99
- cap.release()
100
- return frames
101
-
102
- def run(data):
103
- filename, opt, device = data
104
- os.environ['CUDA_VISIBLE_DEVICES'] = device
105
- kp_extractor = KeypointExtractor()
106
- images = read_video(filename)
107
- name = filename.split('/')[-2:]
108
- os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
109
- kp_extractor.extract_keypoint(
110
- images,
111
- name=os.path.join(opt.output_dir, name[-2], name[-1])
112
- )
113
-
114
- if __name__ == '__main__':
115
- set_start_method('spawn')
116
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
117
- parser.add_argument('--input_dir', type=str, help='the folder of the input files')
118
- parser.add_argument('--output_dir', type=str, help='the folder of the output files')
119
- parser.add_argument('--device_ids', type=str, default='0,1')
120
- parser.add_argument('--workers', type=int, default=4)
121
-
122
- opt = parser.parse_args()
123
- filenames = list()
124
- VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
125
- VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
126
- extensions = VIDEO_EXTENSIONS
127
-
128
- for ext in extensions:
129
- os.listdir(f'{opt.input_dir}')
130
- print(f'{opt.input_dir}/*.{ext}')
131
- filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
132
- print('Total number of videos:', len(filenames))
133
- pool = Pool(opt.workers)
134
- args_list = cycle([opt])
135
- device_ids = opt.device_ids.split(",")
136
- device_ids = cycle(device_ids)
137
- for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
138
- None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/__init__.py DELETED
@@ -1,67 +0,0 @@
1
- """This package contains modules related to objective functions, optimizations, and network architectures.
2
-
3
- To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
- You need to implement the following five functions:
5
- -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
- -- <set_input>: unpack data from dataset and apply preprocessing.
7
- -- <forward>: produce intermediate results.
8
- -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
- -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
-
11
- In the function <__init__>, you need to define four lists:
12
- -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
- -- self.model_names (str list): define networks used in our training.
14
- -- self.visual_names (str list): specify the images that you want to display and save.
15
- -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
-
17
- Now you can use the model class by specifying flag '--model dummy'.
18
- See our template model class 'template_model.py' for more details.
19
- """
20
-
21
- import importlib
22
- from src.face3d.models.base_model import BaseModel
23
-
24
-
25
- def find_model_using_name(model_name):
26
- """Import the module "models/[model_name]_model.py".
27
-
28
- In the file, the class called DatasetNameModel() will
29
- be instantiated. It has to be a subclass of BaseModel,
30
- and it is case-insensitive.
31
- """
32
- model_filename = "face3d.models." + model_name + "_model"
33
- modellib = importlib.import_module(model_filename)
34
- model = None
35
- target_model_name = model_name.replace('_', '') + 'model'
36
- for name, cls in modellib.__dict__.items():
37
- if name.lower() == target_model_name.lower() \
38
- and issubclass(cls, BaseModel):
39
- model = cls
40
-
41
- if model is None:
42
- print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
- exit(0)
44
-
45
- return model
46
-
47
-
48
- def get_option_setter(model_name):
49
- """Return the static method <modify_commandline_options> of the model class."""
50
- model_class = find_model_using_name(model_name)
51
- return model_class.modify_commandline_options
52
-
53
-
54
- def create_model(opt):
55
- """Create a model given the option.
56
-
57
- This function warps the class CustomDatasetDataLoader.
58
- This is the main interface between this package and 'train.py'/'test.py'
59
-
60
- Example:
61
- >>> from models import create_model
62
- >>> model = create_model(opt)
63
- """
64
- model = find_model_using_name(opt.model)
65
- instance = model(opt)
66
- print("model [%s] was created" % type(instance).__name__)
67
- return instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (3.23 kB)
 
src/face3d/models/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (3.23 kB)
 
src/face3d/models/__pycache__/base_model.cpython-38.pyc DELETED
Binary file (12.4 kB)
 
src/face3d/models/__pycache__/base_model.cpython-39.pyc DELETED
Binary file (12.4 kB)
 
src/face3d/models/__pycache__/networks.cpython-38.pyc DELETED
Binary file (17.1 kB)
 
src/face3d/models/__pycache__/networks.cpython-39.pyc DELETED
Binary file (17.1 kB)
 
src/face3d/models/arcface_torch/README.md DELETED
@@ -1,164 +0,0 @@
1
- # Distributed Arcface Training in Pytorch
2
-
3
- This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
4
- identity on a single server.
5
-
6
- ## Requirements
7
-
8
- - Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
9
- - `pip install -r requirements.txt`.
10
- - Download the dataset
11
- from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
12
- .
13
-
14
- ## How to Training
15
-
16
- To train a model, run `train.py` with the path to the configs:
17
-
18
- ### 1. Single node, 8 GPUs:
19
-
20
- ```shell
21
- python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
22
- ```
23
-
24
- ### 2. Multiple nodes, each node 8 GPUs:
25
-
26
- Node 0:
27
-
28
- ```shell
29
- python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
30
- ```
31
-
32
- Node 1:
33
-
34
- ```shell
35
- python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
36
- ```
37
-
38
- ### 3.Training resnet2060 with 8 GPUs:
39
-
40
- ```shell
41
- python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
42
- ```
43
-
44
- ## Model Zoo
45
-
46
- - The models are available for non-commercial research purposes only.
47
- - All models can be found in here.
48
- - [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
49
- - [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
50
-
51
- ### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
52
-
53
- ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
54
- recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
55
- As the result, we can evaluate the FAIR performance for different algorithms.
56
-
57
- For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
58
- globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
59
-
60
- For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
61
- Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
62
- There are totally 13,928 positive pairs and 96,983,824 negative pairs.
63
-
64
- | Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
65
- | :---: | :--- | :--- | :--- |:--- |:--- |
66
- | MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
67
- | Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
68
- | MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
69
- | Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
70
- | MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
71
- | Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
72
- | MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
73
- | Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
74
- | MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
75
- | Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
76
-
77
- ### Performance on IJB-C and Verification Datasets
78
-
79
- | Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
80
- | :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
81
- | MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
82
- | MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
83
- | MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
84
- | MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
85
- | MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
86
- | Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
87
- | Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
88
- | Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
89
- | Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
90
-
91
- [comment]: <> (More details see [model.md]&#40;docs/modelzoo.md&#41; in docs.)
92
-
93
-
94
- ## [Speed Benchmark](docs/speed_benchmark.md)
95
-
96
- **Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
97
- classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
98
- accuracy with several times faster training performance and smaller GPU memory.
99
- Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
100
- sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
101
- sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
102
- we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
103
- training and mixed precision training.
104
-
105
- ![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
106
-
107
- More details see
108
- [speed_benchmark.md](docs/speed_benchmark.md) in docs.
109
-
110
- ### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
111
-
112
- `-` means training failed because of gpu memory limitations.
113
-
114
- | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
115
- | :--- | :--- | :--- | :--- |
116
- |125000 | 4681 | 4824 | 5004 |
117
- |1400000 | **1672** | 3043 | 4738 |
118
- |5500000 | **-** | **1389** | 3975 |
119
- |8000000 | **-** | **-** | 3565 |
120
- |16000000 | **-** | **-** | 2679 |
121
- |29000000 | **-** | **-** | **1855** |
122
-
123
- ### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
124
-
125
- | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
126
- | :--- | :--- | :--- | :--- |
127
- |125000 | 7358 | 5306 | 4868 |
128
- |1400000 | 32252 | 11178 | 6056 |
129
- |5500000 | **-** | 32188 | 9854 |
130
- |8000000 | **-** | **-** | 12310 |
131
- |16000000 | **-** | **-** | 19950 |
132
- |29000000 | **-** | **-** | 32324 |
133
-
134
- ## Evaluation ICCV2021-MFR and IJB-C
135
-
136
- More details see [eval.md](docs/eval.md) in docs.
137
-
138
- ## Test
139
-
140
- We tested many versions of PyTorch. Please create an issue if you are having trouble.
141
-
142
- - [x] torch 1.6.0
143
- - [x] torch 1.7.1
144
- - [x] torch 1.8.0
145
- - [x] torch 1.9.0
146
-
147
- ## Citation
148
-
149
- ```
150
- @inproceedings{deng2019arcface,
151
- title={Arcface: Additive angular margin loss for deep face recognition},
152
- author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
153
- booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
154
- pages={4690--4699},
155
- year={2019}
156
- }
157
- @inproceedings{an2020partical_fc,
158
- title={Partial FC: Training 10 Million Identities on a Single Machine},
159
- author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
160
- Zhang, Debing and Fu Ying},
161
- booktitle={Arxiv 2010.05222},
162
- year={2020}
163
- }
164
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/backbones/__init__.py DELETED
@@ -1,25 +0,0 @@
1
- from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
2
- from .mobilefacenet import get_mbf
3
-
4
-
5
- def get_model(name, **kwargs):
6
- # resnet
7
- if name == "r18":
8
- return iresnet18(False, **kwargs)
9
- elif name == "r34":
10
- return iresnet34(False, **kwargs)
11
- elif name == "r50":
12
- return iresnet50(False, **kwargs)
13
- elif name == "r100":
14
- return iresnet100(False, **kwargs)
15
- elif name == "r200":
16
- return iresnet200(False, **kwargs)
17
- elif name == "r2060":
18
- from .iresnet2060 import iresnet2060
19
- return iresnet2060(False, **kwargs)
20
- elif name == "mbf":
21
- fp16 = kwargs.get("fp16", False)
22
- num_features = kwargs.get("num_features", 512)
23
- return get_mbf(fp16=fp16, num_features=num_features)
24
- else:
25
- raise ValueError()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (823 Bytes)
 
src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (847 Bytes)
 
src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc DELETED
Binary file (5.39 kB)
 
src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-39.pyc DELETED
Binary file (5.47 kB)
 
src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc DELETED
Binary file (5.45 kB)
 
src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-39.pyc DELETED
Binary file (5.47 kB)
 
src/face3d/models/arcface_torch/backbones/iresnet.py DELETED
@@ -1,187 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
- __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
5
-
6
-
7
- def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
8
- """3x3 convolution with padding"""
9
- return nn.Conv2d(in_planes,
10
- out_planes,
11
- kernel_size=3,
12
- stride=stride,
13
- padding=dilation,
14
- groups=groups,
15
- bias=False,
16
- dilation=dilation)
17
-
18
-
19
- def conv1x1(in_planes, out_planes, stride=1):
20
- """1x1 convolution"""
21
- return nn.Conv2d(in_planes,
22
- out_planes,
23
- kernel_size=1,
24
- stride=stride,
25
- bias=False)
26
-
27
-
28
- class IBasicBlock(nn.Module):
29
- expansion = 1
30
- def __init__(self, inplanes, planes, stride=1, downsample=None,
31
- groups=1, base_width=64, dilation=1):
32
- super(IBasicBlock, self).__init__()
33
- if groups != 1 or base_width != 64:
34
- raise ValueError('BasicBlock only supports groups=1 and base_width=64')
35
- if dilation > 1:
36
- raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
37
- self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
38
- self.conv1 = conv3x3(inplanes, planes)
39
- self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
40
- self.prelu = nn.PReLU(planes)
41
- self.conv2 = conv3x3(planes, planes, stride)
42
- self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
43
- self.downsample = downsample
44
- self.stride = stride
45
-
46
- def forward(self, x):
47
- identity = x
48
- out = self.bn1(x)
49
- out = self.conv1(out)
50
- out = self.bn2(out)
51
- out = self.prelu(out)
52
- out = self.conv2(out)
53
- out = self.bn3(out)
54
- if self.downsample is not None:
55
- identity = self.downsample(x)
56
- out += identity
57
- return out
58
-
59
-
60
- class IResNet(nn.Module):
61
- fc_scale = 7 * 7
62
- def __init__(self,
63
- block, layers, dropout=0, num_features=512, zero_init_residual=False,
64
- groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
65
- super(IResNet, self).__init__()
66
- self.fp16 = fp16
67
- self.inplanes = 64
68
- self.dilation = 1
69
- if replace_stride_with_dilation is None:
70
- replace_stride_with_dilation = [False, False, False]
71
- if len(replace_stride_with_dilation) != 3:
72
- raise ValueError("replace_stride_with_dilation should be None "
73
- "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
74
- self.groups = groups
75
- self.base_width = width_per_group
76
- self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
77
- self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
78
- self.prelu = nn.PReLU(self.inplanes)
79
- self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
80
- self.layer2 = self._make_layer(block,
81
- 128,
82
- layers[1],
83
- stride=2,
84
- dilate=replace_stride_with_dilation[0])
85
- self.layer3 = self._make_layer(block,
86
- 256,
87
- layers[2],
88
- stride=2,
89
- dilate=replace_stride_with_dilation[1])
90
- self.layer4 = self._make_layer(block,
91
- 512,
92
- layers[3],
93
- stride=2,
94
- dilate=replace_stride_with_dilation[2])
95
- self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
96
- self.dropout = nn.Dropout(p=dropout, inplace=True)
97
- self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
98
- self.features = nn.BatchNorm1d(num_features, eps=1e-05)
99
- nn.init.constant_(self.features.weight, 1.0)
100
- self.features.weight.requires_grad = False
101
-
102
- for m in self.modules():
103
- if isinstance(m, nn.Conv2d):
104
- nn.init.normal_(m.weight, 0, 0.1)
105
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
106
- nn.init.constant_(m.weight, 1)
107
- nn.init.constant_(m.bias, 0)
108
-
109
- if zero_init_residual:
110
- for m in self.modules():
111
- if isinstance(m, IBasicBlock):
112
- nn.init.constant_(m.bn2.weight, 0)
113
-
114
- def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
115
- downsample = None
116
- previous_dilation = self.dilation
117
- if dilate:
118
- self.dilation *= stride
119
- stride = 1
120
- if stride != 1 or self.inplanes != planes * block.expansion:
121
- downsample = nn.Sequential(
122
- conv1x1(self.inplanes, planes * block.expansion, stride),
123
- nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
124
- )
125
- layers = []
126
- layers.append(
127
- block(self.inplanes, planes, stride, downsample, self.groups,
128
- self.base_width, previous_dilation))
129
- self.inplanes = planes * block.expansion
130
- for _ in range(1, blocks):
131
- layers.append(
132
- block(self.inplanes,
133
- planes,
134
- groups=self.groups,
135
- base_width=self.base_width,
136
- dilation=self.dilation))
137
-
138
- return nn.Sequential(*layers)
139
-
140
- def forward(self, x):
141
- with torch.cuda.amp.autocast(self.fp16):
142
- x = self.conv1(x)
143
- x = self.bn1(x)
144
- x = self.prelu(x)
145
- x = self.layer1(x)
146
- x = self.layer2(x)
147
- x = self.layer3(x)
148
- x = self.layer4(x)
149
- x = self.bn2(x)
150
- x = torch.flatten(x, 1)
151
- x = self.dropout(x)
152
- x = self.fc(x.float() if self.fp16 else x)
153
- x = self.features(x)
154
- return x
155
-
156
-
157
- def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
158
- model = IResNet(block, layers, **kwargs)
159
- if pretrained:
160
- raise ValueError()
161
- return model
162
-
163
-
164
- def iresnet18(pretrained=False, progress=True, **kwargs):
165
- return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
166
- progress, **kwargs)
167
-
168
-
169
- def iresnet34(pretrained=False, progress=True, **kwargs):
170
- return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
171
- progress, **kwargs)
172
-
173
-
174
- def iresnet50(pretrained=False, progress=True, **kwargs):
175
- return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
176
- progress, **kwargs)
177
-
178
-
179
- def iresnet100(pretrained=False, progress=True, **kwargs):
180
- return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
181
- progress, **kwargs)
182
-
183
-
184
- def iresnet200(pretrained=False, progress=True, **kwargs):
185
- return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
186
- progress, **kwargs)
187
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/backbones/iresnet2060.py DELETED
@@ -1,176 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
- assert torch.__version__ >= "1.8.1"
5
- from torch.utils.checkpoint import checkpoint_sequential
6
-
7
- __all__ = ['iresnet2060']
8
-
9
-
10
- def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
11
- """3x3 convolution with padding"""
12
- return nn.Conv2d(in_planes,
13
- out_planes,
14
- kernel_size=3,
15
- stride=stride,
16
- padding=dilation,
17
- groups=groups,
18
- bias=False,
19
- dilation=dilation)
20
-
21
-
22
- def conv1x1(in_planes, out_planes, stride=1):
23
- """1x1 convolution"""
24
- return nn.Conv2d(in_planes,
25
- out_planes,
26
- kernel_size=1,
27
- stride=stride,
28
- bias=False)
29
-
30
-
31
- class IBasicBlock(nn.Module):
32
- expansion = 1
33
-
34
- def __init__(self, inplanes, planes, stride=1, downsample=None,
35
- groups=1, base_width=64, dilation=1):
36
- super(IBasicBlock, self).__init__()
37
- if groups != 1 or base_width != 64:
38
- raise ValueError('BasicBlock only supports groups=1 and base_width=64')
39
- if dilation > 1:
40
- raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
41
- self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
42
- self.conv1 = conv3x3(inplanes, planes)
43
- self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
44
- self.prelu = nn.PReLU(planes)
45
- self.conv2 = conv3x3(planes, planes, stride)
46
- self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
47
- self.downsample = downsample
48
- self.stride = stride
49
-
50
- def forward(self, x):
51
- identity = x
52
- out = self.bn1(x)
53
- out = self.conv1(out)
54
- out = self.bn2(out)
55
- out = self.prelu(out)
56
- out = self.conv2(out)
57
- out = self.bn3(out)
58
- if self.downsample is not None:
59
- identity = self.downsample(x)
60
- out += identity
61
- return out
62
-
63
-
64
- class IResNet(nn.Module):
65
- fc_scale = 7 * 7
66
-
67
- def __init__(self,
68
- block, layers, dropout=0, num_features=512, zero_init_residual=False,
69
- groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
70
- super(IResNet, self).__init__()
71
- self.fp16 = fp16
72
- self.inplanes = 64
73
- self.dilation = 1
74
- if replace_stride_with_dilation is None:
75
- replace_stride_with_dilation = [False, False, False]
76
- if len(replace_stride_with_dilation) != 3:
77
- raise ValueError("replace_stride_with_dilation should be None "
78
- "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
79
- self.groups = groups
80
- self.base_width = width_per_group
81
- self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
82
- self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
83
- self.prelu = nn.PReLU(self.inplanes)
84
- self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
85
- self.layer2 = self._make_layer(block,
86
- 128,
87
- layers[1],
88
- stride=2,
89
- dilate=replace_stride_with_dilation[0])
90
- self.layer3 = self._make_layer(block,
91
- 256,
92
- layers[2],
93
- stride=2,
94
- dilate=replace_stride_with_dilation[1])
95
- self.layer4 = self._make_layer(block,
96
- 512,
97
- layers[3],
98
- stride=2,
99
- dilate=replace_stride_with_dilation[2])
100
- self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
101
- self.dropout = nn.Dropout(p=dropout, inplace=True)
102
- self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
103
- self.features = nn.BatchNorm1d(num_features, eps=1e-05)
104
- nn.init.constant_(self.features.weight, 1.0)
105
- self.features.weight.requires_grad = False
106
-
107
- for m in self.modules():
108
- if isinstance(m, nn.Conv2d):
109
- nn.init.normal_(m.weight, 0, 0.1)
110
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
111
- nn.init.constant_(m.weight, 1)
112
- nn.init.constant_(m.bias, 0)
113
-
114
- if zero_init_residual:
115
- for m in self.modules():
116
- if isinstance(m, IBasicBlock):
117
- nn.init.constant_(m.bn2.weight, 0)
118
-
119
- def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
120
- downsample = None
121
- previous_dilation = self.dilation
122
- if dilate:
123
- self.dilation *= stride
124
- stride = 1
125
- if stride != 1 or self.inplanes != planes * block.expansion:
126
- downsample = nn.Sequential(
127
- conv1x1(self.inplanes, planes * block.expansion, stride),
128
- nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
129
- )
130
- layers = []
131
- layers.append(
132
- block(self.inplanes, planes, stride, downsample, self.groups,
133
- self.base_width, previous_dilation))
134
- self.inplanes = planes * block.expansion
135
- for _ in range(1, blocks):
136
- layers.append(
137
- block(self.inplanes,
138
- planes,
139
- groups=self.groups,
140
- base_width=self.base_width,
141
- dilation=self.dilation))
142
-
143
- return nn.Sequential(*layers)
144
-
145
- def checkpoint(self, func, num_seg, x):
146
- if self.training:
147
- return checkpoint_sequential(func, num_seg, x)
148
- else:
149
- return func(x)
150
-
151
- def forward(self, x):
152
- with torch.cuda.amp.autocast(self.fp16):
153
- x = self.conv1(x)
154
- x = self.bn1(x)
155
- x = self.prelu(x)
156
- x = self.layer1(x)
157
- x = self.checkpoint(self.layer2, 20, x)
158
- x = self.checkpoint(self.layer3, 100, x)
159
- x = self.layer4(x)
160
- x = self.bn2(x)
161
- x = torch.flatten(x, 1)
162
- x = self.dropout(x)
163
- x = self.fc(x.float() if self.fp16 else x)
164
- x = self.features(x)
165
- return x
166
-
167
-
168
- def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
169
- model = IResNet(block, layers, **kwargs)
170
- if pretrained:
171
- raise ValueError()
172
- return model
173
-
174
-
175
- def iresnet2060(pretrained=False, progress=True, **kwargs):
176
- return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/backbones/mobilefacenet.py DELETED
@@ -1,130 +0,0 @@
1
- '''
2
- Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
3
- Original author cavalleria
4
- '''
5
-
6
- import torch.nn as nn
7
- from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
8
- import torch
9
-
10
-
11
- class Flatten(Module):
12
- def forward(self, x):
13
- return x.view(x.size(0), -1)
14
-
15
-
16
- class ConvBlock(Module):
17
- def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
18
- super(ConvBlock, self).__init__()
19
- self.layers = nn.Sequential(
20
- Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
21
- BatchNorm2d(num_features=out_c),
22
- PReLU(num_parameters=out_c)
23
- )
24
-
25
- def forward(self, x):
26
- return self.layers(x)
27
-
28
-
29
- class LinearBlock(Module):
30
- def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
31
- super(LinearBlock, self).__init__()
32
- self.layers = nn.Sequential(
33
- Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
34
- BatchNorm2d(num_features=out_c)
35
- )
36
-
37
- def forward(self, x):
38
- return self.layers(x)
39
-
40
-
41
- class DepthWise(Module):
42
- def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
43
- super(DepthWise, self).__init__()
44
- self.residual = residual
45
- self.layers = nn.Sequential(
46
- ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
47
- ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
48
- LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
49
- )
50
-
51
- def forward(self, x):
52
- short_cut = None
53
- if self.residual:
54
- short_cut = x
55
- x = self.layers(x)
56
- if self.residual:
57
- output = short_cut + x
58
- else:
59
- output = x
60
- return output
61
-
62
-
63
- class Residual(Module):
64
- def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
65
- super(Residual, self).__init__()
66
- modules = []
67
- for _ in range(num_block):
68
- modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
69
- self.layers = Sequential(*modules)
70
-
71
- def forward(self, x):
72
- return self.layers(x)
73
-
74
-
75
- class GDC(Module):
76
- def __init__(self, embedding_size):
77
- super(GDC, self).__init__()
78
- self.layers = nn.Sequential(
79
- LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
80
- Flatten(),
81
- Linear(512, embedding_size, bias=False),
82
- BatchNorm1d(embedding_size))
83
-
84
- def forward(self, x):
85
- return self.layers(x)
86
-
87
-
88
- class MobileFaceNet(Module):
89
- def __init__(self, fp16=False, num_features=512):
90
- super(MobileFaceNet, self).__init__()
91
- scale = 2
92
- self.fp16 = fp16
93
- self.layers = nn.Sequential(
94
- ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
95
- ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
96
- DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
97
- Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
98
- DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
99
- Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
100
- DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
101
- Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
102
- )
103
- self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
104
- self.features = GDC(num_features)
105
- self._initialize_weights()
106
-
107
- def _initialize_weights(self):
108
- for m in self.modules():
109
- if isinstance(m, nn.Conv2d):
110
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
111
- if m.bias is not None:
112
- m.bias.data.zero_()
113
- elif isinstance(m, nn.BatchNorm2d):
114
- m.weight.data.fill_(1)
115
- m.bias.data.zero_()
116
- elif isinstance(m, nn.Linear):
117
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
118
- if m.bias is not None:
119
- m.bias.data.zero_()
120
-
121
- def forward(self, x):
122
- with torch.cuda.amp.autocast(self.fp16):
123
- x = self.layers(x)
124
- x = self.conv_sep(x.float() if self.fp16 else x)
125
- x = self.features(x)
126
- return x
127
-
128
-
129
- def get_mbf(fp16, num_features):
130
- return MobileFaceNet(fp16, num_features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/3millions.py DELETED
@@ -1,23 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # configs for test speed
4
-
5
- config = edict()
6
- config.loss = "arcface"
7
- config.network = "r50"
8
- config.resume = False
9
- config.output = None
10
- config.embedding_size = 512
11
- config.sample_rate = 1.0
12
- config.fp16 = True
13
- config.momentum = 0.9
14
- config.weight_decay = 5e-4
15
- config.batch_size = 128
16
- config.lr = 0.1 # batch size is 512
17
-
18
- config.rec = "synthetic"
19
- config.num_classes = 300 * 10000
20
- config.num_epoch = 30
21
- config.warmup_epoch = -1
22
- config.decay_epoch = [10, 16, 22]
23
- config.val_targets = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/3millions_pfc.py DELETED
@@ -1,23 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # configs for test speed
4
-
5
- config = edict()
6
- config.loss = "arcface"
7
- config.network = "r50"
8
- config.resume = False
9
- config.output = None
10
- config.embedding_size = 512
11
- config.sample_rate = 0.1
12
- config.fp16 = True
13
- config.momentum = 0.9
14
- config.weight_decay = 5e-4
15
- config.batch_size = 128
16
- config.lr = 0.1 # batch size is 512
17
-
18
- config.rec = "synthetic"
19
- config.num_classes = 300 * 10000
20
- config.num_epoch = 30
21
- config.warmup_epoch = -1
22
- config.decay_epoch = [10, 16, 22]
23
- config.val_targets = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/__init__.py DELETED
File without changes