ethanNeuralImage commited on
Commit
fdd9364
1 Parent(s): 563b842

remove PTI from model

Browse files
app.py CHANGED
@@ -29,7 +29,7 @@ from ris.blend import blend_latents
29
  from ris.model import Generator as RIS_Generator
30
 
31
  #from models.pti.manipulator import Manipulator
32
- #from models.pti.wrapper import Generator_wrapper
33
  #from models.pti.e4e_projection import projection
34
 
35
  from metrics import FaceMetric
@@ -334,12 +334,6 @@ with gr.Blocks() as demo:
334
  else:
335
  e4e_output = [None, None, None, None, e4e_metrics_text]
336
  output_imgs.extend(e4e_output)
337
- if 'PTI' in inverter_bools:
338
- pti_output = None, None, None, None
339
- #manipulator.set_real_img_projection(src, inv_mode='w+', pti_mode='s')
340
- else:
341
- pti_output = None, None, None, None
342
- output_imgs.extend(pti_output)
343
  return output_imgs
344
  submit_button.click(
345
  submit,
@@ -349,9 +343,10 @@ with gr.Blocks() as demo:
349
  gd_bool, neutral_text, target_text, alpha, beta,
350
  ris_bool, ref_img
351
  ],
352
- [output_hyperstyle_invert, output_hyperstyle_mapper, output_hyperstyle_gd, output_hyperstyle_ris, output_hypersyle_metrics,
353
- output_e4e_invert, output_e4e_mapper, output_e4e_gd, output_e4e_ris, output_e4e_metrics,
354
- output_pti_invert, output_pti_mapper, output_pti_gd, output_pti_ris]
 
355
  )
356
 
357
  demo.launch()
 
29
  from ris.model import Generator as RIS_Generator
30
 
31
  #from models.pti.manipulator import Manipulator
32
+ #from models.pti.wrapper import Generator as Generator_wrapper
33
  #from models.pti.e4e_projection import projection
34
 
35
  from metrics import FaceMetric
 
334
  else:
335
  e4e_output = [None, None, None, None, e4e_metrics_text]
336
  output_imgs.extend(e4e_output)
 
 
 
 
 
 
337
  return output_imgs
338
  submit_button.click(
339
  submit,
 
343
  gd_bool, neutral_text, target_text, alpha, beta,
344
  ris_bool, ref_img
345
  ],
346
+ [
347
+ output_hyperstyle_invert, output_hyperstyle_mapper, output_hyperstyle_gd, output_hyperstyle_ris, output_hypersyle_metrics,
348
+ output_e4e_invert, output_e4e_mapper, output_e4e_gd, output_e4e_ris, output_e4e_metrics,
349
+ ]
350
  )
351
 
352
  demo.launch()
configs.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class GENERATOR_CONFIGS:
5
+ """StyleGAN2-ada generator configuration
6
+ """
7
+ def __init__(self, resolution=1024):
8
+ channel_base = 32768 if resolution >= 1024 else 16384
9
+ self.G_kwargs = {
10
+ 'class_name': 'training.networks.Generator',
11
+ 'z_dim': 512,
12
+ 'w_dim': 512,
13
+ 'mapping_kwargs': {'num_layers': 8},
14
+ 'synthesis_kwargs': {
15
+ 'channel_base': channel_base,
16
+ 'channel_max': 512,
17
+ 'num_fp16_res': 4,
18
+ 'conv_clamp': 256
19
+ }
20
+ }
21
+ self.common_kwargs = {'c_dim': 0, 'img_resolution': resolution, 'img_channels': 3}
22
+ self.w_idx_lst = [
23
+ 0,1, # 4
24
+ 1,2,3, # 8
25
+ 3,4,5, # 16
26
+ 5,6,7, # 32
27
+ 7,8,9, # 64
28
+ 9,10,11, # 128
29
+ 11,12,13, # 256
30
+ 13,14,15, # 512
31
+ 15,16,17, # 1024
32
+ ]
33
+ cutoff_idx = int(np.log2(1024/resolution) * (-3))
34
+ if resolution < 1024:
35
+ self.w_idx_lst = self.w_idx_lst[:cutoff_idx]
36
+
37
+
38
+ class PATH_CONFIGS:
39
+ """Paths configuration
40
+ """
41
+ def __init__(self):
42
+ self.e4e = 'pretrained/e4e_ffhq_encode.pt'
43
+ self.stylegan2_ada_ffhq = 'pretrained/ffhq.pkl'
44
+ self.ir_se50 = 'pretrained/model_ir_se50.pth'
45
+ self.dlib = 'pretrained/shape_predictor_68_face_landmarks.dat'
46
+
47
+ class PTI_HPARAMS:
48
+ """Pivot-tuning-inversion related hyper-parameters
49
+ """
50
+ def __init__(self):
51
+ # Architectures
52
+ self.lpips_type = 'alex'
53
+ self.first_inv_type = 'w+'
54
+ self.optim_type = 'adam'
55
+
56
+ # Locality regularization
57
+ self.latent_ball_num_of_samples = 1
58
+ self.locality_regularization_interval = 1
59
+ self.use_locality_regularization = False
60
+ self.regulizer_l2_lambda = 0.1
61
+ self.regulizer_lpips_lambda = 0.1
62
+ self.regulizer_alpha = 30
63
+
64
+ ## Loss
65
+ self.pt_l2_lambda = 1
66
+ self.pt_lpips_lambda = 1
67
+
68
+ ## Steps
69
+ self.LPIPS_value_threshold = 0.06
70
+ self.max_pti_steps = 350
71
+ self.first_inv_steps = 450
72
+ self.max_images_to_invert = 30
73
+
74
+ ## Optimization
75
+ self.pti_learning_rate = 3e-4
76
+ self.first_inv_lr = 5e-3
77
+ self.train_batch_size = 1
78
+
79
+
80
+ class PTI_GLOBAL_CFGS:
81
+ def __init__(self):
82
+ self.training_step = 1
83
+ self.imgage_rec_result_log_snapshot = 100
84
+ self.pivotal_training_steps = 0
85
+ self.model_snapshot_interval = 400
86
+ self.run_name = ''
models/pti/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import dnnlib
2
+ from . import torch_utils
3
+ from . import training
models/pti/legacy.py CHANGED
@@ -68,6 +68,8 @@ class _LegacyUnpickler(pickle.Unpickler):
68
  def find_class(self, module, name):
69
  if module == 'dnnlib.tflib.network' and name == 'Network':
70
  return _TFNetworkStub
 
 
71
  return super().find_class(module, name)
72
 
73
  #----------------------------------------------------------------------------
 
68
  def find_class(self, module, name):
69
  if module == 'dnnlib.tflib.network' and name == 'Network':
70
  return _TFNetworkStub
71
+ if module.split('.')[0] in ['dlib_utils', 'dnnlib', 'torch_utils', 'training']:
72
+ module = 'models.pti.'+module
73
  return super().find_class(module, name)
74
 
75
  #----------------------------------------------------------------------------
pretrained_models/pti/ffhq.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a205a346e86a9ddaae702e118097d014b7b8bd719491396a162cca438f2f524c
3
+ size 381624121
ris/manipulator.py CHANGED
@@ -289,7 +289,7 @@ if __name__ == '__main__':
289
 
290
  device = torch.device('cuda:0')
291
  ckpt = args.ckpt
292
- G = Generator(ckpt, device)
293
 
294
  face_preprocess = args.face_preprocess
295
  dataset_name = args.dataset_name
 
289
 
290
  device = torch.device('cuda:0')
291
  ckpt = args.ckpt
292
+ G = Generator_wrapper(ckpt, device)
293
 
294
  face_preprocess = args.face_preprocess
295
  dataset_name = args.dataset_name
ris/wrapper.py CHANGED
@@ -7,15 +7,15 @@ import PIL.Image
7
  import torch
8
  from torchvision.transforms import transforms
9
 
10
- import dnnlib
11
  import legacy
12
- from configs_gd import GENERATOR_CONFIGS
13
- from dlib_utils.face_alignment import image_align
14
- from dlib_utils.landmarks_detector import LandmarksDetector
15
- from torch_utils.misc import copy_params_and_buffers
16
 
17
- from pivot_tuning_inversion.utils.ImagesDataset import ImagesDataset, ImageLatentsDataset
18
- from pivot_tuning_inversion.training.coaches.multi_id_coach import MultiIDCoach
19
 
20
 
21
  class FaceLandmarksDetector:
 
7
  import torch
8
  from torchvision.transforms import transforms
9
 
10
+ import models.pti.dnnlib as dnnlib
11
  import legacy
12
+ from models.pti.configs import GENERATOR_CONFIGS
13
+ from models.pti.dlib_utils.face_alignment import image_align
14
+ from models.pti.dlib_utils.landmarks_detector import LandmarksDetector
15
+ from models.pti.torch_utils.misc import copy_params_and_buffers
16
 
17
+ from models.pti.pivot_tuning_inversion.utils.ImagesDataset import ImagesDataset, ImageLatentsDataset
18
+ from models.pti.pivot_tuning_inversion.training.coaches.multi_id_coach import MultiIDCoach
19
 
20
 
21
  class FaceLandmarksDetector: