Reevee commited on
Commit
f39e999
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. README.md +11 -0
  3. app.py +31 -0
  4. configs/__init__.py +0 -0
  5. configs/data_configs.py +41 -0
  6. configs/paths_config.py +23 -0
  7. configs/transforms_config.py +152 -0
  8. criteria/__init__.py +0 -0
  9. criteria/id_loss.py +44 -0
  10. criteria/lpips/__init__.py +0 -0
  11. criteria/lpips/lpips.py +35 -0
  12. criteria/lpips/networks.py +96 -0
  13. criteria/lpips/utils.py +30 -0
  14. criteria/moco_loss.py +69 -0
  15. criteria/w_norm.py +14 -0
  16. datasets/__init__.py +0 -0
  17. datasets/augmentations.py +110 -0
  18. datasets/gt_res_dataset.py +32 -0
  19. datasets/images_dataset.py +33 -0
  20. datasets/inference_dataset.py +22 -0
  21. dnnlib/__init__.py +9 -0
  22. dnnlib/util.py +477 -0
  23. legacy.py +384 -0
  24. model_build.py +95 -0
  25. models/__init__.py +0 -0
  26. models/encoders/__init__.py +0 -0
  27. models/encoders/helpers.py +119 -0
  28. models/encoders/model_irse.py +84 -0
  29. models/encoders/psp_encoders.py +186 -0
  30. models/mtcnn/__init__.py +0 -0
  31. models/mtcnn/mtcnn.py +156 -0
  32. models/mtcnn/mtcnn_pytorch/__init__.py +0 -0
  33. models/mtcnn/mtcnn_pytorch/src/__init__.py +2 -0
  34. models/mtcnn/mtcnn_pytorch/src/align_trans.py +304 -0
  35. models/mtcnn/mtcnn_pytorch/src/box_utils.py +238 -0
  36. models/mtcnn/mtcnn_pytorch/src/detector.py +126 -0
  37. models/mtcnn/mtcnn_pytorch/src/first_stage.py +101 -0
  38. models/mtcnn/mtcnn_pytorch/src/get_nets.py +171 -0
  39. models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py +350 -0
  40. models/mtcnn/mtcnn_pytorch/src/visualization_utils.py +31 -0
  41. models/psp.py +118 -0
  42. models/stylegan2/__init__.py +0 -0
  43. models/stylegan2/model.py +674 -0
  44. models/stylegan2/op/__init__.py +2 -0
  45. models/stylegan2/op/fused_act.py +37 -0
  46. models/stylegan2/op/upfirdn2d.py +60 -0
  47. pretrained/ohayou_face.pkl +3 -0
  48. pretrained/ohayou_face.pt +3 -0
  49. requirements.txt +10 -0
  50. torch_utils/__init__.py +9 -0
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
1
+ pretrained/ohayou_face.pt filter=lfs diff=lfs merge=lfs -text
2
+ pretrained/ohayou_face.pkl filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Ohayou_Face
3
+ emoji: ⚡
4
+ colorFrom: red
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import gradio as gr
4
+ from torchvision import transforms
5
+ import easydict
6
+ import torch
7
+ import numpy as np
8
+ import model_build
9
+
10
+
11
+ psp = model_build.build_psp()
12
+ stylegan2 = model_build.build_stylegan2()
13
+
14
+ pretransform = transforms.Compose([
15
+ transforms.Resize((256, 256)),
16
+ transforms.ToTensor(),
17
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
18
+
19
+ def pipeline(img):
20
+ img = model_build.img_preprocess(img, pretransform)
21
+ with torch.no_grad():
22
+ _, latent_space = psp(img.float(), randomize_noise=True, resize=False, return_latents=True)
23
+ img = stylegan2(latent_space, noise_mode='none')
24
+ img = Image.fromarray(np.array((img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).squeeze(0)[20:-20,:,:]))
25
+ img.save('output.png')
26
+ return 'output.png'
27
+
28
+ examples=[['momoi_out.png',False], ['churuki_out.png', False], ['fgfgfggf.png', False], ['dsfd.png', False]]
29
+ description="The male image doesn't work well. 1:1 ratio image recommended (square cropable after uploading). If the background is not monochromatic, it can be mixed with hair color. It takes an average of 5 seconds, but it can take longer if there is a lot of traffic. 남성 이미지에는 잘 작동하지 않음. 1:1비율 권장(업로드 후 정사각형 자르기 가능), 배경이 단색이 아니면 머리색과 섞일 수 있음. 트래픽이 많으면 5초 이상 걸릴 수 있음. Email:krkmfn@gmail.com"
30
+ gr.Interface(pipeline, [gr.inputs.Image(type="pil")], gr.outputs.Image(type="file"),description=description,allow_flagging=False,examples=examples,allow_screenshot=False,enable_queue=False).launch()
31
+
configs/__init__.py ADDED
File without changes
configs/data_configs.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs import transforms_config
2
+ from configs.paths_config import dataset_paths
3
+
4
+
5
+ DATASETS = {
6
+ 'ffhq_encode': {
7
+ 'transforms': transforms_config.EncodeTransforms,
8
+ 'train_source_root': dataset_paths['ffhq'],
9
+ 'train_target_root': dataset_paths['ffhq'],
10
+ 'test_source_root': dataset_paths['celeba_test'],
11
+ 'test_target_root': dataset_paths['celeba_test'],
12
+ },
13
+ 'furry': {
14
+ 'transforms': transforms_config.FrontalizationTransforms,
15
+ 'train_source_root': dataset_paths['anime'],
16
+ 'train_target_root': dataset_paths['anime'],
17
+ 'test_source_root': dataset_paths['gogal'],
18
+ 'test_target_root': dataset_paths['gogal'],
19
+ },
20
+ 'celebs_sketch_to_face': {
21
+ 'transforms': transforms_config.SketchToImageTransforms,
22
+ 'train_source_root': dataset_paths['celeba_train_sketch'],
23
+ 'train_target_root': dataset_paths['celeba_train'],
24
+ 'test_source_root': dataset_paths['celeba_test_sketch'],
25
+ 'test_target_root': dataset_paths['celeba_test'],
26
+ },
27
+ 'celebs_seg_to_face': {
28
+ 'transforms': transforms_config.SegToImageTransforms,
29
+ 'train_source_root': dataset_paths['celeba_train_segmentation'],
30
+ 'train_target_root': dataset_paths['celeba_train'],
31
+ 'test_source_root': dataset_paths['celeba_test_segmentation'],
32
+ 'test_target_root': dataset_paths['celeba_test'],
33
+ },
34
+ 'celebs_super_resolution': {
35
+ 'transforms': transforms_config.SuperResTransforms,
36
+ 'train_source_root': dataset_paths['celeba_train'],
37
+ 'train_target_root': dataset_paths['celeba_train'],
38
+ 'test_source_root': dataset_paths['celeba_test'],
39
+ 'test_target_root': dataset_paths['celeba_test'],
40
+ },
41
+ }
configs/paths_config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_paths = {
2
+ 'celeba_train': '',
3
+ 'celeba_test': '',
4
+ 'celeba_train_sketch': '',
5
+ 'celeba_test_sketch': '',
6
+ 'celeba_train_segmentation': '',
7
+ 'celeba_test_segmentation': '',
8
+ 'ffhq': '',
9
+ 'anime' : '/content/drive/MyDrive/Dataset/anime',
10
+ 'gogal' : '/content/drive/MyDrive/All Data/고갈왕'
11
+ }
12
+
13
+ model_paths = {
14
+ 'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
15
+ 'ir_se50': 'pretrained_models/model_ir_se50.pth',
16
+ 'circular_face': 'pretrained_models/CurricularFace_Backbone.pth',
17
+ 'mtcnn_pnet': 'pretrained_models/mtcnn/pnet.npy',
18
+ 'mtcnn_rnet': 'pretrained_models/mtcnn/rnet.npy',
19
+ 'mtcnn_onet': 'pretrained_models/mtcnn/onet.npy',
20
+ 'shape_predictor': 'shape_predictor_68_face_landmarks.dat',
21
+ 'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth.tar',
22
+ 'anime' : '/content/drive/MyDrive/StyleGAN2-ada/result/pretrained/anime_face.pt'
23
+ }
configs/transforms_config.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import torchvision.transforms as transforms
3
+ from datasets import augmentations
4
+
5
+
6
+ class TransformsConfig(object):
7
+
8
+ def __init__(self, opts):
9
+ self.opts = opts
10
+
11
+ @abstractmethod
12
+ def get_transforms(self):
13
+ pass
14
+
15
+
16
+ class EncodeTransforms(TransformsConfig):
17
+
18
+ def __init__(self, opts):
19
+ super(EncodeTransforms, self).__init__(opts)
20
+
21
+ def get_transforms(self):
22
+ transforms_dict = {
23
+ 'transform_gt_train': transforms.Compose([
24
+ transforms.Resize((256, 256)),
25
+ transforms.RandomHorizontalFlip(0.5),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
28
+ 'transform_source': None,
29
+ 'transform_test': transforms.Compose([
30
+ transforms.Resize((256, 256)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
33
+ 'transform_inference': transforms.Compose([
34
+ transforms.Resize((256, 256)),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
37
+ }
38
+ return transforms_dict
39
+
40
+
41
+ class FrontalizationTransforms(TransformsConfig):
42
+
43
+ def __init__(self, opts):
44
+ super(FrontalizationTransforms, self).__init__(opts)
45
+
46
+ def get_transforms(self):
47
+ transforms_dict = {
48
+ 'transform_gt_train': transforms.Compose([
49
+ transforms.Resize((256, 256)),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
52
+ 'transform_source': transforms.Compose([
53
+ transforms.Resize((256, 256)),
54
+ transforms.ToTensor(),
55
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
56
+ 'transform_test': transforms.Compose([
57
+ transforms.Resize((256, 256)),
58
+ transforms.ToTensor(),
59
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
60
+ 'transform_inference': transforms.Compose([
61
+ transforms.Resize((256, 256)),
62
+ transforms.ToTensor(),
63
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
64
+ }
65
+ return transforms_dict
66
+
67
+
68
+ class SketchToImageTransforms(TransformsConfig):
69
+
70
+ def __init__(self, opts):
71
+ super(SketchToImageTransforms, self).__init__(opts)
72
+
73
+ def get_transforms(self):
74
+ transforms_dict = {
75
+ 'transform_gt_train': transforms.Compose([
76
+ transforms.Resize((256, 256)),
77
+ transforms.ToTensor(),
78
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
79
+ 'transform_source': transforms.Compose([
80
+ transforms.Resize((256, 256)),
81
+ transforms.ToTensor()]),
82
+ 'transform_test': transforms.Compose([
83
+ transforms.Resize((256, 256)),
84
+ transforms.ToTensor(),
85
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
86
+ 'transform_inference': transforms.Compose([
87
+ transforms.Resize((256, 256)),
88
+ transforms.ToTensor()]),
89
+ }
90
+ return transforms_dict
91
+
92
+
93
+ class SegToImageTransforms(TransformsConfig):
94
+
95
+ def __init__(self, opts):
96
+ super(SegToImageTransforms, self).__init__(opts)
97
+
98
+ def get_transforms(self):
99
+ transforms_dict = {
100
+ 'transform_gt_train': transforms.Compose([
101
+ transforms.Resize((256, 256)),
102
+ transforms.ToTensor(),
103
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
104
+ 'transform_source': transforms.Compose([
105
+ transforms.Resize((256, 256)),
106
+ augmentations.ToOneHot(self.opts.label_nc),
107
+ transforms.ToTensor()]),
108
+ 'transform_test': transforms.Compose([
109
+ transforms.Resize((256, 256)),
110
+ transforms.ToTensor(),
111
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
112
+ 'transform_inference': transforms.Compose([
113
+ transforms.Resize((256, 256)),
114
+ augmentations.ToOneHot(self.opts.label_nc),
115
+ transforms.ToTensor()])
116
+ }
117
+ return transforms_dict
118
+
119
+
120
+ class SuperResTransforms(TransformsConfig):
121
+
122
+ def __init__(self, opts):
123
+ super(SuperResTransforms, self).__init__(opts)
124
+
125
+ def get_transforms(self):
126
+ if self.opts.resize_factors is None:
127
+ self.opts.resize_factors = '1,2,4,8,16,32'
128
+ factors = [int(f) for f in self.opts.resize_factors.split(",")]
129
+ print("Performing down-sampling with factors: {}".format(factors))
130
+ transforms_dict = {
131
+ 'transform_gt_train': transforms.Compose([
132
+ transforms.Resize((256, 256)),
133
+ transforms.ToTensor(),
134
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
135
+ 'transform_source': transforms.Compose([
136
+ transforms.Resize((256, 256)),
137
+ augmentations.BilinearResize(factors=factors),
138
+ transforms.Resize((256, 256)),
139
+ transforms.ToTensor(),
140
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
141
+ 'transform_test': transforms.Compose([
142
+ transforms.Resize((256, 256)),
143
+ transforms.ToTensor(),
144
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
145
+ 'transform_inference': transforms.Compose([
146
+ transforms.Resize((256, 256)),
147
+ augmentations.BilinearResize(factors=factors),
148
+ transforms.Resize((256, 256)),
149
+ transforms.ToTensor(),
150
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
151
+ }
152
+ return transforms_dict
criteria/__init__.py ADDED
File without changes
criteria/id_loss.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from configs.paths_config import model_paths
4
+ from models.encoders.model_irse import Backbone
5
+
6
+
7
+ class IDLoss(nn.Module):
8
+ def __init__(self):
9
+ super(IDLoss, self).__init__()
10
+ print('Loading ResNet ArcFace')
11
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
12
+ self.facenet.load_state_dict(torch.load(model_paths['ir_se50']))
13
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
14
+ self.facenet.eval()
15
+
16
+ def extract_feats(self, x):
17
+ x = x[:, :, 35:223, 32:220] # Crop interesting region
18
+ x = self.face_pool(x)
19
+ x_feats = self.facenet(x)
20
+ return x_feats
21
+
22
+ def forward(self, y_hat, y, x):
23
+ n_samples = x.shape[0]
24
+ x_feats = self.extract_feats(x)
25
+ y_feats = self.extract_feats(y) # Otherwise use the feature from there
26
+ y_hat_feats = self.extract_feats(y_hat)
27
+ y_feats = y_feats.detach()
28
+ loss = 0
29
+ sim_improvement = 0
30
+ id_logs = []
31
+ count = 0
32
+ for i in range(n_samples):
33
+ diff_target = y_hat_feats[i].dot(y_feats[i])
34
+ diff_input = y_hat_feats[i].dot(x_feats[i])
35
+ diff_views = y_feats[i].dot(x_feats[i])
36
+ id_logs.append({'diff_target': float(diff_target),
37
+ 'diff_input': float(diff_input),
38
+ 'diff_views': float(diff_views)})
39
+ loss += 1 - diff_target
40
+ id_diff = float(diff_target) - float(diff_views)
41
+ sim_improvement += id_diff
42
+ count += 1
43
+
44
+ return loss / count, sim_improvement / count, id_logs
criteria/lpips/__init__.py ADDED
File without changes
criteria/lpips/lpips.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from criteria.lpips.networks import get_network, LinLayers
5
+ from criteria.lpips.utils import get_state_dict
6
+
7
+
8
+ class LPIPS(nn.Module):
9
+ r"""Creates a criterion that measures
10
+ Learned Perceptual Image Patch Similarity (LPIPS).
11
+ Arguments:
12
+ net_type (str): the network type to compare the features:
13
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
14
+ version (str): the version of LPIPS. Default: 0.1.
15
+ """
16
+ def __init__(self, net_type: str = 'alex', version: str = '0.1'):
17
+
18
+ assert version in ['0.1'], 'v0.1 is only supported now'
19
+
20
+ super(LPIPS, self).__init__()
21
+
22
+ # pretrained network
23
+ self.net = get_network(net_type).to("cuda")
24
+
25
+ # linear layers
26
+ self.lin = LinLayers(self.net.n_channels_list).to("cuda")
27
+ self.lin.load_state_dict(get_state_dict(net_type, version))
28
+
29
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
30
+ feat_x, feat_y = self.net(x), self.net(y)
31
+
32
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
33
+ res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
34
+
35
+ return torch.sum(torch.cat(res, 0)) / x.shape[0]
criteria/lpips/networks.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ from itertools import chain
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from criteria.lpips.utils import normalize_activation
10
+
11
+
12
+ def get_network(net_type: str):
13
+ if net_type == 'alex':
14
+ return AlexNet()
15
+ elif net_type == 'squeeze':
16
+ return SqueezeNet()
17
+ elif net_type == 'vgg':
18
+ return VGG16()
19
+ else:
20
+ raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21
+
22
+
23
+ class LinLayers(nn.ModuleList):
24
+ def __init__(self, n_channels_list: Sequence[int]):
25
+ super(LinLayers, self).__init__([
26
+ nn.Sequential(
27
+ nn.Identity(),
28
+ nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29
+ ) for nc in n_channels_list
30
+ ])
31
+
32
+ for param in self.parameters():
33
+ param.requires_grad = False
34
+
35
+
36
+ class BaseNet(nn.Module):
37
+ def __init__(self):
38
+ super(BaseNet, self).__init__()
39
+
40
+ # register buffer
41
+ self.register_buffer(
42
+ 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43
+ self.register_buffer(
44
+ 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45
+
46
+ def set_requires_grad(self, state: bool):
47
+ for param in chain(self.parameters(), self.buffers()):
48
+ param.requires_grad = state
49
+
50
+ def z_score(self, x: torch.Tensor):
51
+ return (x - self.mean) / self.std
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ x = self.z_score(x)
55
+
56
+ output = []
57
+ for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58
+ x = layer(x)
59
+ if i in self.target_layers:
60
+ output.append(normalize_activation(x))
61
+ if len(output) == len(self.target_layers):
62
+ break
63
+ return output
64
+
65
+
66
+ class SqueezeNet(BaseNet):
67
+ def __init__(self):
68
+ super(SqueezeNet, self).__init__()
69
+
70
+ self.layers = models.squeezenet1_1(True).features
71
+ self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72
+ self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73
+
74
+ self.set_requires_grad(False)
75
+
76
+
77
+ class AlexNet(BaseNet):
78
+ def __init__(self):
79
+ super(AlexNet, self).__init__()
80
+
81
+ self.layers = models.alexnet(True).features
82
+ self.target_layers = [2, 5, 8, 10, 12]
83
+ self.n_channels_list = [64, 192, 384, 256, 256]
84
+
85
+ self.set_requires_grad(False)
86
+
87
+
88
+ class VGG16(BaseNet):
89
+ def __init__(self):
90
+ super(VGG16, self).__init__()
91
+
92
+ self.layers = models.vgg16(True).features
93
+ self.target_layers = [4, 9, 16, 23, 30]
94
+ self.n_channels_list = [64, 128, 256, 512, 512]
95
+
96
+ self.set_requires_grad(False)
criteria/lpips/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+
5
+
6
+ def normalize_activation(x, eps=1e-10):
7
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8
+ return x / (norm_factor + eps)
9
+
10
+
11
+ def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12
+ # build url
13
+ url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14
+ + f'master/lpips/weights/v{version}/{net_type}.pth'
15
+
16
+ # download
17
+ old_state_dict = torch.hub.load_state_dict_from_url(
18
+ url, progress=True,
19
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
20
+ )
21
+
22
+ # rename keys
23
+ new_state_dict = OrderedDict()
24
+ for key, val in old_state_dict.items():
25
+ new_key = key
26
+ new_key = new_key.replace('lin', '')
27
+ new_key = new_key.replace('model.', '')
28
+ new_state_dict[new_key] = val
29
+
30
+ return new_state_dict
criteria/moco_loss.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from configs.paths_config import model_paths
5
+
6
+
7
+ class MocoLoss(nn.Module):
8
+
9
+ def __init__(self):
10
+ super(MocoLoss, self).__init__()
11
+ print("Loading MOCO model from path: {}".format(model_paths["moco"]))
12
+ self.model = self.__load_model()
13
+ self.model.cuda()
14
+ self.model.eval()
15
+
16
+ @staticmethod
17
+ def __load_model():
18
+ import torchvision.models as models
19
+ model = models.__dict__["resnet50"]()
20
+ # freeze all layers but the last fc
21
+ for name, param in model.named_parameters():
22
+ if name not in ['fc.weight', 'fc.bias']:
23
+ param.requires_grad = False
24
+ checkpoint = torch.load(model_paths['moco'], map_location="cpu")
25
+ state_dict = checkpoint['state_dict']
26
+ # rename moco pre-trained keys
27
+ for k in list(state_dict.keys()):
28
+ # retain only encoder_q up to before the embedding layer
29
+ if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
30
+ # remove prefix
31
+ state_dict[k[len("module.encoder_q."):]] = state_dict[k]
32
+ # delete renamed or unused k
33
+ del state_dict[k]
34
+ msg = model.load_state_dict(state_dict, strict=False)
35
+ assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
36
+ # remove output layer
37
+ model = nn.Sequential(*list(model.children())[:-1]).cuda()
38
+ return model
39
+
40
+ def extract_feats(self, x):
41
+ x = F.interpolate(x, size=224)
42
+ x_feats = self.model(x)
43
+ x_feats = nn.functional.normalize(x_feats, dim=1)
44
+ x_feats = x_feats.squeeze()
45
+ return x_feats
46
+
47
+ def forward(self, y_hat, y, x):
48
+ n_samples = x.shape[0]
49
+ x_feats = self.extract_feats(x)
50
+ y_feats = self.extract_feats(y)
51
+ y_hat_feats = self.extract_feats(y_hat)
52
+ y_feats = y_feats.detach()
53
+ loss = 0
54
+ sim_improvement = 0
55
+ sim_logs = []
56
+ count = 0
57
+ for i in range(n_samples):
58
+ diff_target = y_hat_feats[i].dot(y_feats[i])
59
+ diff_input = y_hat_feats[i].dot(x_feats[i])
60
+ diff_views = y_feats[i].dot(x_feats[i])
61
+ sim_logs.append({'diff_target': float(diff_target),
62
+ 'diff_input': float(diff_input),
63
+ 'diff_views': float(diff_views)})
64
+ loss += 1 - diff_target
65
+ sim_diff = float(diff_target) - float(diff_views)
66
+ sim_improvement += sim_diff
67
+ count += 1
68
+
69
+ return loss / count, sim_improvement / count, sim_logs
criteria/w_norm.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class WNormLoss(nn.Module):
6
+
7
+ def __init__(self, start_from_latent_avg=True):
8
+ super(WNormLoss, self).__init__()
9
+ self.start_from_latent_avg = start_from_latent_avg
10
+
11
+ def forward(self, latent, latent_avg=None):
12
+ if self.start_from_latent_avg:
13
+ latent = latent - latent_avg
14
+ return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]
datasets/__init__.py ADDED
File without changes
datasets/augmentations.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from torchvision import transforms
6
+
7
+
8
+ class ToOneHot(object):
9
+ """ Convert the input PIL image to a one-hot torch tensor """
10
+ def __init__(self, n_classes=None):
11
+ self.n_classes = n_classes
12
+
13
+ def onehot_initialization(self, a):
14
+ if self.n_classes is None:
15
+ self.n_classes = len(np.unique(a))
16
+ out = np.zeros(a.shape + (self.n_classes, ), dtype=int)
17
+ out[self.__all_idx(a, axis=2)] = 1
18
+ return out
19
+
20
+ def __all_idx(self, idx, axis):
21
+ grid = np.ogrid[tuple(map(slice, idx.shape))]
22
+ grid.insert(axis, idx)
23
+ return tuple(grid)
24
+
25
+ def __call__(self, img):
26
+ img = np.array(img)
27
+ one_hot = self.onehot_initialization(img)
28
+ return one_hot
29
+
30
+
31
+ class BilinearResize(object):
32
+ def __init__(self, factors=[1, 2, 4, 8, 16, 32]):
33
+ self.factors = factors
34
+
35
+ def __call__(self, image):
36
+ factor = np.random.choice(self.factors, size=1)[0]
37
+ D = BicubicDownSample(factor=factor, cuda=False)
38
+ img_tensor = transforms.ToTensor()(image).unsqueeze(0)
39
+ img_tensor_lr = D(img_tensor)[0].clamp(0, 1)
40
+ img_low_res = transforms.ToPILImage()(img_tensor_lr)
41
+ return img_low_res
42
+
43
+
44
+ class BicubicDownSample(nn.Module):
45
+ def bicubic_kernel(self, x, a=-0.50):
46
+ """
47
+ This equation is exactly copied from the website below:
48
+ https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic
49
+ """
50
+ abs_x = torch.abs(x)
51
+ if abs_x <= 1.:
52
+ return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1
53
+ elif 1. < abs_x < 2.:
54
+ return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a
55
+ else:
56
+ return 0.0
57
+
58
+ def __init__(self, factor=4, cuda=True, padding='reflect'):
59
+ super().__init__()
60
+ self.factor = factor
61
+ size = factor * 4
62
+ k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor)
63
+ for i in range(size)], dtype=torch.float32)
64
+ k = k / torch.sum(k)
65
+ k1 = torch.reshape(k, shape=(1, 1, size, 1))
66
+ self.k1 = torch.cat([k1, k1, k1], dim=0)
67
+ k2 = torch.reshape(k, shape=(1, 1, 1, size))
68
+ self.k2 = torch.cat([k2, k2, k2], dim=0)
69
+ self.cuda = '.cuda' if cuda else ''
70
+ self.padding = padding
71
+ for param in self.parameters():
72
+ param.requires_grad = False
73
+
74
+ def forward(self, x, nhwc=False, clip_round=False, byte_output=False):
75
+ filter_height = self.factor * 4
76
+ filter_width = self.factor * 4
77
+ stride = self.factor
78
+
79
+ pad_along_height = max(filter_height - stride, 0)
80
+ pad_along_width = max(filter_width - stride, 0)
81
+ filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda))
82
+ filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda))
83
+
84
+ # compute actual padding values for each side
85
+ pad_top = pad_along_height // 2
86
+ pad_bottom = pad_along_height - pad_top
87
+ pad_left = pad_along_width // 2
88
+ pad_right = pad_along_width - pad_left
89
+
90
+ # apply mirror padding
91
+ if nhwc:
92
+ x = torch.transpose(torch.transpose(x, 2, 3), 1, 2) # NHWC to NCHW
93
+
94
+ # downscaling performed by 1-d convolution
95
+ x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding)
96
+ x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3)
97
+ if clip_round:
98
+ x = torch.clamp(torch.round(x), 0.0, 255.)
99
+
100
+ x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding)
101
+ x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3)
102
+ if clip_round:
103
+ x = torch.clamp(torch.round(x), 0.0, 255.)
104
+
105
+ if nhwc:
106
+ x = torch.transpose(torch.transpose(x, 1, 3), 1, 2)
107
+ if byte_output:
108
+ return x.type('torch.ByteTensor'.format(self.cuda))
109
+ else:
110
+ return x
datasets/gt_res_dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # encoding: utf-8
3
+ import os
4
+ from torch.utils.data import Dataset
5
+ from PIL import Image
6
+
7
+
8
+ class GTResDataset(Dataset):
9
+
10
+ def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
11
+ self.pairs = []
12
+ for f in os.listdir(root_path):
13
+ image_path = os.path.join(root_path, f)
14
+ gt_path = os.path.join(gt_dir, f)
15
+ if f.endswith(".jpg") or f.endswith(".png"):
16
+ self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
17
+ self.transform = transform
18
+ self.transform_train = transform_train
19
+
20
+ def __len__(self):
21
+ return len(self.pairs)
22
+
23
+ def __getitem__(self, index):
24
+ from_path, to_path, _ = self.pairs[index]
25
+ from_im = Image.open(from_path).convert('RGB')
26
+ to_im = Image.open(to_path).convert('RGB')
27
+
28
+ if self.transform:
29
+ to_im = self.transform(to_im)
30
+ from_im = self.transform(from_im)
31
+
32
+ return from_im, to_im
datasets/images_dataset.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ from utils import data_utils
4
+
5
+
6
+ class ImagesDataset(Dataset):
7
+
8
+ def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None):
9
+ self.source_paths = sorted(data_utils.make_dataset(source_root))
10
+ self.target_paths = sorted(data_utils.make_dataset(target_root))
11
+ self.source_transform = source_transform
12
+ self.target_transform = target_transform
13
+ self.opts = opts
14
+
15
+ def __len__(self):
16
+ return len(self.source_paths)
17
+
18
+ def __getitem__(self, index):
19
+ from_path = self.source_paths[index]
20
+ from_im = Image.open(from_path)
21
+ from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L')
22
+
23
+ to_path = self.target_paths[index]
24
+ to_im = Image.open(to_path).convert('RGB')
25
+ if self.target_transform:
26
+ to_im = self.target_transform(to_im)
27
+
28
+ if self.source_transform:
29
+ from_im = self.source_transform(from_im)
30
+ else:
31
+ from_im = to_im
32
+
33
+ return from_im, to_im
datasets/inference_dataset.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ from utils import data_utils
4
+
5
+
6
+ class InferenceDataset(Dataset):
7
+
8
+ def __init__(self, root, opts, transform=None):
9
+ self.paths = sorted(data_utils.make_dataset(root))
10
+ self.transform = transform
11
+ self.opts = opts
12
+
13
+ def __len__(self):
14
+ return len(self.paths)
15
+
16
+ def __getitem__(self, index):
17
+ from_path = self.paths[index]
18
+ from_im = Image.open(from_path)
19
+ from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L')
20
+ if self.transform:
21
+ from_im = self.transform(from_im)
22
+ return from_im
dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
dnnlib/util.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ import ctypes
12
+ import fnmatch
13
+ import importlib
14
+ import inspect
15
+ import numpy as np
16
+ import os
17
+ import shutil
18
+ import sys
19
+ import types
20
+ import io
21
+ import pickle
22
+ import re
23
+ import requests
24
+ import html
25
+ import hashlib
26
+ import glob
27
+ import tempfile
28
+ import urllib
29
+ import urllib.request
30
+ import uuid
31
+
32
+ from distutils.util import strtobool
33
+ from typing import Any, List, Tuple, Union
34
+
35
+
36
+ # Util classes
37
+ # ------------------------------------------------------------------------------------------
38
+
39
+
40
+ class EasyDict(dict):
41
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
+
43
+ def __getattr__(self, name: str) -> Any:
44
+ try:
45
+ return self[name]
46
+ except KeyError:
47
+ raise AttributeError(name)
48
+
49
+ def __setattr__(self, name: str, value: Any) -> None:
50
+ self[name] = value
51
+
52
+ def __delattr__(self, name: str) -> None:
53
+ del self[name]
54
+
55
+
56
+ class Logger(object):
57
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
58
+
59
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
60
+ self.file = None
61
+
62
+ if file_name is not None:
63
+ self.file = open(file_name, file_mode)
64
+
65
+ self.should_flush = should_flush
66
+ self.stdout = sys.stdout
67
+ self.stderr = sys.stderr
68
+
69
+ sys.stdout = self
70
+ sys.stderr = self
71
+
72
+ def __enter__(self) -> "Logger":
73
+ return self
74
+
75
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
76
+ self.close()
77
+
78
+ def write(self, text: Union[str, bytes]) -> None:
79
+ """Write text to stdout (and a file) and optionally flush."""
80
+ if isinstance(text, bytes):
81
+ text = text.decode()
82
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
83
+ return
84
+
85
+ if self.file is not None:
86
+ self.file.write(text)
87
+
88
+ self.stdout.write(text)
89
+
90
+ if self.should_flush:
91
+ self.flush()
92
+
93
+ def flush(self) -> None:
94
+ """Flush written text to both stdout and a file, if open."""
95
+ if self.file is not None:
96
+ self.file.flush()
97
+
98
+ self.stdout.flush()
99
+
100
+ def close(self) -> None:
101
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
102
+ self.flush()
103
+
104
+ # if using multiple loggers, prevent closing in wrong order
105
+ if sys.stdout is self:
106
+ sys.stdout = self.stdout
107
+ if sys.stderr is self:
108
+ sys.stderr = self.stderr
109
+
110
+ if self.file is not None:
111
+ self.file.close()
112
+ self.file = None
113
+
114
+
115
+ # Cache directories
116
+ # ------------------------------------------------------------------------------------------
117
+
118
+ _dnnlib_cache_dir = None
119
+
120
+ def set_cache_dir(path: str) -> None:
121
+ global _dnnlib_cache_dir
122
+ _dnnlib_cache_dir = path
123
+
124
+ def make_cache_dir_path(*paths: str) -> str:
125
+ if _dnnlib_cache_dir is not None:
126
+ return os.path.join(_dnnlib_cache_dir, *paths)
127
+ if 'DNNLIB_CACHE_DIR' in os.environ:
128
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
129
+ if 'HOME' in os.environ:
130
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
131
+ if 'USERPROFILE' in os.environ:
132
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
133
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
134
+
135
+ # Small util functions
136
+ # ------------------------------------------------------------------------------------------
137
+
138
+
139
+ def format_time(seconds: Union[int, float]) -> str:
140
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
141
+ s = int(np.rint(seconds))
142
+
143
+ if s < 60:
144
+ return "{0}s".format(s)
145
+ elif s < 60 * 60:
146
+ return "{0}m {1:02}s".format(s // 60, s % 60)
147
+ elif s < 24 * 60 * 60:
148
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
149
+ else:
150
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
151
+
152
+
153
+ def ask_yes_no(question: str) -> bool:
154
+ """Ask the user the question until the user inputs a valid answer."""
155
+ while True:
156
+ try:
157
+ print("{0} [y/n]".format(question))
158
+ return strtobool(input().lower())
159
+ except ValueError:
160
+ pass
161
+
162
+
163
+ def tuple_product(t: Tuple) -> Any:
164
+ """Calculate the product of the tuple elements."""
165
+ result = 1
166
+
167
+ for v in t:
168
+ result *= v
169
+
170
+ return result
171
+
172
+
173
+ _str_to_ctype = {
174
+ "uint8": ctypes.c_ubyte,
175
+ "uint16": ctypes.c_uint16,
176
+ "uint32": ctypes.c_uint32,
177
+ "uint64": ctypes.c_uint64,
178
+ "int8": ctypes.c_byte,
179
+ "int16": ctypes.c_int16,
180
+ "int32": ctypes.c_int32,
181
+ "int64": ctypes.c_int64,
182
+ "float32": ctypes.c_float,
183
+ "float64": ctypes.c_double
184
+ }
185
+
186
+
187
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
188
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
189
+ type_str = None
190
+
191
+ if isinstance(type_obj, str):
192
+ type_str = type_obj
193
+ elif hasattr(type_obj, "__name__"):
194
+ type_str = type_obj.__name__
195
+ elif hasattr(type_obj, "name"):
196
+ type_str = type_obj.name
197
+ else:
198
+ raise RuntimeError("Cannot infer type name from input")
199
+
200
+ assert type_str in _str_to_ctype.keys()
201
+
202
+ my_dtype = np.dtype(type_str)
203
+ my_ctype = _str_to_ctype[type_str]
204
+
205
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
206
+
207
+ return my_dtype, my_ctype
208
+
209
+
210
+ def is_pickleable(obj: Any) -> bool:
211
+ try:
212
+ with io.BytesIO() as stream:
213
+ pickle.dump(obj, stream)
214
+ return True
215
+ except:
216
+ return False
217
+
218
+
219
+ # Functionality to import modules/objects by name, and call functions by name
220
+ # ------------------------------------------------------------------------------------------
221
+
222
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
223
+ """Searches for the underlying module behind the name to some python object.
224
+ Returns the module and the object name (original name with module part removed)."""
225
+
226
+ # allow convenience shorthands, substitute them by full names
227
+ obj_name = re.sub("^np.", "numpy.", obj_name)
228
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
229
+
230
+ # list alternatives for (module_name, local_obj_name)
231
+ parts = obj_name.split(".")
232
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
233
+
234
+ # try each alternative in turn
235
+ for module_name, local_obj_name in name_pairs:
236
+ try:
237
+ module = importlib.import_module(module_name) # may raise ImportError
238
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
239
+ return module, local_obj_name
240
+ except:
241
+ pass
242
+
243
+ # maybe some of the modules themselves contain errors?
244
+ for module_name, _local_obj_name in name_pairs:
245
+ try:
246
+ importlib.import_module(module_name) # may raise ImportError
247
+ except ImportError:
248
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
249
+ raise
250
+
251
+ # maybe the requested attribute is missing?
252
+ for module_name, local_obj_name in name_pairs:
253
+ try:
254
+ module = importlib.import_module(module_name) # may raise ImportError
255
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
256
+ except ImportError:
257
+ pass
258
+
259
+ # we are out of luck, but we have no idea why
260
+ raise ImportError(obj_name)
261
+
262
+
263
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
264
+ """Traverses the object name and returns the last (rightmost) python object."""
265
+ if obj_name == '':
266
+ return module
267
+ obj = module
268
+ for part in obj_name.split("."):
269
+ obj = getattr(obj, part)
270
+ return obj
271
+
272
+
273
+ def get_obj_by_name(name: str) -> Any:
274
+ """Finds the python object with the given name."""
275
+ module, obj_name = get_module_from_obj_name(name)
276
+ return get_obj_from_module(module, obj_name)
277
+
278
+
279
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
280
+ """Finds the python object with the given name and calls it as a function."""
281
+ assert func_name is not None
282
+ func_obj = get_obj_by_name(func_name)
283
+ assert callable(func_obj)
284
+ return func_obj(*args, **kwargs)
285
+
286
+
287
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
288
+ """Finds the python class with the given name and constructs it with the given arguments."""
289
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
290
+
291
+
292
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
293
+ """Get the directory path of the module containing the given object name."""
294
+ module, _ = get_module_from_obj_name(obj_name)
295
+ return os.path.dirname(inspect.getfile(module))
296
+
297
+
298
+ def is_top_level_function(obj: Any) -> bool:
299
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
300
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
301
+
302
+
303
+ def get_top_level_function_name(obj: Any) -> str:
304
+ """Return the fully-qualified name of a top-level function."""
305
+ assert is_top_level_function(obj)
306
+ module = obj.__module__
307
+ if module == '__main__':
308
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
309
+ return module + "." + obj.__name__
310
+
311
+
312
+ # File system helpers
313
+ # ------------------------------------------------------------------------------------------
314
+
315
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
316
+ """List all files recursively in a given directory while ignoring given file and directory names.
317
+ Returns list of tuples containing both absolute and relative paths."""
318
+ assert os.path.isdir(dir_path)
319
+ base_name = os.path.basename(os.path.normpath(dir_path))
320
+
321
+ if ignores is None:
322
+ ignores = []
323
+
324
+ result = []
325
+
326
+ for root, dirs, files in os.walk(dir_path, topdown=True):
327
+ for ignore_ in ignores:
328
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
329
+
330
+ # dirs need to be edited in-place
331
+ for d in dirs_to_remove:
332
+ dirs.remove(d)
333
+
334
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
335
+
336
+ absolute_paths = [os.path.join(root, f) for f in files]
337
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
338
+
339
+ if add_base_to_relative:
340
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
341
+
342
+ assert len(absolute_paths) == len(relative_paths)
343
+ result += zip(absolute_paths, relative_paths)
344
+
345
+ return result
346
+
347
+
348
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
349
+ """Takes in a list of tuples of (src, dst) paths and copies files.
350
+ Will create all necessary directories."""
351
+ for file in files:
352
+ target_dir_name = os.path.dirname(file[1])
353
+
354
+ # will create all intermediate-level directories
355
+ if not os.path.exists(target_dir_name):
356
+ os.makedirs(target_dir_name)
357
+
358
+ shutil.copyfile(file[0], file[1])
359
+
360
+
361
+ # URL helpers
362
+ # ------------------------------------------------------------------------------------------
363
+
364
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
365
+ """Determine whether the given object is a valid URL string."""
366
+ if not isinstance(obj, str) or not "://" in obj:
367
+ return False
368
+ if allow_file_urls and obj.startswith('file://'):
369
+ return True
370
+ try:
371
+ res = requests.compat.urlparse(obj)
372
+ if not res.scheme or not res.netloc or not "." in res.netloc:
373
+ return False
374
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
375
+ if not res.scheme or not res.netloc or not "." in res.netloc:
376
+ return False
377
+ except:
378
+ return False
379
+ return True
380
+
381
+
382
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
383
+ """Download the given URL and return a binary-mode file object to access the data."""
384
+ assert num_attempts >= 1
385
+ assert not (return_filename and (not cache))
386
+
387
+ # Doesn't look like an URL scheme so interpret it as a local filename.
388
+ if not re.match('^[a-z]+://', url):
389
+ return url if return_filename else open(url, "rb")
390
+
391
+ # Handle file URLs. This code handles unusual file:// patterns that
392
+ # arise on Windows:
393
+ #
394
+ # file:///c:/foo.txt
395
+ #
396
+ # which would translate to a local '/c:/foo.txt' filename that's
397
+ # invalid. Drop the forward slash for such pathnames.
398
+ #
399
+ # If you touch this code path, you should test it on both Linux and
400
+ # Windows.
401
+ #
402
+ # Some internet resources suggest using urllib.request.url2pathname() but
403
+ # but that converts forward slashes to backslashes and this causes
404
+ # its own set of problems.
405
+ if url.startswith('file://'):
406
+ filename = urllib.parse.urlparse(url).path
407
+ if re.match(r'^/[a-zA-Z]:', filename):
408
+ filename = filename[1:]
409
+ return filename if return_filename else open(filename, "rb")
410
+
411
+ assert is_url(url)
412
+
413
+ # Lookup from cache.
414
+ if cache_dir is None:
415
+ cache_dir = make_cache_dir_path('downloads')
416
+
417
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
418
+ if cache:
419
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
420
+ if len(cache_files) == 1:
421
+ filename = cache_files[0]
422
+ return filename if return_filename else open(filename, "rb")
423
+
424
+ # Download.
425
+ url_name = None
426
+ url_data = None
427
+ with requests.Session() as session:
428
+ if verbose:
429
+ print("Downloading %s ..." % url, end="", flush=True)
430
+ for attempts_left in reversed(range(num_attempts)):
431
+ try:
432
+ with session.get(url) as res:
433
+ res.raise_for_status()
434
+ if len(res.content) == 0:
435
+ raise IOError("No data received")
436
+
437
+ if len(res.content) < 8192:
438
+ content_str = res.content.decode("utf-8")
439
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
440
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
441
+ if len(links) == 1:
442
+ url = requests.compat.urljoin(url, links[0])
443
+ raise IOError("Google Drive virus checker nag")
444
+ if "Google Drive - Quota exceeded" in content_str:
445
+ raise IOError("Google Drive download quota exceeded -- please try again later")
446
+
447
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
448
+ url_name = match[1] if match else url
449
+ url_data = res.content
450
+ if verbose:
451
+ print(" done")
452
+ break
453
+ except KeyboardInterrupt:
454
+ raise
455
+ except:
456
+ if not attempts_left:
457
+ if verbose:
458
+ print(" failed")
459
+ raise
460
+ if verbose:
461
+ print(".", end="", flush=True)
462
+
463
+ # Save to cache.
464
+ if cache:
465
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
466
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
467
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
468
+ os.makedirs(cache_dir, exist_ok=True)
469
+ with open(temp_file, "wb") as f:
470
+ f.write(url_data)
471
+ os.replace(temp_file, cache_file) # atomic
472
+ if return_filename:
473
+ return cache_file
474
+
475
+ # Return data as file object.
476
+ assert not return_filename
477
+ return io.BytesIO(url_data)
legacy.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import click
10
+ import pickle
11
+ import re
12
+ import copy
13
+ import numpy as np
14
+ import torch
15
+ import dnnlib
16
+ from torch_utils import misc
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ # !!! custom
21
+ def load_network_pkl(f, force_fp16=False, custom=False, **ex_kwargs):
22
+ # def load_network_pkl(f, force_fp16=False):
23
+ data = _LegacyUnpickler(f).load()
24
+ # data = pickle.load(f, encoding='latin1')
25
+
26
+ # Legacy TensorFlow pickle => convert.
27
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
28
+ tf_G, tf_D, tf_Gs = data
29
+ G = convert_tf_generator(tf_G, custom=custom, **ex_kwargs)
30
+ D = convert_tf_discriminator(tf_D)
31
+ G_ema = convert_tf_generator(tf_Gs, custom=custom, **ex_kwargs)
32
+ data = dict(G=G, D=D, G_ema=G_ema)
33
+ # !!! custom
34
+ assert isinstance(data['G'], torch.nn.Module)
35
+ assert isinstance(data['D'], torch.nn.Module)
36
+ nets = ['G', 'D', 'G_ema']
37
+ elif isinstance(data, _TFNetworkStub):
38
+ G_ema = convert_tf_generator(data, custom=custom, **ex_kwargs)
39
+ data = dict(G_ema=G_ema)
40
+ nets = ['G_ema']
41
+ else:
42
+ # !!! custom
43
+ if custom is True:
44
+ G_ema = custom_generator(data, **ex_kwargs)
45
+ data = dict(G_ema=G_ema)
46
+ nets = ['G_ema']
47
+ else:
48
+ nets = []
49
+ for name in ['G', 'D', 'G_ema']:
50
+ if name in data.keys():
51
+ nets.append(name)
52
+ # print(nets)
53
+
54
+ # Add missing fields.
55
+ if 'training_set_kwargs' not in data:
56
+ data['training_set_kwargs'] = None
57
+ if 'augment_pipe' not in data:
58
+ data['augment_pipe'] = None
59
+
60
+ # Validate contents.
61
+ assert isinstance(data['G_ema'], torch.nn.Module)
62
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
63
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
64
+
65
+ # Force FP16.
66
+ if force_fp16:
67
+ for key in nets: # !!! custom
68
+ old = data[key]
69
+ kwargs = copy.deepcopy(old.init_kwargs)
70
+ if key.startswith('G'):
71
+ kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
72
+ kwargs.synthesis_kwargs.num_fp16_res = 4
73
+ kwargs.synthesis_kwargs.conv_clamp = 256
74
+ if key.startswith('D'):
75
+ kwargs.num_fp16_res = 4
76
+ kwargs.conv_clamp = 256
77
+ if kwargs != old.init_kwargs:
78
+ new = type(old)(**kwargs).eval().requires_grad_(False)
79
+ misc.copy_params_and_buffers(old, new, require_all=True)
80
+ data[key] = new
81
+ return data
82
+
83
+ #----------------------------------------------------------------------------
84
+
85
+ class _TFNetworkStub(dnnlib.EasyDict):
86
+ pass
87
+
88
+ class _LegacyUnpickler(pickle.Unpickler):
89
+ def find_class(self, module, name):
90
+ if module == 'dnnlib.tflib.network' and name == 'Network':
91
+ return _TFNetworkStub
92
+ return super().find_class(module, name)
93
+
94
+ #----------------------------------------------------------------------------
95
+
96
+ def _collect_tf_params(tf_net):
97
+ # pylint: disable=protected-access
98
+ tf_params = dict()
99
+ def recurse(prefix, tf_net):
100
+ for name, value in tf_net.variables:
101
+ tf_params[prefix + name] = value
102
+ for name, comp in tf_net.components.items():
103
+ recurse(prefix + name + '/', comp)
104
+ recurse('', tf_net)
105
+ return tf_params
106
+
107
+ #----------------------------------------------------------------------------
108
+
109
+ def _populate_module_params(module, *patterns):
110
+ for name, tensor in misc.named_params_and_buffers(module):
111
+ found = False
112
+ value = None
113
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
114
+ match = re.fullmatch(pattern, name)
115
+ if match:
116
+ found = True
117
+ if value_fn is not None:
118
+ value = value_fn(*match.groups())
119
+ break
120
+ try:
121
+ assert found
122
+ if value is not None:
123
+ tensor.copy_(torch.from_numpy(np.array(value)))
124
+ except:
125
+ print(name, list(tensor.shape))
126
+ raise
127
+
128
+ #----------------------------------------------------------------------------
129
+
130
+ # !!! custom
131
+ def custom_generator(data, **ex_kwargs):
132
+ from training import stylegan2_multi as networks
133
+ try: # saved? (with new fix)
134
+ fmap_base = data['G_ema'].synthesis.fmap_base
135
+ except: # default from original configs
136
+ fmap_base = 32768 if data['G_ema'].img_resolution >= 512 else 16384
137
+ kwargs = dnnlib.EasyDict(
138
+ z_dim = data['G_ema'].z_dim,
139
+ c_dim = data['G_ema'].c_dim,
140
+ w_dim = data['G_ema'].w_dim,
141
+ img_resolution = data['G_ema'].img_resolution,
142
+ img_channels = data['G_ema'].img_channels,
143
+ init_res = [4,4], # hacky
144
+ mapping_kwargs = dnnlib.EasyDict(num_layers = data['G_ema'].mapping.num_layers),
145
+ synthesis_kwargs = dnnlib.EasyDict(channel_base = fmap_base, **ex_kwargs),
146
+ )
147
+ G_out = networks.Generator(**kwargs).eval().requires_grad_(False)
148
+ misc.copy_params_and_buffers(data['G_ema'], G_out, require_all=False)
149
+ return G_out
150
+
151
+ # !!! custom
152
+ def convert_tf_generator(tf_G, custom=False, **ex_kwargs):
153
+ # def convert_tf_generator(tf_G):
154
+ if tf_G.version < 4:
155
+ raise ValueError('TensorFlow pickle version too low')
156
+
157
+ # Collect kwargs.
158
+ tf_kwargs = tf_G.static_kwargs
159
+ known_kwargs = set()
160
+ def kwarg(tf_name, default=None, none=None):
161
+ known_kwargs.add(tf_name)
162
+ val = tf_kwargs.get(tf_name, default)
163
+ return val if val is not None else none
164
+
165
+ # Convert kwargs.
166
+ kwargs = dnnlib.EasyDict(
167
+ z_dim = kwarg('latent_size', 512),
168
+ c_dim = kwarg('label_size', 0),
169
+ w_dim = kwarg('dlatent_size', 512),
170
+ img_resolution = kwarg('resolution', 1024),
171
+ img_channels = kwarg('num_channels', 3),
172
+ mapping_kwargs = dnnlib.EasyDict(
173
+ num_layers = kwarg('mapping_layers', 8),
174
+ embed_features = kwarg('label_fmaps', None),
175
+ layer_features = kwarg('mapping_fmaps', None),
176
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
177
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
178
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
179
+ ),
180
+ synthesis_kwargs = dnnlib.EasyDict(
181
+ channel_base = kwarg('fmap_base', 16384) * 2,
182
+ channel_max = kwarg('fmap_max', 512),
183
+ num_fp16_res = kwarg('num_fp16_res', 0),
184
+ conv_clamp = kwarg('conv_clamp', None),
185
+ architecture = kwarg('architecture', 'skip'),
186
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
187
+ use_noise = kwarg('use_noise', True),
188
+ activation = kwarg('nonlinearity', 'lrelu'),
189
+ ),
190
+ # !!! custom
191
+ # init_res = kwarg('init_res', [4,4]),
192
+ )
193
+
194
+ # Check for unknown kwargs.
195
+ kwarg('truncation_psi')
196
+ kwarg('truncation_cutoff')
197
+ kwarg('style_mixing_prob')
198
+ kwarg('structure')
199
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
200
+ # !!! custom
201
+ if custom:
202
+ kwargs.init_res = [4,4]
203
+ kwargs.synthesis_kwargs = dnnlib.EasyDict(**kwargs.synthesis_kwargs, **ex_kwargs)
204
+ if len(unknown_kwargs) > 0:
205
+ print('Unknown TensorFlow data! This may result in problems with your converted model.')
206
+ print(unknown_kwargs)
207
+ #raise ValueError('Unknown TensorFlow kwargs:', unknown_kwargs)
208
+ # raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
209
+ # try:
210
+ # if ex_kwargs['verbose'] is True: print(kwargs.synthesis_kwargs)
211
+ # except: pass
212
+
213
+ # Collect params.
214
+ tf_params = _collect_tf_params(tf_G)
215
+ for name, value in list(tf_params.items()):
216
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
217
+ if match:
218
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
219
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
220
+ kwargs.synthesis.kwargs.architecture = 'orig'
221
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
222
+
223
+ # Convert params.
224
+ if custom:
225
+ from training import stylegan2_multi as networks
226
+ else:
227
+ from training import networks
228
+ G = networks.Generator(**kwargs).eval().requires_grad_(False)
229
+ # pylint: disable=unnecessary-lambda
230
+ _populate_module_params(G,
231
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
232
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
233
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
234
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
235
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
236
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
237
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
238
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
239
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
240
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
241
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
242
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
243
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
244
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
245
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
246
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
247
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
248
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
249
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
250
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
251
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
252
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
253
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
254
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
255
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
256
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
257
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
258
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
259
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
260
+ r'.*\.resample_filter', None,
261
+ )
262
+ return G
263
+
264
+ #----------------------------------------------------------------------------
265
+
266
+ def convert_tf_discriminator(tf_D):
267
+ if tf_D.version < 4:
268
+ raise ValueError('TensorFlow pickle version too low')
269
+
270
+ # Collect kwargs.
271
+ tf_kwargs = tf_D.static_kwargs
272
+ known_kwargs = set()
273
+ def kwarg(tf_name, default=None):
274
+ known_kwargs.add(tf_name)
275
+ return tf_kwargs.get(tf_name, default)
276
+
277
+ # Convert kwargs.
278
+ kwargs = dnnlib.EasyDict(
279
+ c_dim = kwarg('label_size', 0),
280
+ img_resolution = kwarg('resolution', 1024),
281
+ img_channels = kwarg('num_channels', 3),
282
+ architecture = kwarg('architecture', 'resnet'),
283
+ channel_base = kwarg('fmap_base', 16384) * 2,
284
+ channel_max = kwarg('fmap_max', 512),
285
+ num_fp16_res = kwarg('num_fp16_res', 0),
286
+ conv_clamp = kwarg('conv_clamp', None),
287
+ cmap_dim = kwarg('mapping_fmaps', None),
288
+ block_kwargs = dnnlib.EasyDict(
289
+ activation = kwarg('nonlinearity', 'lrelu'),
290
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
291
+ freeze_layers = kwarg('freeze_layers', 0),
292
+ ),
293
+ mapping_kwargs = dnnlib.EasyDict(
294
+ num_layers = kwarg('mapping_layers', 0),
295
+ embed_features = kwarg('mapping_fmaps', None),
296
+ layer_features = kwarg('mapping_fmaps', None),
297
+ activation = kwarg('nonlinearity', 'lrelu'),
298
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
299
+ ),
300
+ epilogue_kwargs = dnnlib.EasyDict(
301
+ mbstd_group_size = kwarg('mbstd_group_size', None),
302
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
303
+ activation = kwarg('nonlinearity', 'lrelu'),
304
+ ),
305
+ # !!! custom
306
+ # init_res = kwarg('init_res', [4,4]),
307
+ )
308
+
309
+ # Check for unknown kwargs.
310
+ kwarg('structure')
311
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
312
+ if len(unknown_kwargs) > 0:
313
+ print('Unknown TensorFlow data! This may result in problems with your converted model.')
314
+ print(unknown_kwargs)
315
+ # originally this repo threw errors:
316
+ # raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
317
+
318
+ # Collect params.
319
+ tf_params = _collect_tf_params(tf_D)
320
+ for name, value in list(tf_params.items()):
321
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
322
+ if match:
323
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
324
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
325
+ kwargs.architecture = 'orig'
326
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
327
+
328
+ # Convert params.
329
+ from training import networks
330
+ D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
331
+ # pylint: disable=unnecessary-lambda
332
+ _populate_module_params(D,
333
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
334
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
335
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
336
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
337
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
338
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
339
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
340
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
341
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
342
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
343
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
344
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
345
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
346
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
347
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
348
+ r'.*\.resample_filter', None,
349
+ )
350
+ return D
351
+
352
+ #----------------------------------------------------------------------------
353
+
354
+ @click.command()
355
+ @click.option('--source', help='Input pickle', required=True, metavar='PATH')
356
+ @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
357
+ @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
358
+ def convert_network_pickle(source, dest, force_fp16):
359
+ """Convert legacy network pickle into the native PyTorch format.
360
+
361
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
362
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
363
+
364
+ Example:
365
+
366
+ \b
367
+ python legacy.py \\
368
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
369
+ --dest=stylegan2-cat-config-f.pkl
370
+ """
371
+ print(f'Loading "{source}"...')
372
+ with dnnlib.util.open_url(source) as f:
373
+ data = load_network_pkl(f, force_fp16=force_fp16)
374
+ print(f'Saving "{dest}"...')
375
+ with open(dest, 'wb') as f:
376
+ pickle.dump(data, f)
377
+ print('Done.')
378
+
379
+ #----------------------------------------------------------------------------
380
+
381
+ if __name__ == "__main__":
382
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
383
+
384
+ #----------------------------------------------------------------------------
model_build.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+
4
+ import numpy as np
5
+ from numpy import linalg
6
+ import PIL.Image as Image
7
+ import torch
8
+ from torchvision import transforms
9
+ from tqdm import tqdm
10
+ from argparse import Namespace
11
+ import easydict
12
+
13
+ import legacy
14
+ import dnnlib
15
+
16
+ from opensimplex import OpenSimplex
17
+
18
+ from configs import data_configs
19
+ from models.psp import pSp
20
+
21
+
22
+ def build_stylegan2(
23
+ increment = 0.01,
24
+ network_pkl = 'pretrained/furry.pkl',
25
+ process = 'image', #['image', 'interpolation','truncation','interpolation-truncation']
26
+ random_seed = 0,
27
+ diameter = 100.0,
28
+ scale_type = 'pad', #['pad', 'padside', 'symm','symmside']
29
+ size = [512, 512],
30
+ seeds = [0],
31
+ space = 'z', #['z', 'w']
32
+ fps = 24,
33
+ frames = 240,
34
+ noise_mode = 'none', #['const', 'random', 'none']
35
+ outdir = 'path',
36
+ projected_w = 'path',
37
+ easing = 'linear',
38
+ device = 'cpu'
39
+
40
+ ):
41
+
42
+ G_kwargs = dnnlib.EasyDict()
43
+ G_kwargs.size = size
44
+ G_kwargs.scale_type = scale_type
45
+
46
+ device = torch.device(device)
47
+ with dnnlib.util.open_url(network_pkl) as f:
48
+ # G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
49
+ G = legacy.load_network_pkl(f, custom=True, **G_kwargs)['G_ema'].to(device) # type: ignore
50
+
51
+ return G.synthesis
52
+
53
+
54
+ def build_psp():
55
+ test_opts = easydict.EasyDict({
56
+ # arguments for inference script
57
+ 'checkpoint_path' : 'pretrained/psp.pt',
58
+ 'couple_outputs' : False,
59
+ 'resize_outputs' : False,
60
+
61
+ 'test_batch_size' : 1,
62
+ 'test_workers' : 1,
63
+
64
+ # arguments for style-mixing script
65
+ 'n_images' : None,
66
+ 'n_outputs_to_generate' : 5,
67
+ 'mix_alpha' : None,
68
+ 'latent_mask' : None,
69
+
70
+ # arguments for super-resolution
71
+ 'resize_factors' : None,
72
+ })
73
+
74
+ # update test options with options used during training
75
+ ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
76
+ opts = ckpt['opts']
77
+ opts.update(vars(test_opts))
78
+ if 'learn_in_w' not in opts:
79
+ opts['learn_in_w'] = False
80
+ opts = Namespace(**opts)
81
+ opts.device = 'cpu'
82
+ net = pSp(opts)
83
+ net.eval()
84
+ return net
85
+
86
+ def img_preprocess(img, transform):
87
+ if (img.mode == 'RGBA') or (img.mode == 'P'):
88
+ img.load()
89
+ background = Image.new("RGB", img.size, (255, 255, 255))
90
+ background.paste(img, mask=img.split()[3]) # 3 is the alpha channel
91
+ img = background
92
+ assert img.mode == 'RGB'
93
+ img = transform(img)
94
+ img = img.unsqueeze(dim=0)
95
+ return img
models/__init__.py ADDED
File without changes
models/encoders/__init__.py ADDED
File without changes
models/encoders/helpers.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
4
+
5
+ """
6
+ ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
7
+ """
8
+
9
+
10
+ class Flatten(Module):
11
+ def forward(self, input):
12
+ return input.view(input.size(0), -1)
13
+
14
+
15
+ def l2_norm(input, axis=1):
16
+ norm = torch.norm(input, 2, axis, True)
17
+ output = torch.div(input, norm)
18
+ return output
19
+
20
+
21
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
22
+ """ A named tuple describing a ResNet block. """
23
+
24
+
25
+ def get_block(in_channel, depth, num_units, stride=2):
26
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
27
+
28
+
29
+ def get_blocks(num_layers):
30
+ if num_layers == 50:
31
+ blocks = [
32
+ get_block(in_channel=64, depth=64, num_units=3),
33
+ get_block(in_channel=64, depth=128, num_units=4),
34
+ get_block(in_channel=128, depth=256, num_units=14),
35
+ get_block(in_channel=256, depth=512, num_units=3)
36
+ ]
37
+ elif num_layers == 100:
38
+ blocks = [
39
+ get_block(in_channel=64, depth=64, num_units=3),
40
+ get_block(in_channel=64, depth=128, num_units=13),
41
+ get_block(in_channel=128, depth=256, num_units=30),
42
+ get_block(in_channel=256, depth=512, num_units=3)
43
+ ]
44
+ elif num_layers == 152:
45
+ blocks = [
46
+ get_block(in_channel=64, depth=64, num_units=3),
47
+ get_block(in_channel=64, depth=128, num_units=8),
48
+ get_block(in_channel=128, depth=256, num_units=36),
49
+ get_block(in_channel=256, depth=512, num_units=3)
50
+ ]
51
+ else:
52
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
53
+ return blocks
54
+
55
+
56
+ class SEModule(Module):
57
+ def __init__(self, channels, reduction):
58
+ super(SEModule, self).__init__()
59
+ self.avg_pool = AdaptiveAvgPool2d(1)
60
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
61
+ self.relu = ReLU(inplace=True)
62
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
63
+ self.sigmoid = Sigmoid()
64
+
65
+ def forward(self, x):
66
+ module_input = x
67
+ x = self.avg_pool(x)
68
+ x = self.fc1(x)
69
+ x = self.relu(x)
70
+ x = self.fc2(x)
71
+ x = self.sigmoid(x)
72
+ return module_input * x
73
+
74
+
75
+ class bottleneck_IR(Module):
76
+ def __init__(self, in_channel, depth, stride):
77
+ super(bottleneck_IR, self).__init__()
78
+ if in_channel == depth:
79
+ self.shortcut_layer = MaxPool2d(1, stride)
80
+ else:
81
+ self.shortcut_layer = Sequential(
82
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
83
+ BatchNorm2d(depth)
84
+ )
85
+ self.res_layer = Sequential(
86
+ BatchNorm2d(in_channel),
87
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
88
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
89
+ )
90
+
91
+ def forward(self, x):
92
+ shortcut = self.shortcut_layer(x)
93
+ res = self.res_layer(x)
94
+ return res + shortcut
95
+
96
+
97
+ class bottleneck_IR_SE(Module):
98
+ def __init__(self, in_channel, depth, stride):
99
+ super(bottleneck_IR_SE, self).__init__()
100
+ if in_channel == depth:
101
+ self.shortcut_layer = MaxPool2d(1, stride)
102
+ else:
103
+ self.shortcut_layer = Sequential(
104
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
105
+ BatchNorm2d(depth)
106
+ )
107
+ self.res_layer = Sequential(
108
+ BatchNorm2d(in_channel),
109
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
110
+ PReLU(depth),
111
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
112
+ BatchNorm2d(depth),
113
+ SEModule(depth, 16)
114
+ )
115
+
116
+ def forward(self, x):
117
+ shortcut = self.shortcut_layer(x)
118
+ res = self.res_layer(x)
119
+ return res + shortcut
models/encoders/model_irse.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
2
+ from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
3
+
4
+ """
5
+ Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
6
+ """
7
+
8
+
9
+ class Backbone(Module):
10
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
11
+ super(Backbone, self).__init__()
12
+ assert input_size in [112, 224], "input_size should be 112 or 224"
13
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
14
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
15
+ blocks = get_blocks(num_layers)
16
+ if mode == 'ir':
17
+ unit_module = bottleneck_IR
18
+ elif mode == 'ir_se':
19
+ unit_module = bottleneck_IR_SE
20
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
21
+ BatchNorm2d(64),
22
+ PReLU(64))
23
+ if input_size == 112:
24
+ self.output_layer = Sequential(BatchNorm2d(512),
25
+ Dropout(drop_ratio),
26
+ Flatten(),
27
+ Linear(512 * 7 * 7, 512),
28
+ BatchNorm1d(512, affine=affine))
29
+ else:
30
+ self.output_layer = Sequential(BatchNorm2d(512),
31
+ Dropout(drop_ratio),
32
+ Flatten(),
33
+ Linear(512 * 14 * 14, 512),
34
+ BatchNorm1d(512, affine=affine))
35
+
36
+ modules = []
37
+ for block in blocks:
38
+ for bottleneck in block:
39
+ modules.append(unit_module(bottleneck.in_channel,
40
+ bottleneck.depth,
41
+ bottleneck.stride))
42
+ self.body = Sequential(*modules)
43
+
44
+ def forward(self, x):
45
+ x = self.input_layer(x)
46
+ x = self.body(x)
47
+ x = self.output_layer(x)
48
+ return l2_norm(x)
49
+
50
+
51
+ def IR_50(input_size):
52
+ """Constructs a ir-50 model."""
53
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
54
+ return model
55
+
56
+
57
+ def IR_101(input_size):
58
+ """Constructs a ir-101 model."""
59
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
60
+ return model
61
+
62
+
63
+ def IR_152(input_size):
64
+ """Constructs a ir-152 model."""
65
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
66
+ return model
67
+
68
+
69
+ def IR_SE_50(input_size):
70
+ """Constructs a ir_se-50 model."""
71
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
72
+ return model
73
+
74
+
75
+ def IR_SE_101(input_size):
76
+ """Constructs a ir_se-101 model."""
77
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
78
+ return model
79
+
80
+
81
+ def IR_SE_152(input_size):
82
+ """Constructs a ir_se-152 model."""
83
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
84
+ return model
models/encoders/psp_encoders.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ from torch.nn import Linear, Conv2d, BatchNorm2d, PReLU, Sequential, Module
6
+
7
+ from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE
8
+ from models.stylegan2.model import EqualLinear
9
+
10
+
11
+ class GradualStyleBlock(Module):
12
+ def __init__(self, in_c, out_c, spatial):
13
+ super(GradualStyleBlock, self).__init__()
14
+ self.out_c = out_c
15
+ self.spatial = spatial
16
+ num_pools = int(np.log2(spatial))
17
+ modules = []
18
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
19
+ nn.LeakyReLU()]
20
+ for i in range(num_pools - 1):
21
+ modules += [
22
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
23
+ nn.LeakyReLU()
24
+ ]
25
+ self.convs = nn.Sequential(*modules)
26
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
27
+
28
+ def forward(self, x):
29
+ x = self.convs(x)
30
+ x = x.view(-1, self.out_c)
31
+ x = self.linear(x)
32
+ return x
33
+
34
+
35
+ class GradualStyleEncoder(Module):
36
+ def __init__(self, num_layers, mode='ir', opts=None):
37
+ super(GradualStyleEncoder, self).__init__()
38
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
39
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
40
+ blocks = get_blocks(num_layers)
41
+ if mode == 'ir':
42
+ unit_module = bottleneck_IR
43
+ elif mode == 'ir_se':
44
+ unit_module = bottleneck_IR_SE
45
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
46
+ BatchNorm2d(64),
47
+ PReLU(64))
48
+ modules = []
49
+ for block in blocks:
50
+ for bottleneck in block:
51
+ modules.append(unit_module(bottleneck.in_channel,
52
+ bottleneck.depth,
53
+ bottleneck.stride))
54
+ self.body = Sequential(*modules)
55
+
56
+ self.styles = nn.ModuleList()
57
+ self.style_count = opts.n_styles
58
+ self.coarse_ind = 3
59
+ self.middle_ind = 7
60
+ for i in range(self.style_count):
61
+ if i < self.coarse_ind:
62
+ style = GradualStyleBlock(512, 512, 16)
63
+ elif i < self.middle_ind:
64
+ style = GradualStyleBlock(512, 512, 32)
65
+ else:
66
+ style = GradualStyleBlock(512, 512, 64)
67
+ self.styles.append(style)
68
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
69
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
70
+
71
+ def _upsample_add(self, x, y):
72
+ '''Upsample and add two feature maps.
73
+ Args:
74
+ x: (Variable) top feature map to be upsampled.
75
+ y: (Variable) lateral feature map.
76
+ Returns:
77
+ (Variable) added feature map.
78
+ Note in PyTorch, when input size is odd, the upsampled feature map
79
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
80
+ maybe not equal to the lateral feature map size.
81
+ e.g.
82
+ original input size: [N,_,15,15] ->
83
+ conv2d feature map size: [N,_,8,8] ->
84
+ upsampled feature map size: [N,_,16,16]
85
+ So we choose bilinear upsample which supports arbitrary output sizes.
86
+ '''
87
+ _, _, H, W = y.size()
88
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
89
+
90
+ def forward(self, x):
91
+ x = self.input_layer(x)
92
+
93
+ latents = []
94
+ modulelist = list(self.body._modules.values())
95
+ for i, l in enumerate(modulelist):
96
+ x = l(x)
97
+ if i == 6:
98
+ c1 = x
99
+ elif i == 20:
100
+ c2 = x
101
+ elif i == 23:
102
+ c3 = x
103
+
104
+ for j in range(self.coarse_ind):
105
+ latents.append(self.styles[j](c3))
106
+
107
+ p2 = self._upsample_add(c3, self.latlayer1(c2))
108
+ for j in range(self.coarse_ind, self.middle_ind):
109
+ latents.append(self.styles[j](p2))
110
+
111
+ p1 = self._upsample_add(p2, self.latlayer2(c1))
112
+ for j in range(self.middle_ind, self.style_count):
113
+ latents.append(self.styles[j](p1))
114
+
115
+ out = torch.stack(latents, dim=1)
116
+ return out
117
+
118
+
119
+ class BackboneEncoderUsingLastLayerIntoW(Module):
120
+ def __init__(self, num_layers, mode='ir', opts=None):
121
+ super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
122
+ print('Using BackboneEncoderUsingLastLayerIntoW')
123
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
124
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
125
+ blocks = get_blocks(num_layers)
126
+ if mode == 'ir':
127
+ unit_module = bottleneck_IR
128
+ elif mode == 'ir_se':
129
+ unit_module = bottleneck_IR_SE
130
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
131
+ BatchNorm2d(64),
132
+ PReLU(64))
133
+ self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
134
+ self.linear = EqualLinear(512, 512, lr_mul=1)
135
+ modules = []
136
+ for block in blocks:
137
+ for bottleneck in block:
138
+ modules.append(unit_module(bottleneck.in_channel,
139
+ bottleneck.depth,
140
+ bottleneck.stride))
141
+ self.body = Sequential(*modules)
142
+
143
+ def forward(self, x):
144
+ x = self.input_layer(x)
145
+ x = self.body(x)
146
+ x = self.output_pool(x)
147
+ x = x.view(-1, 512)
148
+ x = self.linear(x)
149
+ return x
150
+
151
+
152
+ class BackboneEncoderUsingLastLayerIntoWPlus(Module):
153
+ def __init__(self, num_layers, mode='ir', opts=None):
154
+ super(BackboneEncoderUsingLastLayerIntoWPlus, self).__init__()
155
+ print('Using BackboneEncoderUsingLastLayerIntoWPlus')
156
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
157
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
158
+ blocks = get_blocks(num_layers)
159
+ if mode == 'ir':
160
+ unit_module = bottleneck_IR
161
+ elif mode == 'ir_se':
162
+ unit_module = bottleneck_IR_SE
163
+ self.n_styles = opts.n_styles
164
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
165
+ BatchNorm2d(64),
166
+ PReLU(64))
167
+ self.output_layer_2 = Sequential(BatchNorm2d(512),
168
+ torch.nn.AdaptiveAvgPool2d((7, 7)),
169
+ Flatten(),
170
+ Linear(512 * 7 * 7, 512))
171
+ self.linear = EqualLinear(512, 512 * self.n_styles, lr_mul=1)
172
+ modules = []
173
+ for block in blocks:
174
+ for bottleneck in block:
175
+ modules.append(unit_module(bottleneck.in_channel,
176
+ bottleneck.depth,
177
+ bottleneck.stride))
178
+ self.body = Sequential(*modules)
179
+
180
+ def forward(self, x):
181
+ x = self.input_layer(x)
182
+ x = self.body(x)
183
+ x = self.output_layer_2(x)
184
+ x = self.linear(x)
185
+ x = x.view(-1, self.n_styles, 512)
186
+ return x
models/mtcnn/__init__.py ADDED
File without changes
models/mtcnn/mtcnn.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ from models.mtcnn.mtcnn_pytorch.src.get_nets import PNet, RNet, ONet
5
+ from models.mtcnn.mtcnn_pytorch.src.box_utils import nms, calibrate_box, get_image_boxes, convert_to_square
6
+ from models.mtcnn.mtcnn_pytorch.src.first_stage import run_first_stage
7
+ from models.mtcnn.mtcnn_pytorch.src.align_trans import get_reference_facial_points, warp_and_crop_face
8
+
9
+ device = 'cuda:0'
10
+
11
+
12
+ class MTCNN():
13
+ def __init__(self):
14
+ print(device)
15
+ self.pnet = PNet().to(device)
16
+ self.rnet = RNet().to(device)
17
+ self.onet = ONet().to(device)
18
+ self.pnet.eval()
19
+ self.rnet.eval()
20
+ self.onet.eval()
21
+ self.refrence = get_reference_facial_points(default_square=True)
22
+
23
+ def align(self, img):
24
+ _, landmarks = self.detect_faces(img)
25
+ if len(landmarks) == 0:
26
+ return None, None
27
+ facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)]
28
+ warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112))
29
+ return Image.fromarray(warped_face), tfm
30
+
31
+ def align_multi(self, img, limit=None, min_face_size=30.0):
32
+ boxes, landmarks = self.detect_faces(img, min_face_size)
33
+ if limit:
34
+ boxes = boxes[:limit]
35
+ landmarks = landmarks[:limit]
36
+ faces = []
37
+ tfms = []
38
+ for landmark in landmarks:
39
+ facial5points = [[landmark[j], landmark[j + 5]] for j in range(5)]
40
+ warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112))
41
+ faces.append(Image.fromarray(warped_face))
42
+ tfms.append(tfm)
43
+ return boxes, faces, tfms
44
+
45
+ def detect_faces(self, image, min_face_size=20.0,
46
+ thresholds=[0.15, 0.25, 0.35],
47
+ nms_thresholds=[0.7, 0.7, 0.7]):
48
+ """
49
+ Arguments:
50
+ image: an instance of PIL.Image.
51
+ min_face_size: a float number.
52
+ thresholds: a list of length 3.
53
+ nms_thresholds: a list of length 3.
54
+
55
+ Returns:
56
+ two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10],
57
+ bounding boxes and facial landmarks.
58
+ """
59
+
60
+ # BUILD AN IMAGE PYRAMID
61
+ width, height = image.size
62
+ min_length = min(height, width)
63
+
64
+ min_detection_size = 12
65
+ factor = 0.707 # sqrt(0.5)
66
+
67
+ # scales for scaling the image
68
+ scales = []
69
+
70
+ # scales the image so that
71
+ # minimum size that we can detect equals to
72
+ # minimum face size that we want to detect
73
+ m = min_detection_size / min_face_size
74
+ min_length *= m
75
+
76
+ factor_count = 0
77
+ while min_length > min_detection_size:
78
+ scales.append(m * factor ** factor_count)
79
+ min_length *= factor
80
+ factor_count += 1
81
+
82
+ # STAGE 1
83
+
84
+ # it will be returned
85
+ bounding_boxes = []
86
+
87
+ with torch.no_grad():
88
+ # run P-Net on different scales
89
+ for s in scales:
90
+ boxes = run_first_stage(image, self.pnet, scale=s, threshold=thresholds[0])
91
+ bounding_boxes.append(boxes)
92
+
93
+ # collect boxes (and offsets, and scores) from different scales
94
+ bounding_boxes = [i for i in bounding_boxes if i is not None]
95
+ bounding_boxes = np.vstack(bounding_boxes)
96
+
97
+ keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0])
98
+ bounding_boxes = bounding_boxes[keep]
99
+
100
+ # use offsets predicted by pnet to transform bounding boxes
101
+ bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:])
102
+ # shape [n_boxes, 5]
103
+
104
+ bounding_boxes = convert_to_square(bounding_boxes)
105
+ bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
106
+
107
+ # STAGE 2
108
+
109
+ img_boxes = get_image_boxes(bounding_boxes, image, size=24)
110
+ img_boxes = torch.FloatTensor(img_boxes).to(device)
111
+
112
+ output = self.rnet(img_boxes)
113
+ offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4]
114
+ probs = output[1].cpu().data.numpy() # shape [n_boxes, 2]
115
+
116
+ keep = np.where(probs[:, 1] > thresholds[1])[0]
117
+ bounding_boxes = bounding_boxes[keep]
118
+ bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
119
+ offsets = offsets[keep]
120
+
121
+ keep = nms(bounding_boxes, nms_thresholds[1])
122
+ bounding_boxes = bounding_boxes[keep]
123
+ bounding_boxes = calibrate_box(bounding_boxes, offsets[keep])
124
+ bounding_boxes = convert_to_square(bounding_boxes)
125
+ bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
126
+
127
+ # STAGE 3
128
+
129
+ img_boxes = get_image_boxes(bounding_boxes, image, size=48)
130
+ if len(img_boxes) == 0:
131
+ return [], []
132
+ img_boxes = torch.FloatTensor(img_boxes).to(device)
133
+ output = self.onet(img_boxes)
134
+ landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10]
135
+ offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4]
136
+ probs = output[2].cpu().data.numpy() # shape [n_boxes, 2]
137
+
138
+ keep = np.where(probs[:, 1] > thresholds[2])[0]
139
+ bounding_boxes = bounding_boxes[keep]
140
+ bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
141
+ offsets = offsets[keep]
142
+ landmarks = landmarks[keep]
143
+
144
+ # compute landmark points
145
+ width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0
146
+ height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0
147
+ xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1]
148
+ landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5]
149
+ landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10]
150
+
151
+ bounding_boxes = calibrate_box(bounding_boxes, offsets)
152
+ keep = nms(bounding_boxes, nms_thresholds[2], mode='min')
153
+ bounding_boxes = bounding_boxes[keep]
154
+ landmarks = landmarks[keep]
155
+
156
+ return bounding_boxes, landmarks
models/mtcnn/mtcnn_pytorch/__init__.py ADDED
File without changes
models/mtcnn/mtcnn_pytorch/src/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
1
+ from .visualization_utils import show_bboxes
2
+ from .detector import detect_faces
models/mtcnn/mtcnn_pytorch/src/align_trans.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Mon Apr 24 15:43:29 2017
4
+ @author: zhaoy
5
+ """
6
+ import numpy as np
7
+ import cv2
8
+
9
+ # from scipy.linalg import lstsq
10
+ # from scipy.ndimage import geometric_transform # , map_coordinates
11
+
12
+ from models.mtcnn.mtcnn_pytorch.src.matlab_cp2tform import get_similarity_transform_for_cv2
13
+
14
+ # reference facial points, a list of coordinates (x,y)
15
+ REFERENCE_FACIAL_POINTS = [
16
+ [30.29459953, 51.69630051],
17
+ [65.53179932, 51.50139999],
18
+ [48.02519989, 71.73660278],
19
+ [33.54930115, 92.3655014],
20
+ [62.72990036, 92.20410156]
21
+ ]
22
+
23
+ DEFAULT_CROP_SIZE = (96, 112)
24
+
25
+
26
+ class FaceWarpException(Exception):
27
+ def __str__(self):
28
+ return 'In File {}:{}'.format(
29
+ __file__, super.__str__(self))
30
+
31
+
32
+ def get_reference_facial_points(output_size=None,
33
+ inner_padding_factor=0.0,
34
+ outer_padding=(0, 0),
35
+ default_square=False):
36
+ """
37
+ Function:
38
+ ----------
39
+ get reference 5 key points according to crop settings:
40
+ 0. Set default crop_size:
41
+ if default_square:
42
+ crop_size = (112, 112)
43
+ else:
44
+ crop_size = (96, 112)
45
+ 1. Pad the crop_size by inner_padding_factor in each side;
46
+ 2. Resize crop_size into (output_size - outer_padding*2),
47
+ pad into output_size with outer_padding;
48
+ 3. Output reference_5point;
49
+ Parameters:
50
+ ----------
51
+ @output_size: (w, h) or None
52
+ size of aligned face image
53
+ @inner_padding_factor: (w_factor, h_factor)
54
+ padding factor for inner (w, h)
55
+ @outer_padding: (w_pad, h_pad)
56
+ each row is a pair of coordinates (x, y)
57
+ @default_square: True or False
58
+ if True:
59
+ default crop_size = (112, 112)
60
+ else:
61
+ default crop_size = (96, 112);
62
+ !!! make sure, if output_size is not None:
63
+ (output_size - outer_padding)
64
+ = some_scale * (default crop_size * (1.0 + inner_padding_factor))
65
+ Returns:
66
+ ----------
67
+ @reference_5point: 5x2 np.array
68
+ each row is a pair of transformed coordinates (x, y)
69
+ """
70
+ # print('\n===> get_reference_facial_points():')
71
+
72
+ # print('---> Params:')
73
+ # print(' output_size: ', output_size)
74
+ # print(' inner_padding_factor: ', inner_padding_factor)
75
+ # print(' outer_padding:', outer_padding)
76
+ # print(' default_square: ', default_square)
77
+
78
+ tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
79
+ tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
80
+
81
+ # 0) make the inner region a square
82
+ if default_square:
83
+ size_diff = max(tmp_crop_size) - tmp_crop_size
84
+ tmp_5pts += size_diff / 2
85
+ tmp_crop_size += size_diff
86
+
87
+ # print('---> default:')
88
+ # print(' crop_size = ', tmp_crop_size)
89
+ # print(' reference_5pts = ', tmp_5pts)
90
+
91
+ if (output_size and
92
+ output_size[0] == tmp_crop_size[0] and
93
+ output_size[1] == tmp_crop_size[1]):
94
+ # print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
95
+ return tmp_5pts
96
+
97
+ if (inner_padding_factor == 0 and
98
+ outer_padding == (0, 0)):
99
+ if output_size is None:
100
+ # print('No paddings to do: return default reference points')
101
+ return tmp_5pts
102
+ else:
103
+ raise FaceWarpException(
104
+ 'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
105
+
106
+ # check output size
107
+ if not (0 <= inner_padding_factor <= 1.0):
108
+ raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
109
+
110
+ if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
111
+ and output_size is None):
112
+ output_size = tmp_crop_size * \
113
+ (1 + inner_padding_factor * 2).astype(np.int32)
114
+ output_size += np.array(outer_padding)
115
+ # print(' deduced from paddings, output_size = ', output_size)
116
+
117
+ if not (outer_padding[0] < output_size[0]
118
+ and outer_padding[1] < output_size[1]):
119
+ raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
120
+ 'and outer_padding[1] < output_size[1])')
121
+
122
+ # 1) pad the inner region according inner_padding_factor
123
+ # print('---> STEP1: pad the inner region according inner_padding_factor')
124
+ if inner_padding_factor > 0:
125
+ size_diff = tmp_crop_size * inner_padding_factor * 2
126
+ tmp_5pts += size_diff / 2
127
+ tmp_crop_size += np.round(size_diff).astype(np.int32)
128
+
129
+ # print(' crop_size = ', tmp_crop_size)
130
+ # print(' reference_5pts = ', tmp_5pts)
131
+
132
+ # 2) resize the padded inner region
133
+ # print('---> STEP2: resize the padded inner region')
134
+ size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
135
+ # print(' crop_size = ', tmp_crop_size)
136
+ # print(' size_bf_outer_pad = ', size_bf_outer_pad)
137
+
138
+ if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
139
+ raise FaceWarpException('Must have (output_size - outer_padding)'
140
+ '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
141
+
142
+ scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
143
+ # print(' resize scale_factor = ', scale_factor)
144
+ tmp_5pts = tmp_5pts * scale_factor
145
+ # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
146
+ # tmp_5pts = tmp_5pts + size_diff / 2
147
+ tmp_crop_size = size_bf_outer_pad
148
+ # print(' crop_size = ', tmp_crop_size)
149
+ # print(' reference_5pts = ', tmp_5pts)
150
+
151
+ # 3) add outer_padding to make output_size
152
+ reference_5point = tmp_5pts + np.array(outer_padding)
153
+ tmp_crop_size = output_size
154
+ # print('---> STEP3: add outer_padding to make output_size')
155
+ # print(' crop_size = ', tmp_crop_size)
156
+ # print(' reference_5pts = ', tmp_5pts)
157
+
158
+ # print('===> end get_reference_facial_points\n')
159
+
160
+ return reference_5point
161
+
162
+
163
+ def get_affine_transform_matrix(src_pts, dst_pts):
164
+ """
165
+ Function:
166
+ ----------
167
+ get affine transform matrix 'tfm' from src_pts to dst_pts
168
+ Parameters:
169
+ ----------
170
+ @src_pts: Kx2 np.array
171
+ source points matrix, each row is a pair of coordinates (x, y)
172
+ @dst_pts: Kx2 np.array
173
+ destination points matrix, each row is a pair of coordinates (x, y)
174
+ Returns:
175
+ ----------
176
+ @tfm: 2x3 np.array
177
+ transform matrix from src_pts to dst_pts
178
+ """
179
+
180
+ tfm = np.float32([[1, 0, 0], [0, 1, 0]])
181
+ n_pts = src_pts.shape[0]
182
+ ones = np.ones((n_pts, 1), src_pts.dtype)
183
+ src_pts_ = np.hstack([src_pts, ones])
184
+ dst_pts_ = np.hstack([dst_pts, ones])
185
+
186
+ # #print(('src_pts_:\n' + str(src_pts_))
187
+ # #print(('dst_pts_:\n' + str(dst_pts_))
188
+
189
+ A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
190
+
191
+ # #print(('np.linalg.lstsq return A: \n' + str(A))
192
+ # #print(('np.linalg.lstsq return res: \n' + str(res))
193
+ # #print(('np.linalg.lstsq return rank: \n' + str(rank))
194
+ # #print(('np.linalg.lstsq return s: \n' + str(s))
195
+
196
+ if rank == 3:
197
+ tfm = np.float32([
198
+ [A[0, 0], A[1, 0], A[2, 0]],
199
+ [A[0, 1], A[1, 1], A[2, 1]]
200
+ ])
201
+ elif rank == 2:
202
+ tfm = np.float32([
203
+ [A[0, 0], A[1, 0], 0],
204
+ [A[0, 1], A[1, 1], 0]
205
+ ])
206
+
207
+ return tfm
208
+
209
+
210
+ def warp_and_crop_face(src_img,
211
+ facial_pts,
212
+ reference_pts=None,
213
+ crop_size=(96, 112),
214
+ align_type='smilarity'):
215
+ """
216
+ Function:
217
+ ----------
218
+ apply affine transform 'trans' to uv
219
+ Parameters:
220
+ ----------
221
+ @src_img: 3x3 np.array
222
+ input image
223
+ @facial_pts: could be
224
+ 1)a list of K coordinates (x,y)
225
+ or
226
+ 2) Kx2 or 2xK np.array
227
+ each row or col is a pair of coordinates (x, y)
228
+ @reference_pts: could be
229
+ 1) a list of K coordinates (x,y)
230
+ or
231
+ 2) Kx2 or 2xK np.array
232
+ each row or col is a pair of coordinates (x, y)
233
+ or
234
+ 3) None
235
+ if None, use default reference facial points
236
+ @crop_size: (w, h)
237
+ output face image size
238
+ @align_type: transform type, could be one of
239
+ 1) 'similarity': use similarity transform
240
+ 2) 'cv2_affine': use the first 3 points to do affine transform,
241
+ by calling cv2.getAffineTransform()
242
+ 3) 'affine': use all points to do affine transform
243
+ Returns:
244
+ ----------
245
+ @face_img: output face image with size (w, h) = @crop_size
246
+ """
247
+
248
+ if reference_pts is None:
249
+ if crop_size[0] == 96 and crop_size[1] == 112:
250
+ reference_pts = REFERENCE_FACIAL_POINTS
251
+ else:
252
+ default_square = False
253
+ inner_padding_factor = 0
254
+ outer_padding = (0, 0)
255
+ output_size = crop_size
256
+
257
+ reference_pts = get_reference_facial_points(output_size,
258
+ inner_padding_factor,
259
+ outer_padding,
260
+ default_square)
261
+
262
+ ref_pts = np.float32(reference_pts)
263
+ ref_pts_shp = ref_pts.shape
264
+ if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
265
+ raise FaceWarpException(
266
+ 'reference_pts.shape must be (K,2) or (2,K) and K>2')
267
+
268
+ if ref_pts_shp[0] == 2:
269
+ ref_pts = ref_pts.T
270
+
271
+ src_pts = np.float32(facial_pts)
272
+ src_pts_shp = src_pts.shape
273
+ if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
274
+ raise FaceWarpException(
275
+ 'facial_pts.shape must be (K,2) or (2,K) and K>2')
276
+
277
+ if src_pts_shp[0] == 2:
278
+ src_pts = src_pts.T
279
+
280
+ # #print('--->src_pts:\n', src_pts
281
+ # #print('--->ref_pts\n', ref_pts
282
+
283
+ if src_pts.shape != ref_pts.shape:
284
+ raise FaceWarpException(
285
+ 'facial_pts and reference_pts must have the same shape')
286
+
287
+ if align_type is 'cv2_affine':
288
+ tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
289
+ # #print(('cv2.getAffineTransform() returns tfm=\n' + str(tfm))
290
+ elif align_type is 'affine':
291
+ tfm = get_affine_transform_matrix(src_pts, ref_pts)
292
+ # #print(('get_affine_transform_matrix() returns tfm=\n' + str(tfm))
293
+ else:
294
+ tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
295
+ # #print(('get_similarity_transform_for_cv2() returns tfm=\n' + str(tfm))
296
+
297
+ # #print('--->Transform matrix: '
298
+ # #print(('type(tfm):' + str(type(tfm)))
299
+ # #print(('tfm.dtype:' + str(tfm.dtype))
300
+ # #print( tfm
301
+
302
+ face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
303
+
304
+ return face_img, tfm
models/mtcnn/mtcnn_pytorch/src/box_utils.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+
4
+
5
+ def nms(boxes, overlap_threshold=0.5, mode='union'):
6
+ """Non-maximum suppression.
7
+
8
+ Arguments:
9
+ boxes: a float numpy array of shape [n, 5],
10
+ where each row is (xmin, ymin, xmax, ymax, score).
11
+ overlap_threshold: a float number.
12
+ mode: 'union' or 'min'.
13
+
14
+ Returns:
15
+ list with indices of the selected boxes
16
+ """
17
+
18
+ # if there are no boxes, return the empty list
19
+ if len(boxes) == 0:
20
+ return []
21
+
22
+ # list of picked indices
23
+ pick = []
24
+
25
+ # grab the coordinates of the bounding boxes
26
+ x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)]
27
+
28
+ area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0)
29
+ ids = np.argsort(score) # in increasing order
30
+
31
+ while len(ids) > 0:
32
+
33
+ # grab index of the largest value
34
+ last = len(ids) - 1
35
+ i = ids[last]
36
+ pick.append(i)
37
+
38
+ # compute intersections
39
+ # of the box with the largest score
40
+ # with the rest of boxes
41
+
42
+ # left top corner of intersection boxes
43
+ ix1 = np.maximum(x1[i], x1[ids[:last]])
44
+ iy1 = np.maximum(y1[i], y1[ids[:last]])
45
+
46
+ # right bottom corner of intersection boxes
47
+ ix2 = np.minimum(x2[i], x2[ids[:last]])
48
+ iy2 = np.minimum(y2[i], y2[ids[:last]])
49
+
50
+ # width and height of intersection boxes
51
+ w = np.maximum(0.0, ix2 - ix1 + 1.0)
52
+ h = np.maximum(0.0, iy2 - iy1 + 1.0)
53
+
54
+ # intersections' areas
55
+ inter = w * h
56
+ if mode == 'min':
57
+ overlap = inter / np.minimum(area[i], area[ids[:last]])
58
+ elif mode == 'union':
59
+ # intersection over union (IoU)
60
+ overlap = inter / (area[i] + area[ids[:last]] - inter)
61
+
62
+ # delete all boxes where overlap is too big
63
+ ids = np.delete(
64
+ ids,
65
+ np.concatenate([[last], np.where(overlap > overlap_threshold)[0]])
66
+ )
67
+
68
+ return pick
69
+
70
+
71
+ def convert_to_square(bboxes):
72
+ """Convert bounding boxes to a square form.
73
+
74
+ Arguments:
75
+ bboxes: a float numpy array of shape [n, 5].
76
+
77
+ Returns:
78
+ a float numpy array of shape [n, 5],
79
+ squared bounding boxes.
80
+ """
81
+
82
+ square_bboxes = np.zeros_like(bboxes)
83
+ x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
84
+ h = y2 - y1 + 1.0
85
+ w = x2 - x1 + 1.0
86
+ max_side = np.maximum(h, w)
87
+ square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5
88
+ square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5
89
+ square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0
90
+ square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0
91
+ return square_bboxes
92
+
93
+
94
+ def calibrate_box(bboxes, offsets):
95
+ """Transform bounding boxes to be more like true bounding boxes.
96
+ 'offsets' is one of the outputs of the nets.
97
+
98
+ Arguments:
99
+ bboxes: a float numpy array of shape [n, 5].
100
+ offsets: a float numpy array of shape [n, 4].
101
+
102
+ Returns:
103
+ a float numpy array of shape [n, 5].
104
+ """
105
+ x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
106
+ w = x2 - x1 + 1.0
107
+ h = y2 - y1 + 1.0
108
+ w = np.expand_dims(w, 1)
109
+ h = np.expand_dims(h, 1)
110
+
111
+ # this is what happening here:
112
+ # tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)]
113
+ # x1_true = x1 + tx1*w
114
+ # y1_true = y1 + ty1*h
115
+ # x2_true = x2 + tx2*w
116
+ # y2_true = y2 + ty2*h
117
+ # below is just more compact form of this
118
+
119
+ # are offsets always such that
120
+ # x1 < x2 and y1 < y2 ?
121
+
122
+ translation = np.hstack([w, h, w, h]) * offsets
123
+ bboxes[:, 0:4] = bboxes[:, 0:4] + translation
124
+ return bboxes
125
+
126
+
127
+ def get_image_boxes(bounding_boxes, img, size=24):
128
+ """Cut out boxes from the image.
129
+
130
+ Arguments:
131
+ bounding_boxes: a float numpy array of shape [n, 5].
132
+ img: an instance of PIL.Image.
133
+ size: an integer, size of cutouts.
134
+
135
+ Returns:
136
+ a float numpy array of shape [n, 3, size, size].
137
+ """
138
+
139
+ num_boxes = len(bounding_boxes)
140
+ width, height = img.size
141
+
142
+ [dy, edy, dx, edx, y, ey, x, ex, w, h] = correct_bboxes(bounding_boxes, width, height)
143
+ img_boxes = np.zeros((num_boxes, 3, size, size), 'float32')
144
+
145
+ for i in range(num_boxes):
146
+ img_box = np.zeros((h[i], w[i], 3), 'uint8')
147
+
148
+ img_array = np.asarray(img, 'uint8')
149
+ img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] = \
150
+ img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :]
151
+
152
+ # resize
153
+ img_box = Image.fromarray(img_box)
154
+ img_box = img_box.resize((size, size), Image.BILINEAR)
155
+ img_box = np.asarray(img_box, 'float32')
156
+
157
+ img_boxes[i, :, :, :] = _preprocess(img_box)
158
+
159
+ return img_boxes
160
+
161
+
162
+ def correct_bboxes(bboxes, width, height):
163
+ """Crop boxes that are too big and get coordinates
164
+ with respect to cutouts.
165
+
166
+ Arguments:
167
+ bboxes: a float numpy array of shape [n, 5],
168
+ where each row is (xmin, ymin, xmax, ymax, score).
169
+ width: a float number.
170
+ height: a float number.
171
+
172
+ Returns:
173
+ dy, dx, edy, edx: a int numpy arrays of shape [n],
174
+ coordinates of the boxes with respect to the cutouts.
175
+ y, x, ey, ex: a int numpy arrays of shape [n],
176
+ corrected ymin, xmin, ymax, xmax.
177
+ h, w: a int numpy arrays of shape [n],
178
+ just heights and widths of boxes.
179
+
180
+ in the following order:
181
+ [dy, edy, dx, edx, y, ey, x, ex, w, h].
182
+ """
183
+
184
+ x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
185
+ w, h = x2 - x1 + 1.0, y2 - y1 + 1.0
186
+ num_boxes = bboxes.shape[0]
187
+
188
+ # 'e' stands for end
189
+ # (x, y) -> (ex, ey)
190
+ x, y, ex, ey = x1, y1, x2, y2
191
+
192
+ # we need to cut out a box from the image.
193
+ # (x, y, ex, ey) are corrected coordinates of the box
194
+ # in the image.
195
+ # (dx, dy, edx, edy) are coordinates of the box in the cutout
196
+ # from the image.
197
+ dx, dy = np.zeros((num_boxes,)), np.zeros((num_boxes,))
198
+ edx, edy = w.copy() - 1.0, h.copy() - 1.0
199
+
200
+ # if box's bottom right corner is too far right
201
+ ind = np.where(ex > width - 1.0)[0]
202
+ edx[ind] = w[ind] + width - 2.0 - ex[ind]
203
+ ex[ind] = width - 1.0
204
+
205
+ # if box's bottom right corner is too low
206
+ ind = np.where(ey > height - 1.0)[0]
207
+ edy[ind] = h[ind] + height - 2.0 - ey[ind]
208
+ ey[ind] = height - 1.0
209
+
210
+ # if box's top left corner is too far left
211
+ ind = np.where(x < 0.0)[0]
212
+ dx[ind] = 0.0 - x[ind]
213
+ x[ind] = 0.0
214
+
215
+ # if box's top left corner is too high
216
+ ind = np.where(y < 0.0)[0]
217
+ dy[ind] = 0.0 - y[ind]
218
+ y[ind] = 0.0
219
+
220
+ return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h]
221
+ return_list = [i.astype('int32') for i in return_list]
222
+
223
+ return return_list
224
+
225
+
226
+ def _preprocess(img):
227
+ """Preprocessing step before feeding the network.
228
+
229
+ Arguments:
230
+ img: a float numpy array of shape [h, w, c].
231
+
232
+ Returns:
233
+ a float numpy array of shape [1, c, h, w].
234
+ """
235
+ img = img.transpose((2, 0, 1))
236
+ img = np.expand_dims(img, 0)
237
+ img = (img - 127.5) * 0.0078125
238
+ return img
models/mtcnn/mtcnn_pytorch/src/detector.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.autograd import Variable
4
+ from .get_nets import PNet, RNet, ONet
5
+ from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square
6
+ from .first_stage import run_first_stage
7
+
8
+
9
+ def detect_faces(image, min_face_size=20.0,
10
+ thresholds=[0.6, 0.7, 0.8],
11
+ nms_thresholds=[0.7, 0.7, 0.7]):
12
+ """
13
+ Arguments:
14
+ image: an instance of PIL.Image.
15
+ min_face_size: a float number.
16
+ thresholds: a list of length 3.
17
+ nms_thresholds: a list of length 3.
18
+
19
+ Returns:
20
+ two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10],
21
+ bounding boxes and facial landmarks.
22
+ """
23
+
24
+ # LOAD MODELS
25
+ pnet = PNet()
26
+ rnet = RNet()
27
+ onet = ONet()
28
+ onet.eval()
29
+
30
+ # BUILD AN IMAGE PYRAMID
31
+ width, height = image.size
32
+ min_length = min(height, width)
33
+
34
+ min_detection_size = 12
35
+ factor = 0.707 # sqrt(0.5)
36
+
37
+ # scales for scaling the image
38
+ scales = []
39
+
40
+ # scales the image so that
41
+ # minimum size that we can detect equals to
42
+ # minimum face size that we want to detect
43
+ m = min_detection_size / min_face_size
44
+ min_length *= m
45
+
46
+ factor_count = 0
47
+ while min_length > min_detection_size:
48
+ scales.append(m * factor ** factor_count)
49
+ min_length *= factor
50
+ factor_count += 1
51
+
52
+ # STAGE 1
53
+
54
+ # it will be returned
55
+ bounding_boxes = []
56
+
57
+ with torch.no_grad():
58
+ # run P-Net on different scales
59
+ for s in scales:
60
+ boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0])
61
+ bounding_boxes.append(boxes)
62
+
63
+ # collect boxes (and offsets, and scores) from different scales
64
+ bounding_boxes = [i for i in bounding_boxes if i is not None]
65
+ bounding_boxes = np.vstack(bounding_boxes)
66
+
67
+ keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0])
68
+ bounding_boxes = bounding_boxes[keep]
69
+
70
+ # use offsets predicted by pnet to transform bounding boxes
71
+ bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:])
72
+ # shape [n_boxes, 5]
73
+
74
+ bounding_boxes = convert_to_square(bounding_boxes)
75
+ bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
76
+
77
+ # STAGE 2
78
+
79
+ img_boxes = get_image_boxes(bounding_boxes, image, size=24)
80
+ img_boxes = torch.FloatTensor(img_boxes)
81
+
82
+ output = rnet(img_boxes)
83
+ offsets = output[0].data.numpy() # shape [n_boxes, 4]
84
+ probs = output[1].data.numpy() # shape [n_boxes, 2]
85
+
86
+ keep = np.where(probs[:, 1] > thresholds[1])[0]
87
+ bounding_boxes = bounding_boxes[keep]
88
+ bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
89
+ offsets = offsets[keep]
90
+
91
+ keep = nms(bounding_boxes, nms_thresholds[1])
92
+ bounding_boxes = bounding_boxes[keep]
93
+ bounding_boxes = calibrate_box(bounding_boxes, offsets[keep])
94
+ bounding_boxes = convert_to_square(bounding_boxes)
95
+ bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
96
+
97
+ # STAGE 3
98
+
99
+ img_boxes = get_image_boxes(bounding_boxes, image, size=48)
100
+ if len(img_boxes) == 0:
101
+ return [], []
102
+ img_boxes = torch.FloatTensor(img_boxes)
103
+ output = onet(img_boxes)
104
+ landmarks = output[0].data.numpy() # shape [n_boxes, 10]
105
+ offsets = output[1].data.numpy() # shape [n_boxes, 4]
106
+ probs = output[2].data.numpy() # shape [n_boxes, 2]
107
+
108
+ keep = np.where(probs[:, 1] > thresholds[2])[0]
109
+ bounding_boxes = bounding_boxes[keep]
110
+ bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
111
+ offsets = offsets[keep]
112
+ landmarks = landmarks[keep]
113
+
114
+ # compute landmark points
115
+ width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0
116
+ height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0
117
+ xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1]
118
+ landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5]
119
+ landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10]
120
+
121
+ bounding_boxes = calibrate_box(bounding_boxes, offsets)
122
+ keep = nms(bounding_boxes, nms_thresholds[2], mode='min')
123
+ bounding_boxes = bounding_boxes[keep]
124
+ landmarks = landmarks[keep]
125
+
126
+ return bounding_boxes, landmarks
models/mtcnn/mtcnn_pytorch/src/first_stage.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Variable
3
+ import math
4
+ from PIL import Image
5
+ import numpy as np
6
+ from .box_utils import nms, _preprocess
7
+
8
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
9
+ device = 'cuda:0'
10
+
11
+
12
+ def run_first_stage(image, net, scale, threshold):
13
+ """Run P-Net, generate bounding boxes, and do NMS.
14
+
15
+ Arguments:
16
+ image: an instance of PIL.Image.
17
+ net: an instance of pytorch's nn.Module, P-Net.
18
+ scale: a float number,
19
+ scale width and height of the image by this number.
20
+ threshold: a float number,
21
+ threshold on the probability of a face when generating
22
+ bounding boxes from predictions of the net.
23
+
24
+ Returns:
25
+ a float numpy array of shape [n_boxes, 9],
26
+ bounding boxes with scores and offsets (4 + 1 + 4).
27
+ """
28
+
29
+ # scale the image and convert it to a float array
30
+ width, height = image.size
31
+ sw, sh = math.ceil(width * scale), math.ceil(height * scale)
32
+ img = image.resize((sw, sh), Image.BILINEAR)
33
+ img = np.asarray(img, 'float32')
34
+
35
+ img = torch.FloatTensor(_preprocess(img)).to(device)
36
+ with torch.no_grad():
37
+ output = net(img)
38
+ probs = output[1].cpu().data.numpy()[0, 1, :, :]
39
+ offsets = output[0].cpu().data.numpy()
40
+ # probs: probability of a face at each sliding window
41
+ # offsets: transformations to true bounding boxes
42
+
43
+ boxes = _generate_bboxes(probs, offsets, scale, threshold)
44
+ if len(boxes) == 0:
45
+ return None
46
+
47
+ keep = nms(boxes[:, 0:5], overlap_threshold=0.5)
48
+ return boxes[keep]
49
+
50
+
51
+ def _generate_bboxes(probs, offsets, scale, threshold):
52
+ """Generate bounding boxes at places
53
+ where there is probably a face.
54
+
55
+ Arguments:
56
+ probs: a float numpy array of shape [n, m].
57
+ offsets: a float numpy array of shape [1, 4, n, m].
58
+ scale: a float number,
59
+ width and height of the image were scaled by this number.
60
+ threshold: a float number.
61
+
62
+ Returns:
63
+ a float numpy array of shape [n_boxes, 9]
64
+ """
65
+
66
+ # applying P-Net is equivalent, in some sense, to
67
+ # moving 12x12 window with stride 2
68
+ stride = 2
69
+ cell_size = 12
70
+
71
+ # indices of boxes where there is probably a face
72
+ inds = np.where(probs > threshold)
73
+
74
+ if inds[0].size == 0:
75
+ return np.array([])
76
+
77
+ # transformations of bounding boxes
78
+ tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)]
79
+ # they are defined as:
80
+ # w = x2 - x1 + 1
81
+ # h = y2 - y1 + 1
82
+ # x1_true = x1 + tx1*w
83
+ # x2_true = x2 + tx2*w
84
+ # y1_true = y1 + ty1*h
85
+ # y2_true = y2 + ty2*h
86
+
87
+ offsets = np.array([tx1, ty1, tx2, ty2])
88
+ score = probs[inds[0], inds[1]]
89
+
90
+ # P-Net is applied to scaled images
91
+ # so we need to rescale bounding boxes back
92
+ bounding_boxes = np.vstack([
93
+ np.round((stride * inds[1] + 1.0) / scale),
94
+ np.round((stride * inds[0] + 1.0) / scale),
95
+ np.round((stride * inds[1] + 1.0 + cell_size) / scale),
96
+ np.round((stride * inds[0] + 1.0 + cell_size) / scale),
97
+ score, offsets
98
+ ])
99
+ # why one is added?
100
+
101
+ return bounding_boxes.T
models/mtcnn/mtcnn_pytorch/src/get_nets.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from collections import OrderedDict
5
+ import numpy as np
6
+
7
+ from configs.paths_config import model_paths
8
+ PNET_PATH = model_paths["mtcnn_pnet"]
9
+ ONET_PATH = model_paths["mtcnn_onet"]
10
+ RNET_PATH = model_paths["mtcnn_rnet"]
11
+
12
+
13
+ class Flatten(nn.Module):
14
+
15
+ def __init__(self):
16
+ super(Flatten, self).__init__()
17
+
18
+ def forward(self, x):
19
+ """
20
+ Arguments:
21
+ x: a float tensor with shape [batch_size, c, h, w].
22
+ Returns:
23
+ a float tensor with shape [batch_size, c*h*w].
24
+ """
25
+
26
+ # without this pretrained model isn't working
27
+ x = x.transpose(3, 2).contiguous()
28
+
29
+ return x.view(x.size(0), -1)
30
+
31
+
32
+ class PNet(nn.Module):
33
+
34
+ def __init__(self):
35
+ super().__init__()
36
+
37
+ # suppose we have input with size HxW, then
38
+ # after first layer: H - 2,
39
+ # after pool: ceil((H - 2)/2),
40
+ # after second conv: ceil((H - 2)/2) - 2,
41
+ # after last conv: ceil((H - 2)/2) - 4,
42
+ # and the same for W
43
+
44
+ self.features = nn.Sequential(OrderedDict([
45
+ ('conv1', nn.Conv2d(3, 10, 3, 1)),
46
+ ('prelu1', nn.PReLU(10)),
47
+ ('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)),
48
+
49
+ ('conv2', nn.Conv2d(10, 16, 3, 1)),
50
+ ('prelu2', nn.PReLU(16)),
51
+
52
+ ('conv3', nn.Conv2d(16, 32, 3, 1)),
53
+ ('prelu3', nn.PReLU(32))
54
+ ]))
55
+
56
+ self.conv4_1 = nn.Conv2d(32, 2, 1, 1)
57
+ self.conv4_2 = nn.Conv2d(32, 4, 1, 1)
58
+
59
+ weights = np.load(PNET_PATH, allow_pickle=True)[()]
60
+ for n, p in self.named_parameters():
61
+ p.data = torch.FloatTensor(weights[n])
62
+
63
+ def forward(self, x):
64
+ """
65
+ Arguments:
66
+ x: a float tensor with shape [batch_size, 3, h, w].
67
+ Returns:
68
+ b: a float tensor with shape [batch_size, 4, h', w'].
69
+ a: a float tensor with shape [batch_size, 2, h', w'].
70
+ """
71
+ x = self.features(x)
72
+ a = self.conv4_1(x)
73
+ b = self.conv4_2(x)
74
+ a = F.softmax(a, dim=-1)
75
+ return b, a
76
+
77
+
78
+ class RNet(nn.Module):
79
+
80
+ def __init__(self):
81
+ super().__init__()
82
+
83
+ self.features = nn.Sequential(OrderedDict([
84
+ ('conv1', nn.Conv2d(3, 28, 3, 1)),
85
+ ('prelu1', nn.PReLU(28)),
86
+ ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
87
+
88
+ ('conv2', nn.Conv2d(28, 48, 3, 1)),
89
+ ('prelu2', nn.PReLU(48)),
90
+ ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
91
+
92
+ ('conv3', nn.Conv2d(48, 64, 2, 1)),
93
+ ('prelu3', nn.PReLU(64)),
94
+
95
+ ('flatten', Flatten()),
96
+ ('conv4', nn.Linear(576, 128)),
97
+ ('prelu4', nn.PReLU(128))
98
+ ]))
99
+
100
+ self.conv5_1 = nn.Linear(128, 2)
101
+ self.conv5_2 = nn.Linear(128, 4)
102
+
103
+ weights = np.load(RNET_PATH, allow_pickle=True)[()]
104
+ for n, p in self.named_parameters():
105
+ p.data = torch.FloatTensor(weights[n])
106
+
107
+ def forward(self, x):
108
+ """
109
+ Arguments:
110
+ x: a float tensor with shape [batch_size, 3, h, w].
111
+ Returns:
112
+ b: a float tensor with shape [batch_size, 4].
113
+ a: a float tensor with shape [batch_size, 2].
114
+ """
115
+ x = self.features(x)
116
+ a = self.conv5_1(x)
117
+ b = self.conv5_2(x)
118
+ a = F.softmax(a, dim=-1)
119
+ return b, a
120
+
121
+
122
+ class ONet(nn.Module):
123
+
124
+ def __init__(self):
125
+ super().__init__()
126
+
127
+ self.features = nn.Sequential(OrderedDict([
128
+ ('conv1', nn.Conv2d(3, 32, 3, 1)),
129
+ ('prelu1', nn.PReLU(32)),
130
+ ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
131
+
132
+ ('conv2', nn.Conv2d(32, 64, 3, 1)),
133
+ ('prelu2', nn.PReLU(64)),
134
+ ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
135
+
136
+ ('conv3', nn.Conv2d(64, 64, 3, 1)),
137
+ ('prelu3', nn.PReLU(64)),
138
+ ('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)),
139
+
140
+ ('conv4', nn.Conv2d(64, 128, 2, 1)),
141
+ ('prelu4', nn.PReLU(128)),
142
+
143
+ ('flatten', Flatten()),
144
+ ('conv5', nn.Linear(1152, 256)),
145
+ ('drop5', nn.Dropout(0.25)),
146
+ ('prelu5', nn.PReLU(256)),
147
+ ]))
148
+
149
+ self.conv6_1 = nn.Linear(256, 2)
150
+ self.conv6_2 = nn.Linear(256, 4)
151
+ self.conv6_3 = nn.Linear(256, 10)
152
+
153
+ weights = np.load(ONET_PATH, allow_pickle=True)[()]
154
+ for n, p in self.named_parameters():
155
+ p.data = torch.FloatTensor(weights[n])
156
+
157
+ def forward(self, x):
158
+ """
159
+ Arguments:
160
+ x: a float tensor with shape [batch_size, 3, h, w].
161
+ Returns:
162
+ c: a float tensor with shape [batch_size, 10].
163
+ b: a float tensor with shape [batch_size, 4].
164
+ a: a float tensor with shape [batch_size, 2].
165
+ """
166
+ x = self.features(x)
167
+ a = self.conv6_1(x)
168
+ b = self.conv6_2(x)
169
+ c = self.conv6_3(x)
170
+ a = F.softmax(a, dim=-1)
171
+ return c, b, a
models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Tue Jul 11 06:54:28 2017
4
+
5
+ @author: zhaoyafei
6
+ """
7
+
8
+ import numpy as np
9
+ from numpy.linalg import inv, norm, lstsq
10
+ from numpy.linalg import matrix_rank as rank
11
+
12
+
13
+ class MatlabCp2tormException(Exception):
14
+ def __str__(self):
15
+ return 'In File {}:{}'.format(
16
+ __file__, super.__str__(self))
17
+
18
+
19
+ def tformfwd(trans, uv):
20
+ """
21
+ Function:
22
+ ----------
23
+ apply affine transform 'trans' to uv
24
+
25
+ Parameters:
26
+ ----------
27
+ @trans: 3x3 np.array
28
+ transform matrix
29
+ @uv: Kx2 np.array
30
+ each row is a pair of coordinates (x, y)
31
+
32
+ Returns:
33
+ ----------
34
+ @xy: Kx2 np.array
35
+ each row is a pair of transformed coordinates (x, y)
36
+ """
37
+ uv = np.hstack((
38
+ uv, np.ones((uv.shape[0], 1))
39
+ ))
40
+ xy = np.dot(uv, trans)
41
+ xy = xy[:, 0:-1]
42
+ return xy
43
+
44
+
45
+ def tforminv(trans, uv):
46
+ """
47
+ Function:
48
+ ----------
49
+ apply the inverse of affine transform 'trans' to uv
50
+
51
+ Parameters:
52
+ ----------
53
+ @trans: 3x3 np.array
54
+ transform matrix
55
+ @uv: Kx2 np.array
56
+ each row is a pair of coordinates (x, y)
57
+
58
+ Returns:
59
+ ----------
60
+ @xy: Kx2 np.array
61
+ each row is a pair of inverse-transformed coordinates (x, y)
62
+ """
63
+ Tinv = inv(trans)
64
+ xy = tformfwd(Tinv, uv)
65
+ return xy
66
+
67
+
68
+ def findNonreflectiveSimilarity(uv, xy, options=None):
69
+ options = {'K': 2}
70
+
71
+ K = options['K']
72
+ M = xy.shape[0]
73
+ x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
74
+ y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
75
+ # print('--->x, y:\n', x, y
76
+
77
+ tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
78
+ tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
79
+ X = np.vstack((tmp1, tmp2))
80
+ # print('--->X.shape: ', X.shape
81
+ # print('X:\n', X
82
+
83
+ u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
84
+ v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
85
+ U = np.vstack((u, v))
86
+ # print('--->U.shape: ', U.shape
87
+ # print('U:\n', U
88
+
89
+ # We know that X * r = U
90
+ if rank(X) >= 2 * K:
91
+ r, _, _, _ = lstsq(X, U, rcond=None) # Make sure this is what I want
92
+ r = np.squeeze(r)
93
+ else:
94
+ raise Exception('cp2tform:twoUniquePointsReq')
95
+
96
+ # print('--->r:\n', r
97
+
98
+ sc = r[0]
99
+ ss = r[1]
100
+ tx = r[2]
101
+ ty = r[3]
102
+
103
+ Tinv = np.array([
104
+ [sc, -ss, 0],
105
+ [ss, sc, 0],
106
+ [tx, ty, 1]
107
+ ])
108
+
109
+ # print('--->Tinv:\n', Tinv
110
+
111
+ T = inv(Tinv)
112
+ # print('--->T:\n', T
113
+
114
+ T[:, 2] = np.array([0, 0, 1])
115
+
116
+ return T, Tinv
117
+
118
+
119
+ def findSimilarity(uv, xy, options=None):
120
+ options = {'K': 2}
121
+
122
+ # uv = np.array(uv)
123
+ # xy = np.array(xy)
124
+
125
+ # Solve for trans1
126
+ trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
127
+
128
+ # Solve for trans2
129
+
130
+ # manually reflect the xy data across the Y-axis
131
+ xyR = xy
132
+ xyR[:, 0] = -1 * xyR[:, 0]
133
+
134
+ trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
135
+
136
+ # manually reflect the tform to undo the reflection done on xyR
137
+ TreflectY = np.array([
138
+ [-1, 0, 0],
139
+ [0, 1, 0],
140
+ [0, 0, 1]
141
+ ])
142
+
143
+ trans2 = np.dot(trans2r, TreflectY)
144
+
145
+ # Figure out if trans1 or trans2 is better
146
+ xy1 = tformfwd(trans1, uv)
147
+ norm1 = norm(xy1 - xy)
148
+
149
+ xy2 = tformfwd(trans2, uv)
150
+ norm2 = norm(xy2 - xy)
151
+
152
+ if norm1 <= norm2:
153
+ return trans1, trans1_inv
154
+ else:
155
+ trans2_inv = inv(trans2)
156
+ return trans2, trans2_inv
157
+
158
+
159
+ def get_similarity_transform(src_pts, dst_pts, reflective=True):
160
+ """
161
+ Function:
162
+ ----------
163
+ Find Similarity Transform Matrix 'trans':
164
+ u = src_pts[:, 0]
165
+ v = src_pts[:, 1]
166
+ x = dst_pts[:, 0]
167
+ y = dst_pts[:, 1]
168
+ [x, y, 1] = [u, v, 1] * trans
169
+
170
+ Parameters:
171
+ ----------
172
+ @src_pts: Kx2 np.array
173
+ source points, each row is a pair of coordinates (x, y)
174
+ @dst_pts: Kx2 np.array
175
+ destination points, each row is a pair of transformed
176
+ coordinates (x, y)
177
+ @reflective: True or False
178
+ if True:
179
+ use reflective similarity transform
180
+ else:
181
+ use non-reflective similarity transform
182
+
183
+ Returns:
184
+ ----------
185
+ @trans: 3x3 np.array
186
+ transform matrix from uv to xy
187
+ trans_inv: 3x3 np.array
188
+ inverse of trans, transform matrix from xy to uv
189
+ """
190
+
191
+ if reflective:
192
+ trans, trans_inv = findSimilarity(src_pts, dst_pts)
193
+ else:
194
+ trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
195
+
196
+ return trans, trans_inv
197
+
198
+
199
+ def cvt_tform_mat_for_cv2(trans):
200
+ """
201
+ Function:
202
+ ----------
203
+ Convert Transform Matrix 'trans' into 'cv2_trans' which could be
204
+ directly used by cv2.warpAffine():
205
+ u = src_pts[:, 0]
206
+ v = src_pts[:, 1]
207
+ x = dst_pts[:, 0]
208
+ y = dst_pts[:, 1]
209
+ [x, y].T = cv_trans * [u, v, 1].T
210
+
211
+ Parameters:
212
+ ----------
213
+ @trans: 3x3 np.array
214
+ transform matrix from uv to xy
215
+
216
+ Returns:
217
+ ----------
218
+ @cv2_trans: 2x3 np.array
219
+ transform matrix from src_pts to dst_pts, could be directly used
220
+ for cv2.warpAffine()
221
+ """
222
+ cv2_trans = trans[:, 0:2].T
223
+
224
+ return cv2_trans
225
+
226
+
227
+ def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
228
+ """
229
+ Function:
230
+ ----------
231
+ Find Similarity Transform Matrix 'cv2_trans' which could be
232
+ directly used by cv2.warpAffine():
233
+ u = src_pts[:, 0]
234
+ v = src_pts[:, 1]
235
+ x = dst_pts[:, 0]
236
+ y = dst_pts[:, 1]
237
+ [x, y].T = cv_trans * [u, v, 1].T
238
+
239
+ Parameters:
240
+ ----------
241
+ @src_pts: Kx2 np.array
242
+ source points, each row is a pair of coordinates (x, y)
243
+ @dst_pts: Kx2 np.array
244
+ destination points, each row is a pair of transformed
245
+ coordinates (x, y)
246
+ reflective: True or False
247
+ if True:
248
+ use reflective similarity transform
249
+ else:
250
+ use non-reflective similarity transform
251
+
252
+ Returns:
253
+ ----------
254
+ @cv2_trans: 2x3 np.array
255
+ transform matrix from src_pts to dst_pts, could be directly used
256
+ for cv2.warpAffine()
257
+ """
258
+ trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
259
+ cv2_trans = cvt_tform_mat_for_cv2(trans)
260
+
261
+ return cv2_trans
262
+
263
+
264
+ if __name__ == '__main__':
265
+ """
266
+ u = [0, 6, -2]
267
+ v = [0, 3, 5]
268
+ x = [-1, 0, 4]
269
+ y = [-1, -10, 4]
270
+
271
+ # In Matlab, run:
272
+ #
273
+ # uv = [u'; v'];
274
+ # xy = [x'; y'];
275
+ # tform_sim=cp2tform(uv,xy,'similarity');
276
+ #
277
+ # trans = tform_sim.tdata.T
278
+ # ans =
279
+ # -0.0764 -1.6190 0
280
+ # 1.6190 -0.0764 0
281
+ # -3.2156 0.0290 1.0000
282
+ # trans_inv = tform_sim.tdata.Tinv
283
+ # ans =
284
+ #
285
+ # -0.0291 0.6163 0
286
+ # -0.6163 -0.0291 0
287
+ # -0.0756 1.9826 1.0000
288
+ # xy_m=tformfwd(tform_sim, u,v)
289
+ #
290
+ # xy_m =
291
+ #
292
+ # -3.2156 0.0290
293
+ # 1.1833 -9.9143
294
+ # 5.0323 2.8853
295
+ # uv_m=tforminv(tform_sim, x,y)
296
+ #
297
+ # uv_m =
298
+ #
299
+ # 0.5698 1.3953
300
+ # 6.0872 2.2733
301
+ # -2.6570 4.3314
302
+ """
303
+ u = [0, 6, -2]
304
+ v = [0, 3, 5]
305
+ x = [-1, 0, 4]
306
+ y = [-1, -10, 4]
307
+
308
+ uv = np.array((u, v)).T
309
+ xy = np.array((x, y)).T
310
+
311
+ print('\n--->uv:')
312
+ print(uv)
313
+ print('\n--->xy:')
314
+ print(xy)
315
+
316
+ trans, trans_inv = get_similarity_transform(uv, xy)
317
+
318
+ print('\n--->trans matrix:')
319
+ print(trans)
320
+
321
+ print('\n--->trans_inv matrix:')
322
+ print(trans_inv)
323
+
324
+ print('\n---> apply transform to uv')
325
+ print('\nxy_m = uv_augmented * trans')
326
+ uv_aug = np.hstack((
327
+ uv, np.ones((uv.shape[0], 1))
328
+ ))
329
+ xy_m = np.dot(uv_aug, trans)
330
+ print(xy_m)
331
+
332
+ print('\nxy_m = tformfwd(trans, uv)')
333
+ xy_m = tformfwd(trans, uv)
334
+ print(xy_m)
335
+
336
+ print('\n---> apply inverse transform to xy')
337
+ print('\nuv_m = xy_augmented * trans_inv')
338
+ xy_aug = np.hstack((
339
+ xy, np.ones((xy.shape[0], 1))
340
+ ))
341
+ uv_m = np.dot(xy_aug, trans_inv)
342
+ print(uv_m)
343
+
344
+ print('\nuv_m = tformfwd(trans_inv, xy)')
345
+ uv_m = tformfwd(trans_inv, xy)
346
+ print(uv_m)
347
+
348
+ uv_m = tforminv(trans, xy)
349
+ print('\nuv_m = tforminv(trans, xy)')
350
+ print(uv_m)
models/mtcnn/mtcnn_pytorch/src/visualization_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import ImageDraw
2
+
3
+
4
+ def show_bboxes(img, bounding_boxes, facial_landmarks=[]):
5
+ """Draw bounding boxes and facial landmarks.
6
+
7
+ Arguments:
8
+ img: an instance of PIL.Image.
9
+ bounding_boxes: a float numpy array of shape [n, 5].
10
+ facial_landmarks: a float numpy array of shape [n, 10].
11
+
12
+ Returns:
13
+ an instance of PIL.Image.
14
+ """
15
+
16
+ img_copy = img.copy()
17
+ draw = ImageDraw.Draw(img_copy)
18
+
19
+ for b in bounding_boxes:
20
+ draw.rectangle([
21
+ (b[0], b[1]), (b[2], b[3])
22
+ ], outline='white')
23
+
24
+ for p in facial_landmarks:
25
+ for i in range(5):
26
+ draw.ellipse([
27
+ (p[i] - 1.0, p[i + 5] - 1.0),
28
+ (p[i] + 1.0, p[i + 5] + 1.0)
29
+ ], outline='blue')
30
+
31
+ return img_copy
models/psp.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file defines the core research contribution
3
+ """
4
+ import matplotlib
5
+ matplotlib.use('Agg')
6
+ import math
7
+
8
+ import torch
9
+ from torch import nn
10
+ from models.encoders import psp_encoders
11
+ from models.stylegan2.model import Generator
12
+ from configs.paths_config import model_paths
13
+
14
+
15
+ def get_keys(d, name):
16
+ if 'state_dict' in d:
17
+ d = d['state_dict']
18
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
19
+ return d_filt
20
+
21
+
22
+ class pSp(nn.Module):
23
+
24
+ def __init__(self, opts):
25
+ super(pSp, self).__init__()
26
+ self.set_opts(opts)
27
+ # compute number of style inputs based on the output resolution
28
+ self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
29
+ # Define architecture
30
+ self.encoder = self.set_encoder()
31
+ self.decoder = Generator(self.opts.output_size, 512, 8)
32
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
33
+ # Load weights if needed
34
+ self.load_weights()
35
+
36
+ def set_encoder(self):
37
+ if self.opts.encoder_type == 'GradualStyleEncoder':
38
+ encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
39
+ elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW':
40
+ encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
41
+ elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus':
42
+ encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts)
43
+ else:
44
+ raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
45
+ return encoder
46
+
47
+ def load_weights(self):
48
+ if self.opts.checkpoint_path is not None:
49
+ print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path))
50
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
51
+ self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
52
+ self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
53
+ self.__load_latent_avg(ckpt)
54
+ else:
55
+ print('Loading encoders weights from irse50!')
56
+ encoder_ckpt = torch.load(model_paths['ir_se50'])
57
+ # if input to encoder is not an RGB image, do not load the input layer weights
58
+ if self.opts.label_nc != 0:
59
+ encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k}
60
+ self.encoder.load_state_dict(encoder_ckpt, strict=False)
61
+ print('Loading decoder weights from pretrained!')
62
+ ckpt = torch.load(self.opts.stylegan_weights)
63
+ self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
64
+ if self.opts.learn_in_w:
65
+ self.__load_latent_avg(ckpt, repeat=1)
66
+ else:
67
+ self.__load_latent_avg(ckpt, repeat=self.opts.n_styles)
68
+
69
+ def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
70
+ inject_latent=None, return_latents=False, alpha=None):
71
+ if input_code:
72
+ codes = x
73
+ else:
74
+ codes = self.encoder(x)
75
+ # normalize with respect to the center of an average face
76
+ if self.opts.start_from_latent_avg:
77
+ if self.opts.learn_in_w:
78
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1)
79
+ else:
80
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
81
+
82
+
83
+ if latent_mask is not None:
84
+ for i in latent_mask:
85
+ if inject_latent is not None:
86
+ if alpha is not None:
87
+ codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
88
+ else:
89
+ codes[:, i] = inject_latent[:, i]
90
+ else:
91
+ codes[:, i] = 0
92
+
93
+ input_is_latent = not input_code
94
+
95
+ if return_latents:
96
+ result_latent = self.decoder([codes],input_is_latent=input_is_latent,randomize_noise=randomize_noise,return_latents=return_latents)
97
+ return result_latent
98
+ else:
99
+ images, result_latent = self.decoder([codes],
100
+ input_is_latent=input_is_latent,
101
+ randomize_noise=randomize_noise,
102
+ return_latents=return_latents)
103
+
104
+ if resize:
105
+ images = self.face_pool(images)
106
+
107
+ return images
108
+
109
+ def set_opts(self, opts):
110
+ self.opts = opts
111
+
112
+ def __load_latent_avg(self, ckpt, repeat=None):
113
+ if 'latent_avg' in ckpt:
114
+ self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
115
+ if repeat is not None:
116
+ self.latent_avg = self.latent_avg.repeat(repeat, 1)
117
+ else:
118
+ self.latent_avg = None
models/stylegan2/__init__.py ADDED
File without changes
models/stylegan2/model.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
8
+
9
+
10
+ class PixelNorm(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ def forward(self, input):
15
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
16
+
17
+
18
+ def make_kernel(k):
19
+ k = torch.tensor(k, dtype=torch.float32)
20
+
21
+ if k.ndim == 1:
22
+ k = k[None, :] * k[:, None]
23
+
24
+ k /= k.sum()
25
+
26
+ return k
27
+
28
+
29
+ class Upsample(nn.Module):
30
+ def __init__(self, kernel, factor=2):
31
+ super().__init__()
32
+
33
+ self.factor = factor
34
+ kernel = make_kernel(kernel) * (factor ** 2)
35
+ self.register_buffer('kernel', kernel)
36
+
37
+ p = kernel.shape[0] - factor
38
+
39
+ pad0 = (p + 1) // 2 + factor - 1
40
+ pad1 = p // 2
41
+
42
+ self.pad = (pad0, pad1)
43
+
44
+ def forward(self, input):
45
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
46
+
47
+ return out
48
+
49
+
50
+ class Downsample(nn.Module):
51
+ def __init__(self, kernel, factor=2):
52
+ super().__init__()
53
+
54
+ self.factor = factor
55
+ kernel = make_kernel(kernel)
56
+ self.register_buffer('kernel', kernel)
57
+
58
+ p = kernel.shape[0] - factor
59
+
60
+ pad0 = (p + 1) // 2
61
+ pad1 = p // 2
62
+
63
+ self.pad = (pad0, pad1)
64
+
65
+ def forward(self, input):
66
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
67
+
68
+ return out
69
+
70
+
71
+ class Blur(nn.Module):
72
+ def __init__(self, kernel, pad, upsample_factor=1):
73
+ super().__init__()
74
+
75
+ kernel = make_kernel(kernel)
76
+
77
+ if upsample_factor > 1:
78
+ kernel = kernel * (upsample_factor ** 2)
79
+
80
+ self.register_buffer('kernel', kernel)
81
+
82
+ self.pad = pad
83
+
84
+ def forward(self, input):
85
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
86
+
87
+ return out
88
+
89
+
90
+ class EqualConv2d(nn.Module):
91
+ def __init__(
92
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
93
+ ):
94
+ super().__init__()
95
+
96
+ self.weight = nn.Parameter(
97
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
98
+ )
99
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
100
+
101
+ self.stride = stride
102
+ self.padding = padding
103
+
104
+ if bias:
105
+ self.bias = nn.Parameter(torch.zeros(out_channel))
106
+
107
+ else:
108
+ self.bias = None
109
+
110
+ def forward(self, input):
111
+ out = F.conv2d(
112
+ input,
113
+ self.weight * self.scale,
114
+ bias=self.bias,
115
+ stride=self.stride,
116
+ padding=self.padding,
117
+ )
118
+
119
+ return out
120
+
121
+ def __repr__(self):
122
+ return (
123
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
124
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
125
+ )
126
+
127
+
128
+ class EqualLinear(nn.Module):
129
+ def __init__(
130
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
131
+ ):
132
+ super().__init__()
133
+
134
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
135
+
136
+ if bias:
137
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
138
+
139
+ else:
140
+ self.bias = None
141
+
142
+ self.activation = activation
143
+
144
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
145
+ self.lr_mul = lr_mul
146
+
147
+ def forward(self, input):
148
+ if self.activation:
149
+ out = F.linear(input, self.weight * self.scale)
150
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
151
+
152
+ else:
153
+ out = F.linear(
154
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
155
+ )
156
+
157
+ return out
158
+
159
+ def __repr__(self):
160
+ return (
161
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
162
+ )
163
+
164
+
165
+ class ScaledLeakyReLU(nn.Module):
166
+ def __init__(self, negative_slope=0.2):
167
+ super().__init__()
168
+
169
+ self.negative_slope = negative_slope
170
+
171
+ def forward(self, input):
172
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
173
+
174
+ return out * math.sqrt(2)
175
+
176
+
177
+ class ModulatedConv2d(nn.Module):
178
+ def __init__(
179
+ self,
180
+ in_channel,
181
+ out_channel,
182
+ kernel_size,
183
+ style_dim,
184
+ demodulate=True,
185
+ upsample=False,
186
+ downsample=False,
187
+ blur_kernel=[1, 3, 3, 1],
188
+ ):
189
+ super().__init__()
190
+
191
+ self.eps = 1e-8
192
+ self.kernel_size = kernel_size
193
+ self.in_channel = in_channel
194
+ self.out_channel = out_channel
195
+ self.upsample = upsample
196
+ self.downsample = downsample
197
+
198
+ if upsample:
199
+ factor = 2
200
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
201
+ pad0 = (p + 1) // 2 + factor - 1
202
+ pad1 = p // 2 + 1
203
+
204
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
205
+
206
+ if downsample:
207
+ factor = 2
208
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
209
+ pad0 = (p + 1) // 2
210
+ pad1 = p // 2
211
+
212
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
213
+
214
+ fan_in = in_channel * kernel_size ** 2
215
+ self.scale = 1 / math.sqrt(fan_in)
216
+ self.padding = kernel_size // 2
217
+
218
+ self.weight = nn.Parameter(
219
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
220
+ )
221
+
222
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
223
+
224
+ self.demodulate = demodulate
225
+
226
+ def __repr__(self):
227
+ return (
228
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
229
+ f'upsample={self.upsample}, downsample={self.downsample})'
230
+ )
231
+
232
+ def forward(self, input, style):
233
+ batch, in_channel, height, width = input.shape
234
+
235
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
236
+ weight = self.scale * self.weight * style
237
+
238
+ if self.demodulate:
239
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
240
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
241
+
242
+ weight = weight.view(
243
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
244
+ )
245
+
246
+ if self.upsample:
247
+ input = input.view(1, batch * in_channel, height, width)
248
+ weight = weight.view(
249
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
250
+ )
251
+ weight = weight.transpose(1, 2).reshape(
252
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
253
+ )
254
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
255
+ _, _, height, width = out.shape
256
+ out = out.view(batch, self.out_channel, height, width)
257
+ out = self.blur(out)
258
+
259
+ elif self.downsample:
260
+ input = self.blur(input)
261
+ _, _, height, width = input.shape
262
+ input = input.view(1, batch * in_channel, height, width)
263
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
264
+ _, _, height, width = out.shape
265
+ out = out.view(batch, self.out_channel, height, width)
266
+
267
+ else:
268
+ input = input.view(1, batch * in_channel, height, width)
269
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
270
+ _, _, height, width = out.shape
271
+ out = out.view(batch, self.out_channel, height, width)
272
+
273
+ return out
274
+
275
+
276
+ class NoiseInjection(nn.Module):
277
+ def __init__(self):
278
+ super().__init__()
279
+
280
+ self.weight = nn.Parameter(torch.zeros(1))
281
+
282
+ def forward(self, image, noise=None):
283
+ if noise is None:
284
+ batch, _, height, width = image.shape
285
+ noise = image.new_empty(batch, 1, height, width).normal_()
286
+
287
+ return image + self.weight * noise
288
+
289
+
290
+ class ConstantInput(nn.Module):
291
+ def __init__(self, channel, size=4):
292
+ super().__init__()
293
+
294
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
295
+
296
+ def forward(self, input):
297
+ batch = input.shape[0]
298
+ out = self.input.repeat(batch, 1, 1, 1)
299
+
300
+ return out
301
+
302
+
303
+ class StyledConv(nn.Module):
304
+ def __init__(
305
+ self,
306
+ in_channel,
307
+ out_channel,
308
+ kernel_size,
309
+ style_dim,
310
+ upsample=False,
311
+ blur_kernel=[1, 3, 3, 1],
312
+ demodulate=True,
313
+ ):
314
+ super().__init__()
315
+
316
+ self.conv = ModulatedConv2d(
317
+ in_channel,
318
+ out_channel,
319
+ kernel_size,
320
+ style_dim,
321
+ upsample=upsample,
322
+ blur_kernel=blur_kernel,
323
+ demodulate=demodulate,
324
+ )
325
+
326
+ self.noise = NoiseInjection()
327
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
328
+ # self.activate = ScaledLeakyReLU(0.2)
329
+ self.activate = FusedLeakyReLU(out_channel)
330
+
331
+ def forward(self, input, style, noise=None):
332
+ out = self.conv(input, style)
333
+ out = self.noise(out, noise=noise)
334
+ # out = out + self.bias
335
+ out = self.activate(out)
336
+
337
+ return out
338
+
339
+
340
+ class ToRGB(nn.Module):
341
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
342
+ super().__init__()
343
+
344
+ if upsample:
345
+ self.upsample = Upsample(blur_kernel)
346
+
347
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
348
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
349
+
350
+ def forward(self, input, style, skip=None):
351
+ out = self.conv(input, style)
352
+ out = out + self.bias
353
+
354
+ if skip is not None:
355
+ skip = self.upsample(skip)
356
+
357
+ out = out + skip
358
+
359
+ return out
360
+
361
+
362
+ class Generator(nn.Module):
363
+ def __init__(
364
+ self,
365
+ size,
366
+ style_dim,
367
+ n_mlp,
368
+ channel_multiplier=2,
369
+ blur_kernel=[1, 3, 3, 1],
370
+ lr_mlp=0.01,
371
+ ):
372
+ super().__init__()
373
+
374
+ self.size = size
375
+
376
+ self.style_dim = style_dim
377
+
378
+ layers = [PixelNorm()]
379
+
380
+ for i in range(n_mlp):
381
+ layers.append(
382
+ EqualLinear(
383
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
384
+ )
385
+ )
386
+
387
+ self.style = nn.Sequential(*layers)
388
+
389
+ self.channels = {
390
+ 4: 512,
391
+ 8: 512,
392
+ 16: 512,
393
+ 32: 512,
394
+ 64: 256 * channel_multiplier,
395
+ 128: 128 * channel_multiplier,
396
+ 256: 64 * channel_multiplier,
397
+ 512: 32 * channel_multiplier,
398
+ 1024: 16 * channel_multiplier,
399
+ }
400
+
401
+ self.input = ConstantInput(self.channels[4])
402
+ self.conv1 = StyledConv(
403
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
404
+ )
405
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
406
+
407
+ self.log_size = int(math.log(size, 2))
408
+ self.num_layers = (self.log_size - 2) * 2 + 1
409
+
410
+ self.convs = nn.ModuleList()
411
+ self.upsamples = nn.ModuleList()
412
+ self.to_rgbs = nn.ModuleList()
413
+ self.noises = nn.Module()
414
+
415
+ in_channel = self.channels[4]
416
+
417
+ for layer_idx in range(self.num_layers):
418
+ res = (layer_idx + 5) // 2
419
+ shape = [1, 1, 2 ** res, 2 ** res]
420
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
421
+
422
+ for i in range(3, self.log_size + 1):
423
+ out_channel = self.channels[2 ** i]
424
+
425
+ self.convs.append(
426
+ StyledConv(
427
+ in_channel,
428
+ out_channel,
429
+ 3,
430
+ style_dim,
431
+ upsample=True,
432
+ blur_kernel=blur_kernel,
433
+ )
434
+ )
435
+
436
+ self.convs.append(
437
+ StyledConv(
438
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
439
+ )
440
+ )
441
+
442
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
443
+
444
+ in_channel = out_channel
445
+
446
+ self.n_latent = self.log_size * 2 - 2
447
+
448
+ def make_noise(self):
449
+ device = self.input.input.device
450
+
451
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
452
+
453
+ for i in range(3, self.log_size + 1):
454
+ for _ in range(2):
455
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
456
+
457
+ return noises
458
+
459
+ def mean_latent(self, n_latent):
460
+ latent_in = torch.randn(
461
+ n_latent, self.style_dim, device=self.input.input.device
462
+ )
463
+ latent = self.style(latent_in).mean(0, keepdim=True)
464
+
465
+ return latent
466
+
467
+ def get_latent(self, input):
468
+ return self.style(input)
469
+
470
+ def forward(
471
+ self,
472
+ styles,
473
+ return_latents=False,
474
+ return_features=False,
475
+ inject_index=None,
476
+ truncation=1,
477
+ truncation_latent=None,
478
+ input_is_latent=False,
479
+ noise=None,
480
+ randomize_noise=True,
481
+ ):
482
+ if not input_is_latent:
483
+ styles = [self.style(s) for s in styles]
484
+
485
+ if noise is None:
486
+ if randomize_noise:
487
+ noise = [None] * self.num_layers
488
+ else:
489
+ noise = [
490
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
491
+ ]
492
+
493
+ if truncation < 1:
494
+ style_t = []
495
+
496
+ for style in styles:
497
+ style_t.append(
498
+ truncation_latent + truncation * (style - truncation_latent)
499
+ )
500
+
501
+ styles = style_t
502
+
503
+ if len(styles) < 2:
504
+ inject_index = self.n_latent
505
+
506
+ if styles[0].ndim < 3:
507
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
508
+ else:
509
+ latent = styles[0]
510
+
511
+ else:
512
+ if inject_index is None:
513
+ inject_index = random.randint(1, self.n_latent - 1)
514
+
515
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
516
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
517
+
518
+ latent = torch.cat([latent, latent2], 1)
519
+
520
+ if return_latents:
521
+ return latent
522
+
523
+ out = self.input(latent)
524
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
525
+
526
+ skip = self.to_rgb1(out, latent[:, 1])
527
+
528
+ i = 1
529
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
530
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
531
+ ):
532
+ out = conv1(out, latent[:, i], noise=noise1)
533
+ out = conv2(out, latent[:, i + 1], noise=noise2)
534
+ skip = to_rgb(out, latent[:, i + 2], skip)
535
+
536
+ i += 2
537
+
538
+ image = skip
539
+
540
+ if return_features:
541
+ return image, out
542
+ else:
543
+ return image, None
544
+
545
+
546
+ class ConvLayer(nn.Sequential):
547
+ def __init__(
548
+ self,
549
+ in_channel,
550
+ out_channel,
551
+ kernel_size,
552
+ downsample=False,
553
+ blur_kernel=[1, 3, 3, 1],
554
+ bias=True,
555
+ activate=True,
556
+ ):
557
+ layers = []
558
+
559
+ if downsample:
560
+ factor = 2
561
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
562
+ pad0 = (p + 1) // 2
563
+ pad1 = p // 2
564
+
565
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
566
+
567
+ stride = 2
568
+ self.padding = 0
569
+
570
+ else:
571
+ stride = 1
572
+ self.padding = kernel_size // 2
573
+
574
+ layers.append(
575
+ EqualConv2d(
576
+ in_channel,
577
+ out_channel,
578
+ kernel_size,
579
+ padding=self.padding,
580
+ stride=stride,
581
+ bias=bias and not activate,
582
+ )
583
+ )
584
+
585
+ if activate:
586
+ if bias:
587
+ layers.append(FusedLeakyReLU(out_channel))
588
+
589
+ else:
590
+ layers.append(ScaledLeakyReLU(0.2))
591
+
592
+ super().__init__(*layers)
593
+
594
+
595
+ class ResBlock(nn.Module):
596
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
597
+ super().__init__()
598
+
599
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
600
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
601
+
602
+ self.skip = ConvLayer(
603
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
604
+ )
605
+
606
+ def forward(self, input):
607
+ out = self.conv1(input)
608
+ out = self.conv2(out)
609
+
610
+ skip = self.skip(input)
611
+ out = (out + skip) / math.sqrt(2)
612
+
613
+ return out
614
+
615
+
616
+ class Discriminator(nn.Module):
617
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
618
+ super().__init__()
619
+
620
+ channels = {
621
+ 4: 512,
622
+ 8: 512,
623
+ 16: 512,
624
+ 32: 512,
625
+ 64: 256 * channel_multiplier,
626
+ 128: 128 * channel_multiplier,
627
+ 256: 64 * channel_multiplier,
628
+ 512: 32 * channel_multiplier,
629
+ 1024: 16 * channel_multiplier,
630
+ }
631
+
632
+ convs = [ConvLayer(3, channels[size], 1)]
633
+
634
+ log_size = int(math.log(size, 2))
635
+
636
+ in_channel = channels[size]
637
+
638
+ for i in range(log_size, 2, -1):
639
+ out_channel = channels[2 ** (i - 1)]
640
+
641
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
642
+
643
+ in_channel = out_channel
644
+
645
+ self.convs = nn.Sequential(*convs)
646
+
647
+ self.stddev_group = 4
648
+ self.stddev_feat = 1
649
+
650
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
651
+ self.final_linear = nn.Sequential(
652
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
653
+ EqualLinear(channels[4], 1),
654
+ )
655
+
656
+ def forward(self, input):
657
+ out = self.convs(input)
658
+
659
+ batch, channel, height, width = out.shape
660
+ group = min(batch, self.stddev_group)
661
+ stddev = out.view(
662
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
663
+ )
664
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
665
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
666
+ stddev = stddev.repeat(group, 1, height, width)
667
+ out = torch.cat([out, stddev], 1)
668
+
669
+ out = self.final_conv(out)
670
+
671
+ out = out.view(batch, -1)
672
+ out = self.final_linear(out)
673
+
674
+ return out
models/stylegan2/op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
models/stylegan2/op/fused_act.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.autograd import Function
6
+ from torch.nn import functional as F
7
+
8
+
9
+ module_path = os.path.dirname(__file__)
10
+
11
+
12
+ class FusedLeakyReLU(nn.Module):
13
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
14
+ super().__init__()
15
+
16
+ self.bias = nn.Parameter(torch.zeros(channel))
17
+ self.negative_slope = negative_slope
18
+ self.scale = scale
19
+
20
+ def forward(self, input):
21
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
22
+
23
+ def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
24
+ if input.device.type == "cpu":
25
+ if bias is not None:
26
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
27
+ return (
28
+ F.leaky_relu(
29
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
30
+ )
31
+ * scale
32
+ )
33
+
34
+ else:
35
+ return F.leaky_relu(input, negative_slope=0.2) * scale
36
+
37
+
models/stylegan2/op/upfirdn2d.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.autograd import Function
5
+ from torch.nn import functional as F
6
+
7
+
8
+
9
+ module_path = os.path.dirname(__file__)
10
+
11
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
12
+ out = upfirdn2d_native(
13
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
14
+ )
15
+
16
+ return out
17
+
18
+
19
+ def upfirdn2d_native(
20
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
21
+ ):
22
+ _, channel, in_h, in_w = input.shape
23
+ input = input.reshape(-1, in_h, in_w, 1)
24
+
25
+ _, in_h, in_w, minor = input.shape
26
+ kernel_h, kernel_w = kernel.shape
27
+
28
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
29
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
30
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
31
+
32
+ out = F.pad(
33
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
34
+ )
35
+ out = out[
36
+ :,
37
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
38
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
39
+ :,
40
+ ]
41
+
42
+ out = out.permute(0, 3, 1, 2)
43
+ out = out.reshape(
44
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
45
+ )
46
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
47
+ out = F.conv2d(out, w)
48
+ out = out.reshape(
49
+ -1,
50
+ minor,
51
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
52
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
53
+ )
54
+ out = out.permute(0, 2, 3, 1)
55
+ out = out[:, ::down_y, ::down_x, :]
56
+
57
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
58
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
59
+
60
+ return out.view(-1, channel, out_h, out_w)
pretrained/ohayou_face.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89abef3962a9ca6b214f1447e7050725b73d41822d7381e1f4d0f96ac8035381
3
+ size 363965331
pretrained/ohayou_face.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c63de7970a7af6cc5b5c0cf677eb16095f2aaabd68dab41fcc3851bb5c7464f9
3
+ size 1077486507
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ torchvision
4
+ Pillow
5
+ tqdm
6
+ imageio
7
+ scipy
8
+ easydict
9
+ opensimplex==0.3
10
+ ninja
torch_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty