feng2022 commited on
Commit
f9827f9
1 Parent(s): 1c04a5a
.gitignore ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ wandb/
132
+ *.lmdb/
133
+ *.pkl
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 Kim Seonghyeon
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LICENSE-FID ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LICENSE-LPIPS ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24
+
LICENSE-NVIDIA ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+
3
+
4
+ Nvidia Source Code License-NC
5
+
6
+ =======================================================================
7
+
8
+ 1. Definitions
9
+
10
+ "Licensor" means any person or entity that distributes its Work.
11
+
12
+ "Software" means the original work of authorship made available under
13
+ this License.
14
+
15
+ "Work" means the Software and any additions to or derivative works of
16
+ the Software that are made available under this License.
17
+
18
+ "Nvidia Processors" means any central processing unit (CPU), graphics
19
+ processing unit (GPU), field-programmable gate array (FPGA),
20
+ application-specific integrated circuit (ASIC) or any combination
21
+ thereof designed, made, sold, or provided by Nvidia or its affiliates.
22
+
23
+ The terms "reproduce," "reproduction," "derivative works," and
24
+ "distribution" have the meaning as provided under U.S. copyright law;
25
+ provided, however, that for the purposes of this License, derivative
26
+ works shall not include works that remain separable from, or merely
27
+ link (or bind by name) to the interfaces of, the Work.
28
+
29
+ Works, including the Software, are "made available" under this License
30
+ by including in or with the Work either (a) a copyright notice
31
+ referencing the applicability of this License to the Work, or (b) a
32
+ copy of this License.
33
+
34
+ 2. License Grants
35
+
36
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
37
+ License, each Licensor grants to you a perpetual, worldwide,
38
+ non-exclusive, royalty-free, copyright license to reproduce,
39
+ prepare derivative works of, publicly display, publicly perform,
40
+ sublicense and distribute its Work and any resulting derivative
41
+ works in any form.
42
+
43
+ 3. Limitations
44
+
45
+ 3.1 Redistribution. You may reproduce or distribute the Work only
46
+ if (a) you do so under this License, (b) you include a complete
47
+ copy of this License with your distribution, and (c) you retain
48
+ without modification any copyright, patent, trademark, or
49
+ attribution notices that are present in the Work.
50
+
51
+ 3.2 Derivative Works. You may specify that additional or different
52
+ terms apply to the use, reproduction, and distribution of your
53
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
54
+ provide that the use limitation in Section 3.3 applies to your
55
+ derivative works, and (b) you identify the specific derivative
56
+ works that are subject to Your Terms. Notwithstanding Your Terms,
57
+ this License (including the redistribution requirements in Section
58
+ 3.1) will continue to apply to the Work itself.
59
+
60
+ 3.3 Use Limitation. The Work and any derivative works thereof only
61
+ may be used or intended for use non-commercially. The Work or
62
+ derivative works thereof may be used or intended for use by Nvidia
63
+ or its affiliates commercially or non-commercially. As used herein,
64
+ "non-commercially" means for research or evaluation purposes only.
65
+
66
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
67
+ against any Licensor (including any claim, cross-claim or
68
+ counterclaim in a lawsuit) to enforce any patents that you allege
69
+ are infringed by any Work, then your rights under this License from
70
+ such Licensor (including the grants in Sections 2.1 and 2.2) will
71
+ terminate immediately.
72
+
73
+ 3.5 Trademarks. This License does not grant any rights to use any
74
+ Licensor's or its affiliates' names, logos, or trademarks, except
75
+ as necessary to reproduce the notices described in this License.
76
+
77
+ 3.6 Termination. If you violate any term of this License, then your
78
+ rights under this License (including the grants in Sections 2.1 and
79
+ 2.2) will terminate immediately.
80
+
81
+ 4. Disclaimer of Warranty.
82
+
83
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
84
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
85
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
86
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
87
+ THIS LICENSE.
88
+
89
+ 5. Limitation of Liability.
90
+
91
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
92
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
93
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
94
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
95
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
96
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
97
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
98
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
99
+ THE POSSIBILITY OF SUCH DAMAGES.
100
+
101
+ =======================================================================
apply_factor.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from torchvision import utils
5
+
6
+ from model import Generator
7
+
8
+
9
+ if __name__ == "__main__":
10
+ torch.set_grad_enabled(False)
11
+
12
+ parser = argparse.ArgumentParser(description="Apply closed form factorization")
13
+
14
+ parser.add_argument(
15
+ "-i", "--index", type=int, default=0, help="index of eigenvector"
16
+ )
17
+ parser.add_argument(
18
+ "-d",
19
+ "--degree",
20
+ type=float,
21
+ default=5,
22
+ help="scalar factors for moving latent vectors along eigenvector",
23
+ )
24
+ parser.add_argument(
25
+ "--channel_multiplier",
26
+ type=int,
27
+ default=2,
28
+ help='channel multiplier factor. config-f = 2, else = 1',
29
+ )
30
+ parser.add_argument("--ckpt", type=str, required=True, help="stylegan2 checkpoints")
31
+ parser.add_argument(
32
+ "--size", type=int, default=256, help="output image size of the generator"
33
+ )
34
+ parser.add_argument(
35
+ "-n", "--n_sample", type=int, default=7, help="number of samples created"
36
+ )
37
+ parser.add_argument(
38
+ "--truncation", type=float, default=0.7, help="truncation factor"
39
+ )
40
+ parser.add_argument(
41
+ "--device", type=str, default="cuda", help="device to run the model"
42
+ )
43
+ parser.add_argument(
44
+ "--out_prefix",
45
+ type=str,
46
+ default="factor",
47
+ help="filename prefix to result samples",
48
+ )
49
+ parser.add_argument(
50
+ "factor",
51
+ type=str,
52
+ help="name of the closed form factorization result factor file",
53
+ )
54
+
55
+ args = parser.parse_args()
56
+
57
+ eigvec = torch.load(args.factor)["eigvec"].to(args.device)
58
+ ckpt = torch.load(args.ckpt)
59
+ g = Generator(args.size, 512, 8, channel_multiplier=args.channel_multiplier).to(args.device)
60
+ g.load_state_dict(ckpt["g_ema"], strict=False)
61
+
62
+ trunc = g.mean_latent(4096)
63
+
64
+ latent = torch.randn(args.n_sample, 512, device=args.device)
65
+ latent = g.get_latent(latent)
66
+
67
+ direction = args.degree * eigvec[:, args.index].unsqueeze(0)
68
+
69
+ img, _ = g(
70
+ [latent],
71
+ truncation=args.truncation,
72
+ truncation_latent=trunc,
73
+ input_is_latent=True,
74
+ )
75
+ img1, _ = g(
76
+ [latent + direction],
77
+ truncation=args.truncation,
78
+ truncation_latent=trunc,
79
+ input_is_latent=True,
80
+ )
81
+ img2, _ = g(
82
+ [latent - direction],
83
+ truncation=args.truncation,
84
+ truncation_latent=trunc,
85
+ input_is_latent=True,
86
+ )
87
+
88
+ grid = utils.save_image(
89
+ torch.cat([img1, img, img2], 0),
90
+ f"{args.out_prefix}_index-{args.index}_degree-{args.degree}.png",
91
+ normalize=True,
92
+ range=(-1, 1),
93
+ nrow=args.n_sample,
94
+ )
calc_inception.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pickle
3
+ import os
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from torch.utils.data import DataLoader
9
+ from torchvision import transforms
10
+ from torchvision.models import inception_v3, Inception3
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+
14
+ from inception import InceptionV3
15
+ from dataset import MultiResolutionDataset
16
+
17
+
18
+ class Inception3Feature(Inception3):
19
+ def forward(self, x):
20
+ if x.shape[2] != 299 or x.shape[3] != 299:
21
+ x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=True)
22
+
23
+ x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3
24
+ x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32
25
+ x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32
26
+ x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64
27
+
28
+ x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64
29
+ x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80
30
+ x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192
31
+
32
+ x = self.Mixed_5b(x) # 35 x 35 x 192
33
+ x = self.Mixed_5c(x) # 35 x 35 x 256
34
+ x = self.Mixed_5d(x) # 35 x 35 x 288
35
+
36
+ x = self.Mixed_6a(x) # 35 x 35 x 288
37
+ x = self.Mixed_6b(x) # 17 x 17 x 768
38
+ x = self.Mixed_6c(x) # 17 x 17 x 768
39
+ x = self.Mixed_6d(x) # 17 x 17 x 768
40
+ x = self.Mixed_6e(x) # 17 x 17 x 768
41
+
42
+ x = self.Mixed_7a(x) # 17 x 17 x 768
43
+ x = self.Mixed_7b(x) # 8 x 8 x 1280
44
+ x = self.Mixed_7c(x) # 8 x 8 x 2048
45
+
46
+ x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048
47
+
48
+ return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048
49
+
50
+
51
+ def load_patched_inception_v3():
52
+ # inception = inception_v3(pretrained=True)
53
+ # inception_feat = Inception3Feature()
54
+ # inception_feat.load_state_dict(inception.state_dict())
55
+ inception_feat = InceptionV3([3], normalize_input=False)
56
+
57
+ return inception_feat
58
+
59
+
60
+ @torch.no_grad()
61
+ def extract_features(loader, inception, device):
62
+ pbar = tqdm(loader)
63
+
64
+ feature_list = []
65
+
66
+ for img in pbar:
67
+ img = img.to(device)
68
+ feature = inception(img)[0].view(img.shape[0], -1)
69
+ feature_list.append(feature.to("cpu"))
70
+
71
+ features = torch.cat(feature_list, 0)
72
+
73
+ return features
74
+
75
+
76
+ if __name__ == "__main__":
77
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
+
79
+ parser = argparse.ArgumentParser(
80
+ description="Calculate Inception v3 features for datasets"
81
+ )
82
+ parser.add_argument(
83
+ "--size",
84
+ type=int,
85
+ default=256,
86
+ help="image sizes used for embedding calculation",
87
+ )
88
+ parser.add_argument(
89
+ "--batch", default=64, type=int, help="batch size for inception networks"
90
+ )
91
+ parser.add_argument(
92
+ "--n_sample",
93
+ type=int,
94
+ default=50000,
95
+ help="number of samples used for embedding calculation",
96
+ )
97
+ parser.add_argument(
98
+ "--flip", action="store_true", help="apply random flipping to real images"
99
+ )
100
+ parser.add_argument("path", metavar="PATH", help="path to datset lmdb file")
101
+
102
+ args = parser.parse_args()
103
+
104
+ inception = load_patched_inception_v3()
105
+ inception = nn.DataParallel(inception).eval().to(device)
106
+
107
+ transform = transforms.Compose(
108
+ [
109
+ transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
110
+ transforms.ToTensor(),
111
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
112
+ ]
113
+ )
114
+
115
+ dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size)
116
+ loader = DataLoader(dset, batch_size=args.batch, num_workers=4)
117
+
118
+ features = extract_features(loader, inception, device).numpy()
119
+
120
+ features = features[: args.n_sample]
121
+
122
+ print(f"extracted {features.shape[0]} features")
123
+
124
+ mean = np.mean(features, 0)
125
+ cov = np.cov(features, rowvar=False)
126
+
127
+ name = os.path.splitext(os.path.basename(args.path))[0]
128
+
129
+ with open(f"inception_{name}.pkl", "wb") as f:
130
+ pickle.dump({"mean": mean, "cov": cov, "size": args.size, "path": args.path}, f)
checkpoint/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pt
closed_form_factorization.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+
5
+
6
+ if __name__ == "__main__":
7
+ parser = argparse.ArgumentParser(
8
+ description="Extract factor/eigenvectors of latent spaces using closed form factorization"
9
+ )
10
+
11
+ parser.add_argument(
12
+ "--out", type=str, default="factor.pt", help="name of the result factor file"
13
+ )
14
+ parser.add_argument("ckpt", type=str, help="name of the model checkpoint")
15
+
16
+ args = parser.parse_args()
17
+
18
+ ckpt = torch.load(args.ckpt)
19
+ modulate = {
20
+ k: v
21
+ for k, v in ckpt["g_ema"].items()
22
+ if "modulation" in k and "to_rgbs" not in k and "weight" in k
23
+ }
24
+
25
+ weight_mat = []
26
+ for k, v in modulate.items():
27
+ weight_mat.append(v)
28
+
29
+ W = torch.cat(weight_mat, 0)
30
+ eigvec = torch.svd(W).V.to("cpu")
31
+
32
+ torch.save({"ckpt": args.ckpt, "eigvec": eigvec}, args.out)
33
+
convert_weight.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import pickle
5
+ import math
6
+
7
+ import torch
8
+ import numpy as np
9
+ from torchvision import utils
10
+
11
+ from model import Generator, Discriminator
12
+
13
+
14
+ def convert_modconv(vars, source_name, target_name, flip=False):
15
+ weight = vars[source_name + "/weight"].value().eval()
16
+ mod_weight = vars[source_name + "/mod_weight"].value().eval()
17
+ mod_bias = vars[source_name + "/mod_bias"].value().eval()
18
+ noise = vars[source_name + "/noise_strength"].value().eval()
19
+ bias = vars[source_name + "/bias"].value().eval()
20
+
21
+ dic = {
22
+ "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
23
+ "conv.modulation.weight": mod_weight.transpose((1, 0)),
24
+ "conv.modulation.bias": mod_bias + 1,
25
+ "noise.weight": np.array([noise]),
26
+ "activate.bias": bias,
27
+ }
28
+
29
+ dic_torch = {}
30
+
31
+ for k, v in dic.items():
32
+ dic_torch[target_name + "." + k] = torch.from_numpy(v)
33
+
34
+ if flip:
35
+ dic_torch[target_name + ".conv.weight"] = torch.flip(
36
+ dic_torch[target_name + ".conv.weight"], [3, 4]
37
+ )
38
+
39
+ return dic_torch
40
+
41
+
42
+ def convert_conv(vars, source_name, target_name, bias=True, start=0):
43
+ weight = vars[source_name + "/weight"].value().eval()
44
+
45
+ dic = {"weight": weight.transpose((3, 2, 0, 1))}
46
+
47
+ if bias:
48
+ dic["bias"] = vars[source_name + "/bias"].value().eval()
49
+
50
+ dic_torch = {}
51
+
52
+ dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"])
53
+
54
+ if bias:
55
+ dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"])
56
+
57
+ return dic_torch
58
+
59
+
60
+ def convert_torgb(vars, source_name, target_name):
61
+ weight = vars[source_name + "/weight"].value().eval()
62
+ mod_weight = vars[source_name + "/mod_weight"].value().eval()
63
+ mod_bias = vars[source_name + "/mod_bias"].value().eval()
64
+ bias = vars[source_name + "/bias"].value().eval()
65
+
66
+ dic = {
67
+ "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
68
+ "conv.modulation.weight": mod_weight.transpose((1, 0)),
69
+ "conv.modulation.bias": mod_bias + 1,
70
+ "bias": bias.reshape((1, 3, 1, 1)),
71
+ }
72
+
73
+ dic_torch = {}
74
+
75
+ for k, v in dic.items():
76
+ dic_torch[target_name + "." + k] = torch.from_numpy(v)
77
+
78
+ return dic_torch
79
+
80
+
81
+ def convert_dense(vars, source_name, target_name):
82
+ weight = vars[source_name + "/weight"].value().eval()
83
+ bias = vars[source_name + "/bias"].value().eval()
84
+
85
+ dic = {"weight": weight.transpose((1, 0)), "bias": bias}
86
+
87
+ dic_torch = {}
88
+
89
+ for k, v in dic.items():
90
+ dic_torch[target_name + "." + k] = torch.from_numpy(v)
91
+
92
+ return dic_torch
93
+
94
+
95
+ def update(state_dict, new):
96
+ for k, v in new.items():
97
+ if k not in state_dict:
98
+ raise KeyError(k + " is not found")
99
+
100
+ if v.shape != state_dict[k].shape:
101
+ raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}")
102
+
103
+ state_dict[k] = v
104
+
105
+
106
+ def discriminator_fill_statedict(statedict, vars, size):
107
+ log_size = int(math.log(size, 2))
108
+
109
+ update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0"))
110
+
111
+ conv_i = 1
112
+
113
+ for i in range(log_size - 2, 0, -1):
114
+ reso = 4 * 2 ** i
115
+ update(
116
+ statedict,
117
+ convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"),
118
+ )
119
+ update(
120
+ statedict,
121
+ convert_conv(
122
+ vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1
123
+ ),
124
+ )
125
+ update(
126
+ statedict,
127
+ convert_conv(
128
+ vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False
129
+ ),
130
+ )
131
+ conv_i += 1
132
+
133
+ update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv"))
134
+ update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0"))
135
+ update(statedict, convert_dense(vars, f"Output", "final_linear.1"))
136
+
137
+ return statedict
138
+
139
+
140
+ def fill_statedict(state_dict, vars, size, n_mlp):
141
+ log_size = int(math.log(size, 2))
142
+
143
+ for i in range(n_mlp):
144
+ update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}"))
145
+
146
+ update(
147
+ state_dict,
148
+ {
149
+ "input.input": torch.from_numpy(
150
+ vars["G_synthesis/4x4/Const/const"].value().eval()
151
+ )
152
+ },
153
+ )
154
+
155
+ update(state_dict, convert_torgb(vars, "G_synthesis/4x4/ToRGB", "to_rgb1"))
156
+
157
+ for i in range(log_size - 2):
158
+ reso = 4 * 2 ** (i + 1)
159
+ update(
160
+ state_dict,
161
+ convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"),
162
+ )
163
+
164
+ update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "conv1"))
165
+
166
+ conv_i = 0
167
+
168
+ for i in range(log_size - 2):
169
+ reso = 4 * 2 ** (i + 1)
170
+ update(
171
+ state_dict,
172
+ convert_modconv(
173
+ vars,
174
+ f"G_synthesis/{reso}x{reso}/Conv0_up",
175
+ f"convs.{conv_i}",
176
+ flip=True,
177
+ ),
178
+ )
179
+ update(
180
+ state_dict,
181
+ convert_modconv(
182
+ vars, f"G_synthesis/{reso}x{reso}/Conv1", f"convs.{conv_i + 1}"
183
+ ),
184
+ )
185
+ conv_i += 2
186
+
187
+ for i in range(0, (log_size - 2) * 2 + 1):
188
+ update(
189
+ state_dict,
190
+ {
191
+ f"noises.noise_{i}": torch.from_numpy(
192
+ vars[f"G_synthesis/noise{i}"].value().eval()
193
+ )
194
+ },
195
+ )
196
+
197
+ return state_dict
198
+
199
+
200
+ if __name__ == "__main__":
201
+ device = "cuda"
202
+
203
+ parser = argparse.ArgumentParser(
204
+ description="Tensorflow to pytorch model checkpoint converter"
205
+ )
206
+ parser.add_argument(
207
+ "--repo",
208
+ type=str,
209
+ required=True,
210
+ help="path to the offical StyleGAN2 repository with dnnlib/ folder",
211
+ )
212
+ parser.add_argument(
213
+ "--gen", action="store_true", help="convert the generator weights"
214
+ )
215
+ parser.add_argument(
216
+ "--disc", action="store_true", help="convert the discriminator weights"
217
+ )
218
+ parser.add_argument(
219
+ "--channel_multiplier",
220
+ type=int,
221
+ default=2,
222
+ help="channel multiplier factor. config-f = 2, else = 1",
223
+ )
224
+ parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights")
225
+
226
+ args = parser.parse_args()
227
+
228
+ sys.path.append(args.repo)
229
+
230
+ import dnnlib
231
+ from dnnlib import tflib
232
+
233
+ tflib.init_tf()
234
+
235
+ with open(args.path, "rb") as f:
236
+ generator, discriminator, g_ema = pickle.load(f)
237
+
238
+ size = g_ema.output_shape[2]
239
+
240
+ n_mlp = 0
241
+ mapping_layers_names = g_ema.__getstate__()['components']['mapping'].list_layers()
242
+ for layer in mapping_layers_names:
243
+ if layer[0].startswith('Dense'):
244
+ n_mlp += 1
245
+
246
+ g = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
247
+ state_dict = g.state_dict()
248
+ state_dict = fill_statedict(state_dict, g_ema.vars, size, n_mlp)
249
+
250
+ g.load_state_dict(state_dict)
251
+
252
+ latent_avg = torch.from_numpy(g_ema.vars["dlatent_avg"].value().eval())
253
+
254
+ ckpt = {"g_ema": state_dict, "latent_avg": latent_avg}
255
+
256
+ if args.gen:
257
+ g_train = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier)
258
+ g_train_state = g_train.state_dict()
259
+ g_train_state = fill_statedict(g_train_state, generator.vars, size, n_mlp)
260
+ ckpt["g"] = g_train_state
261
+
262
+ if args.disc:
263
+ disc = Discriminator(size, channel_multiplier=args.channel_multiplier)
264
+ d_state = disc.state_dict()
265
+ d_state = discriminator_fill_statedict(d_state, discriminator.vars, size)
266
+ ckpt["d"] = d_state
267
+
268
+ name = os.path.splitext(os.path.basename(args.path))[0]
269
+ torch.save(ckpt, name + ".pt")
270
+
271
+ batch_size = {256: 16, 512: 9, 1024: 4}
272
+ n_sample = batch_size.get(size, 25)
273
+
274
+ g = g.to(device)
275
+
276
+ z = np.random.RandomState(0).randn(n_sample, 512).astype("float32")
277
+
278
+ with torch.no_grad():
279
+ img_pt, _ = g(
280
+ [torch.from_numpy(z).to(device)],
281
+ truncation=0.5,
282
+ truncation_latent=latent_avg.to(device),
283
+ randomize_noise=False,
284
+ )
285
+
286
+ Gs_kwargs = dnnlib.EasyDict()
287
+ Gs_kwargs.randomize_noise = False
288
+ img_tf = g_ema.run(z, None, **Gs_kwargs)
289
+ img_tf = torch.from_numpy(img_tf).to(device)
290
+
291
+ img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(
292
+ 0.0, 1.0
293
+ )
294
+
295
+ img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0)
296
+
297
+ print(img_diff.abs().max())
298
+
299
+ utils.save_image(
300
+ img_concat, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
301
+ )
dataset.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+
3
+ import lmdb
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ class MultiResolutionDataset(Dataset):
9
+ def __init__(self, path, transform, resolution=256):
10
+ self.env = lmdb.open(
11
+ path,
12
+ max_readers=32,
13
+ readonly=True,
14
+ lock=False,
15
+ readahead=False,
16
+ meminit=False,
17
+ )
18
+
19
+ if not self.env:
20
+ raise IOError('Cannot open lmdb dataset', path)
21
+
22
+ with self.env.begin(write=False) as txn:
23
+ self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
24
+
25
+ self.resolution = resolution
26
+ self.transform = transform
27
+
28
+ def __len__(self):
29
+ return self.length
30
+
31
+ def __getitem__(self, index):
32
+ with self.env.begin(write=False) as txn:
33
+ key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
34
+ img_bytes = txn.get(key)
35
+
36
+ buffer = BytesIO(img_bytes)
37
+ img = Image.open(buffer)
38
+ img = self.transform(img)
39
+
40
+ return img
distributed.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import pickle
3
+
4
+ import torch
5
+ from torch import distributed as dist
6
+ from torch.utils.data.sampler import Sampler
7
+
8
+
9
+ def get_rank():
10
+ if not dist.is_available():
11
+ return 0
12
+
13
+ if not dist.is_initialized():
14
+ return 0
15
+
16
+ return dist.get_rank()
17
+
18
+
19
+ def synchronize():
20
+ if not dist.is_available():
21
+ return
22
+
23
+ if not dist.is_initialized():
24
+ return
25
+
26
+ world_size = dist.get_world_size()
27
+
28
+ if world_size == 1:
29
+ return
30
+
31
+ dist.barrier()
32
+
33
+
34
+ def get_world_size():
35
+ if not dist.is_available():
36
+ return 1
37
+
38
+ if not dist.is_initialized():
39
+ return 1
40
+
41
+ return dist.get_world_size()
42
+
43
+
44
+ def reduce_sum(tensor):
45
+ if not dist.is_available():
46
+ return tensor
47
+
48
+ if not dist.is_initialized():
49
+ return tensor
50
+
51
+ tensor = tensor.clone()
52
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
53
+
54
+ return tensor
55
+
56
+
57
+ def gather_grad(params):
58
+ world_size = get_world_size()
59
+
60
+ if world_size == 1:
61
+ return
62
+
63
+ for param in params:
64
+ if param.grad is not None:
65
+ dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
66
+ param.grad.data.div_(world_size)
67
+
68
+
69
+ def all_gather(data):
70
+ world_size = get_world_size()
71
+
72
+ if world_size == 1:
73
+ return [data]
74
+
75
+ buffer = pickle.dumps(data)
76
+ storage = torch.ByteStorage.from_buffer(buffer)
77
+ tensor = torch.ByteTensor(storage).to('cuda')
78
+
79
+ local_size = torch.IntTensor([tensor.numel()]).to('cuda')
80
+ size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
81
+ dist.all_gather(size_list, local_size)
82
+ size_list = [int(size.item()) for size in size_list]
83
+ max_size = max(size_list)
84
+
85
+ tensor_list = []
86
+ for _ in size_list:
87
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
88
+
89
+ if local_size != max_size:
90
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
91
+ tensor = torch.cat((tensor, padding), 0)
92
+
93
+ dist.all_gather(tensor_list, tensor)
94
+
95
+ data_list = []
96
+
97
+ for size, tensor in zip(size_list, tensor_list):
98
+ buffer = tensor.cpu().numpy().tobytes()[:size]
99
+ data_list.append(pickle.loads(buffer))
100
+
101
+ return data_list
102
+
103
+
104
+ def reduce_loss_dict(loss_dict):
105
+ world_size = get_world_size()
106
+
107
+ if world_size < 2:
108
+ return loss_dict
109
+
110
+ with torch.no_grad():
111
+ keys = []
112
+ losses = []
113
+
114
+ for k in sorted(loss_dict.keys()):
115
+ keys.append(k)
116
+ losses.append(loss_dict[k])
117
+
118
+ losses = torch.stack(losses, 0)
119
+ dist.reduce(losses, dst=0)
120
+
121
+ if dist.get_rank() == 0:
122
+ losses /= world_size
123
+
124
+ reduced_losses = {k: v for k, v in zip(keys, losses)}
125
+
126
+ return reduced_losses
fid.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pickle
3
+
4
+ import torch
5
+ from torch import nn
6
+ import numpy as np
7
+ from scipy import linalg
8
+ from tqdm import tqdm
9
+
10
+ from model import Generator
11
+ from calc_inception import load_patched_inception_v3
12
+
13
+
14
+ @torch.no_grad()
15
+ def extract_feature_from_samples(
16
+ generator, inception, truncation, truncation_latent, batch_size, n_sample, device
17
+ ):
18
+ n_batch = n_sample // batch_size
19
+ resid = n_sample - (n_batch * batch_size)
20
+ batch_sizes = [batch_size] * n_batch + [resid]
21
+ features = []
22
+
23
+ for batch in tqdm(batch_sizes):
24
+ latent = torch.randn(batch, 512, device=device)
25
+ img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent)
26
+ feat = inception(img)[0].view(img.shape[0], -1)
27
+ features.append(feat.to("cpu"))
28
+
29
+ features = torch.cat(features, 0)
30
+
31
+ return features
32
+
33
+
34
+ def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
35
+ cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)
36
+
37
+ if not np.isfinite(cov_sqrt).all():
38
+ print("product of cov matrices is singular")
39
+ offset = np.eye(sample_cov.shape[0]) * eps
40
+ cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))
41
+
42
+ if np.iscomplexobj(cov_sqrt):
43
+ if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
44
+ m = np.max(np.abs(cov_sqrt.imag))
45
+
46
+ raise ValueError(f"Imaginary component {m}")
47
+
48
+ cov_sqrt = cov_sqrt.real
49
+
50
+ mean_diff = sample_mean - real_mean
51
+ mean_norm = mean_diff @ mean_diff
52
+
53
+ trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)
54
+
55
+ fid = mean_norm + trace
56
+
57
+ return fid
58
+
59
+
60
+ if __name__ == "__main__":
61
+ device = "cuda"
62
+
63
+ parser = argparse.ArgumentParser(description="Calculate FID scores")
64
+
65
+ parser.add_argument("--truncation", type=float, default=1, help="truncation factor")
66
+ parser.add_argument(
67
+ "--truncation_mean",
68
+ type=int,
69
+ default=4096,
70
+ help="number of samples to calculate mean for truncation",
71
+ )
72
+ parser.add_argument(
73
+ "--batch", type=int, default=64, help="batch size for the generator"
74
+ )
75
+ parser.add_argument(
76
+ "--n_sample",
77
+ type=int,
78
+ default=50000,
79
+ help="number of the samples for calculating FID",
80
+ )
81
+ parser.add_argument(
82
+ "--size", type=int, default=256, help="image sizes for generator"
83
+ )
84
+ parser.add_argument(
85
+ "--inception",
86
+ type=str,
87
+ default=None,
88
+ required=True,
89
+ help="path to precomputed inception embedding",
90
+ )
91
+ parser.add_argument(
92
+ "ckpt", metavar="CHECKPOINT", help="path to generator checkpoint"
93
+ )
94
+
95
+ args = parser.parse_args()
96
+
97
+ ckpt = torch.load(args.ckpt)
98
+
99
+ g = Generator(args.size, 512, 8).to(device)
100
+ g.load_state_dict(ckpt["g_ema"])
101
+ g = nn.DataParallel(g)
102
+ g.eval()
103
+
104
+ if args.truncation < 1:
105
+ with torch.no_grad():
106
+ mean_latent = g.mean_latent(args.truncation_mean)
107
+
108
+ else:
109
+ mean_latent = None
110
+
111
+ inception = nn.DataParallel(load_patched_inception_v3()).to(device)
112
+ inception.eval()
113
+
114
+ features = extract_feature_from_samples(
115
+ g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device
116
+ ).numpy()
117
+ print(f"extracted {features.shape[0]} features")
118
+
119
+ sample_mean = np.mean(features, 0)
120
+ sample_cov = np.cov(features, rowvar=False)
121
+
122
+ with open(args.inception, "rb") as f:
123
+ embeds = pickle.load(f)
124
+ real_mean = embeds["mean"]
125
+ real_cov = embeds["cov"]
126
+
127
+ fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)
128
+
129
+ print("fid:", fid)
generate.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from torchvision import utils
5
+ from model import Generator
6
+ from tqdm import tqdm
7
+
8
+
9
+ def generate(args, g_ema, device, mean_latent):
10
+
11
+ with torch.no_grad():
12
+ g_ema.eval()
13
+ for i in tqdm(range(args.pics)):
14
+ sample_z = torch.randn(args.sample, args.latent, device=device)
15
+
16
+ sample, _ = g_ema(
17
+ [sample_z], truncation=args.truncation, truncation_latent=mean_latent
18
+ )
19
+
20
+ utils.save_image(
21
+ sample,
22
+ f"sample/{str(i).zfill(6)}.png",
23
+ nrow=1,
24
+ normalize=True,
25
+ range=(-1, 1),
26
+ )
27
+
28
+
29
+ if __name__ == "__main__":
30
+ device = "cuda"
31
+
32
+ parser = argparse.ArgumentParser(description="Generate samples from the generator")
33
+
34
+ parser.add_argument(
35
+ "--size", type=int, default=1024, help="output image size of the generator"
36
+ )
37
+ parser.add_argument(
38
+ "--sample",
39
+ type=int,
40
+ default=1,
41
+ help="number of samples to be generated for each image",
42
+ )
43
+ parser.add_argument(
44
+ "--pics", type=int, default=20, help="number of images to be generated"
45
+ )
46
+ parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
47
+ parser.add_argument(
48
+ "--truncation_mean",
49
+ type=int,
50
+ default=4096,
51
+ help="number of vectors to calculate mean for the truncation",
52
+ )
53
+ parser.add_argument(
54
+ "--ckpt",
55
+ type=str,
56
+ default="stylegan2-ffhq-config-f.pt",
57
+ help="path to the model checkpoint",
58
+ )
59
+ parser.add_argument(
60
+ "--channel_multiplier",
61
+ type=int,
62
+ default=2,
63
+ help="channel multiplier of the generator. config-f = 2, else = 1",
64
+ )
65
+
66
+ args = parser.parse_args()
67
+
68
+ args.latent = 512
69
+ args.n_mlp = 8
70
+
71
+ g_ema = Generator(
72
+ args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
73
+ ).to(device)
74
+ checkpoint = torch.load(args.ckpt)
75
+
76
+ g_ema.load_state_dict(checkpoint["g_ema"])
77
+
78
+ if args.truncation < 1:
79
+ with torch.no_grad():
80
+ mean_latent = g_ema.mean_latent(args.truncation_mean)
81
+ else:
82
+ mean_latent = None
83
+
84
+ generate(args, g_ema, device, mean_latent)
inception.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import models
5
+
6
+ try:
7
+ from torchvision.models.utils import load_state_dict_from_url
8
+ except ImportError:
9
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
+
11
+ # Inception weights ported to Pytorch from
12
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
14
+
15
+
16
+ class InceptionV3(nn.Module):
17
+ """Pretrained InceptionV3 network returning feature maps"""
18
+
19
+ # Index of default block of inception to return,
20
+ # corresponds to output of final average pooling
21
+ DEFAULT_BLOCK_INDEX = 3
22
+
23
+ # Maps feature dimensionality to their output blocks indices
24
+ BLOCK_INDEX_BY_DIM = {
25
+ 64: 0, # First max pooling features
26
+ 192: 1, # Second max pooling featurs
27
+ 768: 2, # Pre-aux classifier features
28
+ 2048: 3 # Final average pooling features
29
+ }
30
+
31
+ def __init__(self,
32
+ output_blocks=[DEFAULT_BLOCK_INDEX],
33
+ resize_input=True,
34
+ normalize_input=True,
35
+ requires_grad=False,
36
+ use_fid_inception=True):
37
+ """Build pretrained InceptionV3
38
+
39
+ Parameters
40
+ ----------
41
+ output_blocks : list of int
42
+ Indices of blocks to return features of. Possible values are:
43
+ - 0: corresponds to output of first max pooling
44
+ - 1: corresponds to output of second max pooling
45
+ - 2: corresponds to output which is fed to aux classifier
46
+ - 3: corresponds to output of final average pooling
47
+ resize_input : bool
48
+ If true, bilinearly resizes input to width and height 299 before
49
+ feeding input to model. As the network without fully connected
50
+ layers is fully convolutional, it should be able to handle inputs
51
+ of arbitrary size, so resizing might not be strictly needed
52
+ normalize_input : bool
53
+ If true, scales the input from range (0, 1) to the range the
54
+ pretrained Inception network expects, namely (-1, 1)
55
+ requires_grad : bool
56
+ If true, parameters of the model require gradients. Possibly useful
57
+ for finetuning the network
58
+ use_fid_inception : bool
59
+ If true, uses the pretrained Inception model used in Tensorflow's
60
+ FID implementation. If false, uses the pretrained Inception model
61
+ available in torchvision. The FID Inception model has different
62
+ weights and a slightly different structure from torchvision's
63
+ Inception model. If you want to compute FID scores, you are
64
+ strongly advised to set this parameter to true to get comparable
65
+ results.
66
+ """
67
+ super(InceptionV3, self).__init__()
68
+
69
+ self.resize_input = resize_input
70
+ self.normalize_input = normalize_input
71
+ self.output_blocks = sorted(output_blocks)
72
+ self.last_needed_block = max(output_blocks)
73
+
74
+ assert self.last_needed_block <= 3, \
75
+ 'Last possible output block index is 3'
76
+
77
+ self.blocks = nn.ModuleList()
78
+
79
+ if use_fid_inception:
80
+ inception = fid_inception_v3()
81
+ else:
82
+ inception = models.inception_v3(pretrained=True)
83
+
84
+ # Block 0: input to maxpool1
85
+ block0 = [
86
+ inception.Conv2d_1a_3x3,
87
+ inception.Conv2d_2a_3x3,
88
+ inception.Conv2d_2b_3x3,
89
+ nn.MaxPool2d(kernel_size=3, stride=2)
90
+ ]
91
+ self.blocks.append(nn.Sequential(*block0))
92
+
93
+ # Block 1: maxpool1 to maxpool2
94
+ if self.last_needed_block >= 1:
95
+ block1 = [
96
+ inception.Conv2d_3b_1x1,
97
+ inception.Conv2d_4a_3x3,
98
+ nn.MaxPool2d(kernel_size=3, stride=2)
99
+ ]
100
+ self.blocks.append(nn.Sequential(*block1))
101
+
102
+ # Block 2: maxpool2 to aux classifier
103
+ if self.last_needed_block >= 2:
104
+ block2 = [
105
+ inception.Mixed_5b,
106
+ inception.Mixed_5c,
107
+ inception.Mixed_5d,
108
+ inception.Mixed_6a,
109
+ inception.Mixed_6b,
110
+ inception.Mixed_6c,
111
+ inception.Mixed_6d,
112
+ inception.Mixed_6e,
113
+ ]
114
+ self.blocks.append(nn.Sequential(*block2))
115
+
116
+ # Block 3: aux classifier to final avgpool
117
+ if self.last_needed_block >= 3:
118
+ block3 = [
119
+ inception.Mixed_7a,
120
+ inception.Mixed_7b,
121
+ inception.Mixed_7c,
122
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
123
+ ]
124
+ self.blocks.append(nn.Sequential(*block3))
125
+
126
+ for param in self.parameters():
127
+ param.requires_grad = requires_grad
128
+
129
+ def forward(self, inp):
130
+ """Get Inception feature maps
131
+
132
+ Parameters
133
+ ----------
134
+ inp : torch.autograd.Variable
135
+ Input tensor of shape Bx3xHxW. Values are expected to be in
136
+ range (0, 1)
137
+
138
+ Returns
139
+ -------
140
+ List of torch.autograd.Variable, corresponding to the selected output
141
+ block, sorted ascending by index
142
+ """
143
+ outp = []
144
+ x = inp
145
+
146
+ if self.resize_input:
147
+ x = F.interpolate(x,
148
+ size=(299, 299),
149
+ mode='bilinear',
150
+ align_corners=False)
151
+
152
+ if self.normalize_input:
153
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154
+
155
+ for idx, block in enumerate(self.blocks):
156
+ x = block(x)
157
+ if idx in self.output_blocks:
158
+ outp.append(x)
159
+
160
+ if idx == self.last_needed_block:
161
+ break
162
+
163
+ return outp
164
+
165
+
166
+ def fid_inception_v3():
167
+ """Build pretrained Inception model for FID computation
168
+
169
+ The Inception model for FID computation uses a different set of weights
170
+ and has a slightly different structure than torchvision's Inception.
171
+
172
+ This method first constructs torchvision's Inception and then patches the
173
+ necessary parts that are different in the FID Inception model.
174
+ """
175
+ inception = models.inception_v3(num_classes=1008,
176
+ aux_logits=False,
177
+ pretrained=False)
178
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
179
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
180
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
181
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
182
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
183
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
184
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
185
+ inception.Mixed_7b = FIDInceptionE_1(1280)
186
+ inception.Mixed_7c = FIDInceptionE_2(2048)
187
+
188
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
189
+ inception.load_state_dict(state_dict)
190
+ return inception
191
+
192
+
193
+ class FIDInceptionA(models.inception.InceptionA):
194
+ """InceptionA block patched for FID computation"""
195
+ def __init__(self, in_channels, pool_features):
196
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
197
+
198
+ def forward(self, x):
199
+ branch1x1 = self.branch1x1(x)
200
+
201
+ branch5x5 = self.branch5x5_1(x)
202
+ branch5x5 = self.branch5x5_2(branch5x5)
203
+
204
+ branch3x3dbl = self.branch3x3dbl_1(x)
205
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
206
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
207
+
208
+ # Patch: Tensorflow's average pool does not use the padded zero's in
209
+ # its average calculation
210
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
211
+ count_include_pad=False)
212
+ branch_pool = self.branch_pool(branch_pool)
213
+
214
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
215
+ return torch.cat(outputs, 1)
216
+
217
+
218
+ class FIDInceptionC(models.inception.InceptionC):
219
+ """InceptionC block patched for FID computation"""
220
+ def __init__(self, in_channels, channels_7x7):
221
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
222
+
223
+ def forward(self, x):
224
+ branch1x1 = self.branch1x1(x)
225
+
226
+ branch7x7 = self.branch7x7_1(x)
227
+ branch7x7 = self.branch7x7_2(branch7x7)
228
+ branch7x7 = self.branch7x7_3(branch7x7)
229
+
230
+ branch7x7dbl = self.branch7x7dbl_1(x)
231
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
232
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
233
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
234
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
235
+
236
+ # Patch: Tensorflow's average pool does not use the padded zero's in
237
+ # its average calculation
238
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
239
+ count_include_pad=False)
240
+ branch_pool = self.branch_pool(branch_pool)
241
+
242
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
243
+ return torch.cat(outputs, 1)
244
+
245
+
246
+ class FIDInceptionE_1(models.inception.InceptionE):
247
+ """First InceptionE block patched for FID computation"""
248
+ def __init__(self, in_channels):
249
+ super(FIDInceptionE_1, self).__init__(in_channels)
250
+
251
+ def forward(self, x):
252
+ branch1x1 = self.branch1x1(x)
253
+
254
+ branch3x3 = self.branch3x3_1(x)
255
+ branch3x3 = [
256
+ self.branch3x3_2a(branch3x3),
257
+ self.branch3x3_2b(branch3x3),
258
+ ]
259
+ branch3x3 = torch.cat(branch3x3, 1)
260
+
261
+ branch3x3dbl = self.branch3x3dbl_1(x)
262
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
263
+ branch3x3dbl = [
264
+ self.branch3x3dbl_3a(branch3x3dbl),
265
+ self.branch3x3dbl_3b(branch3x3dbl),
266
+ ]
267
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
268
+
269
+ # Patch: Tensorflow's average pool does not use the padded zero's in
270
+ # its average calculation
271
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
272
+ count_include_pad=False)
273
+ branch_pool = self.branch_pool(branch_pool)
274
+
275
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
276
+ return torch.cat(outputs, 1)
277
+
278
+
279
+ class FIDInceptionE_2(models.inception.InceptionE):
280
+ """Second InceptionE block patched for FID computation"""
281
+ def __init__(self, in_channels):
282
+ super(FIDInceptionE_2, self).__init__(in_channels)
283
+
284
+ def forward(self, x):
285
+ branch1x1 = self.branch1x1(x)
286
+
287
+ branch3x3 = self.branch3x3_1(x)
288
+ branch3x3 = [
289
+ self.branch3x3_2a(branch3x3),
290
+ self.branch3x3_2b(branch3x3),
291
+ ]
292
+ branch3x3 = torch.cat(branch3x3, 1)
293
+
294
+ branch3x3dbl = self.branch3x3dbl_1(x)
295
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
296
+ branch3x3dbl = [
297
+ self.branch3x3dbl_3a(branch3x3dbl),
298
+ self.branch3x3dbl_3b(branch3x3dbl),
299
+ ]
300
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
301
+
302
+ # Patch: The FID Inception model uses max pooling instead of average
303
+ # pooling. This is likely an error in this specific Inception
304
+ # implementation, as other Inception models use average pooling here
305
+ # (which matches the description in the paper).
306
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
307
+ branch_pool = self.branch_pool(branch_pool)
308
+
309
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
310
+ return torch.cat(outputs, 1)
lpips/__init__.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import numpy as np
7
+ from skimage.measure import compare_ssim
8
+ import torch
9
+ from torch.autograd import Variable
10
+
11
+ from lpips import dist_model
12
+
13
+ class PerceptualLoss(torch.nn.Module):
14
+ def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
15
+ # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
16
+ super(PerceptualLoss, self).__init__()
17
+ print('Setting up Perceptual loss...')
18
+ self.use_gpu = use_gpu
19
+ self.spatial = spatial
20
+ self.gpu_ids = gpu_ids
21
+ self.model = dist_model.DistModel()
22
+ self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
23
+ print('...[%s] initialized'%self.model.name())
24
+ print('...Done')
25
+
26
+ def forward(self, pred, target, normalize=False):
27
+ """
28
+ Pred and target are Variables.
29
+ If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
30
+ If normalize is False, assumes the images are already between [-1,+1]
31
+
32
+ Inputs pred and target are Nx3xHxW
33
+ Output pytorch Variable N long
34
+ """
35
+
36
+ if normalize:
37
+ target = 2 * target - 1
38
+ pred = 2 * pred - 1
39
+
40
+ return self.model.forward(target, pred)
41
+
42
+ def normalize_tensor(in_feat,eps=1e-10):
43
+ norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
44
+ return in_feat/(norm_factor+eps)
45
+
46
+ def l2(p0, p1, range=255.):
47
+ return .5*np.mean((p0 / range - p1 / range)**2)
48
+
49
+ def psnr(p0, p1, peak=255.):
50
+ return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
51
+
52
+ def dssim(p0, p1, range=255.):
53
+ return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
54
+
55
+ def rgb2lab(in_img,mean_cent=False):
56
+ from skimage import color
57
+ img_lab = color.rgb2lab(in_img)
58
+ if(mean_cent):
59
+ img_lab[:,:,0] = img_lab[:,:,0]-50
60
+ return img_lab
61
+
62
+ def tensor2np(tensor_obj):
63
+ # change dimension of a tensor object into a numpy array
64
+ return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
65
+
66
+ def np2tensor(np_obj):
67
+ # change dimenion of np array into tensor array
68
+ return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
69
+
70
+ def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
71
+ # image tensor to lab tensor
72
+ from skimage import color
73
+
74
+ img = tensor2im(image_tensor)
75
+ img_lab = color.rgb2lab(img)
76
+ if(mc_only):
77
+ img_lab[:,:,0] = img_lab[:,:,0]-50
78
+ if(to_norm and not mc_only):
79
+ img_lab[:,:,0] = img_lab[:,:,0]-50
80
+ img_lab = img_lab/100.
81
+
82
+ return np2tensor(img_lab)
83
+
84
+ def tensorlab2tensor(lab_tensor,return_inbnd=False):
85
+ from skimage import color
86
+ import warnings
87
+ warnings.filterwarnings("ignore")
88
+
89
+ lab = tensor2np(lab_tensor)*100.
90
+ lab[:,:,0] = lab[:,:,0]+50
91
+
92
+ rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
93
+ if(return_inbnd):
94
+ # convert back to lab, see if we match
95
+ lab_back = color.rgb2lab(rgb_back.astype('uint8'))
96
+ mask = 1.*np.isclose(lab_back,lab,atol=2.)
97
+ mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
98
+ return (im2tensor(rgb_back),mask)
99
+ else:
100
+ return im2tensor(rgb_back)
101
+
102
+ def rgb2lab(input):
103
+ from skimage import color
104
+ return color.rgb2lab(input / 255.)
105
+
106
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
107
+ image_numpy = image_tensor[0].cpu().float().numpy()
108
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
109
+ return image_numpy.astype(imtype)
110
+
111
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
112
+ return torch.Tensor((image / factor - cent)
113
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
114
+
115
+ def tensor2vec(vector_tensor):
116
+ return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
117
+
118
+ def voc_ap(rec, prec, use_07_metric=False):
119
+ """ ap = voc_ap(rec, prec, [use_07_metric])
120
+ Compute VOC AP given precision and recall.
121
+ If use_07_metric is true, uses the
122
+ VOC 07 11 point method (default:False).
123
+ """
124
+ if use_07_metric:
125
+ # 11 point metric
126
+ ap = 0.
127
+ for t in np.arange(0., 1.1, 0.1):
128
+ if np.sum(rec >= t) == 0:
129
+ p = 0
130
+ else:
131
+ p = np.max(prec[rec >= t])
132
+ ap = ap + p / 11.
133
+ else:
134
+ # correct AP calculation
135
+ # first append sentinel values at the end
136
+ mrec = np.concatenate(([0.], rec, [1.]))
137
+ mpre = np.concatenate(([0.], prec, [0.]))
138
+
139
+ # compute the precision envelope
140
+ for i in range(mpre.size - 1, 0, -1):
141
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
142
+
143
+ # to calculate area under PR curve, look for points
144
+ # where X axis (recall) changes value
145
+ i = np.where(mrec[1:] != mrec[:-1])[0]
146
+
147
+ # and sum (\Delta recall) * prec
148
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
149
+ return ap
150
+
151
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
152
+ # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
153
+ image_numpy = image_tensor[0].cpu().float().numpy()
154
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
155
+ return image_numpy.astype(imtype)
156
+
157
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
158
+ # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
159
+ return torch.Tensor((image / factor - cent)
160
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
lpips/base_model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch.autograd import Variable
5
+ from pdb import set_trace as st
6
+ from IPython import embed
7
+
8
+ class BaseModel():
9
+ def __init__(self):
10
+ pass;
11
+
12
+ def name(self):
13
+ return 'BaseModel'
14
+
15
+ def initialize(self, use_gpu=True, gpu_ids=[0]):
16
+ self.use_gpu = use_gpu
17
+ self.gpu_ids = gpu_ids
18
+
19
+ def forward(self):
20
+ pass
21
+
22
+ def get_image_paths(self):
23
+ pass
24
+
25
+ def optimize_parameters(self):
26
+ pass
27
+
28
+ def get_current_visuals(self):
29
+ return self.input
30
+
31
+ def get_current_errors(self):
32
+ return {}
33
+
34
+ def save(self, label):
35
+ pass
36
+
37
+ # helper saving function that can be used by subclasses
38
+ def save_network(self, network, path, network_label, epoch_label):
39
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
40
+ save_path = os.path.join(path, save_filename)
41
+ torch.save(network.state_dict(), save_path)
42
+
43
+ # helper loading function that can be used by subclasses
44
+ def load_network(self, network, network_label, epoch_label):
45
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
46
+ save_path = os.path.join(self.save_dir, save_filename)
47
+ print('Loading network from %s'%save_path)
48
+ network.load_state_dict(torch.load(save_path))
49
+
50
+ def update_learning_rate():
51
+ pass
52
+
53
+ def get_image_paths(self):
54
+ return self.image_paths
55
+
56
+ def save_done(self, flag=False):
57
+ np.save(os.path.join(self.save_dir, 'done_flag'),flag)
58
+ np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
lpips/dist_model.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+
4
+ import sys
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ import os
9
+ from collections import OrderedDict
10
+ from torch.autograd import Variable
11
+ import itertools
12
+ from .base_model import BaseModel
13
+ from scipy.ndimage import zoom
14
+ import fractions
15
+ import functools
16
+ import skimage.transform
17
+ from tqdm import tqdm
18
+
19
+ from IPython import embed
20
+
21
+ from . import networks_basic as networks
22
+ import lpips as util
23
+
24
+ class DistModel(BaseModel):
25
+ def name(self):
26
+ return self.model_name
27
+
28
+ def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
29
+ use_gpu=True, printNet=False, spatial=False,
30
+ is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
31
+ '''
32
+ INPUTS
33
+ model - ['net-lin'] for linearly calibrated network
34
+ ['net'] for off-the-shelf network
35
+ ['L2'] for L2 distance in Lab colorspace
36
+ ['SSIM'] for ssim in RGB colorspace
37
+ net - ['squeeze','alex','vgg']
38
+ model_path - if None, will look in weights/[NET_NAME].pth
39
+ colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
40
+ use_gpu - bool - whether or not to use a GPU
41
+ printNet - bool - whether or not to print network architecture out
42
+ spatial - bool - whether to output an array containing varying distances across spatial dimensions
43
+ spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
44
+ spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
45
+ spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
46
+ is_train - bool - [True] for training mode
47
+ lr - float - initial learning rate
48
+ beta1 - float - initial momentum term for adam
49
+ version - 0.1 for latest, 0.0 was original (with a bug)
50
+ gpu_ids - int array - [0] by default, gpus to use
51
+ '''
52
+ BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
53
+
54
+ self.model = model
55
+ self.net = net
56
+ self.is_train = is_train
57
+ self.spatial = spatial
58
+ self.gpu_ids = gpu_ids
59
+ self.model_name = '%s [%s]'%(model,net)
60
+
61
+ if(self.model == 'net-lin'): # pretrained net + linear layer
62
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
63
+ use_dropout=True, spatial=spatial, version=version, lpips=True)
64
+ kw = {}
65
+ if not use_gpu:
66
+ kw['map_location'] = 'cpu'
67
+ if(model_path is None):
68
+ import inspect
69
+ model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))
70
+
71
+ if(not is_train):
72
+ print('Loading model from: %s'%model_path)
73
+ self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
74
+
75
+ elif(self.model=='net'): # pretrained network
76
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
77
+ elif(self.model in ['L2','l2']):
78
+ self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
79
+ self.model_name = 'L2'
80
+ elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
81
+ self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
82
+ self.model_name = 'SSIM'
83
+ else:
84
+ raise ValueError("Model [%s] not recognized." % self.model)
85
+
86
+ self.parameters = list(self.net.parameters())
87
+
88
+ if self.is_train: # training mode
89
+ # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
90
+ self.rankLoss = networks.BCERankingLoss()
91
+ self.parameters += list(self.rankLoss.net.parameters())
92
+ self.lr = lr
93
+ self.old_lr = lr
94
+ self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
95
+ else: # test mode
96
+ self.net.eval()
97
+
98
+ if(use_gpu):
99
+ self.net.to(gpu_ids[0])
100
+ self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
101
+ if(self.is_train):
102
+ self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
103
+
104
+ if(printNet):
105
+ print('---------- Networks initialized -------------')
106
+ networks.print_network(self.net)
107
+ print('-----------------------------------------------')
108
+
109
+ def forward(self, in0, in1, retPerLayer=False):
110
+ ''' Function computes the distance between image patches in0 and in1
111
+ INPUTS
112
+ in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
113
+ OUTPUT
114
+ computed distances between in0 and in1
115
+ '''
116
+
117
+ return self.net.forward(in0, in1, retPerLayer=retPerLayer)
118
+
119
+ # ***** TRAINING FUNCTIONS *****
120
+ def optimize_parameters(self):
121
+ self.forward_train()
122
+ self.optimizer_net.zero_grad()
123
+ self.backward_train()
124
+ self.optimizer_net.step()
125
+ self.clamp_weights()
126
+
127
+ def clamp_weights(self):
128
+ for module in self.net.modules():
129
+ if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
130
+ module.weight.data = torch.clamp(module.weight.data,min=0)
131
+
132
+ def set_input(self, data):
133
+ self.input_ref = data['ref']
134
+ self.input_p0 = data['p0']
135
+ self.input_p1 = data['p1']
136
+ self.input_judge = data['judge']
137
+
138
+ if(self.use_gpu):
139
+ self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
140
+ self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
141
+ self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
142
+ self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
143
+
144
+ self.var_ref = Variable(self.input_ref,requires_grad=True)
145
+ self.var_p0 = Variable(self.input_p0,requires_grad=True)
146
+ self.var_p1 = Variable(self.input_p1,requires_grad=True)
147
+
148
+ def forward_train(self): # run forward pass
149
+ # print(self.net.module.scaling_layer.shift)
150
+ # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
151
+
152
+ self.d0 = self.forward(self.var_ref, self.var_p0)
153
+ self.d1 = self.forward(self.var_ref, self.var_p1)
154
+ self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
155
+
156
+ self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
157
+
158
+ self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
159
+
160
+ return self.loss_total
161
+
162
+ def backward_train(self):
163
+ torch.mean(self.loss_total).backward()
164
+
165
+ def compute_accuracy(self,d0,d1,judge):
166
+ ''' d0, d1 are Variables, judge is a Tensor '''
167
+ d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
168
+ judge_per = judge.cpu().numpy().flatten()
169
+ return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)
170
+
171
+ def get_current_errors(self):
172
+ retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
173
+ ('acc_r', self.acc_r)])
174
+
175
+ for key in retDict.keys():
176
+ retDict[key] = np.mean(retDict[key])
177
+
178
+ return retDict
179
+
180
+ def get_current_visuals(self):
181
+ zoom_factor = 256/self.var_ref.data.size()[2]
182
+
183
+ ref_img = util.tensor2im(self.var_ref.data)
184
+ p0_img = util.tensor2im(self.var_p0.data)
185
+ p1_img = util.tensor2im(self.var_p1.data)
186
+
187
+ ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
188
+ p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
189
+ p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)
190
+
191
+ return OrderedDict([('ref', ref_img_vis),
192
+ ('p0', p0_img_vis),
193
+ ('p1', p1_img_vis)])
194
+
195
+ def save(self, path, label):
196
+ if(self.use_gpu):
197
+ self.save_network(self.net.module, path, '', label)
198
+ else:
199
+ self.save_network(self.net, path, '', label)
200
+ self.save_network(self.rankLoss.net, path, 'rank', label)
201
+
202
+ def update_learning_rate(self,nepoch_decay):
203
+ lrd = self.lr / nepoch_decay
204
+ lr = self.old_lr - lrd
205
+
206
+ for param_group in self.optimizer_net.param_groups:
207
+ param_group['lr'] = lr
208
+
209
+ print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
210
+ self.old_lr = lr
211
+
212
+ def score_2afc_dataset(data_loader, func, name=''):
213
+ ''' Function computes Two Alternative Forced Choice (2AFC) score using
214
+ distance function 'func' in dataset 'data_loader'
215
+ INPUTS
216
+ data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
217
+ func - callable distance function - calling d=func(in0,in1) should take 2
218
+ pytorch tensors with shape Nx3xXxY, and return numpy array of length N
219
+ OUTPUTS
220
+ [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
221
+ [1] - dictionary with following elements
222
+ d0s,d1s - N arrays containing distances between reference patch to perturbed patches
223
+ gts - N array in [0,1], preferred patch selected by human evaluators
224
+ (closer to "0" for left patch p0, "1" for right patch p1,
225
+ "0.6" means 60pct people preferred right patch, 40pct preferred left)
226
+ scores - N array in [0,1], corresponding to what percentage function agreed with humans
227
+ CONSTS
228
+ N - number of test triplets in data_loader
229
+ '''
230
+
231
+ d0s = []
232
+ d1s = []
233
+ gts = []
234
+
235
+ for data in tqdm(data_loader.load_data(), desc=name):
236
+ d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
237
+ d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
238
+ gts+=data['judge'].cpu().numpy().flatten().tolist()
239
+
240
+ d0s = np.array(d0s)
241
+ d1s = np.array(d1s)
242
+ gts = np.array(gts)
243
+ scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5
244
+
245
+ return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))
246
+
247
+ def score_jnd_dataset(data_loader, func, name=''):
248
+ ''' Function computes JND score using distance function 'func' in dataset 'data_loader'
249
+ INPUTS
250
+ data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
251
+ func - callable distance function - calling d=func(in0,in1) should take 2
252
+ pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
253
+ OUTPUTS
254
+ [0] - JND score in [0,1], mAP score (area under precision-recall curve)
255
+ [1] - dictionary with following elements
256
+ ds - N array containing distances between two patches shown to human evaluator
257
+ sames - N array containing fraction of people who thought the two patches were identical
258
+ CONSTS
259
+ N - number of test triplets in data_loader
260
+ '''
261
+
262
+ ds = []
263
+ gts = []
264
+
265
+ for data in tqdm(data_loader.load_data(), desc=name):
266
+ ds+=func(data['p0'],data['p1']).data.cpu().numpy().tolist()
267
+ gts+=data['same'].cpu().numpy().flatten().tolist()
268
+
269
+ sames = np.array(gts)
270
+ ds = np.array(ds)
271
+
272
+ sorted_inds = np.argsort(ds)
273
+ ds_sorted = ds[sorted_inds]
274
+ sames_sorted = sames[sorted_inds]
275
+
276
+ TPs = np.cumsum(sames_sorted)
277
+ FPs = np.cumsum(1-sames_sorted)
278
+ FNs = np.sum(sames_sorted)-TPs
279
+
280
+ precs = TPs/(TPs+FPs)
281
+ recs = TPs/(TPs+FNs)
282
+ score = util.voc_ap(recs,precs)
283
+
284
+ return(score, dict(ds=ds,sames=sames))
lpips/networks_basic.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+
4
+ import sys
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.init as init
8
+ from torch.autograd import Variable
9
+ import numpy as np
10
+ from pdb import set_trace as st
11
+ from skimage import color
12
+ from IPython import embed
13
+ from . import pretrained_networks as pn
14
+
15
+ import lpips as util
16
+
17
+ def spatial_average(in_tens, keepdim=True):
18
+ return in_tens.mean([2,3],keepdim=keepdim)
19
+
20
+ def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
21
+ in_H = in_tens.shape[2]
22
+ scale_factor = 1.*out_H/in_H
23
+
24
+ return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
25
+
26
+ # Learned perceptual metric
27
+ class PNetLin(nn.Module):
28
+ def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
29
+ super(PNetLin, self).__init__()
30
+
31
+ self.pnet_type = pnet_type
32
+ self.pnet_tune = pnet_tune
33
+ self.pnet_rand = pnet_rand
34
+ self.spatial = spatial
35
+ self.lpips = lpips
36
+ self.version = version
37
+ self.scaling_layer = ScalingLayer()
38
+
39
+ if(self.pnet_type in ['vgg','vgg16']):
40
+ net_type = pn.vgg16
41
+ self.chns = [64,128,256,512,512]
42
+ elif(self.pnet_type=='alex'):
43
+ net_type = pn.alexnet
44
+ self.chns = [64,192,384,256,256]
45
+ elif(self.pnet_type=='squeeze'):
46
+ net_type = pn.squeezenet
47
+ self.chns = [64,128,256,384,384,512,512]
48
+ self.L = len(self.chns)
49
+
50
+ self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
51
+
52
+ if(lpips):
53
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
54
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
55
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
56
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
57
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
58
+ self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
59
+ if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
60
+ self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
61
+ self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
62
+ self.lins+=[self.lin5,self.lin6]
63
+
64
+ def forward(self, in0, in1, retPerLayer=False):
65
+ # v0.0 - original release had a bug, where input was not scaled
66
+ in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
67
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
68
+ feats0, feats1, diffs = {}, {}, {}
69
+
70
+ for kk in range(self.L):
71
+ feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
72
+ diffs[kk] = (feats0[kk]-feats1[kk])**2
73
+
74
+ if(self.lpips):
75
+ if(self.spatial):
76
+ res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
77
+ else:
78
+ res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
79
+ else:
80
+ if(self.spatial):
81
+ res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
82
+ else:
83
+ res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
84
+
85
+ val = res[0]
86
+ for l in range(1,self.L):
87
+ val += res[l]
88
+
89
+ if(retPerLayer):
90
+ return (val, res)
91
+ else:
92
+ return val
93
+
94
+ class ScalingLayer(nn.Module):
95
+ def __init__(self):
96
+ super(ScalingLayer, self).__init__()
97
+ self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
98
+ self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
99
+
100
+ def forward(self, inp):
101
+ return (inp - self.shift) / self.scale
102
+
103
+
104
+ class NetLinLayer(nn.Module):
105
+ ''' A single linear layer which does a 1x1 conv '''
106
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
107
+ super(NetLinLayer, self).__init__()
108
+
109
+ layers = [nn.Dropout(),] if(use_dropout) else []
110
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
111
+ self.model = nn.Sequential(*layers)
112
+
113
+
114
+ class Dist2LogitLayer(nn.Module):
115
+ ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
116
+ def __init__(self, chn_mid=32, use_sigmoid=True):
117
+ super(Dist2LogitLayer, self).__init__()
118
+
119
+ layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
120
+ layers += [nn.LeakyReLU(0.2,True),]
121
+ layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
122
+ layers += [nn.LeakyReLU(0.2,True),]
123
+ layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
124
+ if(use_sigmoid):
125
+ layers += [nn.Sigmoid(),]
126
+ self.model = nn.Sequential(*layers)
127
+
128
+ def forward(self,d0,d1,eps=0.1):
129
+ return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
130
+
131
+ class BCERankingLoss(nn.Module):
132
+ def __init__(self, chn_mid=32):
133
+ super(BCERankingLoss, self).__init__()
134
+ self.net = Dist2LogitLayer(chn_mid=chn_mid)
135
+ # self.parameters = list(self.net.parameters())
136
+ self.loss = torch.nn.BCELoss()
137
+
138
+ def forward(self, d0, d1, judge):
139
+ per = (judge+1.)/2.
140
+ self.logit = self.net.forward(d0,d1)
141
+ return self.loss(self.logit, per)
142
+
143
+ # L2, DSSIM metrics
144
+ class FakeNet(nn.Module):
145
+ def __init__(self, use_gpu=True, colorspace='Lab'):
146
+ super(FakeNet, self).__init__()
147
+ self.use_gpu = use_gpu
148
+ self.colorspace=colorspace
149
+
150
+ class L2(FakeNet):
151
+
152
+ def forward(self, in0, in1, retPerLayer=None):
153
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
154
+
155
+ if(self.colorspace=='RGB'):
156
+ (N,C,X,Y) = in0.size()
157
+ value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
158
+ return value
159
+ elif(self.colorspace=='Lab'):
160
+ value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
161
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
162
+ ret_var = Variable( torch.Tensor((value,) ) )
163
+ if(self.use_gpu):
164
+ ret_var = ret_var.cuda()
165
+ return ret_var
166
+
167
+ class DSSIM(FakeNet):
168
+
169
+ def forward(self, in0, in1, retPerLayer=None):
170
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
171
+
172
+ if(self.colorspace=='RGB'):
173
+ value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
174
+ elif(self.colorspace=='Lab'):
175
+ value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
176
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
177
+ ret_var = Variable( torch.Tensor((value,) ) )
178
+ if(self.use_gpu):
179
+ ret_var = ret_var.cuda()
180
+ return ret_var
181
+
182
+ def print_network(net):
183
+ num_params = 0
184
+ for param in net.parameters():
185
+ num_params += param.numel()
186
+ print('Network',net)
187
+ print('Total number of parameters: %d' % num_params)
lpips/pretrained_networks.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ from torchvision import models as tv
4
+ from IPython import embed
5
+
6
+ class squeezenet(torch.nn.Module):
7
+ def __init__(self, requires_grad=False, pretrained=True):
8
+ super(squeezenet, self).__init__()
9
+ pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
10
+ self.slice1 = torch.nn.Sequential()
11
+ self.slice2 = torch.nn.Sequential()
12
+ self.slice3 = torch.nn.Sequential()
13
+ self.slice4 = torch.nn.Sequential()
14
+ self.slice5 = torch.nn.Sequential()
15
+ self.slice6 = torch.nn.Sequential()
16
+ self.slice7 = torch.nn.Sequential()
17
+ self.N_slices = 7
18
+ for x in range(2):
19
+ self.slice1.add_module(str(x), pretrained_features[x])
20
+ for x in range(2,5):
21
+ self.slice2.add_module(str(x), pretrained_features[x])
22
+ for x in range(5, 8):
23
+ self.slice3.add_module(str(x), pretrained_features[x])
24
+ for x in range(8, 10):
25
+ self.slice4.add_module(str(x), pretrained_features[x])
26
+ for x in range(10, 11):
27
+ self.slice5.add_module(str(x), pretrained_features[x])
28
+ for x in range(11, 12):
29
+ self.slice6.add_module(str(x), pretrained_features[x])
30
+ for x in range(12, 13):
31
+ self.slice7.add_module(str(x), pretrained_features[x])
32
+ if not requires_grad:
33
+ for param in self.parameters():
34
+ param.requires_grad = False
35
+
36
+ def forward(self, X):
37
+ h = self.slice1(X)
38
+ h_relu1 = h
39
+ h = self.slice2(h)
40
+ h_relu2 = h
41
+ h = self.slice3(h)
42
+ h_relu3 = h
43
+ h = self.slice4(h)
44
+ h_relu4 = h
45
+ h = self.slice5(h)
46
+ h_relu5 = h
47
+ h = self.slice6(h)
48
+ h_relu6 = h
49
+ h = self.slice7(h)
50
+ h_relu7 = h
51
+ vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
52
+ out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
53
+
54
+ return out
55
+
56
+
57
+ class alexnet(torch.nn.Module):
58
+ def __init__(self, requires_grad=False, pretrained=True):
59
+ super(alexnet, self).__init__()
60
+ alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
61
+ self.slice1 = torch.nn.Sequential()
62
+ self.slice2 = torch.nn.Sequential()
63
+ self.slice3 = torch.nn.Sequential()
64
+ self.slice4 = torch.nn.Sequential()
65
+ self.slice5 = torch.nn.Sequential()
66
+ self.N_slices = 5
67
+ for x in range(2):
68
+ self.slice1.add_module(str(x), alexnet_pretrained_features[x])
69
+ for x in range(2, 5):
70
+ self.slice2.add_module(str(x), alexnet_pretrained_features[x])
71
+ for x in range(5, 8):
72
+ self.slice3.add_module(str(x), alexnet_pretrained_features[x])
73
+ for x in range(8, 10):
74
+ self.slice4.add_module(str(x), alexnet_pretrained_features[x])
75
+ for x in range(10, 12):
76
+ self.slice5.add_module(str(x), alexnet_pretrained_features[x])
77
+ if not requires_grad:
78
+ for param in self.parameters():
79
+ param.requires_grad = False
80
+
81
+ def forward(self, X):
82
+ h = self.slice1(X)
83
+ h_relu1 = h
84
+ h = self.slice2(h)
85
+ h_relu2 = h
86
+ h = self.slice3(h)
87
+ h_relu3 = h
88
+ h = self.slice4(h)
89
+ h_relu4 = h
90
+ h = self.slice5(h)
91
+ h_relu5 = h
92
+ alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
93
+ out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
94
+
95
+ return out
96
+
97
+ class vgg16(torch.nn.Module):
98
+ def __init__(self, requires_grad=False, pretrained=True):
99
+ super(vgg16, self).__init__()
100
+ vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
101
+ self.slice1 = torch.nn.Sequential()
102
+ self.slice2 = torch.nn.Sequential()
103
+ self.slice3 = torch.nn.Sequential()
104
+ self.slice4 = torch.nn.Sequential()
105
+ self.slice5 = torch.nn.Sequential()
106
+ self.N_slices = 5
107
+ for x in range(4):
108
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
109
+ for x in range(4, 9):
110
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
111
+ for x in range(9, 16):
112
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
113
+ for x in range(16, 23):
114
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
115
+ for x in range(23, 30):
116
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
117
+ if not requires_grad:
118
+ for param in self.parameters():
119
+ param.requires_grad = False
120
+
121
+ def forward(self, X):
122
+ h = self.slice1(X)
123
+ h_relu1_2 = h
124
+ h = self.slice2(h)
125
+ h_relu2_2 = h
126
+ h = self.slice3(h)
127
+ h_relu3_3 = h
128
+ h = self.slice4(h)
129
+ h_relu4_3 = h
130
+ h = self.slice5(h)
131
+ h_relu5_3 = h
132
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
133
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
134
+
135
+ return out
136
+
137
+
138
+
139
+ class resnet(torch.nn.Module):
140
+ def __init__(self, requires_grad=False, pretrained=True, num=18):
141
+ super(resnet, self).__init__()
142
+ if(num==18):
143
+ self.net = tv.resnet18(pretrained=pretrained)
144
+ elif(num==34):
145
+ self.net = tv.resnet34(pretrained=pretrained)
146
+ elif(num==50):
147
+ self.net = tv.resnet50(pretrained=pretrained)
148
+ elif(num==101):
149
+ self.net = tv.resnet101(pretrained=pretrained)
150
+ elif(num==152):
151
+ self.net = tv.resnet152(pretrained=pretrained)
152
+ self.N_slices = 5
153
+
154
+ self.conv1 = self.net.conv1
155
+ self.bn1 = self.net.bn1
156
+ self.relu = self.net.relu
157
+ self.maxpool = self.net.maxpool
158
+ self.layer1 = self.net.layer1
159
+ self.layer2 = self.net.layer2
160
+ self.layer3 = self.net.layer3
161
+ self.layer4 = self.net.layer4
162
+
163
+ def forward(self, X):
164
+ h = self.conv1(X)
165
+ h = self.bn1(h)
166
+ h = self.relu(h)
167
+ h_relu1 = h
168
+ h = self.maxpool(h)
169
+ h = self.layer1(h)
170
+ h_conv2 = h
171
+ h = self.layer2(h)
172
+ h_conv3 = h
173
+ h = self.layer3(h)
174
+ h_conv4 = h
175
+ h = self.layer4(h)
176
+ h_conv5 = h
177
+
178
+ outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
179
+ out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
180
+
181
+ return out
lpips/weights/v0.0/alex.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18720f55913d0af89042f13faa7e536a6ce1444a0914e6db9461355ece1e8cd5
3
+ size 5455
lpips/weights/v0.0/squeeze.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c27abd3a0145541baa50990817df58d3759c3f8154949f42af3b59b4e042d0bf
3
+ size 10057
lpips/weights/v0.0/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9e4236260c3dd988fc79d2a48d645d885afcbb21f9fd595e6744cf7419b582c
3
+ size 6735
lpips/weights/v0.1/alex.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0
3
+ size 6009
lpips/weights/v0.1/squeeze.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76
3
+ size 10811
lpips/weights/v0.1/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
3
+ size 7289
model.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import functools
4
+ import operator
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torch.autograd import Function
10
+
11
+ from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
12
+
13
+
14
+ class PixelNorm(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def forward(self, input):
19
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
20
+
21
+
22
+ def make_kernel(k):
23
+ k = torch.tensor(k, dtype=torch.float32)
24
+
25
+ if k.ndim == 1:
26
+ k = k[None, :] * k[:, None]
27
+
28
+ k /= k.sum()
29
+
30
+ return k
31
+
32
+
33
+ class Upsample(nn.Module):
34
+ def __init__(self, kernel, factor=2):
35
+ super().__init__()
36
+
37
+ self.factor = factor
38
+ kernel = make_kernel(kernel) * (factor ** 2)
39
+ self.register_buffer("kernel", kernel)
40
+
41
+ p = kernel.shape[0] - factor
42
+
43
+ pad0 = (p + 1) // 2 + factor - 1
44
+ pad1 = p // 2
45
+
46
+ self.pad = (pad0, pad1)
47
+
48
+ def forward(self, input):
49
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
50
+
51
+ return out
52
+
53
+
54
+ class Downsample(nn.Module):
55
+ def __init__(self, kernel, factor=2):
56
+ super().__init__()
57
+
58
+ self.factor = factor
59
+ kernel = make_kernel(kernel)
60
+ self.register_buffer("kernel", kernel)
61
+
62
+ p = kernel.shape[0] - factor
63
+
64
+ pad0 = (p + 1) // 2
65
+ pad1 = p // 2
66
+
67
+ self.pad = (pad0, pad1)
68
+
69
+ def forward(self, input):
70
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
71
+
72
+ return out
73
+
74
+
75
+ class Blur(nn.Module):
76
+ def __init__(self, kernel, pad, upsample_factor=1):
77
+ super().__init__()
78
+
79
+ kernel = make_kernel(kernel)
80
+
81
+ if upsample_factor > 1:
82
+ kernel = kernel * (upsample_factor ** 2)
83
+
84
+ self.register_buffer("kernel", kernel)
85
+
86
+ self.pad = pad
87
+
88
+ def forward(self, input):
89
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
90
+
91
+ return out
92
+
93
+
94
+ class EqualConv2d(nn.Module):
95
+ def __init__(
96
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
97
+ ):
98
+ super().__init__()
99
+
100
+ self.weight = nn.Parameter(
101
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
102
+ )
103
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
104
+
105
+ self.stride = stride
106
+ self.padding = padding
107
+
108
+ if bias:
109
+ self.bias = nn.Parameter(torch.zeros(out_channel))
110
+
111
+ else:
112
+ self.bias = None
113
+
114
+ def forward(self, input):
115
+ out = conv2d_gradfix.conv2d(
116
+ input,
117
+ self.weight * self.scale,
118
+ bias=self.bias,
119
+ stride=self.stride,
120
+ padding=self.padding,
121
+ )
122
+
123
+ return out
124
+
125
+ def __repr__(self):
126
+ return (
127
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
128
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
129
+ )
130
+
131
+
132
+ class EqualLinear(nn.Module):
133
+ def __init__(
134
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
135
+ ):
136
+ super().__init__()
137
+
138
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
139
+
140
+ if bias:
141
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
142
+
143
+ else:
144
+ self.bias = None
145
+
146
+ self.activation = activation
147
+
148
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
149
+ self.lr_mul = lr_mul
150
+
151
+ def forward(self, input):
152
+ if self.activation:
153
+ out = F.linear(input, self.weight * self.scale)
154
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
155
+
156
+ else:
157
+ out = F.linear(
158
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
159
+ )
160
+
161
+ return out
162
+
163
+ def __repr__(self):
164
+ return (
165
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
166
+ )
167
+
168
+
169
+ class ModulatedConv2d(nn.Module):
170
+ def __init__(
171
+ self,
172
+ in_channel,
173
+ out_channel,
174
+ kernel_size,
175
+ style_dim,
176
+ demodulate=True,
177
+ upsample=False,
178
+ downsample=False,
179
+ blur_kernel=[1, 3, 3, 1],
180
+ fused=True,
181
+ ):
182
+ super().__init__()
183
+
184
+ self.eps = 1e-8
185
+ self.kernel_size = kernel_size
186
+ self.in_channel = in_channel
187
+ self.out_channel = out_channel
188
+ self.upsample = upsample
189
+ self.downsample = downsample
190
+
191
+ if upsample:
192
+ factor = 2
193
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
194
+ pad0 = (p + 1) // 2 + factor - 1
195
+ pad1 = p // 2 + 1
196
+
197
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
198
+
199
+ if downsample:
200
+ factor = 2
201
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
202
+ pad0 = (p + 1) // 2
203
+ pad1 = p // 2
204
+
205
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
206
+
207
+ fan_in = in_channel * kernel_size ** 2
208
+ self.scale = 1 / math.sqrt(fan_in)
209
+ self.padding = kernel_size // 2
210
+
211
+ self.weight = nn.Parameter(
212
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
213
+ )
214
+
215
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
216
+
217
+ self.demodulate = demodulate
218
+ self.fused = fused
219
+
220
+ def __repr__(self):
221
+ return (
222
+ f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
223
+ f"upsample={self.upsample}, downsample={self.downsample})"
224
+ )
225
+
226
+ def forward(self, input, style):
227
+ batch, in_channel, height, width = input.shape
228
+
229
+ if not self.fused:
230
+ weight = self.scale * self.weight.squeeze(0)
231
+ style = self.modulation(style)
232
+
233
+ if self.demodulate:
234
+ w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
235
+ dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
236
+
237
+ input = input * style.reshape(batch, in_channel, 1, 1)
238
+
239
+ if self.upsample:
240
+ weight = weight.transpose(0, 1)
241
+ out = conv2d_gradfix.conv_transpose2d(
242
+ input, weight, padding=0, stride=2
243
+ )
244
+ out = self.blur(out)
245
+
246
+ elif self.downsample:
247
+ input = self.blur(input)
248
+ out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
249
+
250
+ else:
251
+ out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
252
+
253
+ if self.demodulate:
254
+ out = out * dcoefs.view(batch, -1, 1, 1)
255
+
256
+ return out
257
+
258
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
259
+ weight = self.scale * self.weight * style
260
+
261
+ if self.demodulate:
262
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
263
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
264
+
265
+ weight = weight.view(
266
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
267
+ )
268
+
269
+ if self.upsample:
270
+ input = input.view(1, batch * in_channel, height, width)
271
+ weight = weight.view(
272
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
273
+ )
274
+ weight = weight.transpose(1, 2).reshape(
275
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
276
+ )
277
+ out = conv2d_gradfix.conv_transpose2d(
278
+ input, weight, padding=0, stride=2, groups=batch
279
+ )
280
+ _, _, height, width = out.shape
281
+ out = out.view(batch, self.out_channel, height, width)
282
+ out = self.blur(out)
283
+
284
+ elif self.downsample:
285
+ input = self.blur(input)
286
+ _, _, height, width = input.shape
287
+ input = input.view(1, batch * in_channel, height, width)
288
+ out = conv2d_gradfix.conv2d(
289
+ input, weight, padding=0, stride=2, groups=batch
290
+ )
291
+ _, _, height, width = out.shape
292
+ out = out.view(batch, self.out_channel, height, width)
293
+
294
+ else:
295
+ input = input.view(1, batch * in_channel, height, width)
296
+ out = conv2d_gradfix.conv2d(
297
+ input, weight, padding=self.padding, groups=batch
298
+ )
299
+ _, _, height, width = out.shape
300
+ out = out.view(batch, self.out_channel, height, width)
301
+
302
+ return out
303
+
304
+
305
+ class NoiseInjection(nn.Module):
306
+ def __init__(self):
307
+ super().__init__()
308
+
309
+ self.weight = nn.Parameter(torch.zeros(1))
310
+
311
+ def forward(self, image, noise=None):
312
+ if noise is None:
313
+ batch, _, height, width = image.shape
314
+ noise = image.new_empty(batch, 1, height, width).normal_()
315
+
316
+ return image + self.weight * noise
317
+
318
+
319
+ class ConstantInput(nn.Module):
320
+ def __init__(self, channel, size=4):
321
+ super().__init__()
322
+
323
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
324
+
325
+ def forward(self, input):
326
+ batch = input.shape[0]
327
+ out = self.input.repeat(batch, 1, 1, 1)
328
+
329
+ return out
330
+
331
+
332
+ class StyledConv(nn.Module):
333
+ def __init__(
334
+ self,
335
+ in_channel,
336
+ out_channel,
337
+ kernel_size,
338
+ style_dim,
339
+ upsample=False,
340
+ blur_kernel=[1, 3, 3, 1],
341
+ demodulate=True,
342
+ ):
343
+ super().__init__()
344
+
345
+ self.conv = ModulatedConv2d(
346
+ in_channel,
347
+ out_channel,
348
+ kernel_size,
349
+ style_dim,
350
+ upsample=upsample,
351
+ blur_kernel=blur_kernel,
352
+ demodulate=demodulate,
353
+ )
354
+
355
+ self.noise = NoiseInjection()
356
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
357
+ # self.activate = ScaledLeakyReLU(0.2)
358
+ self.activate = FusedLeakyReLU(out_channel)
359
+
360
+ def forward(self, input, style, noise=None):
361
+ out = self.conv(input, style)
362
+ out = self.noise(out, noise=noise)
363
+ # out = out + self.bias
364
+ out = self.activate(out)
365
+
366
+ return out
367
+
368
+
369
+ class ToRGB(nn.Module):
370
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
371
+ super().__init__()
372
+
373
+ if upsample:
374
+ self.upsample = Upsample(blur_kernel)
375
+
376
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
377
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
378
+
379
+ def forward(self, input, style, skip=None):
380
+ out = self.conv(input, style)
381
+ out = out + self.bias
382
+
383
+ if skip is not None:
384
+ skip = self.upsample(skip)
385
+
386
+ out = out + skip
387
+
388
+ return out
389
+
390
+
391
+ class Generator(nn.Module):
392
+ def __init__(
393
+ self,
394
+ size,
395
+ style_dim,
396
+ n_mlp,
397
+ channel_multiplier=2,
398
+ blur_kernel=[1, 3, 3, 1],
399
+ lr_mlp=0.01,
400
+ ):
401
+ super().__init__()
402
+
403
+ self.size = size
404
+
405
+ self.style_dim = style_dim
406
+
407
+ layers = [PixelNorm()]
408
+
409
+ for i in range(n_mlp):
410
+ layers.append(
411
+ EqualLinear(
412
+ style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
413
+ )
414
+ )
415
+
416
+ self.style = nn.Sequential(*layers)
417
+
418
+ self.channels = {
419
+ 4: 512,
420
+ 8: 512,
421
+ 16: 512,
422
+ 32: 512,
423
+ 64: 256 * channel_multiplier,
424
+ 128: 128 * channel_multiplier,
425
+ 256: 64 * channel_multiplier,
426
+ 512: 32 * channel_multiplier,
427
+ 1024: 16 * channel_multiplier,
428
+ }
429
+
430
+ self.input = ConstantInput(self.channels[4])
431
+ self.conv1 = StyledConv(
432
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
433
+ )
434
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
435
+
436
+ self.log_size = int(math.log(size, 2))
437
+ self.num_layers = (self.log_size - 2) * 2 + 1
438
+
439
+ self.convs = nn.ModuleList()
440
+ self.upsamples = nn.ModuleList()
441
+ self.to_rgbs = nn.ModuleList()
442
+ self.noises = nn.Module()
443
+
444
+ in_channel = self.channels[4]
445
+
446
+ for layer_idx in range(self.num_layers):
447
+ res = (layer_idx + 5) // 2
448
+ shape = [1, 1, 2 ** res, 2 ** res]
449
+ self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
450
+
451
+ for i in range(3, self.log_size + 1):
452
+ out_channel = self.channels[2 ** i]
453
+
454
+ self.convs.append(
455
+ StyledConv(
456
+ in_channel,
457
+ out_channel,
458
+ 3,
459
+ style_dim,
460
+ upsample=True,
461
+ blur_kernel=blur_kernel,
462
+ )
463
+ )
464
+
465
+ self.convs.append(
466
+ StyledConv(
467
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
468
+ )
469
+ )
470
+
471
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
472
+
473
+ in_channel = out_channel
474
+
475
+ self.n_latent = self.log_size * 2 - 2
476
+
477
+ def make_noise(self):
478
+ device = self.input.input.device
479
+
480
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
481
+
482
+ for i in range(3, self.log_size + 1):
483
+ for _ in range(2):
484
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
485
+
486
+ return noises
487
+
488
+ def mean_latent(self, n_latent):
489
+ latent_in = torch.randn(
490
+ n_latent, self.style_dim, device=self.input.input.device
491
+ )
492
+ latent = self.style(latent_in).mean(0, keepdim=True)
493
+
494
+ return latent
495
+
496
+ def get_latent(self, input):
497
+ return self.style(input)
498
+
499
+ def forward(
500
+ self,
501
+ styles,
502
+ return_latents=False,
503
+ inject_index=None,
504
+ truncation=1,
505
+ truncation_latent=None,
506
+ input_is_latent=False,
507
+ noise=None,
508
+ randomize_noise=True,
509
+ ):
510
+ if not input_is_latent:
511
+ styles = [self.style(s) for s in styles]
512
+
513
+ if noise is None:
514
+ if randomize_noise:
515
+ noise = [None] * self.num_layers
516
+ else:
517
+ noise = [
518
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
519
+ ]
520
+
521
+ if truncation < 1:
522
+ style_t = []
523
+
524
+ for style in styles:
525
+ style_t.append(
526
+ truncation_latent + truncation * (style - truncation_latent)
527
+ )
528
+
529
+ styles = style_t
530
+
531
+ if len(styles) < 2:
532
+ inject_index = self.n_latent
533
+
534
+ if styles[0].ndim < 3:
535
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
536
+
537
+ else:
538
+ latent = styles[0]
539
+
540
+ else:
541
+ if inject_index is None:
542
+ inject_index = random.randint(1, self.n_latent - 1)
543
+
544
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
545
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
546
+
547
+ latent = torch.cat([latent, latent2], 1)
548
+
549
+ out = self.input(latent)
550
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
551
+
552
+ skip = self.to_rgb1(out, latent[:, 1])
553
+
554
+ i = 1
555
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
556
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
557
+ ):
558
+ out = conv1(out, latent[:, i], noise=noise1)
559
+ out = conv2(out, latent[:, i + 1], noise=noise2)
560
+ skip = to_rgb(out, latent[:, i + 2], skip)
561
+
562
+ i += 2
563
+
564
+ image = skip
565
+
566
+ if return_latents:
567
+ return image, latent
568
+
569
+ else:
570
+ return image, None
571
+
572
+
573
+ class ConvLayer(nn.Sequential):
574
+ def __init__(
575
+ self,
576
+ in_channel,
577
+ out_channel,
578
+ kernel_size,
579
+ downsample=False,
580
+ blur_kernel=[1, 3, 3, 1],
581
+ bias=True,
582
+ activate=True,
583
+ ):
584
+ layers = []
585
+
586
+ if downsample:
587
+ factor = 2
588
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
589
+ pad0 = (p + 1) // 2
590
+ pad1 = p // 2
591
+
592
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
593
+
594
+ stride = 2
595
+ self.padding = 0
596
+
597
+ else:
598
+ stride = 1
599
+ self.padding = kernel_size // 2
600
+
601
+ layers.append(
602
+ EqualConv2d(
603
+ in_channel,
604
+ out_channel,
605
+ kernel_size,
606
+ padding=self.padding,
607
+ stride=stride,
608
+ bias=bias and not activate,
609
+ )
610
+ )
611
+
612
+ if activate:
613
+ layers.append(FusedLeakyReLU(out_channel, bias=bias))
614
+
615
+ super().__init__(*layers)
616
+
617
+
618
+ class ResBlock(nn.Module):
619
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
620
+ super().__init__()
621
+
622
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
623
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
624
+
625
+ self.skip = ConvLayer(
626
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
627
+ )
628
+
629
+ def forward(self, input):
630
+ out = self.conv1(input)
631
+ out = self.conv2(out)
632
+
633
+ skip = self.skip(input)
634
+ out = (out + skip) / math.sqrt(2)
635
+
636
+ return out
637
+
638
+
639
+ class Discriminator(nn.Module):
640
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
641
+ super().__init__()
642
+
643
+ channels = {
644
+ 4: 512,
645
+ 8: 512,
646
+ 16: 512,
647
+ 32: 512,
648
+ 64: 256 * channel_multiplier,
649
+ 128: 128 * channel_multiplier,
650
+ 256: 64 * channel_multiplier,
651
+ 512: 32 * channel_multiplier,
652
+ 1024: 16 * channel_multiplier,
653
+ }
654
+
655
+ convs = [ConvLayer(3, channels[size], 1)]
656
+
657
+ log_size = int(math.log(size, 2))
658
+
659
+ in_channel = channels[size]
660
+
661
+ for i in range(log_size, 2, -1):
662
+ out_channel = channels[2 ** (i - 1)]
663
+
664
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
665
+
666
+ in_channel = out_channel
667
+
668
+ self.convs = nn.Sequential(*convs)
669
+
670
+ self.stddev_group = 4
671
+ self.stddev_feat = 1
672
+
673
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
674
+ self.final_linear = nn.Sequential(
675
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
676
+ EqualLinear(channels[4], 1),
677
+ )
678
+
679
+ def forward(self, input):
680
+ out = self.convs(input)
681
+
682
+ batch, channel, height, width = out.shape
683
+ group = min(batch, self.stddev_group)
684
+ stddev = out.view(
685
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
686
+ )
687
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
688
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
689
+ stddev = stddev.repeat(group, 1, height, width)
690
+ out = torch.cat([out, stddev], 1)
691
+
692
+ out = self.final_conv(out)
693
+
694
+ out = out.view(batch, -1)
695
+ out = self.final_linear(out)
696
+
697
+ return out
698
+
non_leaking.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import autograd
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+
8
+ from distributed import reduce_sum
9
+ from op import upfirdn2d
10
+
11
+
12
+ class AdaptiveAugment:
13
+ def __init__(self, ada_aug_target, ada_aug_len, update_every, device):
14
+ self.ada_aug_target = ada_aug_target
15
+ self.ada_aug_len = ada_aug_len
16
+ self.update_every = update_every
17
+
18
+ self.ada_update = 0
19
+ self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device)
20
+ self.r_t_stat = 0
21
+ self.ada_aug_p = 0
22
+
23
+ @torch.no_grad()
24
+ def tune(self, real_pred):
25
+ self.ada_aug_buf += torch.tensor(
26
+ (torch.sign(real_pred).sum().item(), real_pred.shape[0]),
27
+ device=real_pred.device,
28
+ )
29
+ self.ada_update += 1
30
+
31
+ if self.ada_update % self.update_every == 0:
32
+ self.ada_aug_buf = reduce_sum(self.ada_aug_buf)
33
+ pred_signs, n_pred = self.ada_aug_buf.tolist()
34
+
35
+ self.r_t_stat = pred_signs / n_pred
36
+
37
+ if self.r_t_stat > self.ada_aug_target:
38
+ sign = 1
39
+
40
+ else:
41
+ sign = -1
42
+
43
+ self.ada_aug_p += sign * n_pred / self.ada_aug_len
44
+ self.ada_aug_p = min(1, max(0, self.ada_aug_p))
45
+ self.ada_aug_buf.mul_(0)
46
+ self.ada_update = 0
47
+
48
+ return self.ada_aug_p
49
+
50
+
51
+ SYM6 = (
52
+ 0.015404109327027373,
53
+ 0.0034907120842174702,
54
+ -0.11799011114819057,
55
+ -0.048311742585633,
56
+ 0.4910559419267466,
57
+ 0.787641141030194,
58
+ 0.3379294217276218,
59
+ -0.07263752278646252,
60
+ -0.021060292512300564,
61
+ 0.04472490177066578,
62
+ 0.0017677118642428036,
63
+ -0.007800708325034148,
64
+ )
65
+
66
+
67
+ def translate_mat(t_x, t_y, device="cpu"):
68
+ batch = t_x.shape[0]
69
+
70
+ mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
71
+ translate = torch.stack((t_x, t_y), 1)
72
+ mat[:, :2, 2] = translate
73
+
74
+ return mat
75
+
76
+
77
+ def rotate_mat(theta, device="cpu"):
78
+ batch = theta.shape[0]
79
+
80
+ mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
81
+ sin_t = torch.sin(theta)
82
+ cos_t = torch.cos(theta)
83
+ rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
84
+ mat[:, :2, :2] = rot
85
+
86
+ return mat
87
+
88
+
89
+ def scale_mat(s_x, s_y, device="cpu"):
90
+ batch = s_x.shape[0]
91
+
92
+ mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
93
+ mat[:, 0, 0] = s_x
94
+ mat[:, 1, 1] = s_y
95
+
96
+ return mat
97
+
98
+
99
+ def translate3d_mat(t_x, t_y, t_z):
100
+ batch = t_x.shape[0]
101
+
102
+ mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
103
+ translate = torch.stack((t_x, t_y, t_z), 1)
104
+ mat[:, :3, 3] = translate
105
+
106
+ return mat
107
+
108
+
109
+ def rotate3d_mat(axis, theta):
110
+ batch = theta.shape[0]
111
+
112
+ u_x, u_y, u_z = axis
113
+
114
+ eye = torch.eye(3).unsqueeze(0)
115
+ cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0)
116
+ outer = torch.tensor(axis)
117
+ outer = (outer.unsqueeze(1) * outer).unsqueeze(0)
118
+
119
+ sin_t = torch.sin(theta).view(-1, 1, 1)
120
+ cos_t = torch.cos(theta).view(-1, 1, 1)
121
+
122
+ rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer
123
+
124
+ eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
125
+ eye_4[:, :3, :3] = rot
126
+
127
+ return eye_4
128
+
129
+
130
+ def scale3d_mat(s_x, s_y, s_z):
131
+ batch = s_x.shape[0]
132
+
133
+ mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
134
+ mat[:, 0, 0] = s_x
135
+ mat[:, 1, 1] = s_y
136
+ mat[:, 2, 2] = s_z
137
+
138
+ return mat
139
+
140
+
141
+ def luma_flip_mat(axis, i):
142
+ batch = i.shape[0]
143
+
144
+ eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
145
+ axis = torch.tensor(axis + (0,))
146
+ flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1)
147
+
148
+ return eye - flip
149
+
150
+
151
+ def saturation_mat(axis, i):
152
+ batch = i.shape[0]
153
+
154
+ eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
155
+ axis = torch.tensor(axis + (0,))
156
+ axis = torch.ger(axis, axis)
157
+ saturate = axis + (eye - axis) * i.view(-1, 1, 1)
158
+
159
+ return saturate
160
+
161
+
162
+ def lognormal_sample(size, mean=0, std=1, device="cpu"):
163
+ return torch.empty(size, device=device).log_normal_(mean=mean, std=std)
164
+
165
+
166
+ def category_sample(size, categories, device="cpu"):
167
+ category = torch.tensor(categories, device=device)
168
+ sample = torch.randint(high=len(categories), size=(size,), device=device)
169
+
170
+ return category[sample]
171
+
172
+
173
+ def uniform_sample(size, low, high, device="cpu"):
174
+ return torch.empty(size, device=device).uniform_(low, high)
175
+
176
+
177
+ def normal_sample(size, mean=0, std=1, device="cpu"):
178
+ return torch.empty(size, device=device).normal_(mean, std)
179
+
180
+
181
+ def bernoulli_sample(size, p, device="cpu"):
182
+ return torch.empty(size, device=device).bernoulli_(p)
183
+
184
+
185
+ def random_mat_apply(p, transform, prev, eye, device="cpu"):
186
+ size = transform.shape[0]
187
+ select = bernoulli_sample(size, p, device=device).view(size, 1, 1)
188
+ select_transform = select * transform + (1 - select) * eye
189
+
190
+ return select_transform @ prev
191
+
192
+
193
+ def sample_affine(p, size, height, width, device="cpu"):
194
+ G = torch.eye(3, device=device).unsqueeze(0).repeat(size, 1, 1)
195
+ eye = G
196
+
197
+ # flip
198
+ param = category_sample(size, (0, 1))
199
+ Gc = scale_mat(1 - 2.0 * param, torch.ones(size), device=device)
200
+ G = random_mat_apply(p, Gc, G, eye, device=device)
201
+ # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
202
+
203
+ # 90 rotate
204
+ param = category_sample(size, (0, 3))
205
+ Gc = rotate_mat(-math.pi / 2 * param, device=device)
206
+ G = random_mat_apply(p, Gc, G, eye, device=device)
207
+ # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
208
+
209
+ # integer translate
210
+ param = uniform_sample((2, size), -0.125, 0.125)
211
+ param_height = torch.round(param[0] * height)
212
+ param_width = torch.round(param[1] * width)
213
+ Gc = translate_mat(param_width, param_height, device=device)
214
+ G = random_mat_apply(p, Gc, G, eye, device=device)
215
+ # print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
216
+
217
+ # isotropic scale
218
+ param = lognormal_sample(size, std=0.2 * math.log(2))
219
+ Gc = scale_mat(param, param, device=device)
220
+ G = random_mat_apply(p, Gc, G, eye, device=device)
221
+ # print('isotropic scale', G, scale_mat(param, param), sep='\n')
222
+
223
+ p_rot = 1 - math.sqrt(1 - p)
224
+
225
+ # pre-rotate
226
+ param = uniform_sample(size, -math.pi, math.pi)
227
+ Gc = rotate_mat(-param, device=device)
228
+ G = random_mat_apply(p_rot, Gc, G, eye, device=device)
229
+ # print('pre-rotate', G, rotate_mat(-param), sep='\n')
230
+
231
+ # anisotropic scale
232
+ param = lognormal_sample(size, std=0.2 * math.log(2))
233
+ Gc = scale_mat(param, 1 / param, device=device)
234
+ G = random_mat_apply(p, Gc, G, eye, device=device)
235
+ # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
236
+
237
+ # post-rotate
238
+ param = uniform_sample(size, -math.pi, math.pi)
239
+ Gc = rotate_mat(-param, device=device)
240
+ G = random_mat_apply(p_rot, Gc, G, eye, device=device)
241
+ # print('post-rotate', G, rotate_mat(-param), sep='\n')
242
+
243
+ # fractional translate
244
+ param = normal_sample((2, size), std=0.125)
245
+ Gc = translate_mat(param[1] * width, param[0] * height, device=device)
246
+ G = random_mat_apply(p, Gc, G, eye, device=device)
247
+ # print('fractional translate', G, translate_mat(param, param), sep='\n')
248
+
249
+ return G
250
+
251
+
252
+ def sample_color(p, size):
253
+ C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1)
254
+ eye = C
255
+ axis_val = 1 / math.sqrt(3)
256
+ axis = (axis_val, axis_val, axis_val)
257
+
258
+ # brightness
259
+ param = normal_sample(size, std=0.2)
260
+ Cc = translate3d_mat(param, param, param)
261
+ C = random_mat_apply(p, Cc, C, eye)
262
+
263
+ # contrast
264
+ param = lognormal_sample(size, std=0.5 * math.log(2))
265
+ Cc = scale3d_mat(param, param, param)
266
+ C = random_mat_apply(p, Cc, C, eye)
267
+
268
+ # luma flip
269
+ param = category_sample(size, (0, 1))
270
+ Cc = luma_flip_mat(axis, param)
271
+ C = random_mat_apply(p, Cc, C, eye)
272
+
273
+ # hue rotation
274
+ param = uniform_sample(size, -math.pi, math.pi)
275
+ Cc = rotate3d_mat(axis, param)
276
+ C = random_mat_apply(p, Cc, C, eye)
277
+
278
+ # saturation
279
+ param = lognormal_sample(size, std=1 * math.log(2))
280
+ Cc = saturation_mat(axis, param)
281
+ C = random_mat_apply(p, Cc, C, eye)
282
+
283
+ return C
284
+
285
+
286
+ def make_grid(shape, x0, x1, y0, y1, device):
287
+ n, c, h, w = shape
288
+ grid = torch.empty(n, h, w, 3, device=device)
289
+ grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device)
290
+ grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1)
291
+ grid[:, :, :, 2] = 1
292
+
293
+ return grid
294
+
295
+
296
+ def affine_grid(grid, mat):
297
+ n, h, w, _ = grid.shape
298
+ return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2)
299
+
300
+
301
+ def get_padding(G, height, width, kernel_size):
302
+ device = G.device
303
+
304
+ cx = (width - 1) / 2
305
+ cy = (height - 1) / 2
306
+ cp = torch.tensor(
307
+ [(-cx, -cy, 1), (cx, -cy, 1), (cx, cy, 1), (-cx, cy, 1)], device=device
308
+ )
309
+ cp = G @ cp.T
310
+
311
+ pad_k = kernel_size // 4
312
+
313
+ pad = cp[:, :2, :].permute(1, 0, 2).flatten(1)
314
+ pad = torch.cat((-pad, pad)).max(1).values
315
+ pad = pad + torch.tensor([pad_k * 2 - cx, pad_k * 2 - cy] * 2, device=device)
316
+ pad = pad.max(torch.tensor([0, 0] * 2, device=device))
317
+ pad = pad.min(torch.tensor([width - 1, height - 1] * 2, device=device))
318
+
319
+ pad_x1, pad_y1, pad_x2, pad_y2 = pad.ceil().to(torch.int32)
320
+
321
+ return pad_x1, pad_x2, pad_y1, pad_y2
322
+
323
+
324
+ def try_sample_affine_and_pad(img, p, kernel_size, G=None):
325
+ batch, _, height, width = img.shape
326
+
327
+ G_try = G
328
+
329
+ if G is None:
330
+ G_try = torch.inverse(sample_affine(p, batch, height, width))
331
+
332
+ pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(G_try, height, width, kernel_size)
333
+
334
+ img_pad = F.pad(img, (pad_x1, pad_x2, pad_y1, pad_y2), mode="reflect")
335
+
336
+ return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2)
337
+
338
+
339
+ class GridSampleForward(autograd.Function):
340
+ @staticmethod
341
+ def forward(ctx, input, grid):
342
+ out = F.grid_sample(
343
+ input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
344
+ )
345
+ ctx.save_for_backward(input, grid)
346
+
347
+ return out
348
+
349
+ @staticmethod
350
+ def backward(ctx, grad_output):
351
+ input, grid = ctx.saved_tensors
352
+ grad_input, grad_grid = GridSampleBackward.apply(grad_output, input, grid)
353
+
354
+ return grad_input, grad_grid
355
+
356
+
357
+ class GridSampleBackward(autograd.Function):
358
+ @staticmethod
359
+ def forward(ctx, grad_output, input, grid):
360
+ op = torch._C._jit_get_operation("aten::grid_sampler_2d_backward")
361
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
362
+ ctx.save_for_backward(grid)
363
+
364
+ return grad_input, grad_grid
365
+
366
+ @staticmethod
367
+ def backward(ctx, grad_grad_input, grad_grad_grid):
368
+ (grid,) = ctx.saved_tensors
369
+ grad_grad_output = None
370
+
371
+ if ctx.needs_input_grad[0]:
372
+ grad_grad_output = GridSampleForward.apply(grad_grad_input, grid)
373
+
374
+ return grad_grad_output, None, None
375
+
376
+
377
+ grid_sample = GridSampleForward.apply
378
+
379
+
380
+ def scale_mat_single(s_x, s_y):
381
+ return torch.tensor(((s_x, 0, 0), (0, s_y, 0), (0, 0, 1)), dtype=torch.float32)
382
+
383
+
384
+ def translate_mat_single(t_x, t_y):
385
+ return torch.tensor(((1, 0, t_x), (0, 1, t_y), (0, 0, 1)), dtype=torch.float32)
386
+
387
+
388
+ def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):
389
+ kernel = antialiasing_kernel
390
+ len_k = len(kernel)
391
+
392
+ kernel = torch.as_tensor(kernel).to(img)
393
+ # kernel = torch.ger(kernel, kernel).to(img)
394
+ kernel_flip = torch.flip(kernel, (0,))
395
+
396
+ img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad(
397
+ img, p, len_k, G
398
+ )
399
+
400
+ G_inv = (
401
+ translate_mat_single((pad_x1 - pad_x2).item() / 2, (pad_y1 - pad_y2).item() / 2)
402
+ @ G
403
+ )
404
+ up_pad = (
405
+ (len_k + 2 - 1) // 2,
406
+ (len_k - 2) // 2,
407
+ (len_k + 2 - 1) // 2,
408
+ (len_k - 2) // 2,
409
+ )
410
+ img_2x = upfirdn2d(img_pad, kernel.unsqueeze(0), up=(2, 1), pad=(*up_pad[:2], 0, 0))
411
+ img_2x = upfirdn2d(img_2x, kernel.unsqueeze(1), up=(1, 2), pad=(0, 0, *up_pad[2:]))
412
+ G_inv = scale_mat_single(2, 2) @ G_inv @ scale_mat_single(1 / 2, 1 / 2)
413
+ G_inv = translate_mat_single(-0.5, -0.5) @ G_inv @ translate_mat_single(0.5, 0.5)
414
+ batch_size, channel, height, width = img.shape
415
+ pad_k = len_k // 4
416
+ shape = (batch_size, channel, (height + pad_k * 2) * 2, (width + pad_k * 2) * 2)
417
+ G_inv = (
418
+ scale_mat_single(2 / img_2x.shape[3], 2 / img_2x.shape[2])
419
+ @ G_inv
420
+ @ scale_mat_single(1 / (2 / shape[3]), 1 / (2 / shape[2]))
421
+ )
422
+ grid = F.affine_grid(G_inv[:, :2, :].to(img_2x), shape, align_corners=False)
423
+ img_affine = grid_sample(img_2x, grid)
424
+ d_p = -pad_k * 2
425
+ down_pad = (
426
+ d_p + (len_k - 2 + 1) // 2,
427
+ d_p + (len_k - 2) // 2,
428
+ d_p + (len_k - 2 + 1) // 2,
429
+ d_p + (len_k - 2) // 2,
430
+ )
431
+ img_down = upfirdn2d(
432
+ img_affine, kernel_flip.unsqueeze(0), down=(2, 1), pad=(*down_pad[:2], 0, 0)
433
+ )
434
+ img_down = upfirdn2d(
435
+ img_down, kernel_flip.unsqueeze(1), down=(1, 2), pad=(0, 0, *down_pad[2:])
436
+ )
437
+
438
+ return img_down, G
439
+
440
+
441
+ def apply_color(img, mat):
442
+ batch = img.shape[0]
443
+ img = img.permute(0, 2, 3, 1)
444
+ mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3)
445
+ mat_add = mat[:, :3, 3].view(batch, 1, 1, 3)
446
+ img = img @ mat_mul + mat_add
447
+ img = img.permute(0, 3, 1, 2)
448
+
449
+ return img
450
+
451
+
452
+ def random_apply_color(img, p, C=None):
453
+ if C is None:
454
+ C = sample_color(p, img.shape[0])
455
+
456
+ img = apply_color(img, C.to(img))
457
+
458
+ return img, C
459
+
460
+
461
+ def augment(img, p, transform_matrix=(None, None)):
462
+ img, G = random_apply_affine(img, p, transform_matrix[0])
463
+ img, C = random_apply_color(img, p, transform_matrix[1])
464
+
465
+ return img, (G, C)
op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
op/conv2d_gradfix.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import warnings
3
+
4
+ import torch
5
+ from torch import autograd
6
+ from torch.nn import functional as F
7
+
8
+ enabled = True
9
+ weight_gradients_disabled = False
10
+
11
+
12
+ @contextlib.contextmanager
13
+ def no_weight_gradients():
14
+ global weight_gradients_disabled
15
+
16
+ old = weight_gradients_disabled
17
+ weight_gradients_disabled = True
18
+ yield
19
+ weight_gradients_disabled = old
20
+
21
+
22
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23
+ if could_use_op(input):
24
+ return conv2d_gradfix(
25
+ transpose=False,
26
+ weight_shape=weight.shape,
27
+ stride=stride,
28
+ padding=padding,
29
+ output_padding=0,
30
+ dilation=dilation,
31
+ groups=groups,
32
+ ).apply(input, weight, bias)
33
+
34
+ return F.conv2d(
35
+ input=input,
36
+ weight=weight,
37
+ bias=bias,
38
+ stride=stride,
39
+ padding=padding,
40
+ dilation=dilation,
41
+ groups=groups,
42
+ )
43
+
44
+
45
+ def conv_transpose2d(
46
+ input,
47
+ weight,
48
+ bias=None,
49
+ stride=1,
50
+ padding=0,
51
+ output_padding=0,
52
+ groups=1,
53
+ dilation=1,
54
+ ):
55
+ if could_use_op(input):
56
+ return conv2d_gradfix(
57
+ transpose=True,
58
+ weight_shape=weight.shape,
59
+ stride=stride,
60
+ padding=padding,
61
+ output_padding=output_padding,
62
+ groups=groups,
63
+ dilation=dilation,
64
+ ).apply(input, weight, bias)
65
+
66
+ return F.conv_transpose2d(
67
+ input=input,
68
+ weight=weight,
69
+ bias=bias,
70
+ stride=stride,
71
+ padding=padding,
72
+ output_padding=output_padding,
73
+ dilation=dilation,
74
+ groups=groups,
75
+ )
76
+
77
+
78
+ def could_use_op(input):
79
+ if (not enabled) or (not torch.backends.cudnn.enabled):
80
+ return False
81
+
82
+ if input.device.type != "cuda":
83
+ return False
84
+
85
+ if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86
+ return True
87
+
88
+ warnings.warn(
89
+ f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90
+ )
91
+
92
+ return False
93
+
94
+
95
+ def ensure_tuple(xs, ndim):
96
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97
+
98
+ return xs
99
+
100
+
101
+ conv2d_gradfix_cache = dict()
102
+
103
+
104
+ def conv2d_gradfix(
105
+ transpose, weight_shape, stride, padding, output_padding, dilation, groups
106
+ ):
107
+ ndim = 2
108
+ weight_shape = tuple(weight_shape)
109
+ stride = ensure_tuple(stride, ndim)
110
+ padding = ensure_tuple(padding, ndim)
111
+ output_padding = ensure_tuple(output_padding, ndim)
112
+ dilation = ensure_tuple(dilation, ndim)
113
+
114
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115
+ if key in conv2d_gradfix_cache:
116
+ return conv2d_gradfix_cache[key]
117
+
118
+ common_kwargs = dict(
119
+ stride=stride, padding=padding, dilation=dilation, groups=groups
120
+ )
121
+
122
+ def calc_output_padding(input_shape, output_shape):
123
+ if transpose:
124
+ return [0, 0]
125
+
126
+ return [
127
+ input_shape[i + 2]
128
+ - (output_shape[i + 2] - 1) * stride[i]
129
+ - (1 - 2 * padding[i])
130
+ - dilation[i] * (weight_shape[i + 2] - 1)
131
+ for i in range(ndim)
132
+ ]
133
+
134
+ class Conv2d(autograd.Function):
135
+ @staticmethod
136
+ def forward(ctx, input, weight, bias):
137
+ if not transpose:
138
+ out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139
+
140
+ else:
141
+ out = F.conv_transpose2d(
142
+ input=input,
143
+ weight=weight,
144
+ bias=bias,
145
+ output_padding=output_padding,
146
+ **common_kwargs,
147
+ )
148
+
149
+ ctx.save_for_backward(input, weight)
150
+
151
+ return out
152
+
153
+ @staticmethod
154
+ def backward(ctx, grad_output):
155
+ input, weight = ctx.saved_tensors
156
+ grad_input, grad_weight, grad_bias = None, None, None
157
+
158
+ if ctx.needs_input_grad[0]:
159
+ p = calc_output_padding(
160
+ input_shape=input.shape, output_shape=grad_output.shape
161
+ )
162
+ grad_input = conv2d_gradfix(
163
+ transpose=(not transpose),
164
+ weight_shape=weight_shape,
165
+ output_padding=p,
166
+ **common_kwargs,
167
+ ).apply(grad_output, weight, None)
168
+
169
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
171
+
172
+ if ctx.needs_input_grad[2]:
173
+ grad_bias = grad_output.sum((0, 2, 3))
174
+
175
+ return grad_input, grad_weight, grad_bias
176
+
177
+ class Conv2dGradWeight(autograd.Function):
178
+ @staticmethod
179
+ def forward(ctx, grad_output, input):
180
+ op = torch._C._jit_get_operation(
181
+ "aten::cudnn_convolution_backward_weight"
182
+ if not transpose
183
+ else "aten::cudnn_convolution_transpose_backward_weight"
184
+ )
185
+ flags = [
186
+ torch.backends.cudnn.benchmark,
187
+ torch.backends.cudnn.deterministic,
188
+ torch.backends.cudnn.allow_tf32,
189
+ ]
190
+ grad_weight = op(
191
+ weight_shape,
192
+ grad_output,
193
+ input,
194
+ padding,
195
+ stride,
196
+ dilation,
197
+ groups,
198
+ *flags,
199
+ )
200
+ ctx.save_for_backward(grad_output, input)
201
+
202
+ return grad_weight
203
+
204
+ @staticmethod
205
+ def backward(ctx, grad_grad_weight):
206
+ grad_output, input = ctx.saved_tensors
207
+ grad_grad_output, grad_grad_input = None, None
208
+
209
+ if ctx.needs_input_grad[0]:
210
+ grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211
+
212
+ if ctx.needs_input_grad[1]:
213
+ p = calc_output_padding(
214
+ input_shape=input.shape, output_shape=grad_output.shape
215
+ )
216
+ grad_grad_input = conv2d_gradfix(
217
+ transpose=(not transpose),
218
+ weight_shape=weight_shape,
219
+ output_padding=p,
220
+ **common_kwargs,
221
+ ).apply(grad_output, grad_grad_weight, None)
222
+
223
+ return grad_grad_output, grad_grad_input
224
+
225
+ conv2d_gradfix_cache[key] = Conv2d
226
+
227
+ return Conv2d
op/fused_act.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ fused = load(
12
+ "fused",
13
+ sources=[
14
+ os.path.join(module_path, "fused_bias_act.cpp"),
15
+ os.path.join(module_path, "fused_bias_act_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class FusedLeakyReLUFunctionBackward(Function):
21
+ @staticmethod
22
+ def forward(ctx, grad_output, out, bias, negative_slope, scale):
23
+ ctx.save_for_backward(out)
24
+ ctx.negative_slope = negative_slope
25
+ ctx.scale = scale
26
+
27
+ empty = grad_output.new_empty(0)
28
+
29
+ grad_input = fused.fused_bias_act(
30
+ grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
31
+ )
32
+
33
+ dim = [0]
34
+
35
+ if grad_input.ndim > 2:
36
+ dim += list(range(2, grad_input.ndim))
37
+
38
+ if bias:
39
+ grad_bias = grad_input.sum(dim).detach()
40
+
41
+ else:
42
+ grad_bias = empty
43
+
44
+ return grad_input, grad_bias
45
+
46
+ @staticmethod
47
+ def backward(ctx, gradgrad_input, gradgrad_bias):
48
+ out, = ctx.saved_tensors
49
+ gradgrad_out = fused.fused_bias_act(
50
+ gradgrad_input.contiguous(),
51
+ gradgrad_bias,
52
+ out,
53
+ 3,
54
+ 1,
55
+ ctx.negative_slope,
56
+ ctx.scale,
57
+ )
58
+
59
+ return gradgrad_out, None, None, None, None
60
+
61
+
62
+ class FusedLeakyReLUFunction(Function):
63
+ @staticmethod
64
+ def forward(ctx, input, bias, negative_slope, scale):
65
+ empty = input.new_empty(0)
66
+
67
+ ctx.bias = bias is not None
68
+
69
+ if bias is None:
70
+ bias = empty
71
+
72
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
73
+ ctx.save_for_backward(out)
74
+ ctx.negative_slope = negative_slope
75
+ ctx.scale = scale
76
+
77
+ return out
78
+
79
+ @staticmethod
80
+ def backward(ctx, grad_output):
81
+ out, = ctx.saved_tensors
82
+
83
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
84
+ grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
85
+ )
86
+
87
+ if not ctx.bias:
88
+ grad_bias = None
89
+
90
+ return grad_input, grad_bias, None, None
91
+
92
+
93
+ class FusedLeakyReLU(nn.Module):
94
+ def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
95
+ super().__init__()
96
+
97
+ if bias:
98
+ self.bias = nn.Parameter(torch.zeros(channel))
99
+
100
+ else:
101
+ self.bias = None
102
+
103
+ self.negative_slope = negative_slope
104
+ self.scale = scale
105
+
106
+ def forward(self, input):
107
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
108
+
109
+
110
+ def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
111
+ if input.device.type == "cpu":
112
+ if bias is not None:
113
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
114
+ return (
115
+ F.leaky_relu(
116
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
117
+ )
118
+ * scale
119
+ )
120
+
121
+ else:
122
+ return F.leaky_relu(input, negative_slope=0.2) * scale
123
+
124
+ else:
125
+ return FusedLeakyReLUFunction.apply(
126
+ input.contiguous(), bias, negative_slope, scale
127
+ )
op/fused_bias_act.cpp ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <ATen/ATen.h>
3
+ #include <torch/extension.h>
4
+
5
+ torch::Tensor fused_bias_act_op(const torch::Tensor &input,
6
+ const torch::Tensor &bias,
7
+ const torch::Tensor &refer, int act, int grad,
8
+ float alpha, float scale);
9
+
10
+ #define CHECK_CUDA(x) \
11
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
12
+ #define CHECK_CONTIGUOUS(x) \
13
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14
+ #define CHECK_INPUT(x) \
15
+ CHECK_CUDA(x); \
16
+ CHECK_CONTIGUOUS(x)
17
+
18
+ torch::Tensor fused_bias_act(const torch::Tensor &input,
19
+ const torch::Tensor &bias,
20
+ const torch::Tensor &refer, int act, int grad,
21
+ float alpha, float scale) {
22
+ CHECK_INPUT(input);
23
+ CHECK_INPUT(bias);
24
+
25
+ at::DeviceGuard guard(input.device());
26
+
27
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
28
+ }
29
+
30
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
32
+ }
op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+
15
+ #include <cuda.h>
16
+ #include <cuda_runtime.h>
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void
20
+ fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
21
+ const scalar_t *p_ref, int act, int grad, scalar_t alpha,
22
+ scalar_t scale, int loop_x, int size_x, int step_b,
23
+ int size_b, int use_bias, int use_ref) {
24
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
25
+
26
+ scalar_t zero = 0.0;
27
+
28
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
29
+ loop_idx++, xi += blockDim.x) {
30
+ scalar_t x = p_x[xi];
31
+
32
+ if (use_bias) {
33
+ x += p_b[(xi / step_b) % size_b];
34
+ }
35
+
36
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
37
+
38
+ scalar_t y;
39
+
40
+ switch (act * 10 + grad) {
41
+ default:
42
+ case 10:
43
+ y = x;
44
+ break;
45
+ case 11:
46
+ y = x;
47
+ break;
48
+ case 12:
49
+ y = 0.0;
50
+ break;
51
+
52
+ case 30:
53
+ y = (x > 0.0) ? x : x * alpha;
54
+ break;
55
+ case 31:
56
+ y = (ref > 0.0) ? x : x * alpha;
57
+ break;
58
+ case 32:
59
+ y = 0.0;
60
+ break;
61
+ }
62
+
63
+ out[xi] = y * scale;
64
+ }
65
+ }
66
+
67
+ torch::Tensor fused_bias_act_op(const torch::Tensor &input,
68
+ const torch::Tensor &bias,
69
+ const torch::Tensor &refer, int act, int grad,
70
+ float alpha, float scale) {
71
+ int curDevice = -1;
72
+ cudaGetDevice(&curDevice);
73
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
74
+
75
+ auto x = input.contiguous();
76
+ auto b = bias.contiguous();
77
+ auto ref = refer.contiguous();
78
+
79
+ int use_bias = b.numel() ? 1 : 0;
80
+ int use_ref = ref.numel() ? 1 : 0;
81
+
82
+ int size_x = x.numel();
83
+ int size_b = b.numel();
84
+ int step_b = 1;
85
+
86
+ for (int i = 1 + 1; i < x.dim(); i++) {
87
+ step_b *= x.size(i);
88
+ }
89
+
90
+ int loop_x = 4;
91
+ int block_size = 4 * 32;
92
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
93
+
94
+ auto y = torch::empty_like(x);
95
+
96
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
97
+ x.scalar_type(), "fused_bias_act_kernel", [&] {
98
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
99
+ y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
100
+ b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
101
+ scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
102
+ });
103
+
104
+ return y;
105
+ }
op/upfirdn2d.cpp ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <torch/extension.h>
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
5
+ const torch::Tensor &kernel, int up_x, int up_y,
6
+ int down_x, int down_y, int pad_x0, int pad_x1,
7
+ int pad_y0, int pad_y1);
8
+
9
+ #define CHECK_CUDA(x) \
10
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11
+ #define CHECK_CONTIGUOUS(x) \
12
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13
+ #define CHECK_INPUT(x) \
14
+ CHECK_CUDA(x); \
15
+ CHECK_CONTIGUOUS(x)
16
+
17
+ torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
18
+ int up_x, int up_y, int down_x, int down_y, int pad_x0,
19
+ int pad_x1, int pad_y0, int pad_y1) {
20
+ CHECK_INPUT(input);
21
+ CHECK_INPUT(kernel);
22
+
23
+ at::DeviceGuard guard(input.device());
24
+
25
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
26
+ pad_y0, pad_y1);
27
+ }
28
+
29
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
30
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
31
+ }
op/upfirdn2d.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import abc
2
+ import os
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ upfirdn2d_op = load(
12
+ "upfirdn2d",
13
+ sources=[
14
+ os.path.join(module_path, "upfirdn2d.cpp"),
15
+ os.path.join(module_path, "upfirdn2d_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class UpFirDn2dBackward(Function):
21
+ @staticmethod
22
+ def forward(
23
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
24
+ ):
25
+
26
+ up_x, up_y = up
27
+ down_x, down_y = down
28
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
29
+
30
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
31
+
32
+ grad_input = upfirdn2d_op.upfirdn2d(
33
+ grad_output,
34
+ grad_kernel,
35
+ down_x,
36
+ down_y,
37
+ up_x,
38
+ up_y,
39
+ g_pad_x0,
40
+ g_pad_x1,
41
+ g_pad_y0,
42
+ g_pad_y1,
43
+ )
44
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
45
+
46
+ ctx.save_for_backward(kernel)
47
+
48
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
49
+
50
+ ctx.up_x = up_x
51
+ ctx.up_y = up_y
52
+ ctx.down_x = down_x
53
+ ctx.down_y = down_y
54
+ ctx.pad_x0 = pad_x0
55
+ ctx.pad_x1 = pad_x1
56
+ ctx.pad_y0 = pad_y0
57
+ ctx.pad_y1 = pad_y1
58
+ ctx.in_size = in_size
59
+ ctx.out_size = out_size
60
+
61
+ return grad_input
62
+
63
+ @staticmethod
64
+ def backward(ctx, gradgrad_input):
65
+ kernel, = ctx.saved_tensors
66
+
67
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
68
+
69
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
70
+ gradgrad_input,
71
+ kernel,
72
+ ctx.up_x,
73
+ ctx.up_y,
74
+ ctx.down_x,
75
+ ctx.down_y,
76
+ ctx.pad_x0,
77
+ ctx.pad_x1,
78
+ ctx.pad_y0,
79
+ ctx.pad_y1,
80
+ )
81
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
82
+ gradgrad_out = gradgrad_out.view(
83
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
84
+ )
85
+
86
+ return gradgrad_out, None, None, None, None, None, None, None, None
87
+
88
+
89
+ class UpFirDn2d(Function):
90
+ @staticmethod
91
+ def forward(ctx, input, kernel, up, down, pad):
92
+ up_x, up_y = up
93
+ down_x, down_y = down
94
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
95
+
96
+ kernel_h, kernel_w = kernel.shape
97
+ batch, channel, in_h, in_w = input.shape
98
+ ctx.in_size = input.shape
99
+
100
+ input = input.reshape(-1, in_h, in_w, 1)
101
+
102
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
103
+
104
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
105
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
106
+ ctx.out_size = (out_h, out_w)
107
+
108
+ ctx.up = (up_x, up_y)
109
+ ctx.down = (down_x, down_y)
110
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
111
+
112
+ g_pad_x0 = kernel_w - pad_x0 - 1
113
+ g_pad_y0 = kernel_h - pad_y0 - 1
114
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
115
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
116
+
117
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
118
+
119
+ out = upfirdn2d_op.upfirdn2d(
120
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
121
+ )
122
+ # out = out.view(major, out_h, out_w, minor)
123
+ out = out.view(-1, channel, out_h, out_w)
124
+
125
+ return out
126
+
127
+ @staticmethod
128
+ def backward(ctx, grad_output):
129
+ kernel, grad_kernel = ctx.saved_tensors
130
+
131
+ grad_input = None
132
+
133
+ if ctx.needs_input_grad[0]:
134
+ grad_input = UpFirDn2dBackward.apply(
135
+ grad_output,
136
+ kernel,
137
+ grad_kernel,
138
+ ctx.up,
139
+ ctx.down,
140
+ ctx.pad,
141
+ ctx.g_pad,
142
+ ctx.in_size,
143
+ ctx.out_size,
144
+ )
145
+
146
+ return grad_input, None, None, None, None
147
+
148
+
149
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
150
+ if not isinstance(up, abc.Iterable):
151
+ up = (up, up)
152
+
153
+ if not isinstance(down, abc.Iterable):
154
+ down = (down, down)
155
+
156
+ if len(pad) == 2:
157
+ pad = (pad[0], pad[1], pad[0], pad[1])
158
+
159
+ if input.device.type == "cpu":
160
+ out = upfirdn2d_native(input, kernel, *up, *down, *pad)
161
+
162
+ else:
163
+ out = UpFirDn2d.apply(input, kernel, up, down, pad)
164
+
165
+ return out
166
+
167
+
168
+ def upfirdn2d_native(
169
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
170
+ ):
171
+ _, channel, in_h, in_w = input.shape
172
+ input = input.reshape(-1, in_h, in_w, 1)
173
+
174
+ _, in_h, in_w, minor = input.shape
175
+ kernel_h, kernel_w = kernel.shape
176
+
177
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
178
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
179
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
180
+
181
+ out = F.pad(
182
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
183
+ )
184
+ out = out[
185
+ :,
186
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
187
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
188
+ :,
189
+ ]
190
+
191
+ out = out.permute(0, 3, 1, 2)
192
+ out = out.reshape(
193
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
194
+ )
195
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
196
+ out = F.conv2d(out, w)
197
+ out = out.reshape(
198
+ -1,
199
+ minor,
200
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
201
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
202
+ )
203
+ out = out.permute(0, 2, 3, 1)
204
+ out = out[:, ::down_y, ::down_x, :]
205
+
206
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
207
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
208
+
209
+ return out.view(-1, channel, out_h, out_w)
op/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18
+ int c = a / b;
19
+
20
+ if (c * b > a) {
21
+ c--;
22
+ }
23
+
24
+ return c;
25
+ }
26
+
27
+ struct UpFirDn2DKernelParams {
28
+ int up_x;
29
+ int up_y;
30
+ int down_x;
31
+ int down_y;
32
+ int pad_x0;
33
+ int pad_x1;
34
+ int pad_y0;
35
+ int pad_y1;
36
+
37
+ int major_dim;
38
+ int in_h;
39
+ int in_w;
40
+ int minor_dim;
41
+ int kernel_h;
42
+ int kernel_w;
43
+ int out_h;
44
+ int out_w;
45
+ int loop_major;
46
+ int loop_x;
47
+ };
48
+
49
+ template <typename scalar_t>
50
+ __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51
+ const scalar_t *kernel,
52
+ const UpFirDn2DKernelParams p) {
53
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54
+ int out_y = minor_idx / p.minor_dim;
55
+ minor_idx -= out_y * p.minor_dim;
56
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57
+ int major_idx_base = blockIdx.z * p.loop_major;
58
+
59
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
60
+ major_idx_base >= p.major_dim) {
61
+ return;
62
+ }
63
+
64
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68
+
69
+ for (int loop_major = 0, major_idx = major_idx_base;
70
+ loop_major < p.loop_major && major_idx < p.major_dim;
71
+ loop_major++, major_idx++) {
72
+ for (int loop_x = 0, out_x = out_x_base;
73
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78
+
79
+ const scalar_t *x_p =
80
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81
+ minor_idx];
82
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83
+ int x_px = p.minor_dim;
84
+ int k_px = -p.up_x;
85
+ int x_py = p.in_w * p.minor_dim;
86
+ int k_py = -p.up_y * p.kernel_w;
87
+
88
+ scalar_t v = 0.0f;
89
+
90
+ for (int y = 0; y < h; y++) {
91
+ for (int x = 0; x < w; x++) {
92
+ v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
93
+ x_p += x_px;
94
+ k_p += k_px;
95
+ }
96
+
97
+ x_p += x_py - w * x_px;
98
+ k_p += k_py - w * k_px;
99
+ }
100
+
101
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102
+ minor_idx] = v;
103
+ }
104
+ }
105
+ }
106
+
107
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
108
+ int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
109
+ __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110
+ const scalar_t *kernel,
111
+ const UpFirDn2DKernelParams p) {
112
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114
+
115
+ __shared__ volatile float sk[kernel_h][kernel_w];
116
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
117
+
118
+ int minor_idx = blockIdx.x;
119
+ int tile_out_y = minor_idx / p.minor_dim;
120
+ minor_idx -= tile_out_y * p.minor_dim;
121
+ tile_out_y *= tile_out_h;
122
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123
+ int major_idx_base = blockIdx.z * p.loop_major;
124
+
125
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126
+ major_idx_base >= p.major_dim) {
127
+ return;
128
+ }
129
+
130
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131
+ tap_idx += blockDim.x) {
132
+ int ky = tap_idx / kernel_w;
133
+ int kx = tap_idx - ky * kernel_w;
134
+ scalar_t v = 0.0;
135
+
136
+ if (kx < p.kernel_w & ky < p.kernel_h) {
137
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138
+ }
139
+
140
+ sk[ky][kx] = v;
141
+ }
142
+
143
+ for (int loop_major = 0, major_idx = major_idx_base;
144
+ loop_major < p.loop_major & major_idx < p.major_dim;
145
+ loop_major++, major_idx++) {
146
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
147
+ loop_x < p.loop_x & tile_out_x < p.out_w;
148
+ loop_x++, tile_out_x += tile_out_w) {
149
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151
+ int tile_in_x = floor_div(tile_mid_x, up_x);
152
+ int tile_in_y = floor_div(tile_mid_y, up_y);
153
+
154
+ __syncthreads();
155
+
156
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157
+ in_idx += blockDim.x) {
158
+ int rel_in_y = in_idx / tile_in_w;
159
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
160
+ int in_x = rel_in_x + tile_in_x;
161
+ int in_y = rel_in_y + tile_in_y;
162
+
163
+ scalar_t v = 0.0;
164
+
165
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167
+ p.minor_dim +
168
+ minor_idx];
169
+ }
170
+
171
+ sx[rel_in_y][rel_in_x] = v;
172
+ }
173
+
174
+ __syncthreads();
175
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176
+ out_idx += blockDim.x) {
177
+ int rel_out_y = out_idx / tile_out_w;
178
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
179
+ int out_x = rel_out_x + tile_out_x;
180
+ int out_y = rel_out_y + tile_out_y;
181
+
182
+ int mid_x = tile_mid_x + rel_out_x * down_x;
183
+ int mid_y = tile_mid_y + rel_out_y * down_y;
184
+ int in_x = floor_div(mid_x, up_x);
185
+ int in_y = floor_div(mid_y, up_y);
186
+ int rel_in_x = in_x - tile_in_x;
187
+ int rel_in_y = in_y - tile_in_y;
188
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190
+
191
+ scalar_t v = 0.0;
192
+
193
+ #pragma unroll
194
+ for (int y = 0; y < kernel_h / up_y; y++)
195
+ #pragma unroll
196
+ for (int x = 0; x < kernel_w / up_x; x++)
197
+ v += sx[rel_in_y + y][rel_in_x + x] *
198
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
199
+
200
+ if (out_x < p.out_w & out_y < p.out_h) {
201
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202
+ minor_idx] = v;
203
+ }
204
+ }
205
+ }
206
+ }
207
+ }
208
+
209
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210
+ const torch::Tensor &kernel, int up_x, int up_y,
211
+ int down_x, int down_y, int pad_x0, int pad_x1,
212
+ int pad_y0, int pad_y1) {
213
+ int curDevice = -1;
214
+ cudaGetDevice(&curDevice);
215
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
216
+
217
+ UpFirDn2DKernelParams p;
218
+
219
+ auto x = input.contiguous();
220
+ auto k = kernel.contiguous();
221
+
222
+ p.major_dim = x.size(0);
223
+ p.in_h = x.size(1);
224
+ p.in_w = x.size(2);
225
+ p.minor_dim = x.size(3);
226
+ p.kernel_h = k.size(0);
227
+ p.kernel_w = k.size(1);
228
+ p.up_x = up_x;
229
+ p.up_y = up_y;
230
+ p.down_x = down_x;
231
+ p.down_y = down_y;
232
+ p.pad_x0 = pad_x0;
233
+ p.pad_x1 = pad_x1;
234
+ p.pad_y0 = pad_y0;
235
+ p.pad_y1 = pad_y1;
236
+
237
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238
+ p.down_y;
239
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240
+ p.down_x;
241
+
242
+ auto out =
243
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244
+
245
+ int mode = -1;
246
+
247
+ int tile_out_h = -1;
248
+ int tile_out_w = -1;
249
+
250
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
252
+ mode = 1;
253
+ tile_out_h = 16;
254
+ tile_out_w = 64;
255
+ }
256
+
257
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
259
+ mode = 2;
260
+ tile_out_h = 16;
261
+ tile_out_w = 64;
262
+ }
263
+
264
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
266
+ mode = 3;
267
+ tile_out_h = 16;
268
+ tile_out_w = 64;
269
+ }
270
+
271
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
273
+ mode = 4;
274
+ tile_out_h = 16;
275
+ tile_out_w = 64;
276
+ }
277
+
278
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
280
+ mode = 5;
281
+ tile_out_h = 8;
282
+ tile_out_w = 32;
283
+ }
284
+
285
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
287
+ mode = 6;
288
+ tile_out_h = 8;
289
+ tile_out_w = 32;
290
+ }
291
+
292
+ dim3 block_size;
293
+ dim3 grid_size;
294
+
295
+ if (tile_out_h > 0 && tile_out_w > 0) {
296
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
297
+ p.loop_x = 1;
298
+ block_size = dim3(32 * 8, 1, 1);
299
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301
+ (p.major_dim - 1) / p.loop_major + 1);
302
+ } else {
303
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
304
+ p.loop_x = 4;
305
+ block_size = dim3(4, 32, 1);
306
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308
+ (p.major_dim - 1) / p.loop_major + 1);
309
+ }
310
+
311
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312
+ switch (mode) {
313
+ case 1:
314
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
315
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
316
+ x.data_ptr<scalar_t>(),
317
+ k.data_ptr<scalar_t>(), p);
318
+
319
+ break;
320
+
321
+ case 2:
322
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
323
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
324
+ x.data_ptr<scalar_t>(),
325
+ k.data_ptr<scalar_t>(), p);
326
+
327
+ break;
328
+
329
+ case 3:
330
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
331
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
332
+ x.data_ptr<scalar_t>(),
333
+ k.data_ptr<scalar_t>(), p);
334
+
335
+ break;
336
+
337
+ case 4:
338
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
339
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
340
+ x.data_ptr<scalar_t>(),
341
+ k.data_ptr<scalar_t>(), p);
342
+
343
+ break;
344
+
345
+ case 5:
346
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
347
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
348
+ x.data_ptr<scalar_t>(),
349
+ k.data_ptr<scalar_t>(), p);
350
+
351
+ break;
352
+
353
+ case 6:
354
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
355
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
356
+ x.data_ptr<scalar_t>(),
357
+ k.data_ptr<scalar_t>(), p);
358
+
359
+ break;
360
+
361
+ default:
362
+ upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
363
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
364
+ k.data_ptr<scalar_t>(), p);
365
+ }
366
+ });
367
+
368
+ return out;
369
+ }
ppl.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+
8
+ import lpips
9
+ from model import Generator
10
+
11
+
12
+ def normalize(x):
13
+ return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True))
14
+
15
+
16
+ def slerp(a, b, t):
17
+ a = normalize(a)
18
+ b = normalize(b)
19
+ d = (a * b).sum(-1, keepdim=True)
20
+ p = t * torch.acos(d)
21
+ c = normalize(b - d * a)
22
+ d = a * torch.cos(p) + c * torch.sin(p)
23
+
24
+ return normalize(d)
25
+
26
+
27
+ def lerp(a, b, t):
28
+ return a + (b - a) * t
29
+
30
+
31
+ if __name__ == "__main__":
32
+ device = "cuda"
33
+
34
+ parser = argparse.ArgumentParser(description="Perceptual Path Length calculator")
35
+
36
+ parser.add_argument(
37
+ "--space", choices=["z", "w"], help="space that PPL calculated with"
38
+ )
39
+ parser.add_argument(
40
+ "--batch", type=int, default=64, help="batch size for the models"
41
+ )
42
+ parser.add_argument(
43
+ "--n_sample",
44
+ type=int,
45
+ default=5000,
46
+ help="number of the samples for calculating PPL",
47
+ )
48
+ parser.add_argument(
49
+ "--size", type=int, default=256, help="output image sizes of the generator"
50
+ )
51
+ parser.add_argument(
52
+ "--eps", type=float, default=1e-4, help="epsilon for numerical stability"
53
+ )
54
+ parser.add_argument(
55
+ "--crop", action="store_true", help="apply center crop to the images"
56
+ )
57
+ parser.add_argument(
58
+ "--sampling",
59
+ default="end",
60
+ choices=["end", "full"],
61
+ help="set endpoint sampling method",
62
+ )
63
+ parser.add_argument(
64
+ "ckpt", metavar="CHECKPOINT", help="path to the model checkpoints"
65
+ )
66
+
67
+ args = parser.parse_args()
68
+
69
+ latent_dim = 512
70
+
71
+ ckpt = torch.load(args.ckpt)
72
+
73
+ g = Generator(args.size, latent_dim, 8).to(device)
74
+ g.load_state_dict(ckpt["g_ema"])
75
+ g.eval()
76
+
77
+ percept = lpips.PerceptualLoss(
78
+ model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
79
+ )
80
+
81
+ distances = []
82
+
83
+ n_batch = args.n_sample // args.batch
84
+ resid = args.n_sample - (n_batch * args.batch)
85
+ batch_sizes = [args.batch] * n_batch + [resid]
86
+
87
+ with torch.no_grad():
88
+ for batch in tqdm(batch_sizes):
89
+ noise = g.make_noise()
90
+
91
+ inputs = torch.randn([batch * 2, latent_dim], device=device)
92
+ if args.sampling == "full":
93
+ lerp_t = torch.rand(batch, device=device)
94
+ else:
95
+ lerp_t = torch.zeros(batch, device=device)
96
+
97
+ if args.space == "w":
98
+ latent = g.get_latent(inputs)
99
+ latent_t0, latent_t1 = latent[::2], latent[1::2]
100
+ latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None])
101
+ latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps)
102
+ latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape)
103
+
104
+ image, _ = g([latent_e], input_is_latent=True, noise=noise)
105
+
106
+ if args.crop:
107
+ c = image.shape[2] // 8
108
+ image = image[:, :, c * 3 : c * 7, c * 2 : c * 6]
109
+
110
+ factor = image.shape[2] // 256
111
+
112
+ if factor > 1:
113
+ image = F.interpolate(
114
+ image, size=(256, 256), mode="bilinear", align_corners=False
115
+ )
116
+
117
+ dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / (
118
+ args.eps ** 2
119
+ )
120
+ distances.append(dist.to("cpu").numpy())
121
+
122
+ distances = np.concatenate(distances, 0)
123
+
124
+ lo = np.percentile(distances, 1, interpolation="lower")
125
+ hi = np.percentile(distances, 99, interpolation="higher")
126
+ filtered_dist = np.extract(
127
+ np.logical_and(lo <= distances, distances <= hi), distances
128
+ )
129
+
130
+ print("ppl:", filtered_dist.mean())
prepare_data.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from io import BytesIO
3
+ import multiprocessing
4
+ from functools import partial
5
+
6
+ from PIL import Image
7
+ import lmdb
8
+ from tqdm import tqdm
9
+ from torchvision import datasets
10
+ from torchvision.transforms import functional as trans_fn
11
+
12
+
13
+ def resize_and_convert(img, size, resample, quality=100):
14
+ img = trans_fn.resize(img, size, resample)
15
+ img = trans_fn.center_crop(img, size)
16
+ buffer = BytesIO()
17
+ img.save(buffer, format="jpeg", quality=quality)
18
+ val = buffer.getvalue()
19
+
20
+ return val
21
+
22
+
23
+ def resize_multiple(
24
+ img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
25
+ ):
26
+ imgs = []
27
+
28
+ for size in sizes:
29
+ imgs.append(resize_and_convert(img, size, resample, quality))
30
+
31
+ return imgs
32
+
33
+
34
+ def resize_worker(img_file, sizes, resample):
35
+ i, file = img_file
36
+ img = Image.open(file)
37
+ img = img.convert("RGB")
38
+ out = resize_multiple(img, sizes=sizes, resample=resample)
39
+
40
+ return i, out
41
+
42
+
43
+ def prepare(
44
+ env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
45
+ ):
46
+ resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
47
+
48
+ files = sorted(dataset.imgs, key=lambda x: x[0])
49
+ files = [(i, file) for i, (file, label) in enumerate(files)]
50
+ total = 0
51
+
52
+ with multiprocessing.Pool(n_worker) as pool:
53
+ for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
54
+ for size, img in zip(sizes, imgs):
55
+ key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
56
+
57
+ with env.begin(write=True) as txn:
58
+ txn.put(key, img)
59
+
60
+ total += 1
61
+
62
+ with env.begin(write=True) as txn:
63
+ txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
64
+
65
+
66
+ if __name__ == "__main__":
67
+ parser = argparse.ArgumentParser(description="Preprocess images for model training")
68
+ parser.add_argument("--out", type=str, help="filename of the result lmdb dataset")
69
+ parser.add_argument(
70
+ "--size",
71
+ type=str,
72
+ default="128,256,512,1024",
73
+ help="resolutions of images for the dataset",
74
+ )
75
+ parser.add_argument(
76
+ "--n_worker",
77
+ type=int,
78
+ default=8,
79
+ help="number of workers for preparing dataset",
80
+ )
81
+ parser.add_argument(
82
+ "--resample",
83
+ type=str,
84
+ default="lanczos",
85
+ help="resampling methods for resizing images",
86
+ )
87
+ parser.add_argument("path", type=str, help="path to the image dataset")
88
+
89
+ args = parser.parse_args()
90
+
91
+ resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
92
+ resample = resample_map[args.resample]
93
+
94
+ sizes = [int(s.strip()) for s in args.size.split(",")]
95
+
96
+ print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
97
+
98
+ imgset = datasets.ImageFolder(args.path)
99
+
100
+ with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
101
+ prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)
projector.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+
5
+ import torch
6
+ from torch import optim
7
+ from torch.nn import functional as F
8
+ from torchvision import transforms
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+
12
+ import lpips
13
+ from model import Generator
14
+
15
+
16
+ def noise_regularize(noises):
17
+ loss = 0
18
+
19
+ for noise in noises:
20
+ size = noise.shape[2]
21
+
22
+ while True:
23
+ loss = (
24
+ loss
25
+ + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
26
+ + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
27
+ )
28
+
29
+ if size <= 8:
30
+ break
31
+
32
+ noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
33
+ noise = noise.mean([3, 5])
34
+ size //= 2
35
+
36
+ return loss
37
+
38
+
39
+ def noise_normalize_(noises):
40
+ for noise in noises:
41
+ mean = noise.mean()
42
+ std = noise.std()
43
+
44
+ noise.data.add_(-mean).div_(std)
45
+
46
+
47
+ def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
48
+ lr_ramp = min(1, (1 - t) / rampdown)
49
+ lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
50
+ lr_ramp = lr_ramp * min(1, t / rampup)
51
+
52
+ return initial_lr * lr_ramp
53
+
54
+
55
+ def latent_noise(latent, strength):
56
+ noise = torch.randn_like(latent) * strength
57
+
58
+ return latent + noise
59
+
60
+
61
+ def make_image(tensor):
62
+ return (
63
+ tensor.detach()
64
+ .clamp_(min=-1, max=1)
65
+ .add(1)
66
+ .div_(2)
67
+ .mul(255)
68
+ .type(torch.uint8)
69
+ .permute(0, 2, 3, 1)
70
+ .to("cpu")
71
+ .numpy()
72
+ )
73
+
74
+
75
+ if __name__ == "__main__":
76
+ device = "cuda"
77
+
78
+ parser = argparse.ArgumentParser(
79
+ description="Image projector to the generator latent spaces"
80
+ )
81
+ parser.add_argument(
82
+ "--ckpt", type=str, required=True, help="path to the model checkpoint"
83
+ )
84
+ parser.add_argument(
85
+ "--size", type=int, default=256, help="output image sizes of the generator"
86
+ )
87
+ parser.add_argument(
88
+ "--lr_rampup",
89
+ type=float,
90
+ default=0.05,
91
+ help="duration of the learning rate warmup",
92
+ )
93
+ parser.add_argument(
94
+ "--lr_rampdown",
95
+ type=float,
96
+ default=0.25,
97
+ help="duration of the learning rate decay",
98
+ )
99
+ parser.add_argument("--lr", type=float, default=0.1, help="learning rate")
100
+ parser.add_argument(
101
+ "--noise", type=float, default=0.05, help="strength of the noise level"
102
+ )
103
+ parser.add_argument(
104
+ "--noise_ramp",
105
+ type=float,
106
+ default=0.75,
107
+ help="duration of the noise level decay",
108
+ )
109
+ parser.add_argument("--step", type=int, default=1000, help="optimize iterations")
110
+ parser.add_argument(
111
+ "--noise_regularize",
112
+ type=float,
113
+ default=1e5,
114
+ help="weight of the noise regularization",
115
+ )
116
+ parser.add_argument("--mse", type=float, default=0, help="weight of the mse loss")
117
+ parser.add_argument(
118
+ "--w_plus",
119
+ action="store_true",
120
+ help="allow to use distinct latent codes to each layers",
121
+ )
122
+ parser.add_argument(
123
+ "files", metavar="FILES", nargs="+", help="path to image files to be projected"
124
+ )
125
+
126
+ args = parser.parse_args()
127
+
128
+ n_mean_latent = 10000
129
+
130
+ resize = min(args.size, 256)
131
+
132
+ transform = transforms.Compose(
133
+ [
134
+ transforms.Resize(resize),
135
+ transforms.CenterCrop(resize),
136
+ transforms.ToTensor(),
137
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
138
+ ]
139
+ )
140
+
141
+ imgs = []
142
+
143
+ for imgfile in args.files:
144
+ img = transform(Image.open(imgfile).convert("RGB"))
145
+ imgs.append(img)
146
+
147
+ imgs = torch.stack(imgs, 0).to(device)
148
+
149
+ g_ema = Generator(args.size, 512, 8)
150
+ g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
151
+ g_ema.eval()
152
+ g_ema = g_ema.to(device)
153
+
154
+ with torch.no_grad():
155
+ noise_sample = torch.randn(n_mean_latent, 512, device=device)
156
+ latent_out = g_ema.style(noise_sample)
157
+
158
+ latent_mean = latent_out.mean(0)
159
+ latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
160
+
161
+ percept = lpips.PerceptualLoss(
162
+ model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
163
+ )
164
+
165
+ noises_single = g_ema.make_noise()
166
+ noises = []
167
+ for noise in noises_single:
168
+ noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
169
+
170
+ latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1)
171
+
172
+ if args.w_plus:
173
+ latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
174
+
175
+ latent_in.requires_grad = True
176
+
177
+ for noise in noises:
178
+ noise.requires_grad = True
179
+
180
+ optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
181
+
182
+ pbar = tqdm(range(args.step))
183
+ latent_path = []
184
+
185
+ for i in pbar:
186
+ t = i / args.step
187
+ lr = get_lr(t, args.lr)
188
+ optimizer.param_groups[0]["lr"] = lr
189
+ noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2
190
+ latent_n = latent_noise(latent_in, noise_strength.item())
191
+
192
+ img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises)
193
+
194
+ batch, channel, height, width = img_gen.shape
195
+
196
+ if height > 256:
197
+ factor = height // 256
198
+
199
+ img_gen = img_gen.reshape(
200
+ batch, channel, height // factor, factor, width // factor, factor
201
+ )
202
+ img_gen = img_gen.mean([3, 5])
203
+
204
+ p_loss = percept(img_gen, imgs).sum()
205
+ n_loss = noise_regularize(noises)
206
+ mse_loss = F.mse_loss(img_gen, imgs)
207
+
208
+ loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
209
+
210
+ optimizer.zero_grad()
211
+ loss.backward()
212
+ optimizer.step()
213
+
214
+ noise_normalize_(noises)
215
+
216
+ if (i + 1) % 100 == 0:
217
+ latent_path.append(latent_in.detach().clone())
218
+
219
+ pbar.set_description(
220
+ (
221
+ f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
222
+ f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
223
+ )
224
+ )
225
+
226
+ img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises)
227
+
228
+ filename = os.path.splitext(os.path.basename(args.files[0]))[0] + ".pt"
229
+
230
+ img_ar = make_image(img_gen)
231
+
232
+ result_file = {}
233
+ for i, input_name in enumerate(args.files):
234
+ noise_single = []
235
+ for noise in noises:
236
+ noise_single.append(noise[i : i + 1])
237
+
238
+ result_file[input_name] = {
239
+ "img": img_gen[i],
240
+ "latent": latent_in[i],
241
+ "noise": noise_single,
242
+ }
243
+
244
+ img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png"
245
+ pil_img = Image.fromarray(img_ar[i])
246
+ pil_img.save(img_name)
247
+
248
+ torch.save(result_file, filename)
sample/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.png
swagan.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import functools
4
+ import operator
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torch.autograd import Function
10
+
11
+ from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
12
+ from model import (
13
+ ModulatedConv2d,
14
+ StyledConv,
15
+ ConstantInput,
16
+ PixelNorm,
17
+ Upsample,
18
+ Downsample,
19
+ Blur,
20
+ EqualLinear,
21
+ ConvLayer,
22
+ )
23
+
24
+
25
+ def get_haar_wavelet(in_channels):
26
+ haar_wav_l = 1 / (2 ** 0.5) * torch.ones(1, 2)
27
+ haar_wav_h = 1 / (2 ** 0.5) * torch.ones(1, 2)
28
+ haar_wav_h[0, 0] = -1 * haar_wav_h[0, 0]
29
+
30
+ haar_wav_ll = haar_wav_l.T * haar_wav_l
31
+ haar_wav_lh = haar_wav_h.T * haar_wav_l
32
+ haar_wav_hl = haar_wav_l.T * haar_wav_h
33
+ haar_wav_hh = haar_wav_h.T * haar_wav_h
34
+
35
+ return haar_wav_ll, haar_wav_lh, haar_wav_hl, haar_wav_hh
36
+
37
+
38
+ def dwt_init(x):
39
+ x01 = x[:, :, 0::2, :] / 2
40
+ x02 = x[:, :, 1::2, :] / 2
41
+ x1 = x01[:, :, :, 0::2]
42
+ x2 = x02[:, :, :, 0::2]
43
+ x3 = x01[:, :, :, 1::2]
44
+ x4 = x02[:, :, :, 1::2]
45
+ x_LL = x1 + x2 + x3 + x4
46
+ x_HL = -x1 - x2 + x3 + x4
47
+ x_LH = -x1 + x2 - x3 + x4
48
+ x_HH = x1 - x2 - x3 + x4
49
+
50
+ return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
51
+
52
+
53
+ def iwt_init(x):
54
+ r = 2
55
+ in_batch, in_channel, in_height, in_width = x.size()
56
+ # print([in_batch, in_channel, in_height, in_width])
57
+ out_batch, out_channel, out_height, out_width = (
58
+ in_batch,
59
+ int(in_channel / (r ** 2)),
60
+ r * in_height,
61
+ r * in_width,
62
+ )
63
+ x1 = x[:, 0:out_channel, :, :] / 2
64
+ x2 = x[:, out_channel : out_channel * 2, :, :] / 2
65
+ x3 = x[:, out_channel * 2 : out_channel * 3, :, :] / 2
66
+ x4 = x[:, out_channel * 3 : out_channel * 4, :, :] / 2
67
+
68
+ h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
69
+
70
+ h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
71
+ h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
72
+ h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
73
+ h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
74
+
75
+ return h
76
+
77
+
78
+ class HaarTransform(nn.Module):
79
+ def __init__(self, in_channels):
80
+ super().__init__()
81
+
82
+ ll, lh, hl, hh = get_haar_wavelet(in_channels)
83
+
84
+ self.register_buffer("ll", ll)
85
+ self.register_buffer("lh", lh)
86
+ self.register_buffer("hl", hl)
87
+ self.register_buffer("hh", hh)
88
+
89
+ def forward(self, input):
90
+ ll = upfirdn2d(input, self.ll, down=2)
91
+ lh = upfirdn2d(input, self.lh, down=2)
92
+ hl = upfirdn2d(input, self.hl, down=2)
93
+ hh = upfirdn2d(input, self.hh, down=2)
94
+
95
+ return torch.cat((ll, lh, hl, hh), 1)
96
+
97
+
98
+ class InverseHaarTransform(nn.Module):
99
+ def __init__(self, in_channels):
100
+ super().__init__()
101
+
102
+ ll, lh, hl, hh = get_haar_wavelet(in_channels)
103
+
104
+ self.register_buffer("ll", ll)
105
+ self.register_buffer("lh", -lh)
106
+ self.register_buffer("hl", -hl)
107
+ self.register_buffer("hh", hh)
108
+
109
+ def forward(self, input):
110
+ ll, lh, hl, hh = input.chunk(4, 1)
111
+ ll = upfirdn2d(ll, self.ll, up=2, pad=(1, 0, 1, 0))
112
+ lh = upfirdn2d(lh, self.lh, up=2, pad=(1, 0, 1, 0))
113
+ hl = upfirdn2d(hl, self.hl, up=2, pad=(1, 0, 1, 0))
114
+ hh = upfirdn2d(hh, self.hh, up=2, pad=(1, 0, 1, 0))
115
+
116
+ return ll + lh + hl + hh
117
+
118
+
119
+ class ToRGB(nn.Module):
120
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
121
+ super().__init__()
122
+
123
+ if upsample:
124
+ self.iwt = InverseHaarTransform(3)
125
+ self.upsample = Upsample(blur_kernel)
126
+ self.dwt = HaarTransform(3)
127
+
128
+ self.conv = ModulatedConv2d(in_channel, 3 * 4, 1, style_dim, demodulate=False)
129
+ self.bias = nn.Parameter(torch.zeros(1, 3 * 4, 1, 1))
130
+
131
+ def forward(self, input, style, skip=None):
132
+ out = self.conv(input, style)
133
+ out = out + self.bias
134
+
135
+ if skip is not None:
136
+ skip = self.iwt(skip)
137
+ skip = self.upsample(skip)
138
+ skip = self.dwt(skip)
139
+
140
+ out = out + skip
141
+
142
+ return out
143
+
144
+
145
+ class Generator(nn.Module):
146
+ def __init__(
147
+ self,
148
+ size,
149
+ style_dim,
150
+ n_mlp,
151
+ channel_multiplier=2,
152
+ blur_kernel=[1, 3, 3, 1],
153
+ lr_mlp=0.01,
154
+ ):
155
+ super().__init__()
156
+
157
+ self.size = size
158
+
159
+ self.style_dim = style_dim
160
+
161
+ layers = [PixelNorm()]
162
+
163
+ for i in range(n_mlp):
164
+ layers.append(
165
+ EqualLinear(
166
+ style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
167
+ )
168
+ )
169
+
170
+ self.style = nn.Sequential(*layers)
171
+
172
+ self.channels = {
173
+ 4: 512,
174
+ 8: 512,
175
+ 16: 512,
176
+ 32: 512,
177
+ 64: 256 * channel_multiplier,
178
+ 128: 128 * channel_multiplier,
179
+ 256: 64 * channel_multiplier,
180
+ 512: 32 * channel_multiplier,
181
+ 1024: 16 * channel_multiplier,
182
+ }
183
+
184
+ self.input = ConstantInput(self.channels[4])
185
+ self.conv1 = StyledConv(
186
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
187
+ )
188
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
189
+
190
+ self.log_size = int(math.log(size, 2)) - 1
191
+ self.num_layers = (self.log_size - 2) * 2 + 1
192
+
193
+ self.convs = nn.ModuleList()
194
+ self.upsamples = nn.ModuleList()
195
+ self.to_rgbs = nn.ModuleList()
196
+ self.noises = nn.Module()
197
+
198
+ in_channel = self.channels[4]
199
+
200
+ for layer_idx in range(self.num_layers):
201
+ res = (layer_idx + 5) // 2
202
+ shape = [1, 1, 2 ** res, 2 ** res]
203
+ self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
204
+
205
+ for i in range(3, self.log_size + 1):
206
+ out_channel = self.channels[2 ** i]
207
+
208
+ self.convs.append(
209
+ StyledConv(
210
+ in_channel,
211
+ out_channel,
212
+ 3,
213
+ style_dim,
214
+ upsample=True,
215
+ blur_kernel=blur_kernel,
216
+ )
217
+ )
218
+
219
+ self.convs.append(
220
+ StyledConv(
221
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
222
+ )
223
+ )
224
+
225
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
226
+
227
+ in_channel = out_channel
228
+
229
+ self.iwt = InverseHaarTransform(3)
230
+
231
+ self.n_latent = self.log_size * 2 - 2
232
+
233
+ def make_noise(self):
234
+ device = self.input.input.device
235
+
236
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
237
+
238
+ for i in range(3, self.log_size + 1):
239
+ for _ in range(2):
240
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
241
+
242
+ return noises
243
+
244
+ def mean_latent(self, n_latent):
245
+ latent_in = torch.randn(
246
+ n_latent, self.style_dim, device=self.input.input.device
247
+ )
248
+ latent = self.style(latent_in).mean(0, keepdim=True)
249
+
250
+ return latent
251
+
252
+ def get_latent(self, input):
253
+ return self.style(input)
254
+
255
+ def forward(
256
+ self,
257
+ styles,
258
+ return_latents=False,
259
+ inject_index=None,
260
+ truncation=1,
261
+ truncation_latent=None,
262
+ input_is_latent=False,
263
+ noise=None,
264
+ randomize_noise=True,
265
+ ):
266
+ if not input_is_latent:
267
+ styles = [self.style(s) for s in styles]
268
+
269
+ if noise is None:
270
+ if randomize_noise:
271
+ noise = [None] * self.num_layers
272
+ else:
273
+ noise = [
274
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
275
+ ]
276
+
277
+ if truncation < 1:
278
+ style_t = []
279
+
280
+ for style in styles:
281
+ style_t.append(
282
+ truncation_latent + truncation * (style - truncation_latent)
283
+ )
284
+
285
+ styles = style_t
286
+
287
+ if len(styles) < 2:
288
+ inject_index = self.n_latent
289
+
290
+ if styles[0].ndim < 3:
291
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
292
+
293
+ else:
294
+ latent = styles[0]
295
+
296
+ else:
297
+ if inject_index is None:
298
+ inject_index = random.randint(1, self.n_latent - 1)
299
+
300
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
301
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
302
+
303
+ latent = torch.cat([latent, latent2], 1)
304
+
305
+ out = self.input(latent)
306
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
307
+
308
+ skip = self.to_rgb1(out, latent[:, 1])
309
+
310
+ i = 1
311
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
312
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
313
+ ):
314
+ out = conv1(out, latent[:, i], noise=noise1)
315
+ out = conv2(out, latent[:, i + 1], noise=noise2)
316
+ skip = to_rgb(out, latent[:, i + 2], skip)
317
+
318
+ i += 2
319
+
320
+ image = self.iwt(skip)
321
+
322
+ if return_latents:
323
+ return image, latent
324
+
325
+ else:
326
+ return image, None
327
+
328
+
329
+ class ConvBlock(nn.Module):
330
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
331
+ super().__init__()
332
+
333
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
334
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
335
+
336
+ def forward(self, input):
337
+ out = self.conv1(input)
338
+ out = self.conv2(out)
339
+
340
+ return out
341
+
342
+
343
+ class FromRGB(nn.Module):
344
+ def __init__(self, out_channel, downsample=True, blur_kernel=[1, 3, 3, 1]):
345
+ super().__init__()
346
+
347
+ self.downsample = downsample
348
+
349
+ if downsample:
350
+ self.iwt = InverseHaarTransform(3)
351
+ self.downsample = Downsample(blur_kernel)
352
+ self.dwt = HaarTransform(3)
353
+
354
+ self.conv = ConvLayer(3 * 4, out_channel, 3)
355
+
356
+ def forward(self, input, skip=None):
357
+ if self.downsample:
358
+ input = self.iwt(input)
359
+ input = self.downsample(input)
360
+ input = self.dwt(input)
361
+
362
+ out = self.conv(input)
363
+
364
+ if skip is not None:
365
+ out = out + skip
366
+
367
+ return input, out
368
+
369
+
370
+ class Discriminator(nn.Module):
371
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
372
+ super().__init__()
373
+
374
+ channels = {
375
+ 4: 512,
376
+ 8: 512,
377
+ 16: 512,
378
+ 32: 512,
379
+ 64: 256 * channel_multiplier,
380
+ 128: 128 * channel_multiplier,
381
+ 256: 64 * channel_multiplier,
382
+ 512: 32 * channel_multiplier,
383
+ 1024: 16 * channel_multiplier,
384
+ }
385
+
386
+ self.dwt = HaarTransform(3)
387
+
388
+ self.from_rgbs = nn.ModuleList()
389
+ self.convs = nn.ModuleList()
390
+
391
+ log_size = int(math.log(size, 2)) - 1
392
+
393
+ in_channel = channels[size]
394
+
395
+ for i in range(log_size, 2, -1):
396
+ out_channel = channels[2 ** (i - 1)]
397
+
398
+ self.from_rgbs.append(FromRGB(in_channel, downsample=i != log_size))
399
+ self.convs.append(ConvBlock(in_channel, out_channel, blur_kernel))
400
+
401
+ in_channel = out_channel
402
+
403
+ self.from_rgbs.append(FromRGB(channels[4]))
404
+
405
+ self.stddev_group = 4
406
+ self.stddev_feat = 1
407
+
408
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
409
+ self.final_linear = nn.Sequential(
410
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
411
+ EqualLinear(channels[4], 1),
412
+ )
413
+
414
+ def forward(self, input):
415
+ input = self.dwt(input)
416
+ out = None
417
+
418
+ for from_rgb, conv in zip(self.from_rgbs, self.convs):
419
+ input, out = from_rgb(input, out)
420
+ out = conv(out)
421
+
422
+ _, out = self.from_rgbs[-1](input, out)
423
+
424
+ batch, channel, height, width = out.shape
425
+ group = min(batch, self.stddev_group)
426
+ stddev = out.view(
427
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
428
+ )
429
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
430
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
431
+ stddev = stddev.repeat(group, 1, height, width)
432
+ out = torch.cat([out, stddev], 1)
433
+
434
+ out = self.final_conv(out)
435
+
436
+ out = out.view(batch, -1)
437
+ out = self.final_linear(out)
438
+
439
+ return out
440
+
train.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import random
4
+ import os
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn, autograd, optim
9
+ from torch.nn import functional as F
10
+ from torch.utils import data
11
+ import torch.distributed as dist
12
+ from torchvision import transforms, utils
13
+ from tqdm import tqdm
14
+
15
+ try:
16
+ import wandb
17
+
18
+ except ImportError:
19
+ wandb = None
20
+
21
+
22
+ from dataset import MultiResolutionDataset
23
+ from distributed import (
24
+ get_rank,
25
+ synchronize,
26
+ reduce_loss_dict,
27
+ reduce_sum,
28
+ get_world_size,
29
+ )
30
+ from op import conv2d_gradfix
31
+ from non_leaking import augment, AdaptiveAugment
32
+
33
+
34
+ def data_sampler(dataset, shuffle, distributed):
35
+ if distributed:
36
+ return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
37
+
38
+ if shuffle:
39
+ return data.RandomSampler(dataset)
40
+
41
+ else:
42
+ return data.SequentialSampler(dataset)
43
+
44
+
45
+ def requires_grad(model, flag=True):
46
+ for p in model.parameters():
47
+ p.requires_grad = flag
48
+
49
+
50
+ def accumulate(model1, model2, decay=0.999):
51
+ par1 = dict(model1.named_parameters())
52
+ par2 = dict(model2.named_parameters())
53
+
54
+ for k in par1.keys():
55
+ par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
56
+
57
+
58
+ def sample_data(loader):
59
+ while True:
60
+ for batch in loader:
61
+ yield batch
62
+
63
+
64
+ def d_logistic_loss(real_pred, fake_pred):
65
+ real_loss = F.softplus(-real_pred)
66
+ fake_loss = F.softplus(fake_pred)
67
+
68
+ return real_loss.mean() + fake_loss.mean()
69
+
70
+
71
+ def d_r1_loss(real_pred, real_img):
72
+ with conv2d_gradfix.no_weight_gradients():
73
+ grad_real, = autograd.grad(
74
+ outputs=real_pred.sum(), inputs=real_img, create_graph=True
75
+ )
76
+ grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
77
+
78
+ return grad_penalty
79
+
80
+
81
+ def g_nonsaturating_loss(fake_pred):
82
+ loss = F.softplus(-fake_pred).mean()
83
+
84
+ return loss
85
+
86
+
87
+ def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
88
+ noise = torch.randn_like(fake_img) / math.sqrt(
89
+ fake_img.shape[2] * fake_img.shape[3]
90
+ )
91
+ grad, = autograd.grad(
92
+ outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
93
+ )
94
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
95
+
96
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
97
+
98
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
99
+
100
+ return path_penalty, path_mean.detach(), path_lengths
101
+
102
+
103
+ def make_noise(batch, latent_dim, n_noise, device):
104
+ if n_noise == 1:
105
+ return torch.randn(batch, latent_dim, device=device)
106
+
107
+ noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
108
+
109
+ return noises
110
+
111
+
112
+ def mixing_noise(batch, latent_dim, prob, device):
113
+ if prob > 0 and random.random() < prob:
114
+ return make_noise(batch, latent_dim, 2, device)
115
+
116
+ else:
117
+ return [make_noise(batch, latent_dim, 1, device)]
118
+
119
+
120
+ def set_grad_none(model, targets):
121
+ for n, p in model.named_parameters():
122
+ if n in targets:
123
+ p.grad = None
124
+
125
+
126
+ def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
127
+ loader = sample_data(loader)
128
+
129
+ pbar = range(args.iter)
130
+
131
+ if get_rank() == 0:
132
+ pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
133
+
134
+ mean_path_length = 0
135
+
136
+ d_loss_val = 0
137
+ r1_loss = torch.tensor(0.0, device=device)
138
+ g_loss_val = 0
139
+ path_loss = torch.tensor(0.0, device=device)
140
+ path_lengths = torch.tensor(0.0, device=device)
141
+ mean_path_length_avg = 0
142
+ loss_dict = {}
143
+
144
+ if args.distributed:
145
+ g_module = generator.module
146
+ d_module = discriminator.module
147
+
148
+ else:
149
+ g_module = generator
150
+ d_module = discriminator
151
+
152
+ accum = 0.5 ** (32 / (10 * 1000))
153
+ ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
154
+ r_t_stat = 0
155
+
156
+ if args.augment and args.augment_p == 0:
157
+ ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8, device)
158
+
159
+ sample_z = torch.randn(args.n_sample, args.latent, device=device)
160
+
161
+ for idx in pbar:
162
+ i = idx + args.start_iter
163
+
164
+ if i > args.iter:
165
+ print("Done!")
166
+
167
+ break
168
+
169
+ real_img = next(loader)
170
+ real_img = real_img.to(device)
171
+
172
+ requires_grad(generator, False)
173
+ requires_grad(discriminator, True)
174
+
175
+ noise = mixing_noise(args.batch, args.latent, args.mixing, device)
176
+ fake_img, _ = generator(noise)
177
+
178
+ if args.augment:
179
+ real_img_aug, _ = augment(real_img, ada_aug_p)
180
+ fake_img, _ = augment(fake_img, ada_aug_p)
181
+
182
+ else:
183
+ real_img_aug = real_img
184
+
185
+ fake_pred = discriminator(fake_img)
186
+ real_pred = discriminator(real_img_aug)
187
+ d_loss = d_logistic_loss(real_pred, fake_pred)
188
+
189
+ loss_dict["d"] = d_loss
190
+ loss_dict["real_score"] = real_pred.mean()
191
+ loss_dict["fake_score"] = fake_pred.mean()
192
+
193
+ discriminator.zero_grad()
194
+ d_loss.backward()
195
+ d_optim.step()
196
+
197
+ if args.augment and args.augment_p == 0:
198
+ ada_aug_p = ada_augment.tune(real_pred)
199
+ r_t_stat = ada_augment.r_t_stat
200
+
201
+ d_regularize = i % args.d_reg_every == 0
202
+
203
+ if d_regularize:
204
+ real_img.requires_grad = True
205
+
206
+ if args.augment:
207
+ real_img_aug, _ = augment(real_img, ada_aug_p)
208
+
209
+ else:
210
+ real_img_aug = real_img
211
+
212
+ real_pred = discriminator(real_img_aug)
213
+ r1_loss = d_r1_loss(real_pred, real_img)
214
+
215
+ discriminator.zero_grad()
216
+ (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()
217
+
218
+ d_optim.step()
219
+
220
+ loss_dict["r1"] = r1_loss
221
+
222
+ requires_grad(generator, True)
223
+ requires_grad(discriminator, False)
224
+
225
+ noise = mixing_noise(args.batch, args.latent, args.mixing, device)
226
+ fake_img, _ = generator(noise)
227
+
228
+ if args.augment:
229
+ fake_img, _ = augment(fake_img, ada_aug_p)
230
+
231
+ fake_pred = discriminator(fake_img)
232
+ g_loss = g_nonsaturating_loss(fake_pred)
233
+
234
+ loss_dict["g"] = g_loss
235
+
236
+ generator.zero_grad()
237
+ g_loss.backward()
238
+ g_optim.step()
239
+
240
+ g_regularize = i % args.g_reg_every == 0
241
+
242
+ if g_regularize:
243
+ path_batch_size = max(1, args.batch // args.path_batch_shrink)
244
+ noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
245
+ fake_img, latents = generator(noise, return_latents=True)
246
+
247
+ path_loss, mean_path_length, path_lengths = g_path_regularize(
248
+ fake_img, latents, mean_path_length
249
+ )
250
+
251
+ generator.zero_grad()
252
+ weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
253
+
254
+ if args.path_batch_shrink:
255
+ weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
256
+
257
+ weighted_path_loss.backward()
258
+
259
+ g_optim.step()
260
+
261
+ mean_path_length_avg = (
262
+ reduce_sum(mean_path_length).item() / get_world_size()
263
+ )
264
+
265
+ loss_dict["path"] = path_loss
266
+ loss_dict["path_length"] = path_lengths.mean()
267
+
268
+ accumulate(g_ema, g_module, accum)
269
+
270
+ loss_reduced = reduce_loss_dict(loss_dict)
271
+
272
+ d_loss_val = loss_reduced["d"].mean().item()
273
+ g_loss_val = loss_reduced["g"].mean().item()
274
+ r1_val = loss_reduced["r1"].mean().item()
275
+ path_loss_val = loss_reduced["path"].mean().item()
276
+ real_score_val = loss_reduced["real_score"].mean().item()
277
+ fake_score_val = loss_reduced["fake_score"].mean().item()
278
+ path_length_val = loss_reduced["path_length"].mean().item()
279
+
280
+ if get_rank() == 0:
281
+ pbar.set_description(
282
+ (
283
+ f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
284
+ f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
285
+ f"augment: {ada_aug_p:.4f}"
286
+ )
287
+ )
288
+
289
+ if wandb and args.wandb:
290
+ wandb.log(
291
+ {
292
+ "Generator": g_loss_val,
293
+ "Discriminator": d_loss_val,
294
+ "Augment": ada_aug_p,
295
+ "Rt": r_t_stat,
296
+ "R1": r1_val,
297
+ "Path Length Regularization": path_loss_val,
298
+ "Mean Path Length": mean_path_length,
299
+ "Real Score": real_score_val,
300
+ "Fake Score": fake_score_val,
301
+ "Path Length": path_length_val,
302
+ }
303
+ )
304
+
305
+ if i % 100 == 0:
306
+ with torch.no_grad():
307
+ g_ema.eval()
308
+ sample, _ = g_ema([sample_z])
309
+ utils.save_image(
310
+ sample,
311
+ f"sample/{str(i).zfill(6)}.png",
312
+ nrow=int(args.n_sample ** 0.5),
313
+ normalize=True,
314
+ range=(-1, 1),
315
+ )
316
+
317
+ if i % 10000 == 0:
318
+ torch.save(
319
+ {
320
+ "g": g_module.state_dict(),
321
+ "d": d_module.state_dict(),
322
+ "g_ema": g_ema.state_dict(),
323
+ "g_optim": g_optim.state_dict(),
324
+ "d_optim": d_optim.state_dict(),
325
+ "args": args,
326
+ "ada_aug_p": ada_aug_p,
327
+ },
328
+ f"checkpoint/{str(i).zfill(6)}.pt",
329
+ )
330
+
331
+
332
+ if __name__ == "__main__":
333
+ device = "cuda"
334
+
335
+ parser = argparse.ArgumentParser(description="StyleGAN2 trainer")
336
+
337
+ parser.add_argument("path", type=str, help="path to the lmdb dataset")
338
+ parser.add_argument('--arch', type=str, default='stylegan2', help='model architectures (stylegan2 | swagan)')
339
+ parser.add_argument(
340
+ "--iter", type=int, default=800000, help="total training iterations"
341
+ )
342
+ parser.add_argument(
343
+ "--batch", type=int, default=16, help="batch sizes for each gpus"
344
+ )
345
+ parser.add_argument(
346
+ "--n_sample",
347
+ type=int,
348
+ default=64,
349
+ help="number of the samples generated during training",
350
+ )
351
+ parser.add_argument(
352
+ "--size", type=int, default=256, help="image sizes for the model"
353
+ )
354
+ parser.add_argument(
355
+ "--r1", type=float, default=10, help="weight of the r1 regularization"
356
+ )
357
+ parser.add_argument(
358
+ "--path_regularize",
359
+ type=float,
360
+ default=2,
361
+ help="weight of the path length regularization",
362
+ )
363
+ parser.add_argument(
364
+ "--path_batch_shrink",
365
+ type=int,
366
+ default=2,
367
+ help="batch size reducing factor for the path length regularization (reduce memory consumption)",
368
+ )
369
+ parser.add_argument(
370
+ "--d_reg_every",
371
+ type=int,
372
+ default=16,
373
+ help="interval of the applying r1 regularization",
374
+ )
375
+ parser.add_argument(
376
+ "--g_reg_every",
377
+ type=int,
378
+ default=4,
379
+ help="interval of the applying path length regularization",
380
+ )
381
+ parser.add_argument(
382
+ "--mixing", type=float, default=0.9, help="probability of latent code mixing"
383
+ )
384
+ parser.add_argument(
385
+ "--ckpt",
386
+ type=str,
387
+ default=None,
388
+ help="path to the checkpoints to resume training",
389
+ )
390
+ parser.add_argument("--lr", type=float, default=0.002, help="learning rate")
391
+ parser.add_argument(
392
+ "--channel_multiplier",
393
+ type=int,
394
+ default=2,
395
+ help="channel multiplier factor for the model. config-f = 2, else = 1",
396
+ )
397
+ parser.add_argument(
398
+ "--wandb", action="store_true", help="use weights and biases logging"
399
+ )
400
+ parser.add_argument(
401
+ "--local_rank", type=int, default=0, help="local rank for distributed training"
402
+ )
403
+ parser.add_argument(
404
+ "--augment", action="store_true", help="apply non leaking augmentation"
405
+ )
406
+ parser.add_argument(
407
+ "--augment_p",
408
+ type=float,
409
+ default=0,
410
+ help="probability of applying augmentation. 0 = use adaptive augmentation",
411
+ )
412
+ parser.add_argument(
413
+ "--ada_target",
414
+ type=float,
415
+ default=0.6,
416
+ help="target augmentation probability for adaptive augmentation",
417
+ )
418
+ parser.add_argument(
419
+ "--ada_length",
420
+ type=int,
421
+ default=500 * 1000,
422
+ help="target duraing to reach augmentation probability for adaptive augmentation",
423
+ )
424
+ parser.add_argument(
425
+ "--ada_every",
426
+ type=int,
427
+ default=256,
428
+ help="probability update interval of the adaptive augmentation",
429
+ )
430
+
431
+ args = parser.parse_args()
432
+
433
+ n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
434
+ args.distributed = n_gpu > 1
435
+
436
+ if args.distributed:
437
+ torch.cuda.set_device(args.local_rank)
438
+ torch.distributed.init_process_group(backend="nccl", init_method="env://")
439
+ synchronize()
440
+
441
+ args.latent = 512
442
+ args.n_mlp = 8
443
+
444
+ args.start_iter = 0
445
+
446
+ if args.arch == 'stylegan2':
447
+ from model import Generator, Discriminator
448
+
449
+ elif args.arch == 'swagan':
450
+ from swagan import Generator, Discriminator
451
+
452
+ generator = Generator(
453
+ args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
454
+ ).to(device)
455
+ discriminator = Discriminator(
456
+ args.size, channel_multiplier=args.channel_multiplier
457
+ ).to(device)
458
+ g_ema = Generator(
459
+ args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
460
+ ).to(device)
461
+ g_ema.eval()
462
+ accumulate(g_ema, generator, 0)
463
+
464
+ g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
465
+ d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
466
+
467
+ g_optim = optim.Adam(
468
+ generator.parameters(),
469
+ lr=args.lr * g_reg_ratio,
470
+ betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
471
+ )
472
+ d_optim = optim.Adam(
473
+ discriminator.parameters(),
474
+ lr=args.lr * d_reg_ratio,
475
+ betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
476
+ )
477
+
478
+ if args.ckpt is not None:
479
+ print("load model:", args.ckpt)
480
+
481
+ ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)
482
+
483
+ try:
484
+ ckpt_name = os.path.basename(args.ckpt)
485
+ args.start_iter = int(os.path.splitext(ckpt_name)[0])
486
+
487
+ except ValueError:
488
+ pass
489
+
490
+ generator.load_state_dict(ckpt["g"])
491
+ discriminator.load_state_dict(ckpt["d"])
492
+ g_ema.load_state_dict(ckpt["g_ema"])
493
+
494
+ g_optim.load_state_dict(ckpt["g_optim"])
495
+ d_optim.load_state_dict(ckpt["d_optim"])
496
+
497
+ if args.distributed:
498
+ generator = nn.parallel.DistributedDataParallel(
499
+ generator,
500
+ device_ids=[args.local_rank],
501
+ output_device=args.local_rank,
502
+ broadcast_buffers=False,
503
+ )
504
+
505
+ discriminator = nn.parallel.DistributedDataParallel(
506
+ discriminator,
507
+ device_ids=[args.local_rank],
508
+ output_device=args.local_rank,
509
+ broadcast_buffers=False,
510
+ )
511
+
512
+ transform = transforms.Compose(
513
+ [
514
+ transforms.RandomHorizontalFlip(),
515
+ transforms.ToTensor(),
516
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
517
+ ]
518
+ )
519
+
520
+ dataset = MultiResolutionDataset(args.path, transform, args.size)
521
+ loader = data.DataLoader(
522
+ dataset,
523
+ batch_size=args.batch,
524
+ sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
525
+ drop_last=True,
526
+ )
527
+
528
+ if get_rank() == 0 and wandb is not None and args.wandb:
529
+ wandb.init(project="stylegan 2")
530
+
531
+ train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)