Sanket commited on
Commit
3d37b6e
1 Parent(s): ff4715d
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -6
  2. LICENSE +21 -0
  3. README.md +32 -6
  4. app.py +204 -0
  5. e4e/.gitignore +129 -0
  6. e4e/criteria/__init__.py +0 -0
  7. e4e/criteria/id_loss.py +47 -0
  8. e4e/criteria/lpips/__init__.py +0 -0
  9. e4e/criteria/lpips/lpips.py +35 -0
  10. e4e/criteria/lpips/networks.py +96 -0
  11. e4e/criteria/lpips/utils.py +30 -0
  12. e4e/criteria/moco_loss.py +71 -0
  13. e4e/criteria/w_norm.py +14 -0
  14. e4e/datasets/__init__.py +0 -0
  15. e4e/datasets/gt_res_dataset.py +32 -0
  16. e4e/datasets/images_dataset.py +33 -0
  17. e4e/datasets/inference_dataset.py +25 -0
  18. e4e/editings/ganspace.py +22 -0
  19. e4e/editings/ganspace_pca/cars_pca.pt +3 -0
  20. e4e/editings/ganspace_pca/ffhq_pca.pt +3 -0
  21. e4e/editings/interfacegan_directions/age.pt +3 -0
  22. e4e/editings/interfacegan_directions/pose.pt +3 -0
  23. e4e/editings/interfacegan_directions/smile.pt +3 -0
  24. e4e/editings/latent_editor.py +45 -0
  25. e4e/editings/sefa.py +46 -0
  26. e4e/environment/e4e_env.yaml +73 -0
  27. e4e/metrics/LEC.py +134 -0
  28. e4e/models/__init__.py +0 -0
  29. e4e/models/discriminator.py +20 -0
  30. e4e/models/encoders/__init__.py +0 -0
  31. e4e/models/encoders/helpers.py +140 -0
  32. e4e/models/encoders/model_irse.py +84 -0
  33. e4e/models/encoders/psp_encoders.py +200 -0
  34. e4e/models/latent_codes_pool.py +55 -0
  35. e4e/models/psp.py +99 -0
  36. e4e/models/stylegan2/__init__.py +0 -0
  37. e4e/models/stylegan2/model.py +678 -0
  38. e4e/models/stylegan2/op/__init__.py +0 -0
  39. e4e/models/stylegan2/op/fused_act.py +85 -0
  40. e4e/models/stylegan2/op/fused_bias_act.cpp +21 -0
  41. e4e/models/stylegan2/op/fused_bias_act_kernel.cu +99 -0
  42. e4e/models/stylegan2/op/upfirdn2d.cpp +23 -0
  43. e4e/models/stylegan2/op/upfirdn2d.py +184 -0
  44. e4e/models/stylegan2/op/upfirdn2d_kernel.cu +272 -0
  45. e4e/notebooks/images/car_img.jpg +0 -0
  46. e4e/notebooks/images/church_img.jpg +0 -0
  47. e4e/notebooks/images/horse_img.jpg +0 -0
  48. e4e/notebooks/images/input_img.jpg +0 -0
  49. e4e/options/__init__.py +0 -0
  50. e4e/options/train_options.py +84 -0
.gitattributes CHANGED
@@ -1,6 +1,7 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ftz filter=lfs diff=lfs merge=lfs -text
6
  *.gz filter=lfs diff=lfs merge=lfs -text
@@ -9,13 +10,9 @@
9
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
  *.model filter=lfs diff=lfs merge=lfs -text
11
  *.msgpack filter=lfs diff=lfs merge=lfs -text
12
- *.npy filter=lfs diff=lfs merge=lfs -text
13
- *.npz filter=lfs diff=lfs merge=lfs -text
14
  *.onnx filter=lfs diff=lfs merge=lfs -text
15
  *.ot filter=lfs diff=lfs merge=lfs -text
16
  *.parquet filter=lfs diff=lfs merge=lfs -text
17
- *.pickle filter=lfs diff=lfs merge=lfs -text
18
- *.pkl filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
  *.pt filter=lfs diff=lfs merge=lfs -text
21
  *.pth filter=lfs diff=lfs merge=lfs -text
@@ -24,8 +21,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
  *.tar.* filter=lfs diff=lfs merge=lfs -text
25
  *.tflite filter=lfs diff=lfs merge=lfs -text
26
  *.tgz filter=lfs diff=lfs merge=lfs -text
27
- *.wasm filter=lfs diff=lfs merge=lfs -text
28
  *.xz filter=lfs diff=lfs merge=lfs -text
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
- *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
 
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
 
 
16
  *.pb filter=lfs diff=lfs merge=lfs -text
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
 
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Min Jin Chong
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,38 @@
1
  ---
2
- title: JoJoGan Powerhow2
3
- emoji: 📚
4
- colorFrom: red
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 3.2
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: JoJoGAN
3
+ emoji: 🌍
4
+ colorFrom: green
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.1.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # Configuration
13
+
14
+ `title`: _string_
15
+ Display title for the Space
16
+
17
+ `emoji`: _string_
18
+ Space emoji (emoji-only character allowed)
19
+
20
+ `colorFrom`: _string_
21
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
22
+
23
+ `colorTo`: _string_
24
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
25
+
26
+ `sdk`: _string_
27
+ Can be either `gradio` or `streamlit`
28
+
29
+ `sdk_version` : _string_
30
+ Only applicable for `streamlit` SDK.
31
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
32
+
33
+ `app_file`: _string_
34
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
35
+ Path is relative to the root of the repository.
36
+
37
+ `pinned`: _boolean_
38
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import torch
4
+ import gradio as gr
5
+ import torch
6
+ torch.backends.cudnn.benchmark = True
7
+ from torchvision import transforms, utils
8
+ from util import *
9
+ from PIL import Image
10
+ import math
11
+ import random
12
+ import numpy as np
13
+ from torch import nn, autograd, optim
14
+ from torch.nn import functional as F
15
+ from tqdm import tqdm
16
+ import lpips
17
+ from model import *
18
+
19
+
20
+ #from e4e_projection import projection as e4e_projection
21
+
22
+ from copy import deepcopy
23
+ import imageio
24
+
25
+ import os
26
+ import sys
27
+ import numpy as np
28
+ from PIL import Image
29
+ import torch
30
+ import torchvision.transforms as transforms
31
+ from argparse import Namespace
32
+ from e4e.models.psp import pSp
33
+ from util import *
34
+ from huggingface_hub import hf_hub_download
35
+
36
+ device= 'cpu'
37
+ model_path_e = hf_hub_download(repo_id="akhaliq/JoJoGAN_e4e_ffhq_encode", filename="e4e_ffhq_encode.pt")
38
+ ckpt = torch.load(model_path_e, map_location='cpu')
39
+ opts = ckpt['opts']
40
+ opts['checkpoint_path'] = model_path_e
41
+ opts= Namespace(**opts)
42
+ net = pSp(opts, device).eval().to(device)
43
+
44
+ @ torch.no_grad()
45
+ def projection(img, name, device='cuda'):
46
+
47
+
48
+ transform = transforms.Compose(
49
+ [
50
+ transforms.Resize(256),
51
+ transforms.CenterCrop(256),
52
+ transforms.ToTensor(),
53
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
54
+ ]
55
+ )
56
+ img = transform(img).unsqueeze(0).to(device)
57
+ images, w_plus = net(img, randomize_noise=False, return_latents=True)
58
+ result_file = {}
59
+ result_file['latent'] = w_plus[0]
60
+ torch.save(result_file, name)
61
+ return w_plus[0]
62
+
63
+
64
+
65
+
66
+ device = 'cpu'
67
+
68
+
69
+ latent_dim = 512
70
+
71
+ model_path_s = hf_hub_download(repo_id="akhaliq/jojogan-stylegan2-ffhq-config-f", filename="stylegan2-ffhq-config-f.pt")
72
+ original_generator = Generator(1024, latent_dim, 8, 2).to(device)
73
+ ckpt = torch.load(model_path_s, map_location=lambda storage, loc: storage)
74
+ original_generator.load_state_dict(ckpt["g_ema"], strict=False)
75
+ mean_latent = original_generator.mean_latent(10000)
76
+
77
+ generatorjojo = deepcopy(original_generator)
78
+
79
+ generatordisney = deepcopy(original_generator)
80
+
81
+ generatorjinx = deepcopy(original_generator)
82
+
83
+ generatorcaitlyn = deepcopy(original_generator)
84
+
85
+ generatoryasuho = deepcopy(original_generator)
86
+
87
+ generatorarcanemulti = deepcopy(original_generator)
88
+
89
+ generatorart = deepcopy(original_generator)
90
+
91
+ generatorspider = deepcopy(original_generator)
92
+
93
+ generatorsketch = deepcopy(original_generator)
94
+
95
+
96
+ transform = transforms.Compose(
97
+ [
98
+ transforms.Resize((1024, 1024)),
99
+ transforms.ToTensor(),
100
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
101
+ ]
102
+ )
103
+
104
+
105
+
106
+
107
+ modeljojo = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_preserve_color.pt")
108
+
109
+
110
+ ckptjojo = torch.load(modeljojo, map_location=lambda storage, loc: storage)
111
+ generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
112
+
113
+
114
+ modeldisney = hf_hub_download(repo_id="akhaliq/jojogan-disney", filename="disney_preserve_color.pt")
115
+
116
+ ckptdisney = torch.load(modeldisney, map_location=lambda storage, loc: storage)
117
+ generatordisney.load_state_dict(ckptdisney["g"], strict=False)
118
+
119
+
120
+ modeljinx = hf_hub_download(repo_id="akhaliq/jojo-gan-jinx", filename="arcane_jinx_preserve_color.pt")
121
+
122
+ ckptjinx = torch.load(modeljinx, map_location=lambda storage, loc: storage)
123
+ generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
124
+
125
+
126
+ modelcaitlyn = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_caitlyn_preserve_color.pt")
127
+
128
+ ckptcaitlyn = torch.load(modelcaitlyn, map_location=lambda storage, loc: storage)
129
+ generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False)
130
+
131
+
132
+ modelyasuho = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_yasuho_preserve_color.pt")
133
+
134
+ ckptyasuho = torch.load(modelyasuho, map_location=lambda storage, loc: storage)
135
+ generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False)
136
+
137
+
138
+ model_arcane_multi = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_multi_preserve_color.pt")
139
+
140
+ ckptarcanemulti = torch.load(model_arcane_multi, map_location=lambda storage, loc: storage)
141
+ generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False)
142
+
143
+
144
+ modelart = hf_hub_download(repo_id="akhaliq/jojo-gan-art", filename="art.pt")
145
+
146
+ ckptart = torch.load(modelart, map_location=lambda storage, loc: storage)
147
+ generatorart.load_state_dict(ckptart["g"], strict=False)
148
+
149
+
150
+ modelSpiderverse = hf_hub_download(repo_id="akhaliq/jojo-gan-spiderverse", filename="Spiderverse-face-500iters-8face.pt")
151
+
152
+ ckptspider = torch.load(modelSpiderverse, map_location=lambda storage, loc: storage)
153
+ generatorspider.load_state_dict(ckptspider["g"], strict=False)
154
+
155
+ modelSketch = hf_hub_download(repo_id="akhaliq/jojogan-sketch", filename="sketch_multi.pt")
156
+
157
+ ckptsketch = torch.load(modelSketch, map_location=lambda storage, loc: storage)
158
+ generatorsketch.load_state_dict(ckptsketch["g"], strict=False)
159
+
160
+ def inference(img, model):
161
+ img.save('out.jpg')
162
+ aligned_face = align_face('out.jpg')
163
+
164
+ my_w = projection(aligned_face, "test.pt", device).unsqueeze(0)
165
+ if model == 'JoJo':
166
+ with torch.no_grad():
167
+ my_sample = generatorjojo(my_w, input_is_latent=True)
168
+ elif model == 'Disney':
169
+ with torch.no_grad():
170
+ my_sample = generatordisney(my_w, input_is_latent=True)
171
+ elif model == 'Jinx':
172
+ with torch.no_grad():
173
+ my_sample = generatorjinx(my_w, input_is_latent=True)
174
+ elif model == 'Caitlyn':
175
+ with torch.no_grad():
176
+ my_sample = generatorcaitlyn(my_w, input_is_latent=True)
177
+ elif model == 'Yasuho':
178
+ with torch.no_grad():
179
+ my_sample = generatoryasuho(my_w, input_is_latent=True)
180
+ elif model == 'Arcane Multi':
181
+ with torch.no_grad():
182
+ my_sample = generatorarcanemulti(my_w, input_is_latent=True)
183
+ elif model == 'Art':
184
+ with torch.no_grad():
185
+ my_sample = generatorart(my_w, input_is_latent=True)
186
+ elif model == 'Spider-Verse':
187
+ with torch.no_grad():
188
+ my_sample = generatorspider(my_w, input_is_latent=True)
189
+ else:
190
+ with torch.no_grad():
191
+ my_sample = generatorsketch(my_w, input_is_latent=True)
192
+
193
+
194
+ npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
195
+ imageio.imwrite('filename.jpeg', npimage)
196
+ return 'filename.jpeg'
197
+
198
+ title = "JoJoGAN"
199
+ description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
200
+
201
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>"
202
+
203
+ examples=[['mona.png','Jinx']]
204
+ gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse','Sketch'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch()
e4e/.gitignore ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
e4e/criteria/__init__.py ADDED
File without changes
e4e/criteria/id_loss.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from configs.paths_config import model_paths
4
+ from models.encoders.model_irse import Backbone
5
+
6
+
7
+ class IDLoss(nn.Module):
8
+ def __init__(self):
9
+ super(IDLoss, self).__init__()
10
+ print('Loading ResNet ArcFace')
11
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
12
+ self.facenet.load_state_dict(torch.load(model_paths['ir_se50']))
13
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
14
+ self.facenet.eval()
15
+ for module in [self.facenet, self.face_pool]:
16
+ for param in module.parameters():
17
+ param.requires_grad = False
18
+
19
+ def extract_feats(self, x):
20
+ x = x[:, :, 35:223, 32:220] # Crop interesting region
21
+ x = self.face_pool(x)
22
+ x_feats = self.facenet(x)
23
+ return x_feats
24
+
25
+ def forward(self, y_hat, y, x):
26
+ n_samples = x.shape[0]
27
+ x_feats = self.extract_feats(x)
28
+ y_feats = self.extract_feats(y) # Otherwise use the feature from there
29
+ y_hat_feats = self.extract_feats(y_hat)
30
+ y_feats = y_feats.detach()
31
+ loss = 0
32
+ sim_improvement = 0
33
+ id_logs = []
34
+ count = 0
35
+ for i in range(n_samples):
36
+ diff_target = y_hat_feats[i].dot(y_feats[i])
37
+ diff_input = y_hat_feats[i].dot(x_feats[i])
38
+ diff_views = y_feats[i].dot(x_feats[i])
39
+ id_logs.append({'diff_target': float(diff_target),
40
+ 'diff_input': float(diff_input),
41
+ 'diff_views': float(diff_views)})
42
+ loss += 1 - diff_target
43
+ id_diff = float(diff_target) - float(diff_views)
44
+ sim_improvement += id_diff
45
+ count += 1
46
+
47
+ return loss / count, sim_improvement / count, id_logs
e4e/criteria/lpips/__init__.py ADDED
File without changes
e4e/criteria/lpips/lpips.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from criteria.lpips.networks import get_network, LinLayers
5
+ from criteria.lpips.utils import get_state_dict
6
+
7
+
8
+ class LPIPS(nn.Module):
9
+ r"""Creates a criterion that measures
10
+ Learned Perceptual Image Patch Similarity (LPIPS).
11
+ Arguments:
12
+ net_type (str): the network type to compare the features:
13
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
14
+ version (str): the version of LPIPS. Default: 0.1.
15
+ """
16
+ def __init__(self, net_type: str = 'alex', version: str = '0.1'):
17
+
18
+ assert version in ['0.1'], 'v0.1 is only supported now'
19
+
20
+ super(LPIPS, self).__init__()
21
+
22
+ # pretrained network
23
+ self.net = get_network(net_type).to("cuda")
24
+
25
+ # linear layers
26
+ self.lin = LinLayers(self.net.n_channels_list).to("cuda")
27
+ self.lin.load_state_dict(get_state_dict(net_type, version))
28
+
29
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
30
+ feat_x, feat_y = self.net(x), self.net(y)
31
+
32
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
33
+ res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
34
+
35
+ return torch.sum(torch.cat(res, 0)) / x.shape[0]
e4e/criteria/lpips/networks.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ from itertools import chain
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from criteria.lpips.utils import normalize_activation
10
+
11
+
12
+ def get_network(net_type: str):
13
+ if net_type == 'alex':
14
+ return AlexNet()
15
+ elif net_type == 'squeeze':
16
+ return SqueezeNet()
17
+ elif net_type == 'vgg':
18
+ return VGG16()
19
+ else:
20
+ raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21
+
22
+
23
+ class LinLayers(nn.ModuleList):
24
+ def __init__(self, n_channels_list: Sequence[int]):
25
+ super(LinLayers, self).__init__([
26
+ nn.Sequential(
27
+ nn.Identity(),
28
+ nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29
+ ) for nc in n_channels_list
30
+ ])
31
+
32
+ for param in self.parameters():
33
+ param.requires_grad = False
34
+
35
+
36
+ class BaseNet(nn.Module):
37
+ def __init__(self):
38
+ super(BaseNet, self).__init__()
39
+
40
+ # register buffer
41
+ self.register_buffer(
42
+ 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43
+ self.register_buffer(
44
+ 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45
+
46
+ def set_requires_grad(self, state: bool):
47
+ for param in chain(self.parameters(), self.buffers()):
48
+ param.requires_grad = state
49
+
50
+ def z_score(self, x: torch.Tensor):
51
+ return (x - self.mean) / self.std
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ x = self.z_score(x)
55
+
56
+ output = []
57
+ for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58
+ x = layer(x)
59
+ if i in self.target_layers:
60
+ output.append(normalize_activation(x))
61
+ if len(output) == len(self.target_layers):
62
+ break
63
+ return output
64
+
65
+
66
+ class SqueezeNet(BaseNet):
67
+ def __init__(self):
68
+ super(SqueezeNet, self).__init__()
69
+
70
+ self.layers = models.squeezenet1_1(True).features
71
+ self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72
+ self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73
+
74
+ self.set_requires_grad(False)
75
+
76
+
77
+ class AlexNet(BaseNet):
78
+ def __init__(self):
79
+ super(AlexNet, self).__init__()
80
+
81
+ self.layers = models.alexnet(True).features
82
+ self.target_layers = [2, 5, 8, 10, 12]
83
+ self.n_channels_list = [64, 192, 384, 256, 256]
84
+
85
+ self.set_requires_grad(False)
86
+
87
+
88
+ class VGG16(BaseNet):
89
+ def __init__(self):
90
+ super(VGG16, self).__init__()
91
+
92
+ self.layers = models.vgg16(True).features
93
+ self.target_layers = [4, 9, 16, 23, 30]
94
+ self.n_channels_list = [64, 128, 256, 512, 512]
95
+
96
+ self.set_requires_grad(False)
e4e/criteria/lpips/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+
5
+
6
+ def normalize_activation(x, eps=1e-10):
7
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8
+ return x / (norm_factor + eps)
9
+
10
+
11
+ def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12
+ # build url
13
+ url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14
+ + f'master/lpips/weights/v{version}/{net_type}.pth'
15
+
16
+ # download
17
+ old_state_dict = torch.hub.load_state_dict_from_url(
18
+ url, progress=True,
19
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
20
+ )
21
+
22
+ # rename keys
23
+ new_state_dict = OrderedDict()
24
+ for key, val in old_state_dict.items():
25
+ new_key = key
26
+ new_key = new_key.replace('lin', '')
27
+ new_key = new_key.replace('model.', '')
28
+ new_state_dict[new_key] = val
29
+
30
+ return new_state_dict
e4e/criteria/moco_loss.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from configs.paths_config import model_paths
6
+
7
+
8
+ class MocoLoss(nn.Module):
9
+
10
+ def __init__(self, opts):
11
+ super(MocoLoss, self).__init__()
12
+ print("Loading MOCO model from path: {}".format(model_paths["moco"]))
13
+ self.model = self.__load_model()
14
+ self.model.eval()
15
+ for param in self.model.parameters():
16
+ param.requires_grad = False
17
+
18
+ @staticmethod
19
+ def __load_model():
20
+ import torchvision.models as models
21
+ model = models.__dict__["resnet50"]()
22
+ # freeze all layers but the last fc
23
+ for name, param in model.named_parameters():
24
+ if name not in ['fc.weight', 'fc.bias']:
25
+ param.requires_grad = False
26
+ checkpoint = torch.load(model_paths['moco'], map_location="cpu")
27
+ state_dict = checkpoint['state_dict']
28
+ # rename moco pre-trained keys
29
+ for k in list(state_dict.keys()):
30
+ # retain only encoder_q up to before the embedding layer
31
+ if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
32
+ # remove prefix
33
+ state_dict[k[len("module.encoder_q."):]] = state_dict[k]
34
+ # delete renamed or unused k
35
+ del state_dict[k]
36
+ msg = model.load_state_dict(state_dict, strict=False)
37
+ assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
38
+ # remove output layer
39
+ model = nn.Sequential(*list(model.children())[:-1]).cuda()
40
+ return model
41
+
42
+ def extract_feats(self, x):
43
+ x = F.interpolate(x, size=224)
44
+ x_feats = self.model(x)
45
+ x_feats = nn.functional.normalize(x_feats, dim=1)
46
+ x_feats = x_feats.squeeze()
47
+ return x_feats
48
+
49
+ def forward(self, y_hat, y, x):
50
+ n_samples = x.shape[0]
51
+ x_feats = self.extract_feats(x)
52
+ y_feats = self.extract_feats(y)
53
+ y_hat_feats = self.extract_feats(y_hat)
54
+ y_feats = y_feats.detach()
55
+ loss = 0
56
+ sim_improvement = 0
57
+ sim_logs = []
58
+ count = 0
59
+ for i in range(n_samples):
60
+ diff_target = y_hat_feats[i].dot(y_feats[i])
61
+ diff_input = y_hat_feats[i].dot(x_feats[i])
62
+ diff_views = y_feats[i].dot(x_feats[i])
63
+ sim_logs.append({'diff_target': float(diff_target),
64
+ 'diff_input': float(diff_input),
65
+ 'diff_views': float(diff_views)})
66
+ loss += 1 - diff_target
67
+ sim_diff = float(diff_target) - float(diff_views)
68
+ sim_improvement += sim_diff
69
+ count += 1
70
+
71
+ return loss / count, sim_improvement / count, sim_logs
e4e/criteria/w_norm.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class WNormLoss(nn.Module):
6
+
7
+ def __init__(self, start_from_latent_avg=True):
8
+ super(WNormLoss, self).__init__()
9
+ self.start_from_latent_avg = start_from_latent_avg
10
+
11
+ def forward(self, latent, latent_avg=None):
12
+ if self.start_from_latent_avg:
13
+ latent = latent - latent_avg
14
+ return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]
e4e/datasets/__init__.py ADDED
File without changes
e4e/datasets/gt_res_dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # encoding: utf-8
3
+ import os
4
+ from torch.utils.data import Dataset
5
+ from PIL import Image
6
+ import torch
7
+
8
+ class GTResDataset(Dataset):
9
+
10
+ def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
11
+ self.pairs = []
12
+ for f in os.listdir(root_path):
13
+ image_path = os.path.join(root_path, f)
14
+ gt_path = os.path.join(gt_dir, f)
15
+ if f.endswith(".jpg") or f.endswith(".png"):
16
+ self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
17
+ self.transform = transform
18
+ self.transform_train = transform_train
19
+
20
+ def __len__(self):
21
+ return len(self.pairs)
22
+
23
+ def __getitem__(self, index):
24
+ from_path, to_path, _ = self.pairs[index]
25
+ from_im = Image.open(from_path).convert('RGB')
26
+ to_im = Image.open(to_path).convert('RGB')
27
+
28
+ if self.transform:
29
+ to_im = self.transform(to_im)
30
+ from_im = self.transform(from_im)
31
+
32
+ return from_im, to_im
e4e/datasets/images_dataset.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ from utils import data_utils
4
+
5
+
6
+ class ImagesDataset(Dataset):
7
+
8
+ def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None):
9
+ self.source_paths = sorted(data_utils.make_dataset(source_root))
10
+ self.target_paths = sorted(data_utils.make_dataset(target_root))
11
+ self.source_transform = source_transform
12
+ self.target_transform = target_transform
13
+ self.opts = opts
14
+
15
+ def __len__(self):
16
+ return len(self.source_paths)
17
+
18
+ def __getitem__(self, index):
19
+ from_path = self.source_paths[index]
20
+ from_im = Image.open(from_path)
21
+ from_im = from_im.convert('RGB')
22
+
23
+ to_path = self.target_paths[index]
24
+ to_im = Image.open(to_path).convert('RGB')
25
+ if self.target_transform:
26
+ to_im = self.target_transform(to_im)
27
+
28
+ if self.source_transform:
29
+ from_im = self.source_transform(from_im)
30
+ else:
31
+ from_im = to_im
32
+
33
+ return from_im, to_im
e4e/datasets/inference_dataset.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ from utils import data_utils
4
+
5
+
6
+ class InferenceDataset(Dataset):
7
+
8
+ def __init__(self, root, opts, transform=None, preprocess=None):
9
+ self.paths = sorted(data_utils.make_dataset(root))
10
+ self.transform = transform
11
+ self.preprocess = preprocess
12
+ self.opts = opts
13
+
14
+ def __len__(self):
15
+ return len(self.paths)
16
+
17
+ def __getitem__(self, index):
18
+ from_path = self.paths[index]
19
+ if self.preprocess is not None:
20
+ from_im = self.preprocess(from_path)
21
+ else:
22
+ from_im = Image.open(from_path).convert('RGB')
23
+ if self.transform:
24
+ from_im = self.transform(from_im)
25
+ return from_im
e4e/editings/ganspace.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def edit(latents, pca, edit_directions):
5
+ edit_latents = []
6
+ for latent in latents:
7
+ for pca_idx, start, end, strength in edit_directions:
8
+ delta = get_delta(pca, latent, pca_idx, strength)
9
+ delta_padded = torch.zeros(latent.shape).to('cuda')
10
+ delta_padded[start:end] += delta.repeat(end - start, 1)
11
+ edit_latents.append(latent + delta_padded)
12
+ return torch.stack(edit_latents)
13
+
14
+
15
+ def get_delta(pca, latent, idx, strength):
16
+ # pca: ganspace checkpoint. latent: (16, 512) w+
17
+ w_centered = latent - pca['mean'].to('cuda')
18
+ lat_comp = pca['comp'].to('cuda')
19
+ lat_std = pca['std'].to('cuda')
20
+ w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx]
21
+ delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx]
22
+ return delta
e4e/editings/ganspace_pca/cars_pca.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5c3bae61ecd85de077fbbf103f5f30cf4b7676fe23a8508166eaf2ce73c8392
3
+ size 167562
e4e/editings/ganspace_pca/ffhq_pca.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d7f9df1c96180d9026b9cb8d04753579fbf385f321a9d0e263641601c5e5d36
3
+ size 167562
e4e/editings/interfacegan_directions/age.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50074516b1629707d89b5e43d6b8abd1792212fa3b961a87a11323d6a5222ae0
3
+ size 2808
e4e/editings/interfacegan_directions/pose.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:736e0eacc8488fa0b020a2e7bd256b957284c364191dfea693705e5d06d43e7d
3
+ size 37624
e4e/editings/interfacegan_directions/smile.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:817a7e732b59dee9eba862bec8bd7e8373568443bc9f9731a21cf9b0356f0653
3
+ size 2808
e4e/editings/latent_editor.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ sys.path.append(".")
4
+ sys.path.append("..")
5
+ from editings import ganspace, sefa
6
+ from utils.common import tensor2im
7
+
8
+
9
+ class LatentEditor(object):
10
+ def __init__(self, stylegan_generator, is_cars=False):
11
+ self.generator = stylegan_generator
12
+ self.is_cars = is_cars # Since the cars StyleGAN output is 384x512, there is a need to crop the 512x512 output.
13
+
14
+ def apply_ganspace(self, latent, ganspace_pca, edit_directions):
15
+ edit_latents = ganspace.edit(latent, ganspace_pca, edit_directions)
16
+ return self._latents_to_image(edit_latents)
17
+
18
+ def apply_interfacegan(self, latent, direction, factor=1, factor_range=None):
19
+ edit_latents = []
20
+ if factor_range is not None: # Apply a range of editing factors. for example, (-5, 5)
21
+ for f in range(*factor_range):
22
+ edit_latent = latent + f * direction
23
+ edit_latents.append(edit_latent)
24
+ edit_latents = torch.cat(edit_latents)
25
+ else:
26
+ edit_latents = latent + factor * direction
27
+ return self._latents_to_image(edit_latents)
28
+
29
+ def apply_sefa(self, latent, indices=[2, 3, 4, 5], **kwargs):
30
+ edit_latents = sefa.edit(self.generator, latent, indices, **kwargs)
31
+ return self._latents_to_image(edit_latents)
32
+
33
+ # Currently, in order to apply StyleFlow editings, one should run inference,
34
+ # save the latent codes and load them form the official StyleFlow repository.
35
+ # def apply_styleflow(self):
36
+ # pass
37
+
38
+ def _latents_to_image(self, latents):
39
+ with torch.no_grad():
40
+ images, _ = self.generator([latents], randomize_noise=False, input_is_latent=True)
41
+ if self.is_cars:
42
+ images = images[:, :, 64:448, :] # 512x512 -> 384x512
43
+ horizontal_concat_image = torch.cat(list(images), 2)
44
+ final_image = tensor2im(horizontal_concat_image)
45
+ return final_image
e4e/editings/sefa.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+
5
+
6
+ def edit(generator, latents, indices, semantics=1, start_distance=-15.0, end_distance=15.0, num_samples=1, step=11):
7
+
8
+ layers, boundaries, values = factorize_weight(generator, indices)
9
+ codes = latents.detach().cpu().numpy() # (1,18,512)
10
+
11
+ # Generate visualization pages.
12
+ distances = np.linspace(start_distance, end_distance, step)
13
+ num_sam = num_samples
14
+ num_sem = semantics
15
+
16
+ edited_latents = []
17
+ for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False):
18
+ boundary = boundaries[sem_id:sem_id + 1]
19
+ for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False):
20
+ code = codes[sam_id:sam_id + 1]
21
+ for col_id, d in enumerate(distances, start=1):
22
+ temp_code = code.copy()
23
+ temp_code[:, layers, :] += boundary * d
24
+ edited_latents.append(torch.from_numpy(temp_code).float().cuda())
25
+ return torch.cat(edited_latents)
26
+
27
+
28
+ def factorize_weight(g_ema, layers='all'):
29
+
30
+ weights = []
31
+ if layers == 'all' or 0 in layers:
32
+ weight = g_ema.conv1.conv.modulation.weight.T
33
+ weights.append(weight.cpu().detach().numpy())
34
+
35
+ if layers == 'all':
36
+ layers = list(range(g_ema.num_layers - 1))
37
+ else:
38
+ layers = [l - 1 for l in layers if l != 0]
39
+
40
+ for idx in layers:
41
+ weight = g_ema.convs[idx].conv.modulation.weight.T
42
+ weights.append(weight.cpu().detach().numpy())
43
+ weight = np.concatenate(weights, axis=1).astype(np.float32)
44
+ weight = weight / np.linalg.norm(weight, axis=0, keepdims=True)
45
+ eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T))
46
+ return layers, eigen_vectors.T, eigen_values
e4e/environment/e4e_env.yaml ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: e4e_env
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - ca-certificates=2020.4.5.1=hecc5488_0
8
+ - certifi=2020.4.5.1=py36h9f0ad1d_0
9
+ - libedit=3.1.20181209=hc058e9b_0
10
+ - libffi=3.2.1=hd88cf55_4
11
+ - libgcc-ng=9.1.0=hdf63c60_0
12
+ - libstdcxx-ng=9.1.0=hdf63c60_0
13
+ - ncurses=6.2=he6710b0_1
14
+ - ninja=1.10.0=hc9558a2_0
15
+ - openssl=1.1.1g=h516909a_0
16
+ - pip=20.0.2=py36_3
17
+ - python=3.6.7=h0371630_0
18
+ - python_abi=3.6=1_cp36m
19
+ - readline=7.0=h7b6447c_5
20
+ - setuptools=46.4.0=py36_0
21
+ - sqlite=3.31.1=h62c20be_1
22
+ - tk=8.6.8=hbc83047_0
23
+ - wheel=0.34.2=py36_0
24
+ - xz=5.2.5=h7b6447c_0
25
+ - zlib=1.2.11=h7b6447c_3
26
+ - pip:
27
+ - absl-py==0.9.0
28
+ - cachetools==4.1.0
29
+ - chardet==3.0.4
30
+ - cycler==0.10.0
31
+ - decorator==4.4.2
32
+ - future==0.18.2
33
+ - google-auth==1.15.0
34
+ - google-auth-oauthlib==0.4.1
35
+ - grpcio==1.29.0
36
+ - idna==2.9
37
+ - imageio==2.8.0
38
+ - importlib-metadata==1.6.0
39
+ - kiwisolver==1.2.0
40
+ - markdown==3.2.2
41
+ - matplotlib==3.2.1
42
+ - mxnet==1.6.0
43
+ - networkx==2.4
44
+ - numpy==1.18.4
45
+ - oauthlib==3.1.0
46
+ - opencv-python==4.2.0.34
47
+ - pillow==7.1.2
48
+ - protobuf==3.12.1
49
+ - pyasn1==0.4.8
50
+ - pyasn1-modules==0.2.8
51
+ - pyparsing==2.4.7
52
+ - python-dateutil==2.8.1
53
+ - pytorch-lightning==0.7.1
54
+ - pywavelets==1.1.1
55
+ - requests==2.23.0
56
+ - requests-oauthlib==1.3.0
57
+ - rsa==4.0
58
+ - scikit-image==0.17.2
59
+ - scipy==1.4.1
60
+ - six==1.15.0
61
+ - tensorboard==2.2.1
62
+ - tensorboard-plugin-wit==1.6.0.post3
63
+ - tensorboardx==1.9
64
+ - tifffile==2020.5.25
65
+ - torch==1.6.0
66
+ - torchvision==0.7.1
67
+ - tqdm==4.46.0
68
+ - urllib3==1.25.9
69
+ - werkzeug==1.0.1
70
+ - zipp==3.1.0
71
+ - pyaml
72
+ prefix: ~/anaconda3/envs/e4e_env
73
+
e4e/metrics/LEC.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ import torch
4
+ import numpy as np
5
+ from torch.utils.data import DataLoader
6
+
7
+ sys.path.append(".")
8
+ sys.path.append("..")
9
+
10
+ from configs import data_configs
11
+ from datasets.images_dataset import ImagesDataset
12
+ from utils.model_utils import setup_model
13
+
14
+
15
+ class LEC:
16
+ def __init__(self, net, is_cars=False):
17
+ """
18
+ Latent Editing Consistency metric as proposed in the main paper.
19
+ :param net: e4e model loaded over the pSp framework.
20
+ :param is_cars: An indication as to whether or not to crop the middle of the StyleGAN's output images.
21
+ """
22
+ self.net = net
23
+ self.is_cars = is_cars
24
+
25
+ def _encode(self, images):
26
+ """
27
+ Encodes the given images into StyleGAN's latent space.
28
+ :param images: Tensor of shape NxCxHxW representing the images to be encoded.
29
+ :return: Tensor of shape NxKx512 representing the latent space embeddings of the given image (in W(K, *) space).
30
+ """
31
+ codes = self.net.encoder(images)
32
+ assert codes.ndim == 3, f"Invalid latent codes shape, should be NxKx512 but is {codes.shape}"
33
+ # normalize with respect to the center of an average face
34
+ if self.net.opts.start_from_latent_avg:
35
+ codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
36
+ return codes
37
+
38
+ def _generate(self, codes):
39
+ """
40
+ Generate the StyleGAN2 images of the given codes
41
+ :param codes: Tensor of shape NxKx512 representing the StyleGAN's latent codes (in W(K, *) space).
42
+ :return: Tensor of shape NxCxHxW representing the generated images.
43
+ """
44
+ images, _ = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=True)
45
+ images = self.net.face_pool(images)
46
+ if self.is_cars:
47
+ images = images[:, :, 32:224, :]
48
+ return images
49
+
50
+ @staticmethod
51
+ def _filter_outliers(arr):
52
+ arr = np.array(arr)
53
+
54
+ lo = np.percentile(arr, 1, interpolation="lower")
55
+ hi = np.percentile(arr, 99, interpolation="higher")
56
+ return np.extract(
57
+ np.logical_and(lo <= arr, arr <= hi), arr
58
+ )
59
+
60
+ def calculate_metric(self, data_loader, edit_function, inverse_edit_function):
61
+ """
62
+ Calculate the LEC metric score.
63
+ :param data_loader: An iterable that returns a tuple of (images, _), similar to the training data loader.
64
+ :param edit_function: A function that receives latent codes and performs a semantically meaningful edit in the
65
+ latent space.
66
+ :param inverse_edit_function: A function that receives latent codes and performs the inverse edit of the
67
+ `edit_function` parameter.
68
+ :return: The LEC metric score.
69
+ """
70
+ distances = []
71
+ with torch.no_grad():
72
+ for batch in data_loader:
73
+ x, _ = batch
74
+ inputs = x.to(device).float()
75
+
76
+ codes = self._encode(inputs)
77
+ edited_codes = edit_function(codes)
78
+ edited_image = self._generate(edited_codes)
79
+ edited_image_inversion_codes = self._encode(edited_image)
80
+ inverse_edit_codes = inverse_edit_function(edited_image_inversion_codes)
81
+
82
+ dist = (codes - inverse_edit_codes).norm(2, dim=(1, 2)).mean()
83
+ distances.append(dist.to("cpu").numpy())
84
+
85
+ distances = self._filter_outliers(distances)
86
+ return distances.mean()
87
+
88
+
89
+ if __name__ == "__main__":
90
+ device = "cuda"
91
+
92
+ parser = argparse.ArgumentParser(description="LEC metric calculator")
93
+
94
+ parser.add_argument("--batch", type=int, default=8, help="batch size for the models")
95
+ parser.add_argument("--images_dir", type=str, default=None,
96
+ help="Path to the images directory on which we calculate the LEC score")
97
+ parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to the model checkpoints")
98
+
99
+ args = parser.parse_args()
100
+ print(args)
101
+
102
+ net, opts = setup_model(args.ckpt, device)
103
+ dataset_args = data_configs.DATASETS[opts.dataset_type]
104
+ transforms_dict = dataset_args['transforms'](opts).get_transforms()
105
+
106
+ images_directory = dataset_args['test_source_root'] if args.images_dir is None else args.images_dir
107
+ test_dataset = ImagesDataset(source_root=images_directory,
108
+ target_root=images_directory,
109
+ source_transform=transforms_dict['transform_source'],
110
+ target_transform=transforms_dict['transform_test'],
111
+ opts=opts)
112
+
113
+ data_loader = DataLoader(test_dataset,
114
+ batch_size=args.batch,
115
+ shuffle=False,
116
+ num_workers=2,
117
+ drop_last=True)
118
+
119
+ print(f'dataset length: {len(test_dataset)}')
120
+
121
+ # In the following example, we are using an InterfaceGAN based editing to calculate the LEC metric.
122
+ # Change the provided example according to your domain and needs.
123
+ direction = torch.load('../editings/interfacegan_directions/age.pt').to(device)
124
+
125
+ def edit_func_example(codes):
126
+ return codes + 3 * direction
127
+
128
+
129
+ def inverse_edit_func_example(codes):
130
+ return codes - 3 * direction
131
+
132
+ lec = LEC(net, is_cars='car' in opts.dataset_type)
133
+ result = lec.calculate_metric(data_loader, edit_func_example, inverse_edit_func_example)
134
+ print(f"LEC: {result}")
e4e/models/__init__.py ADDED
File without changes
e4e/models/discriminator.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class LatentCodesDiscriminator(nn.Module):
5
+ def __init__(self, style_dim, n_mlp):
6
+ super().__init__()
7
+
8
+ self.style_dim = style_dim
9
+
10
+ layers = []
11
+ for i in range(n_mlp-1):
12
+ layers.append(
13
+ nn.Linear(style_dim, style_dim)
14
+ )
15
+ layers.append(nn.LeakyReLU(0.2))
16
+ layers.append(nn.Linear(512, 1))
17
+ self.mlp = nn.Sequential(*layers)
18
+
19
+ def forward(self, w):
20
+ return self.mlp(w)
e4e/models/encoders/__init__.py ADDED
File without changes
e4e/models/encoders/helpers.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
5
+
6
+ """
7
+ ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
8
+ """
9
+
10
+
11
+ class Flatten(Module):
12
+ def forward(self, input):
13
+ return input.view(input.size(0), -1)
14
+
15
+
16
+ def l2_norm(input, axis=1):
17
+ norm = torch.norm(input, 2, axis, True)
18
+ output = torch.div(input, norm)
19
+ return output
20
+
21
+
22
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
23
+ """ A named tuple describing a ResNet block. """
24
+
25
+
26
+ def get_block(in_channel, depth, num_units, stride=2):
27
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
28
+
29
+
30
+ def get_blocks(num_layers):
31
+ if num_layers == 50:
32
+ blocks = [
33
+ get_block(in_channel=64, depth=64, num_units=3),
34
+ get_block(in_channel=64, depth=128, num_units=4),
35
+ get_block(in_channel=128, depth=256, num_units=14),
36
+ get_block(in_channel=256, depth=512, num_units=3)
37
+ ]
38
+ elif num_layers == 100:
39
+ blocks = [
40
+ get_block(in_channel=64, depth=64, num_units=3),
41
+ get_block(in_channel=64, depth=128, num_units=13),
42
+ get_block(in_channel=128, depth=256, num_units=30),
43
+ get_block(in_channel=256, depth=512, num_units=3)
44
+ ]
45
+ elif num_layers == 152:
46
+ blocks = [
47
+ get_block(in_channel=64, depth=64, num_units=3),
48
+ get_block(in_channel=64, depth=128, num_units=8),
49
+ get_block(in_channel=128, depth=256, num_units=36),
50
+ get_block(in_channel=256, depth=512, num_units=3)
51
+ ]
52
+ else:
53
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
54
+ return blocks
55
+
56
+
57
+ class SEModule(Module):
58
+ def __init__(self, channels, reduction):
59
+ super(SEModule, self).__init__()
60
+ self.avg_pool = AdaptiveAvgPool2d(1)
61
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
62
+ self.relu = ReLU(inplace=True)
63
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
64
+ self.sigmoid = Sigmoid()
65
+
66
+ def forward(self, x):
67
+ module_input = x
68
+ x = self.avg_pool(x)
69
+ x = self.fc1(x)
70
+ x = self.relu(x)
71
+ x = self.fc2(x)
72
+ x = self.sigmoid(x)
73
+ return module_input * x
74
+
75
+
76
+ class bottleneck_IR(Module):
77
+ def __init__(self, in_channel, depth, stride):
78
+ super(bottleneck_IR, self).__init__()
79
+ if in_channel == depth:
80
+ self.shortcut_layer = MaxPool2d(1, stride)
81
+ else:
82
+ self.shortcut_layer = Sequential(
83
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
84
+ BatchNorm2d(depth)
85
+ )
86
+ self.res_layer = Sequential(
87
+ BatchNorm2d(in_channel),
88
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
89
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
90
+ )
91
+
92
+ def forward(self, x):
93
+ shortcut = self.shortcut_layer(x)
94
+ res = self.res_layer(x)
95
+ return res + shortcut
96
+
97
+
98
+ class bottleneck_IR_SE(Module):
99
+ def __init__(self, in_channel, depth, stride):
100
+ super(bottleneck_IR_SE, self).__init__()
101
+ if in_channel == depth:
102
+ self.shortcut_layer = MaxPool2d(1, stride)
103
+ else:
104
+ self.shortcut_layer = Sequential(
105
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
106
+ BatchNorm2d(depth)
107
+ )
108
+ self.res_layer = Sequential(
109
+ BatchNorm2d(in_channel),
110
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
111
+ PReLU(depth),
112
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
113
+ BatchNorm2d(depth),
114
+ SEModule(depth, 16)
115
+ )
116
+
117
+ def forward(self, x):
118
+ shortcut = self.shortcut_layer(x)
119
+ res = self.res_layer(x)
120
+ return res + shortcut
121
+
122
+
123
+ def _upsample_add(x, y):
124
+ """Upsample and add two feature maps.
125
+ Args:
126
+ x: (Variable) top feature map to be upsampled.
127
+ y: (Variable) lateral feature map.
128
+ Returns:
129
+ (Variable) added feature map.
130
+ Note in PyTorch, when input size is odd, the upsampled feature map
131
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
132
+ maybe not equal to the lateral feature map size.
133
+ e.g.
134
+ original input size: [N,_,15,15] ->
135
+ conv2d feature map size: [N,_,8,8] ->
136
+ upsampled feature map size: [N,_,16,16]
137
+ So we choose bilinear upsample which supports arbitrary output sizes.
138
+ """
139
+ _, _, H, W = y.size()
140
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
e4e/models/encoders/model_irse.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
2
+ from e4e.models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
3
+
4
+ """
5
+ Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
6
+ """
7
+
8
+
9
+ class Backbone(Module):
10
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
11
+ super(Backbone, self).__init__()
12
+ assert input_size in [112, 224], "input_size should be 112 or 224"
13
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
14
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
15
+ blocks = get_blocks(num_layers)
16
+ if mode == 'ir':
17
+ unit_module = bottleneck_IR
18
+ elif mode == 'ir_se':
19
+ unit_module = bottleneck_IR_SE
20
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
21
+ BatchNorm2d(64),
22
+ PReLU(64))
23
+ if input_size == 112:
24
+ self.output_layer = Sequential(BatchNorm2d(512),
25
+ Dropout(drop_ratio),
26
+ Flatten(),
27
+ Linear(512 * 7 * 7, 512),
28
+ BatchNorm1d(512, affine=affine))
29
+ else:
30
+ self.output_layer = Sequential(BatchNorm2d(512),
31
+ Dropout(drop_ratio),
32
+ Flatten(),
33
+ Linear(512 * 14 * 14, 512),
34
+ BatchNorm1d(512, affine=affine))
35
+
36
+ modules = []
37
+ for block in blocks:
38
+ for bottleneck in block:
39
+ modules.append(unit_module(bottleneck.in_channel,
40
+ bottleneck.depth,
41
+ bottleneck.stride))
42
+ self.body = Sequential(*modules)
43
+
44
+ def forward(self, x):
45
+ x = self.input_layer(x)
46
+ x = self.body(x)
47
+ x = self.output_layer(x)
48
+ return l2_norm(x)
49
+
50
+
51
+ def IR_50(input_size):
52
+ """Constructs a ir-50 model."""
53
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
54
+ return model
55
+
56
+
57
+ def IR_101(input_size):
58
+ """Constructs a ir-101 model."""
59
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
60
+ return model
61
+
62
+
63
+ def IR_152(input_size):
64
+ """Constructs a ir-152 model."""
65
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
66
+ return model
67
+
68
+
69
+ def IR_SE_50(input_size):
70
+ """Constructs a ir_se-50 model."""
71
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
72
+ return model
73
+
74
+
75
+ def IR_SE_101(input_size):
76
+ """Constructs a ir_se-101 model."""
77
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
78
+ return model
79
+
80
+
81
+ def IR_SE_152(input_size):
82
+ """Constructs a ir_se-152 model."""
83
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
84
+ return model
e4e/models/encoders/psp_encoders.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
7
+
8
+ from e4e.models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add
9
+ from e4e.models.stylegan2.model import EqualLinear
10
+
11
+
12
+ class ProgressiveStage(Enum):
13
+ WTraining = 0
14
+ Delta1Training = 1
15
+ Delta2Training = 2
16
+ Delta3Training = 3
17
+ Delta4Training = 4
18
+ Delta5Training = 5
19
+ Delta6Training = 6
20
+ Delta7Training = 7
21
+ Delta8Training = 8
22
+ Delta9Training = 9
23
+ Delta10Training = 10
24
+ Delta11Training = 11
25
+ Delta12Training = 12
26
+ Delta13Training = 13
27
+ Delta14Training = 14
28
+ Delta15Training = 15
29
+ Delta16Training = 16
30
+ Delta17Training = 17
31
+ Inference = 18
32
+
33
+
34
+ class GradualStyleBlock(Module):
35
+ def __init__(self, in_c, out_c, spatial):
36
+ super(GradualStyleBlock, self).__init__()
37
+ self.out_c = out_c
38
+ self.spatial = spatial
39
+ num_pools = int(np.log2(spatial))
40
+ modules = []
41
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
42
+ nn.LeakyReLU()]
43
+ for i in range(num_pools - 1):
44
+ modules += [
45
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
46
+ nn.LeakyReLU()
47
+ ]
48
+ self.convs = nn.Sequential(*modules)
49
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
50
+
51
+ def forward(self, x):
52
+ x = self.convs(x)
53
+ x = x.view(-1, self.out_c)
54
+ x = self.linear(x)
55
+ return x
56
+
57
+
58
+ class GradualStyleEncoder(Module):
59
+ def __init__(self, num_layers, mode='ir', opts=None):
60
+ super(GradualStyleEncoder, self).__init__()
61
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
62
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
63
+ blocks = get_blocks(num_layers)
64
+ if mode == 'ir':
65
+ unit_module = bottleneck_IR
66
+ elif mode == 'ir_se':
67
+ unit_module = bottleneck_IR_SE
68
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
69
+ BatchNorm2d(64),
70
+ PReLU(64))
71
+ modules = []
72
+ for block in blocks:
73
+ for bottleneck in block:
74
+ modules.append(unit_module(bottleneck.in_channel,
75
+ bottleneck.depth,
76
+ bottleneck.stride))
77
+ self.body = Sequential(*modules)
78
+
79
+ self.styles = nn.ModuleList()
80
+ log_size = int(math.log(opts.stylegan_size, 2))
81
+ self.style_count = 2 * log_size - 2
82
+ self.coarse_ind = 3
83
+ self.middle_ind = 7
84
+ for i in range(self.style_count):
85
+ if i < self.coarse_ind:
86
+ style = GradualStyleBlock(512, 512, 16)
87
+ elif i < self.middle_ind:
88
+ style = GradualStyleBlock(512, 512, 32)
89
+ else:
90
+ style = GradualStyleBlock(512, 512, 64)
91
+ self.styles.append(style)
92
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
93
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
94
+
95
+ def forward(self, x):
96
+ x = self.input_layer(x)
97
+
98
+ latents = []
99
+ modulelist = list(self.body._modules.values())
100
+ for i, l in enumerate(modulelist):
101
+ x = l(x)
102
+ if i == 6:
103
+ c1 = x
104
+ elif i == 20:
105
+ c2 = x
106
+ elif i == 23:
107
+ c3 = x
108
+
109
+ for j in range(self.coarse_ind):
110
+ latents.append(self.styles[j](c3))
111
+
112
+ p2 = _upsample_add(c3, self.latlayer1(c2))
113
+ for j in range(self.coarse_ind, self.middle_ind):
114
+ latents.append(self.styles[j](p2))
115
+
116
+ p1 = _upsample_add(p2, self.latlayer2(c1))
117
+ for j in range(self.middle_ind, self.style_count):
118
+ latents.append(self.styles[j](p1))
119
+
120
+ out = torch.stack(latents, dim=1)
121
+ return out
122
+
123
+
124
+ class Encoder4Editing(Module):
125
+ def __init__(self, num_layers, mode='ir', opts=None):
126
+ super(Encoder4Editing, self).__init__()
127
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
128
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
129
+ blocks = get_blocks(num_layers)
130
+ if mode == 'ir':
131
+ unit_module = bottleneck_IR
132
+ elif mode == 'ir_se':
133
+ unit_module = bottleneck_IR_SE
134
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
135
+ BatchNorm2d(64),
136
+ PReLU(64))
137
+ modules = []
138
+ for block in blocks:
139
+ for bottleneck in block:
140
+ modules.append(unit_module(bottleneck.in_channel,
141
+ bottleneck.depth,
142
+ bottleneck.stride))
143
+ self.body = Sequential(*modules)
144
+
145
+ self.styles = nn.ModuleList()
146
+ log_size = int(math.log(opts.stylegan_size, 2))
147
+ self.style_count = 2 * log_size - 2
148
+ self.coarse_ind = 3
149
+ self.middle_ind = 7
150
+
151
+ for i in range(self.style_count):
152
+ if i < self.coarse_ind:
153
+ style = GradualStyleBlock(512, 512, 16)
154
+ elif i < self.middle_ind:
155
+ style = GradualStyleBlock(512, 512, 32)
156
+ else:
157
+ style = GradualStyleBlock(512, 512, 64)
158
+ self.styles.append(style)
159
+
160
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
161
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
162
+
163
+ self.progressive_stage = ProgressiveStage.Inference
164
+
165
+ def get_deltas_starting_dimensions(self):
166
+ ''' Get a list of the initial dimension of every delta from which it is applied '''
167
+ return list(range(self.style_count)) # Each dimension has a delta applied to it
168
+
169
+ def set_progressive_stage(self, new_stage: ProgressiveStage):
170
+ self.progressive_stage = new_stage
171
+ print('Changed progressive stage to: ', new_stage)
172
+
173
+ def forward(self, x):
174
+ x = self.input_layer(x)
175
+
176
+ modulelist = list(self.body._modules.values())
177
+ for i, l in enumerate(modulelist):
178
+ x = l(x)
179
+ if i == 6:
180
+ c1 = x
181
+ elif i == 20:
182
+ c2 = x
183
+ elif i == 23:
184
+ c3 = x
185
+
186
+ # Infer main W and duplicate it
187
+ w0 = self.styles[0](c3)
188
+ w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
189
+ stage = self.progressive_stage.value
190
+ features = c3
191
+ for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas
192
+ if i == self.coarse_ind:
193
+ p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
194
+ features = p2
195
+ elif i == self.middle_ind:
196
+ p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
197
+ features = p1
198
+ delta_i = self.styles[i](features)
199
+ w[:, i] += delta_i
200
+ return w
e4e/models/latent_codes_pool.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+
4
+
5
+ class LatentCodesPool:
6
+ """This class implements latent codes buffer that stores previously generated w latent codes.
7
+ This buffer enables us to update discriminators using a history of generated w's
8
+ rather than the ones produced by the latest encoder.
9
+ """
10
+
11
+ def __init__(self, pool_size):
12
+ """Initialize the ImagePool class
13
+ Parameters:
14
+ pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
15
+ """
16
+ self.pool_size = pool_size
17
+ if self.pool_size > 0: # create an empty pool
18
+ self.num_ws = 0
19
+ self.ws = []
20
+
21
+ def query(self, ws):
22
+ """Return w's from the pool.
23
+ Parameters:
24
+ ws: the latest generated w's from the generator
25
+ Returns w's from the buffer.
26
+ By 50/100, the buffer will return input w's.
27
+ By 50/100, the buffer will return w's previously stored in the buffer,
28
+ and insert the current w's to the buffer.
29
+ """
30
+ if self.pool_size == 0: # if the buffer size is 0, do nothing
31
+ return ws
32
+ return_ws = []
33
+ for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512)
34
+ # w = torch.unsqueeze(image.data, 0)
35
+ if w.ndim == 2:
36
+ i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate
37
+ w = w[i]
38
+ self.handle_w(w, return_ws)
39
+ return_ws = torch.stack(return_ws, 0) # collect all the images and return
40
+ return return_ws
41
+
42
+ def handle_w(self, w, return_ws):
43
+ if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer
44
+ self.num_ws = self.num_ws + 1
45
+ self.ws.append(w)
46
+ return_ws.append(w)
47
+ else:
48
+ p = random.uniform(0, 1)
49
+ if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer
50
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
51
+ tmp = self.ws[random_id].clone()
52
+ self.ws[random_id] = w
53
+ return_ws.append(tmp)
54
+ else: # by another 50% chance, the buffer will return the current image
55
+ return_ws.append(w)
e4e/models/psp.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+
3
+ matplotlib.use('Agg')
4
+ import torch
5
+ from torch import nn
6
+ from e4e.models.encoders import psp_encoders
7
+ from e4e.models.stylegan2.model import Generator
8
+ from e4e.configs.paths_config import model_paths
9
+
10
+
11
+ def get_keys(d, name):
12
+ if 'state_dict' in d:
13
+ d = d['state_dict']
14
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
15
+ return d_filt
16
+
17
+
18
+ class pSp(nn.Module):
19
+
20
+ def __init__(self, opts, device):
21
+ super(pSp, self).__init__()
22
+ self.opts = opts
23
+ self.device = device
24
+ # Define architecture
25
+ self.encoder = self.set_encoder()
26
+ self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2)
27
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
28
+ # Load weights if needed
29
+ self.load_weights()
30
+
31
+ def set_encoder(self):
32
+ if self.opts.encoder_type == 'GradualStyleEncoder':
33
+ encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
34
+ elif self.opts.encoder_type == 'Encoder4Editing':
35
+ encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts)
36
+ else:
37
+ raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
38
+ return encoder
39
+
40
+ def load_weights(self):
41
+ if self.opts.checkpoint_path is not None:
42
+ print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
43
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
44
+ self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
45
+ self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
46
+ self.__load_latent_avg(ckpt)
47
+ else:
48
+ print('Loading encoders weights from irse50!')
49
+ encoder_ckpt = torch.load(model_paths['ir_se50'])
50
+ self.encoder.load_state_dict(encoder_ckpt, strict=False)
51
+ print('Loading decoder weights from pretrained!')
52
+ ckpt = torch.load(self.opts.stylegan_weights)
53
+ self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
54
+ self.__load_latent_avg(ckpt, repeat=self.encoder.style_count)
55
+
56
+ def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
57
+ inject_latent=None, return_latents=False, alpha=None):
58
+ if input_code:
59
+ codes = x
60
+ else:
61
+ codes = self.encoder(x)
62
+ # normalize with respect to the center of an average face
63
+ if self.opts.start_from_latent_avg:
64
+ if codes.ndim == 2:
65
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
66
+ else:
67
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
68
+
69
+ if latent_mask is not None:
70
+ for i in latent_mask:
71
+ if inject_latent is not None:
72
+ if alpha is not None:
73
+ codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
74
+ else:
75
+ codes[:, i] = inject_latent[:, i]
76
+ else:
77
+ codes[:, i] = 0
78
+
79
+ input_is_latent = not input_code
80
+ images, result_latent = self.decoder([codes],
81
+ input_is_latent=input_is_latent,
82
+ randomize_noise=randomize_noise,
83
+ return_latents=return_latents)
84
+
85
+ if resize:
86
+ images = self.face_pool(images)
87
+
88
+ if return_latents:
89
+ return images, result_latent
90
+ else:
91
+ return images
92
+
93
+ def __load_latent_avg(self, ckpt, repeat=None):
94
+ if 'latent_avg' in ckpt:
95
+ self.latent_avg = ckpt['latent_avg'].to(self.device)
96
+ if repeat is not None:
97
+ self.latent_avg = self.latent_avg.repeat(repeat, 1)
98
+ else:
99
+ self.latent_avg = None
e4e/models/stylegan2/__init__.py ADDED
File without changes
e4e/models/stylegan2/model.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ if torch.cuda.is_available():
8
+ from op.fused_act import FusedLeakyReLU, fused_leaky_relu
9
+ from op.upfirdn2d import upfirdn2d
10
+ else:
11
+ from op.fused_act_cpu import FusedLeakyReLU, fused_leaky_relu
12
+ from op.upfirdn2d_cpu import upfirdn2d
13
+
14
+
15
+ class PixelNorm(nn.Module):
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ def forward(self, input):
20
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
21
+
22
+
23
+ def make_kernel(k):
24
+ k = torch.tensor(k, dtype=torch.float32)
25
+
26
+ if k.ndim == 1:
27
+ k = k[None, :] * k[:, None]
28
+
29
+ k /= k.sum()
30
+
31
+ return k
32
+
33
+
34
+ class Upsample(nn.Module):
35
+ def __init__(self, kernel, factor=2):
36
+ super().__init__()
37
+
38
+ self.factor = factor
39
+ kernel = make_kernel(kernel) * (factor ** 2)
40
+ self.register_buffer('kernel', kernel)
41
+
42
+ p = kernel.shape[0] - factor
43
+
44
+ pad0 = (p + 1) // 2 + factor - 1
45
+ pad1 = p // 2
46
+
47
+ self.pad = (pad0, pad1)
48
+
49
+ def forward(self, input):
50
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
51
+
52
+ return out
53
+
54
+
55
+ class Downsample(nn.Module):
56
+ def __init__(self, kernel, factor=2):
57
+ super().__init__()
58
+
59
+ self.factor = factor
60
+ kernel = make_kernel(kernel)
61
+ self.register_buffer('kernel', kernel)
62
+
63
+ p = kernel.shape[0] - factor
64
+
65
+ pad0 = (p + 1) // 2
66
+ pad1 = p // 2
67
+
68
+ self.pad = (pad0, pad1)
69
+
70
+ def forward(self, input):
71
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
72
+
73
+ return out
74
+
75
+
76
+ class Blur(nn.Module):
77
+ def __init__(self, kernel, pad, upsample_factor=1):
78
+ super().__init__()
79
+
80
+ kernel = make_kernel(kernel)
81
+
82
+ if upsample_factor > 1:
83
+ kernel = kernel * (upsample_factor ** 2)
84
+
85
+ self.register_buffer('kernel', kernel)
86
+
87
+ self.pad = pad
88
+
89
+ def forward(self, input):
90
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
91
+
92
+ return out
93
+
94
+
95
+ class EqualConv2d(nn.Module):
96
+ def __init__(
97
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
98
+ ):
99
+ super().__init__()
100
+
101
+ self.weight = nn.Parameter(
102
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
103
+ )
104
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
105
+
106
+ self.stride = stride
107
+ self.padding = padding
108
+
109
+ if bias:
110
+ self.bias = nn.Parameter(torch.zeros(out_channel))
111
+
112
+ else:
113
+ self.bias = None
114
+
115
+ def forward(self, input):
116
+ out = F.conv2d(
117
+ input,
118
+ self.weight * self.scale,
119
+ bias=self.bias,
120
+ stride=self.stride,
121
+ padding=self.padding,
122
+ )
123
+
124
+ return out
125
+
126
+ def __repr__(self):
127
+ return (
128
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
129
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
130
+ )
131
+
132
+
133
+ class EqualLinear(nn.Module):
134
+ def __init__(
135
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
136
+ ):
137
+ super().__init__()
138
+
139
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
140
+
141
+ if bias:
142
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
143
+
144
+ else:
145
+ self.bias = None
146
+
147
+ self.activation = activation
148
+
149
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
150
+ self.lr_mul = lr_mul
151
+
152
+ def forward(self, input):
153
+ if self.activation:
154
+ out = F.linear(input, self.weight * self.scale)
155
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
156
+
157
+ else:
158
+ out = F.linear(
159
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
160
+ )
161
+
162
+ return out
163
+
164
+ def __repr__(self):
165
+ return (
166
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
167
+ )
168
+
169
+
170
+ class ScaledLeakyReLU(nn.Module):
171
+ def __init__(self, negative_slope=0.2):
172
+ super().__init__()
173
+
174
+ self.negative_slope = negative_slope
175
+
176
+ def forward(self, input):
177
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
178
+
179
+ return out * math.sqrt(2)
180
+
181
+
182
+ class ModulatedConv2d(nn.Module):
183
+ def __init__(
184
+ self,
185
+ in_channel,
186
+ out_channel,
187
+ kernel_size,
188
+ style_dim,
189
+ demodulate=True,
190
+ upsample=False,
191
+ downsample=False,
192
+ blur_kernel=[1, 3, 3, 1],
193
+ ):
194
+ super().__init__()
195
+
196
+ self.eps = 1e-8
197
+ self.kernel_size = kernel_size
198
+ self.in_channel = in_channel
199
+ self.out_channel = out_channel
200
+ self.upsample = upsample
201
+ self.downsample = downsample
202
+
203
+ if upsample:
204
+ factor = 2
205
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
206
+ pad0 = (p + 1) // 2 + factor - 1
207
+ pad1 = p // 2 + 1
208
+
209
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
210
+
211
+ if downsample:
212
+ factor = 2
213
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
214
+ pad0 = (p + 1) // 2
215
+ pad1 = p // 2
216
+
217
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
218
+
219
+ fan_in = in_channel * kernel_size ** 2
220
+ self.scale = 1 / math.sqrt(fan_in)
221
+ self.padding = kernel_size // 2
222
+
223
+ self.weight = nn.Parameter(
224
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
225
+ )
226
+
227
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
228
+
229
+ self.demodulate = demodulate
230
+
231
+ def __repr__(self):
232
+ return (
233
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
234
+ f'upsample={self.upsample}, downsample={self.downsample})'
235
+ )
236
+
237
+ def forward(self, input, style):
238
+ batch, in_channel, height, width = input.shape
239
+
240
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
241
+ weight = self.scale * self.weight * style
242
+
243
+ if self.demodulate:
244
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
245
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
246
+
247
+ weight = weight.view(
248
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
249
+ )
250
+
251
+ if self.upsample:
252
+ input = input.view(1, batch * in_channel, height, width)
253
+ weight = weight.view(
254
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
255
+ )
256
+ weight = weight.transpose(1, 2).reshape(
257
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
258
+ )
259
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
260
+ _, _, height, width = out.shape
261
+ out = out.view(batch, self.out_channel, height, width)
262
+ out = self.blur(out)
263
+
264
+ elif self.downsample:
265
+ input = self.blur(input)
266
+ _, _, height, width = input.shape
267
+ input = input.view(1, batch * in_channel, height, width)
268
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
269
+ _, _, height, width = out.shape
270
+ out = out.view(batch, self.out_channel, height, width)
271
+
272
+ else:
273
+ input = input.view(1, batch * in_channel, height, width)
274
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
275
+ _, _, height, width = out.shape
276
+ out = out.view(batch, self.out_channel, height, width)
277
+
278
+ return out
279
+
280
+
281
+ class NoiseInjection(nn.Module):
282
+ def __init__(self):
283
+ super().__init__()
284
+
285
+ self.weight = nn.Parameter(torch.zeros(1))
286
+
287
+ def forward(self, image, noise=None):
288
+ if noise is None:
289
+ batch, _, height, width = image.shape
290
+ noise = image.new_empty(batch, 1, height, width).normal_()
291
+
292
+ return image + self.weight * noise
293
+
294
+
295
+ class ConstantInput(nn.Module):
296
+ def __init__(self, channel, size=4):
297
+ super().__init__()
298
+
299
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
300
+
301
+ def forward(self, input):
302
+ batch = input.shape[0]
303
+ out = self.input.repeat(batch, 1, 1, 1)
304
+
305
+ return out
306
+
307
+
308
+ class StyledConv(nn.Module):
309
+ def __init__(
310
+ self,
311
+ in_channel,
312
+ out_channel,
313
+ kernel_size,
314
+ style_dim,
315
+ upsample=False,
316
+ blur_kernel=[1, 3, 3, 1],
317
+ demodulate=True,
318
+ ):
319
+ super().__init__()
320
+
321
+ self.conv = ModulatedConv2d(
322
+ in_channel,
323
+ out_channel,
324
+ kernel_size,
325
+ style_dim,
326
+ upsample=upsample,
327
+ blur_kernel=blur_kernel,
328
+ demodulate=demodulate,
329
+ )
330
+
331
+ self.noise = NoiseInjection()
332
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
333
+ # self.activate = ScaledLeakyReLU(0.2)
334
+ self.activate = FusedLeakyReLU(out_channel)
335
+
336
+ def forward(self, input, style, noise=None):
337
+ out = self.conv(input, style)
338
+ out = self.noise(out, noise=noise)
339
+ # out = out + self.bias
340
+ out = self.activate(out)
341
+
342
+ return out
343
+
344
+
345
+ class ToRGB(nn.Module):
346
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
347
+ super().__init__()
348
+
349
+ if upsample:
350
+ self.upsample = Upsample(blur_kernel)
351
+
352
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
353
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
354
+
355
+ def forward(self, input, style, skip=None):
356
+ out = self.conv(input, style)
357
+ out = out + self.bias
358
+
359
+ if skip is not None:
360
+ skip = self.upsample(skip)
361
+
362
+ out = out + skip
363
+
364
+ return out
365
+
366
+
367
+ class Generator(nn.Module):
368
+ def __init__(
369
+ self,
370
+ size,
371
+ style_dim,
372
+ n_mlp,
373
+ channel_multiplier=2,
374
+ blur_kernel=[1, 3, 3, 1],
375
+ lr_mlp=0.01,
376
+ ):
377
+ super().__init__()
378
+
379
+ self.size = size
380
+
381
+ self.style_dim = style_dim
382
+
383
+ layers = [PixelNorm()]
384
+
385
+ for i in range(n_mlp):
386
+ layers.append(
387
+ EqualLinear(
388
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
389
+ )
390
+ )
391
+
392
+ self.style = nn.Sequential(*layers)
393
+
394
+ self.channels = {
395
+ 4: 512,
396
+ 8: 512,
397
+ 16: 512,
398
+ 32: 512,
399
+ 64: 256 * channel_multiplier,
400
+ 128: 128 * channel_multiplier,
401
+ 256: 64 * channel_multiplier,
402
+ 512: 32 * channel_multiplier,
403
+ 1024: 16 * channel_multiplier,
404
+ }
405
+
406
+ self.input = ConstantInput(self.channels[4])
407
+ self.conv1 = StyledConv(
408
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
409
+ )
410
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
411
+
412
+ self.log_size = int(math.log(size, 2))
413
+ self.num_layers = (self.log_size - 2) * 2 + 1
414
+
415
+ self.convs = nn.ModuleList()
416
+ self.upsamples = nn.ModuleList()
417
+ self.to_rgbs = nn.ModuleList()
418
+ self.noises = nn.Module()
419
+
420
+ in_channel = self.channels[4]
421
+
422
+ for layer_idx in range(self.num_layers):
423
+ res = (layer_idx + 5) // 2
424
+ shape = [1, 1, 2 ** res, 2 ** res]
425
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
426
+
427
+ for i in range(3, self.log_size + 1):
428
+ out_channel = self.channels[2 ** i]
429
+
430
+ self.convs.append(
431
+ StyledConv(
432
+ in_channel,
433
+ out_channel,
434
+ 3,
435
+ style_dim,
436
+ upsample=True,
437
+ blur_kernel=blur_kernel,
438
+ )
439
+ )
440
+
441
+ self.convs.append(
442
+ StyledConv(
443
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
444
+ )
445
+ )
446
+
447
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
448
+
449
+ in_channel = out_channel
450
+
451
+ self.n_latent = self.log_size * 2 - 2
452
+
453
+ def make_noise(self):
454
+ device = self.input.input.device
455
+
456
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
457
+
458
+ for i in range(3, self.log_size + 1):
459
+ for _ in range(2):
460
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
461
+
462
+ return noises
463
+
464
+ def mean_latent(self, n_latent):
465
+ latent_in = torch.randn(
466
+ n_latent, self.style_dim, device=self.input.input.device
467
+ )
468
+ latent = self.style(latent_in).mean(0, keepdim=True)
469
+
470
+ return latent
471
+
472
+ def get_latent(self, input):
473
+ return self.style(input)
474
+
475
+ def forward(
476
+ self,
477
+ styles,
478
+ return_latents=False,
479
+ return_features=False,
480
+ inject_index=None,
481
+ truncation=1,
482
+ truncation_latent=None,
483
+ input_is_latent=False,
484
+ noise=None,
485
+ randomize_noise=True,
486
+ ):
487
+ if not input_is_latent:
488
+ styles = [self.style(s) for s in styles]
489
+
490
+ if noise is None:
491
+ if randomize_noise:
492
+ noise = [None] * self.num_layers
493
+ else:
494
+ noise = [
495
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
496
+ ]
497
+
498
+ if truncation < 1:
499
+ style_t = []
500
+
501
+ for style in styles:
502
+ style_t.append(
503
+ truncation_latent + truncation * (style - truncation_latent)
504
+ )
505
+
506
+ styles = style_t
507
+
508
+ if len(styles) < 2:
509
+ inject_index = self.n_latent
510
+
511
+ if styles[0].ndim < 3:
512
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
513
+ else:
514
+ latent = styles[0]
515
+
516
+ else:
517
+ if inject_index is None:
518
+ inject_index = random.randint(1, self.n_latent - 1)
519
+
520
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
521
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
522
+
523
+ latent = torch.cat([latent, latent2], 1)
524
+
525
+ out = self.input(latent)
526
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
527
+
528
+ skip = self.to_rgb1(out, latent[:, 1])
529
+
530
+ i = 1
531
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
532
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
533
+ ):
534
+ out = conv1(out, latent[:, i], noise=noise1)
535
+ out = conv2(out, latent[:, i + 1], noise=noise2)
536
+ skip = to_rgb(out, latent[:, i + 2], skip)
537
+
538
+ i += 2
539
+
540
+ image = skip
541
+
542
+ if return_latents:
543
+ return image, latent
544
+ elif return_features:
545
+ return image, out
546
+ else:
547
+ return image, None
548
+
549
+
550
+ class ConvLayer(nn.Sequential):
551
+ def __init__(
552
+ self,
553
+ in_channel,
554
+ out_channel,
555
+ kernel_size,
556
+ downsample=False,
557
+ blur_kernel=[1, 3, 3, 1],
558
+ bias=True,
559
+ activate=True,
560
+ ):
561
+ layers = []
562
+
563
+ if downsample:
564
+ factor = 2
565
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
566
+ pad0 = (p + 1) // 2
567
+ pad1 = p // 2
568
+
569
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
570
+
571
+ stride = 2
572
+ self.padding = 0
573
+
574
+ else:
575
+ stride = 1
576
+ self.padding = kernel_size // 2
577
+
578
+ layers.append(
579
+ EqualConv2d(
580
+ in_channel,
581
+ out_channel,
582
+ kernel_size,
583
+ padding=self.padding,
584
+ stride=stride,
585
+ bias=bias and not activate,
586
+ )
587
+ )
588
+
589
+ if activate:
590
+ if bias:
591
+ layers.append(FusedLeakyReLU(out_channel))
592
+
593
+ else:
594
+ layers.append(ScaledLeakyReLU(0.2))
595
+
596
+ super().__init__(*layers)
597
+
598
+
599
+ class ResBlock(nn.Module):
600
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
601
+ super().__init__()
602
+
603
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
604
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
605
+
606
+ self.skip = ConvLayer(
607
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
608
+ )
609
+
610
+ def forward(self, input):
611
+ out = self.conv1(input)
612
+ out = self.conv2(out)
613
+
614
+ skip = self.skip(input)
615
+ out = (out + skip) / math.sqrt(2)
616
+
617
+ return out
618
+
619
+
620
+ class Discriminator(nn.Module):
621
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
622
+ super().__init__()
623
+
624
+ channels = {
625
+ 4: 512,
626
+ 8: 512,
627
+ 16: 512,
628
+ 32: 512,
629
+ 64: 256 * channel_multiplier,
630
+ 128: 128 * channel_multiplier,
631
+ 256: 64 * channel_multiplier,
632
+ 512: 32 * channel_multiplier,
633
+ 1024: 16 * channel_multiplier,
634
+ }
635
+
636
+ convs = [ConvLayer(3, channels[size], 1)]
637
+
638
+ log_size = int(math.log(size, 2))
639
+
640
+ in_channel = channels[size]
641
+
642
+ for i in range(log_size, 2, -1):
643
+ out_channel = channels[2 ** (i - 1)]
644
+
645
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
646
+
647
+ in_channel = out_channel
648
+
649
+ self.convs = nn.Sequential(*convs)
650
+
651
+ self.stddev_group = 4
652
+ self.stddev_feat = 1
653
+
654
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
655
+ self.final_linear = nn.Sequential(
656
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
657
+ EqualLinear(channels[4], 1),
658
+ )
659
+
660
+ def forward(self, input):
661
+ out = self.convs(input)
662
+
663
+ batch, channel, height, width = out.shape
664
+ group = min(batch, self.stddev_group)
665
+ stddev = out.view(
666
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
667
+ )
668
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
669
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
670
+ stddev = stddev.repeat(group, 1, height, width)
671
+ out = torch.cat([out, stddev], 1)
672
+
673
+ out = self.final_conv(out)
674
+
675
+ out = out.view(batch, -1)
676
+ out = self.final_linear(out)
677
+
678
+ return out
e4e/models/stylegan2/op/__init__.py ADDED
File without changes
e4e/models/stylegan2/op/fused_act.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.autograd import Function
6
+ from torch.utils.cpp_extension import load
7
+
8
+ module_path = os.path.dirname(__file__)
9
+ fused = load(
10
+ 'fused',
11
+ sources=[
12
+ os.path.join(module_path, 'fused_bias_act.cpp'),
13
+ os.path.join(module_path, 'fused_bias_act_kernel.cu'),
14
+ ],
15
+ )
16
+
17
+
18
+ class FusedLeakyReLUFunctionBackward(Function):
19
+ @staticmethod
20
+ def forward(ctx, grad_output, out, negative_slope, scale):
21
+ ctx.save_for_backward(out)
22
+ ctx.negative_slope = negative_slope
23
+ ctx.scale = scale
24
+
25
+ empty = grad_output.new_empty(0)
26
+
27
+ grad_input = fused.fused_bias_act(
28
+ grad_output, empty, out, 3, 1, negative_slope, scale
29
+ )
30
+
31
+ dim = [0]
32
+
33
+ if grad_input.ndim > 2:
34
+ dim += list(range(2, grad_input.ndim))
35
+
36
+ grad_bias = grad_input.sum(dim).detach()
37
+
38
+ return grad_input, grad_bias
39
+
40
+ @staticmethod
41
+ def backward(ctx, gradgrad_input, gradgrad_bias):
42
+ out, = ctx.saved_tensors
43
+ gradgrad_out = fused.fused_bias_act(
44
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
45
+ )
46
+
47
+ return gradgrad_out, None, None, None
48
+
49
+
50
+ class FusedLeakyReLUFunction(Function):
51
+ @staticmethod
52
+ def forward(ctx, input, bias, negative_slope, scale):
53
+ empty = input.new_empty(0)
54
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
55
+ ctx.save_for_backward(out)
56
+ ctx.negative_slope = negative_slope
57
+ ctx.scale = scale
58
+
59
+ return out
60
+
61
+ @staticmethod
62
+ def backward(ctx, grad_output):
63
+ out, = ctx.saved_tensors
64
+
65
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
66
+ grad_output, out, ctx.negative_slope, ctx.scale
67
+ )
68
+
69
+ return grad_input, grad_bias, None, None
70
+
71
+
72
+ class FusedLeakyReLU(nn.Module):
73
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
74
+ super().__init__()
75
+
76
+ self.bias = nn.Parameter(torch.zeros(channel))
77
+ self.negative_slope = negative_slope
78
+ self.scale = scale
79
+
80
+ def forward(self, input):
81
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
82
+
83
+
84
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
85
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
e4e/models/stylegan2/op/fused_bias_act.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5
+ int act, int grad, float alpha, float scale);
6
+
7
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10
+
11
+ torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12
+ int act, int grad, float alpha, float scale) {
13
+ CHECK_CUDA(input);
14
+ CHECK_CUDA(bias);
15
+
16
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17
+ }
18
+
19
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21
+ }
e4e/models/stylegan2/op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22
+
23
+ scalar_t zero = 0.0;
24
+
25
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26
+ scalar_t x = p_x[xi];
27
+
28
+ if (use_bias) {
29
+ x += p_b[(xi / step_b) % size_b];
30
+ }
31
+
32
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
33
+
34
+ scalar_t y;
35
+
36
+ switch (act * 10 + grad) {
37
+ default:
38
+ case 10: y = x; break;
39
+ case 11: y = x; break;
40
+ case 12: y = 0.0; break;
41
+
42
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
43
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
44
+ case 32: y = 0.0; break;
45
+ }
46
+
47
+ out[xi] = y * scale;
48
+ }
49
+ }
50
+
51
+
52
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53
+ int act, int grad, float alpha, float scale) {
54
+ int curDevice = -1;
55
+ cudaGetDevice(&curDevice);
56
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57
+
58
+ auto x = input.contiguous();
59
+ auto b = bias.contiguous();
60
+ auto ref = refer.contiguous();
61
+
62
+ int use_bias = b.numel() ? 1 : 0;
63
+ int use_ref = ref.numel() ? 1 : 0;
64
+
65
+ int size_x = x.numel();
66
+ int size_b = b.numel();
67
+ int step_b = 1;
68
+
69
+ for (int i = 1 + 1; i < x.dim(); i++) {
70
+ step_b *= x.size(i);
71
+ }
72
+
73
+ int loop_x = 4;
74
+ int block_size = 4 * 32;
75
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76
+
77
+ auto y = torch::empty_like(x);
78
+
79
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
81
+ y.data_ptr<scalar_t>(),
82
+ x.data_ptr<scalar_t>(),
83
+ b.data_ptr<scalar_t>(),
84
+ ref.data_ptr<scalar_t>(),
85
+ act,
86
+ grad,
87
+ alpha,
88
+ scale,
89
+ loop_x,
90
+ size_x,
91
+ step_b,
92
+ size_b,
93
+ use_bias,
94
+ use_ref
95
+ );
96
+ });
97
+
98
+ return y;
99
+ }
e4e/models/stylegan2/op/upfirdn2d.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5
+ int up_x, int up_y, int down_x, int down_y,
6
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7
+
8
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11
+
12
+ torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13
+ int up_x, int up_y, int down_x, int down_y,
14
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15
+ CHECK_CUDA(input);
16
+ CHECK_CUDA(kernel);
17
+
18
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19
+ }
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23
+ }
e4e/models/stylegan2/op/upfirdn2d.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.autograd import Function
5
+ from torch.utils.cpp_extension import load
6
+
7
+ module_path = os.path.dirname(__file__)
8
+ upfirdn2d_op = load(
9
+ 'upfirdn2d',
10
+ sources=[
11
+ os.path.join(module_path, 'upfirdn2d.cpp'),
12
+ os.path.join(module_path, 'upfirdn2d_kernel.cu'),
13
+ ],
14
+ )
15
+
16
+
17
+ class UpFirDn2dBackward(Function):
18
+ @staticmethod
19
+ def forward(
20
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
21
+ ):
22
+ up_x, up_y = up
23
+ down_x, down_y = down
24
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
25
+
26
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
27
+
28
+ grad_input = upfirdn2d_op.upfirdn2d(
29
+ grad_output,
30
+ grad_kernel,
31
+ down_x,
32
+ down_y,
33
+ up_x,
34
+ up_y,
35
+ g_pad_x0,
36
+ g_pad_x1,
37
+ g_pad_y0,
38
+ g_pad_y1,
39
+ )
40
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
41
+
42
+ ctx.save_for_backward(kernel)
43
+
44
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
45
+
46
+ ctx.up_x = up_x
47
+ ctx.up_y = up_y
48
+ ctx.down_x = down_x
49
+ ctx.down_y = down_y
50
+ ctx.pad_x0 = pad_x0
51
+ ctx.pad_x1 = pad_x1
52
+ ctx.pad_y0 = pad_y0
53
+ ctx.pad_y1 = pad_y1
54
+ ctx.in_size = in_size
55
+ ctx.out_size = out_size
56
+
57
+ return grad_input
58
+
59
+ @staticmethod
60
+ def backward(ctx, gradgrad_input):
61
+ kernel, = ctx.saved_tensors
62
+
63
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
64
+
65
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
66
+ gradgrad_input,
67
+ kernel,
68
+ ctx.up_x,
69
+ ctx.up_y,
70
+ ctx.down_x,
71
+ ctx.down_y,
72
+ ctx.pad_x0,
73
+ ctx.pad_x1,
74
+ ctx.pad_y0,
75
+ ctx.pad_y1,
76
+ )
77
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
78
+ gradgrad_out = gradgrad_out.view(
79
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
80
+ )
81
+
82
+ return gradgrad_out, None, None, None, None, None, None, None, None
83
+
84
+
85
+ class UpFirDn2d(Function):
86
+ @staticmethod
87
+ def forward(ctx, input, kernel, up, down, pad):
88
+ up_x, up_y = up
89
+ down_x, down_y = down
90
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
91
+
92
+ kernel_h, kernel_w = kernel.shape
93
+ batch, channel, in_h, in_w = input.shape
94
+ ctx.in_size = input.shape
95
+
96
+ input = input.reshape(-1, in_h, in_w, 1)
97
+
98
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
99
+
100
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
101
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
102
+ ctx.out_size = (out_h, out_w)
103
+
104
+ ctx.up = (up_x, up_y)
105
+ ctx.down = (down_x, down_y)
106
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
107
+
108
+ g_pad_x0 = kernel_w - pad_x0 - 1
109
+ g_pad_y0 = kernel_h - pad_y0 - 1
110
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
111
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
112
+
113
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
114
+
115
+ out = upfirdn2d_op.upfirdn2d(
116
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
117
+ )
118
+ # out = out.view(major, out_h, out_w, minor)
119
+ out = out.view(-1, channel, out_h, out_w)
120
+
121
+ return out
122
+
123
+ @staticmethod
124
+ def backward(ctx, grad_output):
125
+ kernel, grad_kernel = ctx.saved_tensors
126
+
127
+ grad_input = UpFirDn2dBackward.apply(
128
+ grad_output,
129
+ kernel,
130
+ grad_kernel,
131
+ ctx.up,
132
+ ctx.down,
133
+ ctx.pad,
134
+ ctx.g_pad,
135
+ ctx.in_size,
136
+ ctx.out_size,
137
+ )
138
+
139
+ return grad_input, None, None, None, None
140
+
141
+
142
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
143
+ out = UpFirDn2d.apply(
144
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
145
+ )
146
+
147
+ return out
148
+
149
+
150
+ def upfirdn2d_native(
151
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
152
+ ):
153
+ _, in_h, in_w, minor = input.shape
154
+ kernel_h, kernel_w = kernel.shape
155
+
156
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
157
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
158
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
159
+
160
+ out = F.pad(
161
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
162
+ )
163
+ out = out[
164
+ :,
165
+ max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
166
+ max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
167
+ :,
168
+ ]
169
+
170
+ out = out.permute(0, 3, 1, 2)
171
+ out = out.reshape(
172
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
173
+ )
174
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
175
+ out = F.conv2d(out, w)
176
+ out = out.reshape(
177
+ -1,
178
+ minor,
179
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
180
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
181
+ )
182
+ out = out.permute(0, 2, 3, 1)
183
+
184
+ return out[:, ::down_y, ::down_x, :]
e4e/models/stylegan2/op/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
19
+ int c = a / b;
20
+
21
+ if (c * b > a) {
22
+ c--;
23
+ }
24
+
25
+ return c;
26
+ }
27
+
28
+
29
+ struct UpFirDn2DKernelParams {
30
+ int up_x;
31
+ int up_y;
32
+ int down_x;
33
+ int down_y;
34
+ int pad_x0;
35
+ int pad_x1;
36
+ int pad_y0;
37
+ int pad_y1;
38
+
39
+ int major_dim;
40
+ int in_h;
41
+ int in_w;
42
+ int minor_dim;
43
+ int kernel_h;
44
+ int kernel_w;
45
+ int out_h;
46
+ int out_w;
47
+ int loop_major;
48
+ int loop_x;
49
+ };
50
+
51
+
52
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y, int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
53
+ __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
54
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
55
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
56
+
57
+ __shared__ volatile float sk[kernel_h][kernel_w];
58
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
59
+
60
+ int minor_idx = blockIdx.x;
61
+ int tile_out_y = minor_idx / p.minor_dim;
62
+ minor_idx -= tile_out_y * p.minor_dim;
63
+ tile_out_y *= tile_out_h;
64
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
65
+ int major_idx_base = blockIdx.z * p.loop_major;
66
+
67
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
68
+ return;
69
+ }
70
+
71
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
72
+ int ky = tap_idx / kernel_w;
73
+ int kx = tap_idx - ky * kernel_w;
74
+ scalar_t v = 0.0;
75
+
76
+ if (kx < p.kernel_w & ky < p.kernel_h) {
77
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
78
+ }
79
+
80
+ sk[ky][kx] = v;
81
+ }
82
+
83
+ for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
84
+ for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
85
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
86
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
87
+ int tile_in_x = floor_div(tile_mid_x, up_x);
88
+ int tile_in_y = floor_div(tile_mid_y, up_y);
89
+
90
+ __syncthreads();
91
+
92
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
93
+ int rel_in_y = in_idx / tile_in_w;
94
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
95
+ int in_x = rel_in_x + tile_in_x;
96
+ int in_y = rel_in_y + tile_in_y;
97
+
98
+ scalar_t v = 0.0;
99
+
100
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
101
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
102
+ }
103
+
104
+ sx[rel_in_y][rel_in_x] = v;
105
+ }
106
+
107
+ __syncthreads();
108
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
109
+ int rel_out_y = out_idx / tile_out_w;
110
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
111
+ int out_x = rel_out_x + tile_out_x;
112
+ int out_y = rel_out_y + tile_out_y;
113
+
114
+ int mid_x = tile_mid_x + rel_out_x * down_x;
115
+ int mid_y = tile_mid_y + rel_out_y * down_y;
116
+ int in_x = floor_div(mid_x, up_x);
117
+ int in_y = floor_div(mid_y, up_y);
118
+ int rel_in_x = in_x - tile_in_x;
119
+ int rel_in_y = in_y - tile_in_y;
120
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
121
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
122
+
123
+ scalar_t v = 0.0;
124
+
125
+ #pragma unroll
126
+ for (int y = 0; y < kernel_h / up_y; y++)
127
+ #pragma unroll
128
+ for (int x = 0; x < kernel_w / up_x; x++)
129
+ v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
130
+
131
+ if (out_x < p.out_w & out_y < p.out_h) {
132
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
133
+ }
134
+ }
135
+ }
136
+ }
137
+ }
138
+
139
+
140
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
141
+ int up_x, int up_y, int down_x, int down_y,
142
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
143
+ int curDevice = -1;
144
+ cudaGetDevice(&curDevice);
145
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
146
+
147
+ UpFirDn2DKernelParams p;
148
+
149
+ auto x = input.contiguous();
150
+ auto k = kernel.contiguous();
151
+
152
+ p.major_dim = x.size(0);
153
+ p.in_h = x.size(1);
154
+ p.in_w = x.size(2);
155
+ p.minor_dim = x.size(3);
156
+ p.kernel_h = k.size(0);
157
+ p.kernel_w = k.size(1);
158
+ p.up_x = up_x;
159
+ p.up_y = up_y;
160
+ p.down_x = down_x;
161
+ p.down_y = down_y;
162
+ p.pad_x0 = pad_x0;
163
+ p.pad_x1 = pad_x1;
164
+ p.pad_y0 = pad_y0;
165
+ p.pad_y1 = pad_y1;
166
+
167
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
168
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
169
+
170
+ auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
171
+
172
+ int mode = -1;
173
+
174
+ int tile_out_h;
175
+ int tile_out_w;
176
+
177
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
178
+ mode = 1;
179
+ tile_out_h = 16;
180
+ tile_out_w = 64;
181
+ }
182
+
183
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
184
+ mode = 2;
185
+ tile_out_h = 16;
186
+ tile_out_w = 64;
187
+ }
188
+
189
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
190
+ mode = 3;
191
+ tile_out_h = 16;
192
+ tile_out_w = 64;
193
+ }
194
+
195
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
196
+ mode = 4;
197
+ tile_out_h = 16;
198
+ tile_out_w = 64;
199
+ }
200
+
201
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
202
+ mode = 5;
203
+ tile_out_h = 8;
204
+ tile_out_w = 32;
205
+ }
206
+
207
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
208
+ mode = 6;
209
+ tile_out_h = 8;
210
+ tile_out_w = 32;
211
+ }
212
+
213
+ dim3 block_size;
214
+ dim3 grid_size;
215
+
216
+ if (tile_out_h > 0 && tile_out_w) {
217
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
218
+ p.loop_x = 1;
219
+ block_size = dim3(32 * 8, 1, 1);
220
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
221
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
222
+ (p.major_dim - 1) / p.loop_major + 1);
223
+ }
224
+
225
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
226
+ switch (mode) {
227
+ case 1:
228
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
229
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
230
+ );
231
+
232
+ break;
233
+
234
+ case 2:
235
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64><<<grid_size, block_size, 0, stream>>>(
236
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
237
+ );
238
+
239
+ break;
240
+
241
+ case 3:
242
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
243
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
244
+ );
245
+
246
+ break;
247
+
248
+ case 4:
249
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64><<<grid_size, block_size, 0, stream>>>(
250
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
251
+ );
252
+
253
+ break;
254
+
255
+ case 5:
256
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
257
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
258
+ );
259
+
260
+ break;
261
+
262
+ case 6:
263
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
264
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
265
+ );
266
+
267
+ break;
268
+ }
269
+ });
270
+
271
+ return out;
272
+ }
e4e/notebooks/images/car_img.jpg ADDED
e4e/notebooks/images/church_img.jpg ADDED
e4e/notebooks/images/horse_img.jpg ADDED
e4e/notebooks/images/input_img.jpg ADDED
e4e/options/__init__.py ADDED
File without changes
e4e/options/train_options.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ from configs.paths_config import model_paths
3
+
4
+
5
+ class TrainOptions:
6
+
7
+ def __init__(self):
8
+ self.parser = ArgumentParser()
9
+ self.initialize()
10
+
11
+ def initialize(self):
12
+ self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
13
+ self.parser.add_argument('--dataset_type', default='ffhq_encode', type=str,
14
+ help='Type of dataset/experiment to run')
15
+ self.parser.add_argument('--encoder_type', default='Encoder4Editing', type=str, help='Which encoder to use')
16
+
17
+ self.parser.add_argument('--batch_size', default=4, type=int, help='Batch size for training')
18
+ self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference')
19
+ self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers')
20
+ self.parser.add_argument('--test_workers', default=2, type=int,
21
+ help='Number of test/inference dataloader workers')
22
+
23
+ self.parser.add_argument('--learning_rate', default=0.0001, type=float, help='Optimizer learning rate')
24
+ self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use')
25
+ self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model')
26
+ self.parser.add_argument('--start_from_latent_avg', action='store_true',
27
+ help='Whether to add average latent vector to generate codes from encoder.')
28
+ self.parser.add_argument('--lpips_type', default='alex', type=str, help='LPIPS backbone')
29
+
30
+ self.parser.add_argument('--lpips_lambda', default=0.8, type=float, help='LPIPS loss multiplier factor')
31
+ self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor')
32
+ self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='L2 loss multiplier factor')
33
+
34
+ self.parser.add_argument('--stylegan_weights', default=model_paths['stylegan_ffhq'], type=str,
35
+ help='Path to StyleGAN model weights')
36
+ self.parser.add_argument('--stylegan_size', default=1024, type=int,
37
+ help='size of pretrained StyleGAN Generator')
38
+ self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint')
39
+
40
+ self.parser.add_argument('--max_steps', default=500000, type=int, help='Maximum number of training steps')
41
+ self.parser.add_argument('--image_interval', default=100, type=int,
42
+ help='Interval for logging train images during training')
43
+ self.parser.add_argument('--board_interval', default=50, type=int,
44
+ help='Interval for logging metrics to tensorboard')
45
+ self.parser.add_argument('--val_interval', default=1000, type=int, help='Validation interval')
46
+ self.parser.add_argument('--save_interval', default=None, type=int, help='Model checkpoint interval')
47
+
48
+ # Discriminator flags
49
+ self.parser.add_argument('--w_discriminator_lambda', default=0, type=float, help='Dw loss multiplier')
50
+ self.parser.add_argument('--w_discriminator_lr', default=2e-5, type=float, help='Dw learning rate')
51
+ self.parser.add_argument("--r1", type=float, default=10, help="weight of the r1 regularization")
52
+ self.parser.add_argument("--d_reg_every", type=int, default=16,
53
+ help="interval for applying r1 regularization")
54
+ self.parser.add_argument('--use_w_pool', action='store_true',
55
+ help='Whether to store a latnet codes pool for the discriminator\'s training')
56
+ self.parser.add_argument("--w_pool_size", type=int, default=50,
57
+ help="W\'s pool size, depends on --use_w_pool")
58
+
59
+ # e4e specific
60
+ self.parser.add_argument('--delta_norm', type=int, default=2, help="norm type of the deltas")
61
+ self.parser.add_argument('--delta_norm_lambda', type=float, default=2e-4, help="lambda for delta norm loss")
62
+
63
+ # Progressive training
64
+ self.parser.add_argument('--progressive_steps', nargs='+', type=int, default=None,
65
+ help="The training steps of training new deltas. steps[i] starts the delta_i training")
66
+ self.parser.add_argument('--progressive_start', type=int, default=None,
67
+ help="The training step to start training the deltas, overrides progressive_steps")
68
+ self.parser.add_argument('--progressive_step_every', type=int, default=2_000,
69
+ help="Amount of training steps for each progressive step")
70
+
71
+ # Save additional training info to enable future training continuation from produced checkpoints
72
+ self.parser.add_argument('--save_training_data', action='store_true',
73
+ help='Save intermediate training data to resume training from the checkpoint')
74
+ self.parser.add_argument('--sub_exp_dir', default=None, type=str, help='Name of sub experiment directory')
75
+ self.parser.add_argument('--keep_optimizer', action='store_true',
76
+ help='Whether to continue from the checkpoint\'s optimizer')
77
+ self.parser.add_argument('--resume_training_from_ckpt', default=None, type=str,
78
+ help='Path to training checkpoint, works when --save_training_data was set to True')
79
+ self.parser.add_argument('--update_param_list', nargs='+', type=str, default=None,
80
+ help="Name of training parameters to update the loaded training checkpoint")
81
+
82
+ def parse(self):
83
+ opts = self.parser.parse_args()
84
+ return opts