Spaces:
Runtime error
Runtime error
Reevee
commited on
Commit
•
f39e999
0
Parent(s):
first
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- README.md +11 -0
- app.py +31 -0
- configs/__init__.py +0 -0
- configs/data_configs.py +41 -0
- configs/paths_config.py +23 -0
- configs/transforms_config.py +152 -0
- criteria/__init__.py +0 -0
- criteria/id_loss.py +44 -0
- criteria/lpips/__init__.py +0 -0
- criteria/lpips/lpips.py +35 -0
- criteria/lpips/networks.py +96 -0
- criteria/lpips/utils.py +30 -0
- criteria/moco_loss.py +69 -0
- criteria/w_norm.py +14 -0
- datasets/__init__.py +0 -0
- datasets/augmentations.py +110 -0
- datasets/gt_res_dataset.py +32 -0
- datasets/images_dataset.py +33 -0
- datasets/inference_dataset.py +22 -0
- dnnlib/__init__.py +9 -0
- dnnlib/util.py +477 -0
- legacy.py +384 -0
- model_build.py +95 -0
- models/__init__.py +0 -0
- models/encoders/__init__.py +0 -0
- models/encoders/helpers.py +119 -0
- models/encoders/model_irse.py +84 -0
- models/encoders/psp_encoders.py +186 -0
- models/mtcnn/__init__.py +0 -0
- models/mtcnn/mtcnn.py +156 -0
- models/mtcnn/mtcnn_pytorch/__init__.py +0 -0
- models/mtcnn/mtcnn_pytorch/src/__init__.py +2 -0
- models/mtcnn/mtcnn_pytorch/src/align_trans.py +304 -0
- models/mtcnn/mtcnn_pytorch/src/box_utils.py +238 -0
- models/mtcnn/mtcnn_pytorch/src/detector.py +126 -0
- models/mtcnn/mtcnn_pytorch/src/first_stage.py +101 -0
- models/mtcnn/mtcnn_pytorch/src/get_nets.py +171 -0
- models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py +350 -0
- models/mtcnn/mtcnn_pytorch/src/visualization_utils.py +31 -0
- models/psp.py +118 -0
- models/stylegan2/__init__.py +0 -0
- models/stylegan2/model.py +674 -0
- models/stylegan2/op/__init__.py +2 -0
- models/stylegan2/op/fused_act.py +37 -0
- models/stylegan2/op/upfirdn2d.py +60 -0
- pretrained/ohayou_face.pkl +3 -0
- pretrained/ohayou_face.pt +3 -0
- requirements.txt +10 -0
- torch_utils/__init__.py +9 -0
.gitattributes
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
1 |
+
pretrained/ohayou_face.pt filter=lfs diff=lfs merge=lfs -text
|
2 |
+
pretrained/ohayou_face.pkl filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Ohayou_Face
|
3 |
+
emoji: ⚡
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
app_file: app.py
|
8 |
+
pinned: false
|
9 |
+
---
|
10 |
+
|
11 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
|
app.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import gradio as gr
|
4 |
+
from torchvision import transforms
|
5 |
+
import easydict
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import model_build
|
9 |
+
|
10 |
+
|
11 |
+
psp = model_build.build_psp()
|
12 |
+
stylegan2 = model_build.build_stylegan2()
|
13 |
+
|
14 |
+
pretransform = transforms.Compose([
|
15 |
+
transforms.Resize((256, 256)),
|
16 |
+
transforms.ToTensor(),
|
17 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
18 |
+
|
19 |
+
def pipeline(img):
|
20 |
+
img = model_build.img_preprocess(img, pretransform)
|
21 |
+
with torch.no_grad():
|
22 |
+
_, latent_space = psp(img.float(), randomize_noise=True, resize=False, return_latents=True)
|
23 |
+
img = stylegan2(latent_space, noise_mode='none')
|
24 |
+
img = Image.fromarray(np.array((img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).squeeze(0)[20:-20,:,:]))
|
25 |
+
img.save('output.png')
|
26 |
+
return 'output.png'
|
27 |
+
|
28 |
+
examples=[['momoi_out.png',False], ['churuki_out.png', False], ['fgfgfggf.png', False], ['dsfd.png', False]]
|
29 |
+
description="The male image doesn't work well. 1:1 ratio image recommended (square cropable after uploading). If the background is not monochromatic, it can be mixed with hair color. It takes an average of 5 seconds, but it can take longer if there is a lot of traffic. 남성 이미지에는 잘 작동하지 않음. 1:1비율 권장(업로드 후 정사각형 자르기 가능), 배경이 단색이 아니면 머리색과 섞일 수 있음. 트래픽이 많으면 5초 이상 걸릴 수 있음. Email:krkmfn@gmail.com"
|
30 |
+
gr.Interface(pipeline, [gr.inputs.Image(type="pil")], gr.outputs.Image(type="file"),description=description,allow_flagging=False,examples=examples,allow_screenshot=False,enable_queue=False).launch()
|
31 |
+
|
configs/__init__.py
ADDED
File without changes
|
configs/data_configs.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs import transforms_config
|
2 |
+
from configs.paths_config import dataset_paths
|
3 |
+
|
4 |
+
|
5 |
+
DATASETS = {
|
6 |
+
'ffhq_encode': {
|
7 |
+
'transforms': transforms_config.EncodeTransforms,
|
8 |
+
'train_source_root': dataset_paths['ffhq'],
|
9 |
+
'train_target_root': dataset_paths['ffhq'],
|
10 |
+
'test_source_root': dataset_paths['celeba_test'],
|
11 |
+
'test_target_root': dataset_paths['celeba_test'],
|
12 |
+
},
|
13 |
+
'furry': {
|
14 |
+
'transforms': transforms_config.FrontalizationTransforms,
|
15 |
+
'train_source_root': dataset_paths['anime'],
|
16 |
+
'train_target_root': dataset_paths['anime'],
|
17 |
+
'test_source_root': dataset_paths['gogal'],
|
18 |
+
'test_target_root': dataset_paths['gogal'],
|
19 |
+
},
|
20 |
+
'celebs_sketch_to_face': {
|
21 |
+
'transforms': transforms_config.SketchToImageTransforms,
|
22 |
+
'train_source_root': dataset_paths['celeba_train_sketch'],
|
23 |
+
'train_target_root': dataset_paths['celeba_train'],
|
24 |
+
'test_source_root': dataset_paths['celeba_test_sketch'],
|
25 |
+
'test_target_root': dataset_paths['celeba_test'],
|
26 |
+
},
|
27 |
+
'celebs_seg_to_face': {
|
28 |
+
'transforms': transforms_config.SegToImageTransforms,
|
29 |
+
'train_source_root': dataset_paths['celeba_train_segmentation'],
|
30 |
+
'train_target_root': dataset_paths['celeba_train'],
|
31 |
+
'test_source_root': dataset_paths['celeba_test_segmentation'],
|
32 |
+
'test_target_root': dataset_paths['celeba_test'],
|
33 |
+
},
|
34 |
+
'celebs_super_resolution': {
|
35 |
+
'transforms': transforms_config.SuperResTransforms,
|
36 |
+
'train_source_root': dataset_paths['celeba_train'],
|
37 |
+
'train_target_root': dataset_paths['celeba_train'],
|
38 |
+
'test_source_root': dataset_paths['celeba_test'],
|
39 |
+
'test_target_root': dataset_paths['celeba_test'],
|
40 |
+
},
|
41 |
+
}
|
configs/paths_config.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_paths = {
|
2 |
+
'celeba_train': '',
|
3 |
+
'celeba_test': '',
|
4 |
+
'celeba_train_sketch': '',
|
5 |
+
'celeba_test_sketch': '',
|
6 |
+
'celeba_train_segmentation': '',
|
7 |
+
'celeba_test_segmentation': '',
|
8 |
+
'ffhq': '',
|
9 |
+
'anime' : '/content/drive/MyDrive/Dataset/anime',
|
10 |
+
'gogal' : '/content/drive/MyDrive/All Data/고갈왕'
|
11 |
+
}
|
12 |
+
|
13 |
+
model_paths = {
|
14 |
+
'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
|
15 |
+
'ir_se50': 'pretrained_models/model_ir_se50.pth',
|
16 |
+
'circular_face': 'pretrained_models/CurricularFace_Backbone.pth',
|
17 |
+
'mtcnn_pnet': 'pretrained_models/mtcnn/pnet.npy',
|
18 |
+
'mtcnn_rnet': 'pretrained_models/mtcnn/rnet.npy',
|
19 |
+
'mtcnn_onet': 'pretrained_models/mtcnn/onet.npy',
|
20 |
+
'shape_predictor': 'shape_predictor_68_face_landmarks.dat',
|
21 |
+
'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth.tar',
|
22 |
+
'anime' : '/content/drive/MyDrive/StyleGAN2-ada/result/pretrained/anime_face.pt'
|
23 |
+
}
|
configs/transforms_config.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
from datasets import augmentations
|
4 |
+
|
5 |
+
|
6 |
+
class TransformsConfig(object):
|
7 |
+
|
8 |
+
def __init__(self, opts):
|
9 |
+
self.opts = opts
|
10 |
+
|
11 |
+
@abstractmethod
|
12 |
+
def get_transforms(self):
|
13 |
+
pass
|
14 |
+
|
15 |
+
|
16 |
+
class EncodeTransforms(TransformsConfig):
|
17 |
+
|
18 |
+
def __init__(self, opts):
|
19 |
+
super(EncodeTransforms, self).__init__(opts)
|
20 |
+
|
21 |
+
def get_transforms(self):
|
22 |
+
transforms_dict = {
|
23 |
+
'transform_gt_train': transforms.Compose([
|
24 |
+
transforms.Resize((256, 256)),
|
25 |
+
transforms.RandomHorizontalFlip(0.5),
|
26 |
+
transforms.ToTensor(),
|
27 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
28 |
+
'transform_source': None,
|
29 |
+
'transform_test': transforms.Compose([
|
30 |
+
transforms.Resize((256, 256)),
|
31 |
+
transforms.ToTensor(),
|
32 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
33 |
+
'transform_inference': transforms.Compose([
|
34 |
+
transforms.Resize((256, 256)),
|
35 |
+
transforms.ToTensor(),
|
36 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
37 |
+
}
|
38 |
+
return transforms_dict
|
39 |
+
|
40 |
+
|
41 |
+
class FrontalizationTransforms(TransformsConfig):
|
42 |
+
|
43 |
+
def __init__(self, opts):
|
44 |
+
super(FrontalizationTransforms, self).__init__(opts)
|
45 |
+
|
46 |
+
def get_transforms(self):
|
47 |
+
transforms_dict = {
|
48 |
+
'transform_gt_train': transforms.Compose([
|
49 |
+
transforms.Resize((256, 256)),
|
50 |
+
transforms.ToTensor(),
|
51 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
52 |
+
'transform_source': transforms.Compose([
|
53 |
+
transforms.Resize((256, 256)),
|
54 |
+
transforms.ToTensor(),
|
55 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
56 |
+
'transform_test': transforms.Compose([
|
57 |
+
transforms.Resize((256, 256)),
|
58 |
+
transforms.ToTensor(),
|
59 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
60 |
+
'transform_inference': transforms.Compose([
|
61 |
+
transforms.Resize((256, 256)),
|
62 |
+
transforms.ToTensor(),
|
63 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
64 |
+
}
|
65 |
+
return transforms_dict
|
66 |
+
|
67 |
+
|
68 |
+
class SketchToImageTransforms(TransformsConfig):
|
69 |
+
|
70 |
+
def __init__(self, opts):
|
71 |
+
super(SketchToImageTransforms, self).__init__(opts)
|
72 |
+
|
73 |
+
def get_transforms(self):
|
74 |
+
transforms_dict = {
|
75 |
+
'transform_gt_train': transforms.Compose([
|
76 |
+
transforms.Resize((256, 256)),
|
77 |
+
transforms.ToTensor(),
|
78 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
79 |
+
'transform_source': transforms.Compose([
|
80 |
+
transforms.Resize((256, 256)),
|
81 |
+
transforms.ToTensor()]),
|
82 |
+
'transform_test': transforms.Compose([
|
83 |
+
transforms.Resize((256, 256)),
|
84 |
+
transforms.ToTensor(),
|
85 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
86 |
+
'transform_inference': transforms.Compose([
|
87 |
+
transforms.Resize((256, 256)),
|
88 |
+
transforms.ToTensor()]),
|
89 |
+
}
|
90 |
+
return transforms_dict
|
91 |
+
|
92 |
+
|
93 |
+
class SegToImageTransforms(TransformsConfig):
|
94 |
+
|
95 |
+
def __init__(self, opts):
|
96 |
+
super(SegToImageTransforms, self).__init__(opts)
|
97 |
+
|
98 |
+
def get_transforms(self):
|
99 |
+
transforms_dict = {
|
100 |
+
'transform_gt_train': transforms.Compose([
|
101 |
+
transforms.Resize((256, 256)),
|
102 |
+
transforms.ToTensor(),
|
103 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
104 |
+
'transform_source': transforms.Compose([
|
105 |
+
transforms.Resize((256, 256)),
|
106 |
+
augmentations.ToOneHot(self.opts.label_nc),
|
107 |
+
transforms.ToTensor()]),
|
108 |
+
'transform_test': transforms.Compose([
|
109 |
+
transforms.Resize((256, 256)),
|
110 |
+
transforms.ToTensor(),
|
111 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
112 |
+
'transform_inference': transforms.Compose([
|
113 |
+
transforms.Resize((256, 256)),
|
114 |
+
augmentations.ToOneHot(self.opts.label_nc),
|
115 |
+
transforms.ToTensor()])
|
116 |
+
}
|
117 |
+
return transforms_dict
|
118 |
+
|
119 |
+
|
120 |
+
class SuperResTransforms(TransformsConfig):
|
121 |
+
|
122 |
+
def __init__(self, opts):
|
123 |
+
super(SuperResTransforms, self).__init__(opts)
|
124 |
+
|
125 |
+
def get_transforms(self):
|
126 |
+
if self.opts.resize_factors is None:
|
127 |
+
self.opts.resize_factors = '1,2,4,8,16,32'
|
128 |
+
factors = [int(f) for f in self.opts.resize_factors.split(",")]
|
129 |
+
print("Performing down-sampling with factors: {}".format(factors))
|
130 |
+
transforms_dict = {
|
131 |
+
'transform_gt_train': transforms.Compose([
|
132 |
+
transforms.Resize((256, 256)),
|
133 |
+
transforms.ToTensor(),
|
134 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
135 |
+
'transform_source': transforms.Compose([
|
136 |
+
transforms.Resize((256, 256)),
|
137 |
+
augmentations.BilinearResize(factors=factors),
|
138 |
+
transforms.Resize((256, 256)),
|
139 |
+
transforms.ToTensor(),
|
140 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
141 |
+
'transform_test': transforms.Compose([
|
142 |
+
transforms.Resize((256, 256)),
|
143 |
+
transforms.ToTensor(),
|
144 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
145 |
+
'transform_inference': transforms.Compose([
|
146 |
+
transforms.Resize((256, 256)),
|
147 |
+
augmentations.BilinearResize(factors=factors),
|
148 |
+
transforms.Resize((256, 256)),
|
149 |
+
transforms.ToTensor(),
|
150 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
151 |
+
}
|
152 |
+
return transforms_dict
|
criteria/__init__.py
ADDED
File without changes
|
criteria/id_loss.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from configs.paths_config import model_paths
|
4 |
+
from models.encoders.model_irse import Backbone
|
5 |
+
|
6 |
+
|
7 |
+
class IDLoss(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(IDLoss, self).__init__()
|
10 |
+
print('Loading ResNet ArcFace')
|
11 |
+
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
12 |
+
self.facenet.load_state_dict(torch.load(model_paths['ir_se50']))
|
13 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
14 |
+
self.facenet.eval()
|
15 |
+
|
16 |
+
def extract_feats(self, x):
|
17 |
+
x = x[:, :, 35:223, 32:220] # Crop interesting region
|
18 |
+
x = self.face_pool(x)
|
19 |
+
x_feats = self.facenet(x)
|
20 |
+
return x_feats
|
21 |
+
|
22 |
+
def forward(self, y_hat, y, x):
|
23 |
+
n_samples = x.shape[0]
|
24 |
+
x_feats = self.extract_feats(x)
|
25 |
+
y_feats = self.extract_feats(y) # Otherwise use the feature from there
|
26 |
+
y_hat_feats = self.extract_feats(y_hat)
|
27 |
+
y_feats = y_feats.detach()
|
28 |
+
loss = 0
|
29 |
+
sim_improvement = 0
|
30 |
+
id_logs = []
|
31 |
+
count = 0
|
32 |
+
for i in range(n_samples):
|
33 |
+
diff_target = y_hat_feats[i].dot(y_feats[i])
|
34 |
+
diff_input = y_hat_feats[i].dot(x_feats[i])
|
35 |
+
diff_views = y_feats[i].dot(x_feats[i])
|
36 |
+
id_logs.append({'diff_target': float(diff_target),
|
37 |
+
'diff_input': float(diff_input),
|
38 |
+
'diff_views': float(diff_views)})
|
39 |
+
loss += 1 - diff_target
|
40 |
+
id_diff = float(diff_target) - float(diff_views)
|
41 |
+
sim_improvement += id_diff
|
42 |
+
count += 1
|
43 |
+
|
44 |
+
return loss / count, sim_improvement / count, id_logs
|
criteria/lpips/__init__.py
ADDED
File without changes
|
criteria/lpips/lpips.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from criteria.lpips.networks import get_network, LinLayers
|
5 |
+
from criteria.lpips.utils import get_state_dict
|
6 |
+
|
7 |
+
|
8 |
+
class LPIPS(nn.Module):
|
9 |
+
r"""Creates a criterion that measures
|
10 |
+
Learned Perceptual Image Patch Similarity (LPIPS).
|
11 |
+
Arguments:
|
12 |
+
net_type (str): the network type to compare the features:
|
13 |
+
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
|
14 |
+
version (str): the version of LPIPS. Default: 0.1.
|
15 |
+
"""
|
16 |
+
def __init__(self, net_type: str = 'alex', version: str = '0.1'):
|
17 |
+
|
18 |
+
assert version in ['0.1'], 'v0.1 is only supported now'
|
19 |
+
|
20 |
+
super(LPIPS, self).__init__()
|
21 |
+
|
22 |
+
# pretrained network
|
23 |
+
self.net = get_network(net_type).to("cuda")
|
24 |
+
|
25 |
+
# linear layers
|
26 |
+
self.lin = LinLayers(self.net.n_channels_list).to("cuda")
|
27 |
+
self.lin.load_state_dict(get_state_dict(net_type, version))
|
28 |
+
|
29 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
30 |
+
feat_x, feat_y = self.net(x), self.net(y)
|
31 |
+
|
32 |
+
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
|
33 |
+
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
|
34 |
+
|
35 |
+
return torch.sum(torch.cat(res, 0)) / x.shape[0]
|
criteria/lpips/networks.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence
|
2 |
+
|
3 |
+
from itertools import chain
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision import models
|
8 |
+
|
9 |
+
from criteria.lpips.utils import normalize_activation
|
10 |
+
|
11 |
+
|
12 |
+
def get_network(net_type: str):
|
13 |
+
if net_type == 'alex':
|
14 |
+
return AlexNet()
|
15 |
+
elif net_type == 'squeeze':
|
16 |
+
return SqueezeNet()
|
17 |
+
elif net_type == 'vgg':
|
18 |
+
return VGG16()
|
19 |
+
else:
|
20 |
+
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
|
21 |
+
|
22 |
+
|
23 |
+
class LinLayers(nn.ModuleList):
|
24 |
+
def __init__(self, n_channels_list: Sequence[int]):
|
25 |
+
super(LinLayers, self).__init__([
|
26 |
+
nn.Sequential(
|
27 |
+
nn.Identity(),
|
28 |
+
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
|
29 |
+
) for nc in n_channels_list
|
30 |
+
])
|
31 |
+
|
32 |
+
for param in self.parameters():
|
33 |
+
param.requires_grad = False
|
34 |
+
|
35 |
+
|
36 |
+
class BaseNet(nn.Module):
|
37 |
+
def __init__(self):
|
38 |
+
super(BaseNet, self).__init__()
|
39 |
+
|
40 |
+
# register buffer
|
41 |
+
self.register_buffer(
|
42 |
+
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
43 |
+
self.register_buffer(
|
44 |
+
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
45 |
+
|
46 |
+
def set_requires_grad(self, state: bool):
|
47 |
+
for param in chain(self.parameters(), self.buffers()):
|
48 |
+
param.requires_grad = state
|
49 |
+
|
50 |
+
def z_score(self, x: torch.Tensor):
|
51 |
+
return (x - self.mean) / self.std
|
52 |
+
|
53 |
+
def forward(self, x: torch.Tensor):
|
54 |
+
x = self.z_score(x)
|
55 |
+
|
56 |
+
output = []
|
57 |
+
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
|
58 |
+
x = layer(x)
|
59 |
+
if i in self.target_layers:
|
60 |
+
output.append(normalize_activation(x))
|
61 |
+
if len(output) == len(self.target_layers):
|
62 |
+
break
|
63 |
+
return output
|
64 |
+
|
65 |
+
|
66 |
+
class SqueezeNet(BaseNet):
|
67 |
+
def __init__(self):
|
68 |
+
super(SqueezeNet, self).__init__()
|
69 |
+
|
70 |
+
self.layers = models.squeezenet1_1(True).features
|
71 |
+
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
|
72 |
+
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
|
73 |
+
|
74 |
+
self.set_requires_grad(False)
|
75 |
+
|
76 |
+
|
77 |
+
class AlexNet(BaseNet):
|
78 |
+
def __init__(self):
|
79 |
+
super(AlexNet, self).__init__()
|
80 |
+
|
81 |
+
self.layers = models.alexnet(True).features
|
82 |
+
self.target_layers = [2, 5, 8, 10, 12]
|
83 |
+
self.n_channels_list = [64, 192, 384, 256, 256]
|
84 |
+
|
85 |
+
self.set_requires_grad(False)
|
86 |
+
|
87 |
+
|
88 |
+
class VGG16(BaseNet):
|
89 |
+
def __init__(self):
|
90 |
+
super(VGG16, self).__init__()
|
91 |
+
|
92 |
+
self.layers = models.vgg16(True).features
|
93 |
+
self.target_layers = [4, 9, 16, 23, 30]
|
94 |
+
self.n_channels_list = [64, 128, 256, 512, 512]
|
95 |
+
|
96 |
+
self.set_requires_grad(False)
|
criteria/lpips/utils.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def normalize_activation(x, eps=1e-10):
|
7 |
+
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
|
8 |
+
return x / (norm_factor + eps)
|
9 |
+
|
10 |
+
|
11 |
+
def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
|
12 |
+
# build url
|
13 |
+
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
|
14 |
+
+ f'master/lpips/weights/v{version}/{net_type}.pth'
|
15 |
+
|
16 |
+
# download
|
17 |
+
old_state_dict = torch.hub.load_state_dict_from_url(
|
18 |
+
url, progress=True,
|
19 |
+
map_location=None if torch.cuda.is_available() else torch.device('cpu')
|
20 |
+
)
|
21 |
+
|
22 |
+
# rename keys
|
23 |
+
new_state_dict = OrderedDict()
|
24 |
+
for key, val in old_state_dict.items():
|
25 |
+
new_key = key
|
26 |
+
new_key = new_key.replace('lin', '')
|
27 |
+
new_key = new_key.replace('model.', '')
|
28 |
+
new_state_dict[new_key] = val
|
29 |
+
|
30 |
+
return new_state_dict
|
criteria/moco_loss.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from configs.paths_config import model_paths
|
5 |
+
|
6 |
+
|
7 |
+
class MocoLoss(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self):
|
10 |
+
super(MocoLoss, self).__init__()
|
11 |
+
print("Loading MOCO model from path: {}".format(model_paths["moco"]))
|
12 |
+
self.model = self.__load_model()
|
13 |
+
self.model.cuda()
|
14 |
+
self.model.eval()
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def __load_model():
|
18 |
+
import torchvision.models as models
|
19 |
+
model = models.__dict__["resnet50"]()
|
20 |
+
# freeze all layers but the last fc
|
21 |
+
for name, param in model.named_parameters():
|
22 |
+
if name not in ['fc.weight', 'fc.bias']:
|
23 |
+
param.requires_grad = False
|
24 |
+
checkpoint = torch.load(model_paths['moco'], map_location="cpu")
|
25 |
+
state_dict = checkpoint['state_dict']
|
26 |
+
# rename moco pre-trained keys
|
27 |
+
for k in list(state_dict.keys()):
|
28 |
+
# retain only encoder_q up to before the embedding layer
|
29 |
+
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
|
30 |
+
# remove prefix
|
31 |
+
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
|
32 |
+
# delete renamed or unused k
|
33 |
+
del state_dict[k]
|
34 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
35 |
+
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
|
36 |
+
# remove output layer
|
37 |
+
model = nn.Sequential(*list(model.children())[:-1]).cuda()
|
38 |
+
return model
|
39 |
+
|
40 |
+
def extract_feats(self, x):
|
41 |
+
x = F.interpolate(x, size=224)
|
42 |
+
x_feats = self.model(x)
|
43 |
+
x_feats = nn.functional.normalize(x_feats, dim=1)
|
44 |
+
x_feats = x_feats.squeeze()
|
45 |
+
return x_feats
|
46 |
+
|
47 |
+
def forward(self, y_hat, y, x):
|
48 |
+
n_samples = x.shape[0]
|
49 |
+
x_feats = self.extract_feats(x)
|
50 |
+
y_feats = self.extract_feats(y)
|
51 |
+
y_hat_feats = self.extract_feats(y_hat)
|
52 |
+
y_feats = y_feats.detach()
|
53 |
+
loss = 0
|
54 |
+
sim_improvement = 0
|
55 |
+
sim_logs = []
|
56 |
+
count = 0
|
57 |
+
for i in range(n_samples):
|
58 |
+
diff_target = y_hat_feats[i].dot(y_feats[i])
|
59 |
+
diff_input = y_hat_feats[i].dot(x_feats[i])
|
60 |
+
diff_views = y_feats[i].dot(x_feats[i])
|
61 |
+
sim_logs.append({'diff_target': float(diff_target),
|
62 |
+
'diff_input': float(diff_input),
|
63 |
+
'diff_views': float(diff_views)})
|
64 |
+
loss += 1 - diff_target
|
65 |
+
sim_diff = float(diff_target) - float(diff_views)
|
66 |
+
sim_improvement += sim_diff
|
67 |
+
count += 1
|
68 |
+
|
69 |
+
return loss / count, sim_improvement / count, sim_logs
|
criteria/w_norm.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class WNormLoss(nn.Module):
|
6 |
+
|
7 |
+
def __init__(self, start_from_latent_avg=True):
|
8 |
+
super(WNormLoss, self).__init__()
|
9 |
+
self.start_from_latent_avg = start_from_latent_avg
|
10 |
+
|
11 |
+
def forward(self, latent, latent_avg=None):
|
12 |
+
if self.start_from_latent_avg:
|
13 |
+
latent = latent - latent_avg
|
14 |
+
return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]
|
datasets/__init__.py
ADDED
File without changes
|
datasets/augmentations.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torchvision import transforms
|
6 |
+
|
7 |
+
|
8 |
+
class ToOneHot(object):
|
9 |
+
""" Convert the input PIL image to a one-hot torch tensor """
|
10 |
+
def __init__(self, n_classes=None):
|
11 |
+
self.n_classes = n_classes
|
12 |
+
|
13 |
+
def onehot_initialization(self, a):
|
14 |
+
if self.n_classes is None:
|
15 |
+
self.n_classes = len(np.unique(a))
|
16 |
+
out = np.zeros(a.shape + (self.n_classes, ), dtype=int)
|
17 |
+
out[self.__all_idx(a, axis=2)] = 1
|
18 |
+
return out
|
19 |
+
|
20 |
+
def __all_idx(self, idx, axis):
|
21 |
+
grid = np.ogrid[tuple(map(slice, idx.shape))]
|
22 |
+
grid.insert(axis, idx)
|
23 |
+
return tuple(grid)
|
24 |
+
|
25 |
+
def __call__(self, img):
|
26 |
+
img = np.array(img)
|
27 |
+
one_hot = self.onehot_initialization(img)
|
28 |
+
return one_hot
|
29 |
+
|
30 |
+
|
31 |
+
class BilinearResize(object):
|
32 |
+
def __init__(self, factors=[1, 2, 4, 8, 16, 32]):
|
33 |
+
self.factors = factors
|
34 |
+
|
35 |
+
def __call__(self, image):
|
36 |
+
factor = np.random.choice(self.factors, size=1)[0]
|
37 |
+
D = BicubicDownSample(factor=factor, cuda=False)
|
38 |
+
img_tensor = transforms.ToTensor()(image).unsqueeze(0)
|
39 |
+
img_tensor_lr = D(img_tensor)[0].clamp(0, 1)
|
40 |
+
img_low_res = transforms.ToPILImage()(img_tensor_lr)
|
41 |
+
return img_low_res
|
42 |
+
|
43 |
+
|
44 |
+
class BicubicDownSample(nn.Module):
|
45 |
+
def bicubic_kernel(self, x, a=-0.50):
|
46 |
+
"""
|
47 |
+
This equation is exactly copied from the website below:
|
48 |
+
https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic
|
49 |
+
"""
|
50 |
+
abs_x = torch.abs(x)
|
51 |
+
if abs_x <= 1.:
|
52 |
+
return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1
|
53 |
+
elif 1. < abs_x < 2.:
|
54 |
+
return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a
|
55 |
+
else:
|
56 |
+
return 0.0
|
57 |
+
|
58 |
+
def __init__(self, factor=4, cuda=True, padding='reflect'):
|
59 |
+
super().__init__()
|
60 |
+
self.factor = factor
|
61 |
+
size = factor * 4
|
62 |
+
k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor)
|
63 |
+
for i in range(size)], dtype=torch.float32)
|
64 |
+
k = k / torch.sum(k)
|
65 |
+
k1 = torch.reshape(k, shape=(1, 1, size, 1))
|
66 |
+
self.k1 = torch.cat([k1, k1, k1], dim=0)
|
67 |
+
k2 = torch.reshape(k, shape=(1, 1, 1, size))
|
68 |
+
self.k2 = torch.cat([k2, k2, k2], dim=0)
|
69 |
+
self.cuda = '.cuda' if cuda else ''
|
70 |
+
self.padding = padding
|
71 |
+
for param in self.parameters():
|
72 |
+
param.requires_grad = False
|
73 |
+
|
74 |
+
def forward(self, x, nhwc=False, clip_round=False, byte_output=False):
|
75 |
+
filter_height = self.factor * 4
|
76 |
+
filter_width = self.factor * 4
|
77 |
+
stride = self.factor
|
78 |
+
|
79 |
+
pad_along_height = max(filter_height - stride, 0)
|
80 |
+
pad_along_width = max(filter_width - stride, 0)
|
81 |
+
filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda))
|
82 |
+
filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda))
|
83 |
+
|
84 |
+
# compute actual padding values for each side
|
85 |
+
pad_top = pad_along_height // 2
|
86 |
+
pad_bottom = pad_along_height - pad_top
|
87 |
+
pad_left = pad_along_width // 2
|
88 |
+
pad_right = pad_along_width - pad_left
|
89 |
+
|
90 |
+
# apply mirror padding
|
91 |
+
if nhwc:
|
92 |
+
x = torch.transpose(torch.transpose(x, 2, 3), 1, 2) # NHWC to NCHW
|
93 |
+
|
94 |
+
# downscaling performed by 1-d convolution
|
95 |
+
x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding)
|
96 |
+
x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3)
|
97 |
+
if clip_round:
|
98 |
+
x = torch.clamp(torch.round(x), 0.0, 255.)
|
99 |
+
|
100 |
+
x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding)
|
101 |
+
x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3)
|
102 |
+
if clip_round:
|
103 |
+
x = torch.clamp(torch.round(x), 0.0, 255.)
|
104 |
+
|
105 |
+
if nhwc:
|
106 |
+
x = torch.transpose(torch.transpose(x, 1, 3), 1, 2)
|
107 |
+
if byte_output:
|
108 |
+
return x.type('torch.ByteTensor'.format(self.cuda))
|
109 |
+
else:
|
110 |
+
return x
|
datasets/gt_res_dataset.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# encoding: utf-8
|
3 |
+
import os
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
class GTResDataset(Dataset):
|
9 |
+
|
10 |
+
def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
|
11 |
+
self.pairs = []
|
12 |
+
for f in os.listdir(root_path):
|
13 |
+
image_path = os.path.join(root_path, f)
|
14 |
+
gt_path = os.path.join(gt_dir, f)
|
15 |
+
if f.endswith(".jpg") or f.endswith(".png"):
|
16 |
+
self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
|
17 |
+
self.transform = transform
|
18 |
+
self.transform_train = transform_train
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return len(self.pairs)
|
22 |
+
|
23 |
+
def __getitem__(self, index):
|
24 |
+
from_path, to_path, _ = self.pairs[index]
|
25 |
+
from_im = Image.open(from_path).convert('RGB')
|
26 |
+
to_im = Image.open(to_path).convert('RGB')
|
27 |
+
|
28 |
+
if self.transform:
|
29 |
+
to_im = self.transform(to_im)
|
30 |
+
from_im = self.transform(from_im)
|
31 |
+
|
32 |
+
return from_im, to_im
|
datasets/images_dataset.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from PIL import Image
|
3 |
+
from utils import data_utils
|
4 |
+
|
5 |
+
|
6 |
+
class ImagesDataset(Dataset):
|
7 |
+
|
8 |
+
def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None):
|
9 |
+
self.source_paths = sorted(data_utils.make_dataset(source_root))
|
10 |
+
self.target_paths = sorted(data_utils.make_dataset(target_root))
|
11 |
+
self.source_transform = source_transform
|
12 |
+
self.target_transform = target_transform
|
13 |
+
self.opts = opts
|
14 |
+
|
15 |
+
def __len__(self):
|
16 |
+
return len(self.source_paths)
|
17 |
+
|
18 |
+
def __getitem__(self, index):
|
19 |
+
from_path = self.source_paths[index]
|
20 |
+
from_im = Image.open(from_path)
|
21 |
+
from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L')
|
22 |
+
|
23 |
+
to_path = self.target_paths[index]
|
24 |
+
to_im = Image.open(to_path).convert('RGB')
|
25 |
+
if self.target_transform:
|
26 |
+
to_im = self.target_transform(to_im)
|
27 |
+
|
28 |
+
if self.source_transform:
|
29 |
+
from_im = self.source_transform(from_im)
|
30 |
+
else:
|
31 |
+
from_im = to_im
|
32 |
+
|
33 |
+
return from_im, to_im
|
datasets/inference_dataset.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from PIL import Image
|
3 |
+
from utils import data_utils
|
4 |
+
|
5 |
+
|
6 |
+
class InferenceDataset(Dataset):
|
7 |
+
|
8 |
+
def __init__(self, root, opts, transform=None):
|
9 |
+
self.paths = sorted(data_utils.make_dataset(root))
|
10 |
+
self.transform = transform
|
11 |
+
self.opts = opts
|
12 |
+
|
13 |
+
def __len__(self):
|
14 |
+
return len(self.paths)
|
15 |
+
|
16 |
+
def __getitem__(self, index):
|
17 |
+
from_path = self.paths[index]
|
18 |
+
from_im = Image.open(from_path)
|
19 |
+
from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L')
|
20 |
+
if self.transform:
|
21 |
+
from_im = self.transform(from_im)
|
22 |
+
return from_im
|
dnnlib/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
from .util import EasyDict, make_cache_dir_path
|
dnnlib/util.py
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Miscellaneous utility classes and functions."""
|
10 |
+
|
11 |
+
import ctypes
|
12 |
+
import fnmatch
|
13 |
+
import importlib
|
14 |
+
import inspect
|
15 |
+
import numpy as np
|
16 |
+
import os
|
17 |
+
import shutil
|
18 |
+
import sys
|
19 |
+
import types
|
20 |
+
import io
|
21 |
+
import pickle
|
22 |
+
import re
|
23 |
+
import requests
|
24 |
+
import html
|
25 |
+
import hashlib
|
26 |
+
import glob
|
27 |
+
import tempfile
|
28 |
+
import urllib
|
29 |
+
import urllib.request
|
30 |
+
import uuid
|
31 |
+
|
32 |
+
from distutils.util import strtobool
|
33 |
+
from typing import Any, List, Tuple, Union
|
34 |
+
|
35 |
+
|
36 |
+
# Util classes
|
37 |
+
# ------------------------------------------------------------------------------------------
|
38 |
+
|
39 |
+
|
40 |
+
class EasyDict(dict):
|
41 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
42 |
+
|
43 |
+
def __getattr__(self, name: str) -> Any:
|
44 |
+
try:
|
45 |
+
return self[name]
|
46 |
+
except KeyError:
|
47 |
+
raise AttributeError(name)
|
48 |
+
|
49 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
50 |
+
self[name] = value
|
51 |
+
|
52 |
+
def __delattr__(self, name: str) -> None:
|
53 |
+
del self[name]
|
54 |
+
|
55 |
+
|
56 |
+
class Logger(object):
|
57 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
58 |
+
|
59 |
+
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
60 |
+
self.file = None
|
61 |
+
|
62 |
+
if file_name is not None:
|
63 |
+
self.file = open(file_name, file_mode)
|
64 |
+
|
65 |
+
self.should_flush = should_flush
|
66 |
+
self.stdout = sys.stdout
|
67 |
+
self.stderr = sys.stderr
|
68 |
+
|
69 |
+
sys.stdout = self
|
70 |
+
sys.stderr = self
|
71 |
+
|
72 |
+
def __enter__(self) -> "Logger":
|
73 |
+
return self
|
74 |
+
|
75 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
76 |
+
self.close()
|
77 |
+
|
78 |
+
def write(self, text: Union[str, bytes]) -> None:
|
79 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
80 |
+
if isinstance(text, bytes):
|
81 |
+
text = text.decode()
|
82 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
83 |
+
return
|
84 |
+
|
85 |
+
if self.file is not None:
|
86 |
+
self.file.write(text)
|
87 |
+
|
88 |
+
self.stdout.write(text)
|
89 |
+
|
90 |
+
if self.should_flush:
|
91 |
+
self.flush()
|
92 |
+
|
93 |
+
def flush(self) -> None:
|
94 |
+
"""Flush written text to both stdout and a file, if open."""
|
95 |
+
if self.file is not None:
|
96 |
+
self.file.flush()
|
97 |
+
|
98 |
+
self.stdout.flush()
|
99 |
+
|
100 |
+
def close(self) -> None:
|
101 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
102 |
+
self.flush()
|
103 |
+
|
104 |
+
# if using multiple loggers, prevent closing in wrong order
|
105 |
+
if sys.stdout is self:
|
106 |
+
sys.stdout = self.stdout
|
107 |
+
if sys.stderr is self:
|
108 |
+
sys.stderr = self.stderr
|
109 |
+
|
110 |
+
if self.file is not None:
|
111 |
+
self.file.close()
|
112 |
+
self.file = None
|
113 |
+
|
114 |
+
|
115 |
+
# Cache directories
|
116 |
+
# ------------------------------------------------------------------------------------------
|
117 |
+
|
118 |
+
_dnnlib_cache_dir = None
|
119 |
+
|
120 |
+
def set_cache_dir(path: str) -> None:
|
121 |
+
global _dnnlib_cache_dir
|
122 |
+
_dnnlib_cache_dir = path
|
123 |
+
|
124 |
+
def make_cache_dir_path(*paths: str) -> str:
|
125 |
+
if _dnnlib_cache_dir is not None:
|
126 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
127 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
128 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
129 |
+
if 'HOME' in os.environ:
|
130 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
131 |
+
if 'USERPROFILE' in os.environ:
|
132 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
133 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
134 |
+
|
135 |
+
# Small util functions
|
136 |
+
# ------------------------------------------------------------------------------------------
|
137 |
+
|
138 |
+
|
139 |
+
def format_time(seconds: Union[int, float]) -> str:
|
140 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
141 |
+
s = int(np.rint(seconds))
|
142 |
+
|
143 |
+
if s < 60:
|
144 |
+
return "{0}s".format(s)
|
145 |
+
elif s < 60 * 60:
|
146 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
147 |
+
elif s < 24 * 60 * 60:
|
148 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
149 |
+
else:
|
150 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
151 |
+
|
152 |
+
|
153 |
+
def ask_yes_no(question: str) -> bool:
|
154 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
155 |
+
while True:
|
156 |
+
try:
|
157 |
+
print("{0} [y/n]".format(question))
|
158 |
+
return strtobool(input().lower())
|
159 |
+
except ValueError:
|
160 |
+
pass
|
161 |
+
|
162 |
+
|
163 |
+
def tuple_product(t: Tuple) -> Any:
|
164 |
+
"""Calculate the product of the tuple elements."""
|
165 |
+
result = 1
|
166 |
+
|
167 |
+
for v in t:
|
168 |
+
result *= v
|
169 |
+
|
170 |
+
return result
|
171 |
+
|
172 |
+
|
173 |
+
_str_to_ctype = {
|
174 |
+
"uint8": ctypes.c_ubyte,
|
175 |
+
"uint16": ctypes.c_uint16,
|
176 |
+
"uint32": ctypes.c_uint32,
|
177 |
+
"uint64": ctypes.c_uint64,
|
178 |
+
"int8": ctypes.c_byte,
|
179 |
+
"int16": ctypes.c_int16,
|
180 |
+
"int32": ctypes.c_int32,
|
181 |
+
"int64": ctypes.c_int64,
|
182 |
+
"float32": ctypes.c_float,
|
183 |
+
"float64": ctypes.c_double
|
184 |
+
}
|
185 |
+
|
186 |
+
|
187 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
188 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
189 |
+
type_str = None
|
190 |
+
|
191 |
+
if isinstance(type_obj, str):
|
192 |
+
type_str = type_obj
|
193 |
+
elif hasattr(type_obj, "__name__"):
|
194 |
+
type_str = type_obj.__name__
|
195 |
+
elif hasattr(type_obj, "name"):
|
196 |
+
type_str = type_obj.name
|
197 |
+
else:
|
198 |
+
raise RuntimeError("Cannot infer type name from input")
|
199 |
+
|
200 |
+
assert type_str in _str_to_ctype.keys()
|
201 |
+
|
202 |
+
my_dtype = np.dtype(type_str)
|
203 |
+
my_ctype = _str_to_ctype[type_str]
|
204 |
+
|
205 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
206 |
+
|
207 |
+
return my_dtype, my_ctype
|
208 |
+
|
209 |
+
|
210 |
+
def is_pickleable(obj: Any) -> bool:
|
211 |
+
try:
|
212 |
+
with io.BytesIO() as stream:
|
213 |
+
pickle.dump(obj, stream)
|
214 |
+
return True
|
215 |
+
except:
|
216 |
+
return False
|
217 |
+
|
218 |
+
|
219 |
+
# Functionality to import modules/objects by name, and call functions by name
|
220 |
+
# ------------------------------------------------------------------------------------------
|
221 |
+
|
222 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
223 |
+
"""Searches for the underlying module behind the name to some python object.
|
224 |
+
Returns the module and the object name (original name with module part removed)."""
|
225 |
+
|
226 |
+
# allow convenience shorthands, substitute them by full names
|
227 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
228 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
229 |
+
|
230 |
+
# list alternatives for (module_name, local_obj_name)
|
231 |
+
parts = obj_name.split(".")
|
232 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
233 |
+
|
234 |
+
# try each alternative in turn
|
235 |
+
for module_name, local_obj_name in name_pairs:
|
236 |
+
try:
|
237 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
238 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
239 |
+
return module, local_obj_name
|
240 |
+
except:
|
241 |
+
pass
|
242 |
+
|
243 |
+
# maybe some of the modules themselves contain errors?
|
244 |
+
for module_name, _local_obj_name in name_pairs:
|
245 |
+
try:
|
246 |
+
importlib.import_module(module_name) # may raise ImportError
|
247 |
+
except ImportError:
|
248 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
249 |
+
raise
|
250 |
+
|
251 |
+
# maybe the requested attribute is missing?
|
252 |
+
for module_name, local_obj_name in name_pairs:
|
253 |
+
try:
|
254 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
255 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
256 |
+
except ImportError:
|
257 |
+
pass
|
258 |
+
|
259 |
+
# we are out of luck, but we have no idea why
|
260 |
+
raise ImportError(obj_name)
|
261 |
+
|
262 |
+
|
263 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
264 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
265 |
+
if obj_name == '':
|
266 |
+
return module
|
267 |
+
obj = module
|
268 |
+
for part in obj_name.split("."):
|
269 |
+
obj = getattr(obj, part)
|
270 |
+
return obj
|
271 |
+
|
272 |
+
|
273 |
+
def get_obj_by_name(name: str) -> Any:
|
274 |
+
"""Finds the python object with the given name."""
|
275 |
+
module, obj_name = get_module_from_obj_name(name)
|
276 |
+
return get_obj_from_module(module, obj_name)
|
277 |
+
|
278 |
+
|
279 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
280 |
+
"""Finds the python object with the given name and calls it as a function."""
|
281 |
+
assert func_name is not None
|
282 |
+
func_obj = get_obj_by_name(func_name)
|
283 |
+
assert callable(func_obj)
|
284 |
+
return func_obj(*args, **kwargs)
|
285 |
+
|
286 |
+
|
287 |
+
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
288 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
289 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
290 |
+
|
291 |
+
|
292 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
293 |
+
"""Get the directory path of the module containing the given object name."""
|
294 |
+
module, _ = get_module_from_obj_name(obj_name)
|
295 |
+
return os.path.dirname(inspect.getfile(module))
|
296 |
+
|
297 |
+
|
298 |
+
def is_top_level_function(obj: Any) -> bool:
|
299 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
300 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
301 |
+
|
302 |
+
|
303 |
+
def get_top_level_function_name(obj: Any) -> str:
|
304 |
+
"""Return the fully-qualified name of a top-level function."""
|
305 |
+
assert is_top_level_function(obj)
|
306 |
+
module = obj.__module__
|
307 |
+
if module == '__main__':
|
308 |
+
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
309 |
+
return module + "." + obj.__name__
|
310 |
+
|
311 |
+
|
312 |
+
# File system helpers
|
313 |
+
# ------------------------------------------------------------------------------------------
|
314 |
+
|
315 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
316 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
317 |
+
Returns list of tuples containing both absolute and relative paths."""
|
318 |
+
assert os.path.isdir(dir_path)
|
319 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
320 |
+
|
321 |
+
if ignores is None:
|
322 |
+
ignores = []
|
323 |
+
|
324 |
+
result = []
|
325 |
+
|
326 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
327 |
+
for ignore_ in ignores:
|
328 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
329 |
+
|
330 |
+
# dirs need to be edited in-place
|
331 |
+
for d in dirs_to_remove:
|
332 |
+
dirs.remove(d)
|
333 |
+
|
334 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
335 |
+
|
336 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
337 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
338 |
+
|
339 |
+
if add_base_to_relative:
|
340 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
341 |
+
|
342 |
+
assert len(absolute_paths) == len(relative_paths)
|
343 |
+
result += zip(absolute_paths, relative_paths)
|
344 |
+
|
345 |
+
return result
|
346 |
+
|
347 |
+
|
348 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
349 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
350 |
+
Will create all necessary directories."""
|
351 |
+
for file in files:
|
352 |
+
target_dir_name = os.path.dirname(file[1])
|
353 |
+
|
354 |
+
# will create all intermediate-level directories
|
355 |
+
if not os.path.exists(target_dir_name):
|
356 |
+
os.makedirs(target_dir_name)
|
357 |
+
|
358 |
+
shutil.copyfile(file[0], file[1])
|
359 |
+
|
360 |
+
|
361 |
+
# URL helpers
|
362 |
+
# ------------------------------------------------------------------------------------------
|
363 |
+
|
364 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
365 |
+
"""Determine whether the given object is a valid URL string."""
|
366 |
+
if not isinstance(obj, str) or not "://" in obj:
|
367 |
+
return False
|
368 |
+
if allow_file_urls and obj.startswith('file://'):
|
369 |
+
return True
|
370 |
+
try:
|
371 |
+
res = requests.compat.urlparse(obj)
|
372 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
373 |
+
return False
|
374 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
375 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
376 |
+
return False
|
377 |
+
except:
|
378 |
+
return False
|
379 |
+
return True
|
380 |
+
|
381 |
+
|
382 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
383 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
384 |
+
assert num_attempts >= 1
|
385 |
+
assert not (return_filename and (not cache))
|
386 |
+
|
387 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
388 |
+
if not re.match('^[a-z]+://', url):
|
389 |
+
return url if return_filename else open(url, "rb")
|
390 |
+
|
391 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
392 |
+
# arise on Windows:
|
393 |
+
#
|
394 |
+
# file:///c:/foo.txt
|
395 |
+
#
|
396 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
397 |
+
# invalid. Drop the forward slash for such pathnames.
|
398 |
+
#
|
399 |
+
# If you touch this code path, you should test it on both Linux and
|
400 |
+
# Windows.
|
401 |
+
#
|
402 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
403 |
+
# but that converts forward slashes to backslashes and this causes
|
404 |
+
# its own set of problems.
|
405 |
+
if url.startswith('file://'):
|
406 |
+
filename = urllib.parse.urlparse(url).path
|
407 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
408 |
+
filename = filename[1:]
|
409 |
+
return filename if return_filename else open(filename, "rb")
|
410 |
+
|
411 |
+
assert is_url(url)
|
412 |
+
|
413 |
+
# Lookup from cache.
|
414 |
+
if cache_dir is None:
|
415 |
+
cache_dir = make_cache_dir_path('downloads')
|
416 |
+
|
417 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
418 |
+
if cache:
|
419 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
420 |
+
if len(cache_files) == 1:
|
421 |
+
filename = cache_files[0]
|
422 |
+
return filename if return_filename else open(filename, "rb")
|
423 |
+
|
424 |
+
# Download.
|
425 |
+
url_name = None
|
426 |
+
url_data = None
|
427 |
+
with requests.Session() as session:
|
428 |
+
if verbose:
|
429 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
430 |
+
for attempts_left in reversed(range(num_attempts)):
|
431 |
+
try:
|
432 |
+
with session.get(url) as res:
|
433 |
+
res.raise_for_status()
|
434 |
+
if len(res.content) == 0:
|
435 |
+
raise IOError("No data received")
|
436 |
+
|
437 |
+
if len(res.content) < 8192:
|
438 |
+
content_str = res.content.decode("utf-8")
|
439 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
440 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
441 |
+
if len(links) == 1:
|
442 |
+
url = requests.compat.urljoin(url, links[0])
|
443 |
+
raise IOError("Google Drive virus checker nag")
|
444 |
+
if "Google Drive - Quota exceeded" in content_str:
|
445 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
446 |
+
|
447 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
448 |
+
url_name = match[1] if match else url
|
449 |
+
url_data = res.content
|
450 |
+
if verbose:
|
451 |
+
print(" done")
|
452 |
+
break
|
453 |
+
except KeyboardInterrupt:
|
454 |
+
raise
|
455 |
+
except:
|
456 |
+
if not attempts_left:
|
457 |
+
if verbose:
|
458 |
+
print(" failed")
|
459 |
+
raise
|
460 |
+
if verbose:
|
461 |
+
print(".", end="", flush=True)
|
462 |
+
|
463 |
+
# Save to cache.
|
464 |
+
if cache:
|
465 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
466 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
467 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
468 |
+
os.makedirs(cache_dir, exist_ok=True)
|
469 |
+
with open(temp_file, "wb") as f:
|
470 |
+
f.write(url_data)
|
471 |
+
os.replace(temp_file, cache_file) # atomic
|
472 |
+
if return_filename:
|
473 |
+
return cache_file
|
474 |
+
|
475 |
+
# Return data as file object.
|
476 |
+
assert not return_filename
|
477 |
+
return io.BytesIO(url_data)
|
legacy.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import click
|
10 |
+
import pickle
|
11 |
+
import re
|
12 |
+
import copy
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import dnnlib
|
16 |
+
from torch_utils import misc
|
17 |
+
|
18 |
+
#----------------------------------------------------------------------------
|
19 |
+
|
20 |
+
# !!! custom
|
21 |
+
def load_network_pkl(f, force_fp16=False, custom=False, **ex_kwargs):
|
22 |
+
# def load_network_pkl(f, force_fp16=False):
|
23 |
+
data = _LegacyUnpickler(f).load()
|
24 |
+
# data = pickle.load(f, encoding='latin1')
|
25 |
+
|
26 |
+
# Legacy TensorFlow pickle => convert.
|
27 |
+
if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
|
28 |
+
tf_G, tf_D, tf_Gs = data
|
29 |
+
G = convert_tf_generator(tf_G, custom=custom, **ex_kwargs)
|
30 |
+
D = convert_tf_discriminator(tf_D)
|
31 |
+
G_ema = convert_tf_generator(tf_Gs, custom=custom, **ex_kwargs)
|
32 |
+
data = dict(G=G, D=D, G_ema=G_ema)
|
33 |
+
# !!! custom
|
34 |
+
assert isinstance(data['G'], torch.nn.Module)
|
35 |
+
assert isinstance(data['D'], torch.nn.Module)
|
36 |
+
nets = ['G', 'D', 'G_ema']
|
37 |
+
elif isinstance(data, _TFNetworkStub):
|
38 |
+
G_ema = convert_tf_generator(data, custom=custom, **ex_kwargs)
|
39 |
+
data = dict(G_ema=G_ema)
|
40 |
+
nets = ['G_ema']
|
41 |
+
else:
|
42 |
+
# !!! custom
|
43 |
+
if custom is True:
|
44 |
+
G_ema = custom_generator(data, **ex_kwargs)
|
45 |
+
data = dict(G_ema=G_ema)
|
46 |
+
nets = ['G_ema']
|
47 |
+
else:
|
48 |
+
nets = []
|
49 |
+
for name in ['G', 'D', 'G_ema']:
|
50 |
+
if name in data.keys():
|
51 |
+
nets.append(name)
|
52 |
+
# print(nets)
|
53 |
+
|
54 |
+
# Add missing fields.
|
55 |
+
if 'training_set_kwargs' not in data:
|
56 |
+
data['training_set_kwargs'] = None
|
57 |
+
if 'augment_pipe' not in data:
|
58 |
+
data['augment_pipe'] = None
|
59 |
+
|
60 |
+
# Validate contents.
|
61 |
+
assert isinstance(data['G_ema'], torch.nn.Module)
|
62 |
+
assert isinstance(data['training_set_kwargs'], (dict, type(None)))
|
63 |
+
assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
|
64 |
+
|
65 |
+
# Force FP16.
|
66 |
+
if force_fp16:
|
67 |
+
for key in nets: # !!! custom
|
68 |
+
old = data[key]
|
69 |
+
kwargs = copy.deepcopy(old.init_kwargs)
|
70 |
+
if key.startswith('G'):
|
71 |
+
kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
|
72 |
+
kwargs.synthesis_kwargs.num_fp16_res = 4
|
73 |
+
kwargs.synthesis_kwargs.conv_clamp = 256
|
74 |
+
if key.startswith('D'):
|
75 |
+
kwargs.num_fp16_res = 4
|
76 |
+
kwargs.conv_clamp = 256
|
77 |
+
if kwargs != old.init_kwargs:
|
78 |
+
new = type(old)(**kwargs).eval().requires_grad_(False)
|
79 |
+
misc.copy_params_and_buffers(old, new, require_all=True)
|
80 |
+
data[key] = new
|
81 |
+
return data
|
82 |
+
|
83 |
+
#----------------------------------------------------------------------------
|
84 |
+
|
85 |
+
class _TFNetworkStub(dnnlib.EasyDict):
|
86 |
+
pass
|
87 |
+
|
88 |
+
class _LegacyUnpickler(pickle.Unpickler):
|
89 |
+
def find_class(self, module, name):
|
90 |
+
if module == 'dnnlib.tflib.network' and name == 'Network':
|
91 |
+
return _TFNetworkStub
|
92 |
+
return super().find_class(module, name)
|
93 |
+
|
94 |
+
#----------------------------------------------------------------------------
|
95 |
+
|
96 |
+
def _collect_tf_params(tf_net):
|
97 |
+
# pylint: disable=protected-access
|
98 |
+
tf_params = dict()
|
99 |
+
def recurse(prefix, tf_net):
|
100 |
+
for name, value in tf_net.variables:
|
101 |
+
tf_params[prefix + name] = value
|
102 |
+
for name, comp in tf_net.components.items():
|
103 |
+
recurse(prefix + name + '/', comp)
|
104 |
+
recurse('', tf_net)
|
105 |
+
return tf_params
|
106 |
+
|
107 |
+
#----------------------------------------------------------------------------
|
108 |
+
|
109 |
+
def _populate_module_params(module, *patterns):
|
110 |
+
for name, tensor in misc.named_params_and_buffers(module):
|
111 |
+
found = False
|
112 |
+
value = None
|
113 |
+
for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
|
114 |
+
match = re.fullmatch(pattern, name)
|
115 |
+
if match:
|
116 |
+
found = True
|
117 |
+
if value_fn is not None:
|
118 |
+
value = value_fn(*match.groups())
|
119 |
+
break
|
120 |
+
try:
|
121 |
+
assert found
|
122 |
+
if value is not None:
|
123 |
+
tensor.copy_(torch.from_numpy(np.array(value)))
|
124 |
+
except:
|
125 |
+
print(name, list(tensor.shape))
|
126 |
+
raise
|
127 |
+
|
128 |
+
#----------------------------------------------------------------------------
|
129 |
+
|
130 |
+
# !!! custom
|
131 |
+
def custom_generator(data, **ex_kwargs):
|
132 |
+
from training import stylegan2_multi as networks
|
133 |
+
try: # saved? (with new fix)
|
134 |
+
fmap_base = data['G_ema'].synthesis.fmap_base
|
135 |
+
except: # default from original configs
|
136 |
+
fmap_base = 32768 if data['G_ema'].img_resolution >= 512 else 16384
|
137 |
+
kwargs = dnnlib.EasyDict(
|
138 |
+
z_dim = data['G_ema'].z_dim,
|
139 |
+
c_dim = data['G_ema'].c_dim,
|
140 |
+
w_dim = data['G_ema'].w_dim,
|
141 |
+
img_resolution = data['G_ema'].img_resolution,
|
142 |
+
img_channels = data['G_ema'].img_channels,
|
143 |
+
init_res = [4,4], # hacky
|
144 |
+
mapping_kwargs = dnnlib.EasyDict(num_layers = data['G_ema'].mapping.num_layers),
|
145 |
+
synthesis_kwargs = dnnlib.EasyDict(channel_base = fmap_base, **ex_kwargs),
|
146 |
+
)
|
147 |
+
G_out = networks.Generator(**kwargs).eval().requires_grad_(False)
|
148 |
+
misc.copy_params_and_buffers(data['G_ema'], G_out, require_all=False)
|
149 |
+
return G_out
|
150 |
+
|
151 |
+
# !!! custom
|
152 |
+
def convert_tf_generator(tf_G, custom=False, **ex_kwargs):
|
153 |
+
# def convert_tf_generator(tf_G):
|
154 |
+
if tf_G.version < 4:
|
155 |
+
raise ValueError('TensorFlow pickle version too low')
|
156 |
+
|
157 |
+
# Collect kwargs.
|
158 |
+
tf_kwargs = tf_G.static_kwargs
|
159 |
+
known_kwargs = set()
|
160 |
+
def kwarg(tf_name, default=None, none=None):
|
161 |
+
known_kwargs.add(tf_name)
|
162 |
+
val = tf_kwargs.get(tf_name, default)
|
163 |
+
return val if val is not None else none
|
164 |
+
|
165 |
+
# Convert kwargs.
|
166 |
+
kwargs = dnnlib.EasyDict(
|
167 |
+
z_dim = kwarg('latent_size', 512),
|
168 |
+
c_dim = kwarg('label_size', 0),
|
169 |
+
w_dim = kwarg('dlatent_size', 512),
|
170 |
+
img_resolution = kwarg('resolution', 1024),
|
171 |
+
img_channels = kwarg('num_channels', 3),
|
172 |
+
mapping_kwargs = dnnlib.EasyDict(
|
173 |
+
num_layers = kwarg('mapping_layers', 8),
|
174 |
+
embed_features = kwarg('label_fmaps', None),
|
175 |
+
layer_features = kwarg('mapping_fmaps', None),
|
176 |
+
activation = kwarg('mapping_nonlinearity', 'lrelu'),
|
177 |
+
lr_multiplier = kwarg('mapping_lrmul', 0.01),
|
178 |
+
w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
|
179 |
+
),
|
180 |
+
synthesis_kwargs = dnnlib.EasyDict(
|
181 |
+
channel_base = kwarg('fmap_base', 16384) * 2,
|
182 |
+
channel_max = kwarg('fmap_max', 512),
|
183 |
+
num_fp16_res = kwarg('num_fp16_res', 0),
|
184 |
+
conv_clamp = kwarg('conv_clamp', None),
|
185 |
+
architecture = kwarg('architecture', 'skip'),
|
186 |
+
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
|
187 |
+
use_noise = kwarg('use_noise', True),
|
188 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
189 |
+
),
|
190 |
+
# !!! custom
|
191 |
+
# init_res = kwarg('init_res', [4,4]),
|
192 |
+
)
|
193 |
+
|
194 |
+
# Check for unknown kwargs.
|
195 |
+
kwarg('truncation_psi')
|
196 |
+
kwarg('truncation_cutoff')
|
197 |
+
kwarg('style_mixing_prob')
|
198 |
+
kwarg('structure')
|
199 |
+
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
|
200 |
+
# !!! custom
|
201 |
+
if custom:
|
202 |
+
kwargs.init_res = [4,4]
|
203 |
+
kwargs.synthesis_kwargs = dnnlib.EasyDict(**kwargs.synthesis_kwargs, **ex_kwargs)
|
204 |
+
if len(unknown_kwargs) > 0:
|
205 |
+
print('Unknown TensorFlow data! This may result in problems with your converted model.')
|
206 |
+
print(unknown_kwargs)
|
207 |
+
#raise ValueError('Unknown TensorFlow kwargs:', unknown_kwargs)
|
208 |
+
# raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
|
209 |
+
# try:
|
210 |
+
# if ex_kwargs['verbose'] is True: print(kwargs.synthesis_kwargs)
|
211 |
+
# except: pass
|
212 |
+
|
213 |
+
# Collect params.
|
214 |
+
tf_params = _collect_tf_params(tf_G)
|
215 |
+
for name, value in list(tf_params.items()):
|
216 |
+
match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
|
217 |
+
if match:
|
218 |
+
r = kwargs.img_resolution // (2 ** int(match.group(1)))
|
219 |
+
tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
|
220 |
+
kwargs.synthesis.kwargs.architecture = 'orig'
|
221 |
+
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
|
222 |
+
|
223 |
+
# Convert params.
|
224 |
+
if custom:
|
225 |
+
from training import stylegan2_multi as networks
|
226 |
+
else:
|
227 |
+
from training import networks
|
228 |
+
G = networks.Generator(**kwargs).eval().requires_grad_(False)
|
229 |
+
# pylint: disable=unnecessary-lambda
|
230 |
+
_populate_module_params(G,
|
231 |
+
r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
|
232 |
+
r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
|
233 |
+
r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
|
234 |
+
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
|
235 |
+
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
|
236 |
+
r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
|
237 |
+
r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
|
238 |
+
r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
|
239 |
+
r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
|
240 |
+
r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
|
241 |
+
r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
|
242 |
+
r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
|
243 |
+
r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
|
244 |
+
r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
|
245 |
+
r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
|
246 |
+
r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
|
247 |
+
r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
|
248 |
+
r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
|
249 |
+
r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
|
250 |
+
r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
|
251 |
+
r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
|
252 |
+
r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
|
253 |
+
r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
|
254 |
+
r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
|
255 |
+
r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
|
256 |
+
r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
|
257 |
+
r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
|
258 |
+
r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
|
259 |
+
r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
|
260 |
+
r'.*\.resample_filter', None,
|
261 |
+
)
|
262 |
+
return G
|
263 |
+
|
264 |
+
#----------------------------------------------------------------------------
|
265 |
+
|
266 |
+
def convert_tf_discriminator(tf_D):
|
267 |
+
if tf_D.version < 4:
|
268 |
+
raise ValueError('TensorFlow pickle version too low')
|
269 |
+
|
270 |
+
# Collect kwargs.
|
271 |
+
tf_kwargs = tf_D.static_kwargs
|
272 |
+
known_kwargs = set()
|
273 |
+
def kwarg(tf_name, default=None):
|
274 |
+
known_kwargs.add(tf_name)
|
275 |
+
return tf_kwargs.get(tf_name, default)
|
276 |
+
|
277 |
+
# Convert kwargs.
|
278 |
+
kwargs = dnnlib.EasyDict(
|
279 |
+
c_dim = kwarg('label_size', 0),
|
280 |
+
img_resolution = kwarg('resolution', 1024),
|
281 |
+
img_channels = kwarg('num_channels', 3),
|
282 |
+
architecture = kwarg('architecture', 'resnet'),
|
283 |
+
channel_base = kwarg('fmap_base', 16384) * 2,
|
284 |
+
channel_max = kwarg('fmap_max', 512),
|
285 |
+
num_fp16_res = kwarg('num_fp16_res', 0),
|
286 |
+
conv_clamp = kwarg('conv_clamp', None),
|
287 |
+
cmap_dim = kwarg('mapping_fmaps', None),
|
288 |
+
block_kwargs = dnnlib.EasyDict(
|
289 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
290 |
+
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
|
291 |
+
freeze_layers = kwarg('freeze_layers', 0),
|
292 |
+
),
|
293 |
+
mapping_kwargs = dnnlib.EasyDict(
|
294 |
+
num_layers = kwarg('mapping_layers', 0),
|
295 |
+
embed_features = kwarg('mapping_fmaps', None),
|
296 |
+
layer_features = kwarg('mapping_fmaps', None),
|
297 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
298 |
+
lr_multiplier = kwarg('mapping_lrmul', 0.1),
|
299 |
+
),
|
300 |
+
epilogue_kwargs = dnnlib.EasyDict(
|
301 |
+
mbstd_group_size = kwarg('mbstd_group_size', None),
|
302 |
+
mbstd_num_channels = kwarg('mbstd_num_features', 1),
|
303 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
304 |
+
),
|
305 |
+
# !!! custom
|
306 |
+
# init_res = kwarg('init_res', [4,4]),
|
307 |
+
)
|
308 |
+
|
309 |
+
# Check for unknown kwargs.
|
310 |
+
kwarg('structure')
|
311 |
+
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
|
312 |
+
if len(unknown_kwargs) > 0:
|
313 |
+
print('Unknown TensorFlow data! This may result in problems with your converted model.')
|
314 |
+
print(unknown_kwargs)
|
315 |
+
# originally this repo threw errors:
|
316 |
+
# raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
|
317 |
+
|
318 |
+
# Collect params.
|
319 |
+
tf_params = _collect_tf_params(tf_D)
|
320 |
+
for name, value in list(tf_params.items()):
|
321 |
+
match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
|
322 |
+
if match:
|
323 |
+
r = kwargs.img_resolution // (2 ** int(match.group(1)))
|
324 |
+
tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
|
325 |
+
kwargs.architecture = 'orig'
|
326 |
+
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
|
327 |
+
|
328 |
+
# Convert params.
|
329 |
+
from training import networks
|
330 |
+
D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
|
331 |
+
# pylint: disable=unnecessary-lambda
|
332 |
+
_populate_module_params(D,
|
333 |
+
r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
|
334 |
+
r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
|
335 |
+
r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
|
336 |
+
r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
|
337 |
+
r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
|
338 |
+
r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
|
339 |
+
r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
|
340 |
+
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
|
341 |
+
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
|
342 |
+
r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
|
343 |
+
r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
|
344 |
+
r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
|
345 |
+
r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
|
346 |
+
r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
|
347 |
+
r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
|
348 |
+
r'.*\.resample_filter', None,
|
349 |
+
)
|
350 |
+
return D
|
351 |
+
|
352 |
+
#----------------------------------------------------------------------------
|
353 |
+
|
354 |
+
@click.command()
|
355 |
+
@click.option('--source', help='Input pickle', required=True, metavar='PATH')
|
356 |
+
@click.option('--dest', help='Output pickle', required=True, metavar='PATH')
|
357 |
+
@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
|
358 |
+
def convert_network_pickle(source, dest, force_fp16):
|
359 |
+
"""Convert legacy network pickle into the native PyTorch format.
|
360 |
+
|
361 |
+
The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
|
362 |
+
It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
|
363 |
+
|
364 |
+
Example:
|
365 |
+
|
366 |
+
\b
|
367 |
+
python legacy.py \\
|
368 |
+
--source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
|
369 |
+
--dest=stylegan2-cat-config-f.pkl
|
370 |
+
"""
|
371 |
+
print(f'Loading "{source}"...')
|
372 |
+
with dnnlib.util.open_url(source) as f:
|
373 |
+
data = load_network_pkl(f, force_fp16=force_fp16)
|
374 |
+
print(f'Saving "{dest}"...')
|
375 |
+
with open(dest, 'wb') as f:
|
376 |
+
pickle.dump(data, f)
|
377 |
+
print('Done.')
|
378 |
+
|
379 |
+
#----------------------------------------------------------------------------
|
380 |
+
|
381 |
+
if __name__ == "__main__":
|
382 |
+
convert_network_pickle() # pylint: disable=no-value-for-parameter
|
383 |
+
|
384 |
+
#----------------------------------------------------------------------------
|
model_build.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from numpy import linalg
|
6 |
+
import PIL.Image as Image
|
7 |
+
import torch
|
8 |
+
from torchvision import transforms
|
9 |
+
from tqdm import tqdm
|
10 |
+
from argparse import Namespace
|
11 |
+
import easydict
|
12 |
+
|
13 |
+
import legacy
|
14 |
+
import dnnlib
|
15 |
+
|
16 |
+
from opensimplex import OpenSimplex
|
17 |
+
|
18 |
+
from configs import data_configs
|
19 |
+
from models.psp import pSp
|
20 |
+
|
21 |
+
|
22 |
+
def build_stylegan2(
|
23 |
+
increment = 0.01,
|
24 |
+
network_pkl = 'pretrained/furry.pkl',
|
25 |
+
process = 'image', #['image', 'interpolation','truncation','interpolation-truncation']
|
26 |
+
random_seed = 0,
|
27 |
+
diameter = 100.0,
|
28 |
+
scale_type = 'pad', #['pad', 'padside', 'symm','symmside']
|
29 |
+
size = [512, 512],
|
30 |
+
seeds = [0],
|
31 |
+
space = 'z', #['z', 'w']
|
32 |
+
fps = 24,
|
33 |
+
frames = 240,
|
34 |
+
noise_mode = 'none', #['const', 'random', 'none']
|
35 |
+
outdir = 'path',
|
36 |
+
projected_w = 'path',
|
37 |
+
easing = 'linear',
|
38 |
+
device = 'cpu'
|
39 |
+
|
40 |
+
):
|
41 |
+
|
42 |
+
G_kwargs = dnnlib.EasyDict()
|
43 |
+
G_kwargs.size = size
|
44 |
+
G_kwargs.scale_type = scale_type
|
45 |
+
|
46 |
+
device = torch.device(device)
|
47 |
+
with dnnlib.util.open_url(network_pkl) as f:
|
48 |
+
# G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
|
49 |
+
G = legacy.load_network_pkl(f, custom=True, **G_kwargs)['G_ema'].to(device) # type: ignore
|
50 |
+
|
51 |
+
return G.synthesis
|
52 |
+
|
53 |
+
|
54 |
+
def build_psp():
|
55 |
+
test_opts = easydict.EasyDict({
|
56 |
+
# arguments for inference script
|
57 |
+
'checkpoint_path' : 'pretrained/psp.pt',
|
58 |
+
'couple_outputs' : False,
|
59 |
+
'resize_outputs' : False,
|
60 |
+
|
61 |
+
'test_batch_size' : 1,
|
62 |
+
'test_workers' : 1,
|
63 |
+
|
64 |
+
# arguments for style-mixing script
|
65 |
+
'n_images' : None,
|
66 |
+
'n_outputs_to_generate' : 5,
|
67 |
+
'mix_alpha' : None,
|
68 |
+
'latent_mask' : None,
|
69 |
+
|
70 |
+
# arguments for super-resolution
|
71 |
+
'resize_factors' : None,
|
72 |
+
})
|
73 |
+
|
74 |
+
# update test options with options used during training
|
75 |
+
ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
|
76 |
+
opts = ckpt['opts']
|
77 |
+
opts.update(vars(test_opts))
|
78 |
+
if 'learn_in_w' not in opts:
|
79 |
+
opts['learn_in_w'] = False
|
80 |
+
opts = Namespace(**opts)
|
81 |
+
opts.device = 'cpu'
|
82 |
+
net = pSp(opts)
|
83 |
+
net.eval()
|
84 |
+
return net
|
85 |
+
|
86 |
+
def img_preprocess(img, transform):
|
87 |
+
if (img.mode == 'RGBA') or (img.mode == 'P'):
|
88 |
+
img.load()
|
89 |
+
background = Image.new("RGB", img.size, (255, 255, 255))
|
90 |
+
background.paste(img, mask=img.split()[3]) # 3 is the alpha channel
|
91 |
+
img = background
|
92 |
+
assert img.mode == 'RGB'
|
93 |
+
img = transform(img)
|
94 |
+
img = img.unsqueeze(dim=0)
|
95 |
+
return img
|
models/__init__.py
ADDED
File without changes
|
models/encoders/__init__.py
ADDED
File without changes
|
models/encoders/helpers.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import torch
|
3 |
+
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
|
4 |
+
|
5 |
+
"""
|
6 |
+
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
7 |
+
"""
|
8 |
+
|
9 |
+
|
10 |
+
class Flatten(Module):
|
11 |
+
def forward(self, input):
|
12 |
+
return input.view(input.size(0), -1)
|
13 |
+
|
14 |
+
|
15 |
+
def l2_norm(input, axis=1):
|
16 |
+
norm = torch.norm(input, 2, axis, True)
|
17 |
+
output = torch.div(input, norm)
|
18 |
+
return output
|
19 |
+
|
20 |
+
|
21 |
+
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
22 |
+
""" A named tuple describing a ResNet block. """
|
23 |
+
|
24 |
+
|
25 |
+
def get_block(in_channel, depth, num_units, stride=2):
|
26 |
+
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
27 |
+
|
28 |
+
|
29 |
+
def get_blocks(num_layers):
|
30 |
+
if num_layers == 50:
|
31 |
+
blocks = [
|
32 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
33 |
+
get_block(in_channel=64, depth=128, num_units=4),
|
34 |
+
get_block(in_channel=128, depth=256, num_units=14),
|
35 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
36 |
+
]
|
37 |
+
elif num_layers == 100:
|
38 |
+
blocks = [
|
39 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
40 |
+
get_block(in_channel=64, depth=128, num_units=13),
|
41 |
+
get_block(in_channel=128, depth=256, num_units=30),
|
42 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
43 |
+
]
|
44 |
+
elif num_layers == 152:
|
45 |
+
blocks = [
|
46 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
47 |
+
get_block(in_channel=64, depth=128, num_units=8),
|
48 |
+
get_block(in_channel=128, depth=256, num_units=36),
|
49 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
50 |
+
]
|
51 |
+
else:
|
52 |
+
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
|
53 |
+
return blocks
|
54 |
+
|
55 |
+
|
56 |
+
class SEModule(Module):
|
57 |
+
def __init__(self, channels, reduction):
|
58 |
+
super(SEModule, self).__init__()
|
59 |
+
self.avg_pool = AdaptiveAvgPool2d(1)
|
60 |
+
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
61 |
+
self.relu = ReLU(inplace=True)
|
62 |
+
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
63 |
+
self.sigmoid = Sigmoid()
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
module_input = x
|
67 |
+
x = self.avg_pool(x)
|
68 |
+
x = self.fc1(x)
|
69 |
+
x = self.relu(x)
|
70 |
+
x = self.fc2(x)
|
71 |
+
x = self.sigmoid(x)
|
72 |
+
return module_input * x
|
73 |
+
|
74 |
+
|
75 |
+
class bottleneck_IR(Module):
|
76 |
+
def __init__(self, in_channel, depth, stride):
|
77 |
+
super(bottleneck_IR, self).__init__()
|
78 |
+
if in_channel == depth:
|
79 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
80 |
+
else:
|
81 |
+
self.shortcut_layer = Sequential(
|
82 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
83 |
+
BatchNorm2d(depth)
|
84 |
+
)
|
85 |
+
self.res_layer = Sequential(
|
86 |
+
BatchNorm2d(in_channel),
|
87 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
|
88 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
shortcut = self.shortcut_layer(x)
|
93 |
+
res = self.res_layer(x)
|
94 |
+
return res + shortcut
|
95 |
+
|
96 |
+
|
97 |
+
class bottleneck_IR_SE(Module):
|
98 |
+
def __init__(self, in_channel, depth, stride):
|
99 |
+
super(bottleneck_IR_SE, self).__init__()
|
100 |
+
if in_channel == depth:
|
101 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
102 |
+
else:
|
103 |
+
self.shortcut_layer = Sequential(
|
104 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
105 |
+
BatchNorm2d(depth)
|
106 |
+
)
|
107 |
+
self.res_layer = Sequential(
|
108 |
+
BatchNorm2d(in_channel),
|
109 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
110 |
+
PReLU(depth),
|
111 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
112 |
+
BatchNorm2d(depth),
|
113 |
+
SEModule(depth, 16)
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
shortcut = self.shortcut_layer(x)
|
118 |
+
res = self.res_layer(x)
|
119 |
+
return res + shortcut
|
models/encoders/model_irse.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
|
2 |
+
from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
|
3 |
+
|
4 |
+
"""
|
5 |
+
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
6 |
+
"""
|
7 |
+
|
8 |
+
|
9 |
+
class Backbone(Module):
|
10 |
+
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
|
11 |
+
super(Backbone, self).__init__()
|
12 |
+
assert input_size in [112, 224], "input_size should be 112 or 224"
|
13 |
+
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
|
14 |
+
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
|
15 |
+
blocks = get_blocks(num_layers)
|
16 |
+
if mode == 'ir':
|
17 |
+
unit_module = bottleneck_IR
|
18 |
+
elif mode == 'ir_se':
|
19 |
+
unit_module = bottleneck_IR_SE
|
20 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
21 |
+
BatchNorm2d(64),
|
22 |
+
PReLU(64))
|
23 |
+
if input_size == 112:
|
24 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
25 |
+
Dropout(drop_ratio),
|
26 |
+
Flatten(),
|
27 |
+
Linear(512 * 7 * 7, 512),
|
28 |
+
BatchNorm1d(512, affine=affine))
|
29 |
+
else:
|
30 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
31 |
+
Dropout(drop_ratio),
|
32 |
+
Flatten(),
|
33 |
+
Linear(512 * 14 * 14, 512),
|
34 |
+
BatchNorm1d(512, affine=affine))
|
35 |
+
|
36 |
+
modules = []
|
37 |
+
for block in blocks:
|
38 |
+
for bottleneck in block:
|
39 |
+
modules.append(unit_module(bottleneck.in_channel,
|
40 |
+
bottleneck.depth,
|
41 |
+
bottleneck.stride))
|
42 |
+
self.body = Sequential(*modules)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.input_layer(x)
|
46 |
+
x = self.body(x)
|
47 |
+
x = self.output_layer(x)
|
48 |
+
return l2_norm(x)
|
49 |
+
|
50 |
+
|
51 |
+
def IR_50(input_size):
|
52 |
+
"""Constructs a ir-50 model."""
|
53 |
+
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
|
54 |
+
return model
|
55 |
+
|
56 |
+
|
57 |
+
def IR_101(input_size):
|
58 |
+
"""Constructs a ir-101 model."""
|
59 |
+
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
|
60 |
+
return model
|
61 |
+
|
62 |
+
|
63 |
+
def IR_152(input_size):
|
64 |
+
"""Constructs a ir-152 model."""
|
65 |
+
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
|
66 |
+
return model
|
67 |
+
|
68 |
+
|
69 |
+
def IR_SE_50(input_size):
|
70 |
+
"""Constructs a ir_se-50 model."""
|
71 |
+
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
|
72 |
+
return model
|
73 |
+
|
74 |
+
|
75 |
+
def IR_SE_101(input_size):
|
76 |
+
"""Constructs a ir_se-101 model."""
|
77 |
+
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
|
78 |
+
return model
|
79 |
+
|
80 |
+
|
81 |
+
def IR_SE_152(input_size):
|
82 |
+
"""Constructs a ir_se-152 model."""
|
83 |
+
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
|
84 |
+
return model
|
models/encoders/psp_encoders.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import Linear, Conv2d, BatchNorm2d, PReLU, Sequential, Module
|
6 |
+
|
7 |
+
from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE
|
8 |
+
from models.stylegan2.model import EqualLinear
|
9 |
+
|
10 |
+
|
11 |
+
class GradualStyleBlock(Module):
|
12 |
+
def __init__(self, in_c, out_c, spatial):
|
13 |
+
super(GradualStyleBlock, self).__init__()
|
14 |
+
self.out_c = out_c
|
15 |
+
self.spatial = spatial
|
16 |
+
num_pools = int(np.log2(spatial))
|
17 |
+
modules = []
|
18 |
+
modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
|
19 |
+
nn.LeakyReLU()]
|
20 |
+
for i in range(num_pools - 1):
|
21 |
+
modules += [
|
22 |
+
Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
|
23 |
+
nn.LeakyReLU()
|
24 |
+
]
|
25 |
+
self.convs = nn.Sequential(*modules)
|
26 |
+
self.linear = EqualLinear(out_c, out_c, lr_mul=1)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
x = self.convs(x)
|
30 |
+
x = x.view(-1, self.out_c)
|
31 |
+
x = self.linear(x)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class GradualStyleEncoder(Module):
|
36 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
37 |
+
super(GradualStyleEncoder, self).__init__()
|
38 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
39 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
40 |
+
blocks = get_blocks(num_layers)
|
41 |
+
if mode == 'ir':
|
42 |
+
unit_module = bottleneck_IR
|
43 |
+
elif mode == 'ir_se':
|
44 |
+
unit_module = bottleneck_IR_SE
|
45 |
+
self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
|
46 |
+
BatchNorm2d(64),
|
47 |
+
PReLU(64))
|
48 |
+
modules = []
|
49 |
+
for block in blocks:
|
50 |
+
for bottleneck in block:
|
51 |
+
modules.append(unit_module(bottleneck.in_channel,
|
52 |
+
bottleneck.depth,
|
53 |
+
bottleneck.stride))
|
54 |
+
self.body = Sequential(*modules)
|
55 |
+
|
56 |
+
self.styles = nn.ModuleList()
|
57 |
+
self.style_count = opts.n_styles
|
58 |
+
self.coarse_ind = 3
|
59 |
+
self.middle_ind = 7
|
60 |
+
for i in range(self.style_count):
|
61 |
+
if i < self.coarse_ind:
|
62 |
+
style = GradualStyleBlock(512, 512, 16)
|
63 |
+
elif i < self.middle_ind:
|
64 |
+
style = GradualStyleBlock(512, 512, 32)
|
65 |
+
else:
|
66 |
+
style = GradualStyleBlock(512, 512, 64)
|
67 |
+
self.styles.append(style)
|
68 |
+
self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
|
69 |
+
self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
|
70 |
+
|
71 |
+
def _upsample_add(self, x, y):
|
72 |
+
'''Upsample and add two feature maps.
|
73 |
+
Args:
|
74 |
+
x: (Variable) top feature map to be upsampled.
|
75 |
+
y: (Variable) lateral feature map.
|
76 |
+
Returns:
|
77 |
+
(Variable) added feature map.
|
78 |
+
Note in PyTorch, when input size is odd, the upsampled feature map
|
79 |
+
with `F.upsample(..., scale_factor=2, mode='nearest')`
|
80 |
+
maybe not equal to the lateral feature map size.
|
81 |
+
e.g.
|
82 |
+
original input size: [N,_,15,15] ->
|
83 |
+
conv2d feature map size: [N,_,8,8] ->
|
84 |
+
upsampled feature map size: [N,_,16,16]
|
85 |
+
So we choose bilinear upsample which supports arbitrary output sizes.
|
86 |
+
'''
|
87 |
+
_, _, H, W = y.size()
|
88 |
+
return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
x = self.input_layer(x)
|
92 |
+
|
93 |
+
latents = []
|
94 |
+
modulelist = list(self.body._modules.values())
|
95 |
+
for i, l in enumerate(modulelist):
|
96 |
+
x = l(x)
|
97 |
+
if i == 6:
|
98 |
+
c1 = x
|
99 |
+
elif i == 20:
|
100 |
+
c2 = x
|
101 |
+
elif i == 23:
|
102 |
+
c3 = x
|
103 |
+
|
104 |
+
for j in range(self.coarse_ind):
|
105 |
+
latents.append(self.styles[j](c3))
|
106 |
+
|
107 |
+
p2 = self._upsample_add(c3, self.latlayer1(c2))
|
108 |
+
for j in range(self.coarse_ind, self.middle_ind):
|
109 |
+
latents.append(self.styles[j](p2))
|
110 |
+
|
111 |
+
p1 = self._upsample_add(p2, self.latlayer2(c1))
|
112 |
+
for j in range(self.middle_ind, self.style_count):
|
113 |
+
latents.append(self.styles[j](p1))
|
114 |
+
|
115 |
+
out = torch.stack(latents, dim=1)
|
116 |
+
return out
|
117 |
+
|
118 |
+
|
119 |
+
class BackboneEncoderUsingLastLayerIntoW(Module):
|
120 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
121 |
+
super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
|
122 |
+
print('Using BackboneEncoderUsingLastLayerIntoW')
|
123 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
124 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
125 |
+
blocks = get_blocks(num_layers)
|
126 |
+
if mode == 'ir':
|
127 |
+
unit_module = bottleneck_IR
|
128 |
+
elif mode == 'ir_se':
|
129 |
+
unit_module = bottleneck_IR_SE
|
130 |
+
self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
|
131 |
+
BatchNorm2d(64),
|
132 |
+
PReLU(64))
|
133 |
+
self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
134 |
+
self.linear = EqualLinear(512, 512, lr_mul=1)
|
135 |
+
modules = []
|
136 |
+
for block in blocks:
|
137 |
+
for bottleneck in block:
|
138 |
+
modules.append(unit_module(bottleneck.in_channel,
|
139 |
+
bottleneck.depth,
|
140 |
+
bottleneck.stride))
|
141 |
+
self.body = Sequential(*modules)
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
x = self.input_layer(x)
|
145 |
+
x = self.body(x)
|
146 |
+
x = self.output_pool(x)
|
147 |
+
x = x.view(-1, 512)
|
148 |
+
x = self.linear(x)
|
149 |
+
return x
|
150 |
+
|
151 |
+
|
152 |
+
class BackboneEncoderUsingLastLayerIntoWPlus(Module):
|
153 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
154 |
+
super(BackboneEncoderUsingLastLayerIntoWPlus, self).__init__()
|
155 |
+
print('Using BackboneEncoderUsingLastLayerIntoWPlus')
|
156 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
157 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
158 |
+
blocks = get_blocks(num_layers)
|
159 |
+
if mode == 'ir':
|
160 |
+
unit_module = bottleneck_IR
|
161 |
+
elif mode == 'ir_se':
|
162 |
+
unit_module = bottleneck_IR_SE
|
163 |
+
self.n_styles = opts.n_styles
|
164 |
+
self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
|
165 |
+
BatchNorm2d(64),
|
166 |
+
PReLU(64))
|
167 |
+
self.output_layer_2 = Sequential(BatchNorm2d(512),
|
168 |
+
torch.nn.AdaptiveAvgPool2d((7, 7)),
|
169 |
+
Flatten(),
|
170 |
+
Linear(512 * 7 * 7, 512))
|
171 |
+
self.linear = EqualLinear(512, 512 * self.n_styles, lr_mul=1)
|
172 |
+
modules = []
|
173 |
+
for block in blocks:
|
174 |
+
for bottleneck in block:
|
175 |
+
modules.append(unit_module(bottleneck.in_channel,
|
176 |
+
bottleneck.depth,
|
177 |
+
bottleneck.stride))
|
178 |
+
self.body = Sequential(*modules)
|
179 |
+
|
180 |
+
def forward(self, x):
|
181 |
+
x = self.input_layer(x)
|
182 |
+
x = self.body(x)
|
183 |
+
x = self.output_layer_2(x)
|
184 |
+
x = self.linear(x)
|
185 |
+
x = x.view(-1, self.n_styles, 512)
|
186 |
+
return x
|
models/mtcnn/__init__.py
ADDED
File without changes
|
models/mtcnn/mtcnn.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from models.mtcnn.mtcnn_pytorch.src.get_nets import PNet, RNet, ONet
|
5 |
+
from models.mtcnn.mtcnn_pytorch.src.box_utils import nms, calibrate_box, get_image_boxes, convert_to_square
|
6 |
+
from models.mtcnn.mtcnn_pytorch.src.first_stage import run_first_stage
|
7 |
+
from models.mtcnn.mtcnn_pytorch.src.align_trans import get_reference_facial_points, warp_and_crop_face
|
8 |
+
|
9 |
+
device = 'cuda:0'
|
10 |
+
|
11 |
+
|
12 |
+
class MTCNN():
|
13 |
+
def __init__(self):
|
14 |
+
print(device)
|
15 |
+
self.pnet = PNet().to(device)
|
16 |
+
self.rnet = RNet().to(device)
|
17 |
+
self.onet = ONet().to(device)
|
18 |
+
self.pnet.eval()
|
19 |
+
self.rnet.eval()
|
20 |
+
self.onet.eval()
|
21 |
+
self.refrence = get_reference_facial_points(default_square=True)
|
22 |
+
|
23 |
+
def align(self, img):
|
24 |
+
_, landmarks = self.detect_faces(img)
|
25 |
+
if len(landmarks) == 0:
|
26 |
+
return None, None
|
27 |
+
facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)]
|
28 |
+
warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112))
|
29 |
+
return Image.fromarray(warped_face), tfm
|
30 |
+
|
31 |
+
def align_multi(self, img, limit=None, min_face_size=30.0):
|
32 |
+
boxes, landmarks = self.detect_faces(img, min_face_size)
|
33 |
+
if limit:
|
34 |
+
boxes = boxes[:limit]
|
35 |
+
landmarks = landmarks[:limit]
|
36 |
+
faces = []
|
37 |
+
tfms = []
|
38 |
+
for landmark in landmarks:
|
39 |
+
facial5points = [[landmark[j], landmark[j + 5]] for j in range(5)]
|
40 |
+
warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112))
|
41 |
+
faces.append(Image.fromarray(warped_face))
|
42 |
+
tfms.append(tfm)
|
43 |
+
return boxes, faces, tfms
|
44 |
+
|
45 |
+
def detect_faces(self, image, min_face_size=20.0,
|
46 |
+
thresholds=[0.15, 0.25, 0.35],
|
47 |
+
nms_thresholds=[0.7, 0.7, 0.7]):
|
48 |
+
"""
|
49 |
+
Arguments:
|
50 |
+
image: an instance of PIL.Image.
|
51 |
+
min_face_size: a float number.
|
52 |
+
thresholds: a list of length 3.
|
53 |
+
nms_thresholds: a list of length 3.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10],
|
57 |
+
bounding boxes and facial landmarks.
|
58 |
+
"""
|
59 |
+
|
60 |
+
# BUILD AN IMAGE PYRAMID
|
61 |
+
width, height = image.size
|
62 |
+
min_length = min(height, width)
|
63 |
+
|
64 |
+
min_detection_size = 12
|
65 |
+
factor = 0.707 # sqrt(0.5)
|
66 |
+
|
67 |
+
# scales for scaling the image
|
68 |
+
scales = []
|
69 |
+
|
70 |
+
# scales the image so that
|
71 |
+
# minimum size that we can detect equals to
|
72 |
+
# minimum face size that we want to detect
|
73 |
+
m = min_detection_size / min_face_size
|
74 |
+
min_length *= m
|
75 |
+
|
76 |
+
factor_count = 0
|
77 |
+
while min_length > min_detection_size:
|
78 |
+
scales.append(m * factor ** factor_count)
|
79 |
+
min_length *= factor
|
80 |
+
factor_count += 1
|
81 |
+
|
82 |
+
# STAGE 1
|
83 |
+
|
84 |
+
# it will be returned
|
85 |
+
bounding_boxes = []
|
86 |
+
|
87 |
+
with torch.no_grad():
|
88 |
+
# run P-Net on different scales
|
89 |
+
for s in scales:
|
90 |
+
boxes = run_first_stage(image, self.pnet, scale=s, threshold=thresholds[0])
|
91 |
+
bounding_boxes.append(boxes)
|
92 |
+
|
93 |
+
# collect boxes (and offsets, and scores) from different scales
|
94 |
+
bounding_boxes = [i for i in bounding_boxes if i is not None]
|
95 |
+
bounding_boxes = np.vstack(bounding_boxes)
|
96 |
+
|
97 |
+
keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0])
|
98 |
+
bounding_boxes = bounding_boxes[keep]
|
99 |
+
|
100 |
+
# use offsets predicted by pnet to transform bounding boxes
|
101 |
+
bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:])
|
102 |
+
# shape [n_boxes, 5]
|
103 |
+
|
104 |
+
bounding_boxes = convert_to_square(bounding_boxes)
|
105 |
+
bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
|
106 |
+
|
107 |
+
# STAGE 2
|
108 |
+
|
109 |
+
img_boxes = get_image_boxes(bounding_boxes, image, size=24)
|
110 |
+
img_boxes = torch.FloatTensor(img_boxes).to(device)
|
111 |
+
|
112 |
+
output = self.rnet(img_boxes)
|
113 |
+
offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4]
|
114 |
+
probs = output[1].cpu().data.numpy() # shape [n_boxes, 2]
|
115 |
+
|
116 |
+
keep = np.where(probs[:, 1] > thresholds[1])[0]
|
117 |
+
bounding_boxes = bounding_boxes[keep]
|
118 |
+
bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
|
119 |
+
offsets = offsets[keep]
|
120 |
+
|
121 |
+
keep = nms(bounding_boxes, nms_thresholds[1])
|
122 |
+
bounding_boxes = bounding_boxes[keep]
|
123 |
+
bounding_boxes = calibrate_box(bounding_boxes, offsets[keep])
|
124 |
+
bounding_boxes = convert_to_square(bounding_boxes)
|
125 |
+
bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
|
126 |
+
|
127 |
+
# STAGE 3
|
128 |
+
|
129 |
+
img_boxes = get_image_boxes(bounding_boxes, image, size=48)
|
130 |
+
if len(img_boxes) == 0:
|
131 |
+
return [], []
|
132 |
+
img_boxes = torch.FloatTensor(img_boxes).to(device)
|
133 |
+
output = self.onet(img_boxes)
|
134 |
+
landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10]
|
135 |
+
offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4]
|
136 |
+
probs = output[2].cpu().data.numpy() # shape [n_boxes, 2]
|
137 |
+
|
138 |
+
keep = np.where(probs[:, 1] > thresholds[2])[0]
|
139 |
+
bounding_boxes = bounding_boxes[keep]
|
140 |
+
bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
|
141 |
+
offsets = offsets[keep]
|
142 |
+
landmarks = landmarks[keep]
|
143 |
+
|
144 |
+
# compute landmark points
|
145 |
+
width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0
|
146 |
+
height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0
|
147 |
+
xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1]
|
148 |
+
landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5]
|
149 |
+
landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10]
|
150 |
+
|
151 |
+
bounding_boxes = calibrate_box(bounding_boxes, offsets)
|
152 |
+
keep = nms(bounding_boxes, nms_thresholds[2], mode='min')
|
153 |
+
bounding_boxes = bounding_boxes[keep]
|
154 |
+
landmarks = landmarks[keep]
|
155 |
+
|
156 |
+
return bounding_boxes, landmarks
|
models/mtcnn/mtcnn_pytorch/__init__.py
ADDED
File without changes
|
models/mtcnn/mtcnn_pytorch/src/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
1 |
+
from .visualization_utils import show_bboxes
|
2 |
+
from .detector import detect_faces
|
models/mtcnn/mtcnn_pytorch/src/align_trans.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Mon Apr 24 15:43:29 2017
|
4 |
+
@author: zhaoy
|
5 |
+
"""
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
|
9 |
+
# from scipy.linalg import lstsq
|
10 |
+
# from scipy.ndimage import geometric_transform # , map_coordinates
|
11 |
+
|
12 |
+
from models.mtcnn.mtcnn_pytorch.src.matlab_cp2tform import get_similarity_transform_for_cv2
|
13 |
+
|
14 |
+
# reference facial points, a list of coordinates (x,y)
|
15 |
+
REFERENCE_FACIAL_POINTS = [
|
16 |
+
[30.29459953, 51.69630051],
|
17 |
+
[65.53179932, 51.50139999],
|
18 |
+
[48.02519989, 71.73660278],
|
19 |
+
[33.54930115, 92.3655014],
|
20 |
+
[62.72990036, 92.20410156]
|
21 |
+
]
|
22 |
+
|
23 |
+
DEFAULT_CROP_SIZE = (96, 112)
|
24 |
+
|
25 |
+
|
26 |
+
class FaceWarpException(Exception):
|
27 |
+
def __str__(self):
|
28 |
+
return 'In File {}:{}'.format(
|
29 |
+
__file__, super.__str__(self))
|
30 |
+
|
31 |
+
|
32 |
+
def get_reference_facial_points(output_size=None,
|
33 |
+
inner_padding_factor=0.0,
|
34 |
+
outer_padding=(0, 0),
|
35 |
+
default_square=False):
|
36 |
+
"""
|
37 |
+
Function:
|
38 |
+
----------
|
39 |
+
get reference 5 key points according to crop settings:
|
40 |
+
0. Set default crop_size:
|
41 |
+
if default_square:
|
42 |
+
crop_size = (112, 112)
|
43 |
+
else:
|
44 |
+
crop_size = (96, 112)
|
45 |
+
1. Pad the crop_size by inner_padding_factor in each side;
|
46 |
+
2. Resize crop_size into (output_size - outer_padding*2),
|
47 |
+
pad into output_size with outer_padding;
|
48 |
+
3. Output reference_5point;
|
49 |
+
Parameters:
|
50 |
+
----------
|
51 |
+
@output_size: (w, h) or None
|
52 |
+
size of aligned face image
|
53 |
+
@inner_padding_factor: (w_factor, h_factor)
|
54 |
+
padding factor for inner (w, h)
|
55 |
+
@outer_padding: (w_pad, h_pad)
|
56 |
+
each row is a pair of coordinates (x, y)
|
57 |
+
@default_square: True or False
|
58 |
+
if True:
|
59 |
+
default crop_size = (112, 112)
|
60 |
+
else:
|
61 |
+
default crop_size = (96, 112);
|
62 |
+
!!! make sure, if output_size is not None:
|
63 |
+
(output_size - outer_padding)
|
64 |
+
= some_scale * (default crop_size * (1.0 + inner_padding_factor))
|
65 |
+
Returns:
|
66 |
+
----------
|
67 |
+
@reference_5point: 5x2 np.array
|
68 |
+
each row is a pair of transformed coordinates (x, y)
|
69 |
+
"""
|
70 |
+
# print('\n===> get_reference_facial_points():')
|
71 |
+
|
72 |
+
# print('---> Params:')
|
73 |
+
# print(' output_size: ', output_size)
|
74 |
+
# print(' inner_padding_factor: ', inner_padding_factor)
|
75 |
+
# print(' outer_padding:', outer_padding)
|
76 |
+
# print(' default_square: ', default_square)
|
77 |
+
|
78 |
+
tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
|
79 |
+
tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
|
80 |
+
|
81 |
+
# 0) make the inner region a square
|
82 |
+
if default_square:
|
83 |
+
size_diff = max(tmp_crop_size) - tmp_crop_size
|
84 |
+
tmp_5pts += size_diff / 2
|
85 |
+
tmp_crop_size += size_diff
|
86 |
+
|
87 |
+
# print('---> default:')
|
88 |
+
# print(' crop_size = ', tmp_crop_size)
|
89 |
+
# print(' reference_5pts = ', tmp_5pts)
|
90 |
+
|
91 |
+
if (output_size and
|
92 |
+
output_size[0] == tmp_crop_size[0] and
|
93 |
+
output_size[1] == tmp_crop_size[1]):
|
94 |
+
# print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
|
95 |
+
return tmp_5pts
|
96 |
+
|
97 |
+
if (inner_padding_factor == 0 and
|
98 |
+
outer_padding == (0, 0)):
|
99 |
+
if output_size is None:
|
100 |
+
# print('No paddings to do: return default reference points')
|
101 |
+
return tmp_5pts
|
102 |
+
else:
|
103 |
+
raise FaceWarpException(
|
104 |
+
'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
|
105 |
+
|
106 |
+
# check output size
|
107 |
+
if not (0 <= inner_padding_factor <= 1.0):
|
108 |
+
raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
|
109 |
+
|
110 |
+
if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
|
111 |
+
and output_size is None):
|
112 |
+
output_size = tmp_crop_size * \
|
113 |
+
(1 + inner_padding_factor * 2).astype(np.int32)
|
114 |
+
output_size += np.array(outer_padding)
|
115 |
+
# print(' deduced from paddings, output_size = ', output_size)
|
116 |
+
|
117 |
+
if not (outer_padding[0] < output_size[0]
|
118 |
+
and outer_padding[1] < output_size[1]):
|
119 |
+
raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
|
120 |
+
'and outer_padding[1] < output_size[1])')
|
121 |
+
|
122 |
+
# 1) pad the inner region according inner_padding_factor
|
123 |
+
# print('---> STEP1: pad the inner region according inner_padding_factor')
|
124 |
+
if inner_padding_factor > 0:
|
125 |
+
size_diff = tmp_crop_size * inner_padding_factor * 2
|
126 |
+
tmp_5pts += size_diff / 2
|
127 |
+
tmp_crop_size += np.round(size_diff).astype(np.int32)
|
128 |
+
|
129 |
+
# print(' crop_size = ', tmp_crop_size)
|
130 |
+
# print(' reference_5pts = ', tmp_5pts)
|
131 |
+
|
132 |
+
# 2) resize the padded inner region
|
133 |
+
# print('---> STEP2: resize the padded inner region')
|
134 |
+
size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
|
135 |
+
# print(' crop_size = ', tmp_crop_size)
|
136 |
+
# print(' size_bf_outer_pad = ', size_bf_outer_pad)
|
137 |
+
|
138 |
+
if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
|
139 |
+
raise FaceWarpException('Must have (output_size - outer_padding)'
|
140 |
+
'= some_scale * (crop_size * (1.0 + inner_padding_factor)')
|
141 |
+
|
142 |
+
scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
|
143 |
+
# print(' resize scale_factor = ', scale_factor)
|
144 |
+
tmp_5pts = tmp_5pts * scale_factor
|
145 |
+
# size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
|
146 |
+
# tmp_5pts = tmp_5pts + size_diff / 2
|
147 |
+
tmp_crop_size = size_bf_outer_pad
|
148 |
+
# print(' crop_size = ', tmp_crop_size)
|
149 |
+
# print(' reference_5pts = ', tmp_5pts)
|
150 |
+
|
151 |
+
# 3) add outer_padding to make output_size
|
152 |
+
reference_5point = tmp_5pts + np.array(outer_padding)
|
153 |
+
tmp_crop_size = output_size
|
154 |
+
# print('---> STEP3: add outer_padding to make output_size')
|
155 |
+
# print(' crop_size = ', tmp_crop_size)
|
156 |
+
# print(' reference_5pts = ', tmp_5pts)
|
157 |
+
|
158 |
+
# print('===> end get_reference_facial_points\n')
|
159 |
+
|
160 |
+
return reference_5point
|
161 |
+
|
162 |
+
|
163 |
+
def get_affine_transform_matrix(src_pts, dst_pts):
|
164 |
+
"""
|
165 |
+
Function:
|
166 |
+
----------
|
167 |
+
get affine transform matrix 'tfm' from src_pts to dst_pts
|
168 |
+
Parameters:
|
169 |
+
----------
|
170 |
+
@src_pts: Kx2 np.array
|
171 |
+
source points matrix, each row is a pair of coordinates (x, y)
|
172 |
+
@dst_pts: Kx2 np.array
|
173 |
+
destination points matrix, each row is a pair of coordinates (x, y)
|
174 |
+
Returns:
|
175 |
+
----------
|
176 |
+
@tfm: 2x3 np.array
|
177 |
+
transform matrix from src_pts to dst_pts
|
178 |
+
"""
|
179 |
+
|
180 |
+
tfm = np.float32([[1, 0, 0], [0, 1, 0]])
|
181 |
+
n_pts = src_pts.shape[0]
|
182 |
+
ones = np.ones((n_pts, 1), src_pts.dtype)
|
183 |
+
src_pts_ = np.hstack([src_pts, ones])
|
184 |
+
dst_pts_ = np.hstack([dst_pts, ones])
|
185 |
+
|
186 |
+
# #print(('src_pts_:\n' + str(src_pts_))
|
187 |
+
# #print(('dst_pts_:\n' + str(dst_pts_))
|
188 |
+
|
189 |
+
A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
|
190 |
+
|
191 |
+
# #print(('np.linalg.lstsq return A: \n' + str(A))
|
192 |
+
# #print(('np.linalg.lstsq return res: \n' + str(res))
|
193 |
+
# #print(('np.linalg.lstsq return rank: \n' + str(rank))
|
194 |
+
# #print(('np.linalg.lstsq return s: \n' + str(s))
|
195 |
+
|
196 |
+
if rank == 3:
|
197 |
+
tfm = np.float32([
|
198 |
+
[A[0, 0], A[1, 0], A[2, 0]],
|
199 |
+
[A[0, 1], A[1, 1], A[2, 1]]
|
200 |
+
])
|
201 |
+
elif rank == 2:
|
202 |
+
tfm = np.float32([
|
203 |
+
[A[0, 0], A[1, 0], 0],
|
204 |
+
[A[0, 1], A[1, 1], 0]
|
205 |
+
])
|
206 |
+
|
207 |
+
return tfm
|
208 |
+
|
209 |
+
|
210 |
+
def warp_and_crop_face(src_img,
|
211 |
+
facial_pts,
|
212 |
+
reference_pts=None,
|
213 |
+
crop_size=(96, 112),
|
214 |
+
align_type='smilarity'):
|
215 |
+
"""
|
216 |
+
Function:
|
217 |
+
----------
|
218 |
+
apply affine transform 'trans' to uv
|
219 |
+
Parameters:
|
220 |
+
----------
|
221 |
+
@src_img: 3x3 np.array
|
222 |
+
input image
|
223 |
+
@facial_pts: could be
|
224 |
+
1)a list of K coordinates (x,y)
|
225 |
+
or
|
226 |
+
2) Kx2 or 2xK np.array
|
227 |
+
each row or col is a pair of coordinates (x, y)
|
228 |
+
@reference_pts: could be
|
229 |
+
1) a list of K coordinates (x,y)
|
230 |
+
or
|
231 |
+
2) Kx2 or 2xK np.array
|
232 |
+
each row or col is a pair of coordinates (x, y)
|
233 |
+
or
|
234 |
+
3) None
|
235 |
+
if None, use default reference facial points
|
236 |
+
@crop_size: (w, h)
|
237 |
+
output face image size
|
238 |
+
@align_type: transform type, could be one of
|
239 |
+
1) 'similarity': use similarity transform
|
240 |
+
2) 'cv2_affine': use the first 3 points to do affine transform,
|
241 |
+
by calling cv2.getAffineTransform()
|
242 |
+
3) 'affine': use all points to do affine transform
|
243 |
+
Returns:
|
244 |
+
----------
|
245 |
+
@face_img: output face image with size (w, h) = @crop_size
|
246 |
+
"""
|
247 |
+
|
248 |
+
if reference_pts is None:
|
249 |
+
if crop_size[0] == 96 and crop_size[1] == 112:
|
250 |
+
reference_pts = REFERENCE_FACIAL_POINTS
|
251 |
+
else:
|
252 |
+
default_square = False
|
253 |
+
inner_padding_factor = 0
|
254 |
+
outer_padding = (0, 0)
|
255 |
+
output_size = crop_size
|
256 |
+
|
257 |
+
reference_pts = get_reference_facial_points(output_size,
|
258 |
+
inner_padding_factor,
|
259 |
+
outer_padding,
|
260 |
+
default_square)
|
261 |
+
|
262 |
+
ref_pts = np.float32(reference_pts)
|
263 |
+
ref_pts_shp = ref_pts.shape
|
264 |
+
if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
|
265 |
+
raise FaceWarpException(
|
266 |
+
'reference_pts.shape must be (K,2) or (2,K) and K>2')
|
267 |
+
|
268 |
+
if ref_pts_shp[0] == 2:
|
269 |
+
ref_pts = ref_pts.T
|
270 |
+
|
271 |
+
src_pts = np.float32(facial_pts)
|
272 |
+
src_pts_shp = src_pts.shape
|
273 |
+
if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
|
274 |
+
raise FaceWarpException(
|
275 |
+
'facial_pts.shape must be (K,2) or (2,K) and K>2')
|
276 |
+
|
277 |
+
if src_pts_shp[0] == 2:
|
278 |
+
src_pts = src_pts.T
|
279 |
+
|
280 |
+
# #print('--->src_pts:\n', src_pts
|
281 |
+
# #print('--->ref_pts\n', ref_pts
|
282 |
+
|
283 |
+
if src_pts.shape != ref_pts.shape:
|
284 |
+
raise FaceWarpException(
|
285 |
+
'facial_pts and reference_pts must have the same shape')
|
286 |
+
|
287 |
+
if align_type is 'cv2_affine':
|
288 |
+
tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
|
289 |
+
# #print(('cv2.getAffineTransform() returns tfm=\n' + str(tfm))
|
290 |
+
elif align_type is 'affine':
|
291 |
+
tfm = get_affine_transform_matrix(src_pts, ref_pts)
|
292 |
+
# #print(('get_affine_transform_matrix() returns tfm=\n' + str(tfm))
|
293 |
+
else:
|
294 |
+
tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
|
295 |
+
# #print(('get_similarity_transform_for_cv2() returns tfm=\n' + str(tfm))
|
296 |
+
|
297 |
+
# #print('--->Transform matrix: '
|
298 |
+
# #print(('type(tfm):' + str(type(tfm)))
|
299 |
+
# #print(('tfm.dtype:' + str(tfm.dtype))
|
300 |
+
# #print( tfm
|
301 |
+
|
302 |
+
face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
|
303 |
+
|
304 |
+
return face_img, tfm
|
models/mtcnn/mtcnn_pytorch/src/box_utils.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
|
5 |
+
def nms(boxes, overlap_threshold=0.5, mode='union'):
|
6 |
+
"""Non-maximum suppression.
|
7 |
+
|
8 |
+
Arguments:
|
9 |
+
boxes: a float numpy array of shape [n, 5],
|
10 |
+
where each row is (xmin, ymin, xmax, ymax, score).
|
11 |
+
overlap_threshold: a float number.
|
12 |
+
mode: 'union' or 'min'.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
list with indices of the selected boxes
|
16 |
+
"""
|
17 |
+
|
18 |
+
# if there are no boxes, return the empty list
|
19 |
+
if len(boxes) == 0:
|
20 |
+
return []
|
21 |
+
|
22 |
+
# list of picked indices
|
23 |
+
pick = []
|
24 |
+
|
25 |
+
# grab the coordinates of the bounding boxes
|
26 |
+
x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)]
|
27 |
+
|
28 |
+
area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0)
|
29 |
+
ids = np.argsort(score) # in increasing order
|
30 |
+
|
31 |
+
while len(ids) > 0:
|
32 |
+
|
33 |
+
# grab index of the largest value
|
34 |
+
last = len(ids) - 1
|
35 |
+
i = ids[last]
|
36 |
+
pick.append(i)
|
37 |
+
|
38 |
+
# compute intersections
|
39 |
+
# of the box with the largest score
|
40 |
+
# with the rest of boxes
|
41 |
+
|
42 |
+
# left top corner of intersection boxes
|
43 |
+
ix1 = np.maximum(x1[i], x1[ids[:last]])
|
44 |
+
iy1 = np.maximum(y1[i], y1[ids[:last]])
|
45 |
+
|
46 |
+
# right bottom corner of intersection boxes
|
47 |
+
ix2 = np.minimum(x2[i], x2[ids[:last]])
|
48 |
+
iy2 = np.minimum(y2[i], y2[ids[:last]])
|
49 |
+
|
50 |
+
# width and height of intersection boxes
|
51 |
+
w = np.maximum(0.0, ix2 - ix1 + 1.0)
|
52 |
+
h = np.maximum(0.0, iy2 - iy1 + 1.0)
|
53 |
+
|
54 |
+
# intersections' areas
|
55 |
+
inter = w * h
|
56 |
+
if mode == 'min':
|
57 |
+
overlap = inter / np.minimum(area[i], area[ids[:last]])
|
58 |
+
elif mode == 'union':
|
59 |
+
# intersection over union (IoU)
|
60 |
+
overlap = inter / (area[i] + area[ids[:last]] - inter)
|
61 |
+
|
62 |
+
# delete all boxes where overlap is too big
|
63 |
+
ids = np.delete(
|
64 |
+
ids,
|
65 |
+
np.concatenate([[last], np.where(overlap > overlap_threshold)[0]])
|
66 |
+
)
|
67 |
+
|
68 |
+
return pick
|
69 |
+
|
70 |
+
|
71 |
+
def convert_to_square(bboxes):
|
72 |
+
"""Convert bounding boxes to a square form.
|
73 |
+
|
74 |
+
Arguments:
|
75 |
+
bboxes: a float numpy array of shape [n, 5].
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
a float numpy array of shape [n, 5],
|
79 |
+
squared bounding boxes.
|
80 |
+
"""
|
81 |
+
|
82 |
+
square_bboxes = np.zeros_like(bboxes)
|
83 |
+
x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
|
84 |
+
h = y2 - y1 + 1.0
|
85 |
+
w = x2 - x1 + 1.0
|
86 |
+
max_side = np.maximum(h, w)
|
87 |
+
square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5
|
88 |
+
square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5
|
89 |
+
square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0
|
90 |
+
square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0
|
91 |
+
return square_bboxes
|
92 |
+
|
93 |
+
|
94 |
+
def calibrate_box(bboxes, offsets):
|
95 |
+
"""Transform bounding boxes to be more like true bounding boxes.
|
96 |
+
'offsets' is one of the outputs of the nets.
|
97 |
+
|
98 |
+
Arguments:
|
99 |
+
bboxes: a float numpy array of shape [n, 5].
|
100 |
+
offsets: a float numpy array of shape [n, 4].
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
a float numpy array of shape [n, 5].
|
104 |
+
"""
|
105 |
+
x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
|
106 |
+
w = x2 - x1 + 1.0
|
107 |
+
h = y2 - y1 + 1.0
|
108 |
+
w = np.expand_dims(w, 1)
|
109 |
+
h = np.expand_dims(h, 1)
|
110 |
+
|
111 |
+
# this is what happening here:
|
112 |
+
# tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)]
|
113 |
+
# x1_true = x1 + tx1*w
|
114 |
+
# y1_true = y1 + ty1*h
|
115 |
+
# x2_true = x2 + tx2*w
|
116 |
+
# y2_true = y2 + ty2*h
|
117 |
+
# below is just more compact form of this
|
118 |
+
|
119 |
+
# are offsets always such that
|
120 |
+
# x1 < x2 and y1 < y2 ?
|
121 |
+
|
122 |
+
translation = np.hstack([w, h, w, h]) * offsets
|
123 |
+
bboxes[:, 0:4] = bboxes[:, 0:4] + translation
|
124 |
+
return bboxes
|
125 |
+
|
126 |
+
|
127 |
+
def get_image_boxes(bounding_boxes, img, size=24):
|
128 |
+
"""Cut out boxes from the image.
|
129 |
+
|
130 |
+
Arguments:
|
131 |
+
bounding_boxes: a float numpy array of shape [n, 5].
|
132 |
+
img: an instance of PIL.Image.
|
133 |
+
size: an integer, size of cutouts.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
a float numpy array of shape [n, 3, size, size].
|
137 |
+
"""
|
138 |
+
|
139 |
+
num_boxes = len(bounding_boxes)
|
140 |
+
width, height = img.size
|
141 |
+
|
142 |
+
[dy, edy, dx, edx, y, ey, x, ex, w, h] = correct_bboxes(bounding_boxes, width, height)
|
143 |
+
img_boxes = np.zeros((num_boxes, 3, size, size), 'float32')
|
144 |
+
|
145 |
+
for i in range(num_boxes):
|
146 |
+
img_box = np.zeros((h[i], w[i], 3), 'uint8')
|
147 |
+
|
148 |
+
img_array = np.asarray(img, 'uint8')
|
149 |
+
img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] = \
|
150 |
+
img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :]
|
151 |
+
|
152 |
+
# resize
|
153 |
+
img_box = Image.fromarray(img_box)
|
154 |
+
img_box = img_box.resize((size, size), Image.BILINEAR)
|
155 |
+
img_box = np.asarray(img_box, 'float32')
|
156 |
+
|
157 |
+
img_boxes[i, :, :, :] = _preprocess(img_box)
|
158 |
+
|
159 |
+
return img_boxes
|
160 |
+
|
161 |
+
|
162 |
+
def correct_bboxes(bboxes, width, height):
|
163 |
+
"""Crop boxes that are too big and get coordinates
|
164 |
+
with respect to cutouts.
|
165 |
+
|
166 |
+
Arguments:
|
167 |
+
bboxes: a float numpy array of shape [n, 5],
|
168 |
+
where each row is (xmin, ymin, xmax, ymax, score).
|
169 |
+
width: a float number.
|
170 |
+
height: a float number.
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
dy, dx, edy, edx: a int numpy arrays of shape [n],
|
174 |
+
coordinates of the boxes with respect to the cutouts.
|
175 |
+
y, x, ey, ex: a int numpy arrays of shape [n],
|
176 |
+
corrected ymin, xmin, ymax, xmax.
|
177 |
+
h, w: a int numpy arrays of shape [n],
|
178 |
+
just heights and widths of boxes.
|
179 |
+
|
180 |
+
in the following order:
|
181 |
+
[dy, edy, dx, edx, y, ey, x, ex, w, h].
|
182 |
+
"""
|
183 |
+
|
184 |
+
x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
|
185 |
+
w, h = x2 - x1 + 1.0, y2 - y1 + 1.0
|
186 |
+
num_boxes = bboxes.shape[0]
|
187 |
+
|
188 |
+
# 'e' stands for end
|
189 |
+
# (x, y) -> (ex, ey)
|
190 |
+
x, y, ex, ey = x1, y1, x2, y2
|
191 |
+
|
192 |
+
# we need to cut out a box from the image.
|
193 |
+
# (x, y, ex, ey) are corrected coordinates of the box
|
194 |
+
# in the image.
|
195 |
+
# (dx, dy, edx, edy) are coordinates of the box in the cutout
|
196 |
+
# from the image.
|
197 |
+
dx, dy = np.zeros((num_boxes,)), np.zeros((num_boxes,))
|
198 |
+
edx, edy = w.copy() - 1.0, h.copy() - 1.0
|
199 |
+
|
200 |
+
# if box's bottom right corner is too far right
|
201 |
+
ind = np.where(ex > width - 1.0)[0]
|
202 |
+
edx[ind] = w[ind] + width - 2.0 - ex[ind]
|
203 |
+
ex[ind] = width - 1.0
|
204 |
+
|
205 |
+
# if box's bottom right corner is too low
|
206 |
+
ind = np.where(ey > height - 1.0)[0]
|
207 |
+
edy[ind] = h[ind] + height - 2.0 - ey[ind]
|
208 |
+
ey[ind] = height - 1.0
|
209 |
+
|
210 |
+
# if box's top left corner is too far left
|
211 |
+
ind = np.where(x < 0.0)[0]
|
212 |
+
dx[ind] = 0.0 - x[ind]
|
213 |
+
x[ind] = 0.0
|
214 |
+
|
215 |
+
# if box's top left corner is too high
|
216 |
+
ind = np.where(y < 0.0)[0]
|
217 |
+
dy[ind] = 0.0 - y[ind]
|
218 |
+
y[ind] = 0.0
|
219 |
+
|
220 |
+
return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h]
|
221 |
+
return_list = [i.astype('int32') for i in return_list]
|
222 |
+
|
223 |
+
return return_list
|
224 |
+
|
225 |
+
|
226 |
+
def _preprocess(img):
|
227 |
+
"""Preprocessing step before feeding the network.
|
228 |
+
|
229 |
+
Arguments:
|
230 |
+
img: a float numpy array of shape [h, w, c].
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
a float numpy array of shape [1, c, h, w].
|
234 |
+
"""
|
235 |
+
img = img.transpose((2, 0, 1))
|
236 |
+
img = np.expand_dims(img, 0)
|
237 |
+
img = (img - 127.5) * 0.0078125
|
238 |
+
return img
|
models/mtcnn/mtcnn_pytorch/src/detector.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch.autograd import Variable
|
4 |
+
from .get_nets import PNet, RNet, ONet
|
5 |
+
from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square
|
6 |
+
from .first_stage import run_first_stage
|
7 |
+
|
8 |
+
|
9 |
+
def detect_faces(image, min_face_size=20.0,
|
10 |
+
thresholds=[0.6, 0.7, 0.8],
|
11 |
+
nms_thresholds=[0.7, 0.7, 0.7]):
|
12 |
+
"""
|
13 |
+
Arguments:
|
14 |
+
image: an instance of PIL.Image.
|
15 |
+
min_face_size: a float number.
|
16 |
+
thresholds: a list of length 3.
|
17 |
+
nms_thresholds: a list of length 3.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10],
|
21 |
+
bounding boxes and facial landmarks.
|
22 |
+
"""
|
23 |
+
|
24 |
+
# LOAD MODELS
|
25 |
+
pnet = PNet()
|
26 |
+
rnet = RNet()
|
27 |
+
onet = ONet()
|
28 |
+
onet.eval()
|
29 |
+
|
30 |
+
# BUILD AN IMAGE PYRAMID
|
31 |
+
width, height = image.size
|
32 |
+
min_length = min(height, width)
|
33 |
+
|
34 |
+
min_detection_size = 12
|
35 |
+
factor = 0.707 # sqrt(0.5)
|
36 |
+
|
37 |
+
# scales for scaling the image
|
38 |
+
scales = []
|
39 |
+
|
40 |
+
# scales the image so that
|
41 |
+
# minimum size that we can detect equals to
|
42 |
+
# minimum face size that we want to detect
|
43 |
+
m = min_detection_size / min_face_size
|
44 |
+
min_length *= m
|
45 |
+
|
46 |
+
factor_count = 0
|
47 |
+
while min_length > min_detection_size:
|
48 |
+
scales.append(m * factor ** factor_count)
|
49 |
+
min_length *= factor
|
50 |
+
factor_count += 1
|
51 |
+
|
52 |
+
# STAGE 1
|
53 |
+
|
54 |
+
# it will be returned
|
55 |
+
bounding_boxes = []
|
56 |
+
|
57 |
+
with torch.no_grad():
|
58 |
+
# run P-Net on different scales
|
59 |
+
for s in scales:
|
60 |
+
boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0])
|
61 |
+
bounding_boxes.append(boxes)
|
62 |
+
|
63 |
+
# collect boxes (and offsets, and scores) from different scales
|
64 |
+
bounding_boxes = [i for i in bounding_boxes if i is not None]
|
65 |
+
bounding_boxes = np.vstack(bounding_boxes)
|
66 |
+
|
67 |
+
keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0])
|
68 |
+
bounding_boxes = bounding_boxes[keep]
|
69 |
+
|
70 |
+
# use offsets predicted by pnet to transform bounding boxes
|
71 |
+
bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:])
|
72 |
+
# shape [n_boxes, 5]
|
73 |
+
|
74 |
+
bounding_boxes = convert_to_square(bounding_boxes)
|
75 |
+
bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
|
76 |
+
|
77 |
+
# STAGE 2
|
78 |
+
|
79 |
+
img_boxes = get_image_boxes(bounding_boxes, image, size=24)
|
80 |
+
img_boxes = torch.FloatTensor(img_boxes)
|
81 |
+
|
82 |
+
output = rnet(img_boxes)
|
83 |
+
offsets = output[0].data.numpy() # shape [n_boxes, 4]
|
84 |
+
probs = output[1].data.numpy() # shape [n_boxes, 2]
|
85 |
+
|
86 |
+
keep = np.where(probs[:, 1] > thresholds[1])[0]
|
87 |
+
bounding_boxes = bounding_boxes[keep]
|
88 |
+
bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
|
89 |
+
offsets = offsets[keep]
|
90 |
+
|
91 |
+
keep = nms(bounding_boxes, nms_thresholds[1])
|
92 |
+
bounding_boxes = bounding_boxes[keep]
|
93 |
+
bounding_boxes = calibrate_box(bounding_boxes, offsets[keep])
|
94 |
+
bounding_boxes = convert_to_square(bounding_boxes)
|
95 |
+
bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
|
96 |
+
|
97 |
+
# STAGE 3
|
98 |
+
|
99 |
+
img_boxes = get_image_boxes(bounding_boxes, image, size=48)
|
100 |
+
if len(img_boxes) == 0:
|
101 |
+
return [], []
|
102 |
+
img_boxes = torch.FloatTensor(img_boxes)
|
103 |
+
output = onet(img_boxes)
|
104 |
+
landmarks = output[0].data.numpy() # shape [n_boxes, 10]
|
105 |
+
offsets = output[1].data.numpy() # shape [n_boxes, 4]
|
106 |
+
probs = output[2].data.numpy() # shape [n_boxes, 2]
|
107 |
+
|
108 |
+
keep = np.where(probs[:, 1] > thresholds[2])[0]
|
109 |
+
bounding_boxes = bounding_boxes[keep]
|
110 |
+
bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
|
111 |
+
offsets = offsets[keep]
|
112 |
+
landmarks = landmarks[keep]
|
113 |
+
|
114 |
+
# compute landmark points
|
115 |
+
width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0
|
116 |
+
height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0
|
117 |
+
xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1]
|
118 |
+
landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5]
|
119 |
+
landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10]
|
120 |
+
|
121 |
+
bounding_boxes = calibrate_box(bounding_boxes, offsets)
|
122 |
+
keep = nms(bounding_boxes, nms_thresholds[2], mode='min')
|
123 |
+
bounding_boxes = bounding_boxes[keep]
|
124 |
+
landmarks = landmarks[keep]
|
125 |
+
|
126 |
+
return bounding_boxes, landmarks
|
models/mtcnn/mtcnn_pytorch/src/first_stage.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.autograd import Variable
|
3 |
+
import math
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
from .box_utils import nms, _preprocess
|
7 |
+
|
8 |
+
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
9 |
+
device = 'cuda:0'
|
10 |
+
|
11 |
+
|
12 |
+
def run_first_stage(image, net, scale, threshold):
|
13 |
+
"""Run P-Net, generate bounding boxes, and do NMS.
|
14 |
+
|
15 |
+
Arguments:
|
16 |
+
image: an instance of PIL.Image.
|
17 |
+
net: an instance of pytorch's nn.Module, P-Net.
|
18 |
+
scale: a float number,
|
19 |
+
scale width and height of the image by this number.
|
20 |
+
threshold: a float number,
|
21 |
+
threshold on the probability of a face when generating
|
22 |
+
bounding boxes from predictions of the net.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
a float numpy array of shape [n_boxes, 9],
|
26 |
+
bounding boxes with scores and offsets (4 + 1 + 4).
|
27 |
+
"""
|
28 |
+
|
29 |
+
# scale the image and convert it to a float array
|
30 |
+
width, height = image.size
|
31 |
+
sw, sh = math.ceil(width * scale), math.ceil(height * scale)
|
32 |
+
img = image.resize((sw, sh), Image.BILINEAR)
|
33 |
+
img = np.asarray(img, 'float32')
|
34 |
+
|
35 |
+
img = torch.FloatTensor(_preprocess(img)).to(device)
|
36 |
+
with torch.no_grad():
|
37 |
+
output = net(img)
|
38 |
+
probs = output[1].cpu().data.numpy()[0, 1, :, :]
|
39 |
+
offsets = output[0].cpu().data.numpy()
|
40 |
+
# probs: probability of a face at each sliding window
|
41 |
+
# offsets: transformations to true bounding boxes
|
42 |
+
|
43 |
+
boxes = _generate_bboxes(probs, offsets, scale, threshold)
|
44 |
+
if len(boxes) == 0:
|
45 |
+
return None
|
46 |
+
|
47 |
+
keep = nms(boxes[:, 0:5], overlap_threshold=0.5)
|
48 |
+
return boxes[keep]
|
49 |
+
|
50 |
+
|
51 |
+
def _generate_bboxes(probs, offsets, scale, threshold):
|
52 |
+
"""Generate bounding boxes at places
|
53 |
+
where there is probably a face.
|
54 |
+
|
55 |
+
Arguments:
|
56 |
+
probs: a float numpy array of shape [n, m].
|
57 |
+
offsets: a float numpy array of shape [1, 4, n, m].
|
58 |
+
scale: a float number,
|
59 |
+
width and height of the image were scaled by this number.
|
60 |
+
threshold: a float number.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
a float numpy array of shape [n_boxes, 9]
|
64 |
+
"""
|
65 |
+
|
66 |
+
# applying P-Net is equivalent, in some sense, to
|
67 |
+
# moving 12x12 window with stride 2
|
68 |
+
stride = 2
|
69 |
+
cell_size = 12
|
70 |
+
|
71 |
+
# indices of boxes where there is probably a face
|
72 |
+
inds = np.where(probs > threshold)
|
73 |
+
|
74 |
+
if inds[0].size == 0:
|
75 |
+
return np.array([])
|
76 |
+
|
77 |
+
# transformations of bounding boxes
|
78 |
+
tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)]
|
79 |
+
# they are defined as:
|
80 |
+
# w = x2 - x1 + 1
|
81 |
+
# h = y2 - y1 + 1
|
82 |
+
# x1_true = x1 + tx1*w
|
83 |
+
# x2_true = x2 + tx2*w
|
84 |
+
# y1_true = y1 + ty1*h
|
85 |
+
# y2_true = y2 + ty2*h
|
86 |
+
|
87 |
+
offsets = np.array([tx1, ty1, tx2, ty2])
|
88 |
+
score = probs[inds[0], inds[1]]
|
89 |
+
|
90 |
+
# P-Net is applied to scaled images
|
91 |
+
# so we need to rescale bounding boxes back
|
92 |
+
bounding_boxes = np.vstack([
|
93 |
+
np.round((stride * inds[1] + 1.0) / scale),
|
94 |
+
np.round((stride * inds[0] + 1.0) / scale),
|
95 |
+
np.round((stride * inds[1] + 1.0 + cell_size) / scale),
|
96 |
+
np.round((stride * inds[0] + 1.0 + cell_size) / scale),
|
97 |
+
score, offsets
|
98 |
+
])
|
99 |
+
# why one is added?
|
100 |
+
|
101 |
+
return bounding_boxes.T
|
models/mtcnn/mtcnn_pytorch/src/get_nets.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from collections import OrderedDict
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from configs.paths_config import model_paths
|
8 |
+
PNET_PATH = model_paths["mtcnn_pnet"]
|
9 |
+
ONET_PATH = model_paths["mtcnn_onet"]
|
10 |
+
RNET_PATH = model_paths["mtcnn_rnet"]
|
11 |
+
|
12 |
+
|
13 |
+
class Flatten(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
super(Flatten, self).__init__()
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
"""
|
20 |
+
Arguments:
|
21 |
+
x: a float tensor with shape [batch_size, c, h, w].
|
22 |
+
Returns:
|
23 |
+
a float tensor with shape [batch_size, c*h*w].
|
24 |
+
"""
|
25 |
+
|
26 |
+
# without this pretrained model isn't working
|
27 |
+
x = x.transpose(3, 2).contiguous()
|
28 |
+
|
29 |
+
return x.view(x.size(0), -1)
|
30 |
+
|
31 |
+
|
32 |
+
class PNet(nn.Module):
|
33 |
+
|
34 |
+
def __init__(self):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
# suppose we have input with size HxW, then
|
38 |
+
# after first layer: H - 2,
|
39 |
+
# after pool: ceil((H - 2)/2),
|
40 |
+
# after second conv: ceil((H - 2)/2) - 2,
|
41 |
+
# after last conv: ceil((H - 2)/2) - 4,
|
42 |
+
# and the same for W
|
43 |
+
|
44 |
+
self.features = nn.Sequential(OrderedDict([
|
45 |
+
('conv1', nn.Conv2d(3, 10, 3, 1)),
|
46 |
+
('prelu1', nn.PReLU(10)),
|
47 |
+
('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)),
|
48 |
+
|
49 |
+
('conv2', nn.Conv2d(10, 16, 3, 1)),
|
50 |
+
('prelu2', nn.PReLU(16)),
|
51 |
+
|
52 |
+
('conv3', nn.Conv2d(16, 32, 3, 1)),
|
53 |
+
('prelu3', nn.PReLU(32))
|
54 |
+
]))
|
55 |
+
|
56 |
+
self.conv4_1 = nn.Conv2d(32, 2, 1, 1)
|
57 |
+
self.conv4_2 = nn.Conv2d(32, 4, 1, 1)
|
58 |
+
|
59 |
+
weights = np.load(PNET_PATH, allow_pickle=True)[()]
|
60 |
+
for n, p in self.named_parameters():
|
61 |
+
p.data = torch.FloatTensor(weights[n])
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
"""
|
65 |
+
Arguments:
|
66 |
+
x: a float tensor with shape [batch_size, 3, h, w].
|
67 |
+
Returns:
|
68 |
+
b: a float tensor with shape [batch_size, 4, h', w'].
|
69 |
+
a: a float tensor with shape [batch_size, 2, h', w'].
|
70 |
+
"""
|
71 |
+
x = self.features(x)
|
72 |
+
a = self.conv4_1(x)
|
73 |
+
b = self.conv4_2(x)
|
74 |
+
a = F.softmax(a, dim=-1)
|
75 |
+
return b, a
|
76 |
+
|
77 |
+
|
78 |
+
class RNet(nn.Module):
|
79 |
+
|
80 |
+
def __init__(self):
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
self.features = nn.Sequential(OrderedDict([
|
84 |
+
('conv1', nn.Conv2d(3, 28, 3, 1)),
|
85 |
+
('prelu1', nn.PReLU(28)),
|
86 |
+
('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
|
87 |
+
|
88 |
+
('conv2', nn.Conv2d(28, 48, 3, 1)),
|
89 |
+
('prelu2', nn.PReLU(48)),
|
90 |
+
('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
|
91 |
+
|
92 |
+
('conv3', nn.Conv2d(48, 64, 2, 1)),
|
93 |
+
('prelu3', nn.PReLU(64)),
|
94 |
+
|
95 |
+
('flatten', Flatten()),
|
96 |
+
('conv4', nn.Linear(576, 128)),
|
97 |
+
('prelu4', nn.PReLU(128))
|
98 |
+
]))
|
99 |
+
|
100 |
+
self.conv5_1 = nn.Linear(128, 2)
|
101 |
+
self.conv5_2 = nn.Linear(128, 4)
|
102 |
+
|
103 |
+
weights = np.load(RNET_PATH, allow_pickle=True)[()]
|
104 |
+
for n, p in self.named_parameters():
|
105 |
+
p.data = torch.FloatTensor(weights[n])
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
"""
|
109 |
+
Arguments:
|
110 |
+
x: a float tensor with shape [batch_size, 3, h, w].
|
111 |
+
Returns:
|
112 |
+
b: a float tensor with shape [batch_size, 4].
|
113 |
+
a: a float tensor with shape [batch_size, 2].
|
114 |
+
"""
|
115 |
+
x = self.features(x)
|
116 |
+
a = self.conv5_1(x)
|
117 |
+
b = self.conv5_2(x)
|
118 |
+
a = F.softmax(a, dim=-1)
|
119 |
+
return b, a
|
120 |
+
|
121 |
+
|
122 |
+
class ONet(nn.Module):
|
123 |
+
|
124 |
+
def __init__(self):
|
125 |
+
super().__init__()
|
126 |
+
|
127 |
+
self.features = nn.Sequential(OrderedDict([
|
128 |
+
('conv1', nn.Conv2d(3, 32, 3, 1)),
|
129 |
+
('prelu1', nn.PReLU(32)),
|
130 |
+
('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
|
131 |
+
|
132 |
+
('conv2', nn.Conv2d(32, 64, 3, 1)),
|
133 |
+
('prelu2', nn.PReLU(64)),
|
134 |
+
('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
|
135 |
+
|
136 |
+
('conv3', nn.Conv2d(64, 64, 3, 1)),
|
137 |
+
('prelu3', nn.PReLU(64)),
|
138 |
+
('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)),
|
139 |
+
|
140 |
+
('conv4', nn.Conv2d(64, 128, 2, 1)),
|
141 |
+
('prelu4', nn.PReLU(128)),
|
142 |
+
|
143 |
+
('flatten', Flatten()),
|
144 |
+
('conv5', nn.Linear(1152, 256)),
|
145 |
+
('drop5', nn.Dropout(0.25)),
|
146 |
+
('prelu5', nn.PReLU(256)),
|
147 |
+
]))
|
148 |
+
|
149 |
+
self.conv6_1 = nn.Linear(256, 2)
|
150 |
+
self.conv6_2 = nn.Linear(256, 4)
|
151 |
+
self.conv6_3 = nn.Linear(256, 10)
|
152 |
+
|
153 |
+
weights = np.load(ONET_PATH, allow_pickle=True)[()]
|
154 |
+
for n, p in self.named_parameters():
|
155 |
+
p.data = torch.FloatTensor(weights[n])
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
"""
|
159 |
+
Arguments:
|
160 |
+
x: a float tensor with shape [batch_size, 3, h, w].
|
161 |
+
Returns:
|
162 |
+
c: a float tensor with shape [batch_size, 10].
|
163 |
+
b: a float tensor with shape [batch_size, 4].
|
164 |
+
a: a float tensor with shape [batch_size, 2].
|
165 |
+
"""
|
166 |
+
x = self.features(x)
|
167 |
+
a = self.conv6_1(x)
|
168 |
+
b = self.conv6_2(x)
|
169 |
+
c = self.conv6_3(x)
|
170 |
+
a = F.softmax(a, dim=-1)
|
171 |
+
return c, b, a
|
models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Tue Jul 11 06:54:28 2017
|
4 |
+
|
5 |
+
@author: zhaoyafei
|
6 |
+
"""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from numpy.linalg import inv, norm, lstsq
|
10 |
+
from numpy.linalg import matrix_rank as rank
|
11 |
+
|
12 |
+
|
13 |
+
class MatlabCp2tormException(Exception):
|
14 |
+
def __str__(self):
|
15 |
+
return 'In File {}:{}'.format(
|
16 |
+
__file__, super.__str__(self))
|
17 |
+
|
18 |
+
|
19 |
+
def tformfwd(trans, uv):
|
20 |
+
"""
|
21 |
+
Function:
|
22 |
+
----------
|
23 |
+
apply affine transform 'trans' to uv
|
24 |
+
|
25 |
+
Parameters:
|
26 |
+
----------
|
27 |
+
@trans: 3x3 np.array
|
28 |
+
transform matrix
|
29 |
+
@uv: Kx2 np.array
|
30 |
+
each row is a pair of coordinates (x, y)
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
----------
|
34 |
+
@xy: Kx2 np.array
|
35 |
+
each row is a pair of transformed coordinates (x, y)
|
36 |
+
"""
|
37 |
+
uv = np.hstack((
|
38 |
+
uv, np.ones((uv.shape[0], 1))
|
39 |
+
))
|
40 |
+
xy = np.dot(uv, trans)
|
41 |
+
xy = xy[:, 0:-1]
|
42 |
+
return xy
|
43 |
+
|
44 |
+
|
45 |
+
def tforminv(trans, uv):
|
46 |
+
"""
|
47 |
+
Function:
|
48 |
+
----------
|
49 |
+
apply the inverse of affine transform 'trans' to uv
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
----------
|
53 |
+
@trans: 3x3 np.array
|
54 |
+
transform matrix
|
55 |
+
@uv: Kx2 np.array
|
56 |
+
each row is a pair of coordinates (x, y)
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
----------
|
60 |
+
@xy: Kx2 np.array
|
61 |
+
each row is a pair of inverse-transformed coordinates (x, y)
|
62 |
+
"""
|
63 |
+
Tinv = inv(trans)
|
64 |
+
xy = tformfwd(Tinv, uv)
|
65 |
+
return xy
|
66 |
+
|
67 |
+
|
68 |
+
def findNonreflectiveSimilarity(uv, xy, options=None):
|
69 |
+
options = {'K': 2}
|
70 |
+
|
71 |
+
K = options['K']
|
72 |
+
M = xy.shape[0]
|
73 |
+
x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
|
74 |
+
y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
|
75 |
+
# print('--->x, y:\n', x, y
|
76 |
+
|
77 |
+
tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
|
78 |
+
tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
|
79 |
+
X = np.vstack((tmp1, tmp2))
|
80 |
+
# print('--->X.shape: ', X.shape
|
81 |
+
# print('X:\n', X
|
82 |
+
|
83 |
+
u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
|
84 |
+
v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
|
85 |
+
U = np.vstack((u, v))
|
86 |
+
# print('--->U.shape: ', U.shape
|
87 |
+
# print('U:\n', U
|
88 |
+
|
89 |
+
# We know that X * r = U
|
90 |
+
if rank(X) >= 2 * K:
|
91 |
+
r, _, _, _ = lstsq(X, U, rcond=None) # Make sure this is what I want
|
92 |
+
r = np.squeeze(r)
|
93 |
+
else:
|
94 |
+
raise Exception('cp2tform:twoUniquePointsReq')
|
95 |
+
|
96 |
+
# print('--->r:\n', r
|
97 |
+
|
98 |
+
sc = r[0]
|
99 |
+
ss = r[1]
|
100 |
+
tx = r[2]
|
101 |
+
ty = r[3]
|
102 |
+
|
103 |
+
Tinv = np.array([
|
104 |
+
[sc, -ss, 0],
|
105 |
+
[ss, sc, 0],
|
106 |
+
[tx, ty, 1]
|
107 |
+
])
|
108 |
+
|
109 |
+
# print('--->Tinv:\n', Tinv
|
110 |
+
|
111 |
+
T = inv(Tinv)
|
112 |
+
# print('--->T:\n', T
|
113 |
+
|
114 |
+
T[:, 2] = np.array([0, 0, 1])
|
115 |
+
|
116 |
+
return T, Tinv
|
117 |
+
|
118 |
+
|
119 |
+
def findSimilarity(uv, xy, options=None):
|
120 |
+
options = {'K': 2}
|
121 |
+
|
122 |
+
# uv = np.array(uv)
|
123 |
+
# xy = np.array(xy)
|
124 |
+
|
125 |
+
# Solve for trans1
|
126 |
+
trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
|
127 |
+
|
128 |
+
# Solve for trans2
|
129 |
+
|
130 |
+
# manually reflect the xy data across the Y-axis
|
131 |
+
xyR = xy
|
132 |
+
xyR[:, 0] = -1 * xyR[:, 0]
|
133 |
+
|
134 |
+
trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
|
135 |
+
|
136 |
+
# manually reflect the tform to undo the reflection done on xyR
|
137 |
+
TreflectY = np.array([
|
138 |
+
[-1, 0, 0],
|
139 |
+
[0, 1, 0],
|
140 |
+
[0, 0, 1]
|
141 |
+
])
|
142 |
+
|
143 |
+
trans2 = np.dot(trans2r, TreflectY)
|
144 |
+
|
145 |
+
# Figure out if trans1 or trans2 is better
|
146 |
+
xy1 = tformfwd(trans1, uv)
|
147 |
+
norm1 = norm(xy1 - xy)
|
148 |
+
|
149 |
+
xy2 = tformfwd(trans2, uv)
|
150 |
+
norm2 = norm(xy2 - xy)
|
151 |
+
|
152 |
+
if norm1 <= norm2:
|
153 |
+
return trans1, trans1_inv
|
154 |
+
else:
|
155 |
+
trans2_inv = inv(trans2)
|
156 |
+
return trans2, trans2_inv
|
157 |
+
|
158 |
+
|
159 |
+
def get_similarity_transform(src_pts, dst_pts, reflective=True):
|
160 |
+
"""
|
161 |
+
Function:
|
162 |
+
----------
|
163 |
+
Find Similarity Transform Matrix 'trans':
|
164 |
+
u = src_pts[:, 0]
|
165 |
+
v = src_pts[:, 1]
|
166 |
+
x = dst_pts[:, 0]
|
167 |
+
y = dst_pts[:, 1]
|
168 |
+
[x, y, 1] = [u, v, 1] * trans
|
169 |
+
|
170 |
+
Parameters:
|
171 |
+
----------
|
172 |
+
@src_pts: Kx2 np.array
|
173 |
+
source points, each row is a pair of coordinates (x, y)
|
174 |
+
@dst_pts: Kx2 np.array
|
175 |
+
destination points, each row is a pair of transformed
|
176 |
+
coordinates (x, y)
|
177 |
+
@reflective: True or False
|
178 |
+
if True:
|
179 |
+
use reflective similarity transform
|
180 |
+
else:
|
181 |
+
use non-reflective similarity transform
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
----------
|
185 |
+
@trans: 3x3 np.array
|
186 |
+
transform matrix from uv to xy
|
187 |
+
trans_inv: 3x3 np.array
|
188 |
+
inverse of trans, transform matrix from xy to uv
|
189 |
+
"""
|
190 |
+
|
191 |
+
if reflective:
|
192 |
+
trans, trans_inv = findSimilarity(src_pts, dst_pts)
|
193 |
+
else:
|
194 |
+
trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
|
195 |
+
|
196 |
+
return trans, trans_inv
|
197 |
+
|
198 |
+
|
199 |
+
def cvt_tform_mat_for_cv2(trans):
|
200 |
+
"""
|
201 |
+
Function:
|
202 |
+
----------
|
203 |
+
Convert Transform Matrix 'trans' into 'cv2_trans' which could be
|
204 |
+
directly used by cv2.warpAffine():
|
205 |
+
u = src_pts[:, 0]
|
206 |
+
v = src_pts[:, 1]
|
207 |
+
x = dst_pts[:, 0]
|
208 |
+
y = dst_pts[:, 1]
|
209 |
+
[x, y].T = cv_trans * [u, v, 1].T
|
210 |
+
|
211 |
+
Parameters:
|
212 |
+
----------
|
213 |
+
@trans: 3x3 np.array
|
214 |
+
transform matrix from uv to xy
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
----------
|
218 |
+
@cv2_trans: 2x3 np.array
|
219 |
+
transform matrix from src_pts to dst_pts, could be directly used
|
220 |
+
for cv2.warpAffine()
|
221 |
+
"""
|
222 |
+
cv2_trans = trans[:, 0:2].T
|
223 |
+
|
224 |
+
return cv2_trans
|
225 |
+
|
226 |
+
|
227 |
+
def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
|
228 |
+
"""
|
229 |
+
Function:
|
230 |
+
----------
|
231 |
+
Find Similarity Transform Matrix 'cv2_trans' which could be
|
232 |
+
directly used by cv2.warpAffine():
|
233 |
+
u = src_pts[:, 0]
|
234 |
+
v = src_pts[:, 1]
|
235 |
+
x = dst_pts[:, 0]
|
236 |
+
y = dst_pts[:, 1]
|
237 |
+
[x, y].T = cv_trans * [u, v, 1].T
|
238 |
+
|
239 |
+
Parameters:
|
240 |
+
----------
|
241 |
+
@src_pts: Kx2 np.array
|
242 |
+
source points, each row is a pair of coordinates (x, y)
|
243 |
+
@dst_pts: Kx2 np.array
|
244 |
+
destination points, each row is a pair of transformed
|
245 |
+
coordinates (x, y)
|
246 |
+
reflective: True or False
|
247 |
+
if True:
|
248 |
+
use reflective similarity transform
|
249 |
+
else:
|
250 |
+
use non-reflective similarity transform
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
----------
|
254 |
+
@cv2_trans: 2x3 np.array
|
255 |
+
transform matrix from src_pts to dst_pts, could be directly used
|
256 |
+
for cv2.warpAffine()
|
257 |
+
"""
|
258 |
+
trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
|
259 |
+
cv2_trans = cvt_tform_mat_for_cv2(trans)
|
260 |
+
|
261 |
+
return cv2_trans
|
262 |
+
|
263 |
+
|
264 |
+
if __name__ == '__main__':
|
265 |
+
"""
|
266 |
+
u = [0, 6, -2]
|
267 |
+
v = [0, 3, 5]
|
268 |
+
x = [-1, 0, 4]
|
269 |
+
y = [-1, -10, 4]
|
270 |
+
|
271 |
+
# In Matlab, run:
|
272 |
+
#
|
273 |
+
# uv = [u'; v'];
|
274 |
+
# xy = [x'; y'];
|
275 |
+
# tform_sim=cp2tform(uv,xy,'similarity');
|
276 |
+
#
|
277 |
+
# trans = tform_sim.tdata.T
|
278 |
+
# ans =
|
279 |
+
# -0.0764 -1.6190 0
|
280 |
+
# 1.6190 -0.0764 0
|
281 |
+
# -3.2156 0.0290 1.0000
|
282 |
+
# trans_inv = tform_sim.tdata.Tinv
|
283 |
+
# ans =
|
284 |
+
#
|
285 |
+
# -0.0291 0.6163 0
|
286 |
+
# -0.6163 -0.0291 0
|
287 |
+
# -0.0756 1.9826 1.0000
|
288 |
+
# xy_m=tformfwd(tform_sim, u,v)
|
289 |
+
#
|
290 |
+
# xy_m =
|
291 |
+
#
|
292 |
+
# -3.2156 0.0290
|
293 |
+
# 1.1833 -9.9143
|
294 |
+
# 5.0323 2.8853
|
295 |
+
# uv_m=tforminv(tform_sim, x,y)
|
296 |
+
#
|
297 |
+
# uv_m =
|
298 |
+
#
|
299 |
+
# 0.5698 1.3953
|
300 |
+
# 6.0872 2.2733
|
301 |
+
# -2.6570 4.3314
|
302 |
+
"""
|
303 |
+
u = [0, 6, -2]
|
304 |
+
v = [0, 3, 5]
|
305 |
+
x = [-1, 0, 4]
|
306 |
+
y = [-1, -10, 4]
|
307 |
+
|
308 |
+
uv = np.array((u, v)).T
|
309 |
+
xy = np.array((x, y)).T
|
310 |
+
|
311 |
+
print('\n--->uv:')
|
312 |
+
print(uv)
|
313 |
+
print('\n--->xy:')
|
314 |
+
print(xy)
|
315 |
+
|
316 |
+
trans, trans_inv = get_similarity_transform(uv, xy)
|
317 |
+
|
318 |
+
print('\n--->trans matrix:')
|
319 |
+
print(trans)
|
320 |
+
|
321 |
+
print('\n--->trans_inv matrix:')
|
322 |
+
print(trans_inv)
|
323 |
+
|
324 |
+
print('\n---> apply transform to uv')
|
325 |
+
print('\nxy_m = uv_augmented * trans')
|
326 |
+
uv_aug = np.hstack((
|
327 |
+
uv, np.ones((uv.shape[0], 1))
|
328 |
+
))
|
329 |
+
xy_m = np.dot(uv_aug, trans)
|
330 |
+
print(xy_m)
|
331 |
+
|
332 |
+
print('\nxy_m = tformfwd(trans, uv)')
|
333 |
+
xy_m = tformfwd(trans, uv)
|
334 |
+
print(xy_m)
|
335 |
+
|
336 |
+
print('\n---> apply inverse transform to xy')
|
337 |
+
print('\nuv_m = xy_augmented * trans_inv')
|
338 |
+
xy_aug = np.hstack((
|
339 |
+
xy, np.ones((xy.shape[0], 1))
|
340 |
+
))
|
341 |
+
uv_m = np.dot(xy_aug, trans_inv)
|
342 |
+
print(uv_m)
|
343 |
+
|
344 |
+
print('\nuv_m = tformfwd(trans_inv, xy)')
|
345 |
+
uv_m = tformfwd(trans_inv, xy)
|
346 |
+
print(uv_m)
|
347 |
+
|
348 |
+
uv_m = tforminv(trans, xy)
|
349 |
+
print('\nuv_m = tforminv(trans, xy)')
|
350 |
+
print(uv_m)
|
models/mtcnn/mtcnn_pytorch/src/visualization_utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import ImageDraw
|
2 |
+
|
3 |
+
|
4 |
+
def show_bboxes(img, bounding_boxes, facial_landmarks=[]):
|
5 |
+
"""Draw bounding boxes and facial landmarks.
|
6 |
+
|
7 |
+
Arguments:
|
8 |
+
img: an instance of PIL.Image.
|
9 |
+
bounding_boxes: a float numpy array of shape [n, 5].
|
10 |
+
facial_landmarks: a float numpy array of shape [n, 10].
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
an instance of PIL.Image.
|
14 |
+
"""
|
15 |
+
|
16 |
+
img_copy = img.copy()
|
17 |
+
draw = ImageDraw.Draw(img_copy)
|
18 |
+
|
19 |
+
for b in bounding_boxes:
|
20 |
+
draw.rectangle([
|
21 |
+
(b[0], b[1]), (b[2], b[3])
|
22 |
+
], outline='white')
|
23 |
+
|
24 |
+
for p in facial_landmarks:
|
25 |
+
for i in range(5):
|
26 |
+
draw.ellipse([
|
27 |
+
(p[i] - 1.0, p[i + 5] - 1.0),
|
28 |
+
(p[i] + 1.0, p[i + 5] + 1.0)
|
29 |
+
], outline='blue')
|
30 |
+
|
31 |
+
return img_copy
|
models/psp.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file defines the core research contribution
|
3 |
+
"""
|
4 |
+
import matplotlib
|
5 |
+
matplotlib.use('Agg')
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
from models.encoders import psp_encoders
|
11 |
+
from models.stylegan2.model import Generator
|
12 |
+
from configs.paths_config import model_paths
|
13 |
+
|
14 |
+
|
15 |
+
def get_keys(d, name):
|
16 |
+
if 'state_dict' in d:
|
17 |
+
d = d['state_dict']
|
18 |
+
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
|
19 |
+
return d_filt
|
20 |
+
|
21 |
+
|
22 |
+
class pSp(nn.Module):
|
23 |
+
|
24 |
+
def __init__(self, opts):
|
25 |
+
super(pSp, self).__init__()
|
26 |
+
self.set_opts(opts)
|
27 |
+
# compute number of style inputs based on the output resolution
|
28 |
+
self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
|
29 |
+
# Define architecture
|
30 |
+
self.encoder = self.set_encoder()
|
31 |
+
self.decoder = Generator(self.opts.output_size, 512, 8)
|
32 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
33 |
+
# Load weights if needed
|
34 |
+
self.load_weights()
|
35 |
+
|
36 |
+
def set_encoder(self):
|
37 |
+
if self.opts.encoder_type == 'GradualStyleEncoder':
|
38 |
+
encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
|
39 |
+
elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW':
|
40 |
+
encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
|
41 |
+
elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus':
|
42 |
+
encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts)
|
43 |
+
else:
|
44 |
+
raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
|
45 |
+
return encoder
|
46 |
+
|
47 |
+
def load_weights(self):
|
48 |
+
if self.opts.checkpoint_path is not None:
|
49 |
+
print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path))
|
50 |
+
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
|
51 |
+
self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
|
52 |
+
self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
|
53 |
+
self.__load_latent_avg(ckpt)
|
54 |
+
else:
|
55 |
+
print('Loading encoders weights from irse50!')
|
56 |
+
encoder_ckpt = torch.load(model_paths['ir_se50'])
|
57 |
+
# if input to encoder is not an RGB image, do not load the input layer weights
|
58 |
+
if self.opts.label_nc != 0:
|
59 |
+
encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k}
|
60 |
+
self.encoder.load_state_dict(encoder_ckpt, strict=False)
|
61 |
+
print('Loading decoder weights from pretrained!')
|
62 |
+
ckpt = torch.load(self.opts.stylegan_weights)
|
63 |
+
self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
|
64 |
+
if self.opts.learn_in_w:
|
65 |
+
self.__load_latent_avg(ckpt, repeat=1)
|
66 |
+
else:
|
67 |
+
self.__load_latent_avg(ckpt, repeat=self.opts.n_styles)
|
68 |
+
|
69 |
+
def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
|
70 |
+
inject_latent=None, return_latents=False, alpha=None):
|
71 |
+
if input_code:
|
72 |
+
codes = x
|
73 |
+
else:
|
74 |
+
codes = self.encoder(x)
|
75 |
+
# normalize with respect to the center of an average face
|
76 |
+
if self.opts.start_from_latent_avg:
|
77 |
+
if self.opts.learn_in_w:
|
78 |
+
codes = codes + self.latent_avg.repeat(codes.shape[0], 1)
|
79 |
+
else:
|
80 |
+
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
|
81 |
+
|
82 |
+
|
83 |
+
if latent_mask is not None:
|
84 |
+
for i in latent_mask:
|
85 |
+
if inject_latent is not None:
|
86 |
+
if alpha is not None:
|
87 |
+
codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
|
88 |
+
else:
|
89 |
+
codes[:, i] = inject_latent[:, i]
|
90 |
+
else:
|
91 |
+
codes[:, i] = 0
|
92 |
+
|
93 |
+
input_is_latent = not input_code
|
94 |
+
|
95 |
+
if return_latents:
|
96 |
+
result_latent = self.decoder([codes],input_is_latent=input_is_latent,randomize_noise=randomize_noise,return_latents=return_latents)
|
97 |
+
return result_latent
|
98 |
+
else:
|
99 |
+
images, result_latent = self.decoder([codes],
|
100 |
+
input_is_latent=input_is_latent,
|
101 |
+
randomize_noise=randomize_noise,
|
102 |
+
return_latents=return_latents)
|
103 |
+
|
104 |
+
if resize:
|
105 |
+
images = self.face_pool(images)
|
106 |
+
|
107 |
+
return images
|
108 |
+
|
109 |
+
def set_opts(self, opts):
|
110 |
+
self.opts = opts
|
111 |
+
|
112 |
+
def __load_latent_avg(self, ckpt, repeat=None):
|
113 |
+
if 'latent_avg' in ckpt:
|
114 |
+
self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
|
115 |
+
if repeat is not None:
|
116 |
+
self.latent_avg = self.latent_avg.repeat(repeat, 1)
|
117 |
+
else:
|
118 |
+
self.latent_avg = None
|
models/stylegan2/__init__.py
ADDED
File without changes
|
models/stylegan2/model.py
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
|
8 |
+
|
9 |
+
|
10 |
+
class PixelNorm(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
def forward(self, input):
|
15 |
+
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
16 |
+
|
17 |
+
|
18 |
+
def make_kernel(k):
|
19 |
+
k = torch.tensor(k, dtype=torch.float32)
|
20 |
+
|
21 |
+
if k.ndim == 1:
|
22 |
+
k = k[None, :] * k[:, None]
|
23 |
+
|
24 |
+
k /= k.sum()
|
25 |
+
|
26 |
+
return k
|
27 |
+
|
28 |
+
|
29 |
+
class Upsample(nn.Module):
|
30 |
+
def __init__(self, kernel, factor=2):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
self.factor = factor
|
34 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
35 |
+
self.register_buffer('kernel', kernel)
|
36 |
+
|
37 |
+
p = kernel.shape[0] - factor
|
38 |
+
|
39 |
+
pad0 = (p + 1) // 2 + factor - 1
|
40 |
+
pad1 = p // 2
|
41 |
+
|
42 |
+
self.pad = (pad0, pad1)
|
43 |
+
|
44 |
+
def forward(self, input):
|
45 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
46 |
+
|
47 |
+
return out
|
48 |
+
|
49 |
+
|
50 |
+
class Downsample(nn.Module):
|
51 |
+
def __init__(self, kernel, factor=2):
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self.factor = factor
|
55 |
+
kernel = make_kernel(kernel)
|
56 |
+
self.register_buffer('kernel', kernel)
|
57 |
+
|
58 |
+
p = kernel.shape[0] - factor
|
59 |
+
|
60 |
+
pad0 = (p + 1) // 2
|
61 |
+
pad1 = p // 2
|
62 |
+
|
63 |
+
self.pad = (pad0, pad1)
|
64 |
+
|
65 |
+
def forward(self, input):
|
66 |
+
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
67 |
+
|
68 |
+
return out
|
69 |
+
|
70 |
+
|
71 |
+
class Blur(nn.Module):
|
72 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
73 |
+
super().__init__()
|
74 |
+
|
75 |
+
kernel = make_kernel(kernel)
|
76 |
+
|
77 |
+
if upsample_factor > 1:
|
78 |
+
kernel = kernel * (upsample_factor ** 2)
|
79 |
+
|
80 |
+
self.register_buffer('kernel', kernel)
|
81 |
+
|
82 |
+
self.pad = pad
|
83 |
+
|
84 |
+
def forward(self, input):
|
85 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
86 |
+
|
87 |
+
return out
|
88 |
+
|
89 |
+
|
90 |
+
class EqualConv2d(nn.Module):
|
91 |
+
def __init__(
|
92 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
93 |
+
):
|
94 |
+
super().__init__()
|
95 |
+
|
96 |
+
self.weight = nn.Parameter(
|
97 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
98 |
+
)
|
99 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
100 |
+
|
101 |
+
self.stride = stride
|
102 |
+
self.padding = padding
|
103 |
+
|
104 |
+
if bias:
|
105 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
106 |
+
|
107 |
+
else:
|
108 |
+
self.bias = None
|
109 |
+
|
110 |
+
def forward(self, input):
|
111 |
+
out = F.conv2d(
|
112 |
+
input,
|
113 |
+
self.weight * self.scale,
|
114 |
+
bias=self.bias,
|
115 |
+
stride=self.stride,
|
116 |
+
padding=self.padding,
|
117 |
+
)
|
118 |
+
|
119 |
+
return out
|
120 |
+
|
121 |
+
def __repr__(self):
|
122 |
+
return (
|
123 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
124 |
+
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
125 |
+
)
|
126 |
+
|
127 |
+
|
128 |
+
class EqualLinear(nn.Module):
|
129 |
+
def __init__(
|
130 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
131 |
+
):
|
132 |
+
super().__init__()
|
133 |
+
|
134 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
135 |
+
|
136 |
+
if bias:
|
137 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
138 |
+
|
139 |
+
else:
|
140 |
+
self.bias = None
|
141 |
+
|
142 |
+
self.activation = activation
|
143 |
+
|
144 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
145 |
+
self.lr_mul = lr_mul
|
146 |
+
|
147 |
+
def forward(self, input):
|
148 |
+
if self.activation:
|
149 |
+
out = F.linear(input, self.weight * self.scale)
|
150 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
151 |
+
|
152 |
+
else:
|
153 |
+
out = F.linear(
|
154 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
155 |
+
)
|
156 |
+
|
157 |
+
return out
|
158 |
+
|
159 |
+
def __repr__(self):
|
160 |
+
return (
|
161 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
|
162 |
+
)
|
163 |
+
|
164 |
+
|
165 |
+
class ScaledLeakyReLU(nn.Module):
|
166 |
+
def __init__(self, negative_slope=0.2):
|
167 |
+
super().__init__()
|
168 |
+
|
169 |
+
self.negative_slope = negative_slope
|
170 |
+
|
171 |
+
def forward(self, input):
|
172 |
+
out = F.leaky_relu(input, negative_slope=self.negative_slope)
|
173 |
+
|
174 |
+
return out * math.sqrt(2)
|
175 |
+
|
176 |
+
|
177 |
+
class ModulatedConv2d(nn.Module):
|
178 |
+
def __init__(
|
179 |
+
self,
|
180 |
+
in_channel,
|
181 |
+
out_channel,
|
182 |
+
kernel_size,
|
183 |
+
style_dim,
|
184 |
+
demodulate=True,
|
185 |
+
upsample=False,
|
186 |
+
downsample=False,
|
187 |
+
blur_kernel=[1, 3, 3, 1],
|
188 |
+
):
|
189 |
+
super().__init__()
|
190 |
+
|
191 |
+
self.eps = 1e-8
|
192 |
+
self.kernel_size = kernel_size
|
193 |
+
self.in_channel = in_channel
|
194 |
+
self.out_channel = out_channel
|
195 |
+
self.upsample = upsample
|
196 |
+
self.downsample = downsample
|
197 |
+
|
198 |
+
if upsample:
|
199 |
+
factor = 2
|
200 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
201 |
+
pad0 = (p + 1) // 2 + factor - 1
|
202 |
+
pad1 = p // 2 + 1
|
203 |
+
|
204 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
205 |
+
|
206 |
+
if downsample:
|
207 |
+
factor = 2
|
208 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
209 |
+
pad0 = (p + 1) // 2
|
210 |
+
pad1 = p // 2
|
211 |
+
|
212 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
213 |
+
|
214 |
+
fan_in = in_channel * kernel_size ** 2
|
215 |
+
self.scale = 1 / math.sqrt(fan_in)
|
216 |
+
self.padding = kernel_size // 2
|
217 |
+
|
218 |
+
self.weight = nn.Parameter(
|
219 |
+
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
220 |
+
)
|
221 |
+
|
222 |
+
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
223 |
+
|
224 |
+
self.demodulate = demodulate
|
225 |
+
|
226 |
+
def __repr__(self):
|
227 |
+
return (
|
228 |
+
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
|
229 |
+
f'upsample={self.upsample}, downsample={self.downsample})'
|
230 |
+
)
|
231 |
+
|
232 |
+
def forward(self, input, style):
|
233 |
+
batch, in_channel, height, width = input.shape
|
234 |
+
|
235 |
+
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
236 |
+
weight = self.scale * self.weight * style
|
237 |
+
|
238 |
+
if self.demodulate:
|
239 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
240 |
+
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
241 |
+
|
242 |
+
weight = weight.view(
|
243 |
+
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
244 |
+
)
|
245 |
+
|
246 |
+
if self.upsample:
|
247 |
+
input = input.view(1, batch * in_channel, height, width)
|
248 |
+
weight = weight.view(
|
249 |
+
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
250 |
+
)
|
251 |
+
weight = weight.transpose(1, 2).reshape(
|
252 |
+
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
253 |
+
)
|
254 |
+
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
255 |
+
_, _, height, width = out.shape
|
256 |
+
out = out.view(batch, self.out_channel, height, width)
|
257 |
+
out = self.blur(out)
|
258 |
+
|
259 |
+
elif self.downsample:
|
260 |
+
input = self.blur(input)
|
261 |
+
_, _, height, width = input.shape
|
262 |
+
input = input.view(1, batch * in_channel, height, width)
|
263 |
+
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
264 |
+
_, _, height, width = out.shape
|
265 |
+
out = out.view(batch, self.out_channel, height, width)
|
266 |
+
|
267 |
+
else:
|
268 |
+
input = input.view(1, batch * in_channel, height, width)
|
269 |
+
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
|
270 |
+
_, _, height, width = out.shape
|
271 |
+
out = out.view(batch, self.out_channel, height, width)
|
272 |
+
|
273 |
+
return out
|
274 |
+
|
275 |
+
|
276 |
+
class NoiseInjection(nn.Module):
|
277 |
+
def __init__(self):
|
278 |
+
super().__init__()
|
279 |
+
|
280 |
+
self.weight = nn.Parameter(torch.zeros(1))
|
281 |
+
|
282 |
+
def forward(self, image, noise=None):
|
283 |
+
if noise is None:
|
284 |
+
batch, _, height, width = image.shape
|
285 |
+
noise = image.new_empty(batch, 1, height, width).normal_()
|
286 |
+
|
287 |
+
return image + self.weight * noise
|
288 |
+
|
289 |
+
|
290 |
+
class ConstantInput(nn.Module):
|
291 |
+
def __init__(self, channel, size=4):
|
292 |
+
super().__init__()
|
293 |
+
|
294 |
+
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
295 |
+
|
296 |
+
def forward(self, input):
|
297 |
+
batch = input.shape[0]
|
298 |
+
out = self.input.repeat(batch, 1, 1, 1)
|
299 |
+
|
300 |
+
return out
|
301 |
+
|
302 |
+
|
303 |
+
class StyledConv(nn.Module):
|
304 |
+
def __init__(
|
305 |
+
self,
|
306 |
+
in_channel,
|
307 |
+
out_channel,
|
308 |
+
kernel_size,
|
309 |
+
style_dim,
|
310 |
+
upsample=False,
|
311 |
+
blur_kernel=[1, 3, 3, 1],
|
312 |
+
demodulate=True,
|
313 |
+
):
|
314 |
+
super().__init__()
|
315 |
+
|
316 |
+
self.conv = ModulatedConv2d(
|
317 |
+
in_channel,
|
318 |
+
out_channel,
|
319 |
+
kernel_size,
|
320 |
+
style_dim,
|
321 |
+
upsample=upsample,
|
322 |
+
blur_kernel=blur_kernel,
|
323 |
+
demodulate=demodulate,
|
324 |
+
)
|
325 |
+
|
326 |
+
self.noise = NoiseInjection()
|
327 |
+
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
328 |
+
# self.activate = ScaledLeakyReLU(0.2)
|
329 |
+
self.activate = FusedLeakyReLU(out_channel)
|
330 |
+
|
331 |
+
def forward(self, input, style, noise=None):
|
332 |
+
out = self.conv(input, style)
|
333 |
+
out = self.noise(out, noise=noise)
|
334 |
+
# out = out + self.bias
|
335 |
+
out = self.activate(out)
|
336 |
+
|
337 |
+
return out
|
338 |
+
|
339 |
+
|
340 |
+
class ToRGB(nn.Module):
|
341 |
+
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
342 |
+
super().__init__()
|
343 |
+
|
344 |
+
if upsample:
|
345 |
+
self.upsample = Upsample(blur_kernel)
|
346 |
+
|
347 |
+
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
348 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
349 |
+
|
350 |
+
def forward(self, input, style, skip=None):
|
351 |
+
out = self.conv(input, style)
|
352 |
+
out = out + self.bias
|
353 |
+
|
354 |
+
if skip is not None:
|
355 |
+
skip = self.upsample(skip)
|
356 |
+
|
357 |
+
out = out + skip
|
358 |
+
|
359 |
+
return out
|
360 |
+
|
361 |
+
|
362 |
+
class Generator(nn.Module):
|
363 |
+
def __init__(
|
364 |
+
self,
|
365 |
+
size,
|
366 |
+
style_dim,
|
367 |
+
n_mlp,
|
368 |
+
channel_multiplier=2,
|
369 |
+
blur_kernel=[1, 3, 3, 1],
|
370 |
+
lr_mlp=0.01,
|
371 |
+
):
|
372 |
+
super().__init__()
|
373 |
+
|
374 |
+
self.size = size
|
375 |
+
|
376 |
+
self.style_dim = style_dim
|
377 |
+
|
378 |
+
layers = [PixelNorm()]
|
379 |
+
|
380 |
+
for i in range(n_mlp):
|
381 |
+
layers.append(
|
382 |
+
EqualLinear(
|
383 |
+
style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
|
384 |
+
)
|
385 |
+
)
|
386 |
+
|
387 |
+
self.style = nn.Sequential(*layers)
|
388 |
+
|
389 |
+
self.channels = {
|
390 |
+
4: 512,
|
391 |
+
8: 512,
|
392 |
+
16: 512,
|
393 |
+
32: 512,
|
394 |
+
64: 256 * channel_multiplier,
|
395 |
+
128: 128 * channel_multiplier,
|
396 |
+
256: 64 * channel_multiplier,
|
397 |
+
512: 32 * channel_multiplier,
|
398 |
+
1024: 16 * channel_multiplier,
|
399 |
+
}
|
400 |
+
|
401 |
+
self.input = ConstantInput(self.channels[4])
|
402 |
+
self.conv1 = StyledConv(
|
403 |
+
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
404 |
+
)
|
405 |
+
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
406 |
+
|
407 |
+
self.log_size = int(math.log(size, 2))
|
408 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
409 |
+
|
410 |
+
self.convs = nn.ModuleList()
|
411 |
+
self.upsamples = nn.ModuleList()
|
412 |
+
self.to_rgbs = nn.ModuleList()
|
413 |
+
self.noises = nn.Module()
|
414 |
+
|
415 |
+
in_channel = self.channels[4]
|
416 |
+
|
417 |
+
for layer_idx in range(self.num_layers):
|
418 |
+
res = (layer_idx + 5) // 2
|
419 |
+
shape = [1, 1, 2 ** res, 2 ** res]
|
420 |
+
self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
|
421 |
+
|
422 |
+
for i in range(3, self.log_size + 1):
|
423 |
+
out_channel = self.channels[2 ** i]
|
424 |
+
|
425 |
+
self.convs.append(
|
426 |
+
StyledConv(
|
427 |
+
in_channel,
|
428 |
+
out_channel,
|
429 |
+
3,
|
430 |
+
style_dim,
|
431 |
+
upsample=True,
|
432 |
+
blur_kernel=blur_kernel,
|
433 |
+
)
|
434 |
+
)
|
435 |
+
|
436 |
+
self.convs.append(
|
437 |
+
StyledConv(
|
438 |
+
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
439 |
+
)
|
440 |
+
)
|
441 |
+
|
442 |
+
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
443 |
+
|
444 |
+
in_channel = out_channel
|
445 |
+
|
446 |
+
self.n_latent = self.log_size * 2 - 2
|
447 |
+
|
448 |
+
def make_noise(self):
|
449 |
+
device = self.input.input.device
|
450 |
+
|
451 |
+
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
452 |
+
|
453 |
+
for i in range(3, self.log_size + 1):
|
454 |
+
for _ in range(2):
|
455 |
+
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
456 |
+
|
457 |
+
return noises
|
458 |
+
|
459 |
+
def mean_latent(self, n_latent):
|
460 |
+
latent_in = torch.randn(
|
461 |
+
n_latent, self.style_dim, device=self.input.input.device
|
462 |
+
)
|
463 |
+
latent = self.style(latent_in).mean(0, keepdim=True)
|
464 |
+
|
465 |
+
return latent
|
466 |
+
|
467 |
+
def get_latent(self, input):
|
468 |
+
return self.style(input)
|
469 |
+
|
470 |
+
def forward(
|
471 |
+
self,
|
472 |
+
styles,
|
473 |
+
return_latents=False,
|
474 |
+
return_features=False,
|
475 |
+
inject_index=None,
|
476 |
+
truncation=1,
|
477 |
+
truncation_latent=None,
|
478 |
+
input_is_latent=False,
|
479 |
+
noise=None,
|
480 |
+
randomize_noise=True,
|
481 |
+
):
|
482 |
+
if not input_is_latent:
|
483 |
+
styles = [self.style(s) for s in styles]
|
484 |
+
|
485 |
+
if noise is None:
|
486 |
+
if randomize_noise:
|
487 |
+
noise = [None] * self.num_layers
|
488 |
+
else:
|
489 |
+
noise = [
|
490 |
+
getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
|
491 |
+
]
|
492 |
+
|
493 |
+
if truncation < 1:
|
494 |
+
style_t = []
|
495 |
+
|
496 |
+
for style in styles:
|
497 |
+
style_t.append(
|
498 |
+
truncation_latent + truncation * (style - truncation_latent)
|
499 |
+
)
|
500 |
+
|
501 |
+
styles = style_t
|
502 |
+
|
503 |
+
if len(styles) < 2:
|
504 |
+
inject_index = self.n_latent
|
505 |
+
|
506 |
+
if styles[0].ndim < 3:
|
507 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
508 |
+
else:
|
509 |
+
latent = styles[0]
|
510 |
+
|
511 |
+
else:
|
512 |
+
if inject_index is None:
|
513 |
+
inject_index = random.randint(1, self.n_latent - 1)
|
514 |
+
|
515 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
516 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
517 |
+
|
518 |
+
latent = torch.cat([latent, latent2], 1)
|
519 |
+
|
520 |
+
if return_latents:
|
521 |
+
return latent
|
522 |
+
|
523 |
+
out = self.input(latent)
|
524 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
525 |
+
|
526 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
527 |
+
|
528 |
+
i = 1
|
529 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
530 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
531 |
+
):
|
532 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
533 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
534 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
535 |
+
|
536 |
+
i += 2
|
537 |
+
|
538 |
+
image = skip
|
539 |
+
|
540 |
+
if return_features:
|
541 |
+
return image, out
|
542 |
+
else:
|
543 |
+
return image, None
|
544 |
+
|
545 |
+
|
546 |
+
class ConvLayer(nn.Sequential):
|
547 |
+
def __init__(
|
548 |
+
self,
|
549 |
+
in_channel,
|
550 |
+
out_channel,
|
551 |
+
kernel_size,
|
552 |
+
downsample=False,
|
553 |
+
blur_kernel=[1, 3, 3, 1],
|
554 |
+
bias=True,
|
555 |
+
activate=True,
|
556 |
+
):
|
557 |
+
layers = []
|
558 |
+
|
559 |
+
if downsample:
|
560 |
+
factor = 2
|
561 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
562 |
+
pad0 = (p + 1) // 2
|
563 |
+
pad1 = p // 2
|
564 |
+
|
565 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
566 |
+
|
567 |
+
stride = 2
|
568 |
+
self.padding = 0
|
569 |
+
|
570 |
+
else:
|
571 |
+
stride = 1
|
572 |
+
self.padding = kernel_size // 2
|
573 |
+
|
574 |
+
layers.append(
|
575 |
+
EqualConv2d(
|
576 |
+
in_channel,
|
577 |
+
out_channel,
|
578 |
+
kernel_size,
|
579 |
+
padding=self.padding,
|
580 |
+
stride=stride,
|
581 |
+
bias=bias and not activate,
|
582 |
+
)
|
583 |
+
)
|
584 |
+
|
585 |
+
if activate:
|
586 |
+
if bias:
|
587 |
+
layers.append(FusedLeakyReLU(out_channel))
|
588 |
+
|
589 |
+
else:
|
590 |
+
layers.append(ScaledLeakyReLU(0.2))
|
591 |
+
|
592 |
+
super().__init__(*layers)
|
593 |
+
|
594 |
+
|
595 |
+
class ResBlock(nn.Module):
|
596 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
597 |
+
super().__init__()
|
598 |
+
|
599 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
600 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
601 |
+
|
602 |
+
self.skip = ConvLayer(
|
603 |
+
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
604 |
+
)
|
605 |
+
|
606 |
+
def forward(self, input):
|
607 |
+
out = self.conv1(input)
|
608 |
+
out = self.conv2(out)
|
609 |
+
|
610 |
+
skip = self.skip(input)
|
611 |
+
out = (out + skip) / math.sqrt(2)
|
612 |
+
|
613 |
+
return out
|
614 |
+
|
615 |
+
|
616 |
+
class Discriminator(nn.Module):
|
617 |
+
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
618 |
+
super().__init__()
|
619 |
+
|
620 |
+
channels = {
|
621 |
+
4: 512,
|
622 |
+
8: 512,
|
623 |
+
16: 512,
|
624 |
+
32: 512,
|
625 |
+
64: 256 * channel_multiplier,
|
626 |
+
128: 128 * channel_multiplier,
|
627 |
+
256: 64 * channel_multiplier,
|
628 |
+
512: 32 * channel_multiplier,
|
629 |
+
1024: 16 * channel_multiplier,
|
630 |
+
}
|
631 |
+
|
632 |
+
convs = [ConvLayer(3, channels[size], 1)]
|
633 |
+
|
634 |
+
log_size = int(math.log(size, 2))
|
635 |
+
|
636 |
+
in_channel = channels[size]
|
637 |
+
|
638 |
+
for i in range(log_size, 2, -1):
|
639 |
+
out_channel = channels[2 ** (i - 1)]
|
640 |
+
|
641 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
642 |
+
|
643 |
+
in_channel = out_channel
|
644 |
+
|
645 |
+
self.convs = nn.Sequential(*convs)
|
646 |
+
|
647 |
+
self.stddev_group = 4
|
648 |
+
self.stddev_feat = 1
|
649 |
+
|
650 |
+
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
651 |
+
self.final_linear = nn.Sequential(
|
652 |
+
EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
|
653 |
+
EqualLinear(channels[4], 1),
|
654 |
+
)
|
655 |
+
|
656 |
+
def forward(self, input):
|
657 |
+
out = self.convs(input)
|
658 |
+
|
659 |
+
batch, channel, height, width = out.shape
|
660 |
+
group = min(batch, self.stddev_group)
|
661 |
+
stddev = out.view(
|
662 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
663 |
+
)
|
664 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
665 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
666 |
+
stddev = stddev.repeat(group, 1, height, width)
|
667 |
+
out = torch.cat([out, stddev], 1)
|
668 |
+
|
669 |
+
out = self.final_conv(out)
|
670 |
+
|
671 |
+
out = out.view(batch, -1)
|
672 |
+
out = self.final_linear(out)
|
673 |
+
|
674 |
+
return out
|
models/stylegan2/op/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
1 |
+
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
2 |
+
from .upfirdn2d import upfirdn2d
|
models/stylegan2/op/fused_act.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.autograd import Function
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
|
9 |
+
module_path = os.path.dirname(__file__)
|
10 |
+
|
11 |
+
|
12 |
+
class FusedLeakyReLU(nn.Module):
|
13 |
+
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
17 |
+
self.negative_slope = negative_slope
|
18 |
+
self.scale = scale
|
19 |
+
|
20 |
+
def forward(self, input):
|
21 |
+
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
22 |
+
|
23 |
+
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
|
24 |
+
if input.device.type == "cpu":
|
25 |
+
if bias is not None:
|
26 |
+
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
27 |
+
return (
|
28 |
+
F.leaky_relu(
|
29 |
+
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
|
30 |
+
)
|
31 |
+
* scale
|
32 |
+
)
|
33 |
+
|
34 |
+
else:
|
35 |
+
return F.leaky_relu(input, negative_slope=0.2) * scale
|
36 |
+
|
37 |
+
|
models/stylegan2/op/upfirdn2d.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Function
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
module_path = os.path.dirname(__file__)
|
10 |
+
|
11 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
12 |
+
out = upfirdn2d_native(
|
13 |
+
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
|
14 |
+
)
|
15 |
+
|
16 |
+
return out
|
17 |
+
|
18 |
+
|
19 |
+
def upfirdn2d_native(
|
20 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
21 |
+
):
|
22 |
+
_, channel, in_h, in_w = input.shape
|
23 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
24 |
+
|
25 |
+
_, in_h, in_w, minor = input.shape
|
26 |
+
kernel_h, kernel_w = kernel.shape
|
27 |
+
|
28 |
+
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
29 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
30 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
31 |
+
|
32 |
+
out = F.pad(
|
33 |
+
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
34 |
+
)
|
35 |
+
out = out[
|
36 |
+
:,
|
37 |
+
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
38 |
+
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
39 |
+
:,
|
40 |
+
]
|
41 |
+
|
42 |
+
out = out.permute(0, 3, 1, 2)
|
43 |
+
out = out.reshape(
|
44 |
+
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
45 |
+
)
|
46 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
47 |
+
out = F.conv2d(out, w)
|
48 |
+
out = out.reshape(
|
49 |
+
-1,
|
50 |
+
minor,
|
51 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
52 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
53 |
+
)
|
54 |
+
out = out.permute(0, 2, 3, 1)
|
55 |
+
out = out[:, ::down_y, ::down_x, :]
|
56 |
+
|
57 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
58 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
59 |
+
|
60 |
+
return out.view(-1, channel, out_h, out_w)
|
pretrained/ohayou_face.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:89abef3962a9ca6b214f1447e7050725b73d41822d7381e1f4d0f96ac8035381
|
3 |
+
size 363965331
|
pretrained/ohayou_face.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c63de7970a7af6cc5b5c0cf677eb16095f2aaabd68dab41fcc3851bb5c7464f9
|
3 |
+
size 1077486507
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
numpy
|
3 |
+
torchvision
|
4 |
+
Pillow
|
5 |
+
tqdm
|
6 |
+
imageio
|
7 |
+
scipy
|
8 |
+
easydict
|
9 |
+
opensimplex==0.3
|
10 |
+
ninja
|
torch_utils/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
# empty
|