trysem feng2022 commited on
Commit
5b68e3e
·
0 Parent(s):

Duplicate from feng2022/Time-TravelRephotography

Browse files

Co-authored-by: Time-travelRephotography <feng2022@users.noreply.huggingface.co>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +27 -0
  2. .gitignore +133 -0
  3. .gitmodules +9 -0
  4. README.md +14 -0
  5. Time_TravelRephotography/LICENSE +21 -0
  6. Time_TravelRephotography/LICENSE-NVIDIA +101 -0
  7. Time_TravelRephotography/LICENSE-STYLEGAN2 +21 -0
  8. Time_TravelRephotography/losses/color_transfer_loss.py +60 -0
  9. Time_TravelRephotography/losses/contextual_loss/.gitignore +104 -0
  10. Time_TravelRephotography/losses/contextual_loss/LICENSE +21 -0
  11. Time_TravelRephotography/losses/contextual_loss/__init__.py +1 -0
  12. Time_TravelRephotography/losses/contextual_loss/config.py +2 -0
  13. Time_TravelRephotography/losses/contextual_loss/functional.py +198 -0
  14. Time_TravelRephotography/losses/contextual_loss/modules/__init__.py +4 -0
  15. Time_TravelRephotography/losses/contextual_loss/modules/contextual.py +121 -0
  16. Time_TravelRephotography/losses/contextual_loss/modules/contextual_bilateral.py +69 -0
  17. Time_TravelRephotography/losses/contextual_loss/modules/vgg.py +48 -0
  18. Time_TravelRephotography/losses/joint_loss.py +167 -0
  19. Time_TravelRephotography/losses/perceptual_loss.py +111 -0
  20. Time_TravelRephotography/losses/reconstruction.py +119 -0
  21. Time_TravelRephotography/losses/regularize_noise.py +37 -0
  22. Time_TravelRephotography/model.py +697 -0
  23. Time_TravelRephotography/models/__init__.py +0 -0
  24. Time_TravelRephotography/models/degrade.py +122 -0
  25. Time_TravelRephotography/models/encoder.py +66 -0
  26. Time_TravelRephotography/models/encoder4editing/.gitignore +133 -0
  27. Time_TravelRephotography/models/encoder4editing/LICENSE +21 -0
  28. Time_TravelRephotography/models/encoder4editing/README.md +143 -0
  29. Time_TravelRephotography/models/encoder4editing/__init__.py +15 -0
  30. Time_TravelRephotography/models/encoder4editing/bash_scripts/inference.sh +15 -0
  31. Time_TravelRephotography/models/encoder4editing/configs/__init__.py +0 -0
  32. Time_TravelRephotography/models/encoder4editing/configs/data_configs.py +41 -0
  33. Time_TravelRephotography/models/encoder4editing/configs/paths_config.py +28 -0
  34. Time_TravelRephotography/models/encoder4editing/configs/transforms_config.py +62 -0
  35. Time_TravelRephotography/models/encoder4editing/criteria/__init__.py +0 -0
  36. Time_TravelRephotography/models/encoder4editing/criteria/id_loss.py +47 -0
  37. Time_TravelRephotography/models/encoder4editing/criteria/lpips/__init__.py +0 -0
  38. Time_TravelRephotography/models/encoder4editing/criteria/lpips/lpips.py +35 -0
  39. Time_TravelRephotography/models/encoder4editing/criteria/lpips/networks.py +96 -0
  40. Time_TravelRephotography/models/encoder4editing/criteria/lpips/utils.py +30 -0
  41. Time_TravelRephotography/models/encoder4editing/criteria/moco_loss.py +71 -0
  42. Time_TravelRephotography/models/encoder4editing/criteria/w_norm.py +14 -0
  43. Time_TravelRephotography/models/encoder4editing/datasets/__init__.py +0 -0
  44. Time_TravelRephotography/models/encoder4editing/datasets/gt_res_dataset.py +32 -0
  45. Time_TravelRephotography/models/encoder4editing/datasets/images_dataset.py +33 -0
  46. Time_TravelRephotography/models/encoder4editing/datasets/inference_dataset.py +25 -0
  47. Time_TravelRephotography/models/encoder4editing/editings/ganspace.py +22 -0
  48. Time_TravelRephotography/models/encoder4editing/editings/ganspace_pca/cars_pca.pt +3 -0
  49. Time_TravelRephotography/models/encoder4editing/editings/ganspace_pca/ffhq_pca.pt +3 -0
  50. Time_TravelRephotography/models/encoder4editing/editings/interfacegan_directions/age.pt +3 -0
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.onnx filter=lfs diff=lfs merge=lfs -text
13
+ *.ot filter=lfs diff=lfs merge=lfs -text
14
+ *.parquet filter=lfs diff=lfs merge=lfs -text
15
+ *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.pt filter=lfs diff=lfs merge=lfs -text
17
+ *.pth filter=lfs diff=lfs merge=lfs -text
18
+ *.rar filter=lfs diff=lfs merge=lfs -text
19
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
21
+ *.tflite filter=lfs diff=lfs merge=lfs -text
22
+ *.tgz filter=lfs diff=lfs merge=lfs -text
23
+ *.wasm filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ wandb/
132
+ *.lmdb/
133
+ *.pkl
.gitmodules ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [submodule "third_party/face_parsing"]
2
+ path = third_party/face_parsing
3
+ url = https://github.com/Time-Travel-Rephotography/face-parsing.PyTorch.git
4
+ [submodule "models/encoder4editing"]
5
+ path = models/encoder4editing
6
+ url = https://github.com/Time-Travel-Rephotography/encoder4editing.git
7
+ [submodule "losses/contextual_loss"]
8
+ path = losses/contextual_loss
9
+ url = https://github.com/Time-Travel-Rephotography/contextual_loss_pytorch.git
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Time TravelRephotography
3
+ emoji: 🦀
4
+ colorFrom: yellow
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 2.9.4
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: feng2022/Time-TravelRephotography
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
Time_TravelRephotography/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Time-Travel-Rephotography
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Time_TravelRephotography/LICENSE-NVIDIA ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+
3
+
4
+ Nvidia Source Code License-NC
5
+
6
+ =======================================================================
7
+
8
+ 1. Definitions
9
+
10
+ "Licensor" means any person or entity that distributes its Work.
11
+
12
+ "Software" means the original work of authorship made available under
13
+ this License.
14
+
15
+ "Work" means the Software and any additions to or derivative works of
16
+ the Software that are made available under this License.
17
+
18
+ "Nvidia Processors" means any central processing unit (CPU), graphics
19
+ processing unit (GPU), field-programmable gate array (FPGA),
20
+ application-specific integrated circuit (ASIC) or any combination
21
+ thereof designed, made, sold, or provided by Nvidia or its affiliates.
22
+
23
+ The terms "reproduce," "reproduction," "derivative works," and
24
+ "distribution" have the meaning as provided under U.S. copyright law;
25
+ provided, however, that for the purposes of this License, derivative
26
+ works shall not include works that remain separable from, or merely
27
+ link (or bind by name) to the interfaces of, the Work.
28
+
29
+ Works, including the Software, are "made available" under this License
30
+ by including in or with the Work either (a) a copyright notice
31
+ referencing the applicability of this License to the Work, or (b) a
32
+ copy of this License.
33
+
34
+ 2. License Grants
35
+
36
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
37
+ License, each Licensor grants to you a perpetual, worldwide,
38
+ non-exclusive, royalty-free, copyright license to reproduce,
39
+ prepare derivative works of, publicly display, publicly perform,
40
+ sublicense and distribute its Work and any resulting derivative
41
+ works in any form.
42
+
43
+ 3. Limitations
44
+
45
+ 3.1 Redistribution. You may reproduce or distribute the Work only
46
+ if (a) you do so under this License, (b) you include a complete
47
+ copy of this License with your distribution, and (c) you retain
48
+ without modification any copyright, patent, trademark, or
49
+ attribution notices that are present in the Work.
50
+
51
+ 3.2 Derivative Works. You may specify that additional or different
52
+ terms apply to the use, reproduction, and distribution of your
53
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
54
+ provide that the use limitation in Section 3.3 applies to your
55
+ derivative works, and (b) you identify the specific derivative
56
+ works that are subject to Your Terms. Notwithstanding Your Terms,
57
+ this License (including the redistribution requirements in Section
58
+ 3.1) will continue to apply to the Work itself.
59
+
60
+ 3.3 Use Limitation. The Work and any derivative works thereof only
61
+ may be used or intended for use non-commercially. The Work or
62
+ derivative works thereof may be used or intended for use by Nvidia
63
+ or its affiliates commercially or non-commercially. As used herein,
64
+ "non-commercially" means for research or evaluation purposes only.
65
+
66
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
67
+ against any Licensor (including any claim, cross-claim or
68
+ counterclaim in a lawsuit) to enforce any patents that you allege
69
+ are infringed by any Work, then your rights under this License from
70
+ such Licensor (including the grants in Sections 2.1 and 2.2) will
71
+ terminate immediately.
72
+
73
+ 3.5 Trademarks. This License does not grant any rights to use any
74
+ Licensor's or its affiliates' names, logos, or trademarks, except
75
+ as necessary to reproduce the notices described in this License.
76
+
77
+ 3.6 Termination. If you violate any term of this License, then your
78
+ rights under this License (including the grants in Sections 2.1 and
79
+ 2.2) will terminate immediately.
80
+
81
+ 4. Disclaimer of Warranty.
82
+
83
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
84
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
85
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
86
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
87
+ THIS LICENSE.
88
+
89
+ 5. Limitation of Liability.
90
+
91
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
92
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
93
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
94
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
95
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
96
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
97
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
98
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
99
+ THE POSSIBILITY OF SUCH DAMAGES.
100
+
101
+ =======================================================================
Time_TravelRephotography/LICENSE-STYLEGAN2 ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 Kim Seonghyeon
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Time_TravelRephotography/losses/color_transfer_loss.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.functional import (
6
+ smooth_l1_loss,
7
+ )
8
+
9
+
10
+ def flatten_CHW(im: torch.Tensor) -> torch.Tensor:
11
+ """
12
+ (B, C, H, W) -> (B, -1)
13
+ """
14
+ B = im.shape[0]
15
+ return im.reshape(B, -1)
16
+
17
+
18
+ def stddev(x: torch.Tensor) -> torch.Tensor:
19
+ """
20
+ x: (B, -1), assume with mean normalized
21
+ Retuens:
22
+ stddev: (B)
23
+ """
24
+ return torch.sqrt(torch.mean(x * x, dim=-1))
25
+
26
+
27
+ def gram_matrix(input_):
28
+ B, C = input_.shape[:2]
29
+ features = input_.view(B, C, -1)
30
+ N = features.shape[-1]
31
+ G = torch.bmm(features, features.transpose(1, 2)) # C x C
32
+ return G.div(C * N)
33
+
34
+
35
+ class ColorTransferLoss(nn.Module):
36
+ """Penalize the gram matrix difference between StyleGAN2's ToRGB outputs"""
37
+ def __init__(
38
+ self,
39
+ init_rgbs,
40
+ scale_rgb: bool = False
41
+ ):
42
+ super().__init__()
43
+
44
+ with torch.no_grad():
45
+ init_feats = [x.detach() for x in init_rgbs]
46
+ self.stds = [stddev(flatten_CHW(rgb)) if scale_rgb else 1 for rgb in init_feats] # (B, 1, 1, 1) or scalar
47
+ self.grams = [gram_matrix(rgb / std) for rgb, std in zip(init_feats, self.stds)]
48
+
49
+ def forward(self, rgbs: List[torch.Tensor], level: int = None):
50
+ if level is None:
51
+ level = len(self.grams)
52
+
53
+ feats = rgbs
54
+ loss = 0
55
+ for i, (rgb, std) in enumerate(zip(feats[:level], self.stds[:level])):
56
+ G = gram_matrix(rgb / std)
57
+ loss = loss + smooth_l1_loss(G, self.grams[i])
58
+
59
+ return loss
60
+
Time_TravelRephotography/losses/contextual_loss/.gitignore ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+
50
+ # Translations
51
+ *.mo
52
+ *.pot
53
+
54
+ # Django stuff:
55
+ *.log
56
+ local_settings.py
57
+ db.sqlite3
58
+
59
+ # Flask stuff:
60
+ instance/
61
+ .webassets-cache
62
+
63
+ # Scrapy stuff:
64
+ .scrapy
65
+
66
+ # Sphinx documentation
67
+ docs/_build/
68
+
69
+ # PyBuilder
70
+ target/
71
+
72
+ # Jupyter Notebook
73
+ .ipynb_checkpoints
74
+
75
+ # pyenv
76
+ .python-version
77
+
78
+ # celery beat schedule file
79
+ celerybeat-schedule
80
+
81
+ # SageMath parsed files
82
+ *.sage.py
83
+
84
+ # Environments
85
+ .env
86
+ .venv
87
+ env/
88
+ venv/
89
+ ENV/
90
+ env.bak/
91
+ venv.bak/
92
+
93
+ # Spyder project settings
94
+ .spyderproject
95
+ .spyproject
96
+
97
+ # Rope project settings
98
+ .ropeproject
99
+
100
+ # mkdocs documentation
101
+ /site
102
+
103
+ # mypy
104
+ .mypy_cache/
Time_TravelRephotography/losses/contextual_loss/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 Sou Uchida
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Time_TravelRephotography/losses/contextual_loss/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modules import *
Time_TravelRephotography/losses/contextual_loss/config.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # TODO: add supports for L1, L2 etc.
2
+ LOSS_TYPES = ['cosine', 'l1', 'l2']
Time_TravelRephotography/losses/contextual_loss/functional.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from .config import LOSS_TYPES
5
+
6
+ __all__ = ['contextual_loss', 'contextual_bilateral_loss']
7
+
8
+
9
+ def contextual_loss(x: torch.Tensor,
10
+ y: torch.Tensor,
11
+ band_width: float = 0.5,
12
+ loss_type: str = 'cosine',
13
+ all_dist: bool = False):
14
+ """
15
+ Computes contextual loss between x and y.
16
+ The most of this code is copied from
17
+ https://gist.github.com/yunjey/3105146c736f9c1055463c33b4c989da.
18
+
19
+ Parameters
20
+ ---
21
+ x : torch.Tensor
22
+ features of shape (N, C, H, W).
23
+ y : torch.Tensor
24
+ features of shape (N, C, H, W).
25
+ band_width : float, optional
26
+ a band-width parameter used to convert distance to similarity.
27
+ in the paper, this is described as :math:`h`.
28
+ loss_type : str, optional
29
+ a loss type to measure the distance between features.
30
+ Note: `l1` and `l2` frequently raises OOM.
31
+
32
+ Returns
33
+ ---
34
+ cx_loss : torch.Tensor
35
+ contextual loss between x and y (Eq (1) in the paper)
36
+ """
37
+
38
+ assert x.size() == y.size(), 'input tensor must have the same size.'
39
+ assert loss_type in LOSS_TYPES, f'select a loss type from {LOSS_TYPES}.'
40
+
41
+ N, C, H, W = x.size()
42
+
43
+ if loss_type == 'cosine':
44
+ dist_raw = compute_cosine_distance(x, y)
45
+ elif loss_type == 'l1':
46
+ dist_raw = compute_l1_distance(x, y)
47
+ elif loss_type == 'l2':
48
+ dist_raw = compute_l2_distance(x, y)
49
+
50
+ dist_tilde = compute_relative_distance(dist_raw)
51
+ cx = compute_cx(dist_tilde, band_width)
52
+ if all_dist:
53
+ return cx
54
+
55
+ cx = torch.mean(torch.max(cx, dim=1)[0], dim=1) # Eq(1)
56
+ cx_loss = torch.mean(-torch.log(cx + 1e-5)) # Eq(5)
57
+
58
+ return cx_loss
59
+
60
+
61
+ # TODO: Operation check
62
+ def contextual_bilateral_loss(x: torch.Tensor,
63
+ y: torch.Tensor,
64
+ weight_sp: float = 0.1,
65
+ band_width: float = 1.,
66
+ loss_type: str = 'cosine'):
67
+ """
68
+ Computes Contextual Bilateral (CoBi) Loss between x and y,
69
+ proposed in https://arxiv.org/pdf/1905.05169.pdf.
70
+
71
+ Parameters
72
+ ---
73
+ x : torch.Tensor
74
+ features of shape (N, C, H, W).
75
+ y : torch.Tensor
76
+ features of shape (N, C, H, W).
77
+ band_width : float, optional
78
+ a band-width parameter used to convert distance to similarity.
79
+ in the paper, this is described as :math:`h`.
80
+ loss_type : str, optional
81
+ a loss type to measure the distance between features.
82
+ Note: `l1` and `l2` frequently raises OOM.
83
+
84
+ Returns
85
+ ---
86
+ cx_loss : torch.Tensor
87
+ contextual loss between x and y (Eq (1) in the paper).
88
+ k_arg_max_NC : torch.Tensor
89
+ indices to maximize similarity over channels.
90
+ """
91
+
92
+ assert x.size() == y.size(), 'input tensor must have the same size.'
93
+ assert loss_type in LOSS_TYPES, f'select a loss type from {LOSS_TYPES}.'
94
+
95
+ # spatial loss
96
+ grid = compute_meshgrid(x.shape).to(x.device)
97
+ dist_raw = compute_l2_distance(grid, grid)
98
+ dist_tilde = compute_relative_distance(dist_raw)
99
+ cx_sp = compute_cx(dist_tilde, band_width)
100
+
101
+ # feature loss
102
+ if loss_type == 'cosine':
103
+ dist_raw = compute_cosine_distance(x, y)
104
+ elif loss_type == 'l1':
105
+ dist_raw = compute_l1_distance(x, y)
106
+ elif loss_type == 'l2':
107
+ dist_raw = compute_l2_distance(x, y)
108
+ dist_tilde = compute_relative_distance(dist_raw)
109
+ cx_feat = compute_cx(dist_tilde, band_width)
110
+
111
+ # combined loss
112
+ cx_combine = (1. - weight_sp) * cx_feat + weight_sp * cx_sp
113
+
114
+ k_max_NC, _ = torch.max(cx_combine, dim=2, keepdim=True)
115
+
116
+ cx = k_max_NC.mean(dim=1)
117
+ cx_loss = torch.mean(-torch.log(cx + 1e-5))
118
+
119
+ return cx_loss
120
+
121
+
122
+ def compute_cx(dist_tilde, band_width):
123
+ w = torch.exp((1 - dist_tilde) / band_width) # Eq(3)
124
+ cx = w / torch.sum(w, dim=2, keepdim=True) # Eq(4)
125
+ return cx
126
+
127
+
128
+ def compute_relative_distance(dist_raw):
129
+ dist_min, _ = torch.min(dist_raw, dim=2, keepdim=True)
130
+ dist_tilde = dist_raw / (dist_min + 1e-5)
131
+ return dist_tilde
132
+
133
+
134
+ def compute_cosine_distance(x, y):
135
+ # mean shifting by channel-wise mean of `y`.
136
+ y_mu = y.mean(dim=(0, 2, 3), keepdim=True)
137
+ x_centered = x - y_mu
138
+ y_centered = y - y_mu
139
+
140
+ # L2 normalization
141
+ x_normalized = F.normalize(x_centered, p=2, dim=1)
142
+ y_normalized = F.normalize(y_centered, p=2, dim=1)
143
+
144
+ # channel-wise vectorization
145
+ N, C, *_ = x.size()
146
+ x_normalized = x_normalized.reshape(N, C, -1) # (N, C, H*W)
147
+ y_normalized = y_normalized.reshape(N, C, -1) # (N, C, H*W)
148
+
149
+ # consine similarity
150
+ cosine_sim = torch.bmm(x_normalized.transpose(1, 2),
151
+ y_normalized) # (N, H*W, H*W)
152
+
153
+ # convert to distance
154
+ dist = 1 - cosine_sim
155
+
156
+ return dist
157
+
158
+
159
+ # TODO: Considering avoiding OOM.
160
+ def compute_l1_distance(x: torch.Tensor, y: torch.Tensor):
161
+ N, C, H, W = x.size()
162
+ x_vec = x.view(N, C, -1)
163
+ y_vec = y.view(N, C, -1)
164
+
165
+ dist = x_vec.unsqueeze(2) - y_vec.unsqueeze(3)
166
+ dist = dist.abs().sum(dim=1)
167
+ dist = dist.transpose(1, 2).reshape(N, H*W, H*W)
168
+ dist = dist.clamp(min=0.)
169
+
170
+ return dist
171
+
172
+
173
+ # TODO: Considering avoiding OOM.
174
+ def compute_l2_distance(x, y):
175
+ N, C, H, W = x.size()
176
+ x_vec = x.view(N, C, -1)
177
+ y_vec = y.view(N, C, -1)
178
+ x_s = torch.sum(x_vec ** 2, dim=1)
179
+ y_s = torch.sum(y_vec ** 2, dim=1)
180
+
181
+ A = y_vec.transpose(1, 2) @ x_vec
182
+ dist = y_s - 2 * A + x_s.transpose(0, 1)
183
+ dist = dist.transpose(1, 2).reshape(N, H*W, H*W)
184
+ dist = dist.clamp(min=0.)
185
+
186
+ return dist
187
+
188
+
189
+ def compute_meshgrid(shape):
190
+ N, C, H, W = shape
191
+ rows = torch.arange(0, H, dtype=torch.float32) / (H + 1)
192
+ cols = torch.arange(0, W, dtype=torch.float32) / (W + 1)
193
+
194
+ feature_grid = torch.meshgrid(rows, cols)
195
+ feature_grid = torch.stack(feature_grid).unsqueeze(0)
196
+ feature_grid = torch.cat([feature_grid for _ in range(N)], dim=0)
197
+
198
+ return feature_grid
Time_TravelRephotography/losses/contextual_loss/modules/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .contextual import ContextualLoss
2
+ from .contextual_bilateral import ContextualBilateralLoss
3
+
4
+ __all__ = ['ContextualLoss', 'ContextualBilateralLoss']
Time_TravelRephotography/losses/contextual_loss/modules/contextual.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import (
3
+ Iterable,
4
+ List,
5
+ Optional,
6
+ )
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from .vgg import VGG19
13
+ from .. import functional as F
14
+ from ..config import LOSS_TYPES
15
+
16
+
17
+ class ContextualLoss(nn.Module):
18
+ """
19
+ Creates a criterion that measures the contextual loss.
20
+
21
+ Parameters
22
+ ---
23
+ band_width : int, optional
24
+ a band_width parameter described as :math:`h` in the paper.
25
+ use_vgg : bool, optional
26
+ if you want to use VGG feature, set this `True`.
27
+ vgg_layer : str, optional
28
+ intermidiate layer name for VGG feature.
29
+ Now we support layer names:
30
+ `['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4']`
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ band_width: float = 0.5,
36
+ loss_type: str = 'cosine',
37
+ use_vgg: bool = False,
38
+ vgg_layers: List[str] = ['relu3_4'],
39
+ feature_1d_size: int = 64,
40
+ ):
41
+
42
+ super().__init__()
43
+
44
+ assert band_width > 0, 'band_width parameter must be positive.'
45
+ assert loss_type in LOSS_TYPES,\
46
+ f'select a loss type from {LOSS_TYPES}.'
47
+
48
+ self.loss_type = loss_type
49
+ self.band_width = band_width
50
+ self.feature_1d_size = feature_1d_size
51
+
52
+ if use_vgg:
53
+ self.vgg_model = VGG19()
54
+ self.vgg_layers = vgg_layers
55
+ self.register_buffer(
56
+ name='vgg_mean',
57
+ tensor=torch.tensor(
58
+ [[[0.485]], [[0.456]], [[0.406]]], requires_grad=False)
59
+ )
60
+ self.register_buffer(
61
+ name='vgg_std',
62
+ tensor=torch.tensor(
63
+ [[[0.229]], [[0.224]], [[0.225]]], requires_grad=False)
64
+ )
65
+
66
+ def forward(self, x: torch.Tensor, y: torch.Tensor, all_dist: bool = False):
67
+ if not hasattr(self, 'vgg_model'):
68
+ return self.contextual_loss(x, y, self.feature_1d_size, self.band_width, all_dist=all_dist)
69
+
70
+
71
+ x = self.forward_vgg(x)
72
+ y = self.forward_vgg(y)
73
+
74
+ loss = 0
75
+ for layer in self.vgg_layers:
76
+ # picking up vgg feature maps
77
+ fx = getattr(x, layer)
78
+ fy = getattr(y, layer)
79
+ loss = loss + self.contextual_loss(
80
+ fx, fy, self.feature_1d_size, self.band_width, all_dist=all_dist, loss_type=self.loss_type
81
+ )
82
+ return loss
83
+
84
+ def forward_vgg(self, x: torch.Tensor):
85
+ assert x.shape[1] == 3, 'VGG model takes 3 chennel images.'
86
+ # [-1, 1] -> [0, 1]
87
+ x = (x + 1) * 0.5
88
+
89
+ # normalization
90
+ x = x.sub(self.vgg_mean.detach()).div(self.vgg_std)
91
+ return self.vgg_model(x)
92
+
93
+ @classmethod
94
+ def contextual_loss(
95
+ cls,
96
+ x: torch.Tensor, y: torch.Tensor,
97
+ feature_1d_size: int,
98
+ band_width: int,
99
+ all_dist: bool = False,
100
+ loss_type: str = 'cosine',
101
+ ) -> torch.Tensor:
102
+ feature_size = feature_1d_size ** 2
103
+ if np.prod(x.shape[2:]) > feature_size or np.prod(y.shape[2:]) > feature_size:
104
+ x, indices = cls.random_sampling(x, feature_1d_size=feature_1d_size)
105
+ y, _ = cls.random_sampling(y, feature_1d_size=feature_1d_size, indices=indices)
106
+
107
+ return F.contextual_loss(x, y, band_width, all_dist=all_dist, loss_type=loss_type)
108
+
109
+ @staticmethod
110
+ def random_sampling(
111
+ tensor_NCHW: torch.Tensor, feature_1d_size: int, indices: Optional[List] = None
112
+ ):
113
+ N, C, H, W = tensor_NCHW.shape
114
+ S = H * W
115
+ tensor_NCS = tensor_NCHW.reshape([N, C, S])
116
+ if indices is None:
117
+ all_indices = list(range(S))
118
+ random.shuffle(all_indices)
119
+ indices = all_indices[:feature_1d_size**2]
120
+ res = tensor_NCS[:, :, indices].reshape(N, -1, feature_1d_size, feature_1d_size)
121
+ return res, indices
Time_TravelRephotography/losses/contextual_loss/modules/contextual_bilateral.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vgg import VGG19
5
+ from .. import functional as F
6
+ from ..config import LOSS_TYPES
7
+
8
+
9
+ class ContextualBilateralLoss(nn.Module):
10
+ """
11
+ Creates a criterion that measures the contextual bilateral loss.
12
+
13
+ Parameters
14
+ ---
15
+ weight_sp : float, optional
16
+ a balancing weight between spatial and feature loss.
17
+ band_width : int, optional
18
+ a band_width parameter described as :math:`h` in the paper.
19
+ use_vgg : bool, optional
20
+ if you want to use VGG feature, set this `True`.
21
+ vgg_layer : str, optional
22
+ intermidiate layer name for VGG feature.
23
+ Now we support layer names:
24
+ `['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4']`
25
+ """
26
+
27
+ def __init__(self,
28
+ weight_sp: float = 0.1,
29
+ band_width: float = 0.5,
30
+ loss_type: str = 'cosine',
31
+ use_vgg: bool = False,
32
+ vgg_layer: str = 'relu3_4'):
33
+
34
+ super(ContextualBilateralLoss, self).__init__()
35
+
36
+ assert band_width > 0, 'band_width parameter must be positive.'
37
+ assert loss_type in LOSS_TYPES,\
38
+ f'select a loss type from {LOSS_TYPES}.'
39
+
40
+ self.band_width = band_width
41
+
42
+ if use_vgg:
43
+ self.vgg_model = VGG19()
44
+ self.vgg_layer = vgg_layer
45
+ self.register_buffer(
46
+ name='vgg_mean',
47
+ tensor=torch.tensor(
48
+ [[[0.485]], [[0.456]], [[0.406]]], requires_grad=False)
49
+ )
50
+ self.register_buffer(
51
+ name='vgg_std',
52
+ tensor=torch.tensor(
53
+ [[[0.229]], [[0.224]], [[0.225]]], requires_grad=False)
54
+ )
55
+
56
+ def forward(self, x, y):
57
+ if hasattr(self, 'vgg_model'):
58
+ assert x.shape[1] == 3 and y.shape[1] == 3,\
59
+ 'VGG model takes 3 chennel images.'
60
+
61
+ # normalization
62
+ x = x.sub(self.vgg_mean.detach()).div(self.vgg_std.detach())
63
+ y = y.sub(self.vgg_mean.detach()).div(self.vgg_std.detach())
64
+
65
+ # picking up vgg feature maps
66
+ x = getattr(self.vgg_model(x), self.vgg_layer)
67
+ y = getattr(self.vgg_model(y), self.vgg_layer)
68
+
69
+ return F.contextual_bilateral_loss(x, y, self.band_width)
Time_TravelRephotography/losses/contextual_loss/modules/vgg.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ import torch.nn as nn
4
+ import torchvision.models.vgg as vgg
5
+
6
+
7
+ class VGG19(nn.Module):
8
+ def __init__(self, requires_grad=False):
9
+ super(VGG19, self).__init__()
10
+ vgg_pretrained_features = vgg.vgg19(pretrained=True).features
11
+ self.slice1 = nn.Sequential()
12
+ self.slice2 = nn.Sequential()
13
+ self.slice3 = nn.Sequential()
14
+ self.slice4 = nn.Sequential()
15
+ self.slice5 = nn.Sequential()
16
+ for x in range(4):
17
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
18
+ for x in range(4, 9):
19
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
20
+ for x in range(9, 18):
21
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
22
+ for x in range(18, 27):
23
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
24
+ for x in range(27, 36):
25
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
26
+ if not requires_grad:
27
+ for param in self.parameters():
28
+ param.requires_grad = False
29
+
30
+ def forward(self, X):
31
+ h = self.slice1(X)
32
+ h_relu1_2 = h
33
+ h = self.slice2(h)
34
+ h_relu2_2 = h
35
+ h = self.slice3(h)
36
+ h_relu3_4 = h
37
+ h = self.slice4(h)
38
+ h_relu4_4 = h
39
+ h = self.slice5(h)
40
+ h_relu5_4 = h
41
+
42
+ vgg_outputs = namedtuple(
43
+ "VggOutputs", ['relu1_2', 'relu2_2',
44
+ 'relu3_4', 'relu4_4', 'relu5_4'])
45
+ out = vgg_outputs(h_relu1_2, h_relu2_2,
46
+ h_relu3_4, h_relu4_4, h_relu5_4)
47
+
48
+ return out
Time_TravelRephotography/losses/joint_loss.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import (
2
+ ArgumentParser,
3
+ Namespace,
4
+ )
5
+ from typing import (
6
+ Dict,
7
+ Iterable,
8
+ Optional,
9
+ Tuple,
10
+ )
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch import nn
15
+
16
+ from utils.misc import (
17
+ optional_string,
18
+ iterable_to_str,
19
+ )
20
+
21
+ from .contextual_loss import ContextualLoss
22
+ from .color_transfer_loss import ColorTransferLoss
23
+ from .regularize_noise import NoiseRegularizer
24
+ from .reconstruction import (
25
+ EyeLoss,
26
+ FaceLoss,
27
+ create_perceptual_loss,
28
+ ReconstructionArguments,
29
+ )
30
+
31
+ class LossArguments:
32
+ @staticmethod
33
+ def add_arguments(parser: ArgumentParser):
34
+ ReconstructionArguments.add_arguments(parser)
35
+
36
+ parser.add_argument("--color_transfer", type=float, default=1e10, help="color transfer loss weight")
37
+ parser.add_argument("--eye", type=float, default=0.1, help="eye loss weight")
38
+ parser.add_argument('--noise_regularize', type=float, default=5e4)
39
+ # contextual loss
40
+ parser.add_argument("--contextual", type=float, default=0.1, help="contextual loss weight")
41
+ parser.add_argument("--cx_layers", nargs='*', help="contextual loss layers",
42
+ choices=['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4'],
43
+ default=['relu3_4', 'relu2_2', 'relu1_2'])
44
+
45
+ @staticmethod
46
+ def to_string(args: Namespace) -> str:
47
+ return (
48
+ ReconstructionArguments.to_string(args)
49
+ + optional_string(args.eye > 0, f"-eye{args.eye}")
50
+ + optional_string(args.color_transfer, f"-color{args.color_transfer:.1e}")
51
+ + optional_string(
52
+ args.contextual,
53
+ f"-cx{args.contextual}({iterable_to_str(args.cx_layers)})"
54
+ )
55
+ #+ optional_string(args.mse, f"-mse{args.mse}")
56
+ + optional_string(args.noise_regularize, f"-NR{args.noise_regularize:.1e}")
57
+ )
58
+
59
+
60
+ class BakedMultiContextualLoss(nn.Module):
61
+ """Random sample different image patches for different vgg layers."""
62
+ def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256):
63
+ super().__init__()
64
+
65
+ self.cxs = nn.ModuleList([ContextualLoss(use_vgg=True, vgg_layers=[layer])
66
+ for layer in args.cx_layers])
67
+ self.size = size
68
+ self.sibling = sibling.detach()
69
+
70
+ def forward(self, img: torch.Tensor):
71
+ cx_loss = 0
72
+ for cx in self.cxs:
73
+ h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2)
74
+ cx_loss = cx(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size]) + cx_loss
75
+ return cx_loss
76
+
77
+
78
+ class BakedContextualLoss(ContextualLoss):
79
+ def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256):
80
+ super().__init__(use_vgg=True, vgg_layers=args.cx_layers)
81
+ self.size = size
82
+ self.sibling = sibling.detach()
83
+
84
+ def forward(self, img: torch.Tensor):
85
+ h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2)
86
+ return super().forward(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size])
87
+
88
+
89
+ class JointLoss(nn.Module):
90
+ def __init__(
91
+ self,
92
+ args: Namespace,
93
+ target: torch.Tensor,
94
+ sibling: Optional[torch.Tensor],
95
+ sibling_rgbs: Optional[Iterable[torch.Tensor]] = None,
96
+ ):
97
+ super().__init__()
98
+
99
+ self.weights = {
100
+ "face": 1., "eye": args.eye,
101
+ "contextual": args.contextual, "color_transfer": args.color_transfer,
102
+ "noise": args.noise_regularize,
103
+ }
104
+
105
+ reconstruction = {}
106
+ if args.vgg > 0 or args.vggface > 0:
107
+ percept = create_perceptual_loss(args)
108
+ reconstruction.update(
109
+ {"face": FaceLoss(target, input_size=args.generator_size, size=args.recon_size, percept=percept)}
110
+ )
111
+ if args.eye > 0:
112
+ reconstruction.update(
113
+ {"eye": EyeLoss(target, input_size=args.generator_size, percept=percept)}
114
+ )
115
+ self.reconstruction = nn.ModuleDict(reconstruction)
116
+
117
+ exemplar = {}
118
+ if args.contextual > 0 and len(args.cx_layers) > 0:
119
+ assert sibling is not None
120
+ exemplar.update(
121
+ {"contextual": BakedContextualLoss(sibling, args)}
122
+ )
123
+ if args.color_transfer > 0:
124
+ assert sibling_rgbs is not None
125
+ self.sibling_rgbs = sibling_rgbs
126
+ exemplar.update(
127
+ {"color_transfer": ColorTransferLoss(init_rgbs=sibling_rgbs)}
128
+ )
129
+ self.exemplar = nn.ModuleDict(exemplar)
130
+
131
+ if args.noise_regularize > 0:
132
+ self.noise_criterion = NoiseRegularizer()
133
+
134
+ def forward(
135
+ self, img, degrade=None, noises=None, rgbs=None, rgb_level: Optional[int] = None
136
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
137
+ """
138
+ Args:
139
+ rgbs: results from the ToRGB layers
140
+ """
141
+ # TODO: add current optimization resolution for noises
142
+
143
+ losses = {}
144
+
145
+ # reconstruction losses
146
+ for name, criterion in self.reconstruction.items():
147
+ losses[name] = criterion(img, degrade=degrade)
148
+
149
+ # exemplar losses
150
+ if 'contextual' in self.exemplar:
151
+ losses["contextual"] = self.exemplar["contextual"](img)
152
+ if "color_transfer" in self.exemplar:
153
+ assert rgbs is not None
154
+ losses["color_transfer"] = self.exemplar["color_transfer"](rgbs, level=rgb_level)
155
+
156
+ # noise regularizer
157
+ if self.weights["noise"] > 0:
158
+ losses["noise"] = self.noise_criterion(noises)
159
+
160
+ total_loss = 0
161
+ for name, loss in losses.items():
162
+ total_loss = total_loss + self.weights[name] * loss
163
+ return total_loss, losses
164
+
165
+ def update_sibling(self, sibling: torch.Tensor):
166
+ assert "contextual" in self.exemplar
167
+ self.exemplar["contextual"].sibling = sibling.detach()
Time_TravelRephotography/losses/perceptual_loss.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code borrowed from https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49#file-vgg_perceptual_loss-py-L5
3
+ """
4
+ import torch
5
+ import torchvision
6
+ from models.vggface import VGGFaceFeats
7
+
8
+
9
+ def cos_loss(fi, ft):
10
+ return 1 - torch.nn.functional.cosine_similarity(fi, ft).mean()
11
+
12
+
13
+ class VGGPerceptualLoss(torch.nn.Module):
14
+ def __init__(self, resize=False):
15
+ super(VGGPerceptualLoss, self).__init__()
16
+ blocks = []
17
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
18
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
19
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
20
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
21
+ for bl in blocks:
22
+ for p in bl:
23
+ p.requires_grad = False
24
+ self.blocks = torch.nn.ModuleList(blocks)
25
+ self.transform = torch.nn.functional.interpolate
26
+ self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
27
+ self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
28
+ self.resize = resize
29
+
30
+ def forward(self, input, target, max_layer=4, cos_dist: bool = False):
31
+ target = (target + 1) * 0.5
32
+ input = (input + 1) * 0.5
33
+
34
+ if input.shape[1] != 3:
35
+ input = input.repeat(1, 3, 1, 1)
36
+ target = target.repeat(1, 3, 1, 1)
37
+ input = (input-self.mean) / self.std
38
+ target = (target-self.mean) / self.std
39
+ if self.resize:
40
+ input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
41
+ target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
42
+ x = input
43
+ y = target
44
+ loss = 0.0
45
+ loss_func = cos_loss if cos_dist else torch.nn.functional.l1_loss
46
+ for bi, block in enumerate(self.blocks[:max_layer]):
47
+ x = block(x)
48
+ y = block(y)
49
+ loss += loss_func(x, y.detach())
50
+ return loss
51
+
52
+
53
+ class VGGFacePerceptualLoss(torch.nn.Module):
54
+ def __init__(self, weight_path: str = "checkpoint/vgg_face_dag.pt", resize: bool = False):
55
+ super().__init__()
56
+ self.vgg = VGGFaceFeats()
57
+ self.vgg.load_state_dict(torch.load(weight_path))
58
+
59
+ mean = torch.tensor(self.vgg.meta["mean"]).view(1, 3, 1, 1) / 255.0
60
+ self.register_buffer("mean", mean)
61
+
62
+ self.transform = torch.nn.functional.interpolate
63
+ self.resize = resize
64
+
65
+ def forward(self, input, target, max_layer: int = 4, cos_dist: bool = False):
66
+ target = (target + 1) * 0.5
67
+ input = (input + 1) * 0.5
68
+
69
+ # preprocessing
70
+ if input.shape[1] != 3:
71
+ input = input.repeat(1, 3, 1, 1)
72
+ target = target.repeat(1, 3, 1, 1)
73
+ input = input - self.mean
74
+ target = target - self.mean
75
+ if self.resize:
76
+ input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
77
+ target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
78
+
79
+ input_feats = self.vgg(input)
80
+ target_feats = self.vgg(target)
81
+
82
+ loss_func = cos_loss if cos_dist else torch.nn.functional.l1_loss
83
+ # calc perceptual loss
84
+ loss = 0.0
85
+ for fi, ft in zip(input_feats[:max_layer], target_feats[:max_layer]):
86
+ loss = loss + loss_func(fi, ft.detach())
87
+ return loss
88
+
89
+
90
+ class PerceptualLoss(torch.nn.Module):
91
+ def __init__(
92
+ self, lambda_vggface: float = 0.025 / 0.15, lambda_vgg: float = 1, eps: float = 1e-8, cos_dist: bool = False
93
+ ):
94
+ super().__init__()
95
+ self.register_buffer("lambda_vggface", torch.tensor(lambda_vggface))
96
+ self.register_buffer("lambda_vgg", torch.tensor(lambda_vgg))
97
+ self.cos_dist = cos_dist
98
+
99
+ if lambda_vgg > eps:
100
+ self.vgg = VGGPerceptualLoss()
101
+ if lambda_vggface > eps:
102
+ self.vggface = VGGFacePerceptualLoss()
103
+
104
+ def forward(self, input, target, eps=1e-8, use_vggface: bool = True, use_vgg=True, max_vgg_layer=4):
105
+ loss = 0.0
106
+ if self.lambda_vgg > eps and use_vgg:
107
+ loss = loss + self.lambda_vgg * self.vgg(input, target, max_layer=max_vgg_layer)
108
+ if self.lambda_vggface > eps and use_vggface:
109
+ loss = loss + self.lambda_vggface * self.vggface(input, target, cos_dist=self.cos_dist)
110
+ return loss
111
+
Time_TravelRephotography/losses/reconstruction.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import (
2
+ ArgumentParser,
3
+ Namespace,
4
+ )
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+
11
+ from losses.perceptual_loss import PerceptualLoss
12
+ from models.degrade import Downsample
13
+ from utils.misc import optional_string
14
+
15
+
16
+ class ReconstructionArguments:
17
+ @staticmethod
18
+ def add_arguments(parser: ArgumentParser):
19
+ parser.add_argument("--vggface", type=float, default=0.3, help="vggface")
20
+ parser.add_argument("--vgg", type=float, default=1, help="vgg")
21
+ parser.add_argument('--recon_size', type=int, default=256, help="size for face reconstruction loss")
22
+
23
+ @staticmethod
24
+ def to_string(args: Namespace) -> str:
25
+ return (
26
+ f"s{args.recon_size}"
27
+ + optional_string(args.vgg > 0, f"-vgg{args.vgg}")
28
+ + optional_string(args.vggface > 0, f"-vggface{args.vggface}")
29
+ )
30
+
31
+
32
+ def create_perceptual_loss(args: Namespace):
33
+ return PerceptualLoss(lambda_vgg=args.vgg, lambda_vggface=args.vggface, cos_dist=False)
34
+
35
+
36
+ class EyeLoss(nn.Module):
37
+ def __init__(
38
+ self,
39
+ target: torch.Tensor,
40
+ input_size: int = 1024,
41
+ input_channels: int = 3,
42
+ percept: Optional[nn.Module] = None,
43
+ args: Optional[Namespace] = None
44
+ ):
45
+ """
46
+ target: target image
47
+ """
48
+ assert not (percept is None and args is None)
49
+
50
+ super().__init__()
51
+
52
+ self.target = target
53
+
54
+ target_size = target.shape[-1]
55
+ self.downsample = Downsample(input_size, target_size, input_channels) \
56
+ if target_size != input_size else (lambda x: x)
57
+
58
+ self.percept = percept if percept is not None else create_perceptual_loss(args)
59
+
60
+ eye_size = np.array((224, 224))
61
+ btlrs = []
62
+ for sgn in [1, -1]:
63
+ center = np.array((480, 384 * sgn)) # (y, x)
64
+ b, t = center[0] - eye_size[0] // 2, center[0] + eye_size[0] // 2
65
+ l, r = center[1] - eye_size[1] // 2, center[1] + eye_size[1] // 2
66
+ btlrs.append((np.array((b, t, l, r)) / 1024 * target_size).astype(int))
67
+ self.btlrs = np.stack(btlrs, axis=0)
68
+
69
+ def forward(self, img: torch.Tensor, degrade: nn.Module = None):
70
+ """
71
+ img: it should be the degraded version of the generated image
72
+ """
73
+ if degrade is not None:
74
+ img = degrade(img, downsample=self.downsample)
75
+
76
+ loss = 0
77
+ for (b, t, l, r) in self.btlrs:
78
+ loss = loss + self.percept(
79
+ img[:, :, b:t, l:r], self.target[:, :, b:t, l:r],
80
+ use_vggface=False, max_vgg_layer=4,
81
+ # use_vgg=False,
82
+ )
83
+ return loss
84
+
85
+
86
+ class FaceLoss(nn.Module):
87
+ def __init__(
88
+ self,
89
+ target: torch.Tensor,
90
+ input_size: int = 1024,
91
+ input_channels: int = 3,
92
+ size: int = 256,
93
+ percept: Optional[nn.Module] = None,
94
+ args: Optional[Namespace] = None
95
+ ):
96
+ """
97
+ target: target image
98
+ """
99
+ assert not (percept is None and args is None)
100
+
101
+ super().__init__()
102
+
103
+ target_size = target.shape[-1]
104
+ self.target = target if target_size == size \
105
+ else Downsample(target_size, size, target.shape[1]).to(target.device)(target)
106
+
107
+ self.downsample = Downsample(input_size, size, input_channels) \
108
+ if size != input_size else (lambda x: x)
109
+
110
+ self.percept = percept if percept is not None else create_perceptual_loss(args)
111
+
112
+ def forward(self, img: torch.Tensor, degrade: nn.Module = None):
113
+ """
114
+ img: it should be the degraded version of the generated image
115
+ """
116
+ if degrade is not None:
117
+ img = degrade(img, downsample=self.downsample)
118
+ loss = self.percept(img, self.target)
119
+ return loss
Time_TravelRephotography/losses/regularize_noise.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class NoiseRegularizer(nn.Module):
8
+ def forward(self, noises: Iterable[torch.Tensor]):
9
+ loss = 0
10
+
11
+ for noise in noises:
12
+ size = noise.shape[2]
13
+
14
+ while True:
15
+ loss = (
16
+ loss
17
+ + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
18
+ + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
19
+ )
20
+
21
+ if size <= 8:
22
+ break
23
+
24
+ noise = noise.reshape([1, 1, size // 2, 2, size // 2, 2])
25
+ noise = noise.mean([3, 5])
26
+ size //= 2
27
+
28
+ return loss
29
+
30
+ @staticmethod
31
+ def normalize(noises: Iterable[torch.Tensor]):
32
+ for noise in noises:
33
+ mean = noise.mean()
34
+ std = noise.std()
35
+
36
+ noise.data.add_(-mean).div_(std)
37
+
Time_TravelRephotography/model.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import functools
4
+ import operator
5
+ import numpy as np
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.autograd import Function
11
+
12
+ from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
13
+
14
+
15
+ class PixelNorm(nn.Module):
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ def forward(self, input):
20
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
21
+
22
+
23
+ def make_kernel(k):
24
+ k = torch.tensor(k, dtype=torch.float32)
25
+
26
+ if k.ndim == 1:
27
+ k = k[None, :] * k[:, None]
28
+
29
+ k /= k.sum()
30
+
31
+ return k
32
+
33
+
34
+ class Upsample(nn.Module):
35
+ def __init__(self, kernel, factor=2):
36
+ super().__init__()
37
+
38
+ self.factor = factor
39
+ kernel = make_kernel(kernel) * (factor ** 2)
40
+ self.register_buffer('kernel', kernel)
41
+
42
+ p = kernel.shape[0] - factor
43
+
44
+ pad0 = (p + 1) // 2 + factor - 1
45
+ pad1 = p // 2
46
+
47
+ self.pad = (pad0, pad1)
48
+
49
+ def forward(self, input):
50
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
51
+
52
+ return out
53
+
54
+
55
+ class Downsample(nn.Module):
56
+ def __init__(self, kernel, factor=2):
57
+ super().__init__()
58
+
59
+ self.factor = factor
60
+ kernel = make_kernel(kernel)
61
+ self.register_buffer('kernel', kernel)
62
+
63
+ p = kernel.shape[0] - factor
64
+
65
+ pad0 = (p + 1) // 2
66
+ pad1 = p // 2
67
+
68
+ self.pad = (pad0, pad1)
69
+
70
+ def forward(self, input):
71
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
72
+
73
+ return out
74
+
75
+
76
+ class Blur(nn.Module):
77
+ def __init__(self, kernel, pad, upsample_factor=1):
78
+ super().__init__()
79
+
80
+ kernel = make_kernel(kernel)
81
+
82
+ if upsample_factor > 1:
83
+ kernel = kernel * (upsample_factor ** 2)
84
+
85
+ self.register_buffer('kernel', kernel)
86
+
87
+ self.pad = pad
88
+
89
+ def forward(self, input):
90
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
91
+
92
+ return out
93
+
94
+
95
+ class EqualConv2d(nn.Module):
96
+ def __init__(
97
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
98
+ ):
99
+ super().__init__()
100
+
101
+ self.weight = nn.Parameter(
102
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
103
+ )
104
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
105
+
106
+ self.stride = stride
107
+ self.padding = padding
108
+
109
+ if bias:
110
+ self.bias = nn.Parameter(torch.zeros(out_channel))
111
+
112
+ else:
113
+ self.bias = None
114
+
115
+ def forward(self, input):
116
+ out = F.conv2d(
117
+ input,
118
+ self.weight * self.scale,
119
+ bias=self.bias,
120
+ stride=self.stride,
121
+ padding=self.padding,
122
+ )
123
+
124
+ return out
125
+
126
+ def __repr__(self):
127
+ return (
128
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
129
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
130
+ )
131
+
132
+
133
+ class EqualLinear(nn.Module):
134
+ def __init__(
135
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
136
+ ):
137
+ super().__init__()
138
+
139
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
140
+
141
+ if bias:
142
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
143
+
144
+ else:
145
+ self.bias = None
146
+
147
+ self.activation = activation
148
+
149
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
150
+ self.lr_mul = lr_mul
151
+
152
+ def forward(self, input):
153
+ if self.activation:
154
+ out = F.linear(input, self.weight * self.scale)
155
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
156
+
157
+ else:
158
+ out = F.linear(
159
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
160
+ )
161
+
162
+ return out
163
+
164
+ def __repr__(self):
165
+ return (
166
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
167
+ )
168
+
169
+
170
+ class ScaledLeakyReLU(nn.Module):
171
+ def __init__(self, negative_slope=0.2):
172
+ super().__init__()
173
+
174
+ self.negative_slope = negative_slope
175
+
176
+ def forward(self, input):
177
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
178
+
179
+ return out * math.sqrt(2)
180
+
181
+
182
+ class ModulatedConv2d(nn.Module):
183
+ def __init__(
184
+ self,
185
+ in_channel,
186
+ out_channel,
187
+ kernel_size,
188
+ style_dim,
189
+ demodulate=True,
190
+ upsample=False,
191
+ downsample=False,
192
+ blur_kernel=[1, 3, 3, 1],
193
+ ):
194
+ super().__init__()
195
+
196
+ self.eps = 1e-8
197
+ self.kernel_size = kernel_size
198
+ self.in_channel = in_channel
199
+ self.out_channel = out_channel
200
+ self.upsample = upsample
201
+ self.downsample = downsample
202
+
203
+ if upsample:
204
+ factor = 2
205
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
206
+ pad0 = (p + 1) // 2 + factor - 1
207
+ pad1 = p // 2 + 1
208
+
209
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
210
+
211
+ if downsample:
212
+ factor = 2
213
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
214
+ pad0 = (p + 1) // 2
215
+ pad1 = p // 2
216
+
217
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
218
+
219
+ fan_in = in_channel * kernel_size ** 2
220
+ self.scale = 1 / math.sqrt(fan_in)
221
+ self.padding = kernel_size // 2
222
+
223
+ self.weight = nn.Parameter(
224
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
225
+ )
226
+
227
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
228
+
229
+ self.demodulate = demodulate
230
+
231
+ def __repr__(self):
232
+ return (
233
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
234
+ f'upsample={self.upsample}, downsample={self.downsample})'
235
+ )
236
+
237
+ def forward(self, input, style):
238
+ batch, in_channel, height, width = input.shape
239
+
240
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
241
+ weight = self.scale * self.weight * style
242
+
243
+ if self.demodulate:
244
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
245
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
246
+
247
+ weight = weight.view(
248
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
249
+ )
250
+
251
+ if self.upsample:
252
+ input = input.view(1, batch * in_channel, height, width)
253
+ weight = weight.view(
254
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
255
+ )
256
+ weight = weight.transpose(1, 2).reshape(
257
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
258
+ )
259
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
260
+ _, _, height, width = out.shape
261
+ out = out.view(batch, self.out_channel, height, width)
262
+ out = self.blur(out)
263
+
264
+ elif self.downsample:
265
+ input = self.blur(input)
266
+ _, _, height, width = input.shape
267
+ input = input.view(1, batch * in_channel, height, width)
268
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
269
+ _, _, height, width = out.shape
270
+ out = out.view(batch, self.out_channel, height, width)
271
+
272
+ else:
273
+ input = input.view(1, batch * in_channel, height, width)
274
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
275
+ _, _, height, width = out.shape
276
+ out = out.view(batch, self.out_channel, height, width)
277
+
278
+ return out
279
+
280
+
281
+ class NoiseInjection(nn.Module):
282
+ def __init__(self):
283
+ super().__init__()
284
+
285
+ self.weight = nn.Parameter(torch.zeros(1))
286
+
287
+ def forward(self, image, noise=None):
288
+ if noise is None:
289
+ batch, _, height, width = image.shape
290
+ noise = image.new_empty(batch, 1, height, width).normal_()
291
+
292
+ return image + self.weight * noise
293
+
294
+
295
+ class ConstantInput(nn.Module):
296
+ def __init__(self, channel, size=4):
297
+ super().__init__()
298
+
299
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
300
+
301
+ def forward(self, input):
302
+ batch = input.shape[0]
303
+ out = self.input.repeat(batch, 1, 1, 1)
304
+
305
+ return out
306
+
307
+
308
+ class StyledConv(nn.Module):
309
+ def __init__(
310
+ self,
311
+ in_channel,
312
+ out_channel,
313
+ kernel_size,
314
+ style_dim,
315
+ upsample=False,
316
+ blur_kernel=[1, 3, 3, 1],
317
+ demodulate=True,
318
+ ):
319
+ super().__init__()
320
+
321
+ self.conv = ModulatedConv2d(
322
+ in_channel,
323
+ out_channel,
324
+ kernel_size,
325
+ style_dim,
326
+ upsample=upsample,
327
+ blur_kernel=blur_kernel,
328
+ demodulate=demodulate,
329
+ )
330
+
331
+ self.noise = NoiseInjection()
332
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
333
+ # self.activate = ScaledLeakyReLU(0.2)
334
+ self.activate = FusedLeakyReLU(out_channel)
335
+
336
+ def forward(self, input, style, noise=None):
337
+ out = self.conv(input, style)
338
+ out = self.noise(out, noise=noise)
339
+ # out = out + self.bias
340
+ out = self.activate(out)
341
+
342
+ return out
343
+
344
+
345
+ class ToRGB(nn.Module):
346
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
347
+ super().__init__()
348
+
349
+ if upsample:
350
+ self.upsample = Upsample(blur_kernel)
351
+
352
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
353
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
354
+
355
+ def forward(self, input, style, skip=None):
356
+ out = self.conv(input, style)
357
+ style_modulated = out
358
+ out = out + self.bias
359
+
360
+ if skip is not None:
361
+ skip = self.upsample(skip)
362
+
363
+ out = out + skip
364
+
365
+ return out, style_modulated
366
+
367
+
368
+ class Generator(nn.Module):
369
+ def __init__(
370
+ self,
371
+ size,
372
+ style_dim,
373
+ n_mlp,
374
+ channel_multiplier=2,
375
+ blur_kernel=[1, 3, 3, 1],
376
+ lr_mlp=0.01,
377
+ ):
378
+ super().__init__()
379
+
380
+ self.size = size
381
+
382
+ self.style_dim = style_dim
383
+
384
+ layers = [PixelNorm()]
385
+
386
+ for i in range(n_mlp):
387
+ layers.append(
388
+ EqualLinear(
389
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
390
+ )
391
+ )
392
+
393
+ self.style = nn.Sequential(*layers)
394
+
395
+ self.channels = {
396
+ 4: 512,
397
+ 8: 512,
398
+ 16: 512,
399
+ 32: 512,
400
+ 64: 256 * channel_multiplier,
401
+ 128: 128 * channel_multiplier,
402
+ 256: 64 * channel_multiplier,
403
+ 512: 32 * channel_multiplier,
404
+ 1024: 16 * channel_multiplier,
405
+ }
406
+
407
+ self.input = ConstantInput(self.channels[4])
408
+ self.conv1 = StyledConv(
409
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
410
+ )
411
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
412
+
413
+ self.log_size = int(math.log(size, 2))
414
+ self.num_layers = (self.log_size - 2) * 2 + 1
415
+
416
+ self.convs = nn.ModuleList()
417
+ self.upsamples = nn.ModuleList()
418
+ self.to_rgbs = nn.ModuleList()
419
+ self.noises = nn.Module()
420
+
421
+ in_channel = self.channels[4]
422
+
423
+ for layer_idx in range(self.num_layers):
424
+ res = (layer_idx + 5) // 2
425
+ shape = [1, 1, 2 ** res, 2 ** res]
426
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
427
+
428
+ for i in range(3, self.log_size + 1):
429
+ out_channel = self.channels[2 ** i]
430
+
431
+ self.convs.append(
432
+ StyledConv(
433
+ in_channel,
434
+ out_channel,
435
+ 3,
436
+ style_dim,
437
+ upsample=True,
438
+ blur_kernel=blur_kernel,
439
+ )
440
+ )
441
+
442
+ self.convs.append(
443
+ StyledConv(
444
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
445
+ )
446
+ )
447
+
448
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
449
+
450
+ in_channel = out_channel
451
+
452
+ self.n_latent = self.log_size * 2 - 2
453
+
454
+ @property
455
+ def device(self):
456
+ # TODO if multi-gpu is expected, could use the following more expensive version
457
+ #device, = list(set(p.device for p in self.parameters()))
458
+ return next(self.parameters()).device
459
+
460
+ @staticmethod
461
+ def get_latent_size(size):
462
+ log_size = int(math.log(size, 2))
463
+ return log_size * 2 - 2
464
+
465
+ @staticmethod
466
+ def make_noise_by_size(size: int, device: torch.device):
467
+ log_size = int(math.log(size, 2))
468
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
469
+
470
+ for i in range(3, log_size + 1):
471
+ for _ in range(2):
472
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
473
+
474
+ return noises
475
+
476
+
477
+ def make_noise(self):
478
+ return self.make_noise_by_size(self.size, self.input.input.device)
479
+
480
+ def mean_latent(self, n_latent):
481
+ latent_in = torch.randn(
482
+ n_latent, self.style_dim, device=self.input.input.device
483
+ )
484
+ latent = self.style(latent_in).mean(0, keepdim=True)
485
+
486
+ return latent
487
+
488
+ def get_latent(self, input):
489
+ return self.style(input)
490
+
491
+ def forward(
492
+ self,
493
+ styles,
494
+ return_latents=False,
495
+ inject_index=None,
496
+ truncation=1,
497
+ truncation_latent=None,
498
+ input_is_latent=False,
499
+ noise=None,
500
+ randomize_noise=True,
501
+ ):
502
+ if not input_is_latent:
503
+ styles = [self.style(s) for s in styles]
504
+
505
+ if noise is None:
506
+ if randomize_noise:
507
+ noise = [None] * self.num_layers
508
+ else:
509
+ noise = [
510
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
511
+ ]
512
+
513
+ if truncation < 1:
514
+ style_t = []
515
+
516
+ for style in styles:
517
+ style_t.append(
518
+ truncation_latent + truncation * (style - truncation_latent)
519
+ )
520
+
521
+ styles = style_t
522
+
523
+ if len(styles) < 2:
524
+ inject_index = self.n_latent
525
+
526
+ if styles[0].ndim < 3:
527
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
528
+
529
+ else:
530
+ latent = styles[0]
531
+
532
+ else:
533
+ if inject_index is None:
534
+ inject_index = random.randint(1, self.n_latent - 1)
535
+
536
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
537
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
538
+
539
+ latent = torch.cat([latent, latent2], 1)
540
+
541
+ out = self.input(latent)
542
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
543
+
544
+ skip, rgb_mod = self.to_rgb1(out, latent[:, 1])
545
+
546
+
547
+ rgbs = [rgb_mod] # all but the last skip
548
+ i = 1
549
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
550
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
551
+ ):
552
+ out = conv1(out, latent[:, i], noise=noise1)
553
+ out = conv2(out, latent[:, i + 1], noise=noise2)
554
+ skip, rgb_mod = to_rgb(out, latent[:, i + 2], skip)
555
+ rgbs.append(rgb_mod)
556
+
557
+ i += 2
558
+
559
+ image = skip
560
+
561
+ if return_latents:
562
+ return image, latent, rgbs
563
+
564
+ else:
565
+ return image, None, rgbs
566
+
567
+
568
+ class ConvLayer(nn.Sequential):
569
+ def __init__(
570
+ self,
571
+ in_channel,
572
+ out_channel,
573
+ kernel_size,
574
+ downsample=False,
575
+ blur_kernel=[1, 3, 3, 1],
576
+ bias=True,
577
+ activate=True,
578
+ ):
579
+ layers = []
580
+
581
+ if downsample:
582
+ factor = 2
583
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
584
+ pad0 = (p + 1) // 2
585
+ pad1 = p // 2
586
+
587
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
588
+
589
+ stride = 2
590
+ self.padding = 0
591
+
592
+ else:
593
+ stride = 1
594
+ self.padding = kernel_size // 2
595
+
596
+ layers.append(
597
+ EqualConv2d(
598
+ in_channel,
599
+ out_channel,
600
+ kernel_size,
601
+ padding=self.padding,
602
+ stride=stride,
603
+ bias=bias and not activate,
604
+ )
605
+ )
606
+
607
+ if activate:
608
+ if bias:
609
+ layers.append(FusedLeakyReLU(out_channel))
610
+
611
+ else:
612
+ layers.append(ScaledLeakyReLU(0.2))
613
+
614
+ super().__init__(*layers)
615
+
616
+
617
+ class ResBlock(nn.Module):
618
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
619
+ super().__init__()
620
+
621
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
622
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
623
+
624
+ self.skip = ConvLayer(
625
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
626
+ )
627
+
628
+ def forward(self, input):
629
+ out = self.conv1(input)
630
+ out = self.conv2(out)
631
+
632
+ skip = self.skip(input)
633
+ out = (out + skip) / math.sqrt(2)
634
+
635
+ return out
636
+
637
+
638
+ class Discriminator(nn.Module):
639
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
640
+ super().__init__()
641
+
642
+ channels = {
643
+ 4: 512,
644
+ 8: 512,
645
+ 16: 512,
646
+ 32: 512,
647
+ 64: 256 * channel_multiplier,
648
+ 128: 128 * channel_multiplier,
649
+ 256: 64 * channel_multiplier,
650
+ 512: 32 * channel_multiplier,
651
+ 1024: 16 * channel_multiplier,
652
+ }
653
+
654
+ convs = [ConvLayer(3, channels[size], 1)]
655
+
656
+ log_size = int(math.log(size, 2))
657
+
658
+ in_channel = channels[size]
659
+
660
+ for i in range(log_size, 2, -1):
661
+ out_channel = channels[2 ** (i - 1)]
662
+
663
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
664
+
665
+ in_channel = out_channel
666
+
667
+ self.convs = nn.Sequential(*convs)
668
+
669
+ self.stddev_group = 4
670
+ self.stddev_feat = 1
671
+
672
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
673
+ self.final_linear = nn.Sequential(
674
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
675
+ EqualLinear(channels[4], 1),
676
+ )
677
+
678
+ def forward(self, input):
679
+ out = self.convs(input)
680
+
681
+ batch, channel, height, width = out.shape
682
+ group = min(batch, self.stddev_group)
683
+ stddev = out.view(
684
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
685
+ )
686
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
687
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
688
+ stddev = stddev.repeat(group, 1, height, width)
689
+ out = torch.cat([out, stddev], 1)
690
+
691
+ out = self.final_conv(out)
692
+
693
+ out = out.view(batch, -1)
694
+ out = self.final_linear(out)
695
+
696
+ return out
697
+
Time_TravelRephotography/models/__init__.py ADDED
File without changes
Time_TravelRephotography/models/degrade.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import (
2
+ ArgumentParser,
3
+ Namespace,
4
+ )
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from utils.misc import optional_string
11
+
12
+ from .gaussian_smoothing import GaussianSmoothing
13
+
14
+
15
+ class DegradeArguments:
16
+ @staticmethod
17
+ def add_arguments(parser: ArgumentParser):
18
+ parser.add_argument('--spectral_sensitivity', choices=["g", "b", "gb"], default="g",
19
+ help="Type of spectral sensitivity. g: grayscale (panchromatic), b: blue-sensitive, gb: green+blue (orthochromatic)")
20
+ parser.add_argument('--gaussian', type=float, default=0,
21
+ help="estimated blur radius in pixels of the input photo if it is scaled to 1024x1024")
22
+
23
+ @staticmethod
24
+ def to_string(args: Namespace) -> str:
25
+ return (
26
+ f"{args.spectral_sensitivity}"
27
+ + optional_string(args.gaussian > 0, f"-G{args.gaussian}")
28
+ )
29
+
30
+
31
+ class CameraResponse(nn.Module):
32
+ def __init__(self):
33
+ super().__init__()
34
+
35
+ self.register_parameter("gamma", nn.Parameter(torch.ones(1)))
36
+ self.register_parameter("offset", nn.Parameter(torch.zeros(1)))
37
+ self.register_parameter("gain", nn.Parameter(torch.ones(1)))
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ x = torch.clamp(x, max=1, min=-1+1e-2)
41
+ x = (1 + x) * 0.5
42
+ x = self.offset + self.gain * torch.pow(x, self.gamma)
43
+ x = (x - 0.5) * 2
44
+ # b = torch.clamp(b, max=1, min=-1)
45
+ return x
46
+
47
+
48
+ class SpectralResponse(nn.Module):
49
+ # TODO: use enum instead for color mode
50
+ def __init__(self, spectral_sensitivity: str = 'b'):
51
+ assert spectral_sensitivity in ("g", "b", "gb"), f"spectral_sensitivity {spectral_sensitivity} is not implemented."
52
+
53
+ super().__init__()
54
+
55
+ self.spectral_sensitivity = spectral_sensitivity
56
+
57
+ if self.spectral_sensitivity == "g":
58
+ self.register_buffer("to_gray", torch.tensor([0.299, 0.587, 0.114]).reshape(1, -1, 1, 1))
59
+
60
+ def forward(self, rgb: torch.Tensor) -> torch.Tensor:
61
+ if self.spectral_sensitivity == "b":
62
+ x = rgb[:, -1:]
63
+ elif self.spectral_sensitivity == "gb":
64
+ x = (rgb[:, 1:2] + rgb[:, -1:]) * 0.5
65
+ else:
66
+ assert self.spectral_sensitivity == "g"
67
+ x = (rgb * self.to_gray).sum(dim=1, keepdim=True)
68
+ return x
69
+
70
+
71
+ class Downsample(nn.Module):
72
+ """Antialiasing downsampling"""
73
+ def __init__(self, input_size: int, output_size: int, channels: int):
74
+ super().__init__()
75
+ if input_size % output_size == 0:
76
+ self.stride = input_size // output_size
77
+ self.grid = None
78
+ else:
79
+ self.stride = 1
80
+ step = input_size / output_size
81
+ x = torch.arange(output_size) * step
82
+ Y, X = torch.meshgrid(x, x)
83
+ grid = torch.stack((X, Y), dim=-1)
84
+ grid /= torch.Tensor((input_size - 1, input_size - 1)).view(1, 1, -1)
85
+ grid = grid * 2 - 1
86
+ self.register_buffer("grid", grid)
87
+ sigma = 0.5 * input_size / output_size
88
+ #print(f"{input_size} -> {output_size}: sigma={sigma}")
89
+ self.blur = GaussianSmoothing(channels, int(2 * (sigma * 2) + 1 + 0.5), sigma)
90
+
91
+ def forward(self, im: torch.Tensor):
92
+ out = self.blur(im, stride=self.stride)
93
+ if self.grid is not None:
94
+ out = F.grid_sample(out, self.grid[None].expand(im.shape[0], -1, -1, -1))
95
+ return out
96
+
97
+
98
+
99
+ class Degrade(nn.Module):
100
+ """
101
+ Simulate the degradation of antique film
102
+ """
103
+ def __init__(self, args:Namespace):
104
+ super().__init__()
105
+ self.srf = SpectralResponse(args.spectral_sensitivity)
106
+ self.crf = CameraResponse()
107
+ self.gaussian = None
108
+ if args.gaussian is not None and args.gaussian > 0:
109
+ self.gaussian = GaussianSmoothing(3, 2 * int(args.gaussian * 2 + 0.5) + 1, args.gaussian)
110
+
111
+ def forward(self, img: torch.Tensor, downsample: nn.Module = None):
112
+ if self.gaussian is not None:
113
+ img = self.gaussian(img)
114
+ if downsample is not None:
115
+ img = downsample(img)
116
+ img = self.srf(img)
117
+ img = self.crf(img)
118
+ # Note that I changed it back to 3 channels
119
+ return img.repeat((1, 3, 1, 1)) if img.shape[1] == 1 else img
120
+
121
+
122
+
Time_TravelRephotography/models/encoder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace, ArgumentParser
2
+ from functools import partial
3
+
4
+ from torch import nn
5
+
6
+ from .resnet import ResNetBasicBlock, activation_func, norm_module, Conv2dAuto
7
+
8
+
9
+ def add_arguments(parser: ArgumentParser) -> ArgumentParser:
10
+ parser.add_argument("--latent_size", type=int, default=512, help="latent size")
11
+ return parser
12
+
13
+
14
+ def create_model(args) -> nn.Module:
15
+ in_channels = 3 if "rgb" in args and args.rgb else 1
16
+ return Encoder(in_channels, args.encoder_size, latent_size=args.latent_size)
17
+
18
+
19
+ class Flatten(nn.Module):
20
+ def forward(self, input_):
21
+ return input_.view(input_.size(0), -1)
22
+
23
+
24
+ class Encoder(nn.Module):
25
+ def __init__(
26
+ self, in_channels: int, size: int, latent_size: int = 512,
27
+ activation: str = 'leaky_relu', norm: str = "instance"
28
+ ):
29
+ super().__init__()
30
+
31
+ out_channels0 = 64
32
+ norm_m = norm_module(norm)
33
+ self.conv0 = nn.Sequential(
34
+ Conv2dAuto(in_channels, out_channels0, kernel_size=5),
35
+ norm_m(out_channels0),
36
+ activation_func(activation),
37
+ )
38
+
39
+ pool_kernel = 2
40
+ self.pool = nn.AvgPool2d(pool_kernel)
41
+
42
+ num_channels = [128, 256, 512, 512]
43
+ # FIXME: this is a hack
44
+ if size >= 256:
45
+ num_channels.append(512)
46
+
47
+ residual = partial(ResNetBasicBlock, activation=activation, norm=norm, bias=True)
48
+ residual_blocks = nn.ModuleList()
49
+ for in_channel, out_channel in zip([out_channels0] + num_channels[:-1], num_channels):
50
+ residual_blocks.append(residual(in_channel, out_channel))
51
+ residual_blocks.append(nn.AvgPool2d(pool_kernel))
52
+ self.residual_blocks = nn.Sequential(*residual_blocks)
53
+
54
+ self.last = nn.Sequential(
55
+ nn.ReLU(),
56
+ nn.AvgPool2d(4), # TODO: not sure whehter this would cause problem
57
+ Flatten(),
58
+ nn.Linear(num_channels[-1], latent_size, bias=True)
59
+ )
60
+
61
+ def forward(self, input_):
62
+ out = self.conv0(input_)
63
+ out = self.pool(out)
64
+ out = self.residual_blocks(out)
65
+ out = self.last(out)
66
+ return out
Time_TravelRephotography/models/encoder4editing/.gitignore ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # Custom dataset
132
+ pretrained_models
133
+ results_test
Time_TravelRephotography/models/encoder4editing/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 omertov
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Time_TravelRephotography/models/encoder4editing/README.md ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Designing an Encoder for StyleGAN Image Manipulation (SIGGRAPH 2021)
2
+ <a href="https://arxiv.org/abs/2102.02766"><img src="https://img.shields.io/badge/arXiv-2008.00951-b31b1b.svg"></a>
3
+ <a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
4
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/omertov/encoder4editing/blob/main/notebooks/inference_playground.ipynb)
5
+
6
+ > Recently, there has been a surge of diverse methods for performing image editing by employing pre-trained unconditional generators. Applying these methods on real images, however, remains a challenge, as it necessarily requires the inversion of the images into their latent space. To successfully invert a real image, one needs to find a latent code that reconstructs the input image accurately, and more importantly, allows for its meaningful manipulation. In this paper, we carefully study the latent space of StyleGAN, the state-of-the-art unconditional generator. We identify and analyze the existence of a distortion-editability tradeoff and a distortion-perception tradeoff within the StyleGAN latent space. We then suggest two principles for designing encoders in a manner that allows one to control the proximity of the inversions to regions that StyleGAN was originally trained on. We present an encoder based on our two principles that is specifically designed for facilitating editing on real images by balancing these tradeoffs. By evaluating its performance qualitatively and quantitatively on numerous challenging domains, including cars and horses, we show that our inversion method, followed by common editing techniques, achieves superior real-image editing quality, with only a small reconstruction accuracy drop.
7
+
8
+ <p align="center">
9
+ <img src="docs/teaser.jpg" width="800px"/>
10
+ </p>
11
+
12
+ ## Description
13
+ Official Implementation of "<a href="https://arxiv.org/abs/2102.02766">Designing an Encoder for StyleGAN Image Manipulation</a>" paper for both training and evaluation.
14
+ The e4e encoder is specifically designed to complement existing image manipulation techniques performed over StyleGAN's latent space.
15
+
16
+ ## Recent Updates
17
+ `2021.08.17`: Add single style code encoder (use `--encoder_type SingleStyleCodeEncoder`). <br />
18
+ `2021.03.25`: Add pose editing direction.
19
+
20
+ ## Getting Started
21
+ ### Prerequisites
22
+ - Linux or macOS
23
+ - NVIDIA GPU + CUDA CuDNN (CPU may be possible with some modifications, but is not inherently supported)
24
+ - Python 3
25
+
26
+ ### Installation
27
+ - Clone the repository:
28
+ ```
29
+ git clone https://github.com/omertov/encoder4editing.git
30
+ cd encoder4editing
31
+ ```
32
+ - Dependencies:
33
+ We recommend running this repository using [Anaconda](https://docs.anaconda.com/anaconda/install/).
34
+ All dependencies for defining the environment are provided in `environment/e4e_env.yaml`.
35
+
36
+ ### Inference Notebook
37
+ We provide a Jupyter notebook found in `notebooks/inference_playground.ipynb` that allows one to encode and perform several editings on real images using StyleGAN.
38
+
39
+ ### Pretrained Models
40
+ Please download the pre-trained models from the following links. Each e4e model contains the entire pSp framework architecture, including the encoder and decoder weights.
41
+ | Path | Description
42
+ | :--- | :----------
43
+ |[FFHQ Inversion](https://drive.google.com/file/d/1cUv_reLE6k3604or78EranS7XzuVMWeO/view?usp=sharing) | FFHQ e4e encoder.
44
+ |[Cars Inversion](https://drive.google.com/file/d/17faPqBce2m1AQeLCLHUVXaDfxMRU2QcV/view?usp=sharing) | Cars e4e encoder.
45
+ |[Horse Inversion](https://drive.google.com/file/d/1TkLLnuX86B_BMo2ocYD0kX9kWh53rUVX/view?usp=sharing) | Horse e4e encoder.
46
+ |[Church Inversion](https://drive.google.com/file/d/1-L0ZdnQLwtdy6-A_Ccgq5uNJGTqE7qBa/view?usp=sharing) | Church e4e encoder.
47
+
48
+ If you wish to use one of the pretrained models for training or inference, you may do so using the flag `--checkpoint_path`.
49
+
50
+ In addition, we provide various auxiliary models needed for training your own e4e model from scratch.
51
+ | Path | Description
52
+ | :--- | :----------
53
+ |[FFHQ StyleGAN](https://drive.google.com/file/d/1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT/view?usp=sharing) | StyleGAN model pretrained on FFHQ taken from [rosinality](https://github.com/rosinality/stylegan2-pytorch) with 1024x1024 output resolution.
54
+ |[IR-SE50 Model](https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing) | Pretrained IR-SE50 model taken from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) for use in our ID loss during training.
55
+ |[MOCOv2 Model](https://drive.google.com/file/d/18rLcNGdteX5LwT7sv_F7HWr12HpVEzVe/view?usp=sharing) | Pretrained ResNet-50 model trained using MOCOv2 for use in our simmilarity loss for domains other then human faces during training.
56
+
57
+ By default, we assume that all auxiliary models are downloaded and saved to the directory `pretrained_models`. However, you may use your own paths by changing the necessary values in `configs/path_configs.py`.
58
+
59
+ ## Training
60
+ To train the e4e encoder, make sure the paths to the required models, as well as training and testing data is configured in `configs/path_configs.py` and `configs/data_configs.py`.
61
+ #### **Training the e4e Encoder**
62
+ ```
63
+ python scripts/train.py \
64
+ --dataset_type cars_encode \
65
+ --exp_dir new/experiment/directory \
66
+ --start_from_latent_avg \
67
+ --use_w_pool \
68
+ --w_discriminator_lambda 0.1 \
69
+ --progressive_start 20000 \
70
+ --id_lambda 0.5 \
71
+ --val_interval 10000 \
72
+ --max_steps 200000 \
73
+ --stylegan_size 512 \
74
+ --stylegan_weights path/to/pretrained/stylegan.pt \
75
+ --workers 8 \
76
+ --batch_size 8 \
77
+ --test_batch_size 4 \
78
+ --test_workers 4
79
+ ```
80
+
81
+ #### Training on your own dataset
82
+ In order to train the e4e encoder on a custom dataset, perform the following adjustments:
83
+ 1. Insert the paths to your train and test data into the `dataset_paths` variable defined in `configs/paths_config.py`:
84
+ ```
85
+ dataset_paths = {
86
+ 'my_train_data': '/path/to/train/images/directory',
87
+ 'my_test_data': '/path/to/test/images/directory'
88
+ }
89
+ ```
90
+ 2. Configure a new dataset under the DATASETS variable defined in `configs/data_configs.py`:
91
+ ```
92
+ DATASETS = {
93
+ 'my_data_encode': {
94
+ 'transforms': transforms_config.EncodeTransforms,
95
+ 'train_source_root': dataset_paths['my_train_data'],
96
+ 'train_target_root': dataset_paths['my_train_data'],
97
+ 'test_source_root': dataset_paths['my_test_data'],
98
+ 'test_target_root': dataset_paths['my_test_data']
99
+ }
100
+ }
101
+ ```
102
+ Refer to `configs/transforms_config.py` for the transformations applied to the train and test images during training.
103
+
104
+ 3. Finally, run a training session with `--dataset_type my_data_encode`.
105
+
106
+ ## Inference
107
+ Having trained your model, you can use `scripts/inference.py` to apply the model on a set of images.
108
+ For example,
109
+ ```
110
+ python scripts/inference.py \
111
+ --images_dir=/path/to/images/directory \
112
+ --save_dir=/path/to/saving/directory \
113
+ path/to/checkpoint.pt
114
+ ```
115
+
116
+ ## Latent Editing Consistency (LEC)
117
+ As described in the paper, we suggest a new metric, Latent Editing Consistency (LEC), for evaluating the encoder's
118
+ performance.
119
+ We provide an example for calculating the metric over the FFHQ StyleGAN using the aging editing direction in
120
+ `metrics/LEC.py`.
121
+
122
+ To run the example:
123
+ ```
124
+ cd metrics
125
+ python LEC.py \
126
+ --images_dir=/path/to/images/directory \
127
+ path/to/checkpoint.pt
128
+ ```
129
+
130
+ ## Acknowledgments
131
+ This code borrows heavily from [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel)
132
+
133
+ ## Citation
134
+ If you use this code for your research, please cite our paper <a href="https://arxiv.org/abs/2102.02766">Designing an Encoder for StyleGAN Image Manipulation</a>:
135
+
136
+ ```
137
+ @article{tov2021designing,
138
+ title={Designing an Encoder for StyleGAN Image Manipulation},
139
+ author={Tov, Omer and Alaluf, Yuval and Nitzan, Yotam and Patashnik, Or and Cohen-Or, Daniel},
140
+ journal={arXiv preprint arXiv:2102.02766},
141
+ year={2021}
142
+ }
143
+ ```
Time_TravelRephotography/models/encoder4editing/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils.model_utils import setup_model
2
+
3
+
4
+ def get_latents(net, x, is_cars=False):
5
+ codes = net.encoder(x)
6
+ if net.opts.start_from_latent_avg:
7
+ if codes.ndim == 2:
8
+ codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
9
+ else:
10
+ codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)
11
+ if codes.shape[1] == 18 and is_cars:
12
+ codes = codes[:, :16, :]
13
+ return codes
14
+
15
+
Time_TravelRephotography/models/encoder4editing/bash_scripts/inference.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -exo
2
+
3
+ list="$1"
4
+ ckpt="${2:-pretrained_models/e4e_ffhq_encode.pt}"
5
+
6
+ base_dir="$REPHOTO/dataset/historically_interesting/aligned/manual_celebrity_in_19th_century/tier1/${list}/"
7
+ save_dir="results_test/${list}/"
8
+
9
+
10
+ TORCH_EXTENSIONS_DIR=/tmp/torch_extensions
11
+ PYTHONPATH="" \
12
+ python scripts/inference.py \
13
+ --images_dir="${base_dir}" \
14
+ --save_dir="${save_dir}" \
15
+ "${ckpt}"
Time_TravelRephotography/models/encoder4editing/configs/__init__.py ADDED
File without changes
Time_TravelRephotography/models/encoder4editing/configs/data_configs.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs import transforms_config
2
+ from configs.paths_config import dataset_paths
3
+
4
+
5
+ DATASETS = {
6
+ 'ffhq_encode': {
7
+ 'transforms': transforms_config.EncodeTransforms,
8
+ 'train_source_root': dataset_paths['ffhq'],
9
+ 'train_target_root': dataset_paths['ffhq'],
10
+ 'test_source_root': dataset_paths['celeba_test'],
11
+ 'test_target_root': dataset_paths['celeba_test'],
12
+ },
13
+ 'cars_encode': {
14
+ 'transforms': transforms_config.CarsEncodeTransforms,
15
+ 'train_source_root': dataset_paths['cars_train'],
16
+ 'train_target_root': dataset_paths['cars_train'],
17
+ 'test_source_root': dataset_paths['cars_test'],
18
+ 'test_target_root': dataset_paths['cars_test'],
19
+ },
20
+ 'horse_encode': {
21
+ 'transforms': transforms_config.EncodeTransforms,
22
+ 'train_source_root': dataset_paths['horse_train'],
23
+ 'train_target_root': dataset_paths['horse_train'],
24
+ 'test_source_root': dataset_paths['horse_test'],
25
+ 'test_target_root': dataset_paths['horse_test'],
26
+ },
27
+ 'church_encode': {
28
+ 'transforms': transforms_config.EncodeTransforms,
29
+ 'train_source_root': dataset_paths['church_train'],
30
+ 'train_target_root': dataset_paths['church_train'],
31
+ 'test_source_root': dataset_paths['church_test'],
32
+ 'test_target_root': dataset_paths['church_test'],
33
+ },
34
+ 'cats_encode': {
35
+ 'transforms': transforms_config.EncodeTransforms,
36
+ 'train_source_root': dataset_paths['cats_train'],
37
+ 'train_target_root': dataset_paths['cats_train'],
38
+ 'test_source_root': dataset_paths['cats_test'],
39
+ 'test_target_root': dataset_paths['cats_test'],
40
+ }
41
+ }
Time_TravelRephotography/models/encoder4editing/configs/paths_config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_paths = {
2
+ # Face Datasets (In the paper: FFHQ - train, CelebAHQ - test)
3
+ 'ffhq': '',
4
+ 'celeba_test': '',
5
+
6
+ # Cars Dataset (In the paper: Stanford cars)
7
+ 'cars_train': '',
8
+ 'cars_test': '',
9
+
10
+ # Horse Dataset (In the paper: LSUN Horse)
11
+ 'horse_train': '',
12
+ 'horse_test': '',
13
+
14
+ # Church Dataset (In the paper: LSUN Church)
15
+ 'church_train': '',
16
+ 'church_test': '',
17
+
18
+ # Cats Dataset (In the paper: LSUN Cat)
19
+ 'cats_train': '',
20
+ 'cats_test': ''
21
+ }
22
+
23
+ model_paths = {
24
+ 'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
25
+ 'ir_se50': 'pretrained_models/model_ir_se50.pth',
26
+ 'shape_predictor': 'pretrained_models/shape_predictor_68_face_landmarks.dat',
27
+ 'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth'
28
+ }
Time_TravelRephotography/models/encoder4editing/configs/transforms_config.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import torchvision.transforms as transforms
3
+
4
+
5
+ class TransformsConfig(object):
6
+
7
+ def __init__(self, opts):
8
+ self.opts = opts
9
+
10
+ @abstractmethod
11
+ def get_transforms(self):
12
+ pass
13
+
14
+
15
+ class EncodeTransforms(TransformsConfig):
16
+
17
+ def __init__(self, opts):
18
+ super(EncodeTransforms, self).__init__(opts)
19
+
20
+ def get_transforms(self):
21
+ transforms_dict = {
22
+ 'transform_gt_train': transforms.Compose([
23
+ transforms.Resize((256, 256)),
24
+ transforms.RandomHorizontalFlip(0.5),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
27
+ 'transform_source': None,
28
+ 'transform_test': transforms.Compose([
29
+ transforms.Resize((256, 256)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
32
+ 'transform_inference': transforms.Compose([
33
+ transforms.Resize((256, 256)),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
36
+ }
37
+ return transforms_dict
38
+
39
+
40
+ class CarsEncodeTransforms(TransformsConfig):
41
+
42
+ def __init__(self, opts):
43
+ super(CarsEncodeTransforms, self).__init__(opts)
44
+
45
+ def get_transforms(self):
46
+ transforms_dict = {
47
+ 'transform_gt_train': transforms.Compose([
48
+ transforms.Resize((192, 256)),
49
+ transforms.RandomHorizontalFlip(0.5),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
52
+ 'transform_source': None,
53
+ 'transform_test': transforms.Compose([
54
+ transforms.Resize((192, 256)),
55
+ transforms.ToTensor(),
56
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
57
+ 'transform_inference': transforms.Compose([
58
+ transforms.Resize((192, 256)),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
61
+ }
62
+ return transforms_dict
Time_TravelRephotography/models/encoder4editing/criteria/__init__.py ADDED
File without changes
Time_TravelRephotography/models/encoder4editing/criteria/id_loss.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from configs.paths_config import model_paths
4
+ from models.encoders.model_irse import Backbone
5
+
6
+
7
+ class IDLoss(nn.Module):
8
+ def __init__(self):
9
+ super(IDLoss, self).__init__()
10
+ print('Loading ResNet ArcFace')
11
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
12
+ self.facenet.load_state_dict(torch.load(model_paths['ir_se50']))
13
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
14
+ self.facenet.eval()
15
+ for module in [self.facenet, self.face_pool]:
16
+ for param in module.parameters():
17
+ param.requires_grad = False
18
+
19
+ def extract_feats(self, x):
20
+ x = x[:, :, 35:223, 32:220] # Crop interesting region
21
+ x = self.face_pool(x)
22
+ x_feats = self.facenet(x)
23
+ return x_feats
24
+
25
+ def forward(self, y_hat, y, x):
26
+ n_samples = x.shape[0]
27
+ x_feats = self.extract_feats(x)
28
+ y_feats = self.extract_feats(y) # Otherwise use the feature from there
29
+ y_hat_feats = self.extract_feats(y_hat)
30
+ y_feats = y_feats.detach()
31
+ loss = 0
32
+ sim_improvement = 0
33
+ id_logs = []
34
+ count = 0
35
+ for i in range(n_samples):
36
+ diff_target = y_hat_feats[i].dot(y_feats[i])
37
+ diff_input = y_hat_feats[i].dot(x_feats[i])
38
+ diff_views = y_feats[i].dot(x_feats[i])
39
+ id_logs.append({'diff_target': float(diff_target),
40
+ 'diff_input': float(diff_input),
41
+ 'diff_views': float(diff_views)})
42
+ loss += 1 - diff_target
43
+ id_diff = float(diff_target) - float(diff_views)
44
+ sim_improvement += id_diff
45
+ count += 1
46
+
47
+ return loss / count, sim_improvement / count, id_logs
Time_TravelRephotography/models/encoder4editing/criteria/lpips/__init__.py ADDED
File without changes
Time_TravelRephotography/models/encoder4editing/criteria/lpips/lpips.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from criteria.lpips.networks import get_network, LinLayers
5
+ from criteria.lpips.utils import get_state_dict
6
+
7
+
8
+ class LPIPS(nn.Module):
9
+ r"""Creates a criterion that measures
10
+ Learned Perceptual Image Patch Similarity (LPIPS).
11
+ Arguments:
12
+ net_type (str): the network type to compare the features:
13
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
14
+ version (str): the version of LPIPS. Default: 0.1.
15
+ """
16
+ def __init__(self, net_type: str = 'alex', version: str = '0.1'):
17
+
18
+ assert version in ['0.1'], 'v0.1 is only supported now'
19
+
20
+ super(LPIPS, self).__init__()
21
+
22
+ # pretrained network
23
+ self.net = get_network(net_type).to("cuda")
24
+
25
+ # linear layers
26
+ self.lin = LinLayers(self.net.n_channels_list).to("cuda")
27
+ self.lin.load_state_dict(get_state_dict(net_type, version))
28
+
29
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
30
+ feat_x, feat_y = self.net(x), self.net(y)
31
+
32
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
33
+ res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
34
+
35
+ return torch.sum(torch.cat(res, 0)) / x.shape[0]
Time_TravelRephotography/models/encoder4editing/criteria/lpips/networks.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ from itertools import chain
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from criteria.lpips.utils import normalize_activation
10
+
11
+
12
+ def get_network(net_type: str):
13
+ if net_type == 'alex':
14
+ return AlexNet()
15
+ elif net_type == 'squeeze':
16
+ return SqueezeNet()
17
+ elif net_type == 'vgg':
18
+ return VGG16()
19
+ else:
20
+ raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21
+
22
+
23
+ class LinLayers(nn.ModuleList):
24
+ def __init__(self, n_channels_list: Sequence[int]):
25
+ super(LinLayers, self).__init__([
26
+ nn.Sequential(
27
+ nn.Identity(),
28
+ nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29
+ ) for nc in n_channels_list
30
+ ])
31
+
32
+ for param in self.parameters():
33
+ param.requires_grad = False
34
+
35
+
36
+ class BaseNet(nn.Module):
37
+ def __init__(self):
38
+ super(BaseNet, self).__init__()
39
+
40
+ # register buffer
41
+ self.register_buffer(
42
+ 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43
+ self.register_buffer(
44
+ 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45
+
46
+ def set_requires_grad(self, state: bool):
47
+ for param in chain(self.parameters(), self.buffers()):
48
+ param.requires_grad = state
49
+
50
+ def z_score(self, x: torch.Tensor):
51
+ return (x - self.mean) / self.std
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ x = self.z_score(x)
55
+
56
+ output = []
57
+ for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58
+ x = layer(x)
59
+ if i in self.target_layers:
60
+ output.append(normalize_activation(x))
61
+ if len(output) == len(self.target_layers):
62
+ break
63
+ return output
64
+
65
+
66
+ class SqueezeNet(BaseNet):
67
+ def __init__(self):
68
+ super(SqueezeNet, self).__init__()
69
+
70
+ self.layers = models.squeezenet1_1(True).features
71
+ self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72
+ self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73
+
74
+ self.set_requires_grad(False)
75
+
76
+
77
+ class AlexNet(BaseNet):
78
+ def __init__(self):
79
+ super(AlexNet, self).__init__()
80
+
81
+ self.layers = models.alexnet(True).features
82
+ self.target_layers = [2, 5, 8, 10, 12]
83
+ self.n_channels_list = [64, 192, 384, 256, 256]
84
+
85
+ self.set_requires_grad(False)
86
+
87
+
88
+ class VGG16(BaseNet):
89
+ def __init__(self):
90
+ super(VGG16, self).__init__()
91
+
92
+ self.layers = models.vgg16(True).features
93
+ self.target_layers = [4, 9, 16, 23, 30]
94
+ self.n_channels_list = [64, 128, 256, 512, 512]
95
+
96
+ self.set_requires_grad(False)
Time_TravelRephotography/models/encoder4editing/criteria/lpips/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+
5
+
6
+ def normalize_activation(x, eps=1e-10):
7
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8
+ return x / (norm_factor + eps)
9
+
10
+
11
+ def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12
+ # build url
13
+ url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14
+ + f'master/lpips/weights/v{version}/{net_type}.pth'
15
+
16
+ # download
17
+ old_state_dict = torch.hub.load_state_dict_from_url(
18
+ url, progress=True,
19
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
20
+ )
21
+
22
+ # rename keys
23
+ new_state_dict = OrderedDict()
24
+ for key, val in old_state_dict.items():
25
+ new_key = key
26
+ new_key = new_key.replace('lin', '')
27
+ new_key = new_key.replace('model.', '')
28
+ new_state_dict[new_key] = val
29
+
30
+ return new_state_dict
Time_TravelRephotography/models/encoder4editing/criteria/moco_loss.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from configs.paths_config import model_paths
6
+
7
+
8
+ class MocoLoss(nn.Module):
9
+
10
+ def __init__(self, opts):
11
+ super(MocoLoss, self).__init__()
12
+ print("Loading MOCO model from path: {}".format(model_paths["moco"]))
13
+ self.model = self.__load_model()
14
+ self.model.eval()
15
+ for param in self.model.parameters():
16
+ param.requires_grad = False
17
+
18
+ @staticmethod
19
+ def __load_model():
20
+ import torchvision.models as models
21
+ model = models.__dict__["resnet50"]()
22
+ # freeze all layers but the last fc
23
+ for name, param in model.named_parameters():
24
+ if name not in ['fc.weight', 'fc.bias']:
25
+ param.requires_grad = False
26
+ checkpoint = torch.load(model_paths['moco'], map_location="cpu")
27
+ state_dict = checkpoint['state_dict']
28
+ # rename moco pre-trained keys
29
+ for k in list(state_dict.keys()):
30
+ # retain only encoder_q up to before the embedding layer
31
+ if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
32
+ # remove prefix
33
+ state_dict[k[len("module.encoder_q."):]] = state_dict[k]
34
+ # delete renamed or unused k
35
+ del state_dict[k]
36
+ msg = model.load_state_dict(state_dict, strict=False)
37
+ assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
38
+ # remove output layer
39
+ model = nn.Sequential(*list(model.children())[:-1]).cuda()
40
+ return model
41
+
42
+ def extract_feats(self, x):
43
+ x = F.interpolate(x, size=224)
44
+ x_feats = self.model(x)
45
+ x_feats = nn.functional.normalize(x_feats, dim=1)
46
+ x_feats = x_feats.squeeze()
47
+ return x_feats
48
+
49
+ def forward(self, y_hat, y, x):
50
+ n_samples = x.shape[0]
51
+ x_feats = self.extract_feats(x)
52
+ y_feats = self.extract_feats(y)
53
+ y_hat_feats = self.extract_feats(y_hat)
54
+ y_feats = y_feats.detach()
55
+ loss = 0
56
+ sim_improvement = 0
57
+ sim_logs = []
58
+ count = 0
59
+ for i in range(n_samples):
60
+ diff_target = y_hat_feats[i].dot(y_feats[i])
61
+ diff_input = y_hat_feats[i].dot(x_feats[i])
62
+ diff_views = y_feats[i].dot(x_feats[i])
63
+ sim_logs.append({'diff_target': float(diff_target),
64
+ 'diff_input': float(diff_input),
65
+ 'diff_views': float(diff_views)})
66
+ loss += 1 - diff_target
67
+ sim_diff = float(diff_target) - float(diff_views)
68
+ sim_improvement += sim_diff
69
+ count += 1
70
+
71
+ return loss / count, sim_improvement / count, sim_logs
Time_TravelRephotography/models/encoder4editing/criteria/w_norm.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class WNormLoss(nn.Module):
6
+
7
+ def __init__(self, start_from_latent_avg=True):
8
+ super(WNormLoss, self).__init__()
9
+ self.start_from_latent_avg = start_from_latent_avg
10
+
11
+ def forward(self, latent, latent_avg=None):
12
+ if self.start_from_latent_avg:
13
+ latent = latent - latent_avg
14
+ return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]
Time_TravelRephotography/models/encoder4editing/datasets/__init__.py ADDED
File without changes
Time_TravelRephotography/models/encoder4editing/datasets/gt_res_dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # encoding: utf-8
3
+ import os
4
+ from torch.utils.data import Dataset
5
+ from PIL import Image
6
+ import torch
7
+
8
+ class GTResDataset(Dataset):
9
+
10
+ def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
11
+ self.pairs = []
12
+ for f in os.listdir(root_path):
13
+ image_path = os.path.join(root_path, f)
14
+ gt_path = os.path.join(gt_dir, f)
15
+ if f.endswith(".jpg") or f.endswith(".png"):
16
+ self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
17
+ self.transform = transform
18
+ self.transform_train = transform_train
19
+
20
+ def __len__(self):
21
+ return len(self.pairs)
22
+
23
+ def __getitem__(self, index):
24
+ from_path, to_path, _ = self.pairs[index]
25
+ from_im = Image.open(from_path).convert('RGB')
26
+ to_im = Image.open(to_path).convert('RGB')
27
+
28
+ if self.transform:
29
+ to_im = self.transform(to_im)
30
+ from_im = self.transform(from_im)
31
+
32
+ return from_im, to_im
Time_TravelRephotography/models/encoder4editing/datasets/images_dataset.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ from utils import data_utils
4
+
5
+
6
+ class ImagesDataset(Dataset):
7
+
8
+ def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None):
9
+ self.source_paths = sorted(data_utils.make_dataset(source_root))
10
+ self.target_paths = sorted(data_utils.make_dataset(target_root))
11
+ self.source_transform = source_transform
12
+ self.target_transform = target_transform
13
+ self.opts = opts
14
+
15
+ def __len__(self):
16
+ return len(self.source_paths)
17
+
18
+ def __getitem__(self, index):
19
+ from_path = self.source_paths[index]
20
+ from_im = Image.open(from_path)
21
+ from_im = from_im.convert('RGB')
22
+
23
+ to_path = self.target_paths[index]
24
+ to_im = Image.open(to_path).convert('RGB')
25
+ if self.target_transform:
26
+ to_im = self.target_transform(to_im)
27
+
28
+ if self.source_transform:
29
+ from_im = self.source_transform(from_im)
30
+ else:
31
+ from_im = to_im
32
+
33
+ return from_im, to_im
Time_TravelRephotography/models/encoder4editing/datasets/inference_dataset.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ from utils import data_utils
4
+
5
+
6
+ class InferenceDataset(Dataset):
7
+
8
+ def __init__(self, root, opts, transform=None, preprocess=None):
9
+ self.paths = sorted(data_utils.make_dataset(root))
10
+ self.transform = transform
11
+ self.preprocess = preprocess
12
+ self.opts = opts
13
+
14
+ def __len__(self):
15
+ return len(self.paths)
16
+
17
+ def __getitem__(self, index):
18
+ from_path = self.paths[index]
19
+ if self.preprocess is not None:
20
+ from_im = self.preprocess(from_path)
21
+ else:
22
+ from_im = Image.open(from_path).convert('RGB')
23
+ if self.transform:
24
+ from_im = self.transform(from_im)
25
+ return from_im
Time_TravelRephotography/models/encoder4editing/editings/ganspace.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def edit(latents, pca, edit_directions):
5
+ edit_latents = []
6
+ for latent in latents:
7
+ for pca_idx, start, end, strength in edit_directions:
8
+ delta = get_delta(pca, latent, pca_idx, strength)
9
+ delta_padded = torch.zeros(latent.shape).to('cuda')
10
+ delta_padded[start:end] += delta.repeat(end - start, 1)
11
+ edit_latents.append(latent + delta_padded)
12
+ return torch.stack(edit_latents)
13
+
14
+
15
+ def get_delta(pca, latent, idx, strength):
16
+ # pca: ganspace checkpoint. latent: (16, 512) w+
17
+ w_centered = latent - pca['mean'].to('cuda')
18
+ lat_comp = pca['comp'].to('cuda')
19
+ lat_std = pca['std'].to('cuda')
20
+ w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx]
21
+ delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx]
22
+ return delta
Time_TravelRephotography/models/encoder4editing/editings/ganspace_pca/cars_pca.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5c3bae61ecd85de077fbbf103f5f30cf4b7676fe23a8508166eaf2ce73c8392
3
+ size 167562
Time_TravelRephotography/models/encoder4editing/editings/ganspace_pca/ffhq_pca.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d7f9df1c96180d9026b9cb8d04753579fbf385f321a9d0e263641601c5e5d36
3
+ size 167562
Time_TravelRephotography/models/encoder4editing/editings/interfacegan_directions/age.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50074516b1629707d89b5e43d6b8abd1792212fa3b961a87a11323d6a5222ae0
3
+ size 2808