Spaces:
Runtime error
Runtime error
Commit
•
fdd9364
1
Parent(s):
563b842
remove PTI from model
Browse files- app.py +5 -10
- configs.py +86 -0
- models/pti/__init__.py +3 -0
- models/pti/legacy.py +2 -0
- pretrained_models/pti/ffhq.pkl +3 -0
- ris/manipulator.py +1 -1
- ris/wrapper.py +7 -7
app.py
CHANGED
@@ -29,7 +29,7 @@ from ris.blend import blend_latents
|
|
29 |
from ris.model import Generator as RIS_Generator
|
30 |
|
31 |
#from models.pti.manipulator import Manipulator
|
32 |
-
#from models.pti.wrapper import Generator_wrapper
|
33 |
#from models.pti.e4e_projection import projection
|
34 |
|
35 |
from metrics import FaceMetric
|
@@ -334,12 +334,6 @@ with gr.Blocks() as demo:
|
|
334 |
else:
|
335 |
e4e_output = [None, None, None, None, e4e_metrics_text]
|
336 |
output_imgs.extend(e4e_output)
|
337 |
-
if 'PTI' in inverter_bools:
|
338 |
-
pti_output = None, None, None, None
|
339 |
-
#manipulator.set_real_img_projection(src, inv_mode='w+', pti_mode='s')
|
340 |
-
else:
|
341 |
-
pti_output = None, None, None, None
|
342 |
-
output_imgs.extend(pti_output)
|
343 |
return output_imgs
|
344 |
submit_button.click(
|
345 |
submit,
|
@@ -349,9 +343,10 @@ with gr.Blocks() as demo:
|
|
349 |
gd_bool, neutral_text, target_text, alpha, beta,
|
350 |
ris_bool, ref_img
|
351 |
],
|
352 |
-
[
|
353 |
-
|
354 |
-
|
|
|
355 |
)
|
356 |
|
357 |
demo.launch()
|
|
|
29 |
from ris.model import Generator as RIS_Generator
|
30 |
|
31 |
#from models.pti.manipulator import Manipulator
|
32 |
+
#from models.pti.wrapper import Generator as Generator_wrapper
|
33 |
#from models.pti.e4e_projection import projection
|
34 |
|
35 |
from metrics import FaceMetric
|
|
|
334 |
else:
|
335 |
e4e_output = [None, None, None, None, e4e_metrics_text]
|
336 |
output_imgs.extend(e4e_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
return output_imgs
|
338 |
submit_button.click(
|
339 |
submit,
|
|
|
343 |
gd_bool, neutral_text, target_text, alpha, beta,
|
344 |
ris_bool, ref_img
|
345 |
],
|
346 |
+
[
|
347 |
+
output_hyperstyle_invert, output_hyperstyle_mapper, output_hyperstyle_gd, output_hyperstyle_ris, output_hypersyle_metrics,
|
348 |
+
output_e4e_invert, output_e4e_mapper, output_e4e_gd, output_e4e_ris, output_e4e_metrics,
|
349 |
+
]
|
350 |
)
|
351 |
|
352 |
demo.launch()
|
configs.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class GENERATOR_CONFIGS:
|
5 |
+
"""StyleGAN2-ada generator configuration
|
6 |
+
"""
|
7 |
+
def __init__(self, resolution=1024):
|
8 |
+
channel_base = 32768 if resolution >= 1024 else 16384
|
9 |
+
self.G_kwargs = {
|
10 |
+
'class_name': 'training.networks.Generator',
|
11 |
+
'z_dim': 512,
|
12 |
+
'w_dim': 512,
|
13 |
+
'mapping_kwargs': {'num_layers': 8},
|
14 |
+
'synthesis_kwargs': {
|
15 |
+
'channel_base': channel_base,
|
16 |
+
'channel_max': 512,
|
17 |
+
'num_fp16_res': 4,
|
18 |
+
'conv_clamp': 256
|
19 |
+
}
|
20 |
+
}
|
21 |
+
self.common_kwargs = {'c_dim': 0, 'img_resolution': resolution, 'img_channels': 3}
|
22 |
+
self.w_idx_lst = [
|
23 |
+
0,1, # 4
|
24 |
+
1,2,3, # 8
|
25 |
+
3,4,5, # 16
|
26 |
+
5,6,7, # 32
|
27 |
+
7,8,9, # 64
|
28 |
+
9,10,11, # 128
|
29 |
+
11,12,13, # 256
|
30 |
+
13,14,15, # 512
|
31 |
+
15,16,17, # 1024
|
32 |
+
]
|
33 |
+
cutoff_idx = int(np.log2(1024/resolution) * (-3))
|
34 |
+
if resolution < 1024:
|
35 |
+
self.w_idx_lst = self.w_idx_lst[:cutoff_idx]
|
36 |
+
|
37 |
+
|
38 |
+
class PATH_CONFIGS:
|
39 |
+
"""Paths configuration
|
40 |
+
"""
|
41 |
+
def __init__(self):
|
42 |
+
self.e4e = 'pretrained/e4e_ffhq_encode.pt'
|
43 |
+
self.stylegan2_ada_ffhq = 'pretrained/ffhq.pkl'
|
44 |
+
self.ir_se50 = 'pretrained/model_ir_se50.pth'
|
45 |
+
self.dlib = 'pretrained/shape_predictor_68_face_landmarks.dat'
|
46 |
+
|
47 |
+
class PTI_HPARAMS:
|
48 |
+
"""Pivot-tuning-inversion related hyper-parameters
|
49 |
+
"""
|
50 |
+
def __init__(self):
|
51 |
+
# Architectures
|
52 |
+
self.lpips_type = 'alex'
|
53 |
+
self.first_inv_type = 'w+'
|
54 |
+
self.optim_type = 'adam'
|
55 |
+
|
56 |
+
# Locality regularization
|
57 |
+
self.latent_ball_num_of_samples = 1
|
58 |
+
self.locality_regularization_interval = 1
|
59 |
+
self.use_locality_regularization = False
|
60 |
+
self.regulizer_l2_lambda = 0.1
|
61 |
+
self.regulizer_lpips_lambda = 0.1
|
62 |
+
self.regulizer_alpha = 30
|
63 |
+
|
64 |
+
## Loss
|
65 |
+
self.pt_l2_lambda = 1
|
66 |
+
self.pt_lpips_lambda = 1
|
67 |
+
|
68 |
+
## Steps
|
69 |
+
self.LPIPS_value_threshold = 0.06
|
70 |
+
self.max_pti_steps = 350
|
71 |
+
self.first_inv_steps = 450
|
72 |
+
self.max_images_to_invert = 30
|
73 |
+
|
74 |
+
## Optimization
|
75 |
+
self.pti_learning_rate = 3e-4
|
76 |
+
self.first_inv_lr = 5e-3
|
77 |
+
self.train_batch_size = 1
|
78 |
+
|
79 |
+
|
80 |
+
class PTI_GLOBAL_CFGS:
|
81 |
+
def __init__(self):
|
82 |
+
self.training_step = 1
|
83 |
+
self.imgage_rec_result_log_snapshot = 100
|
84 |
+
self.pivotal_training_steps = 0
|
85 |
+
self.model_snapshot_interval = 400
|
86 |
+
self.run_name = ''
|
models/pti/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from . import dnnlib
|
2 |
+
from . import torch_utils
|
3 |
+
from . import training
|
models/pti/legacy.py
CHANGED
@@ -68,6 +68,8 @@ class _LegacyUnpickler(pickle.Unpickler):
|
|
68 |
def find_class(self, module, name):
|
69 |
if module == 'dnnlib.tflib.network' and name == 'Network':
|
70 |
return _TFNetworkStub
|
|
|
|
|
71 |
return super().find_class(module, name)
|
72 |
|
73 |
#----------------------------------------------------------------------------
|
|
|
68 |
def find_class(self, module, name):
|
69 |
if module == 'dnnlib.tflib.network' and name == 'Network':
|
70 |
return _TFNetworkStub
|
71 |
+
if module.split('.')[0] in ['dlib_utils', 'dnnlib', 'torch_utils', 'training']:
|
72 |
+
module = 'models.pti.'+module
|
73 |
return super().find_class(module, name)
|
74 |
|
75 |
#----------------------------------------------------------------------------
|
pretrained_models/pti/ffhq.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a205a346e86a9ddaae702e118097d014b7b8bd719491396a162cca438f2f524c
|
3 |
+
size 381624121
|
ris/manipulator.py
CHANGED
@@ -289,7 +289,7 @@ if __name__ == '__main__':
|
|
289 |
|
290 |
device = torch.device('cuda:0')
|
291 |
ckpt = args.ckpt
|
292 |
-
G =
|
293 |
|
294 |
face_preprocess = args.face_preprocess
|
295 |
dataset_name = args.dataset_name
|
|
|
289 |
|
290 |
device = torch.device('cuda:0')
|
291 |
ckpt = args.ckpt
|
292 |
+
G = Generator_wrapper(ckpt, device)
|
293 |
|
294 |
face_preprocess = args.face_preprocess
|
295 |
dataset_name = args.dataset_name
|
ris/wrapper.py
CHANGED
@@ -7,15 +7,15 @@ import PIL.Image
|
|
7 |
import torch
|
8 |
from torchvision.transforms import transforms
|
9 |
|
10 |
-
import dnnlib
|
11 |
import legacy
|
12 |
-
from
|
13 |
-
from dlib_utils.face_alignment import image_align
|
14 |
-
from dlib_utils.landmarks_detector import LandmarksDetector
|
15 |
-
from torch_utils.misc import copy_params_and_buffers
|
16 |
|
17 |
-
from pivot_tuning_inversion.utils.ImagesDataset import ImagesDataset, ImageLatentsDataset
|
18 |
-
from pivot_tuning_inversion.training.coaches.multi_id_coach import MultiIDCoach
|
19 |
|
20 |
|
21 |
class FaceLandmarksDetector:
|
|
|
7 |
import torch
|
8 |
from torchvision.transforms import transforms
|
9 |
|
10 |
+
import models.pti.dnnlib as dnnlib
|
11 |
import legacy
|
12 |
+
from models.pti.configs import GENERATOR_CONFIGS
|
13 |
+
from models.pti.dlib_utils.face_alignment import image_align
|
14 |
+
from models.pti.dlib_utils.landmarks_detector import LandmarksDetector
|
15 |
+
from models.pti.torch_utils.misc import copy_params_and_buffers
|
16 |
|
17 |
+
from models.pti.pivot_tuning_inversion.utils.ImagesDataset import ImagesDataset, ImageLatentsDataset
|
18 |
+
from models.pti.pivot_tuning_inversion.training.coaches.multi_id_coach import MultiIDCoach
|
19 |
|
20 |
|
21 |
class FaceLandmarksDetector:
|