Spaces:
Build error
Build error
Duplicate from feng2022/Time-TravelRephotography
Browse filesCo-authored-by: Time-travelRephotography <feng2022@users.noreply.huggingface.co>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +27 -0
- .gitignore +133 -0
- .gitmodules +9 -0
- README.md +14 -0
- Time_TravelRephotography/LICENSE +21 -0
- Time_TravelRephotography/LICENSE-NVIDIA +101 -0
- Time_TravelRephotography/LICENSE-STYLEGAN2 +21 -0
- Time_TravelRephotography/losses/color_transfer_loss.py +60 -0
- Time_TravelRephotography/losses/contextual_loss/.gitignore +104 -0
- Time_TravelRephotography/losses/contextual_loss/LICENSE +21 -0
- Time_TravelRephotography/losses/contextual_loss/__init__.py +1 -0
- Time_TravelRephotography/losses/contextual_loss/config.py +2 -0
- Time_TravelRephotography/losses/contextual_loss/functional.py +198 -0
- Time_TravelRephotography/losses/contextual_loss/modules/__init__.py +4 -0
- Time_TravelRephotography/losses/contextual_loss/modules/contextual.py +121 -0
- Time_TravelRephotography/losses/contextual_loss/modules/contextual_bilateral.py +69 -0
- Time_TravelRephotography/losses/contextual_loss/modules/vgg.py +48 -0
- Time_TravelRephotography/losses/joint_loss.py +167 -0
- Time_TravelRephotography/losses/perceptual_loss.py +111 -0
- Time_TravelRephotography/losses/reconstruction.py +119 -0
- Time_TravelRephotography/losses/regularize_noise.py +37 -0
- Time_TravelRephotography/model.py +697 -0
- Time_TravelRephotography/models/__init__.py +0 -0
- Time_TravelRephotography/models/degrade.py +122 -0
- Time_TravelRephotography/models/encoder.py +66 -0
- Time_TravelRephotography/models/encoder4editing/.gitignore +133 -0
- Time_TravelRephotography/models/encoder4editing/LICENSE +21 -0
- Time_TravelRephotography/models/encoder4editing/README.md +143 -0
- Time_TravelRephotography/models/encoder4editing/__init__.py +15 -0
- Time_TravelRephotography/models/encoder4editing/bash_scripts/inference.sh +15 -0
- Time_TravelRephotography/models/encoder4editing/configs/__init__.py +0 -0
- Time_TravelRephotography/models/encoder4editing/configs/data_configs.py +41 -0
- Time_TravelRephotography/models/encoder4editing/configs/paths_config.py +28 -0
- Time_TravelRephotography/models/encoder4editing/configs/transforms_config.py +62 -0
- Time_TravelRephotography/models/encoder4editing/criteria/__init__.py +0 -0
- Time_TravelRephotography/models/encoder4editing/criteria/id_loss.py +47 -0
- Time_TravelRephotography/models/encoder4editing/criteria/lpips/__init__.py +0 -0
- Time_TravelRephotography/models/encoder4editing/criteria/lpips/lpips.py +35 -0
- Time_TravelRephotography/models/encoder4editing/criteria/lpips/networks.py +96 -0
- Time_TravelRephotography/models/encoder4editing/criteria/lpips/utils.py +30 -0
- Time_TravelRephotography/models/encoder4editing/criteria/moco_loss.py +71 -0
- Time_TravelRephotography/models/encoder4editing/criteria/w_norm.py +14 -0
- Time_TravelRephotography/models/encoder4editing/datasets/__init__.py +0 -0
- Time_TravelRephotography/models/encoder4editing/datasets/gt_res_dataset.py +32 -0
- Time_TravelRephotography/models/encoder4editing/datasets/images_dataset.py +33 -0
- Time_TravelRephotography/models/encoder4editing/datasets/inference_dataset.py +25 -0
- Time_TravelRephotography/models/encoder4editing/editings/ganspace.py +22 -0
- Time_TravelRephotography/models/encoder4editing/editings/ganspace_pca/cars_pca.pt +3 -0
- Time_TravelRephotography/models/encoder4editing/editings/ganspace_pca/ffhq_pca.pt +3 -0
- Time_TravelRephotography/models/encoder4editing/editings/interfacegan_directions/age.pt +3 -0
.gitattributes
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
19 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
wandb/
|
132 |
+
*.lmdb/
|
133 |
+
*.pkl
|
.gitmodules
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "third_party/face_parsing"]
|
2 |
+
path = third_party/face_parsing
|
3 |
+
url = https://github.com/Time-Travel-Rephotography/face-parsing.PyTorch.git
|
4 |
+
[submodule "models/encoder4editing"]
|
5 |
+
path = models/encoder4editing
|
6 |
+
url = https://github.com/Time-Travel-Rephotography/encoder4editing.git
|
7 |
+
[submodule "losses/contextual_loss"]
|
8 |
+
path = losses/contextual_loss
|
9 |
+
url = https://github.com/Time-Travel-Rephotography/contextual_loss_pytorch.git
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Time TravelRephotography
|
3 |
+
emoji: 🦀
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 2.9.4
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
duplicated_from: feng2022/Time-TravelRephotography
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
|
Time_TravelRephotography/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 Time-Travel-Rephotography
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Time_TravelRephotography/LICENSE-NVIDIA
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
|
3 |
+
|
4 |
+
Nvidia Source Code License-NC
|
5 |
+
|
6 |
+
=======================================================================
|
7 |
+
|
8 |
+
1. Definitions
|
9 |
+
|
10 |
+
"Licensor" means any person or entity that distributes its Work.
|
11 |
+
|
12 |
+
"Software" means the original work of authorship made available under
|
13 |
+
this License.
|
14 |
+
|
15 |
+
"Work" means the Software and any additions to or derivative works of
|
16 |
+
the Software that are made available under this License.
|
17 |
+
|
18 |
+
"Nvidia Processors" means any central processing unit (CPU), graphics
|
19 |
+
processing unit (GPU), field-programmable gate array (FPGA),
|
20 |
+
application-specific integrated circuit (ASIC) or any combination
|
21 |
+
thereof designed, made, sold, or provided by Nvidia or its affiliates.
|
22 |
+
|
23 |
+
The terms "reproduce," "reproduction," "derivative works," and
|
24 |
+
"distribution" have the meaning as provided under U.S. copyright law;
|
25 |
+
provided, however, that for the purposes of this License, derivative
|
26 |
+
works shall not include works that remain separable from, or merely
|
27 |
+
link (or bind by name) to the interfaces of, the Work.
|
28 |
+
|
29 |
+
Works, including the Software, are "made available" under this License
|
30 |
+
by including in or with the Work either (a) a copyright notice
|
31 |
+
referencing the applicability of this License to the Work, or (b) a
|
32 |
+
copy of this License.
|
33 |
+
|
34 |
+
2. License Grants
|
35 |
+
|
36 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this
|
37 |
+
License, each Licensor grants to you a perpetual, worldwide,
|
38 |
+
non-exclusive, royalty-free, copyright license to reproduce,
|
39 |
+
prepare derivative works of, publicly display, publicly perform,
|
40 |
+
sublicense and distribute its Work and any resulting derivative
|
41 |
+
works in any form.
|
42 |
+
|
43 |
+
3. Limitations
|
44 |
+
|
45 |
+
3.1 Redistribution. You may reproduce or distribute the Work only
|
46 |
+
if (a) you do so under this License, (b) you include a complete
|
47 |
+
copy of this License with your distribution, and (c) you retain
|
48 |
+
without modification any copyright, patent, trademark, or
|
49 |
+
attribution notices that are present in the Work.
|
50 |
+
|
51 |
+
3.2 Derivative Works. You may specify that additional or different
|
52 |
+
terms apply to the use, reproduction, and distribution of your
|
53 |
+
derivative works of the Work ("Your Terms") only if (a) Your Terms
|
54 |
+
provide that the use limitation in Section 3.3 applies to your
|
55 |
+
derivative works, and (b) you identify the specific derivative
|
56 |
+
works that are subject to Your Terms. Notwithstanding Your Terms,
|
57 |
+
this License (including the redistribution requirements in Section
|
58 |
+
3.1) will continue to apply to the Work itself.
|
59 |
+
|
60 |
+
3.3 Use Limitation. The Work and any derivative works thereof only
|
61 |
+
may be used or intended for use non-commercially. The Work or
|
62 |
+
derivative works thereof may be used or intended for use by Nvidia
|
63 |
+
or its affiliates commercially or non-commercially. As used herein,
|
64 |
+
"non-commercially" means for research or evaluation purposes only.
|
65 |
+
|
66 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim
|
67 |
+
against any Licensor (including any claim, cross-claim or
|
68 |
+
counterclaim in a lawsuit) to enforce any patents that you allege
|
69 |
+
are infringed by any Work, then your rights under this License from
|
70 |
+
such Licensor (including the grants in Sections 2.1 and 2.2) will
|
71 |
+
terminate immediately.
|
72 |
+
|
73 |
+
3.5 Trademarks. This License does not grant any rights to use any
|
74 |
+
Licensor's or its affiliates' names, logos, or trademarks, except
|
75 |
+
as necessary to reproduce the notices described in this License.
|
76 |
+
|
77 |
+
3.6 Termination. If you violate any term of this License, then your
|
78 |
+
rights under this License (including the grants in Sections 2.1 and
|
79 |
+
2.2) will terminate immediately.
|
80 |
+
|
81 |
+
4. Disclaimer of Warranty.
|
82 |
+
|
83 |
+
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
84 |
+
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
85 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
86 |
+
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
|
87 |
+
THIS LICENSE.
|
88 |
+
|
89 |
+
5. Limitation of Liability.
|
90 |
+
|
91 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
92 |
+
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
93 |
+
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
94 |
+
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
95 |
+
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
96 |
+
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
97 |
+
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
98 |
+
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
|
99 |
+
THE POSSIBILITY OF SUCH DAMAGES.
|
100 |
+
|
101 |
+
=======================================================================
|
Time_TravelRephotography/LICENSE-STYLEGAN2
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019 Kim Seonghyeon
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Time_TravelRephotography/losses/color_transfer_loss.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn.functional import (
|
6 |
+
smooth_l1_loss,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
def flatten_CHW(im: torch.Tensor) -> torch.Tensor:
|
11 |
+
"""
|
12 |
+
(B, C, H, W) -> (B, -1)
|
13 |
+
"""
|
14 |
+
B = im.shape[0]
|
15 |
+
return im.reshape(B, -1)
|
16 |
+
|
17 |
+
|
18 |
+
def stddev(x: torch.Tensor) -> torch.Tensor:
|
19 |
+
"""
|
20 |
+
x: (B, -1), assume with mean normalized
|
21 |
+
Retuens:
|
22 |
+
stddev: (B)
|
23 |
+
"""
|
24 |
+
return torch.sqrt(torch.mean(x * x, dim=-1))
|
25 |
+
|
26 |
+
|
27 |
+
def gram_matrix(input_):
|
28 |
+
B, C = input_.shape[:2]
|
29 |
+
features = input_.view(B, C, -1)
|
30 |
+
N = features.shape[-1]
|
31 |
+
G = torch.bmm(features, features.transpose(1, 2)) # C x C
|
32 |
+
return G.div(C * N)
|
33 |
+
|
34 |
+
|
35 |
+
class ColorTransferLoss(nn.Module):
|
36 |
+
"""Penalize the gram matrix difference between StyleGAN2's ToRGB outputs"""
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
init_rgbs,
|
40 |
+
scale_rgb: bool = False
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
with torch.no_grad():
|
45 |
+
init_feats = [x.detach() for x in init_rgbs]
|
46 |
+
self.stds = [stddev(flatten_CHW(rgb)) if scale_rgb else 1 for rgb in init_feats] # (B, 1, 1, 1) or scalar
|
47 |
+
self.grams = [gram_matrix(rgb / std) for rgb, std in zip(init_feats, self.stds)]
|
48 |
+
|
49 |
+
def forward(self, rgbs: List[torch.Tensor], level: int = None):
|
50 |
+
if level is None:
|
51 |
+
level = len(self.grams)
|
52 |
+
|
53 |
+
feats = rgbs
|
54 |
+
loss = 0
|
55 |
+
for i, (rgb, std) in enumerate(zip(feats[:level], self.stds[:level])):
|
56 |
+
G = gram_matrix(rgb / std)
|
57 |
+
loss = loss + smooth_l1_loss(G, self.grams[i])
|
58 |
+
|
59 |
+
return loss
|
60 |
+
|
Time_TravelRephotography/losses/contextual_loss/.gitignore
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
MANIFEST
|
27 |
+
|
28 |
+
# PyInstaller
|
29 |
+
# Usually these files are written by a python script from a template
|
30 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
31 |
+
*.manifest
|
32 |
+
*.spec
|
33 |
+
|
34 |
+
# Installer logs
|
35 |
+
pip-log.txt
|
36 |
+
pip-delete-this-directory.txt
|
37 |
+
|
38 |
+
# Unit test / coverage reports
|
39 |
+
htmlcov/
|
40 |
+
.tox/
|
41 |
+
.coverage
|
42 |
+
.coverage.*
|
43 |
+
.cache
|
44 |
+
nosetests.xml
|
45 |
+
coverage.xml
|
46 |
+
*.cover
|
47 |
+
.hypothesis/
|
48 |
+
.pytest_cache/
|
49 |
+
|
50 |
+
# Translations
|
51 |
+
*.mo
|
52 |
+
*.pot
|
53 |
+
|
54 |
+
# Django stuff:
|
55 |
+
*.log
|
56 |
+
local_settings.py
|
57 |
+
db.sqlite3
|
58 |
+
|
59 |
+
# Flask stuff:
|
60 |
+
instance/
|
61 |
+
.webassets-cache
|
62 |
+
|
63 |
+
# Scrapy stuff:
|
64 |
+
.scrapy
|
65 |
+
|
66 |
+
# Sphinx documentation
|
67 |
+
docs/_build/
|
68 |
+
|
69 |
+
# PyBuilder
|
70 |
+
target/
|
71 |
+
|
72 |
+
# Jupyter Notebook
|
73 |
+
.ipynb_checkpoints
|
74 |
+
|
75 |
+
# pyenv
|
76 |
+
.python-version
|
77 |
+
|
78 |
+
# celery beat schedule file
|
79 |
+
celerybeat-schedule
|
80 |
+
|
81 |
+
# SageMath parsed files
|
82 |
+
*.sage.py
|
83 |
+
|
84 |
+
# Environments
|
85 |
+
.env
|
86 |
+
.venv
|
87 |
+
env/
|
88 |
+
venv/
|
89 |
+
ENV/
|
90 |
+
env.bak/
|
91 |
+
venv.bak/
|
92 |
+
|
93 |
+
# Spyder project settings
|
94 |
+
.spyderproject
|
95 |
+
.spyproject
|
96 |
+
|
97 |
+
# Rope project settings
|
98 |
+
.ropeproject
|
99 |
+
|
100 |
+
# mkdocs documentation
|
101 |
+
/site
|
102 |
+
|
103 |
+
# mypy
|
104 |
+
.mypy_cache/
|
Time_TravelRephotography/losses/contextual_loss/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019 Sou Uchida
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Time_TravelRephotography/losses/contextual_loss/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .modules import *
|
Time_TravelRephotography/losses/contextual_loss/config.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# TODO: add supports for L1, L2 etc.
|
2 |
+
LOSS_TYPES = ['cosine', 'l1', 'l2']
|
Time_TravelRephotography/losses/contextual_loss/functional.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from .config import LOSS_TYPES
|
5 |
+
|
6 |
+
__all__ = ['contextual_loss', 'contextual_bilateral_loss']
|
7 |
+
|
8 |
+
|
9 |
+
def contextual_loss(x: torch.Tensor,
|
10 |
+
y: torch.Tensor,
|
11 |
+
band_width: float = 0.5,
|
12 |
+
loss_type: str = 'cosine',
|
13 |
+
all_dist: bool = False):
|
14 |
+
"""
|
15 |
+
Computes contextual loss between x and y.
|
16 |
+
The most of this code is copied from
|
17 |
+
https://gist.github.com/yunjey/3105146c736f9c1055463c33b4c989da.
|
18 |
+
|
19 |
+
Parameters
|
20 |
+
---
|
21 |
+
x : torch.Tensor
|
22 |
+
features of shape (N, C, H, W).
|
23 |
+
y : torch.Tensor
|
24 |
+
features of shape (N, C, H, W).
|
25 |
+
band_width : float, optional
|
26 |
+
a band-width parameter used to convert distance to similarity.
|
27 |
+
in the paper, this is described as :math:`h`.
|
28 |
+
loss_type : str, optional
|
29 |
+
a loss type to measure the distance between features.
|
30 |
+
Note: `l1` and `l2` frequently raises OOM.
|
31 |
+
|
32 |
+
Returns
|
33 |
+
---
|
34 |
+
cx_loss : torch.Tensor
|
35 |
+
contextual loss between x and y (Eq (1) in the paper)
|
36 |
+
"""
|
37 |
+
|
38 |
+
assert x.size() == y.size(), 'input tensor must have the same size.'
|
39 |
+
assert loss_type in LOSS_TYPES, f'select a loss type from {LOSS_TYPES}.'
|
40 |
+
|
41 |
+
N, C, H, W = x.size()
|
42 |
+
|
43 |
+
if loss_type == 'cosine':
|
44 |
+
dist_raw = compute_cosine_distance(x, y)
|
45 |
+
elif loss_type == 'l1':
|
46 |
+
dist_raw = compute_l1_distance(x, y)
|
47 |
+
elif loss_type == 'l2':
|
48 |
+
dist_raw = compute_l2_distance(x, y)
|
49 |
+
|
50 |
+
dist_tilde = compute_relative_distance(dist_raw)
|
51 |
+
cx = compute_cx(dist_tilde, band_width)
|
52 |
+
if all_dist:
|
53 |
+
return cx
|
54 |
+
|
55 |
+
cx = torch.mean(torch.max(cx, dim=1)[0], dim=1) # Eq(1)
|
56 |
+
cx_loss = torch.mean(-torch.log(cx + 1e-5)) # Eq(5)
|
57 |
+
|
58 |
+
return cx_loss
|
59 |
+
|
60 |
+
|
61 |
+
# TODO: Operation check
|
62 |
+
def contextual_bilateral_loss(x: torch.Tensor,
|
63 |
+
y: torch.Tensor,
|
64 |
+
weight_sp: float = 0.1,
|
65 |
+
band_width: float = 1.,
|
66 |
+
loss_type: str = 'cosine'):
|
67 |
+
"""
|
68 |
+
Computes Contextual Bilateral (CoBi) Loss between x and y,
|
69 |
+
proposed in https://arxiv.org/pdf/1905.05169.pdf.
|
70 |
+
|
71 |
+
Parameters
|
72 |
+
---
|
73 |
+
x : torch.Tensor
|
74 |
+
features of shape (N, C, H, W).
|
75 |
+
y : torch.Tensor
|
76 |
+
features of shape (N, C, H, W).
|
77 |
+
band_width : float, optional
|
78 |
+
a band-width parameter used to convert distance to similarity.
|
79 |
+
in the paper, this is described as :math:`h`.
|
80 |
+
loss_type : str, optional
|
81 |
+
a loss type to measure the distance between features.
|
82 |
+
Note: `l1` and `l2` frequently raises OOM.
|
83 |
+
|
84 |
+
Returns
|
85 |
+
---
|
86 |
+
cx_loss : torch.Tensor
|
87 |
+
contextual loss between x and y (Eq (1) in the paper).
|
88 |
+
k_arg_max_NC : torch.Tensor
|
89 |
+
indices to maximize similarity over channels.
|
90 |
+
"""
|
91 |
+
|
92 |
+
assert x.size() == y.size(), 'input tensor must have the same size.'
|
93 |
+
assert loss_type in LOSS_TYPES, f'select a loss type from {LOSS_TYPES}.'
|
94 |
+
|
95 |
+
# spatial loss
|
96 |
+
grid = compute_meshgrid(x.shape).to(x.device)
|
97 |
+
dist_raw = compute_l2_distance(grid, grid)
|
98 |
+
dist_tilde = compute_relative_distance(dist_raw)
|
99 |
+
cx_sp = compute_cx(dist_tilde, band_width)
|
100 |
+
|
101 |
+
# feature loss
|
102 |
+
if loss_type == 'cosine':
|
103 |
+
dist_raw = compute_cosine_distance(x, y)
|
104 |
+
elif loss_type == 'l1':
|
105 |
+
dist_raw = compute_l1_distance(x, y)
|
106 |
+
elif loss_type == 'l2':
|
107 |
+
dist_raw = compute_l2_distance(x, y)
|
108 |
+
dist_tilde = compute_relative_distance(dist_raw)
|
109 |
+
cx_feat = compute_cx(dist_tilde, band_width)
|
110 |
+
|
111 |
+
# combined loss
|
112 |
+
cx_combine = (1. - weight_sp) * cx_feat + weight_sp * cx_sp
|
113 |
+
|
114 |
+
k_max_NC, _ = torch.max(cx_combine, dim=2, keepdim=True)
|
115 |
+
|
116 |
+
cx = k_max_NC.mean(dim=1)
|
117 |
+
cx_loss = torch.mean(-torch.log(cx + 1e-5))
|
118 |
+
|
119 |
+
return cx_loss
|
120 |
+
|
121 |
+
|
122 |
+
def compute_cx(dist_tilde, band_width):
|
123 |
+
w = torch.exp((1 - dist_tilde) / band_width) # Eq(3)
|
124 |
+
cx = w / torch.sum(w, dim=2, keepdim=True) # Eq(4)
|
125 |
+
return cx
|
126 |
+
|
127 |
+
|
128 |
+
def compute_relative_distance(dist_raw):
|
129 |
+
dist_min, _ = torch.min(dist_raw, dim=2, keepdim=True)
|
130 |
+
dist_tilde = dist_raw / (dist_min + 1e-5)
|
131 |
+
return dist_tilde
|
132 |
+
|
133 |
+
|
134 |
+
def compute_cosine_distance(x, y):
|
135 |
+
# mean shifting by channel-wise mean of `y`.
|
136 |
+
y_mu = y.mean(dim=(0, 2, 3), keepdim=True)
|
137 |
+
x_centered = x - y_mu
|
138 |
+
y_centered = y - y_mu
|
139 |
+
|
140 |
+
# L2 normalization
|
141 |
+
x_normalized = F.normalize(x_centered, p=2, dim=1)
|
142 |
+
y_normalized = F.normalize(y_centered, p=2, dim=1)
|
143 |
+
|
144 |
+
# channel-wise vectorization
|
145 |
+
N, C, *_ = x.size()
|
146 |
+
x_normalized = x_normalized.reshape(N, C, -1) # (N, C, H*W)
|
147 |
+
y_normalized = y_normalized.reshape(N, C, -1) # (N, C, H*W)
|
148 |
+
|
149 |
+
# consine similarity
|
150 |
+
cosine_sim = torch.bmm(x_normalized.transpose(1, 2),
|
151 |
+
y_normalized) # (N, H*W, H*W)
|
152 |
+
|
153 |
+
# convert to distance
|
154 |
+
dist = 1 - cosine_sim
|
155 |
+
|
156 |
+
return dist
|
157 |
+
|
158 |
+
|
159 |
+
# TODO: Considering avoiding OOM.
|
160 |
+
def compute_l1_distance(x: torch.Tensor, y: torch.Tensor):
|
161 |
+
N, C, H, W = x.size()
|
162 |
+
x_vec = x.view(N, C, -1)
|
163 |
+
y_vec = y.view(N, C, -1)
|
164 |
+
|
165 |
+
dist = x_vec.unsqueeze(2) - y_vec.unsqueeze(3)
|
166 |
+
dist = dist.abs().sum(dim=1)
|
167 |
+
dist = dist.transpose(1, 2).reshape(N, H*W, H*W)
|
168 |
+
dist = dist.clamp(min=0.)
|
169 |
+
|
170 |
+
return dist
|
171 |
+
|
172 |
+
|
173 |
+
# TODO: Considering avoiding OOM.
|
174 |
+
def compute_l2_distance(x, y):
|
175 |
+
N, C, H, W = x.size()
|
176 |
+
x_vec = x.view(N, C, -1)
|
177 |
+
y_vec = y.view(N, C, -1)
|
178 |
+
x_s = torch.sum(x_vec ** 2, dim=1)
|
179 |
+
y_s = torch.sum(y_vec ** 2, dim=1)
|
180 |
+
|
181 |
+
A = y_vec.transpose(1, 2) @ x_vec
|
182 |
+
dist = y_s - 2 * A + x_s.transpose(0, 1)
|
183 |
+
dist = dist.transpose(1, 2).reshape(N, H*W, H*W)
|
184 |
+
dist = dist.clamp(min=0.)
|
185 |
+
|
186 |
+
return dist
|
187 |
+
|
188 |
+
|
189 |
+
def compute_meshgrid(shape):
|
190 |
+
N, C, H, W = shape
|
191 |
+
rows = torch.arange(0, H, dtype=torch.float32) / (H + 1)
|
192 |
+
cols = torch.arange(0, W, dtype=torch.float32) / (W + 1)
|
193 |
+
|
194 |
+
feature_grid = torch.meshgrid(rows, cols)
|
195 |
+
feature_grid = torch.stack(feature_grid).unsqueeze(0)
|
196 |
+
feature_grid = torch.cat([feature_grid for _ in range(N)], dim=0)
|
197 |
+
|
198 |
+
return feature_grid
|
Time_TravelRephotography/losses/contextual_loss/modules/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .contextual import ContextualLoss
|
2 |
+
from .contextual_bilateral import ContextualBilateralLoss
|
3 |
+
|
4 |
+
__all__ = ['ContextualLoss', 'ContextualBilateralLoss']
|
Time_TravelRephotography/losses/contextual_loss/modules/contextual.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import (
|
3 |
+
Iterable,
|
4 |
+
List,
|
5 |
+
Optional,
|
6 |
+
)
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from .vgg import VGG19
|
13 |
+
from .. import functional as F
|
14 |
+
from ..config import LOSS_TYPES
|
15 |
+
|
16 |
+
|
17 |
+
class ContextualLoss(nn.Module):
|
18 |
+
"""
|
19 |
+
Creates a criterion that measures the contextual loss.
|
20 |
+
|
21 |
+
Parameters
|
22 |
+
---
|
23 |
+
band_width : int, optional
|
24 |
+
a band_width parameter described as :math:`h` in the paper.
|
25 |
+
use_vgg : bool, optional
|
26 |
+
if you want to use VGG feature, set this `True`.
|
27 |
+
vgg_layer : str, optional
|
28 |
+
intermidiate layer name for VGG feature.
|
29 |
+
Now we support layer names:
|
30 |
+
`['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4']`
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
band_width: float = 0.5,
|
36 |
+
loss_type: str = 'cosine',
|
37 |
+
use_vgg: bool = False,
|
38 |
+
vgg_layers: List[str] = ['relu3_4'],
|
39 |
+
feature_1d_size: int = 64,
|
40 |
+
):
|
41 |
+
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
assert band_width > 0, 'band_width parameter must be positive.'
|
45 |
+
assert loss_type in LOSS_TYPES,\
|
46 |
+
f'select a loss type from {LOSS_TYPES}.'
|
47 |
+
|
48 |
+
self.loss_type = loss_type
|
49 |
+
self.band_width = band_width
|
50 |
+
self.feature_1d_size = feature_1d_size
|
51 |
+
|
52 |
+
if use_vgg:
|
53 |
+
self.vgg_model = VGG19()
|
54 |
+
self.vgg_layers = vgg_layers
|
55 |
+
self.register_buffer(
|
56 |
+
name='vgg_mean',
|
57 |
+
tensor=torch.tensor(
|
58 |
+
[[[0.485]], [[0.456]], [[0.406]]], requires_grad=False)
|
59 |
+
)
|
60 |
+
self.register_buffer(
|
61 |
+
name='vgg_std',
|
62 |
+
tensor=torch.tensor(
|
63 |
+
[[[0.229]], [[0.224]], [[0.225]]], requires_grad=False)
|
64 |
+
)
|
65 |
+
|
66 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor, all_dist: bool = False):
|
67 |
+
if not hasattr(self, 'vgg_model'):
|
68 |
+
return self.contextual_loss(x, y, self.feature_1d_size, self.band_width, all_dist=all_dist)
|
69 |
+
|
70 |
+
|
71 |
+
x = self.forward_vgg(x)
|
72 |
+
y = self.forward_vgg(y)
|
73 |
+
|
74 |
+
loss = 0
|
75 |
+
for layer in self.vgg_layers:
|
76 |
+
# picking up vgg feature maps
|
77 |
+
fx = getattr(x, layer)
|
78 |
+
fy = getattr(y, layer)
|
79 |
+
loss = loss + self.contextual_loss(
|
80 |
+
fx, fy, self.feature_1d_size, self.band_width, all_dist=all_dist, loss_type=self.loss_type
|
81 |
+
)
|
82 |
+
return loss
|
83 |
+
|
84 |
+
def forward_vgg(self, x: torch.Tensor):
|
85 |
+
assert x.shape[1] == 3, 'VGG model takes 3 chennel images.'
|
86 |
+
# [-1, 1] -> [0, 1]
|
87 |
+
x = (x + 1) * 0.5
|
88 |
+
|
89 |
+
# normalization
|
90 |
+
x = x.sub(self.vgg_mean.detach()).div(self.vgg_std)
|
91 |
+
return self.vgg_model(x)
|
92 |
+
|
93 |
+
@classmethod
|
94 |
+
def contextual_loss(
|
95 |
+
cls,
|
96 |
+
x: torch.Tensor, y: torch.Tensor,
|
97 |
+
feature_1d_size: int,
|
98 |
+
band_width: int,
|
99 |
+
all_dist: bool = False,
|
100 |
+
loss_type: str = 'cosine',
|
101 |
+
) -> torch.Tensor:
|
102 |
+
feature_size = feature_1d_size ** 2
|
103 |
+
if np.prod(x.shape[2:]) > feature_size or np.prod(y.shape[2:]) > feature_size:
|
104 |
+
x, indices = cls.random_sampling(x, feature_1d_size=feature_1d_size)
|
105 |
+
y, _ = cls.random_sampling(y, feature_1d_size=feature_1d_size, indices=indices)
|
106 |
+
|
107 |
+
return F.contextual_loss(x, y, band_width, all_dist=all_dist, loss_type=loss_type)
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
def random_sampling(
|
111 |
+
tensor_NCHW: torch.Tensor, feature_1d_size: int, indices: Optional[List] = None
|
112 |
+
):
|
113 |
+
N, C, H, W = tensor_NCHW.shape
|
114 |
+
S = H * W
|
115 |
+
tensor_NCS = tensor_NCHW.reshape([N, C, S])
|
116 |
+
if indices is None:
|
117 |
+
all_indices = list(range(S))
|
118 |
+
random.shuffle(all_indices)
|
119 |
+
indices = all_indices[:feature_1d_size**2]
|
120 |
+
res = tensor_NCS[:, :, indices].reshape(N, -1, feature_1d_size, feature_1d_size)
|
121 |
+
return res, indices
|
Time_TravelRephotography/losses/contextual_loss/modules/contextual_bilateral.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .vgg import VGG19
|
5 |
+
from .. import functional as F
|
6 |
+
from ..config import LOSS_TYPES
|
7 |
+
|
8 |
+
|
9 |
+
class ContextualBilateralLoss(nn.Module):
|
10 |
+
"""
|
11 |
+
Creates a criterion that measures the contextual bilateral loss.
|
12 |
+
|
13 |
+
Parameters
|
14 |
+
---
|
15 |
+
weight_sp : float, optional
|
16 |
+
a balancing weight between spatial and feature loss.
|
17 |
+
band_width : int, optional
|
18 |
+
a band_width parameter described as :math:`h` in the paper.
|
19 |
+
use_vgg : bool, optional
|
20 |
+
if you want to use VGG feature, set this `True`.
|
21 |
+
vgg_layer : str, optional
|
22 |
+
intermidiate layer name for VGG feature.
|
23 |
+
Now we support layer names:
|
24 |
+
`['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4']`
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self,
|
28 |
+
weight_sp: float = 0.1,
|
29 |
+
band_width: float = 0.5,
|
30 |
+
loss_type: str = 'cosine',
|
31 |
+
use_vgg: bool = False,
|
32 |
+
vgg_layer: str = 'relu3_4'):
|
33 |
+
|
34 |
+
super(ContextualBilateralLoss, self).__init__()
|
35 |
+
|
36 |
+
assert band_width > 0, 'band_width parameter must be positive.'
|
37 |
+
assert loss_type in LOSS_TYPES,\
|
38 |
+
f'select a loss type from {LOSS_TYPES}.'
|
39 |
+
|
40 |
+
self.band_width = band_width
|
41 |
+
|
42 |
+
if use_vgg:
|
43 |
+
self.vgg_model = VGG19()
|
44 |
+
self.vgg_layer = vgg_layer
|
45 |
+
self.register_buffer(
|
46 |
+
name='vgg_mean',
|
47 |
+
tensor=torch.tensor(
|
48 |
+
[[[0.485]], [[0.456]], [[0.406]]], requires_grad=False)
|
49 |
+
)
|
50 |
+
self.register_buffer(
|
51 |
+
name='vgg_std',
|
52 |
+
tensor=torch.tensor(
|
53 |
+
[[[0.229]], [[0.224]], [[0.225]]], requires_grad=False)
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(self, x, y):
|
57 |
+
if hasattr(self, 'vgg_model'):
|
58 |
+
assert x.shape[1] == 3 and y.shape[1] == 3,\
|
59 |
+
'VGG model takes 3 chennel images.'
|
60 |
+
|
61 |
+
# normalization
|
62 |
+
x = x.sub(self.vgg_mean.detach()).div(self.vgg_std.detach())
|
63 |
+
y = y.sub(self.vgg_mean.detach()).div(self.vgg_std.detach())
|
64 |
+
|
65 |
+
# picking up vgg feature maps
|
66 |
+
x = getattr(self.vgg_model(x), self.vgg_layer)
|
67 |
+
y = getattr(self.vgg_model(y), self.vgg_layer)
|
68 |
+
|
69 |
+
return F.contextual_bilateral_loss(x, y, self.band_width)
|
Time_TravelRephotography/losses/contextual_loss/modules/vgg.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchvision.models.vgg as vgg
|
5 |
+
|
6 |
+
|
7 |
+
class VGG19(nn.Module):
|
8 |
+
def __init__(self, requires_grad=False):
|
9 |
+
super(VGG19, self).__init__()
|
10 |
+
vgg_pretrained_features = vgg.vgg19(pretrained=True).features
|
11 |
+
self.slice1 = nn.Sequential()
|
12 |
+
self.slice2 = nn.Sequential()
|
13 |
+
self.slice3 = nn.Sequential()
|
14 |
+
self.slice4 = nn.Sequential()
|
15 |
+
self.slice5 = nn.Sequential()
|
16 |
+
for x in range(4):
|
17 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
18 |
+
for x in range(4, 9):
|
19 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
20 |
+
for x in range(9, 18):
|
21 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
22 |
+
for x in range(18, 27):
|
23 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
24 |
+
for x in range(27, 36):
|
25 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
26 |
+
if not requires_grad:
|
27 |
+
for param in self.parameters():
|
28 |
+
param.requires_grad = False
|
29 |
+
|
30 |
+
def forward(self, X):
|
31 |
+
h = self.slice1(X)
|
32 |
+
h_relu1_2 = h
|
33 |
+
h = self.slice2(h)
|
34 |
+
h_relu2_2 = h
|
35 |
+
h = self.slice3(h)
|
36 |
+
h_relu3_4 = h
|
37 |
+
h = self.slice4(h)
|
38 |
+
h_relu4_4 = h
|
39 |
+
h = self.slice5(h)
|
40 |
+
h_relu5_4 = h
|
41 |
+
|
42 |
+
vgg_outputs = namedtuple(
|
43 |
+
"VggOutputs", ['relu1_2', 'relu2_2',
|
44 |
+
'relu3_4', 'relu4_4', 'relu5_4'])
|
45 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2,
|
46 |
+
h_relu3_4, h_relu4_4, h_relu5_4)
|
47 |
+
|
48 |
+
return out
|
Time_TravelRephotography/losses/joint_loss.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import (
|
2 |
+
ArgumentParser,
|
3 |
+
Namespace,
|
4 |
+
)
|
5 |
+
from typing import (
|
6 |
+
Dict,
|
7 |
+
Iterable,
|
8 |
+
Optional,
|
9 |
+
Tuple,
|
10 |
+
)
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
from utils.misc import (
|
17 |
+
optional_string,
|
18 |
+
iterable_to_str,
|
19 |
+
)
|
20 |
+
|
21 |
+
from .contextual_loss import ContextualLoss
|
22 |
+
from .color_transfer_loss import ColorTransferLoss
|
23 |
+
from .regularize_noise import NoiseRegularizer
|
24 |
+
from .reconstruction import (
|
25 |
+
EyeLoss,
|
26 |
+
FaceLoss,
|
27 |
+
create_perceptual_loss,
|
28 |
+
ReconstructionArguments,
|
29 |
+
)
|
30 |
+
|
31 |
+
class LossArguments:
|
32 |
+
@staticmethod
|
33 |
+
def add_arguments(parser: ArgumentParser):
|
34 |
+
ReconstructionArguments.add_arguments(parser)
|
35 |
+
|
36 |
+
parser.add_argument("--color_transfer", type=float, default=1e10, help="color transfer loss weight")
|
37 |
+
parser.add_argument("--eye", type=float, default=0.1, help="eye loss weight")
|
38 |
+
parser.add_argument('--noise_regularize', type=float, default=5e4)
|
39 |
+
# contextual loss
|
40 |
+
parser.add_argument("--contextual", type=float, default=0.1, help="contextual loss weight")
|
41 |
+
parser.add_argument("--cx_layers", nargs='*', help="contextual loss layers",
|
42 |
+
choices=['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4'],
|
43 |
+
default=['relu3_4', 'relu2_2', 'relu1_2'])
|
44 |
+
|
45 |
+
@staticmethod
|
46 |
+
def to_string(args: Namespace) -> str:
|
47 |
+
return (
|
48 |
+
ReconstructionArguments.to_string(args)
|
49 |
+
+ optional_string(args.eye > 0, f"-eye{args.eye}")
|
50 |
+
+ optional_string(args.color_transfer, f"-color{args.color_transfer:.1e}")
|
51 |
+
+ optional_string(
|
52 |
+
args.contextual,
|
53 |
+
f"-cx{args.contextual}({iterable_to_str(args.cx_layers)})"
|
54 |
+
)
|
55 |
+
#+ optional_string(args.mse, f"-mse{args.mse}")
|
56 |
+
+ optional_string(args.noise_regularize, f"-NR{args.noise_regularize:.1e}")
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
class BakedMultiContextualLoss(nn.Module):
|
61 |
+
"""Random sample different image patches for different vgg layers."""
|
62 |
+
def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256):
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.cxs = nn.ModuleList([ContextualLoss(use_vgg=True, vgg_layers=[layer])
|
66 |
+
for layer in args.cx_layers])
|
67 |
+
self.size = size
|
68 |
+
self.sibling = sibling.detach()
|
69 |
+
|
70 |
+
def forward(self, img: torch.Tensor):
|
71 |
+
cx_loss = 0
|
72 |
+
for cx in self.cxs:
|
73 |
+
h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2)
|
74 |
+
cx_loss = cx(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size]) + cx_loss
|
75 |
+
return cx_loss
|
76 |
+
|
77 |
+
|
78 |
+
class BakedContextualLoss(ContextualLoss):
|
79 |
+
def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256):
|
80 |
+
super().__init__(use_vgg=True, vgg_layers=args.cx_layers)
|
81 |
+
self.size = size
|
82 |
+
self.sibling = sibling.detach()
|
83 |
+
|
84 |
+
def forward(self, img: torch.Tensor):
|
85 |
+
h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2)
|
86 |
+
return super().forward(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size])
|
87 |
+
|
88 |
+
|
89 |
+
class JointLoss(nn.Module):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
args: Namespace,
|
93 |
+
target: torch.Tensor,
|
94 |
+
sibling: Optional[torch.Tensor],
|
95 |
+
sibling_rgbs: Optional[Iterable[torch.Tensor]] = None,
|
96 |
+
):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
self.weights = {
|
100 |
+
"face": 1., "eye": args.eye,
|
101 |
+
"contextual": args.contextual, "color_transfer": args.color_transfer,
|
102 |
+
"noise": args.noise_regularize,
|
103 |
+
}
|
104 |
+
|
105 |
+
reconstruction = {}
|
106 |
+
if args.vgg > 0 or args.vggface > 0:
|
107 |
+
percept = create_perceptual_loss(args)
|
108 |
+
reconstruction.update(
|
109 |
+
{"face": FaceLoss(target, input_size=args.generator_size, size=args.recon_size, percept=percept)}
|
110 |
+
)
|
111 |
+
if args.eye > 0:
|
112 |
+
reconstruction.update(
|
113 |
+
{"eye": EyeLoss(target, input_size=args.generator_size, percept=percept)}
|
114 |
+
)
|
115 |
+
self.reconstruction = nn.ModuleDict(reconstruction)
|
116 |
+
|
117 |
+
exemplar = {}
|
118 |
+
if args.contextual > 0 and len(args.cx_layers) > 0:
|
119 |
+
assert sibling is not None
|
120 |
+
exemplar.update(
|
121 |
+
{"contextual": BakedContextualLoss(sibling, args)}
|
122 |
+
)
|
123 |
+
if args.color_transfer > 0:
|
124 |
+
assert sibling_rgbs is not None
|
125 |
+
self.sibling_rgbs = sibling_rgbs
|
126 |
+
exemplar.update(
|
127 |
+
{"color_transfer": ColorTransferLoss(init_rgbs=sibling_rgbs)}
|
128 |
+
)
|
129 |
+
self.exemplar = nn.ModuleDict(exemplar)
|
130 |
+
|
131 |
+
if args.noise_regularize > 0:
|
132 |
+
self.noise_criterion = NoiseRegularizer()
|
133 |
+
|
134 |
+
def forward(
|
135 |
+
self, img, degrade=None, noises=None, rgbs=None, rgb_level: Optional[int] = None
|
136 |
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
137 |
+
"""
|
138 |
+
Args:
|
139 |
+
rgbs: results from the ToRGB layers
|
140 |
+
"""
|
141 |
+
# TODO: add current optimization resolution for noises
|
142 |
+
|
143 |
+
losses = {}
|
144 |
+
|
145 |
+
# reconstruction losses
|
146 |
+
for name, criterion in self.reconstruction.items():
|
147 |
+
losses[name] = criterion(img, degrade=degrade)
|
148 |
+
|
149 |
+
# exemplar losses
|
150 |
+
if 'contextual' in self.exemplar:
|
151 |
+
losses["contextual"] = self.exemplar["contextual"](img)
|
152 |
+
if "color_transfer" in self.exemplar:
|
153 |
+
assert rgbs is not None
|
154 |
+
losses["color_transfer"] = self.exemplar["color_transfer"](rgbs, level=rgb_level)
|
155 |
+
|
156 |
+
# noise regularizer
|
157 |
+
if self.weights["noise"] > 0:
|
158 |
+
losses["noise"] = self.noise_criterion(noises)
|
159 |
+
|
160 |
+
total_loss = 0
|
161 |
+
for name, loss in losses.items():
|
162 |
+
total_loss = total_loss + self.weights[name] * loss
|
163 |
+
return total_loss, losses
|
164 |
+
|
165 |
+
def update_sibling(self, sibling: torch.Tensor):
|
166 |
+
assert "contextual" in self.exemplar
|
167 |
+
self.exemplar["contextual"].sibling = sibling.detach()
|
Time_TravelRephotography/losses/perceptual_loss.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code borrowed from https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49#file-vgg_perceptual_loss-py-L5
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torchvision
|
6 |
+
from models.vggface import VGGFaceFeats
|
7 |
+
|
8 |
+
|
9 |
+
def cos_loss(fi, ft):
|
10 |
+
return 1 - torch.nn.functional.cosine_similarity(fi, ft).mean()
|
11 |
+
|
12 |
+
|
13 |
+
class VGGPerceptualLoss(torch.nn.Module):
|
14 |
+
def __init__(self, resize=False):
|
15 |
+
super(VGGPerceptualLoss, self).__init__()
|
16 |
+
blocks = []
|
17 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
|
18 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
|
19 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
|
20 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
|
21 |
+
for bl in blocks:
|
22 |
+
for p in bl:
|
23 |
+
p.requires_grad = False
|
24 |
+
self.blocks = torch.nn.ModuleList(blocks)
|
25 |
+
self.transform = torch.nn.functional.interpolate
|
26 |
+
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
|
27 |
+
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
|
28 |
+
self.resize = resize
|
29 |
+
|
30 |
+
def forward(self, input, target, max_layer=4, cos_dist: bool = False):
|
31 |
+
target = (target + 1) * 0.5
|
32 |
+
input = (input + 1) * 0.5
|
33 |
+
|
34 |
+
if input.shape[1] != 3:
|
35 |
+
input = input.repeat(1, 3, 1, 1)
|
36 |
+
target = target.repeat(1, 3, 1, 1)
|
37 |
+
input = (input-self.mean) / self.std
|
38 |
+
target = (target-self.mean) / self.std
|
39 |
+
if self.resize:
|
40 |
+
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
|
41 |
+
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
|
42 |
+
x = input
|
43 |
+
y = target
|
44 |
+
loss = 0.0
|
45 |
+
loss_func = cos_loss if cos_dist else torch.nn.functional.l1_loss
|
46 |
+
for bi, block in enumerate(self.blocks[:max_layer]):
|
47 |
+
x = block(x)
|
48 |
+
y = block(y)
|
49 |
+
loss += loss_func(x, y.detach())
|
50 |
+
return loss
|
51 |
+
|
52 |
+
|
53 |
+
class VGGFacePerceptualLoss(torch.nn.Module):
|
54 |
+
def __init__(self, weight_path: str = "checkpoint/vgg_face_dag.pt", resize: bool = False):
|
55 |
+
super().__init__()
|
56 |
+
self.vgg = VGGFaceFeats()
|
57 |
+
self.vgg.load_state_dict(torch.load(weight_path))
|
58 |
+
|
59 |
+
mean = torch.tensor(self.vgg.meta["mean"]).view(1, 3, 1, 1) / 255.0
|
60 |
+
self.register_buffer("mean", mean)
|
61 |
+
|
62 |
+
self.transform = torch.nn.functional.interpolate
|
63 |
+
self.resize = resize
|
64 |
+
|
65 |
+
def forward(self, input, target, max_layer: int = 4, cos_dist: bool = False):
|
66 |
+
target = (target + 1) * 0.5
|
67 |
+
input = (input + 1) * 0.5
|
68 |
+
|
69 |
+
# preprocessing
|
70 |
+
if input.shape[1] != 3:
|
71 |
+
input = input.repeat(1, 3, 1, 1)
|
72 |
+
target = target.repeat(1, 3, 1, 1)
|
73 |
+
input = input - self.mean
|
74 |
+
target = target - self.mean
|
75 |
+
if self.resize:
|
76 |
+
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
|
77 |
+
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
|
78 |
+
|
79 |
+
input_feats = self.vgg(input)
|
80 |
+
target_feats = self.vgg(target)
|
81 |
+
|
82 |
+
loss_func = cos_loss if cos_dist else torch.nn.functional.l1_loss
|
83 |
+
# calc perceptual loss
|
84 |
+
loss = 0.0
|
85 |
+
for fi, ft in zip(input_feats[:max_layer], target_feats[:max_layer]):
|
86 |
+
loss = loss + loss_func(fi, ft.detach())
|
87 |
+
return loss
|
88 |
+
|
89 |
+
|
90 |
+
class PerceptualLoss(torch.nn.Module):
|
91 |
+
def __init__(
|
92 |
+
self, lambda_vggface: float = 0.025 / 0.15, lambda_vgg: float = 1, eps: float = 1e-8, cos_dist: bool = False
|
93 |
+
):
|
94 |
+
super().__init__()
|
95 |
+
self.register_buffer("lambda_vggface", torch.tensor(lambda_vggface))
|
96 |
+
self.register_buffer("lambda_vgg", torch.tensor(lambda_vgg))
|
97 |
+
self.cos_dist = cos_dist
|
98 |
+
|
99 |
+
if lambda_vgg > eps:
|
100 |
+
self.vgg = VGGPerceptualLoss()
|
101 |
+
if lambda_vggface > eps:
|
102 |
+
self.vggface = VGGFacePerceptualLoss()
|
103 |
+
|
104 |
+
def forward(self, input, target, eps=1e-8, use_vggface: bool = True, use_vgg=True, max_vgg_layer=4):
|
105 |
+
loss = 0.0
|
106 |
+
if self.lambda_vgg > eps and use_vgg:
|
107 |
+
loss = loss + self.lambda_vgg * self.vgg(input, target, max_layer=max_vgg_layer)
|
108 |
+
if self.lambda_vggface > eps and use_vggface:
|
109 |
+
loss = loss + self.lambda_vggface * self.vggface(input, target, cos_dist=self.cos_dist)
|
110 |
+
return loss
|
111 |
+
|
Time_TravelRephotography/losses/reconstruction.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import (
|
2 |
+
ArgumentParser,
|
3 |
+
Namespace,
|
4 |
+
)
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from losses.perceptual_loss import PerceptualLoss
|
12 |
+
from models.degrade import Downsample
|
13 |
+
from utils.misc import optional_string
|
14 |
+
|
15 |
+
|
16 |
+
class ReconstructionArguments:
|
17 |
+
@staticmethod
|
18 |
+
def add_arguments(parser: ArgumentParser):
|
19 |
+
parser.add_argument("--vggface", type=float, default=0.3, help="vggface")
|
20 |
+
parser.add_argument("--vgg", type=float, default=1, help="vgg")
|
21 |
+
parser.add_argument('--recon_size', type=int, default=256, help="size for face reconstruction loss")
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def to_string(args: Namespace) -> str:
|
25 |
+
return (
|
26 |
+
f"s{args.recon_size}"
|
27 |
+
+ optional_string(args.vgg > 0, f"-vgg{args.vgg}")
|
28 |
+
+ optional_string(args.vggface > 0, f"-vggface{args.vggface}")
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def create_perceptual_loss(args: Namespace):
|
33 |
+
return PerceptualLoss(lambda_vgg=args.vgg, lambda_vggface=args.vggface, cos_dist=False)
|
34 |
+
|
35 |
+
|
36 |
+
class EyeLoss(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
target: torch.Tensor,
|
40 |
+
input_size: int = 1024,
|
41 |
+
input_channels: int = 3,
|
42 |
+
percept: Optional[nn.Module] = None,
|
43 |
+
args: Optional[Namespace] = None
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
target: target image
|
47 |
+
"""
|
48 |
+
assert not (percept is None and args is None)
|
49 |
+
|
50 |
+
super().__init__()
|
51 |
+
|
52 |
+
self.target = target
|
53 |
+
|
54 |
+
target_size = target.shape[-1]
|
55 |
+
self.downsample = Downsample(input_size, target_size, input_channels) \
|
56 |
+
if target_size != input_size else (lambda x: x)
|
57 |
+
|
58 |
+
self.percept = percept if percept is not None else create_perceptual_loss(args)
|
59 |
+
|
60 |
+
eye_size = np.array((224, 224))
|
61 |
+
btlrs = []
|
62 |
+
for sgn in [1, -1]:
|
63 |
+
center = np.array((480, 384 * sgn)) # (y, x)
|
64 |
+
b, t = center[0] - eye_size[0] // 2, center[0] + eye_size[0] // 2
|
65 |
+
l, r = center[1] - eye_size[1] // 2, center[1] + eye_size[1] // 2
|
66 |
+
btlrs.append((np.array((b, t, l, r)) / 1024 * target_size).astype(int))
|
67 |
+
self.btlrs = np.stack(btlrs, axis=0)
|
68 |
+
|
69 |
+
def forward(self, img: torch.Tensor, degrade: nn.Module = None):
|
70 |
+
"""
|
71 |
+
img: it should be the degraded version of the generated image
|
72 |
+
"""
|
73 |
+
if degrade is not None:
|
74 |
+
img = degrade(img, downsample=self.downsample)
|
75 |
+
|
76 |
+
loss = 0
|
77 |
+
for (b, t, l, r) in self.btlrs:
|
78 |
+
loss = loss + self.percept(
|
79 |
+
img[:, :, b:t, l:r], self.target[:, :, b:t, l:r],
|
80 |
+
use_vggface=False, max_vgg_layer=4,
|
81 |
+
# use_vgg=False,
|
82 |
+
)
|
83 |
+
return loss
|
84 |
+
|
85 |
+
|
86 |
+
class FaceLoss(nn.Module):
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
target: torch.Tensor,
|
90 |
+
input_size: int = 1024,
|
91 |
+
input_channels: int = 3,
|
92 |
+
size: int = 256,
|
93 |
+
percept: Optional[nn.Module] = None,
|
94 |
+
args: Optional[Namespace] = None
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
target: target image
|
98 |
+
"""
|
99 |
+
assert not (percept is None and args is None)
|
100 |
+
|
101 |
+
super().__init__()
|
102 |
+
|
103 |
+
target_size = target.shape[-1]
|
104 |
+
self.target = target if target_size == size \
|
105 |
+
else Downsample(target_size, size, target.shape[1]).to(target.device)(target)
|
106 |
+
|
107 |
+
self.downsample = Downsample(input_size, size, input_channels) \
|
108 |
+
if size != input_size else (lambda x: x)
|
109 |
+
|
110 |
+
self.percept = percept if percept is not None else create_perceptual_loss(args)
|
111 |
+
|
112 |
+
def forward(self, img: torch.Tensor, degrade: nn.Module = None):
|
113 |
+
"""
|
114 |
+
img: it should be the degraded version of the generated image
|
115 |
+
"""
|
116 |
+
if degrade is not None:
|
117 |
+
img = degrade(img, downsample=self.downsample)
|
118 |
+
loss = self.percept(img, self.target)
|
119 |
+
return loss
|
Time_TravelRephotography/losses/regularize_noise.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Iterable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
class NoiseRegularizer(nn.Module):
|
8 |
+
def forward(self, noises: Iterable[torch.Tensor]):
|
9 |
+
loss = 0
|
10 |
+
|
11 |
+
for noise in noises:
|
12 |
+
size = noise.shape[2]
|
13 |
+
|
14 |
+
while True:
|
15 |
+
loss = (
|
16 |
+
loss
|
17 |
+
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
|
18 |
+
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
|
19 |
+
)
|
20 |
+
|
21 |
+
if size <= 8:
|
22 |
+
break
|
23 |
+
|
24 |
+
noise = noise.reshape([1, 1, size // 2, 2, size // 2, 2])
|
25 |
+
noise = noise.mean([3, 5])
|
26 |
+
size //= 2
|
27 |
+
|
28 |
+
return loss
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def normalize(noises: Iterable[torch.Tensor]):
|
32 |
+
for noise in noises:
|
33 |
+
mean = noise.mean()
|
34 |
+
std = noise.std()
|
35 |
+
|
36 |
+
noise.data.add_(-mean).div_(std)
|
37 |
+
|
Time_TravelRephotography/model.py
ADDED
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import functools
|
4 |
+
import operator
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.autograd import Function
|
11 |
+
|
12 |
+
from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
|
13 |
+
|
14 |
+
|
15 |
+
class PixelNorm(nn.Module):
|
16 |
+
def __init__(self):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
def forward(self, input):
|
20 |
+
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
21 |
+
|
22 |
+
|
23 |
+
def make_kernel(k):
|
24 |
+
k = torch.tensor(k, dtype=torch.float32)
|
25 |
+
|
26 |
+
if k.ndim == 1:
|
27 |
+
k = k[None, :] * k[:, None]
|
28 |
+
|
29 |
+
k /= k.sum()
|
30 |
+
|
31 |
+
return k
|
32 |
+
|
33 |
+
|
34 |
+
class Upsample(nn.Module):
|
35 |
+
def __init__(self, kernel, factor=2):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.factor = factor
|
39 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
40 |
+
self.register_buffer('kernel', kernel)
|
41 |
+
|
42 |
+
p = kernel.shape[0] - factor
|
43 |
+
|
44 |
+
pad0 = (p + 1) // 2 + factor - 1
|
45 |
+
pad1 = p // 2
|
46 |
+
|
47 |
+
self.pad = (pad0, pad1)
|
48 |
+
|
49 |
+
def forward(self, input):
|
50 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
51 |
+
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
class Downsample(nn.Module):
|
56 |
+
def __init__(self, kernel, factor=2):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.factor = factor
|
60 |
+
kernel = make_kernel(kernel)
|
61 |
+
self.register_buffer('kernel', kernel)
|
62 |
+
|
63 |
+
p = kernel.shape[0] - factor
|
64 |
+
|
65 |
+
pad0 = (p + 1) // 2
|
66 |
+
pad1 = p // 2
|
67 |
+
|
68 |
+
self.pad = (pad0, pad1)
|
69 |
+
|
70 |
+
def forward(self, input):
|
71 |
+
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
72 |
+
|
73 |
+
return out
|
74 |
+
|
75 |
+
|
76 |
+
class Blur(nn.Module):
|
77 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
78 |
+
super().__init__()
|
79 |
+
|
80 |
+
kernel = make_kernel(kernel)
|
81 |
+
|
82 |
+
if upsample_factor > 1:
|
83 |
+
kernel = kernel * (upsample_factor ** 2)
|
84 |
+
|
85 |
+
self.register_buffer('kernel', kernel)
|
86 |
+
|
87 |
+
self.pad = pad
|
88 |
+
|
89 |
+
def forward(self, input):
|
90 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
91 |
+
|
92 |
+
return out
|
93 |
+
|
94 |
+
|
95 |
+
class EqualConv2d(nn.Module):
|
96 |
+
def __init__(
|
97 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
|
101 |
+
self.weight = nn.Parameter(
|
102 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
103 |
+
)
|
104 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
105 |
+
|
106 |
+
self.stride = stride
|
107 |
+
self.padding = padding
|
108 |
+
|
109 |
+
if bias:
|
110 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
111 |
+
|
112 |
+
else:
|
113 |
+
self.bias = None
|
114 |
+
|
115 |
+
def forward(self, input):
|
116 |
+
out = F.conv2d(
|
117 |
+
input,
|
118 |
+
self.weight * self.scale,
|
119 |
+
bias=self.bias,
|
120 |
+
stride=self.stride,
|
121 |
+
padding=self.padding,
|
122 |
+
)
|
123 |
+
|
124 |
+
return out
|
125 |
+
|
126 |
+
def __repr__(self):
|
127 |
+
return (
|
128 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
129 |
+
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
class EqualLinear(nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
136 |
+
):
|
137 |
+
super().__init__()
|
138 |
+
|
139 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
140 |
+
|
141 |
+
if bias:
|
142 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
143 |
+
|
144 |
+
else:
|
145 |
+
self.bias = None
|
146 |
+
|
147 |
+
self.activation = activation
|
148 |
+
|
149 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
150 |
+
self.lr_mul = lr_mul
|
151 |
+
|
152 |
+
def forward(self, input):
|
153 |
+
if self.activation:
|
154 |
+
out = F.linear(input, self.weight * self.scale)
|
155 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
156 |
+
|
157 |
+
else:
|
158 |
+
out = F.linear(
|
159 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
160 |
+
)
|
161 |
+
|
162 |
+
return out
|
163 |
+
|
164 |
+
def __repr__(self):
|
165 |
+
return (
|
166 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
class ScaledLeakyReLU(nn.Module):
|
171 |
+
def __init__(self, negative_slope=0.2):
|
172 |
+
super().__init__()
|
173 |
+
|
174 |
+
self.negative_slope = negative_slope
|
175 |
+
|
176 |
+
def forward(self, input):
|
177 |
+
out = F.leaky_relu(input, negative_slope=self.negative_slope)
|
178 |
+
|
179 |
+
return out * math.sqrt(2)
|
180 |
+
|
181 |
+
|
182 |
+
class ModulatedConv2d(nn.Module):
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
in_channel,
|
186 |
+
out_channel,
|
187 |
+
kernel_size,
|
188 |
+
style_dim,
|
189 |
+
demodulate=True,
|
190 |
+
upsample=False,
|
191 |
+
downsample=False,
|
192 |
+
blur_kernel=[1, 3, 3, 1],
|
193 |
+
):
|
194 |
+
super().__init__()
|
195 |
+
|
196 |
+
self.eps = 1e-8
|
197 |
+
self.kernel_size = kernel_size
|
198 |
+
self.in_channel = in_channel
|
199 |
+
self.out_channel = out_channel
|
200 |
+
self.upsample = upsample
|
201 |
+
self.downsample = downsample
|
202 |
+
|
203 |
+
if upsample:
|
204 |
+
factor = 2
|
205 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
206 |
+
pad0 = (p + 1) // 2 + factor - 1
|
207 |
+
pad1 = p // 2 + 1
|
208 |
+
|
209 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
210 |
+
|
211 |
+
if downsample:
|
212 |
+
factor = 2
|
213 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
214 |
+
pad0 = (p + 1) // 2
|
215 |
+
pad1 = p // 2
|
216 |
+
|
217 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
218 |
+
|
219 |
+
fan_in = in_channel * kernel_size ** 2
|
220 |
+
self.scale = 1 / math.sqrt(fan_in)
|
221 |
+
self.padding = kernel_size // 2
|
222 |
+
|
223 |
+
self.weight = nn.Parameter(
|
224 |
+
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
225 |
+
)
|
226 |
+
|
227 |
+
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
228 |
+
|
229 |
+
self.demodulate = demodulate
|
230 |
+
|
231 |
+
def __repr__(self):
|
232 |
+
return (
|
233 |
+
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
|
234 |
+
f'upsample={self.upsample}, downsample={self.downsample})'
|
235 |
+
)
|
236 |
+
|
237 |
+
def forward(self, input, style):
|
238 |
+
batch, in_channel, height, width = input.shape
|
239 |
+
|
240 |
+
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
241 |
+
weight = self.scale * self.weight * style
|
242 |
+
|
243 |
+
if self.demodulate:
|
244 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
245 |
+
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
246 |
+
|
247 |
+
weight = weight.view(
|
248 |
+
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
249 |
+
)
|
250 |
+
|
251 |
+
if self.upsample:
|
252 |
+
input = input.view(1, batch * in_channel, height, width)
|
253 |
+
weight = weight.view(
|
254 |
+
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
255 |
+
)
|
256 |
+
weight = weight.transpose(1, 2).reshape(
|
257 |
+
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
258 |
+
)
|
259 |
+
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
260 |
+
_, _, height, width = out.shape
|
261 |
+
out = out.view(batch, self.out_channel, height, width)
|
262 |
+
out = self.blur(out)
|
263 |
+
|
264 |
+
elif self.downsample:
|
265 |
+
input = self.blur(input)
|
266 |
+
_, _, height, width = input.shape
|
267 |
+
input = input.view(1, batch * in_channel, height, width)
|
268 |
+
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
269 |
+
_, _, height, width = out.shape
|
270 |
+
out = out.view(batch, self.out_channel, height, width)
|
271 |
+
|
272 |
+
else:
|
273 |
+
input = input.view(1, batch * in_channel, height, width)
|
274 |
+
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
|
275 |
+
_, _, height, width = out.shape
|
276 |
+
out = out.view(batch, self.out_channel, height, width)
|
277 |
+
|
278 |
+
return out
|
279 |
+
|
280 |
+
|
281 |
+
class NoiseInjection(nn.Module):
|
282 |
+
def __init__(self):
|
283 |
+
super().__init__()
|
284 |
+
|
285 |
+
self.weight = nn.Parameter(torch.zeros(1))
|
286 |
+
|
287 |
+
def forward(self, image, noise=None):
|
288 |
+
if noise is None:
|
289 |
+
batch, _, height, width = image.shape
|
290 |
+
noise = image.new_empty(batch, 1, height, width).normal_()
|
291 |
+
|
292 |
+
return image + self.weight * noise
|
293 |
+
|
294 |
+
|
295 |
+
class ConstantInput(nn.Module):
|
296 |
+
def __init__(self, channel, size=4):
|
297 |
+
super().__init__()
|
298 |
+
|
299 |
+
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
300 |
+
|
301 |
+
def forward(self, input):
|
302 |
+
batch = input.shape[0]
|
303 |
+
out = self.input.repeat(batch, 1, 1, 1)
|
304 |
+
|
305 |
+
return out
|
306 |
+
|
307 |
+
|
308 |
+
class StyledConv(nn.Module):
|
309 |
+
def __init__(
|
310 |
+
self,
|
311 |
+
in_channel,
|
312 |
+
out_channel,
|
313 |
+
kernel_size,
|
314 |
+
style_dim,
|
315 |
+
upsample=False,
|
316 |
+
blur_kernel=[1, 3, 3, 1],
|
317 |
+
demodulate=True,
|
318 |
+
):
|
319 |
+
super().__init__()
|
320 |
+
|
321 |
+
self.conv = ModulatedConv2d(
|
322 |
+
in_channel,
|
323 |
+
out_channel,
|
324 |
+
kernel_size,
|
325 |
+
style_dim,
|
326 |
+
upsample=upsample,
|
327 |
+
blur_kernel=blur_kernel,
|
328 |
+
demodulate=demodulate,
|
329 |
+
)
|
330 |
+
|
331 |
+
self.noise = NoiseInjection()
|
332 |
+
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
333 |
+
# self.activate = ScaledLeakyReLU(0.2)
|
334 |
+
self.activate = FusedLeakyReLU(out_channel)
|
335 |
+
|
336 |
+
def forward(self, input, style, noise=None):
|
337 |
+
out = self.conv(input, style)
|
338 |
+
out = self.noise(out, noise=noise)
|
339 |
+
# out = out + self.bias
|
340 |
+
out = self.activate(out)
|
341 |
+
|
342 |
+
return out
|
343 |
+
|
344 |
+
|
345 |
+
class ToRGB(nn.Module):
|
346 |
+
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
347 |
+
super().__init__()
|
348 |
+
|
349 |
+
if upsample:
|
350 |
+
self.upsample = Upsample(blur_kernel)
|
351 |
+
|
352 |
+
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
353 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
354 |
+
|
355 |
+
def forward(self, input, style, skip=None):
|
356 |
+
out = self.conv(input, style)
|
357 |
+
style_modulated = out
|
358 |
+
out = out + self.bias
|
359 |
+
|
360 |
+
if skip is not None:
|
361 |
+
skip = self.upsample(skip)
|
362 |
+
|
363 |
+
out = out + skip
|
364 |
+
|
365 |
+
return out, style_modulated
|
366 |
+
|
367 |
+
|
368 |
+
class Generator(nn.Module):
|
369 |
+
def __init__(
|
370 |
+
self,
|
371 |
+
size,
|
372 |
+
style_dim,
|
373 |
+
n_mlp,
|
374 |
+
channel_multiplier=2,
|
375 |
+
blur_kernel=[1, 3, 3, 1],
|
376 |
+
lr_mlp=0.01,
|
377 |
+
):
|
378 |
+
super().__init__()
|
379 |
+
|
380 |
+
self.size = size
|
381 |
+
|
382 |
+
self.style_dim = style_dim
|
383 |
+
|
384 |
+
layers = [PixelNorm()]
|
385 |
+
|
386 |
+
for i in range(n_mlp):
|
387 |
+
layers.append(
|
388 |
+
EqualLinear(
|
389 |
+
style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
|
390 |
+
)
|
391 |
+
)
|
392 |
+
|
393 |
+
self.style = nn.Sequential(*layers)
|
394 |
+
|
395 |
+
self.channels = {
|
396 |
+
4: 512,
|
397 |
+
8: 512,
|
398 |
+
16: 512,
|
399 |
+
32: 512,
|
400 |
+
64: 256 * channel_multiplier,
|
401 |
+
128: 128 * channel_multiplier,
|
402 |
+
256: 64 * channel_multiplier,
|
403 |
+
512: 32 * channel_multiplier,
|
404 |
+
1024: 16 * channel_multiplier,
|
405 |
+
}
|
406 |
+
|
407 |
+
self.input = ConstantInput(self.channels[4])
|
408 |
+
self.conv1 = StyledConv(
|
409 |
+
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
410 |
+
)
|
411 |
+
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
412 |
+
|
413 |
+
self.log_size = int(math.log(size, 2))
|
414 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
415 |
+
|
416 |
+
self.convs = nn.ModuleList()
|
417 |
+
self.upsamples = nn.ModuleList()
|
418 |
+
self.to_rgbs = nn.ModuleList()
|
419 |
+
self.noises = nn.Module()
|
420 |
+
|
421 |
+
in_channel = self.channels[4]
|
422 |
+
|
423 |
+
for layer_idx in range(self.num_layers):
|
424 |
+
res = (layer_idx + 5) // 2
|
425 |
+
shape = [1, 1, 2 ** res, 2 ** res]
|
426 |
+
self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
|
427 |
+
|
428 |
+
for i in range(3, self.log_size + 1):
|
429 |
+
out_channel = self.channels[2 ** i]
|
430 |
+
|
431 |
+
self.convs.append(
|
432 |
+
StyledConv(
|
433 |
+
in_channel,
|
434 |
+
out_channel,
|
435 |
+
3,
|
436 |
+
style_dim,
|
437 |
+
upsample=True,
|
438 |
+
blur_kernel=blur_kernel,
|
439 |
+
)
|
440 |
+
)
|
441 |
+
|
442 |
+
self.convs.append(
|
443 |
+
StyledConv(
|
444 |
+
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
445 |
+
)
|
446 |
+
)
|
447 |
+
|
448 |
+
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
449 |
+
|
450 |
+
in_channel = out_channel
|
451 |
+
|
452 |
+
self.n_latent = self.log_size * 2 - 2
|
453 |
+
|
454 |
+
@property
|
455 |
+
def device(self):
|
456 |
+
# TODO if multi-gpu is expected, could use the following more expensive version
|
457 |
+
#device, = list(set(p.device for p in self.parameters()))
|
458 |
+
return next(self.parameters()).device
|
459 |
+
|
460 |
+
@staticmethod
|
461 |
+
def get_latent_size(size):
|
462 |
+
log_size = int(math.log(size, 2))
|
463 |
+
return log_size * 2 - 2
|
464 |
+
|
465 |
+
@staticmethod
|
466 |
+
def make_noise_by_size(size: int, device: torch.device):
|
467 |
+
log_size = int(math.log(size, 2))
|
468 |
+
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
469 |
+
|
470 |
+
for i in range(3, log_size + 1):
|
471 |
+
for _ in range(2):
|
472 |
+
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
473 |
+
|
474 |
+
return noises
|
475 |
+
|
476 |
+
|
477 |
+
def make_noise(self):
|
478 |
+
return self.make_noise_by_size(self.size, self.input.input.device)
|
479 |
+
|
480 |
+
def mean_latent(self, n_latent):
|
481 |
+
latent_in = torch.randn(
|
482 |
+
n_latent, self.style_dim, device=self.input.input.device
|
483 |
+
)
|
484 |
+
latent = self.style(latent_in).mean(0, keepdim=True)
|
485 |
+
|
486 |
+
return latent
|
487 |
+
|
488 |
+
def get_latent(self, input):
|
489 |
+
return self.style(input)
|
490 |
+
|
491 |
+
def forward(
|
492 |
+
self,
|
493 |
+
styles,
|
494 |
+
return_latents=False,
|
495 |
+
inject_index=None,
|
496 |
+
truncation=1,
|
497 |
+
truncation_latent=None,
|
498 |
+
input_is_latent=False,
|
499 |
+
noise=None,
|
500 |
+
randomize_noise=True,
|
501 |
+
):
|
502 |
+
if not input_is_latent:
|
503 |
+
styles = [self.style(s) for s in styles]
|
504 |
+
|
505 |
+
if noise is None:
|
506 |
+
if randomize_noise:
|
507 |
+
noise = [None] * self.num_layers
|
508 |
+
else:
|
509 |
+
noise = [
|
510 |
+
getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
|
511 |
+
]
|
512 |
+
|
513 |
+
if truncation < 1:
|
514 |
+
style_t = []
|
515 |
+
|
516 |
+
for style in styles:
|
517 |
+
style_t.append(
|
518 |
+
truncation_latent + truncation * (style - truncation_latent)
|
519 |
+
)
|
520 |
+
|
521 |
+
styles = style_t
|
522 |
+
|
523 |
+
if len(styles) < 2:
|
524 |
+
inject_index = self.n_latent
|
525 |
+
|
526 |
+
if styles[0].ndim < 3:
|
527 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
528 |
+
|
529 |
+
else:
|
530 |
+
latent = styles[0]
|
531 |
+
|
532 |
+
else:
|
533 |
+
if inject_index is None:
|
534 |
+
inject_index = random.randint(1, self.n_latent - 1)
|
535 |
+
|
536 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
537 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
538 |
+
|
539 |
+
latent = torch.cat([latent, latent2], 1)
|
540 |
+
|
541 |
+
out = self.input(latent)
|
542 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
543 |
+
|
544 |
+
skip, rgb_mod = self.to_rgb1(out, latent[:, 1])
|
545 |
+
|
546 |
+
|
547 |
+
rgbs = [rgb_mod] # all but the last skip
|
548 |
+
i = 1
|
549 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
550 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
551 |
+
):
|
552 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
553 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
554 |
+
skip, rgb_mod = to_rgb(out, latent[:, i + 2], skip)
|
555 |
+
rgbs.append(rgb_mod)
|
556 |
+
|
557 |
+
i += 2
|
558 |
+
|
559 |
+
image = skip
|
560 |
+
|
561 |
+
if return_latents:
|
562 |
+
return image, latent, rgbs
|
563 |
+
|
564 |
+
else:
|
565 |
+
return image, None, rgbs
|
566 |
+
|
567 |
+
|
568 |
+
class ConvLayer(nn.Sequential):
|
569 |
+
def __init__(
|
570 |
+
self,
|
571 |
+
in_channel,
|
572 |
+
out_channel,
|
573 |
+
kernel_size,
|
574 |
+
downsample=False,
|
575 |
+
blur_kernel=[1, 3, 3, 1],
|
576 |
+
bias=True,
|
577 |
+
activate=True,
|
578 |
+
):
|
579 |
+
layers = []
|
580 |
+
|
581 |
+
if downsample:
|
582 |
+
factor = 2
|
583 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
584 |
+
pad0 = (p + 1) // 2
|
585 |
+
pad1 = p // 2
|
586 |
+
|
587 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
588 |
+
|
589 |
+
stride = 2
|
590 |
+
self.padding = 0
|
591 |
+
|
592 |
+
else:
|
593 |
+
stride = 1
|
594 |
+
self.padding = kernel_size // 2
|
595 |
+
|
596 |
+
layers.append(
|
597 |
+
EqualConv2d(
|
598 |
+
in_channel,
|
599 |
+
out_channel,
|
600 |
+
kernel_size,
|
601 |
+
padding=self.padding,
|
602 |
+
stride=stride,
|
603 |
+
bias=bias and not activate,
|
604 |
+
)
|
605 |
+
)
|
606 |
+
|
607 |
+
if activate:
|
608 |
+
if bias:
|
609 |
+
layers.append(FusedLeakyReLU(out_channel))
|
610 |
+
|
611 |
+
else:
|
612 |
+
layers.append(ScaledLeakyReLU(0.2))
|
613 |
+
|
614 |
+
super().__init__(*layers)
|
615 |
+
|
616 |
+
|
617 |
+
class ResBlock(nn.Module):
|
618 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
619 |
+
super().__init__()
|
620 |
+
|
621 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
622 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
623 |
+
|
624 |
+
self.skip = ConvLayer(
|
625 |
+
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
626 |
+
)
|
627 |
+
|
628 |
+
def forward(self, input):
|
629 |
+
out = self.conv1(input)
|
630 |
+
out = self.conv2(out)
|
631 |
+
|
632 |
+
skip = self.skip(input)
|
633 |
+
out = (out + skip) / math.sqrt(2)
|
634 |
+
|
635 |
+
return out
|
636 |
+
|
637 |
+
|
638 |
+
class Discriminator(nn.Module):
|
639 |
+
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
640 |
+
super().__init__()
|
641 |
+
|
642 |
+
channels = {
|
643 |
+
4: 512,
|
644 |
+
8: 512,
|
645 |
+
16: 512,
|
646 |
+
32: 512,
|
647 |
+
64: 256 * channel_multiplier,
|
648 |
+
128: 128 * channel_multiplier,
|
649 |
+
256: 64 * channel_multiplier,
|
650 |
+
512: 32 * channel_multiplier,
|
651 |
+
1024: 16 * channel_multiplier,
|
652 |
+
}
|
653 |
+
|
654 |
+
convs = [ConvLayer(3, channels[size], 1)]
|
655 |
+
|
656 |
+
log_size = int(math.log(size, 2))
|
657 |
+
|
658 |
+
in_channel = channels[size]
|
659 |
+
|
660 |
+
for i in range(log_size, 2, -1):
|
661 |
+
out_channel = channels[2 ** (i - 1)]
|
662 |
+
|
663 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
664 |
+
|
665 |
+
in_channel = out_channel
|
666 |
+
|
667 |
+
self.convs = nn.Sequential(*convs)
|
668 |
+
|
669 |
+
self.stddev_group = 4
|
670 |
+
self.stddev_feat = 1
|
671 |
+
|
672 |
+
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
673 |
+
self.final_linear = nn.Sequential(
|
674 |
+
EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
|
675 |
+
EqualLinear(channels[4], 1),
|
676 |
+
)
|
677 |
+
|
678 |
+
def forward(self, input):
|
679 |
+
out = self.convs(input)
|
680 |
+
|
681 |
+
batch, channel, height, width = out.shape
|
682 |
+
group = min(batch, self.stddev_group)
|
683 |
+
stddev = out.view(
|
684 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
685 |
+
)
|
686 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
687 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
688 |
+
stddev = stddev.repeat(group, 1, height, width)
|
689 |
+
out = torch.cat([out, stddev], 1)
|
690 |
+
|
691 |
+
out = self.final_conv(out)
|
692 |
+
|
693 |
+
out = out.view(batch, -1)
|
694 |
+
out = self.final_linear(out)
|
695 |
+
|
696 |
+
return out
|
697 |
+
|
Time_TravelRephotography/models/__init__.py
ADDED
File without changes
|
Time_TravelRephotography/models/degrade.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import (
|
2 |
+
ArgumentParser,
|
3 |
+
Namespace,
|
4 |
+
)
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
from utils.misc import optional_string
|
11 |
+
|
12 |
+
from .gaussian_smoothing import GaussianSmoothing
|
13 |
+
|
14 |
+
|
15 |
+
class DegradeArguments:
|
16 |
+
@staticmethod
|
17 |
+
def add_arguments(parser: ArgumentParser):
|
18 |
+
parser.add_argument('--spectral_sensitivity', choices=["g", "b", "gb"], default="g",
|
19 |
+
help="Type of spectral sensitivity. g: grayscale (panchromatic), b: blue-sensitive, gb: green+blue (orthochromatic)")
|
20 |
+
parser.add_argument('--gaussian', type=float, default=0,
|
21 |
+
help="estimated blur radius in pixels of the input photo if it is scaled to 1024x1024")
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def to_string(args: Namespace) -> str:
|
25 |
+
return (
|
26 |
+
f"{args.spectral_sensitivity}"
|
27 |
+
+ optional_string(args.gaussian > 0, f"-G{args.gaussian}")
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
class CameraResponse(nn.Module):
|
32 |
+
def __init__(self):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
self.register_parameter("gamma", nn.Parameter(torch.ones(1)))
|
36 |
+
self.register_parameter("offset", nn.Parameter(torch.zeros(1)))
|
37 |
+
self.register_parameter("gain", nn.Parameter(torch.ones(1)))
|
38 |
+
|
39 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
40 |
+
x = torch.clamp(x, max=1, min=-1+1e-2)
|
41 |
+
x = (1 + x) * 0.5
|
42 |
+
x = self.offset + self.gain * torch.pow(x, self.gamma)
|
43 |
+
x = (x - 0.5) * 2
|
44 |
+
# b = torch.clamp(b, max=1, min=-1)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class SpectralResponse(nn.Module):
|
49 |
+
# TODO: use enum instead for color mode
|
50 |
+
def __init__(self, spectral_sensitivity: str = 'b'):
|
51 |
+
assert spectral_sensitivity in ("g", "b", "gb"), f"spectral_sensitivity {spectral_sensitivity} is not implemented."
|
52 |
+
|
53 |
+
super().__init__()
|
54 |
+
|
55 |
+
self.spectral_sensitivity = spectral_sensitivity
|
56 |
+
|
57 |
+
if self.spectral_sensitivity == "g":
|
58 |
+
self.register_buffer("to_gray", torch.tensor([0.299, 0.587, 0.114]).reshape(1, -1, 1, 1))
|
59 |
+
|
60 |
+
def forward(self, rgb: torch.Tensor) -> torch.Tensor:
|
61 |
+
if self.spectral_sensitivity == "b":
|
62 |
+
x = rgb[:, -1:]
|
63 |
+
elif self.spectral_sensitivity == "gb":
|
64 |
+
x = (rgb[:, 1:2] + rgb[:, -1:]) * 0.5
|
65 |
+
else:
|
66 |
+
assert self.spectral_sensitivity == "g"
|
67 |
+
x = (rgb * self.to_gray).sum(dim=1, keepdim=True)
|
68 |
+
return x
|
69 |
+
|
70 |
+
|
71 |
+
class Downsample(nn.Module):
|
72 |
+
"""Antialiasing downsampling"""
|
73 |
+
def __init__(self, input_size: int, output_size: int, channels: int):
|
74 |
+
super().__init__()
|
75 |
+
if input_size % output_size == 0:
|
76 |
+
self.stride = input_size // output_size
|
77 |
+
self.grid = None
|
78 |
+
else:
|
79 |
+
self.stride = 1
|
80 |
+
step = input_size / output_size
|
81 |
+
x = torch.arange(output_size) * step
|
82 |
+
Y, X = torch.meshgrid(x, x)
|
83 |
+
grid = torch.stack((X, Y), dim=-1)
|
84 |
+
grid /= torch.Tensor((input_size - 1, input_size - 1)).view(1, 1, -1)
|
85 |
+
grid = grid * 2 - 1
|
86 |
+
self.register_buffer("grid", grid)
|
87 |
+
sigma = 0.5 * input_size / output_size
|
88 |
+
#print(f"{input_size} -> {output_size}: sigma={sigma}")
|
89 |
+
self.blur = GaussianSmoothing(channels, int(2 * (sigma * 2) + 1 + 0.5), sigma)
|
90 |
+
|
91 |
+
def forward(self, im: torch.Tensor):
|
92 |
+
out = self.blur(im, stride=self.stride)
|
93 |
+
if self.grid is not None:
|
94 |
+
out = F.grid_sample(out, self.grid[None].expand(im.shape[0], -1, -1, -1))
|
95 |
+
return out
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
class Degrade(nn.Module):
|
100 |
+
"""
|
101 |
+
Simulate the degradation of antique film
|
102 |
+
"""
|
103 |
+
def __init__(self, args:Namespace):
|
104 |
+
super().__init__()
|
105 |
+
self.srf = SpectralResponse(args.spectral_sensitivity)
|
106 |
+
self.crf = CameraResponse()
|
107 |
+
self.gaussian = None
|
108 |
+
if args.gaussian is not None and args.gaussian > 0:
|
109 |
+
self.gaussian = GaussianSmoothing(3, 2 * int(args.gaussian * 2 + 0.5) + 1, args.gaussian)
|
110 |
+
|
111 |
+
def forward(self, img: torch.Tensor, downsample: nn.Module = None):
|
112 |
+
if self.gaussian is not None:
|
113 |
+
img = self.gaussian(img)
|
114 |
+
if downsample is not None:
|
115 |
+
img = downsample(img)
|
116 |
+
img = self.srf(img)
|
117 |
+
img = self.crf(img)
|
118 |
+
# Note that I changed it back to 3 channels
|
119 |
+
return img.repeat((1, 3, 1, 1)) if img.shape[1] == 1 else img
|
120 |
+
|
121 |
+
|
122 |
+
|
Time_TravelRephotography/models/encoder.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import Namespace, ArgumentParser
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from .resnet import ResNetBasicBlock, activation_func, norm_module, Conv2dAuto
|
7 |
+
|
8 |
+
|
9 |
+
def add_arguments(parser: ArgumentParser) -> ArgumentParser:
|
10 |
+
parser.add_argument("--latent_size", type=int, default=512, help="latent size")
|
11 |
+
return parser
|
12 |
+
|
13 |
+
|
14 |
+
def create_model(args) -> nn.Module:
|
15 |
+
in_channels = 3 if "rgb" in args and args.rgb else 1
|
16 |
+
return Encoder(in_channels, args.encoder_size, latent_size=args.latent_size)
|
17 |
+
|
18 |
+
|
19 |
+
class Flatten(nn.Module):
|
20 |
+
def forward(self, input_):
|
21 |
+
return input_.view(input_.size(0), -1)
|
22 |
+
|
23 |
+
|
24 |
+
class Encoder(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self, in_channels: int, size: int, latent_size: int = 512,
|
27 |
+
activation: str = 'leaky_relu', norm: str = "instance"
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
out_channels0 = 64
|
32 |
+
norm_m = norm_module(norm)
|
33 |
+
self.conv0 = nn.Sequential(
|
34 |
+
Conv2dAuto(in_channels, out_channels0, kernel_size=5),
|
35 |
+
norm_m(out_channels0),
|
36 |
+
activation_func(activation),
|
37 |
+
)
|
38 |
+
|
39 |
+
pool_kernel = 2
|
40 |
+
self.pool = nn.AvgPool2d(pool_kernel)
|
41 |
+
|
42 |
+
num_channels = [128, 256, 512, 512]
|
43 |
+
# FIXME: this is a hack
|
44 |
+
if size >= 256:
|
45 |
+
num_channels.append(512)
|
46 |
+
|
47 |
+
residual = partial(ResNetBasicBlock, activation=activation, norm=norm, bias=True)
|
48 |
+
residual_blocks = nn.ModuleList()
|
49 |
+
for in_channel, out_channel in zip([out_channels0] + num_channels[:-1], num_channels):
|
50 |
+
residual_blocks.append(residual(in_channel, out_channel))
|
51 |
+
residual_blocks.append(nn.AvgPool2d(pool_kernel))
|
52 |
+
self.residual_blocks = nn.Sequential(*residual_blocks)
|
53 |
+
|
54 |
+
self.last = nn.Sequential(
|
55 |
+
nn.ReLU(),
|
56 |
+
nn.AvgPool2d(4), # TODO: not sure whehter this would cause problem
|
57 |
+
Flatten(),
|
58 |
+
nn.Linear(num_channels[-1], latent_size, bias=True)
|
59 |
+
)
|
60 |
+
|
61 |
+
def forward(self, input_):
|
62 |
+
out = self.conv0(input_)
|
63 |
+
out = self.pool(out)
|
64 |
+
out = self.residual_blocks(out)
|
65 |
+
out = self.last(out)
|
66 |
+
return out
|
Time_TravelRephotography/models/encoder4editing/.gitignore
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
# Custom dataset
|
132 |
+
pretrained_models
|
133 |
+
results_test
|
Time_TravelRephotography/models/encoder4editing/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 omertov
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Time_TravelRephotography/models/encoder4editing/README.md
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Designing an Encoder for StyleGAN Image Manipulation (SIGGRAPH 2021)
|
2 |
+
<a href="https://arxiv.org/abs/2102.02766"><img src="https://img.shields.io/badge/arXiv-2008.00951-b31b1b.svg"></a>
|
3 |
+
<a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
|
4 |
+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/omertov/encoder4editing/blob/main/notebooks/inference_playground.ipynb)
|
5 |
+
|
6 |
+
> Recently, there has been a surge of diverse methods for performing image editing by employing pre-trained unconditional generators. Applying these methods on real images, however, remains a challenge, as it necessarily requires the inversion of the images into their latent space. To successfully invert a real image, one needs to find a latent code that reconstructs the input image accurately, and more importantly, allows for its meaningful manipulation. In this paper, we carefully study the latent space of StyleGAN, the state-of-the-art unconditional generator. We identify and analyze the existence of a distortion-editability tradeoff and a distortion-perception tradeoff within the StyleGAN latent space. We then suggest two principles for designing encoders in a manner that allows one to control the proximity of the inversions to regions that StyleGAN was originally trained on. We present an encoder based on our two principles that is specifically designed for facilitating editing on real images by balancing these tradeoffs. By evaluating its performance qualitatively and quantitatively on numerous challenging domains, including cars and horses, we show that our inversion method, followed by common editing techniques, achieves superior real-image editing quality, with only a small reconstruction accuracy drop.
|
7 |
+
|
8 |
+
<p align="center">
|
9 |
+
<img src="docs/teaser.jpg" width="800px"/>
|
10 |
+
</p>
|
11 |
+
|
12 |
+
## Description
|
13 |
+
Official Implementation of "<a href="https://arxiv.org/abs/2102.02766">Designing an Encoder for StyleGAN Image Manipulation</a>" paper for both training and evaluation.
|
14 |
+
The e4e encoder is specifically designed to complement existing image manipulation techniques performed over StyleGAN's latent space.
|
15 |
+
|
16 |
+
## Recent Updates
|
17 |
+
`2021.08.17`: Add single style code encoder (use `--encoder_type SingleStyleCodeEncoder`). <br />
|
18 |
+
`2021.03.25`: Add pose editing direction.
|
19 |
+
|
20 |
+
## Getting Started
|
21 |
+
### Prerequisites
|
22 |
+
- Linux or macOS
|
23 |
+
- NVIDIA GPU + CUDA CuDNN (CPU may be possible with some modifications, but is not inherently supported)
|
24 |
+
- Python 3
|
25 |
+
|
26 |
+
### Installation
|
27 |
+
- Clone the repository:
|
28 |
+
```
|
29 |
+
git clone https://github.com/omertov/encoder4editing.git
|
30 |
+
cd encoder4editing
|
31 |
+
```
|
32 |
+
- Dependencies:
|
33 |
+
We recommend running this repository using [Anaconda](https://docs.anaconda.com/anaconda/install/).
|
34 |
+
All dependencies for defining the environment are provided in `environment/e4e_env.yaml`.
|
35 |
+
|
36 |
+
### Inference Notebook
|
37 |
+
We provide a Jupyter notebook found in `notebooks/inference_playground.ipynb` that allows one to encode and perform several editings on real images using StyleGAN.
|
38 |
+
|
39 |
+
### Pretrained Models
|
40 |
+
Please download the pre-trained models from the following links. Each e4e model contains the entire pSp framework architecture, including the encoder and decoder weights.
|
41 |
+
| Path | Description
|
42 |
+
| :--- | :----------
|
43 |
+
|[FFHQ Inversion](https://drive.google.com/file/d/1cUv_reLE6k3604or78EranS7XzuVMWeO/view?usp=sharing) | FFHQ e4e encoder.
|
44 |
+
|[Cars Inversion](https://drive.google.com/file/d/17faPqBce2m1AQeLCLHUVXaDfxMRU2QcV/view?usp=sharing) | Cars e4e encoder.
|
45 |
+
|[Horse Inversion](https://drive.google.com/file/d/1TkLLnuX86B_BMo2ocYD0kX9kWh53rUVX/view?usp=sharing) | Horse e4e encoder.
|
46 |
+
|[Church Inversion](https://drive.google.com/file/d/1-L0ZdnQLwtdy6-A_Ccgq5uNJGTqE7qBa/view?usp=sharing) | Church e4e encoder.
|
47 |
+
|
48 |
+
If you wish to use one of the pretrained models for training or inference, you may do so using the flag `--checkpoint_path`.
|
49 |
+
|
50 |
+
In addition, we provide various auxiliary models needed for training your own e4e model from scratch.
|
51 |
+
| Path | Description
|
52 |
+
| :--- | :----------
|
53 |
+
|[FFHQ StyleGAN](https://drive.google.com/file/d/1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT/view?usp=sharing) | StyleGAN model pretrained on FFHQ taken from [rosinality](https://github.com/rosinality/stylegan2-pytorch) with 1024x1024 output resolution.
|
54 |
+
|[IR-SE50 Model](https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing) | Pretrained IR-SE50 model taken from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) for use in our ID loss during training.
|
55 |
+
|[MOCOv2 Model](https://drive.google.com/file/d/18rLcNGdteX5LwT7sv_F7HWr12HpVEzVe/view?usp=sharing) | Pretrained ResNet-50 model trained using MOCOv2 for use in our simmilarity loss for domains other then human faces during training.
|
56 |
+
|
57 |
+
By default, we assume that all auxiliary models are downloaded and saved to the directory `pretrained_models`. However, you may use your own paths by changing the necessary values in `configs/path_configs.py`.
|
58 |
+
|
59 |
+
## Training
|
60 |
+
To train the e4e encoder, make sure the paths to the required models, as well as training and testing data is configured in `configs/path_configs.py` and `configs/data_configs.py`.
|
61 |
+
#### **Training the e4e Encoder**
|
62 |
+
```
|
63 |
+
python scripts/train.py \
|
64 |
+
--dataset_type cars_encode \
|
65 |
+
--exp_dir new/experiment/directory \
|
66 |
+
--start_from_latent_avg \
|
67 |
+
--use_w_pool \
|
68 |
+
--w_discriminator_lambda 0.1 \
|
69 |
+
--progressive_start 20000 \
|
70 |
+
--id_lambda 0.5 \
|
71 |
+
--val_interval 10000 \
|
72 |
+
--max_steps 200000 \
|
73 |
+
--stylegan_size 512 \
|
74 |
+
--stylegan_weights path/to/pretrained/stylegan.pt \
|
75 |
+
--workers 8 \
|
76 |
+
--batch_size 8 \
|
77 |
+
--test_batch_size 4 \
|
78 |
+
--test_workers 4
|
79 |
+
```
|
80 |
+
|
81 |
+
#### Training on your own dataset
|
82 |
+
In order to train the e4e encoder on a custom dataset, perform the following adjustments:
|
83 |
+
1. Insert the paths to your train and test data into the `dataset_paths` variable defined in `configs/paths_config.py`:
|
84 |
+
```
|
85 |
+
dataset_paths = {
|
86 |
+
'my_train_data': '/path/to/train/images/directory',
|
87 |
+
'my_test_data': '/path/to/test/images/directory'
|
88 |
+
}
|
89 |
+
```
|
90 |
+
2. Configure a new dataset under the DATASETS variable defined in `configs/data_configs.py`:
|
91 |
+
```
|
92 |
+
DATASETS = {
|
93 |
+
'my_data_encode': {
|
94 |
+
'transforms': transforms_config.EncodeTransforms,
|
95 |
+
'train_source_root': dataset_paths['my_train_data'],
|
96 |
+
'train_target_root': dataset_paths['my_train_data'],
|
97 |
+
'test_source_root': dataset_paths['my_test_data'],
|
98 |
+
'test_target_root': dataset_paths['my_test_data']
|
99 |
+
}
|
100 |
+
}
|
101 |
+
```
|
102 |
+
Refer to `configs/transforms_config.py` for the transformations applied to the train and test images during training.
|
103 |
+
|
104 |
+
3. Finally, run a training session with `--dataset_type my_data_encode`.
|
105 |
+
|
106 |
+
## Inference
|
107 |
+
Having trained your model, you can use `scripts/inference.py` to apply the model on a set of images.
|
108 |
+
For example,
|
109 |
+
```
|
110 |
+
python scripts/inference.py \
|
111 |
+
--images_dir=/path/to/images/directory \
|
112 |
+
--save_dir=/path/to/saving/directory \
|
113 |
+
path/to/checkpoint.pt
|
114 |
+
```
|
115 |
+
|
116 |
+
## Latent Editing Consistency (LEC)
|
117 |
+
As described in the paper, we suggest a new metric, Latent Editing Consistency (LEC), for evaluating the encoder's
|
118 |
+
performance.
|
119 |
+
We provide an example for calculating the metric over the FFHQ StyleGAN using the aging editing direction in
|
120 |
+
`metrics/LEC.py`.
|
121 |
+
|
122 |
+
To run the example:
|
123 |
+
```
|
124 |
+
cd metrics
|
125 |
+
python LEC.py \
|
126 |
+
--images_dir=/path/to/images/directory \
|
127 |
+
path/to/checkpoint.pt
|
128 |
+
```
|
129 |
+
|
130 |
+
## Acknowledgments
|
131 |
+
This code borrows heavily from [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel)
|
132 |
+
|
133 |
+
## Citation
|
134 |
+
If you use this code for your research, please cite our paper <a href="https://arxiv.org/abs/2102.02766">Designing an Encoder for StyleGAN Image Manipulation</a>:
|
135 |
+
|
136 |
+
```
|
137 |
+
@article{tov2021designing,
|
138 |
+
title={Designing an Encoder for StyleGAN Image Manipulation},
|
139 |
+
author={Tov, Omer and Alaluf, Yuval and Nitzan, Yotam and Patashnik, Or and Cohen-Or, Daniel},
|
140 |
+
journal={arXiv preprint arXiv:2102.02766},
|
141 |
+
year={2021}
|
142 |
+
}
|
143 |
+
```
|
Time_TravelRephotography/models/encoder4editing/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils.model_utils import setup_model
|
2 |
+
|
3 |
+
|
4 |
+
def get_latents(net, x, is_cars=False):
|
5 |
+
codes = net.encoder(x)
|
6 |
+
if net.opts.start_from_latent_avg:
|
7 |
+
if codes.ndim == 2:
|
8 |
+
codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
|
9 |
+
else:
|
10 |
+
codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)
|
11 |
+
if codes.shape[1] == 18 and is_cars:
|
12 |
+
codes = codes[:, :16, :]
|
13 |
+
return codes
|
14 |
+
|
15 |
+
|
Time_TravelRephotography/models/encoder4editing/bash_scripts/inference.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
set -exo
|
2 |
+
|
3 |
+
list="$1"
|
4 |
+
ckpt="${2:-pretrained_models/e4e_ffhq_encode.pt}"
|
5 |
+
|
6 |
+
base_dir="$REPHOTO/dataset/historically_interesting/aligned/manual_celebrity_in_19th_century/tier1/${list}/"
|
7 |
+
save_dir="results_test/${list}/"
|
8 |
+
|
9 |
+
|
10 |
+
TORCH_EXTENSIONS_DIR=/tmp/torch_extensions
|
11 |
+
PYTHONPATH="" \
|
12 |
+
python scripts/inference.py \
|
13 |
+
--images_dir="${base_dir}" \
|
14 |
+
--save_dir="${save_dir}" \
|
15 |
+
"${ckpt}"
|
Time_TravelRephotography/models/encoder4editing/configs/__init__.py
ADDED
File without changes
|
Time_TravelRephotography/models/encoder4editing/configs/data_configs.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs import transforms_config
|
2 |
+
from configs.paths_config import dataset_paths
|
3 |
+
|
4 |
+
|
5 |
+
DATASETS = {
|
6 |
+
'ffhq_encode': {
|
7 |
+
'transforms': transforms_config.EncodeTransforms,
|
8 |
+
'train_source_root': dataset_paths['ffhq'],
|
9 |
+
'train_target_root': dataset_paths['ffhq'],
|
10 |
+
'test_source_root': dataset_paths['celeba_test'],
|
11 |
+
'test_target_root': dataset_paths['celeba_test'],
|
12 |
+
},
|
13 |
+
'cars_encode': {
|
14 |
+
'transforms': transforms_config.CarsEncodeTransforms,
|
15 |
+
'train_source_root': dataset_paths['cars_train'],
|
16 |
+
'train_target_root': dataset_paths['cars_train'],
|
17 |
+
'test_source_root': dataset_paths['cars_test'],
|
18 |
+
'test_target_root': dataset_paths['cars_test'],
|
19 |
+
},
|
20 |
+
'horse_encode': {
|
21 |
+
'transforms': transforms_config.EncodeTransforms,
|
22 |
+
'train_source_root': dataset_paths['horse_train'],
|
23 |
+
'train_target_root': dataset_paths['horse_train'],
|
24 |
+
'test_source_root': dataset_paths['horse_test'],
|
25 |
+
'test_target_root': dataset_paths['horse_test'],
|
26 |
+
},
|
27 |
+
'church_encode': {
|
28 |
+
'transforms': transforms_config.EncodeTransforms,
|
29 |
+
'train_source_root': dataset_paths['church_train'],
|
30 |
+
'train_target_root': dataset_paths['church_train'],
|
31 |
+
'test_source_root': dataset_paths['church_test'],
|
32 |
+
'test_target_root': dataset_paths['church_test'],
|
33 |
+
},
|
34 |
+
'cats_encode': {
|
35 |
+
'transforms': transforms_config.EncodeTransforms,
|
36 |
+
'train_source_root': dataset_paths['cats_train'],
|
37 |
+
'train_target_root': dataset_paths['cats_train'],
|
38 |
+
'test_source_root': dataset_paths['cats_test'],
|
39 |
+
'test_target_root': dataset_paths['cats_test'],
|
40 |
+
}
|
41 |
+
}
|
Time_TravelRephotography/models/encoder4editing/configs/paths_config.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_paths = {
|
2 |
+
# Face Datasets (In the paper: FFHQ - train, CelebAHQ - test)
|
3 |
+
'ffhq': '',
|
4 |
+
'celeba_test': '',
|
5 |
+
|
6 |
+
# Cars Dataset (In the paper: Stanford cars)
|
7 |
+
'cars_train': '',
|
8 |
+
'cars_test': '',
|
9 |
+
|
10 |
+
# Horse Dataset (In the paper: LSUN Horse)
|
11 |
+
'horse_train': '',
|
12 |
+
'horse_test': '',
|
13 |
+
|
14 |
+
# Church Dataset (In the paper: LSUN Church)
|
15 |
+
'church_train': '',
|
16 |
+
'church_test': '',
|
17 |
+
|
18 |
+
# Cats Dataset (In the paper: LSUN Cat)
|
19 |
+
'cats_train': '',
|
20 |
+
'cats_test': ''
|
21 |
+
}
|
22 |
+
|
23 |
+
model_paths = {
|
24 |
+
'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
|
25 |
+
'ir_se50': 'pretrained_models/model_ir_se50.pth',
|
26 |
+
'shape_predictor': 'pretrained_models/shape_predictor_68_face_landmarks.dat',
|
27 |
+
'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth'
|
28 |
+
}
|
Time_TravelRephotography/models/encoder4editing/configs/transforms_config.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
|
4 |
+
|
5 |
+
class TransformsConfig(object):
|
6 |
+
|
7 |
+
def __init__(self, opts):
|
8 |
+
self.opts = opts
|
9 |
+
|
10 |
+
@abstractmethod
|
11 |
+
def get_transforms(self):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
class EncodeTransforms(TransformsConfig):
|
16 |
+
|
17 |
+
def __init__(self, opts):
|
18 |
+
super(EncodeTransforms, self).__init__(opts)
|
19 |
+
|
20 |
+
def get_transforms(self):
|
21 |
+
transforms_dict = {
|
22 |
+
'transform_gt_train': transforms.Compose([
|
23 |
+
transforms.Resize((256, 256)),
|
24 |
+
transforms.RandomHorizontalFlip(0.5),
|
25 |
+
transforms.ToTensor(),
|
26 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
27 |
+
'transform_source': None,
|
28 |
+
'transform_test': transforms.Compose([
|
29 |
+
transforms.Resize((256, 256)),
|
30 |
+
transforms.ToTensor(),
|
31 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
32 |
+
'transform_inference': transforms.Compose([
|
33 |
+
transforms.Resize((256, 256)),
|
34 |
+
transforms.ToTensor(),
|
35 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
36 |
+
}
|
37 |
+
return transforms_dict
|
38 |
+
|
39 |
+
|
40 |
+
class CarsEncodeTransforms(TransformsConfig):
|
41 |
+
|
42 |
+
def __init__(self, opts):
|
43 |
+
super(CarsEncodeTransforms, self).__init__(opts)
|
44 |
+
|
45 |
+
def get_transforms(self):
|
46 |
+
transforms_dict = {
|
47 |
+
'transform_gt_train': transforms.Compose([
|
48 |
+
transforms.Resize((192, 256)),
|
49 |
+
transforms.RandomHorizontalFlip(0.5),
|
50 |
+
transforms.ToTensor(),
|
51 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
52 |
+
'transform_source': None,
|
53 |
+
'transform_test': transforms.Compose([
|
54 |
+
transforms.Resize((192, 256)),
|
55 |
+
transforms.ToTensor(),
|
56 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
57 |
+
'transform_inference': transforms.Compose([
|
58 |
+
transforms.Resize((192, 256)),
|
59 |
+
transforms.ToTensor(),
|
60 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
61 |
+
}
|
62 |
+
return transforms_dict
|
Time_TravelRephotography/models/encoder4editing/criteria/__init__.py
ADDED
File without changes
|
Time_TravelRephotography/models/encoder4editing/criteria/id_loss.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from configs.paths_config import model_paths
|
4 |
+
from models.encoders.model_irse import Backbone
|
5 |
+
|
6 |
+
|
7 |
+
class IDLoss(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(IDLoss, self).__init__()
|
10 |
+
print('Loading ResNet ArcFace')
|
11 |
+
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
12 |
+
self.facenet.load_state_dict(torch.load(model_paths['ir_se50']))
|
13 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
14 |
+
self.facenet.eval()
|
15 |
+
for module in [self.facenet, self.face_pool]:
|
16 |
+
for param in module.parameters():
|
17 |
+
param.requires_grad = False
|
18 |
+
|
19 |
+
def extract_feats(self, x):
|
20 |
+
x = x[:, :, 35:223, 32:220] # Crop interesting region
|
21 |
+
x = self.face_pool(x)
|
22 |
+
x_feats = self.facenet(x)
|
23 |
+
return x_feats
|
24 |
+
|
25 |
+
def forward(self, y_hat, y, x):
|
26 |
+
n_samples = x.shape[0]
|
27 |
+
x_feats = self.extract_feats(x)
|
28 |
+
y_feats = self.extract_feats(y) # Otherwise use the feature from there
|
29 |
+
y_hat_feats = self.extract_feats(y_hat)
|
30 |
+
y_feats = y_feats.detach()
|
31 |
+
loss = 0
|
32 |
+
sim_improvement = 0
|
33 |
+
id_logs = []
|
34 |
+
count = 0
|
35 |
+
for i in range(n_samples):
|
36 |
+
diff_target = y_hat_feats[i].dot(y_feats[i])
|
37 |
+
diff_input = y_hat_feats[i].dot(x_feats[i])
|
38 |
+
diff_views = y_feats[i].dot(x_feats[i])
|
39 |
+
id_logs.append({'diff_target': float(diff_target),
|
40 |
+
'diff_input': float(diff_input),
|
41 |
+
'diff_views': float(diff_views)})
|
42 |
+
loss += 1 - diff_target
|
43 |
+
id_diff = float(diff_target) - float(diff_views)
|
44 |
+
sim_improvement += id_diff
|
45 |
+
count += 1
|
46 |
+
|
47 |
+
return loss / count, sim_improvement / count, id_logs
|
Time_TravelRephotography/models/encoder4editing/criteria/lpips/__init__.py
ADDED
File without changes
|
Time_TravelRephotography/models/encoder4editing/criteria/lpips/lpips.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from criteria.lpips.networks import get_network, LinLayers
|
5 |
+
from criteria.lpips.utils import get_state_dict
|
6 |
+
|
7 |
+
|
8 |
+
class LPIPS(nn.Module):
|
9 |
+
r"""Creates a criterion that measures
|
10 |
+
Learned Perceptual Image Patch Similarity (LPIPS).
|
11 |
+
Arguments:
|
12 |
+
net_type (str): the network type to compare the features:
|
13 |
+
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
|
14 |
+
version (str): the version of LPIPS. Default: 0.1.
|
15 |
+
"""
|
16 |
+
def __init__(self, net_type: str = 'alex', version: str = '0.1'):
|
17 |
+
|
18 |
+
assert version in ['0.1'], 'v0.1 is only supported now'
|
19 |
+
|
20 |
+
super(LPIPS, self).__init__()
|
21 |
+
|
22 |
+
# pretrained network
|
23 |
+
self.net = get_network(net_type).to("cuda")
|
24 |
+
|
25 |
+
# linear layers
|
26 |
+
self.lin = LinLayers(self.net.n_channels_list).to("cuda")
|
27 |
+
self.lin.load_state_dict(get_state_dict(net_type, version))
|
28 |
+
|
29 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
30 |
+
feat_x, feat_y = self.net(x), self.net(y)
|
31 |
+
|
32 |
+
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
|
33 |
+
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
|
34 |
+
|
35 |
+
return torch.sum(torch.cat(res, 0)) / x.shape[0]
|
Time_TravelRephotography/models/encoder4editing/criteria/lpips/networks.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence
|
2 |
+
|
3 |
+
from itertools import chain
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision import models
|
8 |
+
|
9 |
+
from criteria.lpips.utils import normalize_activation
|
10 |
+
|
11 |
+
|
12 |
+
def get_network(net_type: str):
|
13 |
+
if net_type == 'alex':
|
14 |
+
return AlexNet()
|
15 |
+
elif net_type == 'squeeze':
|
16 |
+
return SqueezeNet()
|
17 |
+
elif net_type == 'vgg':
|
18 |
+
return VGG16()
|
19 |
+
else:
|
20 |
+
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
|
21 |
+
|
22 |
+
|
23 |
+
class LinLayers(nn.ModuleList):
|
24 |
+
def __init__(self, n_channels_list: Sequence[int]):
|
25 |
+
super(LinLayers, self).__init__([
|
26 |
+
nn.Sequential(
|
27 |
+
nn.Identity(),
|
28 |
+
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
|
29 |
+
) for nc in n_channels_list
|
30 |
+
])
|
31 |
+
|
32 |
+
for param in self.parameters():
|
33 |
+
param.requires_grad = False
|
34 |
+
|
35 |
+
|
36 |
+
class BaseNet(nn.Module):
|
37 |
+
def __init__(self):
|
38 |
+
super(BaseNet, self).__init__()
|
39 |
+
|
40 |
+
# register buffer
|
41 |
+
self.register_buffer(
|
42 |
+
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
43 |
+
self.register_buffer(
|
44 |
+
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
45 |
+
|
46 |
+
def set_requires_grad(self, state: bool):
|
47 |
+
for param in chain(self.parameters(), self.buffers()):
|
48 |
+
param.requires_grad = state
|
49 |
+
|
50 |
+
def z_score(self, x: torch.Tensor):
|
51 |
+
return (x - self.mean) / self.std
|
52 |
+
|
53 |
+
def forward(self, x: torch.Tensor):
|
54 |
+
x = self.z_score(x)
|
55 |
+
|
56 |
+
output = []
|
57 |
+
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
|
58 |
+
x = layer(x)
|
59 |
+
if i in self.target_layers:
|
60 |
+
output.append(normalize_activation(x))
|
61 |
+
if len(output) == len(self.target_layers):
|
62 |
+
break
|
63 |
+
return output
|
64 |
+
|
65 |
+
|
66 |
+
class SqueezeNet(BaseNet):
|
67 |
+
def __init__(self):
|
68 |
+
super(SqueezeNet, self).__init__()
|
69 |
+
|
70 |
+
self.layers = models.squeezenet1_1(True).features
|
71 |
+
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
|
72 |
+
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
|
73 |
+
|
74 |
+
self.set_requires_grad(False)
|
75 |
+
|
76 |
+
|
77 |
+
class AlexNet(BaseNet):
|
78 |
+
def __init__(self):
|
79 |
+
super(AlexNet, self).__init__()
|
80 |
+
|
81 |
+
self.layers = models.alexnet(True).features
|
82 |
+
self.target_layers = [2, 5, 8, 10, 12]
|
83 |
+
self.n_channels_list = [64, 192, 384, 256, 256]
|
84 |
+
|
85 |
+
self.set_requires_grad(False)
|
86 |
+
|
87 |
+
|
88 |
+
class VGG16(BaseNet):
|
89 |
+
def __init__(self):
|
90 |
+
super(VGG16, self).__init__()
|
91 |
+
|
92 |
+
self.layers = models.vgg16(True).features
|
93 |
+
self.target_layers = [4, 9, 16, 23, 30]
|
94 |
+
self.n_channels_list = [64, 128, 256, 512, 512]
|
95 |
+
|
96 |
+
self.set_requires_grad(False)
|
Time_TravelRephotography/models/encoder4editing/criteria/lpips/utils.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def normalize_activation(x, eps=1e-10):
|
7 |
+
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
|
8 |
+
return x / (norm_factor + eps)
|
9 |
+
|
10 |
+
|
11 |
+
def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
|
12 |
+
# build url
|
13 |
+
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
|
14 |
+
+ f'master/lpips/weights/v{version}/{net_type}.pth'
|
15 |
+
|
16 |
+
# download
|
17 |
+
old_state_dict = torch.hub.load_state_dict_from_url(
|
18 |
+
url, progress=True,
|
19 |
+
map_location=None if torch.cuda.is_available() else torch.device('cpu')
|
20 |
+
)
|
21 |
+
|
22 |
+
# rename keys
|
23 |
+
new_state_dict = OrderedDict()
|
24 |
+
for key, val in old_state_dict.items():
|
25 |
+
new_key = key
|
26 |
+
new_key = new_key.replace('lin', '')
|
27 |
+
new_key = new_key.replace('model.', '')
|
28 |
+
new_state_dict[new_key] = val
|
29 |
+
|
30 |
+
return new_state_dict
|
Time_TravelRephotography/models/encoder4editing/criteria/moco_loss.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from configs.paths_config import model_paths
|
6 |
+
|
7 |
+
|
8 |
+
class MocoLoss(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, opts):
|
11 |
+
super(MocoLoss, self).__init__()
|
12 |
+
print("Loading MOCO model from path: {}".format(model_paths["moco"]))
|
13 |
+
self.model = self.__load_model()
|
14 |
+
self.model.eval()
|
15 |
+
for param in self.model.parameters():
|
16 |
+
param.requires_grad = False
|
17 |
+
|
18 |
+
@staticmethod
|
19 |
+
def __load_model():
|
20 |
+
import torchvision.models as models
|
21 |
+
model = models.__dict__["resnet50"]()
|
22 |
+
# freeze all layers but the last fc
|
23 |
+
for name, param in model.named_parameters():
|
24 |
+
if name not in ['fc.weight', 'fc.bias']:
|
25 |
+
param.requires_grad = False
|
26 |
+
checkpoint = torch.load(model_paths['moco'], map_location="cpu")
|
27 |
+
state_dict = checkpoint['state_dict']
|
28 |
+
# rename moco pre-trained keys
|
29 |
+
for k in list(state_dict.keys()):
|
30 |
+
# retain only encoder_q up to before the embedding layer
|
31 |
+
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
|
32 |
+
# remove prefix
|
33 |
+
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
|
34 |
+
# delete renamed or unused k
|
35 |
+
del state_dict[k]
|
36 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
37 |
+
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
|
38 |
+
# remove output layer
|
39 |
+
model = nn.Sequential(*list(model.children())[:-1]).cuda()
|
40 |
+
return model
|
41 |
+
|
42 |
+
def extract_feats(self, x):
|
43 |
+
x = F.interpolate(x, size=224)
|
44 |
+
x_feats = self.model(x)
|
45 |
+
x_feats = nn.functional.normalize(x_feats, dim=1)
|
46 |
+
x_feats = x_feats.squeeze()
|
47 |
+
return x_feats
|
48 |
+
|
49 |
+
def forward(self, y_hat, y, x):
|
50 |
+
n_samples = x.shape[0]
|
51 |
+
x_feats = self.extract_feats(x)
|
52 |
+
y_feats = self.extract_feats(y)
|
53 |
+
y_hat_feats = self.extract_feats(y_hat)
|
54 |
+
y_feats = y_feats.detach()
|
55 |
+
loss = 0
|
56 |
+
sim_improvement = 0
|
57 |
+
sim_logs = []
|
58 |
+
count = 0
|
59 |
+
for i in range(n_samples):
|
60 |
+
diff_target = y_hat_feats[i].dot(y_feats[i])
|
61 |
+
diff_input = y_hat_feats[i].dot(x_feats[i])
|
62 |
+
diff_views = y_feats[i].dot(x_feats[i])
|
63 |
+
sim_logs.append({'diff_target': float(diff_target),
|
64 |
+
'diff_input': float(diff_input),
|
65 |
+
'diff_views': float(diff_views)})
|
66 |
+
loss += 1 - diff_target
|
67 |
+
sim_diff = float(diff_target) - float(diff_views)
|
68 |
+
sim_improvement += sim_diff
|
69 |
+
count += 1
|
70 |
+
|
71 |
+
return loss / count, sim_improvement / count, sim_logs
|
Time_TravelRephotography/models/encoder4editing/criteria/w_norm.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class WNormLoss(nn.Module):
|
6 |
+
|
7 |
+
def __init__(self, start_from_latent_avg=True):
|
8 |
+
super(WNormLoss, self).__init__()
|
9 |
+
self.start_from_latent_avg = start_from_latent_avg
|
10 |
+
|
11 |
+
def forward(self, latent, latent_avg=None):
|
12 |
+
if self.start_from_latent_avg:
|
13 |
+
latent = latent - latent_avg
|
14 |
+
return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]
|
Time_TravelRephotography/models/encoder4editing/datasets/__init__.py
ADDED
File without changes
|
Time_TravelRephotography/models/encoder4editing/datasets/gt_res_dataset.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# encoding: utf-8
|
3 |
+
import os
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
|
8 |
+
class GTResDataset(Dataset):
|
9 |
+
|
10 |
+
def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
|
11 |
+
self.pairs = []
|
12 |
+
for f in os.listdir(root_path):
|
13 |
+
image_path = os.path.join(root_path, f)
|
14 |
+
gt_path = os.path.join(gt_dir, f)
|
15 |
+
if f.endswith(".jpg") or f.endswith(".png"):
|
16 |
+
self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
|
17 |
+
self.transform = transform
|
18 |
+
self.transform_train = transform_train
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return len(self.pairs)
|
22 |
+
|
23 |
+
def __getitem__(self, index):
|
24 |
+
from_path, to_path, _ = self.pairs[index]
|
25 |
+
from_im = Image.open(from_path).convert('RGB')
|
26 |
+
to_im = Image.open(to_path).convert('RGB')
|
27 |
+
|
28 |
+
if self.transform:
|
29 |
+
to_im = self.transform(to_im)
|
30 |
+
from_im = self.transform(from_im)
|
31 |
+
|
32 |
+
return from_im, to_im
|
Time_TravelRephotography/models/encoder4editing/datasets/images_dataset.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from PIL import Image
|
3 |
+
from utils import data_utils
|
4 |
+
|
5 |
+
|
6 |
+
class ImagesDataset(Dataset):
|
7 |
+
|
8 |
+
def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None):
|
9 |
+
self.source_paths = sorted(data_utils.make_dataset(source_root))
|
10 |
+
self.target_paths = sorted(data_utils.make_dataset(target_root))
|
11 |
+
self.source_transform = source_transform
|
12 |
+
self.target_transform = target_transform
|
13 |
+
self.opts = opts
|
14 |
+
|
15 |
+
def __len__(self):
|
16 |
+
return len(self.source_paths)
|
17 |
+
|
18 |
+
def __getitem__(self, index):
|
19 |
+
from_path = self.source_paths[index]
|
20 |
+
from_im = Image.open(from_path)
|
21 |
+
from_im = from_im.convert('RGB')
|
22 |
+
|
23 |
+
to_path = self.target_paths[index]
|
24 |
+
to_im = Image.open(to_path).convert('RGB')
|
25 |
+
if self.target_transform:
|
26 |
+
to_im = self.target_transform(to_im)
|
27 |
+
|
28 |
+
if self.source_transform:
|
29 |
+
from_im = self.source_transform(from_im)
|
30 |
+
else:
|
31 |
+
from_im = to_im
|
32 |
+
|
33 |
+
return from_im, to_im
|
Time_TravelRephotography/models/encoder4editing/datasets/inference_dataset.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from PIL import Image
|
3 |
+
from utils import data_utils
|
4 |
+
|
5 |
+
|
6 |
+
class InferenceDataset(Dataset):
|
7 |
+
|
8 |
+
def __init__(self, root, opts, transform=None, preprocess=None):
|
9 |
+
self.paths = sorted(data_utils.make_dataset(root))
|
10 |
+
self.transform = transform
|
11 |
+
self.preprocess = preprocess
|
12 |
+
self.opts = opts
|
13 |
+
|
14 |
+
def __len__(self):
|
15 |
+
return len(self.paths)
|
16 |
+
|
17 |
+
def __getitem__(self, index):
|
18 |
+
from_path = self.paths[index]
|
19 |
+
if self.preprocess is not None:
|
20 |
+
from_im = self.preprocess(from_path)
|
21 |
+
else:
|
22 |
+
from_im = Image.open(from_path).convert('RGB')
|
23 |
+
if self.transform:
|
24 |
+
from_im = self.transform(from_im)
|
25 |
+
return from_im
|
Time_TravelRephotography/models/encoder4editing/editings/ganspace.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def edit(latents, pca, edit_directions):
|
5 |
+
edit_latents = []
|
6 |
+
for latent in latents:
|
7 |
+
for pca_idx, start, end, strength in edit_directions:
|
8 |
+
delta = get_delta(pca, latent, pca_idx, strength)
|
9 |
+
delta_padded = torch.zeros(latent.shape).to('cuda')
|
10 |
+
delta_padded[start:end] += delta.repeat(end - start, 1)
|
11 |
+
edit_latents.append(latent + delta_padded)
|
12 |
+
return torch.stack(edit_latents)
|
13 |
+
|
14 |
+
|
15 |
+
def get_delta(pca, latent, idx, strength):
|
16 |
+
# pca: ganspace checkpoint. latent: (16, 512) w+
|
17 |
+
w_centered = latent - pca['mean'].to('cuda')
|
18 |
+
lat_comp = pca['comp'].to('cuda')
|
19 |
+
lat_std = pca['std'].to('cuda')
|
20 |
+
w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx]
|
21 |
+
delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx]
|
22 |
+
return delta
|
Time_TravelRephotography/models/encoder4editing/editings/ganspace_pca/cars_pca.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a5c3bae61ecd85de077fbbf103f5f30cf4b7676fe23a8508166eaf2ce73c8392
|
3 |
+
size 167562
|
Time_TravelRephotography/models/encoder4editing/editings/ganspace_pca/ffhq_pca.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4d7f9df1c96180d9026b9cb8d04753579fbf385f321a9d0e263641601c5e5d36
|
3 |
+
size 167562
|
Time_TravelRephotography/models/encoder4editing/editings/interfacegan_directions/age.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50074516b1629707d89b5e43d6b8abd1792212fa3b961a87a11323d6a5222ae0
|
3 |
+
size 2808
|