Spaces:
Runtime error
Runtime error
Bill Ton Hoang Nguyen
commited on
Commit
•
e5e3367
1
Parent(s):
e36fc61
LOL
Browse files- .DS_Store +0 -0
- models/.DS_Store +0 -0
- models/stylegan2/.DS_Store +0 -0
- models/stylegan2/stylegan2-pytorch/.DS_Store +0 -0
- models/stylegan2/stylegan2-pytorch/LICENSE +21 -0
- models/stylegan2/stylegan2-pytorch/LICENSE-FID +201 -0
- models/stylegan2/stylegan2-pytorch/LICENSE-LPIPS +24 -0
- models/stylegan2/stylegan2-pytorch/LICENSE-NVIDIA +101 -0
- models/stylegan2/stylegan2-pytorch/README.md +83 -0
- models/stylegan2/stylegan2-pytorch/calc_inception.py +116 -0
- models/stylegan2/stylegan2-pytorch/checkpoint/.gitignore +1 -0
- models/stylegan2/stylegan2-pytorch/convert_weight.py +283 -0
- models/stylegan2/stylegan2-pytorch/dataset.py +40 -0
- models/stylegan2/stylegan2-pytorch/distributed.py +126 -0
- models/stylegan2/stylegan2-pytorch/fid.py +107 -0
- models/stylegan2/stylegan2-pytorch/generate.py +55 -0
- models/stylegan2/stylegan2-pytorch/inception.py +310 -0
- models/stylegan2/stylegan2-pytorch/lpips/__init__.py +160 -0
- models/stylegan2/stylegan2-pytorch/lpips/base_model.py +58 -0
- models/stylegan2/stylegan2-pytorch/lpips/dist_model.py +284 -0
- models/stylegan2/stylegan2-pytorch/lpips/networks_basic.py +187 -0
- models/stylegan2/stylegan2-pytorch/lpips/pretrained_networks.py +181 -0
- models/stylegan2/stylegan2-pytorch/lpips/weights/v0.0/alex.pth +0 -0
- models/stylegan2/stylegan2-pytorch/lpips/weights/v0.0/squeeze.pth +0 -0
- models/stylegan2/stylegan2-pytorch/lpips/weights/v0.0/vgg.pth +0 -0
- models/stylegan2/stylegan2-pytorch/lpips/weights/v0.1/alex.pth +0 -0
- models/stylegan2/stylegan2-pytorch/lpips/weights/v0.1/squeeze.pth +0 -0
- models/stylegan2/stylegan2-pytorch/lpips/weights/v0.1/vgg.pth +0 -0
- models/stylegan2/stylegan2-pytorch/model.py +703 -0
- models/stylegan2/stylegan2-pytorch/non_leaking.py +137 -0
- models/stylegan2/stylegan2-pytorch/op/__init__.py +2 -0
- models/stylegan2/stylegan2-pytorch/op/fused_act.py +92 -0
- models/stylegan2/stylegan2-pytorch/op/fused_bias_act.cpp +21 -0
- models/stylegan2/stylegan2-pytorch/op/fused_bias_act_kernel.cu +99 -0
- models/stylegan2/stylegan2-pytorch/op/setup.py +33 -0
- models/stylegan2/stylegan2-pytorch/op/upfirdn2d.cpp +23 -0
- models/stylegan2/stylegan2-pytorch/op/upfirdn2d.py +198 -0
- models/stylegan2/stylegan2-pytorch/op/upfirdn2d_kernel.cu +369 -0
- models/stylegan2/stylegan2-pytorch/ppl.py +104 -0
- models/stylegan2/stylegan2-pytorch/prepare_data.py +82 -0
- models/stylegan2/stylegan2-pytorch/projector.py +203 -0
- models/stylegan2/stylegan2-pytorch/sample/.gitignore +1 -0
- models/stylegan2/stylegan2-pytorch/train.py +413 -0
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
models/.DS_Store
CHANGED
Binary files a/models/.DS_Store and b/models/.DS_Store differ
|
|
models/stylegan2/.DS_Store
CHANGED
Binary files a/models/stylegan2/.DS_Store and b/models/stylegan2/.DS_Store differ
|
|
models/stylegan2/stylegan2-pytorch/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
models/stylegan2/stylegan2-pytorch/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019 Kim Seonghyeon
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
models/stylegan2/stylegan2-pytorch/LICENSE-FID
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
models/stylegan2/stylegan2-pytorch/LICENSE-LPIPS
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without
|
5 |
+
modification, are permitted provided that the following conditions are met:
|
6 |
+
|
7 |
+
* Redistributions of source code must retain the above copyright notice, this
|
8 |
+
list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer in the documentation
|
12 |
+
and/or other materials provided with the distribution.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
16 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
18 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
19 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
20 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
21 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
22 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
23 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
24 |
+
|
models/stylegan2/stylegan2-pytorch/LICENSE-NVIDIA
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
|
3 |
+
|
4 |
+
Nvidia Source Code License-NC
|
5 |
+
|
6 |
+
=======================================================================
|
7 |
+
|
8 |
+
1. Definitions
|
9 |
+
|
10 |
+
"Licensor" means any person or entity that distributes its Work.
|
11 |
+
|
12 |
+
"Software" means the original work of authorship made available under
|
13 |
+
this License.
|
14 |
+
|
15 |
+
"Work" means the Software and any additions to or derivative works of
|
16 |
+
the Software that are made available under this License.
|
17 |
+
|
18 |
+
"Nvidia Processors" means any central processing unit (CPU), graphics
|
19 |
+
processing unit (GPU), field-programmable gate array (FPGA),
|
20 |
+
application-specific integrated circuit (ASIC) or any combination
|
21 |
+
thereof designed, made, sold, or provided by Nvidia or its affiliates.
|
22 |
+
|
23 |
+
The terms "reproduce," "reproduction," "derivative works," and
|
24 |
+
"distribution" have the meaning as provided under U.S. copyright law;
|
25 |
+
provided, however, that for the purposes of this License, derivative
|
26 |
+
works shall not include works that remain separable from, or merely
|
27 |
+
link (or bind by name) to the interfaces of, the Work.
|
28 |
+
|
29 |
+
Works, including the Software, are "made available" under this License
|
30 |
+
by including in or with the Work either (a) a copyright notice
|
31 |
+
referencing the applicability of this License to the Work, or (b) a
|
32 |
+
copy of this License.
|
33 |
+
|
34 |
+
2. License Grants
|
35 |
+
|
36 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this
|
37 |
+
License, each Licensor grants to you a perpetual, worldwide,
|
38 |
+
non-exclusive, royalty-free, copyright license to reproduce,
|
39 |
+
prepare derivative works of, publicly display, publicly perform,
|
40 |
+
sublicense and distribute its Work and any resulting derivative
|
41 |
+
works in any form.
|
42 |
+
|
43 |
+
3. Limitations
|
44 |
+
|
45 |
+
3.1 Redistribution. You may reproduce or distribute the Work only
|
46 |
+
if (a) you do so under this License, (b) you include a complete
|
47 |
+
copy of this License with your distribution, and (c) you retain
|
48 |
+
without modification any copyright, patent, trademark, or
|
49 |
+
attribution notices that are present in the Work.
|
50 |
+
|
51 |
+
3.2 Derivative Works. You may specify that additional or different
|
52 |
+
terms apply to the use, reproduction, and distribution of your
|
53 |
+
derivative works of the Work ("Your Terms") only if (a) Your Terms
|
54 |
+
provide that the use limitation in Section 3.3 applies to your
|
55 |
+
derivative works, and (b) you identify the specific derivative
|
56 |
+
works that are subject to Your Terms. Notwithstanding Your Terms,
|
57 |
+
this License (including the redistribution requirements in Section
|
58 |
+
3.1) will continue to apply to the Work itself.
|
59 |
+
|
60 |
+
3.3 Use Limitation. The Work and any derivative works thereof only
|
61 |
+
may be used or intended for use non-commercially. The Work or
|
62 |
+
derivative works thereof may be used or intended for use by Nvidia
|
63 |
+
or its affiliates commercially or non-commercially. As used herein,
|
64 |
+
"non-commercially" means for research or evaluation purposes only.
|
65 |
+
|
66 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim
|
67 |
+
against any Licensor (including any claim, cross-claim or
|
68 |
+
counterclaim in a lawsuit) to enforce any patents that you allege
|
69 |
+
are infringed by any Work, then your rights under this License from
|
70 |
+
such Licensor (including the grants in Sections 2.1 and 2.2) will
|
71 |
+
terminate immediately.
|
72 |
+
|
73 |
+
3.5 Trademarks. This License does not grant any rights to use any
|
74 |
+
Licensor's or its affiliates' names, logos, or trademarks, except
|
75 |
+
as necessary to reproduce the notices described in this License.
|
76 |
+
|
77 |
+
3.6 Termination. If you violate any term of this License, then your
|
78 |
+
rights under this License (including the grants in Sections 2.1 and
|
79 |
+
2.2) will terminate immediately.
|
80 |
+
|
81 |
+
4. Disclaimer of Warranty.
|
82 |
+
|
83 |
+
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
84 |
+
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
85 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
86 |
+
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
|
87 |
+
THIS LICENSE.
|
88 |
+
|
89 |
+
5. Limitation of Liability.
|
90 |
+
|
91 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
92 |
+
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
93 |
+
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
94 |
+
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
95 |
+
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
96 |
+
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
97 |
+
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
98 |
+
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
|
99 |
+
THE POSSIBILITY OF SUCH DAMAGES.
|
100 |
+
|
101 |
+
=======================================================================
|
models/stylegan2/stylegan2-pytorch/README.md
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# StyleGAN 2 in PyTorch
|
2 |
+
|
3 |
+
Implementation of Analyzing and Improving the Image Quality of StyleGAN (https://arxiv.org/abs/1912.04958) in PyTorch
|
4 |
+
|
5 |
+
## Notice
|
6 |
+
|
7 |
+
I have tried to match official implementation as close as possible, but maybe there are some details I missed. So please use this implementation with care.
|
8 |
+
|
9 |
+
## Requirements
|
10 |
+
|
11 |
+
I have tested on:
|
12 |
+
|
13 |
+
* PyTorch 1.3.1
|
14 |
+
* CUDA 10.1/10.2
|
15 |
+
|
16 |
+
## Usage
|
17 |
+
|
18 |
+
First create lmdb datasets:
|
19 |
+
|
20 |
+
> python prepare_data.py --out LMDB_PATH --n_worker N_WORKER --size SIZE1,SIZE2,SIZE3,... DATASET_PATH
|
21 |
+
|
22 |
+
This will convert images to jpeg and pre-resizes it. This implementation does not use progressive growing, but you can create multiple resolution datasets using size arguments with comma separated lists, for the cases that you want to try another resolutions later.
|
23 |
+
|
24 |
+
Then you can train model in distributed settings
|
25 |
+
|
26 |
+
> python -m torch.distributed.launch --nproc_per_node=N_GPU --master_port=PORT train.py --batch BATCH_SIZE LMDB_PATH
|
27 |
+
|
28 |
+
train.py supports Weights & Biases logging. If you want to use it, add --wandb arguments to the script.
|
29 |
+
|
30 |
+
### Convert weight from official checkpoints
|
31 |
+
|
32 |
+
You need to clone official repositories, (https://github.com/NVlabs/stylegan2) as it is requires for load official checkpoints.
|
33 |
+
|
34 |
+
Next, create a conda environment with TF-GPU and Torch-CPU (using GPU for both results in CUDA version mismatches):<br>
|
35 |
+
`conda create -n tf_torch python=3.7 requests tensorflow-gpu=1.14 cudatoolkit=10.0 numpy=1.14 pytorch=1.6 torchvision cpuonly -c pytorch`
|
36 |
+
|
37 |
+
For example, if you cloned repositories in ~/stylegan2 and downloaded stylegan2-ffhq-config-f.pkl, You can convert it like this:
|
38 |
+
|
39 |
+
> python convert_weight.py --repo ~/stylegan2 stylegan2-ffhq-config-f.pkl
|
40 |
+
|
41 |
+
This will create converted stylegan2-ffhq-config-f.pt file.
|
42 |
+
|
43 |
+
If using GCC, you might have to set `-D_GLIBCXX_USE_CXX11_ABI=1` in `~/stylegan2/dnnlib/tflib/custom_ops.py`.
|
44 |
+
|
45 |
+
### Generate samples
|
46 |
+
|
47 |
+
> python generate.py --sample N_FACES --pics N_PICS --ckpt PATH_CHECKPOINT
|
48 |
+
|
49 |
+
You should change your size (--size 256 for example) if you train with another dimension.
|
50 |
+
|
51 |
+
### Project images to latent spaces
|
52 |
+
|
53 |
+
> python projector.py --ckpt [CHECKPOINT] --size [GENERATOR_OUTPUT_SIZE] FILE1 FILE2 ...
|
54 |
+
|
55 |
+
## Pretrained Checkpoints
|
56 |
+
|
57 |
+
[Link](https://drive.google.com/open?id=1PQutd-JboOCOZqmd95XWxWrO8gGEvRcO)
|
58 |
+
|
59 |
+
I have trained the 256px model on FFHQ 550k iterations. I got FID about 4.5. Maybe data preprocessing, resolution, training loop could made this difference, but currently I don't know the exact reason of FID differences.
|
60 |
+
|
61 |
+
## Samples
|
62 |
+
|
63 |
+
![Sample with truncation](doc/sample.png)
|
64 |
+
|
65 |
+
At 110,000 iterations. (trained on 3.52M images)
|
66 |
+
|
67 |
+
### Samples from converted weights
|
68 |
+
|
69 |
+
![Sample from FFHQ](doc/stylegan2-ffhq-config-f.png)
|
70 |
+
|
71 |
+
Sample from FFHQ (1024px)
|
72 |
+
|
73 |
+
![Sample from LSUN Church](doc/stylegan2-church-config-f.png)
|
74 |
+
|
75 |
+
Sample from LSUN Church (256px)
|
76 |
+
|
77 |
+
## License
|
78 |
+
|
79 |
+
Model details and custom CUDA kernel codes are from official repostiories: https://github.com/NVlabs/stylegan2
|
80 |
+
|
81 |
+
Codes for Learned Perceptual Image Patch Similarity, LPIPS came from https://github.com/richzhang/PerceptualSimilarity
|
82 |
+
|
83 |
+
To match FID scores more closely to tensorflow official implementations, I have used FID Inception V3 implementations in https://github.com/mseitzer/pytorch-fid
|
models/stylegan2/stylegan2-pytorch/calc_inception.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pickle
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from torchvision import transforms
|
10 |
+
from torchvision.models import inception_v3, Inception3
|
11 |
+
import numpy as np
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from inception import InceptionV3
|
15 |
+
from dataset import MultiResolutionDataset
|
16 |
+
|
17 |
+
|
18 |
+
class Inception3Feature(Inception3):
|
19 |
+
def forward(self, x):
|
20 |
+
if x.shape[2] != 299 or x.shape[3] != 299:
|
21 |
+
x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True)
|
22 |
+
|
23 |
+
x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3
|
24 |
+
x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32
|
25 |
+
x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32
|
26 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64
|
27 |
+
|
28 |
+
x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64
|
29 |
+
x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80
|
30 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192
|
31 |
+
|
32 |
+
x = self.Mixed_5b(x) # 35 x 35 x 192
|
33 |
+
x = self.Mixed_5c(x) # 35 x 35 x 256
|
34 |
+
x = self.Mixed_5d(x) # 35 x 35 x 288
|
35 |
+
|
36 |
+
x = self.Mixed_6a(x) # 35 x 35 x 288
|
37 |
+
x = self.Mixed_6b(x) # 17 x 17 x 768
|
38 |
+
x = self.Mixed_6c(x) # 17 x 17 x 768
|
39 |
+
x = self.Mixed_6d(x) # 17 x 17 x 768
|
40 |
+
x = self.Mixed_6e(x) # 17 x 17 x 768
|
41 |
+
|
42 |
+
x = self.Mixed_7a(x) # 17 x 17 x 768
|
43 |
+
x = self.Mixed_7b(x) # 8 x 8 x 1280
|
44 |
+
x = self.Mixed_7c(x) # 8 x 8 x 2048
|
45 |
+
|
46 |
+
x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048
|
47 |
+
|
48 |
+
return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048
|
49 |
+
|
50 |
+
|
51 |
+
def load_patched_inception_v3():
|
52 |
+
# inception = inception_v3(pretrained=True)
|
53 |
+
# inception_feat = Inception3Feature()
|
54 |
+
# inception_feat.load_state_dict(inception.state_dict())
|
55 |
+
inception_feat = InceptionV3([3], normalize_input=False)
|
56 |
+
|
57 |
+
return inception_feat
|
58 |
+
|
59 |
+
|
60 |
+
@torch.no_grad()
|
61 |
+
def extract_features(loader, inception, device):
|
62 |
+
pbar = tqdm(loader)
|
63 |
+
|
64 |
+
feature_list = []
|
65 |
+
|
66 |
+
for img in pbar:
|
67 |
+
img = img.to(device)
|
68 |
+
feature = inception(img)[0].view(img.shape[0], -1)
|
69 |
+
feature_list.append(feature.to('cpu'))
|
70 |
+
|
71 |
+
features = torch.cat(feature_list, 0)
|
72 |
+
|
73 |
+
return features
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__':
|
77 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
78 |
+
|
79 |
+
parser = argparse.ArgumentParser(
|
80 |
+
description='Calculate Inception v3 features for datasets'
|
81 |
+
)
|
82 |
+
parser.add_argument('--size', type=int, default=256)
|
83 |
+
parser.add_argument('--batch', default=64, type=int, help='batch size')
|
84 |
+
parser.add_argument('--n_sample', type=int, default=50000)
|
85 |
+
parser.add_argument('--flip', action='store_true')
|
86 |
+
parser.add_argument('path', metavar='PATH', help='path to datset lmdb file')
|
87 |
+
|
88 |
+
args = parser.parse_args()
|
89 |
+
|
90 |
+
inception = load_patched_inception_v3()
|
91 |
+
inception = nn.DataParallel(inception).eval().to(device)
|
92 |
+
|
93 |
+
transform = transforms.Compose(
|
94 |
+
[
|
95 |
+
transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
|
96 |
+
transforms.ToTensor(),
|
97 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
98 |
+
]
|
99 |
+
)
|
100 |
+
|
101 |
+
dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size)
|
102 |
+
loader = DataLoader(dset, batch_size=args.batch, num_workers=4)
|
103 |
+
|
104 |
+
features = extract_features(loader, inception, device).numpy()
|
105 |
+
|
106 |
+
features = features[: args.n_sample]
|
107 |
+
|
108 |
+
print(f'extracted {features.shape[0]} features')
|
109 |
+
|
110 |
+
mean = np.mean(features, 0)
|
111 |
+
cov = np.cov(features, rowvar=False)
|
112 |
+
|
113 |
+
name = os.path.splitext(os.path.basename(args.path))[0]
|
114 |
+
|
115 |
+
with open(f'inception_{name}.pkl', 'wb') as f:
|
116 |
+
pickle.dump({'mean': mean, 'cov': cov, 'size': args.size, 'path': args.path}, f)
|
models/stylegan2/stylegan2-pytorch/checkpoint/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.pt
|
models/stylegan2/stylegan2-pytorch/convert_weight.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import pickle
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from torchvision import utils
|
10 |
+
|
11 |
+
from model import Generator, Discriminator
|
12 |
+
|
13 |
+
|
14 |
+
def convert_modconv(vars, source_name, target_name, flip=False):
|
15 |
+
weight = vars[source_name + '/weight'].value().eval()
|
16 |
+
mod_weight = vars[source_name + '/mod_weight'].value().eval()
|
17 |
+
mod_bias = vars[source_name + '/mod_bias'].value().eval()
|
18 |
+
noise = vars[source_name + '/noise_strength'].value().eval()
|
19 |
+
bias = vars[source_name + '/bias'].value().eval()
|
20 |
+
|
21 |
+
dic = {
|
22 |
+
'conv.weight': np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
|
23 |
+
'conv.modulation.weight': mod_weight.transpose((1, 0)),
|
24 |
+
'conv.modulation.bias': mod_bias + 1,
|
25 |
+
'noise.weight': np.array([noise]),
|
26 |
+
'activate.bias': bias,
|
27 |
+
}
|
28 |
+
|
29 |
+
dic_torch = {}
|
30 |
+
|
31 |
+
for k, v in dic.items():
|
32 |
+
dic_torch[target_name + '.' + k] = torch.from_numpy(v)
|
33 |
+
|
34 |
+
if flip:
|
35 |
+
dic_torch[target_name + '.conv.weight'] = torch.flip(
|
36 |
+
dic_torch[target_name + '.conv.weight'], [3, 4]
|
37 |
+
)
|
38 |
+
|
39 |
+
return dic_torch
|
40 |
+
|
41 |
+
|
42 |
+
def convert_conv(vars, source_name, target_name, bias=True, start=0):
|
43 |
+
weight = vars[source_name + '/weight'].value().eval()
|
44 |
+
|
45 |
+
dic = {'weight': weight.transpose((3, 2, 0, 1))}
|
46 |
+
|
47 |
+
if bias:
|
48 |
+
dic['bias'] = vars[source_name + '/bias'].value().eval()
|
49 |
+
|
50 |
+
dic_torch = {}
|
51 |
+
|
52 |
+
dic_torch[target_name + f'.{start}.weight'] = torch.from_numpy(dic['weight'])
|
53 |
+
|
54 |
+
if bias:
|
55 |
+
dic_torch[target_name + f'.{start + 1}.bias'] = torch.from_numpy(dic['bias'])
|
56 |
+
|
57 |
+
return dic_torch
|
58 |
+
|
59 |
+
|
60 |
+
def convert_torgb(vars, source_name, target_name):
|
61 |
+
weight = vars[source_name + '/weight'].value().eval()
|
62 |
+
mod_weight = vars[source_name + '/mod_weight'].value().eval()
|
63 |
+
mod_bias = vars[source_name + '/mod_bias'].value().eval()
|
64 |
+
bias = vars[source_name + '/bias'].value().eval()
|
65 |
+
|
66 |
+
dic = {
|
67 |
+
'conv.weight': np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
|
68 |
+
'conv.modulation.weight': mod_weight.transpose((1, 0)),
|
69 |
+
'conv.modulation.bias': mod_bias + 1,
|
70 |
+
'bias': bias.reshape((1, 3, 1, 1)),
|
71 |
+
}
|
72 |
+
|
73 |
+
dic_torch = {}
|
74 |
+
|
75 |
+
for k, v in dic.items():
|
76 |
+
dic_torch[target_name + '.' + k] = torch.from_numpy(v)
|
77 |
+
|
78 |
+
return dic_torch
|
79 |
+
|
80 |
+
|
81 |
+
def convert_dense(vars, source_name, target_name):
|
82 |
+
weight = vars[source_name + '/weight'].value().eval()
|
83 |
+
bias = vars[source_name + '/bias'].value().eval()
|
84 |
+
|
85 |
+
dic = {'weight': weight.transpose((1, 0)), 'bias': bias}
|
86 |
+
|
87 |
+
dic_torch = {}
|
88 |
+
|
89 |
+
for k, v in dic.items():
|
90 |
+
dic_torch[target_name + '.' + k] = torch.from_numpy(v)
|
91 |
+
|
92 |
+
return dic_torch
|
93 |
+
|
94 |
+
|
95 |
+
def update(state_dict, new):
|
96 |
+
for k, v in new.items():
|
97 |
+
if k not in state_dict:
|
98 |
+
raise KeyError(k + ' is not found')
|
99 |
+
|
100 |
+
if v.shape != state_dict[k].shape:
|
101 |
+
raise ValueError(f'Shape mismatch: {v.shape} vs {state_dict[k].shape}')
|
102 |
+
|
103 |
+
state_dict[k] = v
|
104 |
+
|
105 |
+
|
106 |
+
def discriminator_fill_statedict(statedict, vars, size):
|
107 |
+
log_size = int(math.log(size, 2))
|
108 |
+
|
109 |
+
update(statedict, convert_conv(vars, f'{size}x{size}/FromRGB', 'convs.0'))
|
110 |
+
|
111 |
+
conv_i = 1
|
112 |
+
|
113 |
+
for i in range(log_size - 2, 0, -1):
|
114 |
+
reso = 4 * 2 ** i
|
115 |
+
update(
|
116 |
+
statedict,
|
117 |
+
convert_conv(vars, f'{reso}x{reso}/Conv0', f'convs.{conv_i}.conv1'),
|
118 |
+
)
|
119 |
+
update(
|
120 |
+
statedict,
|
121 |
+
convert_conv(
|
122 |
+
vars, f'{reso}x{reso}/Conv1_down', f'convs.{conv_i}.conv2', start=1
|
123 |
+
),
|
124 |
+
)
|
125 |
+
update(
|
126 |
+
statedict,
|
127 |
+
convert_conv(
|
128 |
+
vars, f'{reso}x{reso}/Skip', f'convs.{conv_i}.skip', start=1, bias=False
|
129 |
+
),
|
130 |
+
)
|
131 |
+
conv_i += 1
|
132 |
+
|
133 |
+
update(statedict, convert_conv(vars, f'4x4/Conv', 'final_conv'))
|
134 |
+
update(statedict, convert_dense(vars, f'4x4/Dense0', 'final_linear.0'))
|
135 |
+
update(statedict, convert_dense(vars, f'Output', 'final_linear.1'))
|
136 |
+
|
137 |
+
return statedict
|
138 |
+
|
139 |
+
|
140 |
+
def fill_statedict(state_dict, vars, size):
|
141 |
+
log_size = int(math.log(size, 2))
|
142 |
+
|
143 |
+
for i in range(8):
|
144 |
+
update(state_dict, convert_dense(vars, f'G_mapping/Dense{i}', f'style.{i + 1}'))
|
145 |
+
|
146 |
+
update(
|
147 |
+
state_dict,
|
148 |
+
{
|
149 |
+
'input.input': torch.from_numpy(
|
150 |
+
vars['G_synthesis/4x4/Const/const'].value().eval()
|
151 |
+
)
|
152 |
+
},
|
153 |
+
)
|
154 |
+
|
155 |
+
update(state_dict, convert_torgb(vars, 'G_synthesis/4x4/ToRGB', 'to_rgb1'))
|
156 |
+
|
157 |
+
for i in range(log_size - 2):
|
158 |
+
reso = 4 * 2 ** (i + 1)
|
159 |
+
update(
|
160 |
+
state_dict,
|
161 |
+
convert_torgb(vars, f'G_synthesis/{reso}x{reso}/ToRGB', f'to_rgbs.{i}'),
|
162 |
+
)
|
163 |
+
|
164 |
+
update(state_dict, convert_modconv(vars, 'G_synthesis/4x4/Conv', 'conv1'))
|
165 |
+
|
166 |
+
conv_i = 0
|
167 |
+
|
168 |
+
for i in range(log_size - 2):
|
169 |
+
reso = 4 * 2 ** (i + 1)
|
170 |
+
update(
|
171 |
+
state_dict,
|
172 |
+
convert_modconv(
|
173 |
+
vars,
|
174 |
+
f'G_synthesis/{reso}x{reso}/Conv0_up',
|
175 |
+
f'convs.{conv_i}',
|
176 |
+
flip=True,
|
177 |
+
),
|
178 |
+
)
|
179 |
+
update(
|
180 |
+
state_dict,
|
181 |
+
convert_modconv(
|
182 |
+
vars, f'G_synthesis/{reso}x{reso}/Conv1', f'convs.{conv_i + 1}'
|
183 |
+
),
|
184 |
+
)
|
185 |
+
conv_i += 2
|
186 |
+
|
187 |
+
for i in range(0, (log_size - 2) * 2 + 1):
|
188 |
+
update(
|
189 |
+
state_dict,
|
190 |
+
{
|
191 |
+
f'noises.noise_{i}': torch.from_numpy(
|
192 |
+
vars[f'G_synthesis/noise{i}'].value().eval()
|
193 |
+
)
|
194 |
+
},
|
195 |
+
)
|
196 |
+
|
197 |
+
return state_dict
|
198 |
+
|
199 |
+
|
200 |
+
if __name__ == '__main__':
|
201 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
202 |
+
print('Using PyTorch device', device)
|
203 |
+
|
204 |
+
parser = argparse.ArgumentParser()
|
205 |
+
parser.add_argument('--repo', type=str, required=True)
|
206 |
+
parser.add_argument('--gen', action='store_true')
|
207 |
+
parser.add_argument('--disc', action='store_true')
|
208 |
+
parser.add_argument('--channel_multiplier', type=int, default=2)
|
209 |
+
parser.add_argument('path', metavar='PATH')
|
210 |
+
|
211 |
+
args = parser.parse_args()
|
212 |
+
|
213 |
+
sys.path.append(args.repo)
|
214 |
+
|
215 |
+
import dnnlib
|
216 |
+
from dnnlib import tflib
|
217 |
+
|
218 |
+
tflib.init_tf()
|
219 |
+
|
220 |
+
with open(args.path, 'rb') as f:
|
221 |
+
generator, discriminator, g_ema = pickle.load(f)
|
222 |
+
|
223 |
+
size = g_ema.output_shape[2]
|
224 |
+
|
225 |
+
g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
|
226 |
+
state_dict = g.state_dict()
|
227 |
+
state_dict = fill_statedict(state_dict, g_ema.vars, size)
|
228 |
+
|
229 |
+
g.load_state_dict(state_dict)
|
230 |
+
|
231 |
+
latent_avg = torch.from_numpy(g_ema.vars['dlatent_avg'].value().eval())
|
232 |
+
|
233 |
+
ckpt = {'g_ema': state_dict, 'latent_avg': latent_avg}
|
234 |
+
|
235 |
+
if args.gen:
|
236 |
+
g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
|
237 |
+
g_train_state = g_train.state_dict()
|
238 |
+
g_train_state = fill_statedict(g_train_state, generator.vars, size)
|
239 |
+
ckpt['g'] = g_train_state
|
240 |
+
|
241 |
+
if args.disc:
|
242 |
+
disc = Discriminator(size, channel_multiplier=args.channel_multiplier)
|
243 |
+
d_state = disc.state_dict()
|
244 |
+
d_state = discriminator_fill_statedict(d_state, discriminator.vars, size)
|
245 |
+
ckpt['d'] = d_state
|
246 |
+
|
247 |
+
name = os.path.splitext(os.path.basename(args.path))[0]
|
248 |
+
outpath = os.path.join(os.getcwd(), f'{name}.pt')
|
249 |
+
print('Saving', outpath)
|
250 |
+
try:
|
251 |
+
torch.save(ckpt, outpath, _use_new_zipfile_serialization=False)
|
252 |
+
except TypeError:
|
253 |
+
torch.save(ckpt, outpath)
|
254 |
+
|
255 |
+
|
256 |
+
print('Generating TF-Torch comparison images')
|
257 |
+
batch_size = {256: 8, 512: 4, 1024: 2}
|
258 |
+
n_sample = batch_size.get(size, 4)
|
259 |
+
|
260 |
+
g = g.to(device)
|
261 |
+
|
262 |
+
z = np.random.RandomState(0).randn(n_sample, 512).astype('float32')
|
263 |
+
|
264 |
+
with torch.no_grad():
|
265 |
+
img_pt, _ = g(
|
266 |
+
[torch.from_numpy(z).to(device)],
|
267 |
+
truncation=0.5,
|
268 |
+
truncation_latent=latent_avg.to(device),
|
269 |
+
)
|
270 |
+
|
271 |
+
img_tf = g_ema.run(z, None, randomize_noise=False)
|
272 |
+
img_tf = torch.from_numpy(img_tf).to(device)
|
273 |
+
|
274 |
+
img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(
|
275 |
+
0.0, 1.0
|
276 |
+
)
|
277 |
+
|
278 |
+
img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0)
|
279 |
+
utils.save_image(
|
280 |
+
img_concat, name + '.png', nrow=n_sample, normalize=True, range=(-1, 1)
|
281 |
+
)
|
282 |
+
print('Done')
|
283 |
+
|
models/stylegan2/stylegan2-pytorch/dataset.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import BytesIO
|
2 |
+
|
3 |
+
import lmdb
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
|
8 |
+
class MultiResolutionDataset(Dataset):
|
9 |
+
def __init__(self, path, transform, resolution=256):
|
10 |
+
self.env = lmdb.open(
|
11 |
+
path,
|
12 |
+
max_readers=32,
|
13 |
+
readonly=True,
|
14 |
+
lock=False,
|
15 |
+
readahead=False,
|
16 |
+
meminit=False,
|
17 |
+
)
|
18 |
+
|
19 |
+
if not self.env:
|
20 |
+
raise IOError('Cannot open lmdb dataset', path)
|
21 |
+
|
22 |
+
with self.env.begin(write=False) as txn:
|
23 |
+
self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
|
24 |
+
|
25 |
+
self.resolution = resolution
|
26 |
+
self.transform = transform
|
27 |
+
|
28 |
+
def __len__(self):
|
29 |
+
return self.length
|
30 |
+
|
31 |
+
def __getitem__(self, index):
|
32 |
+
with self.env.begin(write=False) as txn:
|
33 |
+
key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
|
34 |
+
img_bytes = txn.get(key)
|
35 |
+
|
36 |
+
buffer = BytesIO(img_bytes)
|
37 |
+
img = Image.open(buffer)
|
38 |
+
img = self.transform(img)
|
39 |
+
|
40 |
+
return img
|
models/stylegan2/stylegan2-pytorch/distributed.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import distributed as dist
|
6 |
+
from torch.utils.data.sampler import Sampler
|
7 |
+
|
8 |
+
|
9 |
+
def get_rank():
|
10 |
+
if not dist.is_available():
|
11 |
+
return 0
|
12 |
+
|
13 |
+
if not dist.is_initialized():
|
14 |
+
return 0
|
15 |
+
|
16 |
+
return dist.get_rank()
|
17 |
+
|
18 |
+
|
19 |
+
def synchronize():
|
20 |
+
if not dist.is_available():
|
21 |
+
return
|
22 |
+
|
23 |
+
if not dist.is_initialized():
|
24 |
+
return
|
25 |
+
|
26 |
+
world_size = dist.get_world_size()
|
27 |
+
|
28 |
+
if world_size == 1:
|
29 |
+
return
|
30 |
+
|
31 |
+
dist.barrier()
|
32 |
+
|
33 |
+
|
34 |
+
def get_world_size():
|
35 |
+
if not dist.is_available():
|
36 |
+
return 1
|
37 |
+
|
38 |
+
if not dist.is_initialized():
|
39 |
+
return 1
|
40 |
+
|
41 |
+
return dist.get_world_size()
|
42 |
+
|
43 |
+
|
44 |
+
def reduce_sum(tensor):
|
45 |
+
if not dist.is_available():
|
46 |
+
return tensor
|
47 |
+
|
48 |
+
if not dist.is_initialized():
|
49 |
+
return tensor
|
50 |
+
|
51 |
+
tensor = tensor.clone()
|
52 |
+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
53 |
+
|
54 |
+
return tensor
|
55 |
+
|
56 |
+
|
57 |
+
def gather_grad(params):
|
58 |
+
world_size = get_world_size()
|
59 |
+
|
60 |
+
if world_size == 1:
|
61 |
+
return
|
62 |
+
|
63 |
+
for param in params:
|
64 |
+
if param.grad is not None:
|
65 |
+
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
|
66 |
+
param.grad.data.div_(world_size)
|
67 |
+
|
68 |
+
|
69 |
+
def all_gather(data):
|
70 |
+
world_size = get_world_size()
|
71 |
+
|
72 |
+
if world_size == 1:
|
73 |
+
return [data]
|
74 |
+
|
75 |
+
buffer = pickle.dumps(data)
|
76 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
77 |
+
tensor = torch.ByteTensor(storage).to('cuda')
|
78 |
+
|
79 |
+
local_size = torch.IntTensor([tensor.numel()]).to('cuda')
|
80 |
+
size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
|
81 |
+
dist.all_gather(size_list, local_size)
|
82 |
+
size_list = [int(size.item()) for size in size_list]
|
83 |
+
max_size = max(size_list)
|
84 |
+
|
85 |
+
tensor_list = []
|
86 |
+
for _ in size_list:
|
87 |
+
tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
|
88 |
+
|
89 |
+
if local_size != max_size:
|
90 |
+
padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
|
91 |
+
tensor = torch.cat((tensor, padding), 0)
|
92 |
+
|
93 |
+
dist.all_gather(tensor_list, tensor)
|
94 |
+
|
95 |
+
data_list = []
|
96 |
+
|
97 |
+
for size, tensor in zip(size_list, tensor_list):
|
98 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
99 |
+
data_list.append(pickle.loads(buffer))
|
100 |
+
|
101 |
+
return data_list
|
102 |
+
|
103 |
+
|
104 |
+
def reduce_loss_dict(loss_dict):
|
105 |
+
world_size = get_world_size()
|
106 |
+
|
107 |
+
if world_size < 2:
|
108 |
+
return loss_dict
|
109 |
+
|
110 |
+
with torch.no_grad():
|
111 |
+
keys = []
|
112 |
+
losses = []
|
113 |
+
|
114 |
+
for k in sorted(loss_dict.keys()):
|
115 |
+
keys.append(k)
|
116 |
+
losses.append(loss_dict[k])
|
117 |
+
|
118 |
+
losses = torch.stack(losses, 0)
|
119 |
+
dist.reduce(losses, dst=0)
|
120 |
+
|
121 |
+
if dist.get_rank() == 0:
|
122 |
+
losses /= world_size
|
123 |
+
|
124 |
+
reduced_losses = {k: v for k, v in zip(keys, losses)}
|
125 |
+
|
126 |
+
return reduced_losses
|
models/stylegan2/stylegan2-pytorch/fid.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import numpy as np
|
7 |
+
from scipy import linalg
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from model import Generator
|
11 |
+
from calc_inception import load_patched_inception_v3
|
12 |
+
|
13 |
+
|
14 |
+
@torch.no_grad()
|
15 |
+
def extract_feature_from_samples(
|
16 |
+
generator, inception, truncation, truncation_latent, batch_size, n_sample, device
|
17 |
+
):
|
18 |
+
n_batch = n_sample // batch_size
|
19 |
+
resid = n_sample - (n_batch * batch_size)
|
20 |
+
batch_sizes = [batch_size] * n_batch + [resid]
|
21 |
+
features = []
|
22 |
+
|
23 |
+
for batch in tqdm(batch_sizes):
|
24 |
+
latent = torch.randn(batch, 512, device=device)
|
25 |
+
img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent)
|
26 |
+
feat = inception(img)[0].view(img.shape[0], -1)
|
27 |
+
features.append(feat.to('cpu'))
|
28 |
+
|
29 |
+
features = torch.cat(features, 0)
|
30 |
+
|
31 |
+
return features
|
32 |
+
|
33 |
+
|
34 |
+
def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
|
35 |
+
cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)
|
36 |
+
|
37 |
+
if not np.isfinite(cov_sqrt).all():
|
38 |
+
print('product of cov matrices is singular')
|
39 |
+
offset = np.eye(sample_cov.shape[0]) * eps
|
40 |
+
cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))
|
41 |
+
|
42 |
+
if np.iscomplexobj(cov_sqrt):
|
43 |
+
if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
|
44 |
+
m = np.max(np.abs(cov_sqrt.imag))
|
45 |
+
|
46 |
+
raise ValueError(f'Imaginary component {m}')
|
47 |
+
|
48 |
+
cov_sqrt = cov_sqrt.real
|
49 |
+
|
50 |
+
mean_diff = sample_mean - real_mean
|
51 |
+
mean_norm = mean_diff @ mean_diff
|
52 |
+
|
53 |
+
trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)
|
54 |
+
|
55 |
+
fid = mean_norm + trace
|
56 |
+
|
57 |
+
return fid
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == '__main__':
|
61 |
+
device = 'cuda'
|
62 |
+
|
63 |
+
parser = argparse.ArgumentParser()
|
64 |
+
|
65 |
+
parser.add_argument('--truncation', type=float, default=1)
|
66 |
+
parser.add_argument('--truncation_mean', type=int, default=4096)
|
67 |
+
parser.add_argument('--batch', type=int, default=64)
|
68 |
+
parser.add_argument('--n_sample', type=int, default=50000)
|
69 |
+
parser.add_argument('--size', type=int, default=256)
|
70 |
+
parser.add_argument('--inception', type=str, default=None, required=True)
|
71 |
+
parser.add_argument('ckpt', metavar='CHECKPOINT')
|
72 |
+
|
73 |
+
args = parser.parse_args()
|
74 |
+
|
75 |
+
ckpt = torch.load(args.ckpt)
|
76 |
+
|
77 |
+
g = Generator(args.size, 512, 8).to(device)
|
78 |
+
g.load_state_dict(ckpt['g_ema'])
|
79 |
+
g = nn.DataParallel(g)
|
80 |
+
g.eval()
|
81 |
+
|
82 |
+
if args.truncation < 1:
|
83 |
+
with torch.no_grad():
|
84 |
+
mean_latent = g.mean_latent(args.truncation_mean)
|
85 |
+
|
86 |
+
else:
|
87 |
+
mean_latent = None
|
88 |
+
|
89 |
+
inception = nn.DataParallel(load_patched_inception_v3()).to(device)
|
90 |
+
inception.eval()
|
91 |
+
|
92 |
+
features = extract_feature_from_samples(
|
93 |
+
g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device
|
94 |
+
).numpy()
|
95 |
+
print(f'extracted {features.shape[0]} features')
|
96 |
+
|
97 |
+
sample_mean = np.mean(features, 0)
|
98 |
+
sample_cov = np.cov(features, rowvar=False)
|
99 |
+
|
100 |
+
with open(args.inception, 'rb') as f:
|
101 |
+
embeds = pickle.load(f)
|
102 |
+
real_mean = embeds['mean']
|
103 |
+
real_cov = embeds['cov']
|
104 |
+
|
105 |
+
fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)
|
106 |
+
|
107 |
+
print('fid:', fid)
|
models/stylegan2/stylegan2-pytorch/generate.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torchvision import utils
|
5 |
+
from model import Generator
|
6 |
+
from tqdm import tqdm
|
7 |
+
def generate(args, g_ema, device, mean_latent):
|
8 |
+
|
9 |
+
with torch.no_grad():
|
10 |
+
g_ema.eval()
|
11 |
+
for i in tqdm(range(args.pics)):
|
12 |
+
sample_z = torch.randn(args.sample, args.latent, device=device)
|
13 |
+
|
14 |
+
sample, _ = g_ema([sample_z], truncation=args.truncation, truncation_latent=mean_latent)
|
15 |
+
|
16 |
+
utils.save_image(
|
17 |
+
sample,
|
18 |
+
f'sample/{str(i).zfill(6)}.png',
|
19 |
+
nrow=1,
|
20 |
+
normalize=True,
|
21 |
+
range=(-1, 1),
|
22 |
+
)
|
23 |
+
|
24 |
+
if __name__ == '__main__':
|
25 |
+
device = 'cuda'
|
26 |
+
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
|
29 |
+
parser.add_argument('--size', type=int, default=1024)
|
30 |
+
parser.add_argument('--sample', type=int, default=1)
|
31 |
+
parser.add_argument('--pics', type=int, default=20)
|
32 |
+
parser.add_argument('--truncation', type=float, default=1)
|
33 |
+
parser.add_argument('--truncation_mean', type=int, default=4096)
|
34 |
+
parser.add_argument('--ckpt', type=str, default="stylegan2-ffhq-config-f.pt")
|
35 |
+
parser.add_argument('--channel_multiplier', type=int, default=2)
|
36 |
+
|
37 |
+
args = parser.parse_args()
|
38 |
+
|
39 |
+
args.latent = 512
|
40 |
+
args.n_mlp = 8
|
41 |
+
|
42 |
+
g_ema = Generator(
|
43 |
+
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
|
44 |
+
).to(device)
|
45 |
+
checkpoint = torch.load(args.ckpt)
|
46 |
+
|
47 |
+
g_ema.load_state_dict(checkpoint['g_ema'])
|
48 |
+
|
49 |
+
if args.truncation < 1:
|
50 |
+
with torch.no_grad():
|
51 |
+
mean_latent = g_ema.mean_latent(args.truncation_mean)
|
52 |
+
else:
|
53 |
+
mean_latent = None
|
54 |
+
|
55 |
+
generate(args, g_ema, device, mean_latent)
|
models/stylegan2/stylegan2-pytorch/inception.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import models
|
5 |
+
|
6 |
+
try:
|
7 |
+
from torchvision.models.utils import load_state_dict_from_url
|
8 |
+
except ImportError:
|
9 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
10 |
+
|
11 |
+
# Inception weights ported to Pytorch from
|
12 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
13 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
|
14 |
+
|
15 |
+
|
16 |
+
class InceptionV3(nn.Module):
|
17 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
18 |
+
|
19 |
+
# Index of default block of inception to return,
|
20 |
+
# corresponds to output of final average pooling
|
21 |
+
DEFAULT_BLOCK_INDEX = 3
|
22 |
+
|
23 |
+
# Maps feature dimensionality to their output blocks indices
|
24 |
+
BLOCK_INDEX_BY_DIM = {
|
25 |
+
64: 0, # First max pooling features
|
26 |
+
192: 1, # Second max pooling featurs
|
27 |
+
768: 2, # Pre-aux classifier features
|
28 |
+
2048: 3 # Final average pooling features
|
29 |
+
}
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
output_blocks=[DEFAULT_BLOCK_INDEX],
|
33 |
+
resize_input=True,
|
34 |
+
normalize_input=True,
|
35 |
+
requires_grad=False,
|
36 |
+
use_fid_inception=True):
|
37 |
+
"""Build pretrained InceptionV3
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
output_blocks : list of int
|
42 |
+
Indices of blocks to return features of. Possible values are:
|
43 |
+
- 0: corresponds to output of first max pooling
|
44 |
+
- 1: corresponds to output of second max pooling
|
45 |
+
- 2: corresponds to output which is fed to aux classifier
|
46 |
+
- 3: corresponds to output of final average pooling
|
47 |
+
resize_input : bool
|
48 |
+
If true, bilinearly resizes input to width and height 299 before
|
49 |
+
feeding input to model. As the network without fully connected
|
50 |
+
layers is fully convolutional, it should be able to handle inputs
|
51 |
+
of arbitrary size, so resizing might not be strictly needed
|
52 |
+
normalize_input : bool
|
53 |
+
If true, scales the input from range (0, 1) to the range the
|
54 |
+
pretrained Inception network expects, namely (-1, 1)
|
55 |
+
requires_grad : bool
|
56 |
+
If true, parameters of the model require gradients. Possibly useful
|
57 |
+
for finetuning the network
|
58 |
+
use_fid_inception : bool
|
59 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
60 |
+
FID implementation. If false, uses the pretrained Inception model
|
61 |
+
available in torchvision. The FID Inception model has different
|
62 |
+
weights and a slightly different structure from torchvision's
|
63 |
+
Inception model. If you want to compute FID scores, you are
|
64 |
+
strongly advised to set this parameter to true to get comparable
|
65 |
+
results.
|
66 |
+
"""
|
67 |
+
super(InceptionV3, self).__init__()
|
68 |
+
|
69 |
+
self.resize_input = resize_input
|
70 |
+
self.normalize_input = normalize_input
|
71 |
+
self.output_blocks = sorted(output_blocks)
|
72 |
+
self.last_needed_block = max(output_blocks)
|
73 |
+
|
74 |
+
assert self.last_needed_block <= 3, \
|
75 |
+
'Last possible output block index is 3'
|
76 |
+
|
77 |
+
self.blocks = nn.ModuleList()
|
78 |
+
|
79 |
+
if use_fid_inception:
|
80 |
+
inception = fid_inception_v3()
|
81 |
+
else:
|
82 |
+
inception = models.inception_v3(pretrained=True)
|
83 |
+
|
84 |
+
# Block 0: input to maxpool1
|
85 |
+
block0 = [
|
86 |
+
inception.Conv2d_1a_3x3,
|
87 |
+
inception.Conv2d_2a_3x3,
|
88 |
+
inception.Conv2d_2b_3x3,
|
89 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
90 |
+
]
|
91 |
+
self.blocks.append(nn.Sequential(*block0))
|
92 |
+
|
93 |
+
# Block 1: maxpool1 to maxpool2
|
94 |
+
if self.last_needed_block >= 1:
|
95 |
+
block1 = [
|
96 |
+
inception.Conv2d_3b_1x1,
|
97 |
+
inception.Conv2d_4a_3x3,
|
98 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
99 |
+
]
|
100 |
+
self.blocks.append(nn.Sequential(*block1))
|
101 |
+
|
102 |
+
# Block 2: maxpool2 to aux classifier
|
103 |
+
if self.last_needed_block >= 2:
|
104 |
+
block2 = [
|
105 |
+
inception.Mixed_5b,
|
106 |
+
inception.Mixed_5c,
|
107 |
+
inception.Mixed_5d,
|
108 |
+
inception.Mixed_6a,
|
109 |
+
inception.Mixed_6b,
|
110 |
+
inception.Mixed_6c,
|
111 |
+
inception.Mixed_6d,
|
112 |
+
inception.Mixed_6e,
|
113 |
+
]
|
114 |
+
self.blocks.append(nn.Sequential(*block2))
|
115 |
+
|
116 |
+
# Block 3: aux classifier to final avgpool
|
117 |
+
if self.last_needed_block >= 3:
|
118 |
+
block3 = [
|
119 |
+
inception.Mixed_7a,
|
120 |
+
inception.Mixed_7b,
|
121 |
+
inception.Mixed_7c,
|
122 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
123 |
+
]
|
124 |
+
self.blocks.append(nn.Sequential(*block3))
|
125 |
+
|
126 |
+
for param in self.parameters():
|
127 |
+
param.requires_grad = requires_grad
|
128 |
+
|
129 |
+
def forward(self, inp):
|
130 |
+
"""Get Inception feature maps
|
131 |
+
|
132 |
+
Parameters
|
133 |
+
----------
|
134 |
+
inp : torch.autograd.Variable
|
135 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
136 |
+
range (0, 1)
|
137 |
+
|
138 |
+
Returns
|
139 |
+
-------
|
140 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
141 |
+
block, sorted ascending by index
|
142 |
+
"""
|
143 |
+
outp = []
|
144 |
+
x = inp
|
145 |
+
|
146 |
+
if self.resize_input:
|
147 |
+
x = F.interpolate(x,
|
148 |
+
size=(299, 299),
|
149 |
+
mode='bilinear',
|
150 |
+
align_corners=False)
|
151 |
+
|
152 |
+
if self.normalize_input:
|
153 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
154 |
+
|
155 |
+
for idx, block in enumerate(self.blocks):
|
156 |
+
x = block(x)
|
157 |
+
if idx in self.output_blocks:
|
158 |
+
outp.append(x)
|
159 |
+
|
160 |
+
if idx == self.last_needed_block:
|
161 |
+
break
|
162 |
+
|
163 |
+
return outp
|
164 |
+
|
165 |
+
|
166 |
+
def fid_inception_v3():
|
167 |
+
"""Build pretrained Inception model for FID computation
|
168 |
+
|
169 |
+
The Inception model for FID computation uses a different set of weights
|
170 |
+
and has a slightly different structure than torchvision's Inception.
|
171 |
+
|
172 |
+
This method first constructs torchvision's Inception and then patches the
|
173 |
+
necessary parts that are different in the FID Inception model.
|
174 |
+
"""
|
175 |
+
inception = models.inception_v3(num_classes=1008,
|
176 |
+
aux_logits=False,
|
177 |
+
pretrained=False)
|
178 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
179 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
180 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
181 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
182 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
183 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
184 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
185 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
186 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
187 |
+
|
188 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
189 |
+
inception.load_state_dict(state_dict)
|
190 |
+
return inception
|
191 |
+
|
192 |
+
|
193 |
+
class FIDInceptionA(models.inception.InceptionA):
|
194 |
+
"""InceptionA block patched for FID computation"""
|
195 |
+
def __init__(self, in_channels, pool_features):
|
196 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
branch1x1 = self.branch1x1(x)
|
200 |
+
|
201 |
+
branch5x5 = self.branch5x5_1(x)
|
202 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
203 |
+
|
204 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
205 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
206 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
207 |
+
|
208 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
209 |
+
# its average calculation
|
210 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
211 |
+
count_include_pad=False)
|
212 |
+
branch_pool = self.branch_pool(branch_pool)
|
213 |
+
|
214 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
215 |
+
return torch.cat(outputs, 1)
|
216 |
+
|
217 |
+
|
218 |
+
class FIDInceptionC(models.inception.InceptionC):
|
219 |
+
"""InceptionC block patched for FID computation"""
|
220 |
+
def __init__(self, in_channels, channels_7x7):
|
221 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
222 |
+
|
223 |
+
def forward(self, x):
|
224 |
+
branch1x1 = self.branch1x1(x)
|
225 |
+
|
226 |
+
branch7x7 = self.branch7x7_1(x)
|
227 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
228 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
229 |
+
|
230 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
231 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
232 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
233 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
234 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
235 |
+
|
236 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
237 |
+
# its average calculation
|
238 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
239 |
+
count_include_pad=False)
|
240 |
+
branch_pool = self.branch_pool(branch_pool)
|
241 |
+
|
242 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
243 |
+
return torch.cat(outputs, 1)
|
244 |
+
|
245 |
+
|
246 |
+
class FIDInceptionE_1(models.inception.InceptionE):
|
247 |
+
"""First InceptionE block patched for FID computation"""
|
248 |
+
def __init__(self, in_channels):
|
249 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
250 |
+
|
251 |
+
def forward(self, x):
|
252 |
+
branch1x1 = self.branch1x1(x)
|
253 |
+
|
254 |
+
branch3x3 = self.branch3x3_1(x)
|
255 |
+
branch3x3 = [
|
256 |
+
self.branch3x3_2a(branch3x3),
|
257 |
+
self.branch3x3_2b(branch3x3),
|
258 |
+
]
|
259 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
260 |
+
|
261 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
262 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
263 |
+
branch3x3dbl = [
|
264 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
265 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
266 |
+
]
|
267 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
268 |
+
|
269 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
270 |
+
# its average calculation
|
271 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
272 |
+
count_include_pad=False)
|
273 |
+
branch_pool = self.branch_pool(branch_pool)
|
274 |
+
|
275 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
276 |
+
return torch.cat(outputs, 1)
|
277 |
+
|
278 |
+
|
279 |
+
class FIDInceptionE_2(models.inception.InceptionE):
|
280 |
+
"""Second InceptionE block patched for FID computation"""
|
281 |
+
def __init__(self, in_channels):
|
282 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
283 |
+
|
284 |
+
def forward(self, x):
|
285 |
+
branch1x1 = self.branch1x1(x)
|
286 |
+
|
287 |
+
branch3x3 = self.branch3x3_1(x)
|
288 |
+
branch3x3 = [
|
289 |
+
self.branch3x3_2a(branch3x3),
|
290 |
+
self.branch3x3_2b(branch3x3),
|
291 |
+
]
|
292 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
293 |
+
|
294 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
295 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
296 |
+
branch3x3dbl = [
|
297 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
298 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
299 |
+
]
|
300 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
301 |
+
|
302 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
303 |
+
# pooling. This is likely an error in this specific Inception
|
304 |
+
# implementation, as other Inception models use average pooling here
|
305 |
+
# (which matches the description in the paper).
|
306 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
307 |
+
branch_pool = self.branch_pool(branch_pool)
|
308 |
+
|
309 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
310 |
+
return torch.cat(outputs, 1)
|
models/stylegan2/stylegan2-pytorch/lpips/__init__.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
from __future__ import division
|
4 |
+
from __future__ import print_function
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from skimage.measure import compare_ssim
|
8 |
+
import torch
|
9 |
+
from torch.autograd import Variable
|
10 |
+
|
11 |
+
from lpips import dist_model
|
12 |
+
|
13 |
+
class PerceptualLoss(torch.nn.Module):
|
14 |
+
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
|
15 |
+
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
|
16 |
+
super(PerceptualLoss, self).__init__()
|
17 |
+
print('Setting up Perceptual loss...')
|
18 |
+
self.use_gpu = use_gpu
|
19 |
+
self.spatial = spatial
|
20 |
+
self.gpu_ids = gpu_ids
|
21 |
+
self.model = dist_model.DistModel()
|
22 |
+
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
|
23 |
+
print('...[%s] initialized'%self.model.name())
|
24 |
+
print('...Done')
|
25 |
+
|
26 |
+
def forward(self, pred, target, normalize=False):
|
27 |
+
"""
|
28 |
+
Pred and target are Variables.
|
29 |
+
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
|
30 |
+
If normalize is False, assumes the images are already between [-1,+1]
|
31 |
+
|
32 |
+
Inputs pred and target are Nx3xHxW
|
33 |
+
Output pytorch Variable N long
|
34 |
+
"""
|
35 |
+
|
36 |
+
if normalize:
|
37 |
+
target = 2 * target - 1
|
38 |
+
pred = 2 * pred - 1
|
39 |
+
|
40 |
+
return self.model.forward(target, pred)
|
41 |
+
|
42 |
+
def normalize_tensor(in_feat,eps=1e-10):
|
43 |
+
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
|
44 |
+
return in_feat/(norm_factor+eps)
|
45 |
+
|
46 |
+
def l2(p0, p1, range=255.):
|
47 |
+
return .5*np.mean((p0 / range - p1 / range)**2)
|
48 |
+
|
49 |
+
def psnr(p0, p1, peak=255.):
|
50 |
+
return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
|
51 |
+
|
52 |
+
def dssim(p0, p1, range=255.):
|
53 |
+
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
|
54 |
+
|
55 |
+
def rgb2lab(in_img,mean_cent=False):
|
56 |
+
from skimage import color
|
57 |
+
img_lab = color.rgb2lab(in_img)
|
58 |
+
if(mean_cent):
|
59 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
60 |
+
return img_lab
|
61 |
+
|
62 |
+
def tensor2np(tensor_obj):
|
63 |
+
# change dimension of a tensor object into a numpy array
|
64 |
+
return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
|
65 |
+
|
66 |
+
def np2tensor(np_obj):
|
67 |
+
# change dimenion of np array into tensor array
|
68 |
+
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
69 |
+
|
70 |
+
def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
|
71 |
+
# image tensor to lab tensor
|
72 |
+
from skimage import color
|
73 |
+
|
74 |
+
img = tensor2im(image_tensor)
|
75 |
+
img_lab = color.rgb2lab(img)
|
76 |
+
if(mc_only):
|
77 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
78 |
+
if(to_norm and not mc_only):
|
79 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
80 |
+
img_lab = img_lab/100.
|
81 |
+
|
82 |
+
return np2tensor(img_lab)
|
83 |
+
|
84 |
+
def tensorlab2tensor(lab_tensor,return_inbnd=False):
|
85 |
+
from skimage import color
|
86 |
+
import warnings
|
87 |
+
warnings.filterwarnings("ignore")
|
88 |
+
|
89 |
+
lab = tensor2np(lab_tensor)*100.
|
90 |
+
lab[:,:,0] = lab[:,:,0]+50
|
91 |
+
|
92 |
+
rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
|
93 |
+
if(return_inbnd):
|
94 |
+
# convert back to lab, see if we match
|
95 |
+
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
|
96 |
+
mask = 1.*np.isclose(lab_back,lab,atol=2.)
|
97 |
+
mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
|
98 |
+
return (im2tensor(rgb_back),mask)
|
99 |
+
else:
|
100 |
+
return im2tensor(rgb_back)
|
101 |
+
|
102 |
+
def rgb2lab(input):
|
103 |
+
from skimage import color
|
104 |
+
return color.rgb2lab(input / 255.)
|
105 |
+
|
106 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
107 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
108 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
109 |
+
return image_numpy.astype(imtype)
|
110 |
+
|
111 |
+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
|
112 |
+
return torch.Tensor((image / factor - cent)
|
113 |
+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
114 |
+
|
115 |
+
def tensor2vec(vector_tensor):
|
116 |
+
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
|
117 |
+
|
118 |
+
def voc_ap(rec, prec, use_07_metric=False):
|
119 |
+
""" ap = voc_ap(rec, prec, [use_07_metric])
|
120 |
+
Compute VOC AP given precision and recall.
|
121 |
+
If use_07_metric is true, uses the
|
122 |
+
VOC 07 11 point method (default:False).
|
123 |
+
"""
|
124 |
+
if use_07_metric:
|
125 |
+
# 11 point metric
|
126 |
+
ap = 0.
|
127 |
+
for t in np.arange(0., 1.1, 0.1):
|
128 |
+
if np.sum(rec >= t) == 0:
|
129 |
+
p = 0
|
130 |
+
else:
|
131 |
+
p = np.max(prec[rec >= t])
|
132 |
+
ap = ap + p / 11.
|
133 |
+
else:
|
134 |
+
# correct AP calculation
|
135 |
+
# first append sentinel values at the end
|
136 |
+
mrec = np.concatenate(([0.], rec, [1.]))
|
137 |
+
mpre = np.concatenate(([0.], prec, [0.]))
|
138 |
+
|
139 |
+
# compute the precision envelope
|
140 |
+
for i in range(mpre.size - 1, 0, -1):
|
141 |
+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
142 |
+
|
143 |
+
# to calculate area under PR curve, look for points
|
144 |
+
# where X axis (recall) changes value
|
145 |
+
i = np.where(mrec[1:] != mrec[:-1])[0]
|
146 |
+
|
147 |
+
# and sum (\Delta recall) * prec
|
148 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
|
149 |
+
return ap
|
150 |
+
|
151 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
152 |
+
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
|
153 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
154 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
155 |
+
return image_numpy.astype(imtype)
|
156 |
+
|
157 |
+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
|
158 |
+
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
|
159 |
+
return torch.Tensor((image / factor - cent)
|
160 |
+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
models/stylegan2/stylegan2-pytorch/lpips/base_model.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Variable
|
5 |
+
from pdb import set_trace as st
|
6 |
+
from IPython import embed
|
7 |
+
|
8 |
+
class BaseModel():
|
9 |
+
def __init__(self):
|
10 |
+
pass;
|
11 |
+
|
12 |
+
def name(self):
|
13 |
+
return 'BaseModel'
|
14 |
+
|
15 |
+
def initialize(self, use_gpu=True, gpu_ids=[0]):
|
16 |
+
self.use_gpu = use_gpu
|
17 |
+
self.gpu_ids = gpu_ids
|
18 |
+
|
19 |
+
def forward(self):
|
20 |
+
pass
|
21 |
+
|
22 |
+
def get_image_paths(self):
|
23 |
+
pass
|
24 |
+
|
25 |
+
def optimize_parameters(self):
|
26 |
+
pass
|
27 |
+
|
28 |
+
def get_current_visuals(self):
|
29 |
+
return self.input
|
30 |
+
|
31 |
+
def get_current_errors(self):
|
32 |
+
return {}
|
33 |
+
|
34 |
+
def save(self, label):
|
35 |
+
pass
|
36 |
+
|
37 |
+
# helper saving function that can be used by subclasses
|
38 |
+
def save_network(self, network, path, network_label, epoch_label):
|
39 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
40 |
+
save_path = os.path.join(path, save_filename)
|
41 |
+
torch.save(network.state_dict(), save_path)
|
42 |
+
|
43 |
+
# helper loading function that can be used by subclasses
|
44 |
+
def load_network(self, network, network_label, epoch_label):
|
45 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
46 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
47 |
+
print('Loading network from %s'%save_path)
|
48 |
+
network.load_state_dict(torch.load(save_path))
|
49 |
+
|
50 |
+
def update_learning_rate():
|
51 |
+
pass
|
52 |
+
|
53 |
+
def get_image_paths(self):
|
54 |
+
return self.image_paths
|
55 |
+
|
56 |
+
def save_done(self, flag=False):
|
57 |
+
np.save(os.path.join(self.save_dir, 'done_flag'),flag)
|
58 |
+
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
|
models/stylegan2/stylegan2-pytorch/lpips/dist_model.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
import os
|
9 |
+
from collections import OrderedDict
|
10 |
+
from torch.autograd import Variable
|
11 |
+
import itertools
|
12 |
+
from .base_model import BaseModel
|
13 |
+
from scipy.ndimage import zoom
|
14 |
+
import fractions
|
15 |
+
import functools
|
16 |
+
import skimage.transform
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
from IPython import embed
|
20 |
+
|
21 |
+
from . import networks_basic as networks
|
22 |
+
import lpips as util
|
23 |
+
|
24 |
+
class DistModel(BaseModel):
|
25 |
+
def name(self):
|
26 |
+
return self.model_name
|
27 |
+
|
28 |
+
def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
|
29 |
+
use_gpu=True, printNet=False, spatial=False,
|
30 |
+
is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
|
31 |
+
'''
|
32 |
+
INPUTS
|
33 |
+
model - ['net-lin'] for linearly calibrated network
|
34 |
+
['net'] for off-the-shelf network
|
35 |
+
['L2'] for L2 distance in Lab colorspace
|
36 |
+
['SSIM'] for ssim in RGB colorspace
|
37 |
+
net - ['squeeze','alex','vgg']
|
38 |
+
model_path - if None, will look in weights/[NET_NAME].pth
|
39 |
+
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
|
40 |
+
use_gpu - bool - whether or not to use a GPU
|
41 |
+
printNet - bool - whether or not to print network architecture out
|
42 |
+
spatial - bool - whether to output an array containing varying distances across spatial dimensions
|
43 |
+
spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
|
44 |
+
spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
|
45 |
+
spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
|
46 |
+
is_train - bool - [True] for training mode
|
47 |
+
lr - float - initial learning rate
|
48 |
+
beta1 - float - initial momentum term for adam
|
49 |
+
version - 0.1 for latest, 0.0 was original (with a bug)
|
50 |
+
gpu_ids - int array - [0] by default, gpus to use
|
51 |
+
'''
|
52 |
+
BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
|
53 |
+
|
54 |
+
self.model = model
|
55 |
+
self.net = net
|
56 |
+
self.is_train = is_train
|
57 |
+
self.spatial = spatial
|
58 |
+
self.gpu_ids = gpu_ids
|
59 |
+
self.model_name = '%s [%s]'%(model,net)
|
60 |
+
|
61 |
+
if(self.model == 'net-lin'): # pretrained net + linear layer
|
62 |
+
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
|
63 |
+
use_dropout=True, spatial=spatial, version=version, lpips=True)
|
64 |
+
kw = {}
|
65 |
+
if not use_gpu:
|
66 |
+
kw['map_location'] = 'cpu'
|
67 |
+
if(model_path is None):
|
68 |
+
import inspect
|
69 |
+
model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))
|
70 |
+
|
71 |
+
if(not is_train):
|
72 |
+
print('Loading model from: %s'%model_path)
|
73 |
+
self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
|
74 |
+
|
75 |
+
elif(self.model=='net'): # pretrained network
|
76 |
+
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
|
77 |
+
elif(self.model in ['L2','l2']):
|
78 |
+
self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
|
79 |
+
self.model_name = 'L2'
|
80 |
+
elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
|
81 |
+
self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
|
82 |
+
self.model_name = 'SSIM'
|
83 |
+
else:
|
84 |
+
raise ValueError("Model [%s] not recognized." % self.model)
|
85 |
+
|
86 |
+
self.parameters = list(self.net.parameters())
|
87 |
+
|
88 |
+
if self.is_train: # training mode
|
89 |
+
# extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
|
90 |
+
self.rankLoss = networks.BCERankingLoss()
|
91 |
+
self.parameters += list(self.rankLoss.net.parameters())
|
92 |
+
self.lr = lr
|
93 |
+
self.old_lr = lr
|
94 |
+
self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
|
95 |
+
else: # test mode
|
96 |
+
self.net.eval()
|
97 |
+
|
98 |
+
if(use_gpu):
|
99 |
+
self.net.to(gpu_ids[0])
|
100 |
+
self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
|
101 |
+
if(self.is_train):
|
102 |
+
self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
|
103 |
+
|
104 |
+
if(printNet):
|
105 |
+
print('---------- Networks initialized -------------')
|
106 |
+
networks.print_network(self.net)
|
107 |
+
print('-----------------------------------------------')
|
108 |
+
|
109 |
+
def forward(self, in0, in1, retPerLayer=False):
|
110 |
+
''' Function computes the distance between image patches in0 and in1
|
111 |
+
INPUTS
|
112 |
+
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
|
113 |
+
OUTPUT
|
114 |
+
computed distances between in0 and in1
|
115 |
+
'''
|
116 |
+
|
117 |
+
return self.net.forward(in0, in1, retPerLayer=retPerLayer)
|
118 |
+
|
119 |
+
# ***** TRAINING FUNCTIONS *****
|
120 |
+
def optimize_parameters(self):
|
121 |
+
self.forward_train()
|
122 |
+
self.optimizer_net.zero_grad()
|
123 |
+
self.backward_train()
|
124 |
+
self.optimizer_net.step()
|
125 |
+
self.clamp_weights()
|
126 |
+
|
127 |
+
def clamp_weights(self):
|
128 |
+
for module in self.net.modules():
|
129 |
+
if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
|
130 |
+
module.weight.data = torch.clamp(module.weight.data,min=0)
|
131 |
+
|
132 |
+
def set_input(self, data):
|
133 |
+
self.input_ref = data['ref']
|
134 |
+
self.input_p0 = data['p0']
|
135 |
+
self.input_p1 = data['p1']
|
136 |
+
self.input_judge = data['judge']
|
137 |
+
|
138 |
+
if(self.use_gpu):
|
139 |
+
self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
|
140 |
+
self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
|
141 |
+
self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
|
142 |
+
self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
|
143 |
+
|
144 |
+
self.var_ref = Variable(self.input_ref,requires_grad=True)
|
145 |
+
self.var_p0 = Variable(self.input_p0,requires_grad=True)
|
146 |
+
self.var_p1 = Variable(self.input_p1,requires_grad=True)
|
147 |
+
|
148 |
+
def forward_train(self): # run forward pass
|
149 |
+
# print(self.net.module.scaling_layer.shift)
|
150 |
+
# print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
|
151 |
+
|
152 |
+
self.d0 = self.forward(self.var_ref, self.var_p0)
|
153 |
+
self.d1 = self.forward(self.var_ref, self.var_p1)
|
154 |
+
self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
|
155 |
+
|
156 |
+
self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
|
157 |
+
|
158 |
+
self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
|
159 |
+
|
160 |
+
return self.loss_total
|
161 |
+
|
162 |
+
def backward_train(self):
|
163 |
+
torch.mean(self.loss_total).backward()
|
164 |
+
|
165 |
+
def compute_accuracy(self,d0,d1,judge):
|
166 |
+
''' d0, d1 are Variables, judge is a Tensor '''
|
167 |
+
d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
|
168 |
+
judge_per = judge.cpu().numpy().flatten()
|
169 |
+
return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)
|
170 |
+
|
171 |
+
def get_current_errors(self):
|
172 |
+
retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
|
173 |
+
('acc_r', self.acc_r)])
|
174 |
+
|
175 |
+
for key in retDict.keys():
|
176 |
+
retDict[key] = np.mean(retDict[key])
|
177 |
+
|
178 |
+
return retDict
|
179 |
+
|
180 |
+
def get_current_visuals(self):
|
181 |
+
zoom_factor = 256/self.var_ref.data.size()[2]
|
182 |
+
|
183 |
+
ref_img = util.tensor2im(self.var_ref.data)
|
184 |
+
p0_img = util.tensor2im(self.var_p0.data)
|
185 |
+
p1_img = util.tensor2im(self.var_p1.data)
|
186 |
+
|
187 |
+
ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
|
188 |
+
p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
|
189 |
+
p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)
|
190 |
+
|
191 |
+
return OrderedDict([('ref', ref_img_vis),
|
192 |
+
('p0', p0_img_vis),
|
193 |
+
('p1', p1_img_vis)])
|
194 |
+
|
195 |
+
def save(self, path, label):
|
196 |
+
if(self.use_gpu):
|
197 |
+
self.save_network(self.net.module, path, '', label)
|
198 |
+
else:
|
199 |
+
self.save_network(self.net, path, '', label)
|
200 |
+
self.save_network(self.rankLoss.net, path, 'rank', label)
|
201 |
+
|
202 |
+
def update_learning_rate(self,nepoch_decay):
|
203 |
+
lrd = self.lr / nepoch_decay
|
204 |
+
lr = self.old_lr - lrd
|
205 |
+
|
206 |
+
for param_group in self.optimizer_net.param_groups:
|
207 |
+
param_group['lr'] = lr
|
208 |
+
|
209 |
+
print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
|
210 |
+
self.old_lr = lr
|
211 |
+
|
212 |
+
def score_2afc_dataset(data_loader, func, name=''):
|
213 |
+
''' Function computes Two Alternative Forced Choice (2AFC) score using
|
214 |
+
distance function 'func' in dataset 'data_loader'
|
215 |
+
INPUTS
|
216 |
+
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
|
217 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
218 |
+
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
|
219 |
+
OUTPUTS
|
220 |
+
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
|
221 |
+
[1] - dictionary with following elements
|
222 |
+
d0s,d1s - N arrays containing distances between reference patch to perturbed patches
|
223 |
+
gts - N array in [0,1], preferred patch selected by human evaluators
|
224 |
+
(closer to "0" for left patch p0, "1" for right patch p1,
|
225 |
+
"0.6" means 60pct people preferred right patch, 40pct preferred left)
|
226 |
+
scores - N array in [0,1], corresponding to what percentage function agreed with humans
|
227 |
+
CONSTS
|
228 |
+
N - number of test triplets in data_loader
|
229 |
+
'''
|
230 |
+
|
231 |
+
d0s = []
|
232 |
+
d1s = []
|
233 |
+
gts = []
|
234 |
+
|
235 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
236 |
+
d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
|
237 |
+
d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
|
238 |
+
gts+=data['judge'].cpu().numpy().flatten().tolist()
|
239 |
+
|
240 |
+
d0s = np.array(d0s)
|
241 |
+
d1s = np.array(d1s)
|
242 |
+
gts = np.array(gts)
|
243 |
+
scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5
|
244 |
+
|
245 |
+
return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))
|
246 |
+
|
247 |
+
def score_jnd_dataset(data_loader, func, name=''):
|
248 |
+
''' Function computes JND score using distance function 'func' in dataset 'data_loader'
|
249 |
+
INPUTS
|
250 |
+
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
|
251 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
252 |
+
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
|
253 |
+
OUTPUTS
|
254 |
+
[0] - JND score in [0,1], mAP score (area under precision-recall curve)
|
255 |
+
[1] - dictionary with following elements
|
256 |
+
ds - N array containing distances between two patches shown to human evaluator
|
257 |
+
sames - N array containing fraction of people who thought the two patches were identical
|
258 |
+
CONSTS
|
259 |
+
N - number of test triplets in data_loader
|
260 |
+
'''
|
261 |
+
|
262 |
+
ds = []
|
263 |
+
gts = []
|
264 |
+
|
265 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
266 |
+
ds+=func(data['p0'],data['p1']).data.cpu().numpy().tolist()
|
267 |
+
gts+=data['same'].cpu().numpy().flatten().tolist()
|
268 |
+
|
269 |
+
sames = np.array(gts)
|
270 |
+
ds = np.array(ds)
|
271 |
+
|
272 |
+
sorted_inds = np.argsort(ds)
|
273 |
+
ds_sorted = ds[sorted_inds]
|
274 |
+
sames_sorted = sames[sorted_inds]
|
275 |
+
|
276 |
+
TPs = np.cumsum(sames_sorted)
|
277 |
+
FPs = np.cumsum(1-sames_sorted)
|
278 |
+
FNs = np.sum(sames_sorted)-TPs
|
279 |
+
|
280 |
+
precs = TPs/(TPs+FPs)
|
281 |
+
recs = TPs/(TPs+FNs)
|
282 |
+
score = util.voc_ap(recs,precs)
|
283 |
+
|
284 |
+
return(score, dict(ds=ds,sames=sames))
|
models/stylegan2/stylegan2-pytorch/lpips/networks_basic.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.init as init
|
8 |
+
from torch.autograd import Variable
|
9 |
+
import numpy as np
|
10 |
+
from pdb import set_trace as st
|
11 |
+
from skimage import color
|
12 |
+
from IPython import embed
|
13 |
+
from . import pretrained_networks as pn
|
14 |
+
|
15 |
+
import lpips as util
|
16 |
+
|
17 |
+
def spatial_average(in_tens, keepdim=True):
|
18 |
+
return in_tens.mean([2,3],keepdim=keepdim)
|
19 |
+
|
20 |
+
def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
|
21 |
+
in_H = in_tens.shape[2]
|
22 |
+
scale_factor = 1.*out_H/in_H
|
23 |
+
|
24 |
+
return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
|
25 |
+
|
26 |
+
# Learned perceptual metric
|
27 |
+
class PNetLin(nn.Module):
|
28 |
+
def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
|
29 |
+
super(PNetLin, self).__init__()
|
30 |
+
|
31 |
+
self.pnet_type = pnet_type
|
32 |
+
self.pnet_tune = pnet_tune
|
33 |
+
self.pnet_rand = pnet_rand
|
34 |
+
self.spatial = spatial
|
35 |
+
self.lpips = lpips
|
36 |
+
self.version = version
|
37 |
+
self.scaling_layer = ScalingLayer()
|
38 |
+
|
39 |
+
if(self.pnet_type in ['vgg','vgg16']):
|
40 |
+
net_type = pn.vgg16
|
41 |
+
self.chns = [64,128,256,512,512]
|
42 |
+
elif(self.pnet_type=='alex'):
|
43 |
+
net_type = pn.alexnet
|
44 |
+
self.chns = [64,192,384,256,256]
|
45 |
+
elif(self.pnet_type=='squeeze'):
|
46 |
+
net_type = pn.squeezenet
|
47 |
+
self.chns = [64,128,256,384,384,512,512]
|
48 |
+
self.L = len(self.chns)
|
49 |
+
|
50 |
+
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
|
51 |
+
|
52 |
+
if(lpips):
|
53 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
54 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
55 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
56 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
57 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
58 |
+
self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
|
59 |
+
if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
|
60 |
+
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
|
61 |
+
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
|
62 |
+
self.lins+=[self.lin5,self.lin6]
|
63 |
+
|
64 |
+
def forward(self, in0, in1, retPerLayer=False):
|
65 |
+
# v0.0 - original release had a bug, where input was not scaled
|
66 |
+
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
|
67 |
+
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
|
68 |
+
feats0, feats1, diffs = {}, {}, {}
|
69 |
+
|
70 |
+
for kk in range(self.L):
|
71 |
+
feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
|
72 |
+
diffs[kk] = (feats0[kk]-feats1[kk])**2
|
73 |
+
|
74 |
+
if(self.lpips):
|
75 |
+
if(self.spatial):
|
76 |
+
res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
|
77 |
+
else:
|
78 |
+
res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
|
79 |
+
else:
|
80 |
+
if(self.spatial):
|
81 |
+
res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
|
82 |
+
else:
|
83 |
+
res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
|
84 |
+
|
85 |
+
val = res[0]
|
86 |
+
for l in range(1,self.L):
|
87 |
+
val += res[l]
|
88 |
+
|
89 |
+
if(retPerLayer):
|
90 |
+
return (val, res)
|
91 |
+
else:
|
92 |
+
return val
|
93 |
+
|
94 |
+
class ScalingLayer(nn.Module):
|
95 |
+
def __init__(self):
|
96 |
+
super(ScalingLayer, self).__init__()
|
97 |
+
self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
|
98 |
+
self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
|
99 |
+
|
100 |
+
def forward(self, inp):
|
101 |
+
return (inp - self.shift) / self.scale
|
102 |
+
|
103 |
+
|
104 |
+
class NetLinLayer(nn.Module):
|
105 |
+
''' A single linear layer which does a 1x1 conv '''
|
106 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
107 |
+
super(NetLinLayer, self).__init__()
|
108 |
+
|
109 |
+
layers = [nn.Dropout(),] if(use_dropout) else []
|
110 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
|
111 |
+
self.model = nn.Sequential(*layers)
|
112 |
+
|
113 |
+
|
114 |
+
class Dist2LogitLayer(nn.Module):
|
115 |
+
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
|
116 |
+
def __init__(self, chn_mid=32, use_sigmoid=True):
|
117 |
+
super(Dist2LogitLayer, self).__init__()
|
118 |
+
|
119 |
+
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
|
120 |
+
layers += [nn.LeakyReLU(0.2,True),]
|
121 |
+
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
|
122 |
+
layers += [nn.LeakyReLU(0.2,True),]
|
123 |
+
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
|
124 |
+
if(use_sigmoid):
|
125 |
+
layers += [nn.Sigmoid(),]
|
126 |
+
self.model = nn.Sequential(*layers)
|
127 |
+
|
128 |
+
def forward(self,d0,d1,eps=0.1):
|
129 |
+
return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
|
130 |
+
|
131 |
+
class BCERankingLoss(nn.Module):
|
132 |
+
def __init__(self, chn_mid=32):
|
133 |
+
super(BCERankingLoss, self).__init__()
|
134 |
+
self.net = Dist2LogitLayer(chn_mid=chn_mid)
|
135 |
+
# self.parameters = list(self.net.parameters())
|
136 |
+
self.loss = torch.nn.BCELoss()
|
137 |
+
|
138 |
+
def forward(self, d0, d1, judge):
|
139 |
+
per = (judge+1.)/2.
|
140 |
+
self.logit = self.net.forward(d0,d1)
|
141 |
+
return self.loss(self.logit, per)
|
142 |
+
|
143 |
+
# L2, DSSIM metrics
|
144 |
+
class FakeNet(nn.Module):
|
145 |
+
def __init__(self, use_gpu=True, colorspace='Lab'):
|
146 |
+
super(FakeNet, self).__init__()
|
147 |
+
self.use_gpu = use_gpu
|
148 |
+
self.colorspace=colorspace
|
149 |
+
|
150 |
+
class L2(FakeNet):
|
151 |
+
|
152 |
+
def forward(self, in0, in1, retPerLayer=None):
|
153 |
+
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
154 |
+
|
155 |
+
if(self.colorspace=='RGB'):
|
156 |
+
(N,C,X,Y) = in0.size()
|
157 |
+
value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
|
158 |
+
return value
|
159 |
+
elif(self.colorspace=='Lab'):
|
160 |
+
value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
|
161 |
+
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
162 |
+
ret_var = Variable( torch.Tensor((value,) ) )
|
163 |
+
if(self.use_gpu):
|
164 |
+
ret_var = ret_var.cuda()
|
165 |
+
return ret_var
|
166 |
+
|
167 |
+
class DSSIM(FakeNet):
|
168 |
+
|
169 |
+
def forward(self, in0, in1, retPerLayer=None):
|
170 |
+
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
171 |
+
|
172 |
+
if(self.colorspace=='RGB'):
|
173 |
+
value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
|
174 |
+
elif(self.colorspace=='Lab'):
|
175 |
+
value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
|
176 |
+
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
177 |
+
ret_var = Variable( torch.Tensor((value,) ) )
|
178 |
+
if(self.use_gpu):
|
179 |
+
ret_var = ret_var.cuda()
|
180 |
+
return ret_var
|
181 |
+
|
182 |
+
def print_network(net):
|
183 |
+
num_params = 0
|
184 |
+
for param in net.parameters():
|
185 |
+
num_params += param.numel()
|
186 |
+
print('Network',net)
|
187 |
+
print('Total number of parameters: %d' % num_params)
|
models/stylegan2/stylegan2-pytorch/lpips/pretrained_networks.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import torch
|
3 |
+
from torchvision import models as tv
|
4 |
+
from IPython import embed
|
5 |
+
|
6 |
+
class squeezenet(torch.nn.Module):
|
7 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
8 |
+
super(squeezenet, self).__init__()
|
9 |
+
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
|
10 |
+
self.slice1 = torch.nn.Sequential()
|
11 |
+
self.slice2 = torch.nn.Sequential()
|
12 |
+
self.slice3 = torch.nn.Sequential()
|
13 |
+
self.slice4 = torch.nn.Sequential()
|
14 |
+
self.slice5 = torch.nn.Sequential()
|
15 |
+
self.slice6 = torch.nn.Sequential()
|
16 |
+
self.slice7 = torch.nn.Sequential()
|
17 |
+
self.N_slices = 7
|
18 |
+
for x in range(2):
|
19 |
+
self.slice1.add_module(str(x), pretrained_features[x])
|
20 |
+
for x in range(2,5):
|
21 |
+
self.slice2.add_module(str(x), pretrained_features[x])
|
22 |
+
for x in range(5, 8):
|
23 |
+
self.slice3.add_module(str(x), pretrained_features[x])
|
24 |
+
for x in range(8, 10):
|
25 |
+
self.slice4.add_module(str(x), pretrained_features[x])
|
26 |
+
for x in range(10, 11):
|
27 |
+
self.slice5.add_module(str(x), pretrained_features[x])
|
28 |
+
for x in range(11, 12):
|
29 |
+
self.slice6.add_module(str(x), pretrained_features[x])
|
30 |
+
for x in range(12, 13):
|
31 |
+
self.slice7.add_module(str(x), pretrained_features[x])
|
32 |
+
if not requires_grad:
|
33 |
+
for param in self.parameters():
|
34 |
+
param.requires_grad = False
|
35 |
+
|
36 |
+
def forward(self, X):
|
37 |
+
h = self.slice1(X)
|
38 |
+
h_relu1 = h
|
39 |
+
h = self.slice2(h)
|
40 |
+
h_relu2 = h
|
41 |
+
h = self.slice3(h)
|
42 |
+
h_relu3 = h
|
43 |
+
h = self.slice4(h)
|
44 |
+
h_relu4 = h
|
45 |
+
h = self.slice5(h)
|
46 |
+
h_relu5 = h
|
47 |
+
h = self.slice6(h)
|
48 |
+
h_relu6 = h
|
49 |
+
h = self.slice7(h)
|
50 |
+
h_relu7 = h
|
51 |
+
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
|
52 |
+
out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
|
53 |
+
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
class alexnet(torch.nn.Module):
|
58 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
59 |
+
super(alexnet, self).__init__()
|
60 |
+
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
|
61 |
+
self.slice1 = torch.nn.Sequential()
|
62 |
+
self.slice2 = torch.nn.Sequential()
|
63 |
+
self.slice3 = torch.nn.Sequential()
|
64 |
+
self.slice4 = torch.nn.Sequential()
|
65 |
+
self.slice5 = torch.nn.Sequential()
|
66 |
+
self.N_slices = 5
|
67 |
+
for x in range(2):
|
68 |
+
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
69 |
+
for x in range(2, 5):
|
70 |
+
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
71 |
+
for x in range(5, 8):
|
72 |
+
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
73 |
+
for x in range(8, 10):
|
74 |
+
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
75 |
+
for x in range(10, 12):
|
76 |
+
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
77 |
+
if not requires_grad:
|
78 |
+
for param in self.parameters():
|
79 |
+
param.requires_grad = False
|
80 |
+
|
81 |
+
def forward(self, X):
|
82 |
+
h = self.slice1(X)
|
83 |
+
h_relu1 = h
|
84 |
+
h = self.slice2(h)
|
85 |
+
h_relu2 = h
|
86 |
+
h = self.slice3(h)
|
87 |
+
h_relu3 = h
|
88 |
+
h = self.slice4(h)
|
89 |
+
h_relu4 = h
|
90 |
+
h = self.slice5(h)
|
91 |
+
h_relu5 = h
|
92 |
+
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
|
93 |
+
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
94 |
+
|
95 |
+
return out
|
96 |
+
|
97 |
+
class vgg16(torch.nn.Module):
|
98 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
99 |
+
super(vgg16, self).__init__()
|
100 |
+
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
|
101 |
+
self.slice1 = torch.nn.Sequential()
|
102 |
+
self.slice2 = torch.nn.Sequential()
|
103 |
+
self.slice3 = torch.nn.Sequential()
|
104 |
+
self.slice4 = torch.nn.Sequential()
|
105 |
+
self.slice5 = torch.nn.Sequential()
|
106 |
+
self.N_slices = 5
|
107 |
+
for x in range(4):
|
108 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
109 |
+
for x in range(4, 9):
|
110 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
111 |
+
for x in range(9, 16):
|
112 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
113 |
+
for x in range(16, 23):
|
114 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
115 |
+
for x in range(23, 30):
|
116 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
117 |
+
if not requires_grad:
|
118 |
+
for param in self.parameters():
|
119 |
+
param.requires_grad = False
|
120 |
+
|
121 |
+
def forward(self, X):
|
122 |
+
h = self.slice1(X)
|
123 |
+
h_relu1_2 = h
|
124 |
+
h = self.slice2(h)
|
125 |
+
h_relu2_2 = h
|
126 |
+
h = self.slice3(h)
|
127 |
+
h_relu3_3 = h
|
128 |
+
h = self.slice4(h)
|
129 |
+
h_relu4_3 = h
|
130 |
+
h = self.slice5(h)
|
131 |
+
h_relu5_3 = h
|
132 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
133 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
134 |
+
|
135 |
+
return out
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
class resnet(torch.nn.Module):
|
140 |
+
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
141 |
+
super(resnet, self).__init__()
|
142 |
+
if(num==18):
|
143 |
+
self.net = tv.resnet18(pretrained=pretrained)
|
144 |
+
elif(num==34):
|
145 |
+
self.net = tv.resnet34(pretrained=pretrained)
|
146 |
+
elif(num==50):
|
147 |
+
self.net = tv.resnet50(pretrained=pretrained)
|
148 |
+
elif(num==101):
|
149 |
+
self.net = tv.resnet101(pretrained=pretrained)
|
150 |
+
elif(num==152):
|
151 |
+
self.net = tv.resnet152(pretrained=pretrained)
|
152 |
+
self.N_slices = 5
|
153 |
+
|
154 |
+
self.conv1 = self.net.conv1
|
155 |
+
self.bn1 = self.net.bn1
|
156 |
+
self.relu = self.net.relu
|
157 |
+
self.maxpool = self.net.maxpool
|
158 |
+
self.layer1 = self.net.layer1
|
159 |
+
self.layer2 = self.net.layer2
|
160 |
+
self.layer3 = self.net.layer3
|
161 |
+
self.layer4 = self.net.layer4
|
162 |
+
|
163 |
+
def forward(self, X):
|
164 |
+
h = self.conv1(X)
|
165 |
+
h = self.bn1(h)
|
166 |
+
h = self.relu(h)
|
167 |
+
h_relu1 = h
|
168 |
+
h = self.maxpool(h)
|
169 |
+
h = self.layer1(h)
|
170 |
+
h_conv2 = h
|
171 |
+
h = self.layer2(h)
|
172 |
+
h_conv3 = h
|
173 |
+
h = self.layer3(h)
|
174 |
+
h_conv4 = h
|
175 |
+
h = self.layer4(h)
|
176 |
+
h_conv5 = h
|
177 |
+
|
178 |
+
outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
|
179 |
+
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
180 |
+
|
181 |
+
return out
|
models/stylegan2/stylegan2-pytorch/lpips/weights/v0.0/alex.pth
ADDED
Binary file (5.46 kB). View file
|
|
models/stylegan2/stylegan2-pytorch/lpips/weights/v0.0/squeeze.pth
ADDED
Binary file (10.1 kB). View file
|
|
models/stylegan2/stylegan2-pytorch/lpips/weights/v0.0/vgg.pth
ADDED
Binary file (6.74 kB). View file
|
|
models/stylegan2/stylegan2-pytorch/lpips/weights/v0.1/alex.pth
ADDED
Binary file (6.01 kB). View file
|
|
models/stylegan2/stylegan2-pytorch/lpips/weights/v0.1/squeeze.pth
ADDED
Binary file (10.8 kB). View file
|
|
models/stylegan2/stylegan2-pytorch/lpips/weights/v0.1/vgg.pth
ADDED
Binary file (7.29 kB). View file
|
|
models/stylegan2/stylegan2-pytorch/model.py
ADDED
@@ -0,0 +1,703 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import functools
|
4 |
+
import operator
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from torch.autograd import Function
|
10 |
+
|
11 |
+
from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
|
12 |
+
|
13 |
+
|
14 |
+
class PixelNorm(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
def forward(self, input):
|
19 |
+
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
20 |
+
|
21 |
+
|
22 |
+
def make_kernel(k):
|
23 |
+
k = torch.tensor(k, dtype=torch.float32)
|
24 |
+
|
25 |
+
if k.ndim == 1:
|
26 |
+
k = k[None, :] * k[:, None]
|
27 |
+
|
28 |
+
k /= k.sum()
|
29 |
+
|
30 |
+
return k
|
31 |
+
|
32 |
+
|
33 |
+
class Upsample(nn.Module):
|
34 |
+
def __init__(self, kernel, factor=2):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.factor = factor
|
38 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
39 |
+
self.register_buffer('kernel', kernel)
|
40 |
+
|
41 |
+
p = kernel.shape[0] - factor
|
42 |
+
|
43 |
+
pad0 = (p + 1) // 2 + factor - 1
|
44 |
+
pad1 = p // 2
|
45 |
+
|
46 |
+
self.pad = (pad0, pad1)
|
47 |
+
|
48 |
+
def forward(self, input):
|
49 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
50 |
+
|
51 |
+
return out
|
52 |
+
|
53 |
+
|
54 |
+
class Downsample(nn.Module):
|
55 |
+
def __init__(self, kernel, factor=2):
|
56 |
+
super().__init__()
|
57 |
+
|
58 |
+
self.factor = factor
|
59 |
+
kernel = make_kernel(kernel)
|
60 |
+
self.register_buffer('kernel', kernel)
|
61 |
+
|
62 |
+
p = kernel.shape[0] - factor
|
63 |
+
|
64 |
+
pad0 = (p + 1) // 2
|
65 |
+
pad1 = p // 2
|
66 |
+
|
67 |
+
self.pad = (pad0, pad1)
|
68 |
+
|
69 |
+
def forward(self, input):
|
70 |
+
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
71 |
+
|
72 |
+
return out
|
73 |
+
|
74 |
+
|
75 |
+
class Blur(nn.Module):
|
76 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
kernel = make_kernel(kernel)
|
80 |
+
|
81 |
+
if upsample_factor > 1:
|
82 |
+
kernel = kernel * (upsample_factor ** 2)
|
83 |
+
|
84 |
+
self.register_buffer('kernel', kernel)
|
85 |
+
|
86 |
+
self.pad = pad
|
87 |
+
|
88 |
+
def forward(self, input):
|
89 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
90 |
+
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class EqualConv2d(nn.Module):
|
95 |
+
def __init__(
|
96 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
|
100 |
+
self.weight = nn.Parameter(
|
101 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
102 |
+
)
|
103 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
104 |
+
|
105 |
+
self.stride = stride
|
106 |
+
self.padding = padding
|
107 |
+
|
108 |
+
if bias:
|
109 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
110 |
+
|
111 |
+
else:
|
112 |
+
self.bias = None
|
113 |
+
|
114 |
+
def forward(self, input):
|
115 |
+
out = F.conv2d(
|
116 |
+
input,
|
117 |
+
self.weight * self.scale,
|
118 |
+
bias=self.bias,
|
119 |
+
stride=self.stride,
|
120 |
+
padding=self.padding,
|
121 |
+
)
|
122 |
+
|
123 |
+
return out
|
124 |
+
|
125 |
+
def __repr__(self):
|
126 |
+
return (
|
127 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
128 |
+
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
129 |
+
)
|
130 |
+
|
131 |
+
|
132 |
+
class EqualLinear(nn.Module):
|
133 |
+
def __init__(
|
134 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
135 |
+
):
|
136 |
+
super().__init__()
|
137 |
+
|
138 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
139 |
+
|
140 |
+
if bias:
|
141 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
142 |
+
|
143 |
+
else:
|
144 |
+
self.bias = None
|
145 |
+
|
146 |
+
self.activation = activation
|
147 |
+
|
148 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
149 |
+
self.lr_mul = lr_mul
|
150 |
+
|
151 |
+
def forward(self, input):
|
152 |
+
if self.activation:
|
153 |
+
out = F.linear(input, self.weight * self.scale)
|
154 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
155 |
+
|
156 |
+
else:
|
157 |
+
out = F.linear(
|
158 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
159 |
+
)
|
160 |
+
|
161 |
+
return out
|
162 |
+
|
163 |
+
def __repr__(self):
|
164 |
+
return (
|
165 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
|
166 |
+
)
|
167 |
+
|
168 |
+
|
169 |
+
class ScaledLeakyReLU(nn.Module):
|
170 |
+
def __init__(self, negative_slope=0.2):
|
171 |
+
super().__init__()
|
172 |
+
|
173 |
+
self.negative_slope = negative_slope
|
174 |
+
|
175 |
+
def forward(self, input):
|
176 |
+
out = F.leaky_relu(input, negative_slope=self.negative_slope)
|
177 |
+
|
178 |
+
return out * math.sqrt(2)
|
179 |
+
|
180 |
+
|
181 |
+
class ModulatedConv2d(nn.Module):
|
182 |
+
def __init__(
|
183 |
+
self,
|
184 |
+
in_channel,
|
185 |
+
out_channel,
|
186 |
+
kernel_size,
|
187 |
+
style_dim,
|
188 |
+
demodulate=True,
|
189 |
+
upsample=False,
|
190 |
+
downsample=False,
|
191 |
+
blur_kernel=[1, 3, 3, 1],
|
192 |
+
):
|
193 |
+
super().__init__()
|
194 |
+
|
195 |
+
self.eps = 1e-8
|
196 |
+
self.kernel_size = kernel_size
|
197 |
+
self.in_channel = in_channel
|
198 |
+
self.out_channel = out_channel
|
199 |
+
self.upsample = upsample
|
200 |
+
self.downsample = downsample
|
201 |
+
|
202 |
+
if upsample:
|
203 |
+
factor = 2
|
204 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
205 |
+
pad0 = (p + 1) // 2 + factor - 1
|
206 |
+
pad1 = p // 2 + 1
|
207 |
+
|
208 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
209 |
+
|
210 |
+
if downsample:
|
211 |
+
factor = 2
|
212 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
213 |
+
pad0 = (p + 1) // 2
|
214 |
+
pad1 = p // 2
|
215 |
+
|
216 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
217 |
+
|
218 |
+
fan_in = in_channel * kernel_size ** 2
|
219 |
+
self.scale = 1 / math.sqrt(fan_in)
|
220 |
+
self.padding = kernel_size // 2
|
221 |
+
|
222 |
+
self.weight = nn.Parameter(
|
223 |
+
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
224 |
+
)
|
225 |
+
|
226 |
+
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
227 |
+
|
228 |
+
self.demodulate = demodulate
|
229 |
+
|
230 |
+
def __repr__(self):
|
231 |
+
return (
|
232 |
+
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
|
233 |
+
f'upsample={self.upsample}, downsample={self.downsample})'
|
234 |
+
)
|
235 |
+
|
236 |
+
def forward(self, input, style):
|
237 |
+
batch, in_channel, height, width = input.shape
|
238 |
+
|
239 |
+
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
240 |
+
weight = self.scale * self.weight * style
|
241 |
+
|
242 |
+
if self.demodulate:
|
243 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
244 |
+
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
245 |
+
|
246 |
+
weight = weight.view(
|
247 |
+
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
248 |
+
)
|
249 |
+
|
250 |
+
if self.upsample:
|
251 |
+
input = input.view(1, batch * in_channel, height, width)
|
252 |
+
weight = weight.view(
|
253 |
+
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
254 |
+
)
|
255 |
+
weight = weight.transpose(1, 2).reshape(
|
256 |
+
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
257 |
+
)
|
258 |
+
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
259 |
+
_, _, height, width = out.shape
|
260 |
+
out = out.view(batch, self.out_channel, height, width)
|
261 |
+
out = self.blur(out)
|
262 |
+
|
263 |
+
elif self.downsample:
|
264 |
+
input = self.blur(input)
|
265 |
+
_, _, height, width = input.shape
|
266 |
+
input = input.view(1, batch * in_channel, height, width)
|
267 |
+
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
268 |
+
_, _, height, width = out.shape
|
269 |
+
out = out.view(batch, self.out_channel, height, width)
|
270 |
+
|
271 |
+
else:
|
272 |
+
input = input.view(1, batch * in_channel, height, width)
|
273 |
+
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
|
274 |
+
_, _, height, width = out.shape
|
275 |
+
out = out.view(batch, self.out_channel, height, width)
|
276 |
+
|
277 |
+
return out
|
278 |
+
|
279 |
+
|
280 |
+
class NoiseInjection(nn.Module):
|
281 |
+
def __init__(self):
|
282 |
+
super().__init__()
|
283 |
+
|
284 |
+
self.weight = nn.Parameter(torch.zeros(1))
|
285 |
+
|
286 |
+
def forward(self, image, noise=None):
|
287 |
+
if noise is None:
|
288 |
+
batch, _, height, width = image.shape
|
289 |
+
noise = image.new_empty(batch, 1, height, width).normal_()
|
290 |
+
|
291 |
+
return image + self.weight * noise
|
292 |
+
|
293 |
+
|
294 |
+
class ConstantInput(nn.Module):
|
295 |
+
def __init__(self, channel, size=4):
|
296 |
+
super().__init__()
|
297 |
+
|
298 |
+
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
299 |
+
|
300 |
+
def forward(self, input):
|
301 |
+
batch = input.shape[0]
|
302 |
+
out = self.input.repeat(batch, 1, 1, 1)
|
303 |
+
|
304 |
+
return out
|
305 |
+
|
306 |
+
|
307 |
+
class StyledConv(nn.Module):
|
308 |
+
def __init__(
|
309 |
+
self,
|
310 |
+
in_channel,
|
311 |
+
out_channel,
|
312 |
+
kernel_size,
|
313 |
+
style_dim,
|
314 |
+
upsample=False,
|
315 |
+
blur_kernel=[1, 3, 3, 1],
|
316 |
+
demodulate=True,
|
317 |
+
):
|
318 |
+
super().__init__()
|
319 |
+
|
320 |
+
self.conv = ModulatedConv2d(
|
321 |
+
in_channel,
|
322 |
+
out_channel,
|
323 |
+
kernel_size,
|
324 |
+
style_dim,
|
325 |
+
upsample=upsample,
|
326 |
+
blur_kernel=blur_kernel,
|
327 |
+
demodulate=demodulate,
|
328 |
+
)
|
329 |
+
|
330 |
+
self.noise = NoiseInjection()
|
331 |
+
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
332 |
+
# self.activate = ScaledLeakyReLU(0.2)
|
333 |
+
self.activate = FusedLeakyReLU(out_channel)
|
334 |
+
|
335 |
+
def forward(self, input, style, noise=None):
|
336 |
+
out = self.conv(input, style)
|
337 |
+
out = self.noise(out, noise=noise)
|
338 |
+
# out = out + self.bias
|
339 |
+
out = self.activate(out)
|
340 |
+
|
341 |
+
return out
|
342 |
+
|
343 |
+
|
344 |
+
class ToRGB(nn.Module):
|
345 |
+
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
346 |
+
super().__init__()
|
347 |
+
|
348 |
+
if upsample:
|
349 |
+
self.upsample = Upsample(blur_kernel)
|
350 |
+
|
351 |
+
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
352 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
353 |
+
|
354 |
+
def forward(self, input, style, skip=None):
|
355 |
+
out = self.conv(input, style)
|
356 |
+
out = out + self.bias
|
357 |
+
|
358 |
+
if skip is not None:
|
359 |
+
skip = self.upsample(skip)
|
360 |
+
|
361 |
+
out = out + skip
|
362 |
+
|
363 |
+
return out
|
364 |
+
|
365 |
+
# Wrapper that gives name to tensor
|
366 |
+
class NamedTensor(nn.Module):
|
367 |
+
def __init__(self):
|
368 |
+
super().__init__()
|
369 |
+
|
370 |
+
def forward(self, x):
|
371 |
+
return x
|
372 |
+
|
373 |
+
# Give each style a unique name
|
374 |
+
class StridedStyle(nn.ModuleList):
|
375 |
+
def __init__(self, n_latents):
|
376 |
+
super().__init__([NamedTensor() for _ in range(n_latents)])
|
377 |
+
self.n_latents = n_latents
|
378 |
+
|
379 |
+
def forward(self, x):
|
380 |
+
# x already strided
|
381 |
+
styles = [self[i](x[:, i, :]) for i in range(self.n_latents)]
|
382 |
+
return torch.stack(styles, dim=1)
|
383 |
+
|
384 |
+
class Generator(nn.Module):
|
385 |
+
def __init__(
|
386 |
+
self,
|
387 |
+
size,
|
388 |
+
style_dim,
|
389 |
+
n_mlp,
|
390 |
+
channel_multiplier=2,
|
391 |
+
blur_kernel=[1, 3, 3, 1],
|
392 |
+
lr_mlp=0.01,
|
393 |
+
):
|
394 |
+
super().__init__()
|
395 |
+
|
396 |
+
self.size = size
|
397 |
+
|
398 |
+
self.style_dim = style_dim
|
399 |
+
|
400 |
+
layers = [PixelNorm()]
|
401 |
+
|
402 |
+
for i in range(n_mlp):
|
403 |
+
layers.append(
|
404 |
+
EqualLinear(
|
405 |
+
style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
|
406 |
+
)
|
407 |
+
)
|
408 |
+
|
409 |
+
self.style = nn.Sequential(*layers)
|
410 |
+
|
411 |
+
self.channels = {
|
412 |
+
4: 512,
|
413 |
+
8: 512,
|
414 |
+
16: 512,
|
415 |
+
32: 512,
|
416 |
+
64: 256 * channel_multiplier,
|
417 |
+
128: 128 * channel_multiplier,
|
418 |
+
256: 64 * channel_multiplier,
|
419 |
+
512: 32 * channel_multiplier,
|
420 |
+
1024: 16 * channel_multiplier,
|
421 |
+
}
|
422 |
+
|
423 |
+
self.input = ConstantInput(self.channels[4])
|
424 |
+
self.conv1 = StyledConv(
|
425 |
+
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
426 |
+
)
|
427 |
+
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
428 |
+
|
429 |
+
self.log_size = int(math.log(size, 2))
|
430 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
431 |
+
|
432 |
+
self.convs = nn.ModuleList()
|
433 |
+
self.upsamples = nn.ModuleList()
|
434 |
+
self.to_rgbs = nn.ModuleList()
|
435 |
+
self.noises = nn.Module()
|
436 |
+
|
437 |
+
in_channel = self.channels[4]
|
438 |
+
|
439 |
+
for layer_idx in range(self.num_layers):
|
440 |
+
res = (layer_idx + 5) // 2
|
441 |
+
shape = [1, 1, 2 ** res, 2 ** res]
|
442 |
+
self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
|
443 |
+
|
444 |
+
for i in range(3, self.log_size + 1):
|
445 |
+
out_channel = self.channels[2 ** i]
|
446 |
+
|
447 |
+
self.convs.append(
|
448 |
+
StyledConv(
|
449 |
+
in_channel,
|
450 |
+
out_channel,
|
451 |
+
3,
|
452 |
+
style_dim,
|
453 |
+
upsample=True,
|
454 |
+
blur_kernel=blur_kernel,
|
455 |
+
)
|
456 |
+
)
|
457 |
+
|
458 |
+
self.convs.append(
|
459 |
+
StyledConv(
|
460 |
+
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
461 |
+
)
|
462 |
+
)
|
463 |
+
|
464 |
+
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
465 |
+
|
466 |
+
in_channel = out_channel
|
467 |
+
|
468 |
+
self.n_latent = self.log_size * 2 - 2
|
469 |
+
self.strided_style = StridedStyle(self.n_latent)
|
470 |
+
|
471 |
+
def make_noise(self):
|
472 |
+
device = self.input.input.device
|
473 |
+
|
474 |
+
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
475 |
+
|
476 |
+
for i in range(3, self.log_size + 1):
|
477 |
+
for _ in range(2):
|
478 |
+
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
479 |
+
|
480 |
+
return noises
|
481 |
+
|
482 |
+
def mean_latent(self, n_latent):
|
483 |
+
latent_in = torch.randn(
|
484 |
+
n_latent, self.style_dim, device=self.input.input.device
|
485 |
+
)
|
486 |
+
latent = self.style(latent_in).mean(0, keepdim=True)
|
487 |
+
|
488 |
+
return latent
|
489 |
+
|
490 |
+
def get_latent(self, input):
|
491 |
+
return self.style(input)
|
492 |
+
|
493 |
+
def forward(
|
494 |
+
self,
|
495 |
+
styles,
|
496 |
+
return_latents=False,
|
497 |
+
inject_index=None,
|
498 |
+
truncation=1,
|
499 |
+
truncation_latent=None,
|
500 |
+
input_is_w=False,
|
501 |
+
noise=None,
|
502 |
+
randomize_noise=True,
|
503 |
+
):
|
504 |
+
if not input_is_w:
|
505 |
+
styles = [self.style(s) for s in styles]
|
506 |
+
|
507 |
+
if noise is None:
|
508 |
+
if randomize_noise:
|
509 |
+
noise = [None] * self.num_layers
|
510 |
+
else:
|
511 |
+
noise = [
|
512 |
+
getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
|
513 |
+
]
|
514 |
+
|
515 |
+
if truncation < 1:
|
516 |
+
style_t = []
|
517 |
+
|
518 |
+
for style in styles:
|
519 |
+
style_t.append(
|
520 |
+
truncation_latent + truncation * (style - truncation_latent)
|
521 |
+
)
|
522 |
+
|
523 |
+
styles = style_t
|
524 |
+
|
525 |
+
if len(styles) == 1:
|
526 |
+
# One global latent
|
527 |
+
inject_index = self.n_latent
|
528 |
+
|
529 |
+
if styles[0].ndim < 3:
|
530 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
531 |
+
|
532 |
+
else:
|
533 |
+
latent = styles[0]
|
534 |
+
|
535 |
+
elif len(styles) == 2:
|
536 |
+
# Latent mixing with two latents
|
537 |
+
if inject_index is None:
|
538 |
+
inject_index = random.randint(1, self.n_latent - 1)
|
539 |
+
|
540 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
541 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
542 |
+
|
543 |
+
latent = self.strided_style(torch.cat([latent, latent2], 1))
|
544 |
+
else:
|
545 |
+
# One latent per layer
|
546 |
+
assert len(styles) == self.n_latent, f'Expected {self.n_latents} latents, got {len(styles)}'
|
547 |
+
styles = torch.stack(styles, dim=1) # [N, 18, 512]
|
548 |
+
latent = self.strided_style(styles)
|
549 |
+
|
550 |
+
out = self.input(latent)
|
551 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
552 |
+
|
553 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
554 |
+
|
555 |
+
i = 1
|
556 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
557 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
558 |
+
):
|
559 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
560 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
561 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
562 |
+
|
563 |
+
i += 2
|
564 |
+
|
565 |
+
image = skip
|
566 |
+
|
567 |
+
if return_latents:
|
568 |
+
return image, latent
|
569 |
+
|
570 |
+
else:
|
571 |
+
return image, None
|
572 |
+
|
573 |
+
|
574 |
+
class ConvLayer(nn.Sequential):
|
575 |
+
def __init__(
|
576 |
+
self,
|
577 |
+
in_channel,
|
578 |
+
out_channel,
|
579 |
+
kernel_size,
|
580 |
+
downsample=False,
|
581 |
+
blur_kernel=[1, 3, 3, 1],
|
582 |
+
bias=True,
|
583 |
+
activate=True,
|
584 |
+
):
|
585 |
+
layers = []
|
586 |
+
|
587 |
+
if downsample:
|
588 |
+
factor = 2
|
589 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
590 |
+
pad0 = (p + 1) // 2
|
591 |
+
pad1 = p // 2
|
592 |
+
|
593 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
594 |
+
|
595 |
+
stride = 2
|
596 |
+
self.padding = 0
|
597 |
+
|
598 |
+
else:
|
599 |
+
stride = 1
|
600 |
+
self.padding = kernel_size // 2
|
601 |
+
|
602 |
+
layers.append(
|
603 |
+
EqualConv2d(
|
604 |
+
in_channel,
|
605 |
+
out_channel,
|
606 |
+
kernel_size,
|
607 |
+
padding=self.padding,
|
608 |
+
stride=stride,
|
609 |
+
bias=bias and not activate,
|
610 |
+
)
|
611 |
+
)
|
612 |
+
|
613 |
+
if activate:
|
614 |
+
if bias:
|
615 |
+
layers.append(FusedLeakyReLU(out_channel))
|
616 |
+
|
617 |
+
else:
|
618 |
+
layers.append(ScaledLeakyReLU(0.2))
|
619 |
+
|
620 |
+
super().__init__(*layers)
|
621 |
+
|
622 |
+
|
623 |
+
class ResBlock(nn.Module):
|
624 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
625 |
+
super().__init__()
|
626 |
+
|
627 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
628 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
629 |
+
|
630 |
+
self.skip = ConvLayer(
|
631 |
+
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
632 |
+
)
|
633 |
+
|
634 |
+
def forward(self, input):
|
635 |
+
out = self.conv1(input)
|
636 |
+
out = self.conv2(out)
|
637 |
+
|
638 |
+
skip = self.skip(input)
|
639 |
+
out = (out + skip) / math.sqrt(2)
|
640 |
+
|
641 |
+
return out
|
642 |
+
|
643 |
+
|
644 |
+
class Discriminator(nn.Module):
|
645 |
+
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
646 |
+
super().__init__()
|
647 |
+
|
648 |
+
channels = {
|
649 |
+
4: 512,
|
650 |
+
8: 512,
|
651 |
+
16: 512,
|
652 |
+
32: 512,
|
653 |
+
64: 256 * channel_multiplier,
|
654 |
+
128: 128 * channel_multiplier,
|
655 |
+
256: 64 * channel_multiplier,
|
656 |
+
512: 32 * channel_multiplier,
|
657 |
+
1024: 16 * channel_multiplier,
|
658 |
+
}
|
659 |
+
|
660 |
+
convs = [ConvLayer(3, channels[size], 1)]
|
661 |
+
|
662 |
+
log_size = int(math.log(size, 2))
|
663 |
+
|
664 |
+
in_channel = channels[size]
|
665 |
+
|
666 |
+
for i in range(log_size, 2, -1):
|
667 |
+
out_channel = channels[2 ** (i - 1)]
|
668 |
+
|
669 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
670 |
+
|
671 |
+
in_channel = out_channel
|
672 |
+
|
673 |
+
self.convs = nn.Sequential(*convs)
|
674 |
+
|
675 |
+
self.stddev_group = 4
|
676 |
+
self.stddev_feat = 1
|
677 |
+
|
678 |
+
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
679 |
+
self.final_linear = nn.Sequential(
|
680 |
+
EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
|
681 |
+
EqualLinear(channels[4], 1),
|
682 |
+
)
|
683 |
+
|
684 |
+
def forward(self, input):
|
685 |
+
out = self.convs(input)
|
686 |
+
|
687 |
+
batch, channel, height, width = out.shape
|
688 |
+
group = min(batch, self.stddev_group)
|
689 |
+
stddev = out.view(
|
690 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
691 |
+
)
|
692 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
693 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
694 |
+
stddev = stddev.repeat(group, 1, height, width)
|
695 |
+
out = torch.cat([out, stddev], 1)
|
696 |
+
|
697 |
+
out = self.final_conv(out)
|
698 |
+
|
699 |
+
out = out.view(batch, -1)
|
700 |
+
out = self.final_linear(out)
|
701 |
+
|
702 |
+
return out
|
703 |
+
|
models/stylegan2/stylegan2-pytorch/non_leaking.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def translate_mat(t_x, t_y):
|
8 |
+
batch = t_x.shape[0]
|
9 |
+
|
10 |
+
mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)
|
11 |
+
translate = torch.stack((t_x, t_y), 1)
|
12 |
+
mat[:, :2, 2] = translate
|
13 |
+
|
14 |
+
return mat
|
15 |
+
|
16 |
+
|
17 |
+
def rotate_mat(theta):
|
18 |
+
batch = theta.shape[0]
|
19 |
+
|
20 |
+
mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)
|
21 |
+
sin_t = torch.sin(theta)
|
22 |
+
cos_t = torch.cos(theta)
|
23 |
+
rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
|
24 |
+
mat[:, :2, :2] = rot
|
25 |
+
|
26 |
+
return mat
|
27 |
+
|
28 |
+
|
29 |
+
def scale_mat(s_x, s_y):
|
30 |
+
batch = s_x.shape[0]
|
31 |
+
|
32 |
+
mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)
|
33 |
+
mat[:, 0, 0] = s_x
|
34 |
+
mat[:, 1, 1] = s_y
|
35 |
+
|
36 |
+
return mat
|
37 |
+
|
38 |
+
|
39 |
+
def lognormal_sample(size, mean=0, std=1):
|
40 |
+
return torch.empty(size).log_normal_(mean=mean, std=std)
|
41 |
+
|
42 |
+
|
43 |
+
def category_sample(size, categories):
|
44 |
+
category = torch.tensor(categories)
|
45 |
+
sample = torch.randint(high=len(categories), size=(size,))
|
46 |
+
|
47 |
+
return category[sample]
|
48 |
+
|
49 |
+
|
50 |
+
def uniform_sample(size, low, high):
|
51 |
+
return torch.empty(size).uniform_(low, high)
|
52 |
+
|
53 |
+
|
54 |
+
def normal_sample(size, mean=0, std=1):
|
55 |
+
return torch.empty(size).normal_(mean, std)
|
56 |
+
|
57 |
+
|
58 |
+
def bernoulli_sample(size, p):
|
59 |
+
return torch.empty(size).bernoulli_(p)
|
60 |
+
|
61 |
+
|
62 |
+
def random_affine_apply(p, transform, prev, eye):
|
63 |
+
size = transform.shape[0]
|
64 |
+
select = bernoulli_sample(size, p).view(size, 1, 1)
|
65 |
+
select_transform = select * transform + (1 - select) * eye
|
66 |
+
|
67 |
+
return select_transform @ prev
|
68 |
+
|
69 |
+
|
70 |
+
def sample_affine(p, size, height, width):
|
71 |
+
G = torch.eye(3).unsqueeze(0).repeat(size, 1, 1)
|
72 |
+
eye = G
|
73 |
+
|
74 |
+
# flip
|
75 |
+
param = category_sample(size, (0, 1))
|
76 |
+
Gc = scale_mat(1 - 2.0 * param, torch.ones(size))
|
77 |
+
G = random_affine_apply(p, Gc, G, eye)
|
78 |
+
# print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
|
79 |
+
|
80 |
+
# 90 rotate
|
81 |
+
param = category_sample(size, (0, 3))
|
82 |
+
Gc = rotate_mat(-math.pi / 2 * param)
|
83 |
+
G = random_affine_apply(p, Gc, G, eye)
|
84 |
+
# print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
|
85 |
+
|
86 |
+
# integer translate
|
87 |
+
param = uniform_sample(size, -0.125, 0.125)
|
88 |
+
param_height = torch.round(param * height) / height
|
89 |
+
param_width = torch.round(param * width) / width
|
90 |
+
Gc = translate_mat(param_width, param_height)
|
91 |
+
G = random_affine_apply(p, Gc, G, eye)
|
92 |
+
# print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
|
93 |
+
|
94 |
+
# isotropic scale
|
95 |
+
param = lognormal_sample(size, std=0.2 * math.log(2))
|
96 |
+
Gc = scale_mat(param, param)
|
97 |
+
G = random_affine_apply(p, Gc, G, eye)
|
98 |
+
# print('isotropic scale', G, scale_mat(param, param), sep='\n')
|
99 |
+
|
100 |
+
p_rot = 1 - math.sqrt(1 - p)
|
101 |
+
|
102 |
+
# pre-rotate
|
103 |
+
param = uniform_sample(size, -math.pi, math.pi)
|
104 |
+
Gc = rotate_mat(-param)
|
105 |
+
G = random_affine_apply(p_rot, Gc, G, eye)
|
106 |
+
# print('pre-rotate', G, rotate_mat(-param), sep='\n')
|
107 |
+
|
108 |
+
# anisotropic scale
|
109 |
+
param = lognormal_sample(size, std=0.2 * math.log(2))
|
110 |
+
Gc = scale_mat(param, 1 / param)
|
111 |
+
G = random_affine_apply(p, Gc, G, eye)
|
112 |
+
# print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
|
113 |
+
|
114 |
+
# post-rotate
|
115 |
+
param = uniform_sample(size, -math.pi, math.pi)
|
116 |
+
Gc = rotate_mat(-param)
|
117 |
+
G = random_affine_apply(p_rot, Gc, G, eye)
|
118 |
+
# print('post-rotate', G, rotate_mat(-param), sep='\n')
|
119 |
+
|
120 |
+
# fractional translate
|
121 |
+
param = normal_sample(size, std=0.125)
|
122 |
+
Gc = translate_mat(param, param)
|
123 |
+
G = random_affine_apply(p, Gc, G, eye)
|
124 |
+
# print('fractional translate', G, translate_mat(param, param), sep='\n')
|
125 |
+
|
126 |
+
return G
|
127 |
+
|
128 |
+
|
129 |
+
def apply_affine(img, G):
|
130 |
+
grid = F.affine_grid(
|
131 |
+
torch.inverse(G).to(img)[:, :2, :], img.shape, align_corners=False
|
132 |
+
)
|
133 |
+
img_affine = F.grid_sample(
|
134 |
+
img, grid, mode="bilinear", align_corners=False, padding_mode="reflection"
|
135 |
+
)
|
136 |
+
|
137 |
+
return img_affine
|
models/stylegan2/stylegan2-pytorch/op/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
2 |
+
from .upfirdn2d import upfirdn2d
|
models/stylegan2/stylegan2-pytorch/op/fused_act.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import platform
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.autograd import Function
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.utils.cpp_extension import load
|
9 |
+
|
10 |
+
use_fallback = False
|
11 |
+
|
12 |
+
# Try loading precompiled, otherwise use native fallback
|
13 |
+
try:
|
14 |
+
import fused
|
15 |
+
except ModuleNotFoundError as e:
|
16 |
+
print('StyleGAN2: Optimized CUDA op FusedLeakyReLU not available, using native PyTorch fallback.')
|
17 |
+
use_fallback = True
|
18 |
+
|
19 |
+
|
20 |
+
class FusedLeakyReLUFunctionBackward(Function):
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx, grad_output, out, negative_slope, scale):
|
23 |
+
ctx.save_for_backward(out)
|
24 |
+
ctx.negative_slope = negative_slope
|
25 |
+
ctx.scale = scale
|
26 |
+
|
27 |
+
empty = grad_output.new_empty(0)
|
28 |
+
|
29 |
+
grad_input = fused.fused_bias_act(
|
30 |
+
grad_output, empty, out, 3, 1, negative_slope, scale
|
31 |
+
)
|
32 |
+
|
33 |
+
dim = [0]
|
34 |
+
|
35 |
+
if grad_input.ndim > 2:
|
36 |
+
dim += list(range(2, grad_input.ndim))
|
37 |
+
|
38 |
+
grad_bias = grad_input.sum(dim).detach()
|
39 |
+
|
40 |
+
return grad_input, grad_bias
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def backward(ctx, gradgrad_input, gradgrad_bias):
|
44 |
+
out, = ctx.saved_tensors
|
45 |
+
gradgrad_out = fused.fused_bias_act(
|
46 |
+
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
|
47 |
+
)
|
48 |
+
|
49 |
+
return gradgrad_out, None, None, None
|
50 |
+
|
51 |
+
|
52 |
+
class FusedLeakyReLUFunction(Function):
|
53 |
+
@staticmethod
|
54 |
+
def forward(ctx, input, bias, negative_slope, scale):
|
55 |
+
empty = input.new_empty(0)
|
56 |
+
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
57 |
+
ctx.save_for_backward(out)
|
58 |
+
ctx.negative_slope = negative_slope
|
59 |
+
ctx.scale = scale
|
60 |
+
|
61 |
+
return out
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def backward(ctx, grad_output):
|
65 |
+
out, = ctx.saved_tensors
|
66 |
+
|
67 |
+
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
68 |
+
grad_output, out, ctx.negative_slope, ctx.scale
|
69 |
+
)
|
70 |
+
|
71 |
+
return grad_input, grad_bias, None, None
|
72 |
+
|
73 |
+
|
74 |
+
class FusedLeakyReLU(nn.Module):
|
75 |
+
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
76 |
+
super().__init__()
|
77 |
+
|
78 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
79 |
+
self.negative_slope = negative_slope
|
80 |
+
self.scale = scale
|
81 |
+
|
82 |
+
def forward(self, input):
|
83 |
+
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
84 |
+
|
85 |
+
|
86 |
+
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
87 |
+
if use_fallback or input.device.type == 'cpu':
|
88 |
+
return scale * F.leaky_relu(
|
89 |
+
input + bias.view((1, -1)+(1,)*(input.ndim-2)), negative_slope=negative_slope
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
models/stylegan2/stylegan2-pytorch/op/fused_bias_act.cpp
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
|
4 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
5 |
+
int act, int grad, float alpha, float scale);
|
6 |
+
|
7 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
8 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
9 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
10 |
+
|
11 |
+
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
12 |
+
int act, int grad, float alpha, float scale) {
|
13 |
+
CHECK_CUDA(input);
|
14 |
+
CHECK_CUDA(bias);
|
15 |
+
|
16 |
+
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
17 |
+
}
|
18 |
+
|
19 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
20 |
+
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
21 |
+
}
|
models/stylegan2/stylegan2-pytorch/op/fused_bias_act_kernel.cu
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAContext.h>
|
12 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
|
18 |
+
template <typename scalar_t>
|
19 |
+
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
|
20 |
+
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
|
21 |
+
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
22 |
+
|
23 |
+
scalar_t zero = 0.0;
|
24 |
+
|
25 |
+
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
|
26 |
+
scalar_t x = p_x[xi];
|
27 |
+
|
28 |
+
if (use_bias) {
|
29 |
+
x += p_b[(xi / step_b) % size_b];
|
30 |
+
}
|
31 |
+
|
32 |
+
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
33 |
+
|
34 |
+
scalar_t y;
|
35 |
+
|
36 |
+
switch (act * 10 + grad) {
|
37 |
+
default:
|
38 |
+
case 10: y = x; break;
|
39 |
+
case 11: y = x; break;
|
40 |
+
case 12: y = 0.0; break;
|
41 |
+
|
42 |
+
case 30: y = (x > 0.0) ? x : x * alpha; break;
|
43 |
+
case 31: y = (ref > 0.0) ? x : x * alpha; break;
|
44 |
+
case 32: y = 0.0; break;
|
45 |
+
}
|
46 |
+
|
47 |
+
out[xi] = y * scale;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
53 |
+
int act, int grad, float alpha, float scale) {
|
54 |
+
int curDevice = -1;
|
55 |
+
cudaGetDevice(&curDevice);
|
56 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
57 |
+
|
58 |
+
auto x = input.contiguous();
|
59 |
+
auto b = bias.contiguous();
|
60 |
+
auto ref = refer.contiguous();
|
61 |
+
|
62 |
+
int use_bias = b.numel() ? 1 : 0;
|
63 |
+
int use_ref = ref.numel() ? 1 : 0;
|
64 |
+
|
65 |
+
int size_x = x.numel();
|
66 |
+
int size_b = b.numel();
|
67 |
+
int step_b = 1;
|
68 |
+
|
69 |
+
for (int i = 1 + 1; i < x.dim(); i++) {
|
70 |
+
step_b *= x.size(i);
|
71 |
+
}
|
72 |
+
|
73 |
+
int loop_x = 4;
|
74 |
+
int block_size = 4 * 32;
|
75 |
+
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
76 |
+
|
77 |
+
auto y = torch::empty_like(x);
|
78 |
+
|
79 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
|
80 |
+
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
81 |
+
y.data_ptr<scalar_t>(),
|
82 |
+
x.data_ptr<scalar_t>(),
|
83 |
+
b.data_ptr<scalar_t>(),
|
84 |
+
ref.data_ptr<scalar_t>(),
|
85 |
+
act,
|
86 |
+
grad,
|
87 |
+
alpha,
|
88 |
+
scale,
|
89 |
+
loop_x,
|
90 |
+
size_x,
|
91 |
+
step_b,
|
92 |
+
size_b,
|
93 |
+
use_bias,
|
94 |
+
use_ref
|
95 |
+
);
|
96 |
+
});
|
97 |
+
|
98 |
+
return y;
|
99 |
+
}
|
models/stylegan2/stylegan2-pytorch/op/setup.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
# Usage:
|
6 |
+
# python setup.py install (or python setup.py bdist_wheel)
|
7 |
+
# NB: Windows: run from VS2017 x64 Native Tool Command Prompt
|
8 |
+
|
9 |
+
rootdir = (Path(__file__).parent / '..' / 'op').resolve()
|
10 |
+
|
11 |
+
setup(
|
12 |
+
name='upfirdn2d',
|
13 |
+
ext_modules=[
|
14 |
+
CUDAExtension('upfirdn2d_op',
|
15 |
+
[str(rootdir / 'upfirdn2d.cpp'), str(rootdir / 'upfirdn2d_kernel.cu')],
|
16 |
+
)
|
17 |
+
],
|
18 |
+
cmdclass={
|
19 |
+
'build_ext': BuildExtension
|
20 |
+
}
|
21 |
+
)
|
22 |
+
|
23 |
+
setup(
|
24 |
+
name='fused',
|
25 |
+
ext_modules=[
|
26 |
+
CUDAExtension('fused',
|
27 |
+
[str(rootdir / 'fused_bias_act.cpp'), str(rootdir / 'fused_bias_act_kernel.cu')],
|
28 |
+
)
|
29 |
+
],
|
30 |
+
cmdclass={
|
31 |
+
'build_ext': BuildExtension
|
32 |
+
}
|
33 |
+
)
|
models/stylegan2/stylegan2-pytorch/op/upfirdn2d.cpp
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
|
4 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
5 |
+
int up_x, int up_y, int down_x, int down_y,
|
6 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
|
7 |
+
|
8 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
9 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
10 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
11 |
+
|
12 |
+
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
|
13 |
+
int up_x, int up_y, int down_x, int down_y,
|
14 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
15 |
+
CHECK_CUDA(input);
|
16 |
+
CHECK_CUDA(kernel);
|
17 |
+
|
18 |
+
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
|
19 |
+
}
|
20 |
+
|
21 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
22 |
+
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
23 |
+
}
|
models/stylegan2/stylegan2-pytorch/op/upfirdn2d.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import platform
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torch.autograd import Function
|
7 |
+
from torch.utils.cpp_extension import load
|
8 |
+
|
9 |
+
use_fallback = False
|
10 |
+
|
11 |
+
# Try loading precompiled, otherwise use native fallback
|
12 |
+
try:
|
13 |
+
import upfirdn2d_op
|
14 |
+
except ModuleNotFoundError as e:
|
15 |
+
print('StyleGAN2: Optimized CUDA op UpFirDn2d not available, using native PyTorch fallback.')
|
16 |
+
use_fallback = True
|
17 |
+
|
18 |
+
class UpFirDn2dBackward(Function):
|
19 |
+
@staticmethod
|
20 |
+
def forward(
|
21 |
+
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
|
22 |
+
):
|
23 |
+
|
24 |
+
up_x, up_y = up
|
25 |
+
down_x, down_y = down
|
26 |
+
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
27 |
+
|
28 |
+
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
29 |
+
|
30 |
+
grad_input = upfirdn2d_op.upfirdn2d(
|
31 |
+
grad_output,
|
32 |
+
grad_kernel,
|
33 |
+
down_x,
|
34 |
+
down_y,
|
35 |
+
up_x,
|
36 |
+
up_y,
|
37 |
+
g_pad_x0,
|
38 |
+
g_pad_x1,
|
39 |
+
g_pad_y0,
|
40 |
+
g_pad_y1,
|
41 |
+
)
|
42 |
+
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
43 |
+
|
44 |
+
ctx.save_for_backward(kernel)
|
45 |
+
|
46 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
47 |
+
|
48 |
+
ctx.up_x = up_x
|
49 |
+
ctx.up_y = up_y
|
50 |
+
ctx.down_x = down_x
|
51 |
+
ctx.down_y = down_y
|
52 |
+
ctx.pad_x0 = pad_x0
|
53 |
+
ctx.pad_x1 = pad_x1
|
54 |
+
ctx.pad_y0 = pad_y0
|
55 |
+
ctx.pad_y1 = pad_y1
|
56 |
+
ctx.in_size = in_size
|
57 |
+
ctx.out_size = out_size
|
58 |
+
|
59 |
+
return grad_input
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def backward(ctx, gradgrad_input):
|
63 |
+
kernel, = ctx.saved_tensors
|
64 |
+
|
65 |
+
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
66 |
+
|
67 |
+
gradgrad_out = upfirdn2d_op.upfirdn2d(
|
68 |
+
gradgrad_input,
|
69 |
+
kernel,
|
70 |
+
ctx.up_x,
|
71 |
+
ctx.up_y,
|
72 |
+
ctx.down_x,
|
73 |
+
ctx.down_y,
|
74 |
+
ctx.pad_x0,
|
75 |
+
ctx.pad_x1,
|
76 |
+
ctx.pad_y0,
|
77 |
+
ctx.pad_y1,
|
78 |
+
)
|
79 |
+
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
80 |
+
gradgrad_out = gradgrad_out.view(
|
81 |
+
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
|
82 |
+
)
|
83 |
+
|
84 |
+
return gradgrad_out, None, None, None, None, None, None, None, None
|
85 |
+
|
86 |
+
|
87 |
+
class UpFirDn2d(Function):
|
88 |
+
@staticmethod
|
89 |
+
def forward(ctx, input, kernel, up, down, pad):
|
90 |
+
up_x, up_y = up
|
91 |
+
down_x, down_y = down
|
92 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
93 |
+
|
94 |
+
kernel_h, kernel_w = kernel.shape
|
95 |
+
batch, channel, in_h, in_w = input.shape
|
96 |
+
ctx.in_size = input.shape
|
97 |
+
|
98 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
99 |
+
|
100 |
+
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
101 |
+
|
102 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
103 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
104 |
+
ctx.out_size = (out_h, out_w)
|
105 |
+
|
106 |
+
ctx.up = (up_x, up_y)
|
107 |
+
ctx.down = (down_x, down_y)
|
108 |
+
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
109 |
+
|
110 |
+
g_pad_x0 = kernel_w - pad_x0 - 1
|
111 |
+
g_pad_y0 = kernel_h - pad_y0 - 1
|
112 |
+
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
113 |
+
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
114 |
+
|
115 |
+
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
116 |
+
|
117 |
+
out = upfirdn2d_op.upfirdn2d(
|
118 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
119 |
+
)
|
120 |
+
# out = out.view(major, out_h, out_w, minor)
|
121 |
+
out = out.view(-1, channel, out_h, out_w)
|
122 |
+
|
123 |
+
return out
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
def backward(ctx, grad_output):
|
127 |
+
kernel, grad_kernel = ctx.saved_tensors
|
128 |
+
|
129 |
+
grad_input = UpFirDn2dBackward.apply(
|
130 |
+
grad_output,
|
131 |
+
kernel,
|
132 |
+
grad_kernel,
|
133 |
+
ctx.up,
|
134 |
+
ctx.down,
|
135 |
+
ctx.pad,
|
136 |
+
ctx.g_pad,
|
137 |
+
ctx.in_size,
|
138 |
+
ctx.out_size,
|
139 |
+
)
|
140 |
+
|
141 |
+
return grad_input, None, None, None, None
|
142 |
+
|
143 |
+
|
144 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
145 |
+
if use_fallback or input.device.type == "cpu":
|
146 |
+
out = upfirdn2d_native(
|
147 |
+
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
|
148 |
+
)
|
149 |
+
else:
|
150 |
+
out = UpFirDn2d.apply(
|
151 |
+
input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
|
152 |
+
)
|
153 |
+
|
154 |
+
return out
|
155 |
+
|
156 |
+
|
157 |
+
def upfirdn2d_native(
|
158 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
159 |
+
):
|
160 |
+
_, channel, in_h, in_w = input.shape
|
161 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
162 |
+
|
163 |
+
_, in_h, in_w, minor = input.shape
|
164 |
+
kernel_h, kernel_w = kernel.shape
|
165 |
+
|
166 |
+
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
167 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
168 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
169 |
+
|
170 |
+
out = F.pad(
|
171 |
+
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
172 |
+
)
|
173 |
+
out = out[
|
174 |
+
:,
|
175 |
+
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
176 |
+
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
177 |
+
:,
|
178 |
+
]
|
179 |
+
|
180 |
+
out = out.permute(0, 3, 1, 2)
|
181 |
+
out = out.reshape(
|
182 |
+
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
183 |
+
)
|
184 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
185 |
+
out = F.conv2d(out, w)
|
186 |
+
out = out.reshape(
|
187 |
+
-1,
|
188 |
+
minor,
|
189 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
190 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
191 |
+
)
|
192 |
+
out = out.permute(0, 2, 3, 1)
|
193 |
+
out = out[:, ::down_y, ::down_x, :]
|
194 |
+
|
195 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
196 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
197 |
+
|
198 |
+
return out.view(-1, channel, out_h, out_w)
|
models/stylegan2/stylegan2-pytorch/op/upfirdn2d_kernel.cu
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
12 |
+
#include <ATen/cuda/CUDAContext.h>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
18 |
+
int c = a / b;
|
19 |
+
|
20 |
+
if (c * b > a) {
|
21 |
+
c--;
|
22 |
+
}
|
23 |
+
|
24 |
+
return c;
|
25 |
+
}
|
26 |
+
|
27 |
+
struct UpFirDn2DKernelParams {
|
28 |
+
int up_x;
|
29 |
+
int up_y;
|
30 |
+
int down_x;
|
31 |
+
int down_y;
|
32 |
+
int pad_x0;
|
33 |
+
int pad_x1;
|
34 |
+
int pad_y0;
|
35 |
+
int pad_y1;
|
36 |
+
|
37 |
+
int major_dim;
|
38 |
+
int in_h;
|
39 |
+
int in_w;
|
40 |
+
int minor_dim;
|
41 |
+
int kernel_h;
|
42 |
+
int kernel_w;
|
43 |
+
int out_h;
|
44 |
+
int out_w;
|
45 |
+
int loop_major;
|
46 |
+
int loop_x;
|
47 |
+
};
|
48 |
+
|
49 |
+
template <typename scalar_t>
|
50 |
+
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
|
51 |
+
const scalar_t *kernel,
|
52 |
+
const UpFirDn2DKernelParams p) {
|
53 |
+
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
54 |
+
int out_y = minor_idx / p.minor_dim;
|
55 |
+
minor_idx -= out_y * p.minor_dim;
|
56 |
+
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
|
57 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
58 |
+
|
59 |
+
if (out_x_base >= p.out_w || out_y >= p.out_h ||
|
60 |
+
major_idx_base >= p.major_dim) {
|
61 |
+
return;
|
62 |
+
}
|
63 |
+
|
64 |
+
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
|
65 |
+
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
|
66 |
+
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
|
67 |
+
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
|
68 |
+
|
69 |
+
for (int loop_major = 0, major_idx = major_idx_base;
|
70 |
+
loop_major < p.loop_major && major_idx < p.major_dim;
|
71 |
+
loop_major++, major_idx++) {
|
72 |
+
for (int loop_x = 0, out_x = out_x_base;
|
73 |
+
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
|
74 |
+
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
|
75 |
+
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
|
76 |
+
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
|
77 |
+
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
|
78 |
+
|
79 |
+
const scalar_t *x_p =
|
80 |
+
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
|
81 |
+
minor_idx];
|
82 |
+
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
|
83 |
+
int x_px = p.minor_dim;
|
84 |
+
int k_px = -p.up_x;
|
85 |
+
int x_py = p.in_w * p.minor_dim;
|
86 |
+
int k_py = -p.up_y * p.kernel_w;
|
87 |
+
|
88 |
+
scalar_t v = 0.0f;
|
89 |
+
|
90 |
+
for (int y = 0; y < h; y++) {
|
91 |
+
for (int x = 0; x < w; x++) {
|
92 |
+
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
|
93 |
+
x_p += x_px;
|
94 |
+
k_p += k_px;
|
95 |
+
}
|
96 |
+
|
97 |
+
x_p += x_py - w * x_px;
|
98 |
+
k_p += k_py - w * k_px;
|
99 |
+
}
|
100 |
+
|
101 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
102 |
+
minor_idx] = v;
|
103 |
+
}
|
104 |
+
}
|
105 |
+
}
|
106 |
+
|
107 |
+
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
|
108 |
+
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
109 |
+
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
|
110 |
+
const scalar_t *kernel,
|
111 |
+
const UpFirDn2DKernelParams p) {
|
112 |
+
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
113 |
+
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
114 |
+
|
115 |
+
__shared__ volatile float sk[kernel_h][kernel_w];
|
116 |
+
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
117 |
+
|
118 |
+
int minor_idx = blockIdx.x;
|
119 |
+
int tile_out_y = minor_idx / p.minor_dim;
|
120 |
+
minor_idx -= tile_out_y * p.minor_dim;
|
121 |
+
tile_out_y *= tile_out_h;
|
122 |
+
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
123 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
124 |
+
|
125 |
+
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
|
126 |
+
major_idx_base >= p.major_dim) {
|
127 |
+
return;
|
128 |
+
}
|
129 |
+
|
130 |
+
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
|
131 |
+
tap_idx += blockDim.x) {
|
132 |
+
int ky = tap_idx / kernel_w;
|
133 |
+
int kx = tap_idx - ky * kernel_w;
|
134 |
+
scalar_t v = 0.0;
|
135 |
+
|
136 |
+
if (kx < p.kernel_w & ky < p.kernel_h) {
|
137 |
+
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
138 |
+
}
|
139 |
+
|
140 |
+
sk[ky][kx] = v;
|
141 |
+
}
|
142 |
+
|
143 |
+
for (int loop_major = 0, major_idx = major_idx_base;
|
144 |
+
loop_major < p.loop_major & major_idx < p.major_dim;
|
145 |
+
loop_major++, major_idx++) {
|
146 |
+
for (int loop_x = 0, tile_out_x = tile_out_x_base;
|
147 |
+
loop_x < p.loop_x & tile_out_x < p.out_w;
|
148 |
+
loop_x++, tile_out_x += tile_out_w) {
|
149 |
+
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
150 |
+
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
151 |
+
int tile_in_x = floor_div(tile_mid_x, up_x);
|
152 |
+
int tile_in_y = floor_div(tile_mid_y, up_y);
|
153 |
+
|
154 |
+
__syncthreads();
|
155 |
+
|
156 |
+
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
|
157 |
+
in_idx += blockDim.x) {
|
158 |
+
int rel_in_y = in_idx / tile_in_w;
|
159 |
+
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
160 |
+
int in_x = rel_in_x + tile_in_x;
|
161 |
+
int in_y = rel_in_y + tile_in_y;
|
162 |
+
|
163 |
+
scalar_t v = 0.0;
|
164 |
+
|
165 |
+
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
166 |
+
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
|
167 |
+
p.minor_dim +
|
168 |
+
minor_idx];
|
169 |
+
}
|
170 |
+
|
171 |
+
sx[rel_in_y][rel_in_x] = v;
|
172 |
+
}
|
173 |
+
|
174 |
+
__syncthreads();
|
175 |
+
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
|
176 |
+
out_idx += blockDim.x) {
|
177 |
+
int rel_out_y = out_idx / tile_out_w;
|
178 |
+
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
179 |
+
int out_x = rel_out_x + tile_out_x;
|
180 |
+
int out_y = rel_out_y + tile_out_y;
|
181 |
+
|
182 |
+
int mid_x = tile_mid_x + rel_out_x * down_x;
|
183 |
+
int mid_y = tile_mid_y + rel_out_y * down_y;
|
184 |
+
int in_x = floor_div(mid_x, up_x);
|
185 |
+
int in_y = floor_div(mid_y, up_y);
|
186 |
+
int rel_in_x = in_x - tile_in_x;
|
187 |
+
int rel_in_y = in_y - tile_in_y;
|
188 |
+
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
189 |
+
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
190 |
+
|
191 |
+
scalar_t v = 0.0;
|
192 |
+
|
193 |
+
#pragma unroll
|
194 |
+
for (int y = 0; y < kernel_h / up_y; y++)
|
195 |
+
#pragma unroll
|
196 |
+
for (int x = 0; x < kernel_w / up_x; x++)
|
197 |
+
v += sx[rel_in_y + y][rel_in_x + x] *
|
198 |
+
sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
199 |
+
|
200 |
+
if (out_x < p.out_w & out_y < p.out_h) {
|
201 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
202 |
+
minor_idx] = v;
|
203 |
+
}
|
204 |
+
}
|
205 |
+
}
|
206 |
+
}
|
207 |
+
}
|
208 |
+
|
209 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
210 |
+
const torch::Tensor &kernel, int up_x, int up_y,
|
211 |
+
int down_x, int down_y, int pad_x0, int pad_x1,
|
212 |
+
int pad_y0, int pad_y1) {
|
213 |
+
int curDevice = -1;
|
214 |
+
cudaGetDevice(&curDevice);
|
215 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
216 |
+
|
217 |
+
UpFirDn2DKernelParams p;
|
218 |
+
|
219 |
+
auto x = input.contiguous();
|
220 |
+
auto k = kernel.contiguous();
|
221 |
+
|
222 |
+
p.major_dim = x.size(0);
|
223 |
+
p.in_h = x.size(1);
|
224 |
+
p.in_w = x.size(2);
|
225 |
+
p.minor_dim = x.size(3);
|
226 |
+
p.kernel_h = k.size(0);
|
227 |
+
p.kernel_w = k.size(1);
|
228 |
+
p.up_x = up_x;
|
229 |
+
p.up_y = up_y;
|
230 |
+
p.down_x = down_x;
|
231 |
+
p.down_y = down_y;
|
232 |
+
p.pad_x0 = pad_x0;
|
233 |
+
p.pad_x1 = pad_x1;
|
234 |
+
p.pad_y0 = pad_y0;
|
235 |
+
p.pad_y1 = pad_y1;
|
236 |
+
|
237 |
+
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
|
238 |
+
p.down_y;
|
239 |
+
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
|
240 |
+
p.down_x;
|
241 |
+
|
242 |
+
auto out =
|
243 |
+
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
244 |
+
|
245 |
+
int mode = -1;
|
246 |
+
|
247 |
+
int tile_out_h = -1;
|
248 |
+
int tile_out_w = -1;
|
249 |
+
|
250 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
251 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
252 |
+
mode = 1;
|
253 |
+
tile_out_h = 16;
|
254 |
+
tile_out_w = 64;
|
255 |
+
}
|
256 |
+
|
257 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
258 |
+
p.kernel_h <= 3 && p.kernel_w <= 3) {
|
259 |
+
mode = 2;
|
260 |
+
tile_out_h = 16;
|
261 |
+
tile_out_w = 64;
|
262 |
+
}
|
263 |
+
|
264 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
265 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
266 |
+
mode = 3;
|
267 |
+
tile_out_h = 16;
|
268 |
+
tile_out_w = 64;
|
269 |
+
}
|
270 |
+
|
271 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
272 |
+
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
273 |
+
mode = 4;
|
274 |
+
tile_out_h = 16;
|
275 |
+
tile_out_w = 64;
|
276 |
+
}
|
277 |
+
|
278 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
279 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
280 |
+
mode = 5;
|
281 |
+
tile_out_h = 8;
|
282 |
+
tile_out_w = 32;
|
283 |
+
}
|
284 |
+
|
285 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
286 |
+
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
287 |
+
mode = 6;
|
288 |
+
tile_out_h = 8;
|
289 |
+
tile_out_w = 32;
|
290 |
+
}
|
291 |
+
|
292 |
+
dim3 block_size;
|
293 |
+
dim3 grid_size;
|
294 |
+
|
295 |
+
if (tile_out_h > 0 && tile_out_w > 0) {
|
296 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
297 |
+
p.loop_x = 1;
|
298 |
+
block_size = dim3(32 * 8, 1, 1);
|
299 |
+
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
300 |
+
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
301 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
302 |
+
} else {
|
303 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
304 |
+
p.loop_x = 4;
|
305 |
+
block_size = dim3(4, 32, 1);
|
306 |
+
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
|
307 |
+
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
|
308 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
309 |
+
}
|
310 |
+
|
311 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
312 |
+
switch (mode) {
|
313 |
+
case 1:
|
314 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
|
315 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
316 |
+
x.data_ptr<scalar_t>(),
|
317 |
+
k.data_ptr<scalar_t>(), p);
|
318 |
+
|
319 |
+
break;
|
320 |
+
|
321 |
+
case 2:
|
322 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
|
323 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
324 |
+
x.data_ptr<scalar_t>(),
|
325 |
+
k.data_ptr<scalar_t>(), p);
|
326 |
+
|
327 |
+
break;
|
328 |
+
|
329 |
+
case 3:
|
330 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
|
331 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
332 |
+
x.data_ptr<scalar_t>(),
|
333 |
+
k.data_ptr<scalar_t>(), p);
|
334 |
+
|
335 |
+
break;
|
336 |
+
|
337 |
+
case 4:
|
338 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
|
339 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
340 |
+
x.data_ptr<scalar_t>(),
|
341 |
+
k.data_ptr<scalar_t>(), p);
|
342 |
+
|
343 |
+
break;
|
344 |
+
|
345 |
+
case 5:
|
346 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
347 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
348 |
+
x.data_ptr<scalar_t>(),
|
349 |
+
k.data_ptr<scalar_t>(), p);
|
350 |
+
|
351 |
+
break;
|
352 |
+
|
353 |
+
case 6:
|
354 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
355 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
356 |
+
x.data_ptr<scalar_t>(),
|
357 |
+
k.data_ptr<scalar_t>(), p);
|
358 |
+
|
359 |
+
break;
|
360 |
+
|
361 |
+
default:
|
362 |
+
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
363 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
364 |
+
k.data_ptr<scalar_t>(), p);
|
365 |
+
}
|
366 |
+
});
|
367 |
+
|
368 |
+
return out;
|
369 |
+
}
|
models/stylegan2/stylegan2-pytorch/ppl.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
import lpips
|
9 |
+
from model import Generator
|
10 |
+
|
11 |
+
|
12 |
+
def normalize(x):
|
13 |
+
return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True))
|
14 |
+
|
15 |
+
|
16 |
+
def slerp(a, b, t):
|
17 |
+
a = normalize(a)
|
18 |
+
b = normalize(b)
|
19 |
+
d = (a * b).sum(-1, keepdim=True)
|
20 |
+
p = t * torch.acos(d)
|
21 |
+
c = normalize(b - d * a)
|
22 |
+
d = a * torch.cos(p) + c * torch.sin(p)
|
23 |
+
|
24 |
+
return normalize(d)
|
25 |
+
|
26 |
+
|
27 |
+
def lerp(a, b, t):
|
28 |
+
return a + (b - a) * t
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == '__main__':
|
32 |
+
device = 'cuda'
|
33 |
+
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
|
36 |
+
parser.add_argument('--space', choices=['z', 'w'])
|
37 |
+
parser.add_argument('--batch', type=int, default=64)
|
38 |
+
parser.add_argument('--n_sample', type=int, default=5000)
|
39 |
+
parser.add_argument('--size', type=int, default=256)
|
40 |
+
parser.add_argument('--eps', type=float, default=1e-4)
|
41 |
+
parser.add_argument('--crop', action='store_true')
|
42 |
+
parser.add_argument('ckpt', metavar='CHECKPOINT')
|
43 |
+
|
44 |
+
args = parser.parse_args()
|
45 |
+
|
46 |
+
latent_dim = 512
|
47 |
+
|
48 |
+
ckpt = torch.load(args.ckpt)
|
49 |
+
|
50 |
+
g = Generator(args.size, latent_dim, 8).to(device)
|
51 |
+
g.load_state_dict(ckpt['g_ema'])
|
52 |
+
g.eval()
|
53 |
+
|
54 |
+
percept = lpips.PerceptualLoss(
|
55 |
+
model='net-lin', net='vgg', use_gpu=device.startswith('cuda')
|
56 |
+
)
|
57 |
+
|
58 |
+
distances = []
|
59 |
+
|
60 |
+
n_batch = args.n_sample // args.batch
|
61 |
+
resid = args.n_sample - (n_batch * args.batch)
|
62 |
+
batch_sizes = [args.batch] * n_batch + [resid]
|
63 |
+
|
64 |
+
with torch.no_grad():
|
65 |
+
for batch in tqdm(batch_sizes):
|
66 |
+
noise = g.make_noise()
|
67 |
+
|
68 |
+
inputs = torch.randn([batch * 2, latent_dim], device=device)
|
69 |
+
lerp_t = torch.rand(batch, device=device)
|
70 |
+
|
71 |
+
if args.space == 'w':
|
72 |
+
latent = g.get_latent(inputs)
|
73 |
+
latent_t0, latent_t1 = latent[::2], latent[1::2]
|
74 |
+
latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None])
|
75 |
+
latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps)
|
76 |
+
latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape)
|
77 |
+
|
78 |
+
image, _ = g([latent_e], input_is_latent=True, noise=noise)
|
79 |
+
|
80 |
+
if args.crop:
|
81 |
+
c = image.shape[2] // 8
|
82 |
+
image = image[:, :, c * 3 : c * 7, c * 2 : c * 6]
|
83 |
+
|
84 |
+
factor = image.shape[2] // 256
|
85 |
+
|
86 |
+
if factor > 1:
|
87 |
+
image = F.interpolate(
|
88 |
+
image, size=(256, 256), mode='bilinear', align_corners=False
|
89 |
+
)
|
90 |
+
|
91 |
+
dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / (
|
92 |
+
args.eps ** 2
|
93 |
+
)
|
94 |
+
distances.append(dist.to('cpu').numpy())
|
95 |
+
|
96 |
+
distances = np.concatenate(distances, 0)
|
97 |
+
|
98 |
+
lo = np.percentile(distances, 1, interpolation='lower')
|
99 |
+
hi = np.percentile(distances, 99, interpolation='higher')
|
100 |
+
filtered_dist = np.extract(
|
101 |
+
np.logical_and(lo <= distances, distances <= hi), distances
|
102 |
+
)
|
103 |
+
|
104 |
+
print('ppl:', filtered_dist.mean())
|
models/stylegan2/stylegan2-pytorch/prepare_data.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from io import BytesIO
|
3 |
+
import multiprocessing
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
import lmdb
|
8 |
+
from tqdm import tqdm
|
9 |
+
from torchvision import datasets
|
10 |
+
from torchvision.transforms import functional as trans_fn
|
11 |
+
|
12 |
+
|
13 |
+
def resize_and_convert(img, size, resample, quality=100):
|
14 |
+
img = trans_fn.resize(img, size, resample)
|
15 |
+
img = trans_fn.center_crop(img, size)
|
16 |
+
buffer = BytesIO()
|
17 |
+
img.save(buffer, format='jpeg', quality=quality)
|
18 |
+
val = buffer.getvalue()
|
19 |
+
|
20 |
+
return val
|
21 |
+
|
22 |
+
|
23 |
+
def resize_multiple(img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100):
|
24 |
+
imgs = []
|
25 |
+
|
26 |
+
for size in sizes:
|
27 |
+
imgs.append(resize_and_convert(img, size, resample, quality))
|
28 |
+
|
29 |
+
return imgs
|
30 |
+
|
31 |
+
|
32 |
+
def resize_worker(img_file, sizes, resample):
|
33 |
+
i, file = img_file
|
34 |
+
img = Image.open(file)
|
35 |
+
img = img.convert('RGB')
|
36 |
+
out = resize_multiple(img, sizes=sizes, resample=resample)
|
37 |
+
|
38 |
+
return i, out
|
39 |
+
|
40 |
+
|
41 |
+
def prepare(env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS):
|
42 |
+
resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
|
43 |
+
|
44 |
+
files = sorted(dataset.imgs, key=lambda x: x[0])
|
45 |
+
files = [(i, file) for i, (file, label) in enumerate(files)]
|
46 |
+
total = 0
|
47 |
+
|
48 |
+
with multiprocessing.Pool(n_worker) as pool:
|
49 |
+
for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
|
50 |
+
for size, img in zip(sizes, imgs):
|
51 |
+
key = f'{size}-{str(i).zfill(5)}'.encode('utf-8')
|
52 |
+
|
53 |
+
with env.begin(write=True) as txn:
|
54 |
+
txn.put(key, img)
|
55 |
+
|
56 |
+
total += 1
|
57 |
+
|
58 |
+
with env.begin(write=True) as txn:
|
59 |
+
txn.put('length'.encode('utf-8'), str(total).encode('utf-8'))
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == '__main__':
|
63 |
+
parser = argparse.ArgumentParser()
|
64 |
+
parser.add_argument('--out', type=str)
|
65 |
+
parser.add_argument('--size', type=str, default='128,256,512,1024')
|
66 |
+
parser.add_argument('--n_worker', type=int, default=8)
|
67 |
+
parser.add_argument('--resample', type=str, default='lanczos')
|
68 |
+
parser.add_argument('path', type=str)
|
69 |
+
|
70 |
+
args = parser.parse_args()
|
71 |
+
|
72 |
+
resample_map = {'lanczos': Image.LANCZOS, 'bilinear': Image.BILINEAR}
|
73 |
+
resample = resample_map[args.resample]
|
74 |
+
|
75 |
+
sizes = [int(s.strip()) for s in args.size.split(',')]
|
76 |
+
|
77 |
+
print(f'Make dataset of image sizes:', ', '.join(str(s) for s in sizes))
|
78 |
+
|
79 |
+
imgset = datasets.ImageFolder(args.path)
|
80 |
+
|
81 |
+
with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
|
82 |
+
prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)
|
models/stylegan2/stylegan2-pytorch/projector.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import optim
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torchvision import transforms
|
9 |
+
from PIL import Image
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
import lpips
|
13 |
+
from model import Generator
|
14 |
+
|
15 |
+
|
16 |
+
def noise_regularize(noises):
|
17 |
+
loss = 0
|
18 |
+
|
19 |
+
for noise in noises:
|
20 |
+
size = noise.shape[2]
|
21 |
+
|
22 |
+
while True:
|
23 |
+
loss = (
|
24 |
+
loss
|
25 |
+
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
|
26 |
+
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
|
27 |
+
)
|
28 |
+
|
29 |
+
if size <= 8:
|
30 |
+
break
|
31 |
+
|
32 |
+
noise = noise.reshape([1, 1, size // 2, 2, size // 2, 2])
|
33 |
+
noise = noise.mean([3, 5])
|
34 |
+
size //= 2
|
35 |
+
|
36 |
+
return loss
|
37 |
+
|
38 |
+
|
39 |
+
def noise_normalize_(noises):
|
40 |
+
for noise in noises:
|
41 |
+
mean = noise.mean()
|
42 |
+
std = noise.std()
|
43 |
+
|
44 |
+
noise.data.add_(-mean).div_(std)
|
45 |
+
|
46 |
+
|
47 |
+
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
|
48 |
+
lr_ramp = min(1, (1 - t) / rampdown)
|
49 |
+
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
|
50 |
+
lr_ramp = lr_ramp * min(1, t / rampup)
|
51 |
+
|
52 |
+
return initial_lr * lr_ramp
|
53 |
+
|
54 |
+
|
55 |
+
def latent_noise(latent, strength):
|
56 |
+
noise = torch.randn_like(latent) * strength
|
57 |
+
|
58 |
+
return latent + noise
|
59 |
+
|
60 |
+
|
61 |
+
def make_image(tensor):
|
62 |
+
return (
|
63 |
+
tensor.detach()
|
64 |
+
.clamp_(min=-1, max=1)
|
65 |
+
.add(1)
|
66 |
+
.div_(2)
|
67 |
+
.mul(255)
|
68 |
+
.type(torch.uint8)
|
69 |
+
.permute(0, 2, 3, 1)
|
70 |
+
.to('cpu')
|
71 |
+
.numpy()
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == '__main__':
|
76 |
+
device = 'cuda'
|
77 |
+
|
78 |
+
parser = argparse.ArgumentParser()
|
79 |
+
parser.add_argument('--ckpt', type=str, required=True)
|
80 |
+
parser.add_argument('--size', type=int, default=256)
|
81 |
+
parser.add_argument('--lr_rampup', type=float, default=0.05)
|
82 |
+
parser.add_argument('--lr_rampdown', type=float, default=0.25)
|
83 |
+
parser.add_argument('--lr', type=float, default=0.1)
|
84 |
+
parser.add_argument('--noise', type=float, default=0.05)
|
85 |
+
parser.add_argument('--noise_ramp', type=float, default=0.75)
|
86 |
+
parser.add_argument('--step', type=int, default=1000)
|
87 |
+
parser.add_argument('--noise_regularize', type=float, default=1e5)
|
88 |
+
parser.add_argument('--mse', type=float, default=0)
|
89 |
+
parser.add_argument('--w_plus', action='store_true')
|
90 |
+
parser.add_argument('files', metavar='FILES', nargs='+')
|
91 |
+
|
92 |
+
args = parser.parse_args()
|
93 |
+
|
94 |
+
n_mean_latent = 10000
|
95 |
+
|
96 |
+
resize = min(args.size, 256)
|
97 |
+
|
98 |
+
transform = transforms.Compose(
|
99 |
+
[
|
100 |
+
transforms.Resize(resize),
|
101 |
+
transforms.CenterCrop(resize),
|
102 |
+
transforms.ToTensor(),
|
103 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
104 |
+
]
|
105 |
+
)
|
106 |
+
|
107 |
+
imgs = []
|
108 |
+
|
109 |
+
for imgfile in args.files:
|
110 |
+
img = transform(Image.open(imgfile).convert('RGB'))
|
111 |
+
imgs.append(img)
|
112 |
+
|
113 |
+
imgs = torch.stack(imgs, 0).to(device)
|
114 |
+
|
115 |
+
g_ema = Generator(args.size, 512, 8)
|
116 |
+
g_ema.load_state_dict(torch.load(args.ckpt)['g_ema'], strict=False)
|
117 |
+
g_ema.eval()
|
118 |
+
g_ema = g_ema.to(device)
|
119 |
+
|
120 |
+
with torch.no_grad():
|
121 |
+
noise_sample = torch.randn(n_mean_latent, 512, device=device)
|
122 |
+
latent_out = g_ema.style(noise_sample)
|
123 |
+
|
124 |
+
latent_mean = latent_out.mean(0)
|
125 |
+
latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
|
126 |
+
|
127 |
+
percept = lpips.PerceptualLoss(
|
128 |
+
model='net-lin', net='vgg', use_gpu=device.startswith('cuda')
|
129 |
+
)
|
130 |
+
|
131 |
+
noises = g_ema.make_noise()
|
132 |
+
|
133 |
+
latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(2, 1)
|
134 |
+
|
135 |
+
if args.w_plus:
|
136 |
+
latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
|
137 |
+
|
138 |
+
latent_in.requires_grad = True
|
139 |
+
|
140 |
+
for noise in noises:
|
141 |
+
noise.requires_grad = True
|
142 |
+
|
143 |
+
optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
|
144 |
+
|
145 |
+
pbar = tqdm(range(args.step))
|
146 |
+
latent_path = []
|
147 |
+
|
148 |
+
for i in pbar:
|
149 |
+
t = i / args.step
|
150 |
+
lr = get_lr(t, args.lr)
|
151 |
+
optimizer.param_groups[0]['lr'] = lr
|
152 |
+
noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2
|
153 |
+
latent_n = latent_noise(latent_in, noise_strength.item())
|
154 |
+
|
155 |
+
img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises)
|
156 |
+
|
157 |
+
batch, channel, height, width = img_gen.shape
|
158 |
+
|
159 |
+
if height > 256:
|
160 |
+
factor = height // 256
|
161 |
+
|
162 |
+
img_gen = img_gen.reshape(
|
163 |
+
batch, channel, height // factor, factor, width // factor, factor
|
164 |
+
)
|
165 |
+
img_gen = img_gen.mean([3, 5])
|
166 |
+
|
167 |
+
p_loss = percept(img_gen, imgs).sum()
|
168 |
+
n_loss = noise_regularize(noises)
|
169 |
+
mse_loss = F.mse_loss(img_gen, imgs)
|
170 |
+
|
171 |
+
loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
|
172 |
+
|
173 |
+
optimizer.zero_grad()
|
174 |
+
loss.backward()
|
175 |
+
optimizer.step()
|
176 |
+
|
177 |
+
noise_normalize_(noises)
|
178 |
+
|
179 |
+
if (i + 1) % 100 == 0:
|
180 |
+
latent_path.append(latent_in.detach().clone())
|
181 |
+
|
182 |
+
pbar.set_description(
|
183 |
+
(
|
184 |
+
f'perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};'
|
185 |
+
f' mse: {mse_loss.item():.4f}; lr: {lr:.4f}'
|
186 |
+
)
|
187 |
+
)
|
188 |
+
|
189 |
+
result_file = {'noises': noises}
|
190 |
+
|
191 |
+
img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises)
|
192 |
+
|
193 |
+
filename = os.path.splitext(os.path.basename(args.files[0]))[0] + '.pt'
|
194 |
+
|
195 |
+
img_ar = make_image(img_gen)
|
196 |
+
|
197 |
+
for i, input_name in enumerate(args.files):
|
198 |
+
result_file[input_name] = {'img': img_gen[i], 'latent': latent_in[i]}
|
199 |
+
img_name = os.path.splitext(os.path.basename(input_name))[0] + '-project.png'
|
200 |
+
pil_img = Image.fromarray(img_ar[i])
|
201 |
+
pil_img.save(img_name)
|
202 |
+
|
203 |
+
torch.save(result_file, filename)
|
models/stylegan2/stylegan2-pytorch/sample/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.png
|
models/stylegan2/stylegan2-pytorch/train.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import random
|
4 |
+
import os
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch import nn, autograd, optim
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.utils import data
|
11 |
+
import torch.distributed as dist
|
12 |
+
from torchvision import transforms, utils
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
try:
|
16 |
+
import wandb
|
17 |
+
|
18 |
+
except ImportError:
|
19 |
+
wandb = None
|
20 |
+
|
21 |
+
from model import Generator, Discriminator
|
22 |
+
from dataset import MultiResolutionDataset
|
23 |
+
from distributed import (
|
24 |
+
get_rank,
|
25 |
+
synchronize,
|
26 |
+
reduce_loss_dict,
|
27 |
+
reduce_sum,
|
28 |
+
get_world_size,
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def data_sampler(dataset, shuffle, distributed):
|
33 |
+
if distributed:
|
34 |
+
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
|
35 |
+
|
36 |
+
if shuffle:
|
37 |
+
return data.RandomSampler(dataset)
|
38 |
+
|
39 |
+
else:
|
40 |
+
return data.SequentialSampler(dataset)
|
41 |
+
|
42 |
+
|
43 |
+
def requires_grad(model, flag=True):
|
44 |
+
for p in model.parameters():
|
45 |
+
p.requires_grad = flag
|
46 |
+
|
47 |
+
|
48 |
+
def accumulate(model1, model2, decay=0.999):
|
49 |
+
par1 = dict(model1.named_parameters())
|
50 |
+
par2 = dict(model2.named_parameters())
|
51 |
+
|
52 |
+
for k in par1.keys():
|
53 |
+
par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)
|
54 |
+
|
55 |
+
|
56 |
+
def sample_data(loader):
|
57 |
+
while True:
|
58 |
+
for batch in loader:
|
59 |
+
yield batch
|
60 |
+
|
61 |
+
|
62 |
+
def d_logistic_loss(real_pred, fake_pred):
|
63 |
+
real_loss = F.softplus(-real_pred)
|
64 |
+
fake_loss = F.softplus(fake_pred)
|
65 |
+
|
66 |
+
return real_loss.mean() + fake_loss.mean()
|
67 |
+
|
68 |
+
|
69 |
+
def d_r1_loss(real_pred, real_img):
|
70 |
+
grad_real, = autograd.grad(
|
71 |
+
outputs=real_pred.sum(), inputs=real_img, create_graph=True
|
72 |
+
)
|
73 |
+
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
|
74 |
+
|
75 |
+
return grad_penalty
|
76 |
+
|
77 |
+
|
78 |
+
def g_nonsaturating_loss(fake_pred):
|
79 |
+
loss = F.softplus(-fake_pred).mean()
|
80 |
+
|
81 |
+
return loss
|
82 |
+
|
83 |
+
|
84 |
+
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
|
85 |
+
noise = torch.randn_like(fake_img) / math.sqrt(
|
86 |
+
fake_img.shape[2] * fake_img.shape[3]
|
87 |
+
)
|
88 |
+
grad, = autograd.grad(
|
89 |
+
outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
|
90 |
+
)
|
91 |
+
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
|
92 |
+
|
93 |
+
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
|
94 |
+
|
95 |
+
path_penalty = (path_lengths - path_mean).pow(2).mean()
|
96 |
+
|
97 |
+
return path_penalty, path_mean.detach(), path_lengths
|
98 |
+
|
99 |
+
|
100 |
+
def make_noise(batch, latent_dim, n_noise, device):
|
101 |
+
if n_noise == 1:
|
102 |
+
return torch.randn(batch, latent_dim, device=device)
|
103 |
+
|
104 |
+
noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
|
105 |
+
|
106 |
+
return noises
|
107 |
+
|
108 |
+
|
109 |
+
def mixing_noise(batch, latent_dim, prob, device):
|
110 |
+
if prob > 0 and random.random() < prob:
|
111 |
+
return make_noise(batch, latent_dim, 2, device)
|
112 |
+
|
113 |
+
else:
|
114 |
+
return [make_noise(batch, latent_dim, 1, device)]
|
115 |
+
|
116 |
+
|
117 |
+
def set_grad_none(model, targets):
|
118 |
+
for n, p in model.named_parameters():
|
119 |
+
if n in targets:
|
120 |
+
p.grad = None
|
121 |
+
|
122 |
+
|
123 |
+
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
|
124 |
+
loader = sample_data(loader)
|
125 |
+
|
126 |
+
pbar = range(args.iter)
|
127 |
+
|
128 |
+
if get_rank() == 0:
|
129 |
+
pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
|
130 |
+
|
131 |
+
mean_path_length = 0
|
132 |
+
|
133 |
+
d_loss_val = 0
|
134 |
+
r1_loss = torch.tensor(0.0, device=device)
|
135 |
+
g_loss_val = 0
|
136 |
+
path_loss = torch.tensor(0.0, device=device)
|
137 |
+
path_lengths = torch.tensor(0.0, device=device)
|
138 |
+
mean_path_length_avg = 0
|
139 |
+
loss_dict = {}
|
140 |
+
|
141 |
+
if args.distributed:
|
142 |
+
g_module = generator.module
|
143 |
+
d_module = discriminator.module
|
144 |
+
|
145 |
+
else:
|
146 |
+
g_module = generator
|
147 |
+
d_module = discriminator
|
148 |
+
|
149 |
+
accum = 0.5 ** (32 / (10 * 1000))
|
150 |
+
|
151 |
+
sample_z = torch.randn(args.n_sample, args.latent, device=device)
|
152 |
+
|
153 |
+
for idx in pbar:
|
154 |
+
i = idx + args.start_iter
|
155 |
+
|
156 |
+
if i > args.iter:
|
157 |
+
print("Done!")
|
158 |
+
|
159 |
+
break
|
160 |
+
|
161 |
+
real_img = next(loader)
|
162 |
+
real_img = real_img.to(device)
|
163 |
+
|
164 |
+
requires_grad(generator, False)
|
165 |
+
requires_grad(discriminator, True)
|
166 |
+
|
167 |
+
noise = mixing_noise(args.batch, args.latent, args.mixing, device)
|
168 |
+
fake_img, _ = generator(noise)
|
169 |
+
fake_pred = discriminator(fake_img)
|
170 |
+
|
171 |
+
real_pred = discriminator(real_img)
|
172 |
+
d_loss = d_logistic_loss(real_pred, fake_pred)
|
173 |
+
|
174 |
+
loss_dict["d"] = d_loss
|
175 |
+
loss_dict["real_score"] = real_pred.mean()
|
176 |
+
loss_dict["fake_score"] = fake_pred.mean()
|
177 |
+
|
178 |
+
discriminator.zero_grad()
|
179 |
+
d_loss.backward()
|
180 |
+
d_optim.step()
|
181 |
+
|
182 |
+
d_regularize = i % args.d_reg_every == 0
|
183 |
+
|
184 |
+
if d_regularize:
|
185 |
+
real_img.requires_grad = True
|
186 |
+
real_pred = discriminator(real_img)
|
187 |
+
r1_loss = d_r1_loss(real_pred, real_img)
|
188 |
+
|
189 |
+
discriminator.zero_grad()
|
190 |
+
(args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()
|
191 |
+
|
192 |
+
d_optim.step()
|
193 |
+
|
194 |
+
loss_dict["r1"] = r1_loss
|
195 |
+
|
196 |
+
requires_grad(generator, True)
|
197 |
+
requires_grad(discriminator, False)
|
198 |
+
|
199 |
+
noise = mixing_noise(args.batch, args.latent, args.mixing, device)
|
200 |
+
fake_img, _ = generator(noise)
|
201 |
+
fake_pred = discriminator(fake_img)
|
202 |
+
g_loss = g_nonsaturating_loss(fake_pred)
|
203 |
+
|
204 |
+
loss_dict["g"] = g_loss
|
205 |
+
|
206 |
+
generator.zero_grad()
|
207 |
+
g_loss.backward()
|
208 |
+
g_optim.step()
|
209 |
+
|
210 |
+
g_regularize = i % args.g_reg_every == 0
|
211 |
+
|
212 |
+
if g_regularize:
|
213 |
+
path_batch_size = max(1, args.batch // args.path_batch_shrink)
|
214 |
+
noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
|
215 |
+
fake_img, latents = generator(noise, return_latents=True)
|
216 |
+
|
217 |
+
path_loss, mean_path_length, path_lengths = g_path_regularize(
|
218 |
+
fake_img, latents, mean_path_length
|
219 |
+
)
|
220 |
+
|
221 |
+
generator.zero_grad()
|
222 |
+
weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
|
223 |
+
|
224 |
+
if args.path_batch_shrink:
|
225 |
+
weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
|
226 |
+
|
227 |
+
weighted_path_loss.backward()
|
228 |
+
|
229 |
+
g_optim.step()
|
230 |
+
|
231 |
+
mean_path_length_avg = (
|
232 |
+
reduce_sum(mean_path_length).item() / get_world_size()
|
233 |
+
)
|
234 |
+
|
235 |
+
loss_dict["path"] = path_loss
|
236 |
+
loss_dict["path_length"] = path_lengths.mean()
|
237 |
+
|
238 |
+
accumulate(g_ema, g_module, accum)
|
239 |
+
|
240 |
+
loss_reduced = reduce_loss_dict(loss_dict)
|
241 |
+
|
242 |
+
d_loss_val = loss_reduced["d"].mean().item()
|
243 |
+
g_loss_val = loss_reduced["g"].mean().item()
|
244 |
+
r1_val = loss_reduced["r1"].mean().item()
|
245 |
+
path_loss_val = loss_reduced["path"].mean().item()
|
246 |
+
real_score_val = loss_reduced["real_score"].mean().item()
|
247 |
+
fake_score_val = loss_reduced["fake_score"].mean().item()
|
248 |
+
path_length_val = loss_reduced["path_length"].mean().item()
|
249 |
+
|
250 |
+
if get_rank() == 0:
|
251 |
+
pbar.set_description(
|
252 |
+
(
|
253 |
+
f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
|
254 |
+
f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}"
|
255 |
+
)
|
256 |
+
)
|
257 |
+
|
258 |
+
if wandb and args.wandb:
|
259 |
+
wandb.log(
|
260 |
+
{
|
261 |
+
"Generator": g_loss_val,
|
262 |
+
"Discriminator": d_loss_val,
|
263 |
+
"R1": r1_val,
|
264 |
+
"Path Length Regularization": path_loss_val,
|
265 |
+
"Mean Path Length": mean_path_length,
|
266 |
+
"Real Score": real_score_val,
|
267 |
+
"Fake Score": fake_score_val,
|
268 |
+
"Path Length": path_length_val,
|
269 |
+
}
|
270 |
+
)
|
271 |
+
|
272 |
+
if i % 100 == 0:
|
273 |
+
with torch.no_grad():
|
274 |
+
g_ema.eval()
|
275 |
+
sample, _ = g_ema([sample_z])
|
276 |
+
utils.save_image(
|
277 |
+
sample,
|
278 |
+
f"sample/{str(i).zfill(6)}.png",
|
279 |
+
nrow=int(args.n_sample ** 0.5),
|
280 |
+
normalize=True,
|
281 |
+
range=(-1, 1),
|
282 |
+
)
|
283 |
+
|
284 |
+
if i % 10000 == 0:
|
285 |
+
torch.save(
|
286 |
+
{
|
287 |
+
"g": g_module.state_dict(),
|
288 |
+
"d": d_module.state_dict(),
|
289 |
+
"g_ema": g_ema.state_dict(),
|
290 |
+
"g_optim": g_optim.state_dict(),
|
291 |
+
"d_optim": d_optim.state_dict(),
|
292 |
+
},
|
293 |
+
f"checkpoint/{str(i).zfill(6)}.pt",
|
294 |
+
)
|
295 |
+
|
296 |
+
|
297 |
+
if __name__ == "__main__":
|
298 |
+
device = "cuda"
|
299 |
+
|
300 |
+
parser = argparse.ArgumentParser()
|
301 |
+
|
302 |
+
parser.add_argument("path", type=str)
|
303 |
+
parser.add_argument("--iter", type=int, default=800000)
|
304 |
+
parser.add_argument("--batch", type=int, default=16)
|
305 |
+
parser.add_argument("--n_sample", type=int, default=64)
|
306 |
+
parser.add_argument("--size", type=int, default=256)
|
307 |
+
parser.add_argument("--r1", type=float, default=10)
|
308 |
+
parser.add_argument("--path_regularize", type=float, default=2)
|
309 |
+
parser.add_argument("--path_batch_shrink", type=int, default=2)
|
310 |
+
parser.add_argument("--d_reg_every", type=int, default=16)
|
311 |
+
parser.add_argument("--g_reg_every", type=int, default=4)
|
312 |
+
parser.add_argument("--mixing", type=float, default=0.9)
|
313 |
+
parser.add_argument("--ckpt", type=str, default=None)
|
314 |
+
parser.add_argument("--lr", type=float, default=0.002)
|
315 |
+
parser.add_argument("--channel_multiplier", type=int, default=2)
|
316 |
+
parser.add_argument("--wandb", action="store_true")
|
317 |
+
parser.add_argument("--local_rank", type=int, default=0)
|
318 |
+
|
319 |
+
args = parser.parse_args()
|
320 |
+
|
321 |
+
n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
|
322 |
+
args.distributed = n_gpu > 1
|
323 |
+
|
324 |
+
if args.distributed:
|
325 |
+
torch.cuda.set_device(args.local_rank)
|
326 |
+
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
327 |
+
synchronize()
|
328 |
+
|
329 |
+
args.latent = 512
|
330 |
+
args.n_mlp = 8
|
331 |
+
|
332 |
+
args.start_iter = 0
|
333 |
+
|
334 |
+
generator = Generator(
|
335 |
+
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
|
336 |
+
).to(device)
|
337 |
+
discriminator = Discriminator(
|
338 |
+
args.size, channel_multiplier=args.channel_multiplier
|
339 |
+
).to(device)
|
340 |
+
g_ema = Generator(
|
341 |
+
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
|
342 |
+
).to(device)
|
343 |
+
g_ema.eval()
|
344 |
+
accumulate(g_ema, generator, 0)
|
345 |
+
|
346 |
+
g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
|
347 |
+
d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
|
348 |
+
|
349 |
+
g_optim = optim.Adam(
|
350 |
+
generator.parameters(),
|
351 |
+
lr=args.lr * g_reg_ratio,
|
352 |
+
betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
|
353 |
+
)
|
354 |
+
d_optim = optim.Adam(
|
355 |
+
discriminator.parameters(),
|
356 |
+
lr=args.lr * d_reg_ratio,
|
357 |
+
betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
|
358 |
+
)
|
359 |
+
|
360 |
+
if args.ckpt is not None:
|
361 |
+
print("load model:", args.ckpt)
|
362 |
+
|
363 |
+
ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)
|
364 |
+
|
365 |
+
try:
|
366 |
+
ckpt_name = os.path.basename(args.ckpt)
|
367 |
+
args.start_iter = int(os.path.splitext(ckpt_name)[0])
|
368 |
+
|
369 |
+
except ValueError:
|
370 |
+
pass
|
371 |
+
|
372 |
+
generator.load_state_dict(ckpt["g"])
|
373 |
+
discriminator.load_state_dict(ckpt["d"])
|
374 |
+
g_ema.load_state_dict(ckpt["g_ema"])
|
375 |
+
|
376 |
+
g_optim.load_state_dict(ckpt["g_optim"])
|
377 |
+
d_optim.load_state_dict(ckpt["d_optim"])
|
378 |
+
|
379 |
+
if args.distributed:
|
380 |
+
generator = nn.parallel.DistributedDataParallel(
|
381 |
+
generator,
|
382 |
+
device_ids=[args.local_rank],
|
383 |
+
output_device=args.local_rank,
|
384 |
+
broadcast_buffers=False,
|
385 |
+
)
|
386 |
+
|
387 |
+
discriminator = nn.parallel.DistributedDataParallel(
|
388 |
+
discriminator,
|
389 |
+
device_ids=[args.local_rank],
|
390 |
+
output_device=args.local_rank,
|
391 |
+
broadcast_buffers=False,
|
392 |
+
)
|
393 |
+
|
394 |
+
transform = transforms.Compose(
|
395 |
+
[
|
396 |
+
transforms.RandomHorizontalFlip(),
|
397 |
+
transforms.ToTensor(),
|
398 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
|
399 |
+
]
|
400 |
+
)
|
401 |
+
|
402 |
+
dataset = MultiResolutionDataset(args.path, transform, args.size)
|
403 |
+
loader = data.DataLoader(
|
404 |
+
dataset,
|
405 |
+
batch_size=args.batch,
|
406 |
+
sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
|
407 |
+
drop_last=True,
|
408 |
+
)
|
409 |
+
|
410 |
+
if get_rank() == 0 and wandb is not None and args.wandb:
|
411 |
+
wandb.init(project="stylegan 2")
|
412 |
+
|
413 |
+
train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
|