Bill Ton Hoang Nguyen commited on
Commit
e5e3367
1 Parent(s): e36fc61
Files changed (43) hide show
  1. .DS_Store +0 -0
  2. models/.DS_Store +0 -0
  3. models/stylegan2/.DS_Store +0 -0
  4. models/stylegan2/stylegan2-pytorch/.DS_Store +0 -0
  5. models/stylegan2/stylegan2-pytorch/LICENSE +21 -0
  6. models/stylegan2/stylegan2-pytorch/LICENSE-FID +201 -0
  7. models/stylegan2/stylegan2-pytorch/LICENSE-LPIPS +24 -0
  8. models/stylegan2/stylegan2-pytorch/LICENSE-NVIDIA +101 -0
  9. models/stylegan2/stylegan2-pytorch/README.md +83 -0
  10. models/stylegan2/stylegan2-pytorch/calc_inception.py +116 -0
  11. models/stylegan2/stylegan2-pytorch/checkpoint/.gitignore +1 -0
  12. models/stylegan2/stylegan2-pytorch/convert_weight.py +283 -0
  13. models/stylegan2/stylegan2-pytorch/dataset.py +40 -0
  14. models/stylegan2/stylegan2-pytorch/distributed.py +126 -0
  15. models/stylegan2/stylegan2-pytorch/fid.py +107 -0
  16. models/stylegan2/stylegan2-pytorch/generate.py +55 -0
  17. models/stylegan2/stylegan2-pytorch/inception.py +310 -0
  18. models/stylegan2/stylegan2-pytorch/lpips/__init__.py +160 -0
  19. models/stylegan2/stylegan2-pytorch/lpips/base_model.py +58 -0
  20. models/stylegan2/stylegan2-pytorch/lpips/dist_model.py +284 -0
  21. models/stylegan2/stylegan2-pytorch/lpips/networks_basic.py +187 -0
  22. models/stylegan2/stylegan2-pytorch/lpips/pretrained_networks.py +181 -0
  23. models/stylegan2/stylegan2-pytorch/lpips/weights/v0.0/alex.pth +0 -0
  24. models/stylegan2/stylegan2-pytorch/lpips/weights/v0.0/squeeze.pth +0 -0
  25. models/stylegan2/stylegan2-pytorch/lpips/weights/v0.0/vgg.pth +0 -0
  26. models/stylegan2/stylegan2-pytorch/lpips/weights/v0.1/alex.pth +0 -0
  27. models/stylegan2/stylegan2-pytorch/lpips/weights/v0.1/squeeze.pth +0 -0
  28. models/stylegan2/stylegan2-pytorch/lpips/weights/v0.1/vgg.pth +0 -0
  29. models/stylegan2/stylegan2-pytorch/model.py +703 -0
  30. models/stylegan2/stylegan2-pytorch/non_leaking.py +137 -0
  31. models/stylegan2/stylegan2-pytorch/op/__init__.py +2 -0
  32. models/stylegan2/stylegan2-pytorch/op/fused_act.py +92 -0
  33. models/stylegan2/stylegan2-pytorch/op/fused_bias_act.cpp +21 -0
  34. models/stylegan2/stylegan2-pytorch/op/fused_bias_act_kernel.cu +99 -0
  35. models/stylegan2/stylegan2-pytorch/op/setup.py +33 -0
  36. models/stylegan2/stylegan2-pytorch/op/upfirdn2d.cpp +23 -0
  37. models/stylegan2/stylegan2-pytorch/op/upfirdn2d.py +198 -0
  38. models/stylegan2/stylegan2-pytorch/op/upfirdn2d_kernel.cu +369 -0
  39. models/stylegan2/stylegan2-pytorch/ppl.py +104 -0
  40. models/stylegan2/stylegan2-pytorch/prepare_data.py +82 -0
  41. models/stylegan2/stylegan2-pytorch/projector.py +203 -0
  42. models/stylegan2/stylegan2-pytorch/sample/.gitignore +1 -0
  43. models/stylegan2/stylegan2-pytorch/train.py +413 -0
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
models/.DS_Store CHANGED
Binary files a/models/.DS_Store and b/models/.DS_Store differ
 
models/stylegan2/.DS_Store CHANGED
Binary files a/models/stylegan2/.DS_Store and b/models/stylegan2/.DS_Store differ
 
models/stylegan2/stylegan2-pytorch/.DS_Store ADDED
Binary file (8.2 kB). View file
 
models/stylegan2/stylegan2-pytorch/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 Kim Seonghyeon
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
models/stylegan2/stylegan2-pytorch/LICENSE-FID ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
models/stylegan2/stylegan2-pytorch/LICENSE-LPIPS ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24
+
models/stylegan2/stylegan2-pytorch/LICENSE-NVIDIA ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+
3
+
4
+ Nvidia Source Code License-NC
5
+
6
+ =======================================================================
7
+
8
+ 1. Definitions
9
+
10
+ "Licensor" means any person or entity that distributes its Work.
11
+
12
+ "Software" means the original work of authorship made available under
13
+ this License.
14
+
15
+ "Work" means the Software and any additions to or derivative works of
16
+ the Software that are made available under this License.
17
+
18
+ "Nvidia Processors" means any central processing unit (CPU), graphics
19
+ processing unit (GPU), field-programmable gate array (FPGA),
20
+ application-specific integrated circuit (ASIC) or any combination
21
+ thereof designed, made, sold, or provided by Nvidia or its affiliates.
22
+
23
+ The terms "reproduce," "reproduction," "derivative works," and
24
+ "distribution" have the meaning as provided under U.S. copyright law;
25
+ provided, however, that for the purposes of this License, derivative
26
+ works shall not include works that remain separable from, or merely
27
+ link (or bind by name) to the interfaces of, the Work.
28
+
29
+ Works, including the Software, are "made available" under this License
30
+ by including in or with the Work either (a) a copyright notice
31
+ referencing the applicability of this License to the Work, or (b) a
32
+ copy of this License.
33
+
34
+ 2. License Grants
35
+
36
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
37
+ License, each Licensor grants to you a perpetual, worldwide,
38
+ non-exclusive, royalty-free, copyright license to reproduce,
39
+ prepare derivative works of, publicly display, publicly perform,
40
+ sublicense and distribute its Work and any resulting derivative
41
+ works in any form.
42
+
43
+ 3. Limitations
44
+
45
+ 3.1 Redistribution. You may reproduce or distribute the Work only
46
+ if (a) you do so under this License, (b) you include a complete
47
+ copy of this License with your distribution, and (c) you retain
48
+ without modification any copyright, patent, trademark, or
49
+ attribution notices that are present in the Work.
50
+
51
+ 3.2 Derivative Works. You may specify that additional or different
52
+ terms apply to the use, reproduction, and distribution of your
53
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
54
+ provide that the use limitation in Section 3.3 applies to your
55
+ derivative works, and (b) you identify the specific derivative
56
+ works that are subject to Your Terms. Notwithstanding Your Terms,
57
+ this License (including the redistribution requirements in Section
58
+ 3.1) will continue to apply to the Work itself.
59
+
60
+ 3.3 Use Limitation. The Work and any derivative works thereof only
61
+ may be used or intended for use non-commercially. The Work or
62
+ derivative works thereof may be used or intended for use by Nvidia
63
+ or its affiliates commercially or non-commercially. As used herein,
64
+ "non-commercially" means for research or evaluation purposes only.
65
+
66
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
67
+ against any Licensor (including any claim, cross-claim or
68
+ counterclaim in a lawsuit) to enforce any patents that you allege
69
+ are infringed by any Work, then your rights under this License from
70
+ such Licensor (including the grants in Sections 2.1 and 2.2) will
71
+ terminate immediately.
72
+
73
+ 3.5 Trademarks. This License does not grant any rights to use any
74
+ Licensor's or its affiliates' names, logos, or trademarks, except
75
+ as necessary to reproduce the notices described in this License.
76
+
77
+ 3.6 Termination. If you violate any term of this License, then your
78
+ rights under this License (including the grants in Sections 2.1 and
79
+ 2.2) will terminate immediately.
80
+
81
+ 4. Disclaimer of Warranty.
82
+
83
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
84
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
85
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
86
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
87
+ THIS LICENSE.
88
+
89
+ 5. Limitation of Liability.
90
+
91
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
92
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
93
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
94
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
95
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
96
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
97
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
98
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
99
+ THE POSSIBILITY OF SUCH DAMAGES.
100
+
101
+ =======================================================================
models/stylegan2/stylegan2-pytorch/README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # StyleGAN 2 in PyTorch
2
+
3
+ Implementation of Analyzing and Improving the Image Quality of StyleGAN (https://arxiv.org/abs/1912.04958) in PyTorch
4
+
5
+ ## Notice
6
+
7
+ I have tried to match official implementation as close as possible, but maybe there are some details I missed. So please use this implementation with care.
8
+
9
+ ## Requirements
10
+
11
+ I have tested on:
12
+
13
+ * PyTorch 1.3.1
14
+ * CUDA 10.1/10.2
15
+
16
+ ## Usage
17
+
18
+ First create lmdb datasets:
19
+
20
+ > python prepare_data.py --out LMDB_PATH --n_worker N_WORKER --size SIZE1,SIZE2,SIZE3,... DATASET_PATH
21
+
22
+ This will convert images to jpeg and pre-resizes it. This implementation does not use progressive growing, but you can create multiple resolution datasets using size arguments with comma separated lists, for the cases that you want to try another resolutions later.
23
+
24
+ Then you can train model in distributed settings
25
+
26
+ > python -m torch.distributed.launch --nproc_per_node=N_GPU --master_port=PORT train.py --batch BATCH_SIZE LMDB_PATH
27
+
28
+ train.py supports Weights & Biases logging. If you want to use it, add --wandb arguments to the script.
29
+
30
+ ### Convert weight from official checkpoints
31
+
32
+ You need to clone official repositories, (https://github.com/NVlabs/stylegan2) as it is requires for load official checkpoints.
33
+
34
+ Next, create a conda environment with TF-GPU and Torch-CPU (using GPU for both results in CUDA version mismatches):<br>
35
+ `conda create -n tf_torch python=3.7 requests tensorflow-gpu=1.14 cudatoolkit=10.0 numpy=1.14 pytorch=1.6 torchvision cpuonly -c pytorch`
36
+
37
+ For example, if you cloned repositories in ~/stylegan2 and downloaded stylegan2-ffhq-config-f.pkl, You can convert it like this:
38
+
39
+ > python convert_weight.py --repo ~/stylegan2 stylegan2-ffhq-config-f.pkl
40
+
41
+ This will create converted stylegan2-ffhq-config-f.pt file.
42
+
43
+ If using GCC, you might have to set `-D_GLIBCXX_USE_CXX11_ABI=1` in `~/stylegan2/dnnlib/tflib/custom_ops.py`.
44
+
45
+ ### Generate samples
46
+
47
+ > python generate.py --sample N_FACES --pics N_PICS --ckpt PATH_CHECKPOINT
48
+
49
+ You should change your size (--size 256 for example) if you train with another dimension.
50
+
51
+ ### Project images to latent spaces
52
+
53
+ > python projector.py --ckpt [CHECKPOINT] --size [GENERATOR_OUTPUT_SIZE] FILE1 FILE2 ...
54
+
55
+ ## Pretrained Checkpoints
56
+
57
+ [Link](https://drive.google.com/open?id=1PQutd-JboOCOZqmd95XWxWrO8gGEvRcO)
58
+
59
+ I have trained the 256px model on FFHQ 550k iterations. I got FID about 4.5. Maybe data preprocessing, resolution, training loop could made this difference, but currently I don't know the exact reason of FID differences.
60
+
61
+ ## Samples
62
+
63
+ ![Sample with truncation](doc/sample.png)
64
+
65
+ At 110,000 iterations. (trained on 3.52M images)
66
+
67
+ ### Samples from converted weights
68
+
69
+ ![Sample from FFHQ](doc/stylegan2-ffhq-config-f.png)
70
+
71
+ Sample from FFHQ (1024px)
72
+
73
+ ![Sample from LSUN Church](doc/stylegan2-church-config-f.png)
74
+
75
+ Sample from LSUN Church (256px)
76
+
77
+ ## License
78
+
79
+ Model details and custom CUDA kernel codes are from official repostiories: https://github.com/NVlabs/stylegan2
80
+
81
+ Codes for Learned Perceptual Image Patch Similarity, LPIPS came from https://github.com/richzhang/PerceptualSimilarity
82
+
83
+ To match FID scores more closely to tensorflow official implementations, I have used FID Inception V3 implementations in https://github.com/mseitzer/pytorch-fid
models/stylegan2/stylegan2-pytorch/calc_inception.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pickle
3
+ import os
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from torch.utils.data import DataLoader
9
+ from torchvision import transforms
10
+ from torchvision.models import inception_v3, Inception3
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+
14
+ from inception import InceptionV3
15
+ from dataset import MultiResolutionDataset
16
+
17
+
18
+ class Inception3Feature(Inception3):
19
+ def forward(self, x):
20
+ if x.shape[2] != 299 or x.shape[3] != 299:
21
+ x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True)
22
+
23
+ x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3
24
+ x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32
25
+ x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32
26
+ x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64
27
+
28
+ x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64
29
+ x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80
30
+ x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192
31
+
32
+ x = self.Mixed_5b(x) # 35 x 35 x 192
33
+ x = self.Mixed_5c(x) # 35 x 35 x 256
34
+ x = self.Mixed_5d(x) # 35 x 35 x 288
35
+
36
+ x = self.Mixed_6a(x) # 35 x 35 x 288
37
+ x = self.Mixed_6b(x) # 17 x 17 x 768
38
+ x = self.Mixed_6c(x) # 17 x 17 x 768
39
+ x = self.Mixed_6d(x) # 17 x 17 x 768
40
+ x = self.Mixed_6e(x) # 17 x 17 x 768
41
+
42
+ x = self.Mixed_7a(x) # 17 x 17 x 768
43
+ x = self.Mixed_7b(x) # 8 x 8 x 1280
44
+ x = self.Mixed_7c(x) # 8 x 8 x 2048
45
+
46
+ x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048
47
+
48
+ return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048
49
+
50
+
51
+ def load_patched_inception_v3():
52
+ # inception = inception_v3(pretrained=True)
53
+ # inception_feat = Inception3Feature()
54
+ # inception_feat.load_state_dict(inception.state_dict())
55
+ inception_feat = InceptionV3([3], normalize_input=False)
56
+
57
+ return inception_feat
58
+
59
+
60
+ @torch.no_grad()
61
+ def extract_features(loader, inception, device):
62
+ pbar = tqdm(loader)
63
+
64
+ feature_list = []
65
+
66
+ for img in pbar:
67
+ img = img.to(device)
68
+ feature = inception(img)[0].view(img.shape[0], -1)
69
+ feature_list.append(feature.to('cpu'))
70
+
71
+ features = torch.cat(feature_list, 0)
72
+
73
+ return features
74
+
75
+
76
+ if __name__ == '__main__':
77
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
78
+
79
+ parser = argparse.ArgumentParser(
80
+ description='Calculate Inception v3 features for datasets'
81
+ )
82
+ parser.add_argument('--size', type=int, default=256)
83
+ parser.add_argument('--batch', default=64, type=int, help='batch size')
84
+ parser.add_argument('--n_sample', type=int, default=50000)
85
+ parser.add_argument('--flip', action='store_true')
86
+ parser.add_argument('path', metavar='PATH', help='path to datset lmdb file')
87
+
88
+ args = parser.parse_args()
89
+
90
+ inception = load_patched_inception_v3()
91
+ inception = nn.DataParallel(inception).eval().to(device)
92
+
93
+ transform = transforms.Compose(
94
+ [
95
+ transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
96
+ transforms.ToTensor(),
97
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
98
+ ]
99
+ )
100
+
101
+ dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size)
102
+ loader = DataLoader(dset, batch_size=args.batch, num_workers=4)
103
+
104
+ features = extract_features(loader, inception, device).numpy()
105
+
106
+ features = features[: args.n_sample]
107
+
108
+ print(f'extracted {features.shape[0]} features')
109
+
110
+ mean = np.mean(features, 0)
111
+ cov = np.cov(features, rowvar=False)
112
+
113
+ name = os.path.splitext(os.path.basename(args.path))[0]
114
+
115
+ with open(f'inception_{name}.pkl', 'wb') as f:
116
+ pickle.dump({'mean': mean, 'cov': cov, 'size': args.size, 'path': args.path}, f)
models/stylegan2/stylegan2-pytorch/checkpoint/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pt
models/stylegan2/stylegan2-pytorch/convert_weight.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import pickle
5
+ import math
6
+
7
+ import torch
8
+ import numpy as np
9
+ from torchvision import utils
10
+
11
+ from model import Generator, Discriminator
12
+
13
+
14
+ def convert_modconv(vars, source_name, target_name, flip=False):
15
+ weight = vars[source_name + '/weight'].value().eval()
16
+ mod_weight = vars[source_name + '/mod_weight'].value().eval()
17
+ mod_bias = vars[source_name + '/mod_bias'].value().eval()
18
+ noise = vars[source_name + '/noise_strength'].value().eval()
19
+ bias = vars[source_name + '/bias'].value().eval()
20
+
21
+ dic = {
22
+ 'conv.weight': np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
23
+ 'conv.modulation.weight': mod_weight.transpose((1, 0)),
24
+ 'conv.modulation.bias': mod_bias + 1,
25
+ 'noise.weight': np.array([noise]),
26
+ 'activate.bias': bias,
27
+ }
28
+
29
+ dic_torch = {}
30
+
31
+ for k, v in dic.items():
32
+ dic_torch[target_name + '.' + k] = torch.from_numpy(v)
33
+
34
+ if flip:
35
+ dic_torch[target_name + '.conv.weight'] = torch.flip(
36
+ dic_torch[target_name + '.conv.weight'], [3, 4]
37
+ )
38
+
39
+ return dic_torch
40
+
41
+
42
+ def convert_conv(vars, source_name, target_name, bias=True, start=0):
43
+ weight = vars[source_name + '/weight'].value().eval()
44
+
45
+ dic = {'weight': weight.transpose((3, 2, 0, 1))}
46
+
47
+ if bias:
48
+ dic['bias'] = vars[source_name + '/bias'].value().eval()
49
+
50
+ dic_torch = {}
51
+
52
+ dic_torch[target_name + f'.{start}.weight'] = torch.from_numpy(dic['weight'])
53
+
54
+ if bias:
55
+ dic_torch[target_name + f'.{start + 1}.bias'] = torch.from_numpy(dic['bias'])
56
+
57
+ return dic_torch
58
+
59
+
60
+ def convert_torgb(vars, source_name, target_name):
61
+ weight = vars[source_name + '/weight'].value().eval()
62
+ mod_weight = vars[source_name + '/mod_weight'].value().eval()
63
+ mod_bias = vars[source_name + '/mod_bias'].value().eval()
64
+ bias = vars[source_name + '/bias'].value().eval()
65
+
66
+ dic = {
67
+ 'conv.weight': np.expand_dims(weight.transpose((3, 2, 0, 1)), 0),
68
+ 'conv.modulation.weight': mod_weight.transpose((1, 0)),
69
+ 'conv.modulation.bias': mod_bias + 1,
70
+ 'bias': bias.reshape((1, 3, 1, 1)),
71
+ }
72
+
73
+ dic_torch = {}
74
+
75
+ for k, v in dic.items():
76
+ dic_torch[target_name + '.' + k] = torch.from_numpy(v)
77
+
78
+ return dic_torch
79
+
80
+
81
+ def convert_dense(vars, source_name, target_name):
82
+ weight = vars[source_name + '/weight'].value().eval()
83
+ bias = vars[source_name + '/bias'].value().eval()
84
+
85
+ dic = {'weight': weight.transpose((1, 0)), 'bias': bias}
86
+
87
+ dic_torch = {}
88
+
89
+ for k, v in dic.items():
90
+ dic_torch[target_name + '.' + k] = torch.from_numpy(v)
91
+
92
+ return dic_torch
93
+
94
+
95
+ def update(state_dict, new):
96
+ for k, v in new.items():
97
+ if k not in state_dict:
98
+ raise KeyError(k + ' is not found')
99
+
100
+ if v.shape != state_dict[k].shape:
101
+ raise ValueError(f'Shape mismatch: {v.shape} vs {state_dict[k].shape}')
102
+
103
+ state_dict[k] = v
104
+
105
+
106
+ def discriminator_fill_statedict(statedict, vars, size):
107
+ log_size = int(math.log(size, 2))
108
+
109
+ update(statedict, convert_conv(vars, f'{size}x{size}/FromRGB', 'convs.0'))
110
+
111
+ conv_i = 1
112
+
113
+ for i in range(log_size - 2, 0, -1):
114
+ reso = 4 * 2 ** i
115
+ update(
116
+ statedict,
117
+ convert_conv(vars, f'{reso}x{reso}/Conv0', f'convs.{conv_i}.conv1'),
118
+ )
119
+ update(
120
+ statedict,
121
+ convert_conv(
122
+ vars, f'{reso}x{reso}/Conv1_down', f'convs.{conv_i}.conv2', start=1
123
+ ),
124
+ )
125
+ update(
126
+ statedict,
127
+ convert_conv(
128
+ vars, f'{reso}x{reso}/Skip', f'convs.{conv_i}.skip', start=1, bias=False
129
+ ),
130
+ )
131
+ conv_i += 1
132
+
133
+ update(statedict, convert_conv(vars, f'4x4/Conv', 'final_conv'))
134
+ update(statedict, convert_dense(vars, f'4x4/Dense0', 'final_linear.0'))
135
+ update(statedict, convert_dense(vars, f'Output', 'final_linear.1'))
136
+
137
+ return statedict
138
+
139
+
140
+ def fill_statedict(state_dict, vars, size):
141
+ log_size = int(math.log(size, 2))
142
+
143
+ for i in range(8):
144
+ update(state_dict, convert_dense(vars, f'G_mapping/Dense{i}', f'style.{i + 1}'))
145
+
146
+ update(
147
+ state_dict,
148
+ {
149
+ 'input.input': torch.from_numpy(
150
+ vars['G_synthesis/4x4/Const/const'].value().eval()
151
+ )
152
+ },
153
+ )
154
+
155
+ update(state_dict, convert_torgb(vars, 'G_synthesis/4x4/ToRGB', 'to_rgb1'))
156
+
157
+ for i in range(log_size - 2):
158
+ reso = 4 * 2 ** (i + 1)
159
+ update(
160
+ state_dict,
161
+ convert_torgb(vars, f'G_synthesis/{reso}x{reso}/ToRGB', f'to_rgbs.{i}'),
162
+ )
163
+
164
+ update(state_dict, convert_modconv(vars, 'G_synthesis/4x4/Conv', 'conv1'))
165
+
166
+ conv_i = 0
167
+
168
+ for i in range(log_size - 2):
169
+ reso = 4 * 2 ** (i + 1)
170
+ update(
171
+ state_dict,
172
+ convert_modconv(
173
+ vars,
174
+ f'G_synthesis/{reso}x{reso}/Conv0_up',
175
+ f'convs.{conv_i}',
176
+ flip=True,
177
+ ),
178
+ )
179
+ update(
180
+ state_dict,
181
+ convert_modconv(
182
+ vars, f'G_synthesis/{reso}x{reso}/Conv1', f'convs.{conv_i + 1}'
183
+ ),
184
+ )
185
+ conv_i += 2
186
+
187
+ for i in range(0, (log_size - 2) * 2 + 1):
188
+ update(
189
+ state_dict,
190
+ {
191
+ f'noises.noise_{i}': torch.from_numpy(
192
+ vars[f'G_synthesis/noise{i}'].value().eval()
193
+ )
194
+ },
195
+ )
196
+
197
+ return state_dict
198
+
199
+
200
+ if __name__ == '__main__':
201
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
202
+ print('Using PyTorch device', device)
203
+
204
+ parser = argparse.ArgumentParser()
205
+ parser.add_argument('--repo', type=str, required=True)
206
+ parser.add_argument('--gen', action='store_true')
207
+ parser.add_argument('--disc', action='store_true')
208
+ parser.add_argument('--channel_multiplier', type=int, default=2)
209
+ parser.add_argument('path', metavar='PATH')
210
+
211
+ args = parser.parse_args()
212
+
213
+ sys.path.append(args.repo)
214
+
215
+ import dnnlib
216
+ from dnnlib import tflib
217
+
218
+ tflib.init_tf()
219
+
220
+ with open(args.path, 'rb') as f:
221
+ generator, discriminator, g_ema = pickle.load(f)
222
+
223
+ size = g_ema.output_shape[2]
224
+
225
+ g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
226
+ state_dict = g.state_dict()
227
+ state_dict = fill_statedict(state_dict, g_ema.vars, size)
228
+
229
+ g.load_state_dict(state_dict)
230
+
231
+ latent_avg = torch.from_numpy(g_ema.vars['dlatent_avg'].value().eval())
232
+
233
+ ckpt = {'g_ema': state_dict, 'latent_avg': latent_avg}
234
+
235
+ if args.gen:
236
+ g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
237
+ g_train_state = g_train.state_dict()
238
+ g_train_state = fill_statedict(g_train_state, generator.vars, size)
239
+ ckpt['g'] = g_train_state
240
+
241
+ if args.disc:
242
+ disc = Discriminator(size, channel_multiplier=args.channel_multiplier)
243
+ d_state = disc.state_dict()
244
+ d_state = discriminator_fill_statedict(d_state, discriminator.vars, size)
245
+ ckpt['d'] = d_state
246
+
247
+ name = os.path.splitext(os.path.basename(args.path))[0]
248
+ outpath = os.path.join(os.getcwd(), f'{name}.pt')
249
+ print('Saving', outpath)
250
+ try:
251
+ torch.save(ckpt, outpath, _use_new_zipfile_serialization=False)
252
+ except TypeError:
253
+ torch.save(ckpt, outpath)
254
+
255
+
256
+ print('Generating TF-Torch comparison images')
257
+ batch_size = {256: 8, 512: 4, 1024: 2}
258
+ n_sample = batch_size.get(size, 4)
259
+
260
+ g = g.to(device)
261
+
262
+ z = np.random.RandomState(0).randn(n_sample, 512).astype('float32')
263
+
264
+ with torch.no_grad():
265
+ img_pt, _ = g(
266
+ [torch.from_numpy(z).to(device)],
267
+ truncation=0.5,
268
+ truncation_latent=latent_avg.to(device),
269
+ )
270
+
271
+ img_tf = g_ema.run(z, None, randomize_noise=False)
272
+ img_tf = torch.from_numpy(img_tf).to(device)
273
+
274
+ img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(
275
+ 0.0, 1.0
276
+ )
277
+
278
+ img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0)
279
+ utils.save_image(
280
+ img_concat, name + '.png', nrow=n_sample, normalize=True, range=(-1, 1)
281
+ )
282
+ print('Done')
283
+
models/stylegan2/stylegan2-pytorch/dataset.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+
3
+ import lmdb
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ class MultiResolutionDataset(Dataset):
9
+ def __init__(self, path, transform, resolution=256):
10
+ self.env = lmdb.open(
11
+ path,
12
+ max_readers=32,
13
+ readonly=True,
14
+ lock=False,
15
+ readahead=False,
16
+ meminit=False,
17
+ )
18
+
19
+ if not self.env:
20
+ raise IOError('Cannot open lmdb dataset', path)
21
+
22
+ with self.env.begin(write=False) as txn:
23
+ self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
24
+
25
+ self.resolution = resolution
26
+ self.transform = transform
27
+
28
+ def __len__(self):
29
+ return self.length
30
+
31
+ def __getitem__(self, index):
32
+ with self.env.begin(write=False) as txn:
33
+ key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
34
+ img_bytes = txn.get(key)
35
+
36
+ buffer = BytesIO(img_bytes)
37
+ img = Image.open(buffer)
38
+ img = self.transform(img)
39
+
40
+ return img
models/stylegan2/stylegan2-pytorch/distributed.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import pickle
3
+
4
+ import torch
5
+ from torch import distributed as dist
6
+ from torch.utils.data.sampler import Sampler
7
+
8
+
9
+ def get_rank():
10
+ if not dist.is_available():
11
+ return 0
12
+
13
+ if not dist.is_initialized():
14
+ return 0
15
+
16
+ return dist.get_rank()
17
+
18
+
19
+ def synchronize():
20
+ if not dist.is_available():
21
+ return
22
+
23
+ if not dist.is_initialized():
24
+ return
25
+
26
+ world_size = dist.get_world_size()
27
+
28
+ if world_size == 1:
29
+ return
30
+
31
+ dist.barrier()
32
+
33
+
34
+ def get_world_size():
35
+ if not dist.is_available():
36
+ return 1
37
+
38
+ if not dist.is_initialized():
39
+ return 1
40
+
41
+ return dist.get_world_size()
42
+
43
+
44
+ def reduce_sum(tensor):
45
+ if not dist.is_available():
46
+ return tensor
47
+
48
+ if not dist.is_initialized():
49
+ return tensor
50
+
51
+ tensor = tensor.clone()
52
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
53
+
54
+ return tensor
55
+
56
+
57
+ def gather_grad(params):
58
+ world_size = get_world_size()
59
+
60
+ if world_size == 1:
61
+ return
62
+
63
+ for param in params:
64
+ if param.grad is not None:
65
+ dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
66
+ param.grad.data.div_(world_size)
67
+
68
+
69
+ def all_gather(data):
70
+ world_size = get_world_size()
71
+
72
+ if world_size == 1:
73
+ return [data]
74
+
75
+ buffer = pickle.dumps(data)
76
+ storage = torch.ByteStorage.from_buffer(buffer)
77
+ tensor = torch.ByteTensor(storage).to('cuda')
78
+
79
+ local_size = torch.IntTensor([tensor.numel()]).to('cuda')
80
+ size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
81
+ dist.all_gather(size_list, local_size)
82
+ size_list = [int(size.item()) for size in size_list]
83
+ max_size = max(size_list)
84
+
85
+ tensor_list = []
86
+ for _ in size_list:
87
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
88
+
89
+ if local_size != max_size:
90
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
91
+ tensor = torch.cat((tensor, padding), 0)
92
+
93
+ dist.all_gather(tensor_list, tensor)
94
+
95
+ data_list = []
96
+
97
+ for size, tensor in zip(size_list, tensor_list):
98
+ buffer = tensor.cpu().numpy().tobytes()[:size]
99
+ data_list.append(pickle.loads(buffer))
100
+
101
+ return data_list
102
+
103
+
104
+ def reduce_loss_dict(loss_dict):
105
+ world_size = get_world_size()
106
+
107
+ if world_size < 2:
108
+ return loss_dict
109
+
110
+ with torch.no_grad():
111
+ keys = []
112
+ losses = []
113
+
114
+ for k in sorted(loss_dict.keys()):
115
+ keys.append(k)
116
+ losses.append(loss_dict[k])
117
+
118
+ losses = torch.stack(losses, 0)
119
+ dist.reduce(losses, dst=0)
120
+
121
+ if dist.get_rank() == 0:
122
+ losses /= world_size
123
+
124
+ reduced_losses = {k: v for k, v in zip(keys, losses)}
125
+
126
+ return reduced_losses
models/stylegan2/stylegan2-pytorch/fid.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pickle
3
+
4
+ import torch
5
+ from torch import nn
6
+ import numpy as np
7
+ from scipy import linalg
8
+ from tqdm import tqdm
9
+
10
+ from model import Generator
11
+ from calc_inception import load_patched_inception_v3
12
+
13
+
14
+ @torch.no_grad()
15
+ def extract_feature_from_samples(
16
+ generator, inception, truncation, truncation_latent, batch_size, n_sample, device
17
+ ):
18
+ n_batch = n_sample // batch_size
19
+ resid = n_sample - (n_batch * batch_size)
20
+ batch_sizes = [batch_size] * n_batch + [resid]
21
+ features = []
22
+
23
+ for batch in tqdm(batch_sizes):
24
+ latent = torch.randn(batch, 512, device=device)
25
+ img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent)
26
+ feat = inception(img)[0].view(img.shape[0], -1)
27
+ features.append(feat.to('cpu'))
28
+
29
+ features = torch.cat(features, 0)
30
+
31
+ return features
32
+
33
+
34
+ def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6):
35
+ cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False)
36
+
37
+ if not np.isfinite(cov_sqrt).all():
38
+ print('product of cov matrices is singular')
39
+ offset = np.eye(sample_cov.shape[0]) * eps
40
+ cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset))
41
+
42
+ if np.iscomplexobj(cov_sqrt):
43
+ if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
44
+ m = np.max(np.abs(cov_sqrt.imag))
45
+
46
+ raise ValueError(f'Imaginary component {m}')
47
+
48
+ cov_sqrt = cov_sqrt.real
49
+
50
+ mean_diff = sample_mean - real_mean
51
+ mean_norm = mean_diff @ mean_diff
52
+
53
+ trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt)
54
+
55
+ fid = mean_norm + trace
56
+
57
+ return fid
58
+
59
+
60
+ if __name__ == '__main__':
61
+ device = 'cuda'
62
+
63
+ parser = argparse.ArgumentParser()
64
+
65
+ parser.add_argument('--truncation', type=float, default=1)
66
+ parser.add_argument('--truncation_mean', type=int, default=4096)
67
+ parser.add_argument('--batch', type=int, default=64)
68
+ parser.add_argument('--n_sample', type=int, default=50000)
69
+ parser.add_argument('--size', type=int, default=256)
70
+ parser.add_argument('--inception', type=str, default=None, required=True)
71
+ parser.add_argument('ckpt', metavar='CHECKPOINT')
72
+
73
+ args = parser.parse_args()
74
+
75
+ ckpt = torch.load(args.ckpt)
76
+
77
+ g = Generator(args.size, 512, 8).to(device)
78
+ g.load_state_dict(ckpt['g_ema'])
79
+ g = nn.DataParallel(g)
80
+ g.eval()
81
+
82
+ if args.truncation < 1:
83
+ with torch.no_grad():
84
+ mean_latent = g.mean_latent(args.truncation_mean)
85
+
86
+ else:
87
+ mean_latent = None
88
+
89
+ inception = nn.DataParallel(load_patched_inception_v3()).to(device)
90
+ inception.eval()
91
+
92
+ features = extract_feature_from_samples(
93
+ g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device
94
+ ).numpy()
95
+ print(f'extracted {features.shape[0]} features')
96
+
97
+ sample_mean = np.mean(features, 0)
98
+ sample_cov = np.cov(features, rowvar=False)
99
+
100
+ with open(args.inception, 'rb') as f:
101
+ embeds = pickle.load(f)
102
+ real_mean = embeds['mean']
103
+ real_cov = embeds['cov']
104
+
105
+ fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov)
106
+
107
+ print('fid:', fid)
models/stylegan2/stylegan2-pytorch/generate.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from torchvision import utils
5
+ from model import Generator
6
+ from tqdm import tqdm
7
+ def generate(args, g_ema, device, mean_latent):
8
+
9
+ with torch.no_grad():
10
+ g_ema.eval()
11
+ for i in tqdm(range(args.pics)):
12
+ sample_z = torch.randn(args.sample, args.latent, device=device)
13
+
14
+ sample, _ = g_ema([sample_z], truncation=args.truncation, truncation_latent=mean_latent)
15
+
16
+ utils.save_image(
17
+ sample,
18
+ f'sample/{str(i).zfill(6)}.png',
19
+ nrow=1,
20
+ normalize=True,
21
+ range=(-1, 1),
22
+ )
23
+
24
+ if __name__ == '__main__':
25
+ device = 'cuda'
26
+
27
+ parser = argparse.ArgumentParser()
28
+
29
+ parser.add_argument('--size', type=int, default=1024)
30
+ parser.add_argument('--sample', type=int, default=1)
31
+ parser.add_argument('--pics', type=int, default=20)
32
+ parser.add_argument('--truncation', type=float, default=1)
33
+ parser.add_argument('--truncation_mean', type=int, default=4096)
34
+ parser.add_argument('--ckpt', type=str, default="stylegan2-ffhq-config-f.pt")
35
+ parser.add_argument('--channel_multiplier', type=int, default=2)
36
+
37
+ args = parser.parse_args()
38
+
39
+ args.latent = 512
40
+ args.n_mlp = 8
41
+
42
+ g_ema = Generator(
43
+ args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
44
+ ).to(device)
45
+ checkpoint = torch.load(args.ckpt)
46
+
47
+ g_ema.load_state_dict(checkpoint['g_ema'])
48
+
49
+ if args.truncation < 1:
50
+ with torch.no_grad():
51
+ mean_latent = g_ema.mean_latent(args.truncation_mean)
52
+ else:
53
+ mean_latent = None
54
+
55
+ generate(args, g_ema, device, mean_latent)
models/stylegan2/stylegan2-pytorch/inception.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import models
5
+
6
+ try:
7
+ from torchvision.models.utils import load_state_dict_from_url
8
+ except ImportError:
9
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
+
11
+ # Inception weights ported to Pytorch from
12
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
14
+
15
+
16
+ class InceptionV3(nn.Module):
17
+ """Pretrained InceptionV3 network returning feature maps"""
18
+
19
+ # Index of default block of inception to return,
20
+ # corresponds to output of final average pooling
21
+ DEFAULT_BLOCK_INDEX = 3
22
+
23
+ # Maps feature dimensionality to their output blocks indices
24
+ BLOCK_INDEX_BY_DIM = {
25
+ 64: 0, # First max pooling features
26
+ 192: 1, # Second max pooling featurs
27
+ 768: 2, # Pre-aux classifier features
28
+ 2048: 3 # Final average pooling features
29
+ }
30
+
31
+ def __init__(self,
32
+ output_blocks=[DEFAULT_BLOCK_INDEX],
33
+ resize_input=True,
34
+ normalize_input=True,
35
+ requires_grad=False,
36
+ use_fid_inception=True):
37
+ """Build pretrained InceptionV3
38
+
39
+ Parameters
40
+ ----------
41
+ output_blocks : list of int
42
+ Indices of blocks to return features of. Possible values are:
43
+ - 0: corresponds to output of first max pooling
44
+ - 1: corresponds to output of second max pooling
45
+ - 2: corresponds to output which is fed to aux classifier
46
+ - 3: corresponds to output of final average pooling
47
+ resize_input : bool
48
+ If true, bilinearly resizes input to width and height 299 before
49
+ feeding input to model. As the network without fully connected
50
+ layers is fully convolutional, it should be able to handle inputs
51
+ of arbitrary size, so resizing might not be strictly needed
52
+ normalize_input : bool
53
+ If true, scales the input from range (0, 1) to the range the
54
+ pretrained Inception network expects, namely (-1, 1)
55
+ requires_grad : bool
56
+ If true, parameters of the model require gradients. Possibly useful
57
+ for finetuning the network
58
+ use_fid_inception : bool
59
+ If true, uses the pretrained Inception model used in Tensorflow's
60
+ FID implementation. If false, uses the pretrained Inception model
61
+ available in torchvision. The FID Inception model has different
62
+ weights and a slightly different structure from torchvision's
63
+ Inception model. If you want to compute FID scores, you are
64
+ strongly advised to set this parameter to true to get comparable
65
+ results.
66
+ """
67
+ super(InceptionV3, self).__init__()
68
+
69
+ self.resize_input = resize_input
70
+ self.normalize_input = normalize_input
71
+ self.output_blocks = sorted(output_blocks)
72
+ self.last_needed_block = max(output_blocks)
73
+
74
+ assert self.last_needed_block <= 3, \
75
+ 'Last possible output block index is 3'
76
+
77
+ self.blocks = nn.ModuleList()
78
+
79
+ if use_fid_inception:
80
+ inception = fid_inception_v3()
81
+ else:
82
+ inception = models.inception_v3(pretrained=True)
83
+
84
+ # Block 0: input to maxpool1
85
+ block0 = [
86
+ inception.Conv2d_1a_3x3,
87
+ inception.Conv2d_2a_3x3,
88
+ inception.Conv2d_2b_3x3,
89
+ nn.MaxPool2d(kernel_size=3, stride=2)
90
+ ]
91
+ self.blocks.append(nn.Sequential(*block0))
92
+
93
+ # Block 1: maxpool1 to maxpool2
94
+ if self.last_needed_block >= 1:
95
+ block1 = [
96
+ inception.Conv2d_3b_1x1,
97
+ inception.Conv2d_4a_3x3,
98
+ nn.MaxPool2d(kernel_size=3, stride=2)
99
+ ]
100
+ self.blocks.append(nn.Sequential(*block1))
101
+
102
+ # Block 2: maxpool2 to aux classifier
103
+ if self.last_needed_block >= 2:
104
+ block2 = [
105
+ inception.Mixed_5b,
106
+ inception.Mixed_5c,
107
+ inception.Mixed_5d,
108
+ inception.Mixed_6a,
109
+ inception.Mixed_6b,
110
+ inception.Mixed_6c,
111
+ inception.Mixed_6d,
112
+ inception.Mixed_6e,
113
+ ]
114
+ self.blocks.append(nn.Sequential(*block2))
115
+
116
+ # Block 3: aux classifier to final avgpool
117
+ if self.last_needed_block >= 3:
118
+ block3 = [
119
+ inception.Mixed_7a,
120
+ inception.Mixed_7b,
121
+ inception.Mixed_7c,
122
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
123
+ ]
124
+ self.blocks.append(nn.Sequential(*block3))
125
+
126
+ for param in self.parameters():
127
+ param.requires_grad = requires_grad
128
+
129
+ def forward(self, inp):
130
+ """Get Inception feature maps
131
+
132
+ Parameters
133
+ ----------
134
+ inp : torch.autograd.Variable
135
+ Input tensor of shape Bx3xHxW. Values are expected to be in
136
+ range (0, 1)
137
+
138
+ Returns
139
+ -------
140
+ List of torch.autograd.Variable, corresponding to the selected output
141
+ block, sorted ascending by index
142
+ """
143
+ outp = []
144
+ x = inp
145
+
146
+ if self.resize_input:
147
+ x = F.interpolate(x,
148
+ size=(299, 299),
149
+ mode='bilinear',
150
+ align_corners=False)
151
+
152
+ if self.normalize_input:
153
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154
+
155
+ for idx, block in enumerate(self.blocks):
156
+ x = block(x)
157
+ if idx in self.output_blocks:
158
+ outp.append(x)
159
+
160
+ if idx == self.last_needed_block:
161
+ break
162
+
163
+ return outp
164
+
165
+
166
+ def fid_inception_v3():
167
+ """Build pretrained Inception model for FID computation
168
+
169
+ The Inception model for FID computation uses a different set of weights
170
+ and has a slightly different structure than torchvision's Inception.
171
+
172
+ This method first constructs torchvision's Inception and then patches the
173
+ necessary parts that are different in the FID Inception model.
174
+ """
175
+ inception = models.inception_v3(num_classes=1008,
176
+ aux_logits=False,
177
+ pretrained=False)
178
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
179
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
180
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
181
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
182
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
183
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
184
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
185
+ inception.Mixed_7b = FIDInceptionE_1(1280)
186
+ inception.Mixed_7c = FIDInceptionE_2(2048)
187
+
188
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
189
+ inception.load_state_dict(state_dict)
190
+ return inception
191
+
192
+
193
+ class FIDInceptionA(models.inception.InceptionA):
194
+ """InceptionA block patched for FID computation"""
195
+ def __init__(self, in_channels, pool_features):
196
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
197
+
198
+ def forward(self, x):
199
+ branch1x1 = self.branch1x1(x)
200
+
201
+ branch5x5 = self.branch5x5_1(x)
202
+ branch5x5 = self.branch5x5_2(branch5x5)
203
+
204
+ branch3x3dbl = self.branch3x3dbl_1(x)
205
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
206
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
207
+
208
+ # Patch: Tensorflow's average pool does not use the padded zero's in
209
+ # its average calculation
210
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
211
+ count_include_pad=False)
212
+ branch_pool = self.branch_pool(branch_pool)
213
+
214
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
215
+ return torch.cat(outputs, 1)
216
+
217
+
218
+ class FIDInceptionC(models.inception.InceptionC):
219
+ """InceptionC block patched for FID computation"""
220
+ def __init__(self, in_channels, channels_7x7):
221
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
222
+
223
+ def forward(self, x):
224
+ branch1x1 = self.branch1x1(x)
225
+
226
+ branch7x7 = self.branch7x7_1(x)
227
+ branch7x7 = self.branch7x7_2(branch7x7)
228
+ branch7x7 = self.branch7x7_3(branch7x7)
229
+
230
+ branch7x7dbl = self.branch7x7dbl_1(x)
231
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
232
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
233
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
234
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
235
+
236
+ # Patch: Tensorflow's average pool does not use the padded zero's in
237
+ # its average calculation
238
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
239
+ count_include_pad=False)
240
+ branch_pool = self.branch_pool(branch_pool)
241
+
242
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
243
+ return torch.cat(outputs, 1)
244
+
245
+
246
+ class FIDInceptionE_1(models.inception.InceptionE):
247
+ """First InceptionE block patched for FID computation"""
248
+ def __init__(self, in_channels):
249
+ super(FIDInceptionE_1, self).__init__(in_channels)
250
+
251
+ def forward(self, x):
252
+ branch1x1 = self.branch1x1(x)
253
+
254
+ branch3x3 = self.branch3x3_1(x)
255
+ branch3x3 = [
256
+ self.branch3x3_2a(branch3x3),
257
+ self.branch3x3_2b(branch3x3),
258
+ ]
259
+ branch3x3 = torch.cat(branch3x3, 1)
260
+
261
+ branch3x3dbl = self.branch3x3dbl_1(x)
262
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
263
+ branch3x3dbl = [
264
+ self.branch3x3dbl_3a(branch3x3dbl),
265
+ self.branch3x3dbl_3b(branch3x3dbl),
266
+ ]
267
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
268
+
269
+ # Patch: Tensorflow's average pool does not use the padded zero's in
270
+ # its average calculation
271
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
272
+ count_include_pad=False)
273
+ branch_pool = self.branch_pool(branch_pool)
274
+
275
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
276
+ return torch.cat(outputs, 1)
277
+
278
+
279
+ class FIDInceptionE_2(models.inception.InceptionE):
280
+ """Second InceptionE block patched for FID computation"""
281
+ def __init__(self, in_channels):
282
+ super(FIDInceptionE_2, self).__init__(in_channels)
283
+
284
+ def forward(self, x):
285
+ branch1x1 = self.branch1x1(x)
286
+
287
+ branch3x3 = self.branch3x3_1(x)
288
+ branch3x3 = [
289
+ self.branch3x3_2a(branch3x3),
290
+ self.branch3x3_2b(branch3x3),
291
+ ]
292
+ branch3x3 = torch.cat(branch3x3, 1)
293
+
294
+ branch3x3dbl = self.branch3x3dbl_1(x)
295
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
296
+ branch3x3dbl = [
297
+ self.branch3x3dbl_3a(branch3x3dbl),
298
+ self.branch3x3dbl_3b(branch3x3dbl),
299
+ ]
300
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
301
+
302
+ # Patch: The FID Inception model uses max pooling instead of average
303
+ # pooling. This is likely an error in this specific Inception
304
+ # implementation, as other Inception models use average pooling here
305
+ # (which matches the description in the paper).
306
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
307
+ branch_pool = self.branch_pool(branch_pool)
308
+
309
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
310
+ return torch.cat(outputs, 1)
models/stylegan2/stylegan2-pytorch/lpips/__init__.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import numpy as np
7
+ from skimage.measure import compare_ssim
8
+ import torch
9
+ from torch.autograd import Variable
10
+
11
+ from lpips import dist_model
12
+
13
+ class PerceptualLoss(torch.nn.Module):
14
+ def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
15
+ # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
16
+ super(PerceptualLoss, self).__init__()
17
+ print('Setting up Perceptual loss...')
18
+ self.use_gpu = use_gpu
19
+ self.spatial = spatial
20
+ self.gpu_ids = gpu_ids
21
+ self.model = dist_model.DistModel()
22
+ self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
23
+ print('...[%s] initialized'%self.model.name())
24
+ print('...Done')
25
+
26
+ def forward(self, pred, target, normalize=False):
27
+ """
28
+ Pred and target are Variables.
29
+ If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
30
+ If normalize is False, assumes the images are already between [-1,+1]
31
+
32
+ Inputs pred and target are Nx3xHxW
33
+ Output pytorch Variable N long
34
+ """
35
+
36
+ if normalize:
37
+ target = 2 * target - 1
38
+ pred = 2 * pred - 1
39
+
40
+ return self.model.forward(target, pred)
41
+
42
+ def normalize_tensor(in_feat,eps=1e-10):
43
+ norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
44
+ return in_feat/(norm_factor+eps)
45
+
46
+ def l2(p0, p1, range=255.):
47
+ return .5*np.mean((p0 / range - p1 / range)**2)
48
+
49
+ def psnr(p0, p1, peak=255.):
50
+ return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
51
+
52
+ def dssim(p0, p1, range=255.):
53
+ return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
54
+
55
+ def rgb2lab(in_img,mean_cent=False):
56
+ from skimage import color
57
+ img_lab = color.rgb2lab(in_img)
58
+ if(mean_cent):
59
+ img_lab[:,:,0] = img_lab[:,:,0]-50
60
+ return img_lab
61
+
62
+ def tensor2np(tensor_obj):
63
+ # change dimension of a tensor object into a numpy array
64
+ return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
65
+
66
+ def np2tensor(np_obj):
67
+ # change dimenion of np array into tensor array
68
+ return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
69
+
70
+ def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
71
+ # image tensor to lab tensor
72
+ from skimage import color
73
+
74
+ img = tensor2im(image_tensor)
75
+ img_lab = color.rgb2lab(img)
76
+ if(mc_only):
77
+ img_lab[:,:,0] = img_lab[:,:,0]-50
78
+ if(to_norm and not mc_only):
79
+ img_lab[:,:,0] = img_lab[:,:,0]-50
80
+ img_lab = img_lab/100.
81
+
82
+ return np2tensor(img_lab)
83
+
84
+ def tensorlab2tensor(lab_tensor,return_inbnd=False):
85
+ from skimage import color
86
+ import warnings
87
+ warnings.filterwarnings("ignore")
88
+
89
+ lab = tensor2np(lab_tensor)*100.
90
+ lab[:,:,0] = lab[:,:,0]+50
91
+
92
+ rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
93
+ if(return_inbnd):
94
+ # convert back to lab, see if we match
95
+ lab_back = color.rgb2lab(rgb_back.astype('uint8'))
96
+ mask = 1.*np.isclose(lab_back,lab,atol=2.)
97
+ mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
98
+ return (im2tensor(rgb_back),mask)
99
+ else:
100
+ return im2tensor(rgb_back)
101
+
102
+ def rgb2lab(input):
103
+ from skimage import color
104
+ return color.rgb2lab(input / 255.)
105
+
106
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
107
+ image_numpy = image_tensor[0].cpu().float().numpy()
108
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
109
+ return image_numpy.astype(imtype)
110
+
111
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
112
+ return torch.Tensor((image / factor - cent)
113
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
114
+
115
+ def tensor2vec(vector_tensor):
116
+ return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
117
+
118
+ def voc_ap(rec, prec, use_07_metric=False):
119
+ """ ap = voc_ap(rec, prec, [use_07_metric])
120
+ Compute VOC AP given precision and recall.
121
+ If use_07_metric is true, uses the
122
+ VOC 07 11 point method (default:False).
123
+ """
124
+ if use_07_metric:
125
+ # 11 point metric
126
+ ap = 0.
127
+ for t in np.arange(0., 1.1, 0.1):
128
+ if np.sum(rec >= t) == 0:
129
+ p = 0
130
+ else:
131
+ p = np.max(prec[rec >= t])
132
+ ap = ap + p / 11.
133
+ else:
134
+ # correct AP calculation
135
+ # first append sentinel values at the end
136
+ mrec = np.concatenate(([0.], rec, [1.]))
137
+ mpre = np.concatenate(([0.], prec, [0.]))
138
+
139
+ # compute the precision envelope
140
+ for i in range(mpre.size - 1, 0, -1):
141
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
142
+
143
+ # to calculate area under PR curve, look for points
144
+ # where X axis (recall) changes value
145
+ i = np.where(mrec[1:] != mrec[:-1])[0]
146
+
147
+ # and sum (\Delta recall) * prec
148
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
149
+ return ap
150
+
151
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
152
+ # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
153
+ image_numpy = image_tensor[0].cpu().float().numpy()
154
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
155
+ return image_numpy.astype(imtype)
156
+
157
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
158
+ # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
159
+ return torch.Tensor((image / factor - cent)
160
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
models/stylegan2/stylegan2-pytorch/lpips/base_model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch.autograd import Variable
5
+ from pdb import set_trace as st
6
+ from IPython import embed
7
+
8
+ class BaseModel():
9
+ def __init__(self):
10
+ pass;
11
+
12
+ def name(self):
13
+ return 'BaseModel'
14
+
15
+ def initialize(self, use_gpu=True, gpu_ids=[0]):
16
+ self.use_gpu = use_gpu
17
+ self.gpu_ids = gpu_ids
18
+
19
+ def forward(self):
20
+ pass
21
+
22
+ def get_image_paths(self):
23
+ pass
24
+
25
+ def optimize_parameters(self):
26
+ pass
27
+
28
+ def get_current_visuals(self):
29
+ return self.input
30
+
31
+ def get_current_errors(self):
32
+ return {}
33
+
34
+ def save(self, label):
35
+ pass
36
+
37
+ # helper saving function that can be used by subclasses
38
+ def save_network(self, network, path, network_label, epoch_label):
39
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
40
+ save_path = os.path.join(path, save_filename)
41
+ torch.save(network.state_dict(), save_path)
42
+
43
+ # helper loading function that can be used by subclasses
44
+ def load_network(self, network, network_label, epoch_label):
45
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
46
+ save_path = os.path.join(self.save_dir, save_filename)
47
+ print('Loading network from %s'%save_path)
48
+ network.load_state_dict(torch.load(save_path))
49
+
50
+ def update_learning_rate():
51
+ pass
52
+
53
+ def get_image_paths(self):
54
+ return self.image_paths
55
+
56
+ def save_done(self, flag=False):
57
+ np.save(os.path.join(self.save_dir, 'done_flag'),flag)
58
+ np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
models/stylegan2/stylegan2-pytorch/lpips/dist_model.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+
4
+ import sys
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ import os
9
+ from collections import OrderedDict
10
+ from torch.autograd import Variable
11
+ import itertools
12
+ from .base_model import BaseModel
13
+ from scipy.ndimage import zoom
14
+ import fractions
15
+ import functools
16
+ import skimage.transform
17
+ from tqdm import tqdm
18
+
19
+ from IPython import embed
20
+
21
+ from . import networks_basic as networks
22
+ import lpips as util
23
+
24
+ class DistModel(BaseModel):
25
+ def name(self):
26
+ return self.model_name
27
+
28
+ def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
29
+ use_gpu=True, printNet=False, spatial=False,
30
+ is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
31
+ '''
32
+ INPUTS
33
+ model - ['net-lin'] for linearly calibrated network
34
+ ['net'] for off-the-shelf network
35
+ ['L2'] for L2 distance in Lab colorspace
36
+ ['SSIM'] for ssim in RGB colorspace
37
+ net - ['squeeze','alex','vgg']
38
+ model_path - if None, will look in weights/[NET_NAME].pth
39
+ colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
40
+ use_gpu - bool - whether or not to use a GPU
41
+ printNet - bool - whether or not to print network architecture out
42
+ spatial - bool - whether to output an array containing varying distances across spatial dimensions
43
+ spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
44
+ spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
45
+ spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
46
+ is_train - bool - [True] for training mode
47
+ lr - float - initial learning rate
48
+ beta1 - float - initial momentum term for adam
49
+ version - 0.1 for latest, 0.0 was original (with a bug)
50
+ gpu_ids - int array - [0] by default, gpus to use
51
+ '''
52
+ BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
53
+
54
+ self.model = model
55
+ self.net = net
56
+ self.is_train = is_train
57
+ self.spatial = spatial
58
+ self.gpu_ids = gpu_ids
59
+ self.model_name = '%s [%s]'%(model,net)
60
+
61
+ if(self.model == 'net-lin'): # pretrained net + linear layer
62
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
63
+ use_dropout=True, spatial=spatial, version=version, lpips=True)
64
+ kw = {}
65
+ if not use_gpu:
66
+ kw['map_location'] = 'cpu'
67
+ if(model_path is None):
68
+ import inspect
69
+ model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))
70
+
71
+ if(not is_train):
72
+ print('Loading model from: %s'%model_path)
73
+ self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
74
+
75
+ elif(self.model=='net'): # pretrained network
76
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
77
+ elif(self.model in ['L2','l2']):
78
+ self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
79
+ self.model_name = 'L2'
80
+ elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
81
+ self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
82
+ self.model_name = 'SSIM'
83
+ else:
84
+ raise ValueError("Model [%s] not recognized." % self.model)
85
+
86
+ self.parameters = list(self.net.parameters())
87
+
88
+ if self.is_train: # training mode
89
+ # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
90
+ self.rankLoss = networks.BCERankingLoss()
91
+ self.parameters += list(self.rankLoss.net.parameters())
92
+ self.lr = lr
93
+ self.old_lr = lr
94
+ self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
95
+ else: # test mode
96
+ self.net.eval()
97
+
98
+ if(use_gpu):
99
+ self.net.to(gpu_ids[0])
100
+ self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
101
+ if(self.is_train):
102
+ self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
103
+
104
+ if(printNet):
105
+ print('---------- Networks initialized -------------')
106
+ networks.print_network(self.net)
107
+ print('-----------------------------------------------')
108
+
109
+ def forward(self, in0, in1, retPerLayer=False):
110
+ ''' Function computes the distance between image patches in0 and in1
111
+ INPUTS
112
+ in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
113
+ OUTPUT
114
+ computed distances between in0 and in1
115
+ '''
116
+
117
+ return self.net.forward(in0, in1, retPerLayer=retPerLayer)
118
+
119
+ # ***** TRAINING FUNCTIONS *****
120
+ def optimize_parameters(self):
121
+ self.forward_train()
122
+ self.optimizer_net.zero_grad()
123
+ self.backward_train()
124
+ self.optimizer_net.step()
125
+ self.clamp_weights()
126
+
127
+ def clamp_weights(self):
128
+ for module in self.net.modules():
129
+ if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
130
+ module.weight.data = torch.clamp(module.weight.data,min=0)
131
+
132
+ def set_input(self, data):
133
+ self.input_ref = data['ref']
134
+ self.input_p0 = data['p0']
135
+ self.input_p1 = data['p1']
136
+ self.input_judge = data['judge']
137
+
138
+ if(self.use_gpu):
139
+ self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
140
+ self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
141
+ self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
142
+ self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
143
+
144
+ self.var_ref = Variable(self.input_ref,requires_grad=True)
145
+ self.var_p0 = Variable(self.input_p0,requires_grad=True)
146
+ self.var_p1 = Variable(self.input_p1,requires_grad=True)
147
+
148
+ def forward_train(self): # run forward pass
149
+ # print(self.net.module.scaling_layer.shift)
150
+ # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
151
+
152
+ self.d0 = self.forward(self.var_ref, self.var_p0)
153
+ self.d1 = self.forward(self.var_ref, self.var_p1)
154
+ self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
155
+
156
+ self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
157
+
158
+ self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
159
+
160
+ return self.loss_total
161
+
162
+ def backward_train(self):
163
+ torch.mean(self.loss_total).backward()
164
+
165
+ def compute_accuracy(self,d0,d1,judge):
166
+ ''' d0, d1 are Variables, judge is a Tensor '''
167
+ d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
168
+ judge_per = judge.cpu().numpy().flatten()
169
+ return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)
170
+
171
+ def get_current_errors(self):
172
+ retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
173
+ ('acc_r', self.acc_r)])
174
+
175
+ for key in retDict.keys():
176
+ retDict[key] = np.mean(retDict[key])
177
+
178
+ return retDict
179
+
180
+ def get_current_visuals(self):
181
+ zoom_factor = 256/self.var_ref.data.size()[2]
182
+
183
+ ref_img = util.tensor2im(self.var_ref.data)
184
+ p0_img = util.tensor2im(self.var_p0.data)
185
+ p1_img = util.tensor2im(self.var_p1.data)
186
+
187
+ ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
188
+ p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
189
+ p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)
190
+
191
+ return OrderedDict([('ref', ref_img_vis),
192
+ ('p0', p0_img_vis),
193
+ ('p1', p1_img_vis)])
194
+
195
+ def save(self, path, label):
196
+ if(self.use_gpu):
197
+ self.save_network(self.net.module, path, '', label)
198
+ else:
199
+ self.save_network(self.net, path, '', label)
200
+ self.save_network(self.rankLoss.net, path, 'rank', label)
201
+
202
+ def update_learning_rate(self,nepoch_decay):
203
+ lrd = self.lr / nepoch_decay
204
+ lr = self.old_lr - lrd
205
+
206
+ for param_group in self.optimizer_net.param_groups:
207
+ param_group['lr'] = lr
208
+
209
+ print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
210
+ self.old_lr = lr
211
+
212
+ def score_2afc_dataset(data_loader, func, name=''):
213
+ ''' Function computes Two Alternative Forced Choice (2AFC) score using
214
+ distance function 'func' in dataset 'data_loader'
215
+ INPUTS
216
+ data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
217
+ func - callable distance function - calling d=func(in0,in1) should take 2
218
+ pytorch tensors with shape Nx3xXxY, and return numpy array of length N
219
+ OUTPUTS
220
+ [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
221
+ [1] - dictionary with following elements
222
+ d0s,d1s - N arrays containing distances between reference patch to perturbed patches
223
+ gts - N array in [0,1], preferred patch selected by human evaluators
224
+ (closer to "0" for left patch p0, "1" for right patch p1,
225
+ "0.6" means 60pct people preferred right patch, 40pct preferred left)
226
+ scores - N array in [0,1], corresponding to what percentage function agreed with humans
227
+ CONSTS
228
+ N - number of test triplets in data_loader
229
+ '''
230
+
231
+ d0s = []
232
+ d1s = []
233
+ gts = []
234
+
235
+ for data in tqdm(data_loader.load_data(), desc=name):
236
+ d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
237
+ d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
238
+ gts+=data['judge'].cpu().numpy().flatten().tolist()
239
+
240
+ d0s = np.array(d0s)
241
+ d1s = np.array(d1s)
242
+ gts = np.array(gts)
243
+ scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5
244
+
245
+ return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))
246
+
247
+ def score_jnd_dataset(data_loader, func, name=''):
248
+ ''' Function computes JND score using distance function 'func' in dataset 'data_loader'
249
+ INPUTS
250
+ data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
251
+ func - callable distance function - calling d=func(in0,in1) should take 2
252
+ pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
253
+ OUTPUTS
254
+ [0] - JND score in [0,1], mAP score (area under precision-recall curve)
255
+ [1] - dictionary with following elements
256
+ ds - N array containing distances between two patches shown to human evaluator
257
+ sames - N array containing fraction of people who thought the two patches were identical
258
+ CONSTS
259
+ N - number of test triplets in data_loader
260
+ '''
261
+
262
+ ds = []
263
+ gts = []
264
+
265
+ for data in tqdm(data_loader.load_data(), desc=name):
266
+ ds+=func(data['p0'],data['p1']).data.cpu().numpy().tolist()
267
+ gts+=data['same'].cpu().numpy().flatten().tolist()
268
+
269
+ sames = np.array(gts)
270
+ ds = np.array(ds)
271
+
272
+ sorted_inds = np.argsort(ds)
273
+ ds_sorted = ds[sorted_inds]
274
+ sames_sorted = sames[sorted_inds]
275
+
276
+ TPs = np.cumsum(sames_sorted)
277
+ FPs = np.cumsum(1-sames_sorted)
278
+ FNs = np.sum(sames_sorted)-TPs
279
+
280
+ precs = TPs/(TPs+FPs)
281
+ recs = TPs/(TPs+FNs)
282
+ score = util.voc_ap(recs,precs)
283
+
284
+ return(score, dict(ds=ds,sames=sames))
models/stylegan2/stylegan2-pytorch/lpips/networks_basic.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+
4
+ import sys
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.init as init
8
+ from torch.autograd import Variable
9
+ import numpy as np
10
+ from pdb import set_trace as st
11
+ from skimage import color
12
+ from IPython import embed
13
+ from . import pretrained_networks as pn
14
+
15
+ import lpips as util
16
+
17
+ def spatial_average(in_tens, keepdim=True):
18
+ return in_tens.mean([2,3],keepdim=keepdim)
19
+
20
+ def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
21
+ in_H = in_tens.shape[2]
22
+ scale_factor = 1.*out_H/in_H
23
+
24
+ return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
25
+
26
+ # Learned perceptual metric
27
+ class PNetLin(nn.Module):
28
+ def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
29
+ super(PNetLin, self).__init__()
30
+
31
+ self.pnet_type = pnet_type
32
+ self.pnet_tune = pnet_tune
33
+ self.pnet_rand = pnet_rand
34
+ self.spatial = spatial
35
+ self.lpips = lpips
36
+ self.version = version
37
+ self.scaling_layer = ScalingLayer()
38
+
39
+ if(self.pnet_type in ['vgg','vgg16']):
40
+ net_type = pn.vgg16
41
+ self.chns = [64,128,256,512,512]
42
+ elif(self.pnet_type=='alex'):
43
+ net_type = pn.alexnet
44
+ self.chns = [64,192,384,256,256]
45
+ elif(self.pnet_type=='squeeze'):
46
+ net_type = pn.squeezenet
47
+ self.chns = [64,128,256,384,384,512,512]
48
+ self.L = len(self.chns)
49
+
50
+ self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
51
+
52
+ if(lpips):
53
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
54
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
55
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
56
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
57
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
58
+ self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
59
+ if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
60
+ self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
61
+ self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
62
+ self.lins+=[self.lin5,self.lin6]
63
+
64
+ def forward(self, in0, in1, retPerLayer=False):
65
+ # v0.0 - original release had a bug, where input was not scaled
66
+ in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
67
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
68
+ feats0, feats1, diffs = {}, {}, {}
69
+
70
+ for kk in range(self.L):
71
+ feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
72
+ diffs[kk] = (feats0[kk]-feats1[kk])**2
73
+
74
+ if(self.lpips):
75
+ if(self.spatial):
76
+ res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
77
+ else:
78
+ res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
79
+ else:
80
+ if(self.spatial):
81
+ res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
82
+ else:
83
+ res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
84
+
85
+ val = res[0]
86
+ for l in range(1,self.L):
87
+ val += res[l]
88
+
89
+ if(retPerLayer):
90
+ return (val, res)
91
+ else:
92
+ return val
93
+
94
+ class ScalingLayer(nn.Module):
95
+ def __init__(self):
96
+ super(ScalingLayer, self).__init__()
97
+ self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
98
+ self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
99
+
100
+ def forward(self, inp):
101
+ return (inp - self.shift) / self.scale
102
+
103
+
104
+ class NetLinLayer(nn.Module):
105
+ ''' A single linear layer which does a 1x1 conv '''
106
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
107
+ super(NetLinLayer, self).__init__()
108
+
109
+ layers = [nn.Dropout(),] if(use_dropout) else []
110
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
111
+ self.model = nn.Sequential(*layers)
112
+
113
+
114
+ class Dist2LogitLayer(nn.Module):
115
+ ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
116
+ def __init__(self, chn_mid=32, use_sigmoid=True):
117
+ super(Dist2LogitLayer, self).__init__()
118
+
119
+ layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
120
+ layers += [nn.LeakyReLU(0.2,True),]
121
+ layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
122
+ layers += [nn.LeakyReLU(0.2,True),]
123
+ layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
124
+ if(use_sigmoid):
125
+ layers += [nn.Sigmoid(),]
126
+ self.model = nn.Sequential(*layers)
127
+
128
+ def forward(self,d0,d1,eps=0.1):
129
+ return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
130
+
131
+ class BCERankingLoss(nn.Module):
132
+ def __init__(self, chn_mid=32):
133
+ super(BCERankingLoss, self).__init__()
134
+ self.net = Dist2LogitLayer(chn_mid=chn_mid)
135
+ # self.parameters = list(self.net.parameters())
136
+ self.loss = torch.nn.BCELoss()
137
+
138
+ def forward(self, d0, d1, judge):
139
+ per = (judge+1.)/2.
140
+ self.logit = self.net.forward(d0,d1)
141
+ return self.loss(self.logit, per)
142
+
143
+ # L2, DSSIM metrics
144
+ class FakeNet(nn.Module):
145
+ def __init__(self, use_gpu=True, colorspace='Lab'):
146
+ super(FakeNet, self).__init__()
147
+ self.use_gpu = use_gpu
148
+ self.colorspace=colorspace
149
+
150
+ class L2(FakeNet):
151
+
152
+ def forward(self, in0, in1, retPerLayer=None):
153
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
154
+
155
+ if(self.colorspace=='RGB'):
156
+ (N,C,X,Y) = in0.size()
157
+ value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
158
+ return value
159
+ elif(self.colorspace=='Lab'):
160
+ value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
161
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
162
+ ret_var = Variable( torch.Tensor((value,) ) )
163
+ if(self.use_gpu):
164
+ ret_var = ret_var.cuda()
165
+ return ret_var
166
+
167
+ class DSSIM(FakeNet):
168
+
169
+ def forward(self, in0, in1, retPerLayer=None):
170
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
171
+
172
+ if(self.colorspace=='RGB'):
173
+ value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
174
+ elif(self.colorspace=='Lab'):
175
+ value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
176
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
177
+ ret_var = Variable( torch.Tensor((value,) ) )
178
+ if(self.use_gpu):
179
+ ret_var = ret_var.cuda()
180
+ return ret_var
181
+
182
+ def print_network(net):
183
+ num_params = 0
184
+ for param in net.parameters():
185
+ num_params += param.numel()
186
+ print('Network',net)
187
+ print('Total number of parameters: %d' % num_params)
models/stylegan2/stylegan2-pytorch/lpips/pretrained_networks.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ from torchvision import models as tv
4
+ from IPython import embed
5
+
6
+ class squeezenet(torch.nn.Module):
7
+ def __init__(self, requires_grad=False, pretrained=True):
8
+ super(squeezenet, self).__init__()
9
+ pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
10
+ self.slice1 = torch.nn.Sequential()
11
+ self.slice2 = torch.nn.Sequential()
12
+ self.slice3 = torch.nn.Sequential()
13
+ self.slice4 = torch.nn.Sequential()
14
+ self.slice5 = torch.nn.Sequential()
15
+ self.slice6 = torch.nn.Sequential()
16
+ self.slice7 = torch.nn.Sequential()
17
+ self.N_slices = 7
18
+ for x in range(2):
19
+ self.slice1.add_module(str(x), pretrained_features[x])
20
+ for x in range(2,5):
21
+ self.slice2.add_module(str(x), pretrained_features[x])
22
+ for x in range(5, 8):
23
+ self.slice3.add_module(str(x), pretrained_features[x])
24
+ for x in range(8, 10):
25
+ self.slice4.add_module(str(x), pretrained_features[x])
26
+ for x in range(10, 11):
27
+ self.slice5.add_module(str(x), pretrained_features[x])
28
+ for x in range(11, 12):
29
+ self.slice6.add_module(str(x), pretrained_features[x])
30
+ for x in range(12, 13):
31
+ self.slice7.add_module(str(x), pretrained_features[x])
32
+ if not requires_grad:
33
+ for param in self.parameters():
34
+ param.requires_grad = False
35
+
36
+ def forward(self, X):
37
+ h = self.slice1(X)
38
+ h_relu1 = h
39
+ h = self.slice2(h)
40
+ h_relu2 = h
41
+ h = self.slice3(h)
42
+ h_relu3 = h
43
+ h = self.slice4(h)
44
+ h_relu4 = h
45
+ h = self.slice5(h)
46
+ h_relu5 = h
47
+ h = self.slice6(h)
48
+ h_relu6 = h
49
+ h = self.slice7(h)
50
+ h_relu7 = h
51
+ vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
52
+ out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
53
+
54
+ return out
55
+
56
+
57
+ class alexnet(torch.nn.Module):
58
+ def __init__(self, requires_grad=False, pretrained=True):
59
+ super(alexnet, self).__init__()
60
+ alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
61
+ self.slice1 = torch.nn.Sequential()
62
+ self.slice2 = torch.nn.Sequential()
63
+ self.slice3 = torch.nn.Sequential()
64
+ self.slice4 = torch.nn.Sequential()
65
+ self.slice5 = torch.nn.Sequential()
66
+ self.N_slices = 5
67
+ for x in range(2):
68
+ self.slice1.add_module(str(x), alexnet_pretrained_features[x])
69
+ for x in range(2, 5):
70
+ self.slice2.add_module(str(x), alexnet_pretrained_features[x])
71
+ for x in range(5, 8):
72
+ self.slice3.add_module(str(x), alexnet_pretrained_features[x])
73
+ for x in range(8, 10):
74
+ self.slice4.add_module(str(x), alexnet_pretrained_features[x])
75
+ for x in range(10, 12):
76
+ self.slice5.add_module(str(x), alexnet_pretrained_features[x])
77
+ if not requires_grad:
78
+ for param in self.parameters():
79
+ param.requires_grad = False
80
+
81
+ def forward(self, X):
82
+ h = self.slice1(X)
83
+ h_relu1 = h
84
+ h = self.slice2(h)
85
+ h_relu2 = h
86
+ h = self.slice3(h)
87
+ h_relu3 = h
88
+ h = self.slice4(h)
89
+ h_relu4 = h
90
+ h = self.slice5(h)
91
+ h_relu5 = h
92
+ alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
93
+ out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
94
+
95
+ return out
96
+
97
+ class vgg16(torch.nn.Module):
98
+ def __init__(self, requires_grad=False, pretrained=True):
99
+ super(vgg16, self).__init__()
100
+ vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
101
+ self.slice1 = torch.nn.Sequential()
102
+ self.slice2 = torch.nn.Sequential()
103
+ self.slice3 = torch.nn.Sequential()
104
+ self.slice4 = torch.nn.Sequential()
105
+ self.slice5 = torch.nn.Sequential()
106
+ self.N_slices = 5
107
+ for x in range(4):
108
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
109
+ for x in range(4, 9):
110
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
111
+ for x in range(9, 16):
112
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
113
+ for x in range(16, 23):
114
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
115
+ for x in range(23, 30):
116
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
117
+ if not requires_grad:
118
+ for param in self.parameters():
119
+ param.requires_grad = False
120
+
121
+ def forward(self, X):
122
+ h = self.slice1(X)
123
+ h_relu1_2 = h
124
+ h = self.slice2(h)
125
+ h_relu2_2 = h
126
+ h = self.slice3(h)
127
+ h_relu3_3 = h
128
+ h = self.slice4(h)
129
+ h_relu4_3 = h
130
+ h = self.slice5(h)
131
+ h_relu5_3 = h
132
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
133
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
134
+
135
+ return out
136
+
137
+
138
+
139
+ class resnet(torch.nn.Module):
140
+ def __init__(self, requires_grad=False, pretrained=True, num=18):
141
+ super(resnet, self).__init__()
142
+ if(num==18):
143
+ self.net = tv.resnet18(pretrained=pretrained)
144
+ elif(num==34):
145
+ self.net = tv.resnet34(pretrained=pretrained)
146
+ elif(num==50):
147
+ self.net = tv.resnet50(pretrained=pretrained)
148
+ elif(num==101):
149
+ self.net = tv.resnet101(pretrained=pretrained)
150
+ elif(num==152):
151
+ self.net = tv.resnet152(pretrained=pretrained)
152
+ self.N_slices = 5
153
+
154
+ self.conv1 = self.net.conv1
155
+ self.bn1 = self.net.bn1
156
+ self.relu = self.net.relu
157
+ self.maxpool = self.net.maxpool
158
+ self.layer1 = self.net.layer1
159
+ self.layer2 = self.net.layer2
160
+ self.layer3 = self.net.layer3
161
+ self.layer4 = self.net.layer4
162
+
163
+ def forward(self, X):
164
+ h = self.conv1(X)
165
+ h = self.bn1(h)
166
+ h = self.relu(h)
167
+ h_relu1 = h
168
+ h = self.maxpool(h)
169
+ h = self.layer1(h)
170
+ h_conv2 = h
171
+ h = self.layer2(h)
172
+ h_conv3 = h
173
+ h = self.layer3(h)
174
+ h_conv4 = h
175
+ h = self.layer4(h)
176
+ h_conv5 = h
177
+
178
+ outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
179
+ out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
180
+
181
+ return out
models/stylegan2/stylegan2-pytorch/lpips/weights/v0.0/alex.pth ADDED
Binary file (5.46 kB). View file
 
models/stylegan2/stylegan2-pytorch/lpips/weights/v0.0/squeeze.pth ADDED
Binary file (10.1 kB). View file
 
models/stylegan2/stylegan2-pytorch/lpips/weights/v0.0/vgg.pth ADDED
Binary file (6.74 kB). View file
 
models/stylegan2/stylegan2-pytorch/lpips/weights/v0.1/alex.pth ADDED
Binary file (6.01 kB). View file
 
models/stylegan2/stylegan2-pytorch/lpips/weights/v0.1/squeeze.pth ADDED
Binary file (10.8 kB). View file
 
models/stylegan2/stylegan2-pytorch/lpips/weights/v0.1/vgg.pth ADDED
Binary file (7.29 kB). View file
 
models/stylegan2/stylegan2-pytorch/model.py ADDED
@@ -0,0 +1,703 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import functools
4
+ import operator
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torch.autograd import Function
10
+
11
+ from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
12
+
13
+
14
+ class PixelNorm(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def forward(self, input):
19
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
20
+
21
+
22
+ def make_kernel(k):
23
+ k = torch.tensor(k, dtype=torch.float32)
24
+
25
+ if k.ndim == 1:
26
+ k = k[None, :] * k[:, None]
27
+
28
+ k /= k.sum()
29
+
30
+ return k
31
+
32
+
33
+ class Upsample(nn.Module):
34
+ def __init__(self, kernel, factor=2):
35
+ super().__init__()
36
+
37
+ self.factor = factor
38
+ kernel = make_kernel(kernel) * (factor ** 2)
39
+ self.register_buffer('kernel', kernel)
40
+
41
+ p = kernel.shape[0] - factor
42
+
43
+ pad0 = (p + 1) // 2 + factor - 1
44
+ pad1 = p // 2
45
+
46
+ self.pad = (pad0, pad1)
47
+
48
+ def forward(self, input):
49
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
50
+
51
+ return out
52
+
53
+
54
+ class Downsample(nn.Module):
55
+ def __init__(self, kernel, factor=2):
56
+ super().__init__()
57
+
58
+ self.factor = factor
59
+ kernel = make_kernel(kernel)
60
+ self.register_buffer('kernel', kernel)
61
+
62
+ p = kernel.shape[0] - factor
63
+
64
+ pad0 = (p + 1) // 2
65
+ pad1 = p // 2
66
+
67
+ self.pad = (pad0, pad1)
68
+
69
+ def forward(self, input):
70
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
71
+
72
+ return out
73
+
74
+
75
+ class Blur(nn.Module):
76
+ def __init__(self, kernel, pad, upsample_factor=1):
77
+ super().__init__()
78
+
79
+ kernel = make_kernel(kernel)
80
+
81
+ if upsample_factor > 1:
82
+ kernel = kernel * (upsample_factor ** 2)
83
+
84
+ self.register_buffer('kernel', kernel)
85
+
86
+ self.pad = pad
87
+
88
+ def forward(self, input):
89
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
90
+
91
+ return out
92
+
93
+
94
+ class EqualConv2d(nn.Module):
95
+ def __init__(
96
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
97
+ ):
98
+ super().__init__()
99
+
100
+ self.weight = nn.Parameter(
101
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
102
+ )
103
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
104
+
105
+ self.stride = stride
106
+ self.padding = padding
107
+
108
+ if bias:
109
+ self.bias = nn.Parameter(torch.zeros(out_channel))
110
+
111
+ else:
112
+ self.bias = None
113
+
114
+ def forward(self, input):
115
+ out = F.conv2d(
116
+ input,
117
+ self.weight * self.scale,
118
+ bias=self.bias,
119
+ stride=self.stride,
120
+ padding=self.padding,
121
+ )
122
+
123
+ return out
124
+
125
+ def __repr__(self):
126
+ return (
127
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
128
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
129
+ )
130
+
131
+
132
+ class EqualLinear(nn.Module):
133
+ def __init__(
134
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
135
+ ):
136
+ super().__init__()
137
+
138
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
139
+
140
+ if bias:
141
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
142
+
143
+ else:
144
+ self.bias = None
145
+
146
+ self.activation = activation
147
+
148
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
149
+ self.lr_mul = lr_mul
150
+
151
+ def forward(self, input):
152
+ if self.activation:
153
+ out = F.linear(input, self.weight * self.scale)
154
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
155
+
156
+ else:
157
+ out = F.linear(
158
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
159
+ )
160
+
161
+ return out
162
+
163
+ def __repr__(self):
164
+ return (
165
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
166
+ )
167
+
168
+
169
+ class ScaledLeakyReLU(nn.Module):
170
+ def __init__(self, negative_slope=0.2):
171
+ super().__init__()
172
+
173
+ self.negative_slope = negative_slope
174
+
175
+ def forward(self, input):
176
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
177
+
178
+ return out * math.sqrt(2)
179
+
180
+
181
+ class ModulatedConv2d(nn.Module):
182
+ def __init__(
183
+ self,
184
+ in_channel,
185
+ out_channel,
186
+ kernel_size,
187
+ style_dim,
188
+ demodulate=True,
189
+ upsample=False,
190
+ downsample=False,
191
+ blur_kernel=[1, 3, 3, 1],
192
+ ):
193
+ super().__init__()
194
+
195
+ self.eps = 1e-8
196
+ self.kernel_size = kernel_size
197
+ self.in_channel = in_channel
198
+ self.out_channel = out_channel
199
+ self.upsample = upsample
200
+ self.downsample = downsample
201
+
202
+ if upsample:
203
+ factor = 2
204
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
205
+ pad0 = (p + 1) // 2 + factor - 1
206
+ pad1 = p // 2 + 1
207
+
208
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
209
+
210
+ if downsample:
211
+ factor = 2
212
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
213
+ pad0 = (p + 1) // 2
214
+ pad1 = p // 2
215
+
216
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
217
+
218
+ fan_in = in_channel * kernel_size ** 2
219
+ self.scale = 1 / math.sqrt(fan_in)
220
+ self.padding = kernel_size // 2
221
+
222
+ self.weight = nn.Parameter(
223
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
224
+ )
225
+
226
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
227
+
228
+ self.demodulate = demodulate
229
+
230
+ def __repr__(self):
231
+ return (
232
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
233
+ f'upsample={self.upsample}, downsample={self.downsample})'
234
+ )
235
+
236
+ def forward(self, input, style):
237
+ batch, in_channel, height, width = input.shape
238
+
239
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
240
+ weight = self.scale * self.weight * style
241
+
242
+ if self.demodulate:
243
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
244
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
245
+
246
+ weight = weight.view(
247
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
248
+ )
249
+
250
+ if self.upsample:
251
+ input = input.view(1, batch * in_channel, height, width)
252
+ weight = weight.view(
253
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
254
+ )
255
+ weight = weight.transpose(1, 2).reshape(
256
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
257
+ )
258
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
259
+ _, _, height, width = out.shape
260
+ out = out.view(batch, self.out_channel, height, width)
261
+ out = self.blur(out)
262
+
263
+ elif self.downsample:
264
+ input = self.blur(input)
265
+ _, _, height, width = input.shape
266
+ input = input.view(1, batch * in_channel, height, width)
267
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
268
+ _, _, height, width = out.shape
269
+ out = out.view(batch, self.out_channel, height, width)
270
+
271
+ else:
272
+ input = input.view(1, batch * in_channel, height, width)
273
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
274
+ _, _, height, width = out.shape
275
+ out = out.view(batch, self.out_channel, height, width)
276
+
277
+ return out
278
+
279
+
280
+ class NoiseInjection(nn.Module):
281
+ def __init__(self):
282
+ super().__init__()
283
+
284
+ self.weight = nn.Parameter(torch.zeros(1))
285
+
286
+ def forward(self, image, noise=None):
287
+ if noise is None:
288
+ batch, _, height, width = image.shape
289
+ noise = image.new_empty(batch, 1, height, width).normal_()
290
+
291
+ return image + self.weight * noise
292
+
293
+
294
+ class ConstantInput(nn.Module):
295
+ def __init__(self, channel, size=4):
296
+ super().__init__()
297
+
298
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
299
+
300
+ def forward(self, input):
301
+ batch = input.shape[0]
302
+ out = self.input.repeat(batch, 1, 1, 1)
303
+
304
+ return out
305
+
306
+
307
+ class StyledConv(nn.Module):
308
+ def __init__(
309
+ self,
310
+ in_channel,
311
+ out_channel,
312
+ kernel_size,
313
+ style_dim,
314
+ upsample=False,
315
+ blur_kernel=[1, 3, 3, 1],
316
+ demodulate=True,
317
+ ):
318
+ super().__init__()
319
+
320
+ self.conv = ModulatedConv2d(
321
+ in_channel,
322
+ out_channel,
323
+ kernel_size,
324
+ style_dim,
325
+ upsample=upsample,
326
+ blur_kernel=blur_kernel,
327
+ demodulate=demodulate,
328
+ )
329
+
330
+ self.noise = NoiseInjection()
331
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
332
+ # self.activate = ScaledLeakyReLU(0.2)
333
+ self.activate = FusedLeakyReLU(out_channel)
334
+
335
+ def forward(self, input, style, noise=None):
336
+ out = self.conv(input, style)
337
+ out = self.noise(out, noise=noise)
338
+ # out = out + self.bias
339
+ out = self.activate(out)
340
+
341
+ return out
342
+
343
+
344
+ class ToRGB(nn.Module):
345
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
346
+ super().__init__()
347
+
348
+ if upsample:
349
+ self.upsample = Upsample(blur_kernel)
350
+
351
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
352
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
353
+
354
+ def forward(self, input, style, skip=None):
355
+ out = self.conv(input, style)
356
+ out = out + self.bias
357
+
358
+ if skip is not None:
359
+ skip = self.upsample(skip)
360
+
361
+ out = out + skip
362
+
363
+ return out
364
+
365
+ # Wrapper that gives name to tensor
366
+ class NamedTensor(nn.Module):
367
+ def __init__(self):
368
+ super().__init__()
369
+
370
+ def forward(self, x):
371
+ return x
372
+
373
+ # Give each style a unique name
374
+ class StridedStyle(nn.ModuleList):
375
+ def __init__(self, n_latents):
376
+ super().__init__([NamedTensor() for _ in range(n_latents)])
377
+ self.n_latents = n_latents
378
+
379
+ def forward(self, x):
380
+ # x already strided
381
+ styles = [self[i](x[:, i, :]) for i in range(self.n_latents)]
382
+ return torch.stack(styles, dim=1)
383
+
384
+ class Generator(nn.Module):
385
+ def __init__(
386
+ self,
387
+ size,
388
+ style_dim,
389
+ n_mlp,
390
+ channel_multiplier=2,
391
+ blur_kernel=[1, 3, 3, 1],
392
+ lr_mlp=0.01,
393
+ ):
394
+ super().__init__()
395
+
396
+ self.size = size
397
+
398
+ self.style_dim = style_dim
399
+
400
+ layers = [PixelNorm()]
401
+
402
+ for i in range(n_mlp):
403
+ layers.append(
404
+ EqualLinear(
405
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
406
+ )
407
+ )
408
+
409
+ self.style = nn.Sequential(*layers)
410
+
411
+ self.channels = {
412
+ 4: 512,
413
+ 8: 512,
414
+ 16: 512,
415
+ 32: 512,
416
+ 64: 256 * channel_multiplier,
417
+ 128: 128 * channel_multiplier,
418
+ 256: 64 * channel_multiplier,
419
+ 512: 32 * channel_multiplier,
420
+ 1024: 16 * channel_multiplier,
421
+ }
422
+
423
+ self.input = ConstantInput(self.channels[4])
424
+ self.conv1 = StyledConv(
425
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
426
+ )
427
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
428
+
429
+ self.log_size = int(math.log(size, 2))
430
+ self.num_layers = (self.log_size - 2) * 2 + 1
431
+
432
+ self.convs = nn.ModuleList()
433
+ self.upsamples = nn.ModuleList()
434
+ self.to_rgbs = nn.ModuleList()
435
+ self.noises = nn.Module()
436
+
437
+ in_channel = self.channels[4]
438
+
439
+ for layer_idx in range(self.num_layers):
440
+ res = (layer_idx + 5) // 2
441
+ shape = [1, 1, 2 ** res, 2 ** res]
442
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
443
+
444
+ for i in range(3, self.log_size + 1):
445
+ out_channel = self.channels[2 ** i]
446
+
447
+ self.convs.append(
448
+ StyledConv(
449
+ in_channel,
450
+ out_channel,
451
+ 3,
452
+ style_dim,
453
+ upsample=True,
454
+ blur_kernel=blur_kernel,
455
+ )
456
+ )
457
+
458
+ self.convs.append(
459
+ StyledConv(
460
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
461
+ )
462
+ )
463
+
464
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
465
+
466
+ in_channel = out_channel
467
+
468
+ self.n_latent = self.log_size * 2 - 2
469
+ self.strided_style = StridedStyle(self.n_latent)
470
+
471
+ def make_noise(self):
472
+ device = self.input.input.device
473
+
474
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
475
+
476
+ for i in range(3, self.log_size + 1):
477
+ for _ in range(2):
478
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
479
+
480
+ return noises
481
+
482
+ def mean_latent(self, n_latent):
483
+ latent_in = torch.randn(
484
+ n_latent, self.style_dim, device=self.input.input.device
485
+ )
486
+ latent = self.style(latent_in).mean(0, keepdim=True)
487
+
488
+ return latent
489
+
490
+ def get_latent(self, input):
491
+ return self.style(input)
492
+
493
+ def forward(
494
+ self,
495
+ styles,
496
+ return_latents=False,
497
+ inject_index=None,
498
+ truncation=1,
499
+ truncation_latent=None,
500
+ input_is_w=False,
501
+ noise=None,
502
+ randomize_noise=True,
503
+ ):
504
+ if not input_is_w:
505
+ styles = [self.style(s) for s in styles]
506
+
507
+ if noise is None:
508
+ if randomize_noise:
509
+ noise = [None] * self.num_layers
510
+ else:
511
+ noise = [
512
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
513
+ ]
514
+
515
+ if truncation < 1:
516
+ style_t = []
517
+
518
+ for style in styles:
519
+ style_t.append(
520
+ truncation_latent + truncation * (style - truncation_latent)
521
+ )
522
+
523
+ styles = style_t
524
+
525
+ if len(styles) == 1:
526
+ # One global latent
527
+ inject_index = self.n_latent
528
+
529
+ if styles[0].ndim < 3:
530
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
531
+
532
+ else:
533
+ latent = styles[0]
534
+
535
+ elif len(styles) == 2:
536
+ # Latent mixing with two latents
537
+ if inject_index is None:
538
+ inject_index = random.randint(1, self.n_latent - 1)
539
+
540
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
541
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
542
+
543
+ latent = self.strided_style(torch.cat([latent, latent2], 1))
544
+ else:
545
+ # One latent per layer
546
+ assert len(styles) == self.n_latent, f'Expected {self.n_latents} latents, got {len(styles)}'
547
+ styles = torch.stack(styles, dim=1) # [N, 18, 512]
548
+ latent = self.strided_style(styles)
549
+
550
+ out = self.input(latent)
551
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
552
+
553
+ skip = self.to_rgb1(out, latent[:, 1])
554
+
555
+ i = 1
556
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
557
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
558
+ ):
559
+ out = conv1(out, latent[:, i], noise=noise1)
560
+ out = conv2(out, latent[:, i + 1], noise=noise2)
561
+ skip = to_rgb(out, latent[:, i + 2], skip)
562
+
563
+ i += 2
564
+
565
+ image = skip
566
+
567
+ if return_latents:
568
+ return image, latent
569
+
570
+ else:
571
+ return image, None
572
+
573
+
574
+ class ConvLayer(nn.Sequential):
575
+ def __init__(
576
+ self,
577
+ in_channel,
578
+ out_channel,
579
+ kernel_size,
580
+ downsample=False,
581
+ blur_kernel=[1, 3, 3, 1],
582
+ bias=True,
583
+ activate=True,
584
+ ):
585
+ layers = []
586
+
587
+ if downsample:
588
+ factor = 2
589
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
590
+ pad0 = (p + 1) // 2
591
+ pad1 = p // 2
592
+
593
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
594
+
595
+ stride = 2
596
+ self.padding = 0
597
+
598
+ else:
599
+ stride = 1
600
+ self.padding = kernel_size // 2
601
+
602
+ layers.append(
603
+ EqualConv2d(
604
+ in_channel,
605
+ out_channel,
606
+ kernel_size,
607
+ padding=self.padding,
608
+ stride=stride,
609
+ bias=bias and not activate,
610
+ )
611
+ )
612
+
613
+ if activate:
614
+ if bias:
615
+ layers.append(FusedLeakyReLU(out_channel))
616
+
617
+ else:
618
+ layers.append(ScaledLeakyReLU(0.2))
619
+
620
+ super().__init__(*layers)
621
+
622
+
623
+ class ResBlock(nn.Module):
624
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
625
+ super().__init__()
626
+
627
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
628
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
629
+
630
+ self.skip = ConvLayer(
631
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
632
+ )
633
+
634
+ def forward(self, input):
635
+ out = self.conv1(input)
636
+ out = self.conv2(out)
637
+
638
+ skip = self.skip(input)
639
+ out = (out + skip) / math.sqrt(2)
640
+
641
+ return out
642
+
643
+
644
+ class Discriminator(nn.Module):
645
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
646
+ super().__init__()
647
+
648
+ channels = {
649
+ 4: 512,
650
+ 8: 512,
651
+ 16: 512,
652
+ 32: 512,
653
+ 64: 256 * channel_multiplier,
654
+ 128: 128 * channel_multiplier,
655
+ 256: 64 * channel_multiplier,
656
+ 512: 32 * channel_multiplier,
657
+ 1024: 16 * channel_multiplier,
658
+ }
659
+
660
+ convs = [ConvLayer(3, channels[size], 1)]
661
+
662
+ log_size = int(math.log(size, 2))
663
+
664
+ in_channel = channels[size]
665
+
666
+ for i in range(log_size, 2, -1):
667
+ out_channel = channels[2 ** (i - 1)]
668
+
669
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
670
+
671
+ in_channel = out_channel
672
+
673
+ self.convs = nn.Sequential(*convs)
674
+
675
+ self.stddev_group = 4
676
+ self.stddev_feat = 1
677
+
678
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
679
+ self.final_linear = nn.Sequential(
680
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
681
+ EqualLinear(channels[4], 1),
682
+ )
683
+
684
+ def forward(self, input):
685
+ out = self.convs(input)
686
+
687
+ batch, channel, height, width = out.shape
688
+ group = min(batch, self.stddev_group)
689
+ stddev = out.view(
690
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
691
+ )
692
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
693
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
694
+ stddev = stddev.repeat(group, 1, height, width)
695
+ out = torch.cat([out, stddev], 1)
696
+
697
+ out = self.final_conv(out)
698
+
699
+ out = out.view(batch, -1)
700
+ out = self.final_linear(out)
701
+
702
+ return out
703
+
models/stylegan2/stylegan2-pytorch/non_leaking.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def translate_mat(t_x, t_y):
8
+ batch = t_x.shape[0]
9
+
10
+ mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)
11
+ translate = torch.stack((t_x, t_y), 1)
12
+ mat[:, :2, 2] = translate
13
+
14
+ return mat
15
+
16
+
17
+ def rotate_mat(theta):
18
+ batch = theta.shape[0]
19
+
20
+ mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)
21
+ sin_t = torch.sin(theta)
22
+ cos_t = torch.cos(theta)
23
+ rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
24
+ mat[:, :2, :2] = rot
25
+
26
+ return mat
27
+
28
+
29
+ def scale_mat(s_x, s_y):
30
+ batch = s_x.shape[0]
31
+
32
+ mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)
33
+ mat[:, 0, 0] = s_x
34
+ mat[:, 1, 1] = s_y
35
+
36
+ return mat
37
+
38
+
39
+ def lognormal_sample(size, mean=0, std=1):
40
+ return torch.empty(size).log_normal_(mean=mean, std=std)
41
+
42
+
43
+ def category_sample(size, categories):
44
+ category = torch.tensor(categories)
45
+ sample = torch.randint(high=len(categories), size=(size,))
46
+
47
+ return category[sample]
48
+
49
+
50
+ def uniform_sample(size, low, high):
51
+ return torch.empty(size).uniform_(low, high)
52
+
53
+
54
+ def normal_sample(size, mean=0, std=1):
55
+ return torch.empty(size).normal_(mean, std)
56
+
57
+
58
+ def bernoulli_sample(size, p):
59
+ return torch.empty(size).bernoulli_(p)
60
+
61
+
62
+ def random_affine_apply(p, transform, prev, eye):
63
+ size = transform.shape[0]
64
+ select = bernoulli_sample(size, p).view(size, 1, 1)
65
+ select_transform = select * transform + (1 - select) * eye
66
+
67
+ return select_transform @ prev
68
+
69
+
70
+ def sample_affine(p, size, height, width):
71
+ G = torch.eye(3).unsqueeze(0).repeat(size, 1, 1)
72
+ eye = G
73
+
74
+ # flip
75
+ param = category_sample(size, (0, 1))
76
+ Gc = scale_mat(1 - 2.0 * param, torch.ones(size))
77
+ G = random_affine_apply(p, Gc, G, eye)
78
+ # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
79
+
80
+ # 90 rotate
81
+ param = category_sample(size, (0, 3))
82
+ Gc = rotate_mat(-math.pi / 2 * param)
83
+ G = random_affine_apply(p, Gc, G, eye)
84
+ # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
85
+
86
+ # integer translate
87
+ param = uniform_sample(size, -0.125, 0.125)
88
+ param_height = torch.round(param * height) / height
89
+ param_width = torch.round(param * width) / width
90
+ Gc = translate_mat(param_width, param_height)
91
+ G = random_affine_apply(p, Gc, G, eye)
92
+ # print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
93
+
94
+ # isotropic scale
95
+ param = lognormal_sample(size, std=0.2 * math.log(2))
96
+ Gc = scale_mat(param, param)
97
+ G = random_affine_apply(p, Gc, G, eye)
98
+ # print('isotropic scale', G, scale_mat(param, param), sep='\n')
99
+
100
+ p_rot = 1 - math.sqrt(1 - p)
101
+
102
+ # pre-rotate
103
+ param = uniform_sample(size, -math.pi, math.pi)
104
+ Gc = rotate_mat(-param)
105
+ G = random_affine_apply(p_rot, Gc, G, eye)
106
+ # print('pre-rotate', G, rotate_mat(-param), sep='\n')
107
+
108
+ # anisotropic scale
109
+ param = lognormal_sample(size, std=0.2 * math.log(2))
110
+ Gc = scale_mat(param, 1 / param)
111
+ G = random_affine_apply(p, Gc, G, eye)
112
+ # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
113
+
114
+ # post-rotate
115
+ param = uniform_sample(size, -math.pi, math.pi)
116
+ Gc = rotate_mat(-param)
117
+ G = random_affine_apply(p_rot, Gc, G, eye)
118
+ # print('post-rotate', G, rotate_mat(-param), sep='\n')
119
+
120
+ # fractional translate
121
+ param = normal_sample(size, std=0.125)
122
+ Gc = translate_mat(param, param)
123
+ G = random_affine_apply(p, Gc, G, eye)
124
+ # print('fractional translate', G, translate_mat(param, param), sep='\n')
125
+
126
+ return G
127
+
128
+
129
+ def apply_affine(img, G):
130
+ grid = F.affine_grid(
131
+ torch.inverse(G).to(img)[:, :2, :], img.shape, align_corners=False
132
+ )
133
+ img_affine = F.grid_sample(
134
+ img, grid, mode="bilinear", align_corners=False, padding_mode="reflection"
135
+ )
136
+
137
+ return img_affine
models/stylegan2/stylegan2-pytorch/op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
models/stylegan2/stylegan2-pytorch/op/fused_act.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.autograd import Function
7
+ import torch.nn.functional as F
8
+ from torch.utils.cpp_extension import load
9
+
10
+ use_fallback = False
11
+
12
+ # Try loading precompiled, otherwise use native fallback
13
+ try:
14
+ import fused
15
+ except ModuleNotFoundError as e:
16
+ print('StyleGAN2: Optimized CUDA op FusedLeakyReLU not available, using native PyTorch fallback.')
17
+ use_fallback = True
18
+
19
+
20
+ class FusedLeakyReLUFunctionBackward(Function):
21
+ @staticmethod
22
+ def forward(ctx, grad_output, out, negative_slope, scale):
23
+ ctx.save_for_backward(out)
24
+ ctx.negative_slope = negative_slope
25
+ ctx.scale = scale
26
+
27
+ empty = grad_output.new_empty(0)
28
+
29
+ grad_input = fused.fused_bias_act(
30
+ grad_output, empty, out, 3, 1, negative_slope, scale
31
+ )
32
+
33
+ dim = [0]
34
+
35
+ if grad_input.ndim > 2:
36
+ dim += list(range(2, grad_input.ndim))
37
+
38
+ grad_bias = grad_input.sum(dim).detach()
39
+
40
+ return grad_input, grad_bias
41
+
42
+ @staticmethod
43
+ def backward(ctx, gradgrad_input, gradgrad_bias):
44
+ out, = ctx.saved_tensors
45
+ gradgrad_out = fused.fused_bias_act(
46
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
47
+ )
48
+
49
+ return gradgrad_out, None, None, None
50
+
51
+
52
+ class FusedLeakyReLUFunction(Function):
53
+ @staticmethod
54
+ def forward(ctx, input, bias, negative_slope, scale):
55
+ empty = input.new_empty(0)
56
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
57
+ ctx.save_for_backward(out)
58
+ ctx.negative_slope = negative_slope
59
+ ctx.scale = scale
60
+
61
+ return out
62
+
63
+ @staticmethod
64
+ def backward(ctx, grad_output):
65
+ out, = ctx.saved_tensors
66
+
67
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
68
+ grad_output, out, ctx.negative_slope, ctx.scale
69
+ )
70
+
71
+ return grad_input, grad_bias, None, None
72
+
73
+
74
+ class FusedLeakyReLU(nn.Module):
75
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
76
+ super().__init__()
77
+
78
+ self.bias = nn.Parameter(torch.zeros(channel))
79
+ self.negative_slope = negative_slope
80
+ self.scale = scale
81
+
82
+ def forward(self, input):
83
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
84
+
85
+
86
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
87
+ if use_fallback or input.device.type == 'cpu':
88
+ return scale * F.leaky_relu(
89
+ input + bias.view((1, -1)+(1,)*(input.ndim-2)), negative_slope=negative_slope
90
+ )
91
+ else:
92
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
models/stylegan2/stylegan2-pytorch/op/fused_bias_act.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5
+ int act, int grad, float alpha, float scale);
6
+
7
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10
+
11
+ torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12
+ int act, int grad, float alpha, float scale) {
13
+ CHECK_CUDA(input);
14
+ CHECK_CUDA(bias);
15
+
16
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17
+ }
18
+
19
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21
+ }
models/stylegan2/stylegan2-pytorch/op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22
+
23
+ scalar_t zero = 0.0;
24
+
25
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26
+ scalar_t x = p_x[xi];
27
+
28
+ if (use_bias) {
29
+ x += p_b[(xi / step_b) % size_b];
30
+ }
31
+
32
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
33
+
34
+ scalar_t y;
35
+
36
+ switch (act * 10 + grad) {
37
+ default:
38
+ case 10: y = x; break;
39
+ case 11: y = x; break;
40
+ case 12: y = 0.0; break;
41
+
42
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
43
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
44
+ case 32: y = 0.0; break;
45
+ }
46
+
47
+ out[xi] = y * scale;
48
+ }
49
+ }
50
+
51
+
52
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53
+ int act, int grad, float alpha, float scale) {
54
+ int curDevice = -1;
55
+ cudaGetDevice(&curDevice);
56
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57
+
58
+ auto x = input.contiguous();
59
+ auto b = bias.contiguous();
60
+ auto ref = refer.contiguous();
61
+
62
+ int use_bias = b.numel() ? 1 : 0;
63
+ int use_ref = ref.numel() ? 1 : 0;
64
+
65
+ int size_x = x.numel();
66
+ int size_b = b.numel();
67
+ int step_b = 1;
68
+
69
+ for (int i = 1 + 1; i < x.dim(); i++) {
70
+ step_b *= x.size(i);
71
+ }
72
+
73
+ int loop_x = 4;
74
+ int block_size = 4 * 32;
75
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76
+
77
+ auto y = torch::empty_like(x);
78
+
79
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
81
+ y.data_ptr<scalar_t>(),
82
+ x.data_ptr<scalar_t>(),
83
+ b.data_ptr<scalar_t>(),
84
+ ref.data_ptr<scalar_t>(),
85
+ act,
86
+ grad,
87
+ alpha,
88
+ scale,
89
+ loop_x,
90
+ size_x,
91
+ step_b,
92
+ size_b,
93
+ use_bias,
94
+ use_ref
95
+ );
96
+ });
97
+
98
+ return y;
99
+ }
models/stylegan2/stylegan2-pytorch/op/setup.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+ from torch.utils.cpp_extension import CUDAExtension, BuildExtension
3
+ from pathlib import Path
4
+
5
+ # Usage:
6
+ # python setup.py install (or python setup.py bdist_wheel)
7
+ # NB: Windows: run from VS2017 x64 Native Tool Command Prompt
8
+
9
+ rootdir = (Path(__file__).parent / '..' / 'op').resolve()
10
+
11
+ setup(
12
+ name='upfirdn2d',
13
+ ext_modules=[
14
+ CUDAExtension('upfirdn2d_op',
15
+ [str(rootdir / 'upfirdn2d.cpp'), str(rootdir / 'upfirdn2d_kernel.cu')],
16
+ )
17
+ ],
18
+ cmdclass={
19
+ 'build_ext': BuildExtension
20
+ }
21
+ )
22
+
23
+ setup(
24
+ name='fused',
25
+ ext_modules=[
26
+ CUDAExtension('fused',
27
+ [str(rootdir / 'fused_bias_act.cpp'), str(rootdir / 'fused_bias_act_kernel.cu')],
28
+ )
29
+ ],
30
+ cmdclass={
31
+ 'build_ext': BuildExtension
32
+ }
33
+ )
models/stylegan2/stylegan2-pytorch/op/upfirdn2d.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5
+ int up_x, int up_y, int down_x, int down_y,
6
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7
+
8
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11
+
12
+ torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13
+ int up_x, int up_y, int down_x, int down_y,
14
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15
+ CHECK_CUDA(input);
16
+ CHECK_CUDA(kernel);
17
+
18
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19
+ }
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23
+ }
models/stylegan2/stylegan2-pytorch/op/upfirdn2d.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+ use_fallback = False
10
+
11
+ # Try loading precompiled, otherwise use native fallback
12
+ try:
13
+ import upfirdn2d_op
14
+ except ModuleNotFoundError as e:
15
+ print('StyleGAN2: Optimized CUDA op UpFirDn2d not available, using native PyTorch fallback.')
16
+ use_fallback = True
17
+
18
+ class UpFirDn2dBackward(Function):
19
+ @staticmethod
20
+ def forward(
21
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
22
+ ):
23
+
24
+ up_x, up_y = up
25
+ down_x, down_y = down
26
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
27
+
28
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
29
+
30
+ grad_input = upfirdn2d_op.upfirdn2d(
31
+ grad_output,
32
+ grad_kernel,
33
+ down_x,
34
+ down_y,
35
+ up_x,
36
+ up_y,
37
+ g_pad_x0,
38
+ g_pad_x1,
39
+ g_pad_y0,
40
+ g_pad_y1,
41
+ )
42
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
43
+
44
+ ctx.save_for_backward(kernel)
45
+
46
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
47
+
48
+ ctx.up_x = up_x
49
+ ctx.up_y = up_y
50
+ ctx.down_x = down_x
51
+ ctx.down_y = down_y
52
+ ctx.pad_x0 = pad_x0
53
+ ctx.pad_x1 = pad_x1
54
+ ctx.pad_y0 = pad_y0
55
+ ctx.pad_y1 = pad_y1
56
+ ctx.in_size = in_size
57
+ ctx.out_size = out_size
58
+
59
+ return grad_input
60
+
61
+ @staticmethod
62
+ def backward(ctx, gradgrad_input):
63
+ kernel, = ctx.saved_tensors
64
+
65
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
66
+
67
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
68
+ gradgrad_input,
69
+ kernel,
70
+ ctx.up_x,
71
+ ctx.up_y,
72
+ ctx.down_x,
73
+ ctx.down_y,
74
+ ctx.pad_x0,
75
+ ctx.pad_x1,
76
+ ctx.pad_y0,
77
+ ctx.pad_y1,
78
+ )
79
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
80
+ gradgrad_out = gradgrad_out.view(
81
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
82
+ )
83
+
84
+ return gradgrad_out, None, None, None, None, None, None, None, None
85
+
86
+
87
+ class UpFirDn2d(Function):
88
+ @staticmethod
89
+ def forward(ctx, input, kernel, up, down, pad):
90
+ up_x, up_y = up
91
+ down_x, down_y = down
92
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
93
+
94
+ kernel_h, kernel_w = kernel.shape
95
+ batch, channel, in_h, in_w = input.shape
96
+ ctx.in_size = input.shape
97
+
98
+ input = input.reshape(-1, in_h, in_w, 1)
99
+
100
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
101
+
102
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
103
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
104
+ ctx.out_size = (out_h, out_w)
105
+
106
+ ctx.up = (up_x, up_y)
107
+ ctx.down = (down_x, down_y)
108
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
109
+
110
+ g_pad_x0 = kernel_w - pad_x0 - 1
111
+ g_pad_y0 = kernel_h - pad_y0 - 1
112
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
113
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
114
+
115
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
116
+
117
+ out = upfirdn2d_op.upfirdn2d(
118
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
119
+ )
120
+ # out = out.view(major, out_h, out_w, minor)
121
+ out = out.view(-1, channel, out_h, out_w)
122
+
123
+ return out
124
+
125
+ @staticmethod
126
+ def backward(ctx, grad_output):
127
+ kernel, grad_kernel = ctx.saved_tensors
128
+
129
+ grad_input = UpFirDn2dBackward.apply(
130
+ grad_output,
131
+ kernel,
132
+ grad_kernel,
133
+ ctx.up,
134
+ ctx.down,
135
+ ctx.pad,
136
+ ctx.g_pad,
137
+ ctx.in_size,
138
+ ctx.out_size,
139
+ )
140
+
141
+ return grad_input, None, None, None, None
142
+
143
+
144
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
145
+ if use_fallback or input.device.type == "cpu":
146
+ out = upfirdn2d_native(
147
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
148
+ )
149
+ else:
150
+ out = UpFirDn2d.apply(
151
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
152
+ )
153
+
154
+ return out
155
+
156
+
157
+ def upfirdn2d_native(
158
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
159
+ ):
160
+ _, channel, in_h, in_w = input.shape
161
+ input = input.reshape(-1, in_h, in_w, 1)
162
+
163
+ _, in_h, in_w, minor = input.shape
164
+ kernel_h, kernel_w = kernel.shape
165
+
166
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
167
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
168
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
169
+
170
+ out = F.pad(
171
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
172
+ )
173
+ out = out[
174
+ :,
175
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
176
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
177
+ :,
178
+ ]
179
+
180
+ out = out.permute(0, 3, 1, 2)
181
+ out = out.reshape(
182
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
183
+ )
184
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
185
+ out = F.conv2d(out, w)
186
+ out = out.reshape(
187
+ -1,
188
+ minor,
189
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
190
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
191
+ )
192
+ out = out.permute(0, 2, 3, 1)
193
+ out = out[:, ::down_y, ::down_x, :]
194
+
195
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
196
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
197
+
198
+ return out.view(-1, channel, out_h, out_w)
models/stylegan2/stylegan2-pytorch/op/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18
+ int c = a / b;
19
+
20
+ if (c * b > a) {
21
+ c--;
22
+ }
23
+
24
+ return c;
25
+ }
26
+
27
+ struct UpFirDn2DKernelParams {
28
+ int up_x;
29
+ int up_y;
30
+ int down_x;
31
+ int down_y;
32
+ int pad_x0;
33
+ int pad_x1;
34
+ int pad_y0;
35
+ int pad_y1;
36
+
37
+ int major_dim;
38
+ int in_h;
39
+ int in_w;
40
+ int minor_dim;
41
+ int kernel_h;
42
+ int kernel_w;
43
+ int out_h;
44
+ int out_w;
45
+ int loop_major;
46
+ int loop_x;
47
+ };
48
+
49
+ template <typename scalar_t>
50
+ __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51
+ const scalar_t *kernel,
52
+ const UpFirDn2DKernelParams p) {
53
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54
+ int out_y = minor_idx / p.minor_dim;
55
+ minor_idx -= out_y * p.minor_dim;
56
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57
+ int major_idx_base = blockIdx.z * p.loop_major;
58
+
59
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
60
+ major_idx_base >= p.major_dim) {
61
+ return;
62
+ }
63
+
64
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68
+
69
+ for (int loop_major = 0, major_idx = major_idx_base;
70
+ loop_major < p.loop_major && major_idx < p.major_dim;
71
+ loop_major++, major_idx++) {
72
+ for (int loop_x = 0, out_x = out_x_base;
73
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78
+
79
+ const scalar_t *x_p =
80
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81
+ minor_idx];
82
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83
+ int x_px = p.minor_dim;
84
+ int k_px = -p.up_x;
85
+ int x_py = p.in_w * p.minor_dim;
86
+ int k_py = -p.up_y * p.kernel_w;
87
+
88
+ scalar_t v = 0.0f;
89
+
90
+ for (int y = 0; y < h; y++) {
91
+ for (int x = 0; x < w; x++) {
92
+ v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
93
+ x_p += x_px;
94
+ k_p += k_px;
95
+ }
96
+
97
+ x_p += x_py - w * x_px;
98
+ k_p += k_py - w * k_px;
99
+ }
100
+
101
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102
+ minor_idx] = v;
103
+ }
104
+ }
105
+ }
106
+
107
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
108
+ int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
109
+ __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110
+ const scalar_t *kernel,
111
+ const UpFirDn2DKernelParams p) {
112
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114
+
115
+ __shared__ volatile float sk[kernel_h][kernel_w];
116
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
117
+
118
+ int minor_idx = blockIdx.x;
119
+ int tile_out_y = minor_idx / p.minor_dim;
120
+ minor_idx -= tile_out_y * p.minor_dim;
121
+ tile_out_y *= tile_out_h;
122
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123
+ int major_idx_base = blockIdx.z * p.loop_major;
124
+
125
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126
+ major_idx_base >= p.major_dim) {
127
+ return;
128
+ }
129
+
130
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131
+ tap_idx += blockDim.x) {
132
+ int ky = tap_idx / kernel_w;
133
+ int kx = tap_idx - ky * kernel_w;
134
+ scalar_t v = 0.0;
135
+
136
+ if (kx < p.kernel_w & ky < p.kernel_h) {
137
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138
+ }
139
+
140
+ sk[ky][kx] = v;
141
+ }
142
+
143
+ for (int loop_major = 0, major_idx = major_idx_base;
144
+ loop_major < p.loop_major & major_idx < p.major_dim;
145
+ loop_major++, major_idx++) {
146
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
147
+ loop_x < p.loop_x & tile_out_x < p.out_w;
148
+ loop_x++, tile_out_x += tile_out_w) {
149
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151
+ int tile_in_x = floor_div(tile_mid_x, up_x);
152
+ int tile_in_y = floor_div(tile_mid_y, up_y);
153
+
154
+ __syncthreads();
155
+
156
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157
+ in_idx += blockDim.x) {
158
+ int rel_in_y = in_idx / tile_in_w;
159
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
160
+ int in_x = rel_in_x + tile_in_x;
161
+ int in_y = rel_in_y + tile_in_y;
162
+
163
+ scalar_t v = 0.0;
164
+
165
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167
+ p.minor_dim +
168
+ minor_idx];
169
+ }
170
+
171
+ sx[rel_in_y][rel_in_x] = v;
172
+ }
173
+
174
+ __syncthreads();
175
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176
+ out_idx += blockDim.x) {
177
+ int rel_out_y = out_idx / tile_out_w;
178
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
179
+ int out_x = rel_out_x + tile_out_x;
180
+ int out_y = rel_out_y + tile_out_y;
181
+
182
+ int mid_x = tile_mid_x + rel_out_x * down_x;
183
+ int mid_y = tile_mid_y + rel_out_y * down_y;
184
+ int in_x = floor_div(mid_x, up_x);
185
+ int in_y = floor_div(mid_y, up_y);
186
+ int rel_in_x = in_x - tile_in_x;
187
+ int rel_in_y = in_y - tile_in_y;
188
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190
+
191
+ scalar_t v = 0.0;
192
+
193
+ #pragma unroll
194
+ for (int y = 0; y < kernel_h / up_y; y++)
195
+ #pragma unroll
196
+ for (int x = 0; x < kernel_w / up_x; x++)
197
+ v += sx[rel_in_y + y][rel_in_x + x] *
198
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
199
+
200
+ if (out_x < p.out_w & out_y < p.out_h) {
201
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202
+ minor_idx] = v;
203
+ }
204
+ }
205
+ }
206
+ }
207
+ }
208
+
209
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210
+ const torch::Tensor &kernel, int up_x, int up_y,
211
+ int down_x, int down_y, int pad_x0, int pad_x1,
212
+ int pad_y0, int pad_y1) {
213
+ int curDevice = -1;
214
+ cudaGetDevice(&curDevice);
215
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
216
+
217
+ UpFirDn2DKernelParams p;
218
+
219
+ auto x = input.contiguous();
220
+ auto k = kernel.contiguous();
221
+
222
+ p.major_dim = x.size(0);
223
+ p.in_h = x.size(1);
224
+ p.in_w = x.size(2);
225
+ p.minor_dim = x.size(3);
226
+ p.kernel_h = k.size(0);
227
+ p.kernel_w = k.size(1);
228
+ p.up_x = up_x;
229
+ p.up_y = up_y;
230
+ p.down_x = down_x;
231
+ p.down_y = down_y;
232
+ p.pad_x0 = pad_x0;
233
+ p.pad_x1 = pad_x1;
234
+ p.pad_y0 = pad_y0;
235
+ p.pad_y1 = pad_y1;
236
+
237
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238
+ p.down_y;
239
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240
+ p.down_x;
241
+
242
+ auto out =
243
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244
+
245
+ int mode = -1;
246
+
247
+ int tile_out_h = -1;
248
+ int tile_out_w = -1;
249
+
250
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
252
+ mode = 1;
253
+ tile_out_h = 16;
254
+ tile_out_w = 64;
255
+ }
256
+
257
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
259
+ mode = 2;
260
+ tile_out_h = 16;
261
+ tile_out_w = 64;
262
+ }
263
+
264
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
266
+ mode = 3;
267
+ tile_out_h = 16;
268
+ tile_out_w = 64;
269
+ }
270
+
271
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
273
+ mode = 4;
274
+ tile_out_h = 16;
275
+ tile_out_w = 64;
276
+ }
277
+
278
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
280
+ mode = 5;
281
+ tile_out_h = 8;
282
+ tile_out_w = 32;
283
+ }
284
+
285
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
287
+ mode = 6;
288
+ tile_out_h = 8;
289
+ tile_out_w = 32;
290
+ }
291
+
292
+ dim3 block_size;
293
+ dim3 grid_size;
294
+
295
+ if (tile_out_h > 0 && tile_out_w > 0) {
296
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
297
+ p.loop_x = 1;
298
+ block_size = dim3(32 * 8, 1, 1);
299
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301
+ (p.major_dim - 1) / p.loop_major + 1);
302
+ } else {
303
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
304
+ p.loop_x = 4;
305
+ block_size = dim3(4, 32, 1);
306
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308
+ (p.major_dim - 1) / p.loop_major + 1);
309
+ }
310
+
311
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312
+ switch (mode) {
313
+ case 1:
314
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
315
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
316
+ x.data_ptr<scalar_t>(),
317
+ k.data_ptr<scalar_t>(), p);
318
+
319
+ break;
320
+
321
+ case 2:
322
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
323
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
324
+ x.data_ptr<scalar_t>(),
325
+ k.data_ptr<scalar_t>(), p);
326
+
327
+ break;
328
+
329
+ case 3:
330
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
331
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
332
+ x.data_ptr<scalar_t>(),
333
+ k.data_ptr<scalar_t>(), p);
334
+
335
+ break;
336
+
337
+ case 4:
338
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
339
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
340
+ x.data_ptr<scalar_t>(),
341
+ k.data_ptr<scalar_t>(), p);
342
+
343
+ break;
344
+
345
+ case 5:
346
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
347
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
348
+ x.data_ptr<scalar_t>(),
349
+ k.data_ptr<scalar_t>(), p);
350
+
351
+ break;
352
+
353
+ case 6:
354
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
355
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
356
+ x.data_ptr<scalar_t>(),
357
+ k.data_ptr<scalar_t>(), p);
358
+
359
+ break;
360
+
361
+ default:
362
+ upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
363
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
364
+ k.data_ptr<scalar_t>(), p);
365
+ }
366
+ });
367
+
368
+ return out;
369
+ }
models/stylegan2/stylegan2-pytorch/ppl.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+
8
+ import lpips
9
+ from model import Generator
10
+
11
+
12
+ def normalize(x):
13
+ return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True))
14
+
15
+
16
+ def slerp(a, b, t):
17
+ a = normalize(a)
18
+ b = normalize(b)
19
+ d = (a * b).sum(-1, keepdim=True)
20
+ p = t * torch.acos(d)
21
+ c = normalize(b - d * a)
22
+ d = a * torch.cos(p) + c * torch.sin(p)
23
+
24
+ return normalize(d)
25
+
26
+
27
+ def lerp(a, b, t):
28
+ return a + (b - a) * t
29
+
30
+
31
+ if __name__ == '__main__':
32
+ device = 'cuda'
33
+
34
+ parser = argparse.ArgumentParser()
35
+
36
+ parser.add_argument('--space', choices=['z', 'w'])
37
+ parser.add_argument('--batch', type=int, default=64)
38
+ parser.add_argument('--n_sample', type=int, default=5000)
39
+ parser.add_argument('--size', type=int, default=256)
40
+ parser.add_argument('--eps', type=float, default=1e-4)
41
+ parser.add_argument('--crop', action='store_true')
42
+ parser.add_argument('ckpt', metavar='CHECKPOINT')
43
+
44
+ args = parser.parse_args()
45
+
46
+ latent_dim = 512
47
+
48
+ ckpt = torch.load(args.ckpt)
49
+
50
+ g = Generator(args.size, latent_dim, 8).to(device)
51
+ g.load_state_dict(ckpt['g_ema'])
52
+ g.eval()
53
+
54
+ percept = lpips.PerceptualLoss(
55
+ model='net-lin', net='vgg', use_gpu=device.startswith('cuda')
56
+ )
57
+
58
+ distances = []
59
+
60
+ n_batch = args.n_sample // args.batch
61
+ resid = args.n_sample - (n_batch * args.batch)
62
+ batch_sizes = [args.batch] * n_batch + [resid]
63
+
64
+ with torch.no_grad():
65
+ for batch in tqdm(batch_sizes):
66
+ noise = g.make_noise()
67
+
68
+ inputs = torch.randn([batch * 2, latent_dim], device=device)
69
+ lerp_t = torch.rand(batch, device=device)
70
+
71
+ if args.space == 'w':
72
+ latent = g.get_latent(inputs)
73
+ latent_t0, latent_t1 = latent[::2], latent[1::2]
74
+ latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None])
75
+ latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps)
76
+ latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape)
77
+
78
+ image, _ = g([latent_e], input_is_latent=True, noise=noise)
79
+
80
+ if args.crop:
81
+ c = image.shape[2] // 8
82
+ image = image[:, :, c * 3 : c * 7, c * 2 : c * 6]
83
+
84
+ factor = image.shape[2] // 256
85
+
86
+ if factor > 1:
87
+ image = F.interpolate(
88
+ image, size=(256, 256), mode='bilinear', align_corners=False
89
+ )
90
+
91
+ dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / (
92
+ args.eps ** 2
93
+ )
94
+ distances.append(dist.to('cpu').numpy())
95
+
96
+ distances = np.concatenate(distances, 0)
97
+
98
+ lo = np.percentile(distances, 1, interpolation='lower')
99
+ hi = np.percentile(distances, 99, interpolation='higher')
100
+ filtered_dist = np.extract(
101
+ np.logical_and(lo <= distances, distances <= hi), distances
102
+ )
103
+
104
+ print('ppl:', filtered_dist.mean())
models/stylegan2/stylegan2-pytorch/prepare_data.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from io import BytesIO
3
+ import multiprocessing
4
+ from functools import partial
5
+
6
+ from PIL import Image
7
+ import lmdb
8
+ from tqdm import tqdm
9
+ from torchvision import datasets
10
+ from torchvision.transforms import functional as trans_fn
11
+
12
+
13
+ def resize_and_convert(img, size, resample, quality=100):
14
+ img = trans_fn.resize(img, size, resample)
15
+ img = trans_fn.center_crop(img, size)
16
+ buffer = BytesIO()
17
+ img.save(buffer, format='jpeg', quality=quality)
18
+ val = buffer.getvalue()
19
+
20
+ return val
21
+
22
+
23
+ def resize_multiple(img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100):
24
+ imgs = []
25
+
26
+ for size in sizes:
27
+ imgs.append(resize_and_convert(img, size, resample, quality))
28
+
29
+ return imgs
30
+
31
+
32
+ def resize_worker(img_file, sizes, resample):
33
+ i, file = img_file
34
+ img = Image.open(file)
35
+ img = img.convert('RGB')
36
+ out = resize_multiple(img, sizes=sizes, resample=resample)
37
+
38
+ return i, out
39
+
40
+
41
+ def prepare(env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS):
42
+ resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
43
+
44
+ files = sorted(dataset.imgs, key=lambda x: x[0])
45
+ files = [(i, file) for i, (file, label) in enumerate(files)]
46
+ total = 0
47
+
48
+ with multiprocessing.Pool(n_worker) as pool:
49
+ for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
50
+ for size, img in zip(sizes, imgs):
51
+ key = f'{size}-{str(i).zfill(5)}'.encode('utf-8')
52
+
53
+ with env.begin(write=True) as txn:
54
+ txn.put(key, img)
55
+
56
+ total += 1
57
+
58
+ with env.begin(write=True) as txn:
59
+ txn.put('length'.encode('utf-8'), str(total).encode('utf-8'))
60
+
61
+
62
+ if __name__ == '__main__':
63
+ parser = argparse.ArgumentParser()
64
+ parser.add_argument('--out', type=str)
65
+ parser.add_argument('--size', type=str, default='128,256,512,1024')
66
+ parser.add_argument('--n_worker', type=int, default=8)
67
+ parser.add_argument('--resample', type=str, default='lanczos')
68
+ parser.add_argument('path', type=str)
69
+
70
+ args = parser.parse_args()
71
+
72
+ resample_map = {'lanczos': Image.LANCZOS, 'bilinear': Image.BILINEAR}
73
+ resample = resample_map[args.resample]
74
+
75
+ sizes = [int(s.strip()) for s in args.size.split(',')]
76
+
77
+ print(f'Make dataset of image sizes:', ', '.join(str(s) for s in sizes))
78
+
79
+ imgset = datasets.ImageFolder(args.path)
80
+
81
+ with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
82
+ prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)
models/stylegan2/stylegan2-pytorch/projector.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+
5
+ import torch
6
+ from torch import optim
7
+ from torch.nn import functional as F
8
+ from torchvision import transforms
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+
12
+ import lpips
13
+ from model import Generator
14
+
15
+
16
+ def noise_regularize(noises):
17
+ loss = 0
18
+
19
+ for noise in noises:
20
+ size = noise.shape[2]
21
+
22
+ while True:
23
+ loss = (
24
+ loss
25
+ + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
26
+ + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
27
+ )
28
+
29
+ if size <= 8:
30
+ break
31
+
32
+ noise = noise.reshape([1, 1, size // 2, 2, size // 2, 2])
33
+ noise = noise.mean([3, 5])
34
+ size //= 2
35
+
36
+ return loss
37
+
38
+
39
+ def noise_normalize_(noises):
40
+ for noise in noises:
41
+ mean = noise.mean()
42
+ std = noise.std()
43
+
44
+ noise.data.add_(-mean).div_(std)
45
+
46
+
47
+ def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
48
+ lr_ramp = min(1, (1 - t) / rampdown)
49
+ lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
50
+ lr_ramp = lr_ramp * min(1, t / rampup)
51
+
52
+ return initial_lr * lr_ramp
53
+
54
+
55
+ def latent_noise(latent, strength):
56
+ noise = torch.randn_like(latent) * strength
57
+
58
+ return latent + noise
59
+
60
+
61
+ def make_image(tensor):
62
+ return (
63
+ tensor.detach()
64
+ .clamp_(min=-1, max=1)
65
+ .add(1)
66
+ .div_(2)
67
+ .mul(255)
68
+ .type(torch.uint8)
69
+ .permute(0, 2, 3, 1)
70
+ .to('cpu')
71
+ .numpy()
72
+ )
73
+
74
+
75
+ if __name__ == '__main__':
76
+ device = 'cuda'
77
+
78
+ parser = argparse.ArgumentParser()
79
+ parser.add_argument('--ckpt', type=str, required=True)
80
+ parser.add_argument('--size', type=int, default=256)
81
+ parser.add_argument('--lr_rampup', type=float, default=0.05)
82
+ parser.add_argument('--lr_rampdown', type=float, default=0.25)
83
+ parser.add_argument('--lr', type=float, default=0.1)
84
+ parser.add_argument('--noise', type=float, default=0.05)
85
+ parser.add_argument('--noise_ramp', type=float, default=0.75)
86
+ parser.add_argument('--step', type=int, default=1000)
87
+ parser.add_argument('--noise_regularize', type=float, default=1e5)
88
+ parser.add_argument('--mse', type=float, default=0)
89
+ parser.add_argument('--w_plus', action='store_true')
90
+ parser.add_argument('files', metavar='FILES', nargs='+')
91
+
92
+ args = parser.parse_args()
93
+
94
+ n_mean_latent = 10000
95
+
96
+ resize = min(args.size, 256)
97
+
98
+ transform = transforms.Compose(
99
+ [
100
+ transforms.Resize(resize),
101
+ transforms.CenterCrop(resize),
102
+ transforms.ToTensor(),
103
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
104
+ ]
105
+ )
106
+
107
+ imgs = []
108
+
109
+ for imgfile in args.files:
110
+ img = transform(Image.open(imgfile).convert('RGB'))
111
+ imgs.append(img)
112
+
113
+ imgs = torch.stack(imgs, 0).to(device)
114
+
115
+ g_ema = Generator(args.size, 512, 8)
116
+ g_ema.load_state_dict(torch.load(args.ckpt)['g_ema'], strict=False)
117
+ g_ema.eval()
118
+ g_ema = g_ema.to(device)
119
+
120
+ with torch.no_grad():
121
+ noise_sample = torch.randn(n_mean_latent, 512, device=device)
122
+ latent_out = g_ema.style(noise_sample)
123
+
124
+ latent_mean = latent_out.mean(0)
125
+ latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
126
+
127
+ percept = lpips.PerceptualLoss(
128
+ model='net-lin', net='vgg', use_gpu=device.startswith('cuda')
129
+ )
130
+
131
+ noises = g_ema.make_noise()
132
+
133
+ latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(2, 1)
134
+
135
+ if args.w_plus:
136
+ latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
137
+
138
+ latent_in.requires_grad = True
139
+
140
+ for noise in noises:
141
+ noise.requires_grad = True
142
+
143
+ optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
144
+
145
+ pbar = tqdm(range(args.step))
146
+ latent_path = []
147
+
148
+ for i in pbar:
149
+ t = i / args.step
150
+ lr = get_lr(t, args.lr)
151
+ optimizer.param_groups[0]['lr'] = lr
152
+ noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2
153
+ latent_n = latent_noise(latent_in, noise_strength.item())
154
+
155
+ img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises)
156
+
157
+ batch, channel, height, width = img_gen.shape
158
+
159
+ if height > 256:
160
+ factor = height // 256
161
+
162
+ img_gen = img_gen.reshape(
163
+ batch, channel, height // factor, factor, width // factor, factor
164
+ )
165
+ img_gen = img_gen.mean([3, 5])
166
+
167
+ p_loss = percept(img_gen, imgs).sum()
168
+ n_loss = noise_regularize(noises)
169
+ mse_loss = F.mse_loss(img_gen, imgs)
170
+
171
+ loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
172
+
173
+ optimizer.zero_grad()
174
+ loss.backward()
175
+ optimizer.step()
176
+
177
+ noise_normalize_(noises)
178
+
179
+ if (i + 1) % 100 == 0:
180
+ latent_path.append(latent_in.detach().clone())
181
+
182
+ pbar.set_description(
183
+ (
184
+ f'perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};'
185
+ f' mse: {mse_loss.item():.4f}; lr: {lr:.4f}'
186
+ )
187
+ )
188
+
189
+ result_file = {'noises': noises}
190
+
191
+ img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises)
192
+
193
+ filename = os.path.splitext(os.path.basename(args.files[0]))[0] + '.pt'
194
+
195
+ img_ar = make_image(img_gen)
196
+
197
+ for i, input_name in enumerate(args.files):
198
+ result_file[input_name] = {'img': img_gen[i], 'latent': latent_in[i]}
199
+ img_name = os.path.splitext(os.path.basename(input_name))[0] + '-project.png'
200
+ pil_img = Image.fromarray(img_ar[i])
201
+ pil_img.save(img_name)
202
+
203
+ torch.save(result_file, filename)
models/stylegan2/stylegan2-pytorch/sample/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.png
models/stylegan2/stylegan2-pytorch/train.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import random
4
+ import os
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn, autograd, optim
9
+ from torch.nn import functional as F
10
+ from torch.utils import data
11
+ import torch.distributed as dist
12
+ from torchvision import transforms, utils
13
+ from tqdm import tqdm
14
+
15
+ try:
16
+ import wandb
17
+
18
+ except ImportError:
19
+ wandb = None
20
+
21
+ from model import Generator, Discriminator
22
+ from dataset import MultiResolutionDataset
23
+ from distributed import (
24
+ get_rank,
25
+ synchronize,
26
+ reduce_loss_dict,
27
+ reduce_sum,
28
+ get_world_size,
29
+ )
30
+
31
+
32
+ def data_sampler(dataset, shuffle, distributed):
33
+ if distributed:
34
+ return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
35
+
36
+ if shuffle:
37
+ return data.RandomSampler(dataset)
38
+
39
+ else:
40
+ return data.SequentialSampler(dataset)
41
+
42
+
43
+ def requires_grad(model, flag=True):
44
+ for p in model.parameters():
45
+ p.requires_grad = flag
46
+
47
+
48
+ def accumulate(model1, model2, decay=0.999):
49
+ par1 = dict(model1.named_parameters())
50
+ par2 = dict(model2.named_parameters())
51
+
52
+ for k in par1.keys():
53
+ par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)
54
+
55
+
56
+ def sample_data(loader):
57
+ while True:
58
+ for batch in loader:
59
+ yield batch
60
+
61
+
62
+ def d_logistic_loss(real_pred, fake_pred):
63
+ real_loss = F.softplus(-real_pred)
64
+ fake_loss = F.softplus(fake_pred)
65
+
66
+ return real_loss.mean() + fake_loss.mean()
67
+
68
+
69
+ def d_r1_loss(real_pred, real_img):
70
+ grad_real, = autograd.grad(
71
+ outputs=real_pred.sum(), inputs=real_img, create_graph=True
72
+ )
73
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
74
+
75
+ return grad_penalty
76
+
77
+
78
+ def g_nonsaturating_loss(fake_pred):
79
+ loss = F.softplus(-fake_pred).mean()
80
+
81
+ return loss
82
+
83
+
84
+ def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
85
+ noise = torch.randn_like(fake_img) / math.sqrt(
86
+ fake_img.shape[2] * fake_img.shape[3]
87
+ )
88
+ grad, = autograd.grad(
89
+ outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
90
+ )
91
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
92
+
93
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
94
+
95
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
96
+
97
+ return path_penalty, path_mean.detach(), path_lengths
98
+
99
+
100
+ def make_noise(batch, latent_dim, n_noise, device):
101
+ if n_noise == 1:
102
+ return torch.randn(batch, latent_dim, device=device)
103
+
104
+ noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
105
+
106
+ return noises
107
+
108
+
109
+ def mixing_noise(batch, latent_dim, prob, device):
110
+ if prob > 0 and random.random() < prob:
111
+ return make_noise(batch, latent_dim, 2, device)
112
+
113
+ else:
114
+ return [make_noise(batch, latent_dim, 1, device)]
115
+
116
+
117
+ def set_grad_none(model, targets):
118
+ for n, p in model.named_parameters():
119
+ if n in targets:
120
+ p.grad = None
121
+
122
+
123
+ def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
124
+ loader = sample_data(loader)
125
+
126
+ pbar = range(args.iter)
127
+
128
+ if get_rank() == 0:
129
+ pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
130
+
131
+ mean_path_length = 0
132
+
133
+ d_loss_val = 0
134
+ r1_loss = torch.tensor(0.0, device=device)
135
+ g_loss_val = 0
136
+ path_loss = torch.tensor(0.0, device=device)
137
+ path_lengths = torch.tensor(0.0, device=device)
138
+ mean_path_length_avg = 0
139
+ loss_dict = {}
140
+
141
+ if args.distributed:
142
+ g_module = generator.module
143
+ d_module = discriminator.module
144
+
145
+ else:
146
+ g_module = generator
147
+ d_module = discriminator
148
+
149
+ accum = 0.5 ** (32 / (10 * 1000))
150
+
151
+ sample_z = torch.randn(args.n_sample, args.latent, device=device)
152
+
153
+ for idx in pbar:
154
+ i = idx + args.start_iter
155
+
156
+ if i > args.iter:
157
+ print("Done!")
158
+
159
+ break
160
+
161
+ real_img = next(loader)
162
+ real_img = real_img.to(device)
163
+
164
+ requires_grad(generator, False)
165
+ requires_grad(discriminator, True)
166
+
167
+ noise = mixing_noise(args.batch, args.latent, args.mixing, device)
168
+ fake_img, _ = generator(noise)
169
+ fake_pred = discriminator(fake_img)
170
+
171
+ real_pred = discriminator(real_img)
172
+ d_loss = d_logistic_loss(real_pred, fake_pred)
173
+
174
+ loss_dict["d"] = d_loss
175
+ loss_dict["real_score"] = real_pred.mean()
176
+ loss_dict["fake_score"] = fake_pred.mean()
177
+
178
+ discriminator.zero_grad()
179
+ d_loss.backward()
180
+ d_optim.step()
181
+
182
+ d_regularize = i % args.d_reg_every == 0
183
+
184
+ if d_regularize:
185
+ real_img.requires_grad = True
186
+ real_pred = discriminator(real_img)
187
+ r1_loss = d_r1_loss(real_pred, real_img)
188
+
189
+ discriminator.zero_grad()
190
+ (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()
191
+
192
+ d_optim.step()
193
+
194
+ loss_dict["r1"] = r1_loss
195
+
196
+ requires_grad(generator, True)
197
+ requires_grad(discriminator, False)
198
+
199
+ noise = mixing_noise(args.batch, args.latent, args.mixing, device)
200
+ fake_img, _ = generator(noise)
201
+ fake_pred = discriminator(fake_img)
202
+ g_loss = g_nonsaturating_loss(fake_pred)
203
+
204
+ loss_dict["g"] = g_loss
205
+
206
+ generator.zero_grad()
207
+ g_loss.backward()
208
+ g_optim.step()
209
+
210
+ g_regularize = i % args.g_reg_every == 0
211
+
212
+ if g_regularize:
213
+ path_batch_size = max(1, args.batch // args.path_batch_shrink)
214
+ noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
215
+ fake_img, latents = generator(noise, return_latents=True)
216
+
217
+ path_loss, mean_path_length, path_lengths = g_path_regularize(
218
+ fake_img, latents, mean_path_length
219
+ )
220
+
221
+ generator.zero_grad()
222
+ weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
223
+
224
+ if args.path_batch_shrink:
225
+ weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
226
+
227
+ weighted_path_loss.backward()
228
+
229
+ g_optim.step()
230
+
231
+ mean_path_length_avg = (
232
+ reduce_sum(mean_path_length).item() / get_world_size()
233
+ )
234
+
235
+ loss_dict["path"] = path_loss
236
+ loss_dict["path_length"] = path_lengths.mean()
237
+
238
+ accumulate(g_ema, g_module, accum)
239
+
240
+ loss_reduced = reduce_loss_dict(loss_dict)
241
+
242
+ d_loss_val = loss_reduced["d"].mean().item()
243
+ g_loss_val = loss_reduced["g"].mean().item()
244
+ r1_val = loss_reduced["r1"].mean().item()
245
+ path_loss_val = loss_reduced["path"].mean().item()
246
+ real_score_val = loss_reduced["real_score"].mean().item()
247
+ fake_score_val = loss_reduced["fake_score"].mean().item()
248
+ path_length_val = loss_reduced["path_length"].mean().item()
249
+
250
+ if get_rank() == 0:
251
+ pbar.set_description(
252
+ (
253
+ f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
254
+ f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}"
255
+ )
256
+ )
257
+
258
+ if wandb and args.wandb:
259
+ wandb.log(
260
+ {
261
+ "Generator": g_loss_val,
262
+ "Discriminator": d_loss_val,
263
+ "R1": r1_val,
264
+ "Path Length Regularization": path_loss_val,
265
+ "Mean Path Length": mean_path_length,
266
+ "Real Score": real_score_val,
267
+ "Fake Score": fake_score_val,
268
+ "Path Length": path_length_val,
269
+ }
270
+ )
271
+
272
+ if i % 100 == 0:
273
+ with torch.no_grad():
274
+ g_ema.eval()
275
+ sample, _ = g_ema([sample_z])
276
+ utils.save_image(
277
+ sample,
278
+ f"sample/{str(i).zfill(6)}.png",
279
+ nrow=int(args.n_sample ** 0.5),
280
+ normalize=True,
281
+ range=(-1, 1),
282
+ )
283
+
284
+ if i % 10000 == 0:
285
+ torch.save(
286
+ {
287
+ "g": g_module.state_dict(),
288
+ "d": d_module.state_dict(),
289
+ "g_ema": g_ema.state_dict(),
290
+ "g_optim": g_optim.state_dict(),
291
+ "d_optim": d_optim.state_dict(),
292
+ },
293
+ f"checkpoint/{str(i).zfill(6)}.pt",
294
+ )
295
+
296
+
297
+ if __name__ == "__main__":
298
+ device = "cuda"
299
+
300
+ parser = argparse.ArgumentParser()
301
+
302
+ parser.add_argument("path", type=str)
303
+ parser.add_argument("--iter", type=int, default=800000)
304
+ parser.add_argument("--batch", type=int, default=16)
305
+ parser.add_argument("--n_sample", type=int, default=64)
306
+ parser.add_argument("--size", type=int, default=256)
307
+ parser.add_argument("--r1", type=float, default=10)
308
+ parser.add_argument("--path_regularize", type=float, default=2)
309
+ parser.add_argument("--path_batch_shrink", type=int, default=2)
310
+ parser.add_argument("--d_reg_every", type=int, default=16)
311
+ parser.add_argument("--g_reg_every", type=int, default=4)
312
+ parser.add_argument("--mixing", type=float, default=0.9)
313
+ parser.add_argument("--ckpt", type=str, default=None)
314
+ parser.add_argument("--lr", type=float, default=0.002)
315
+ parser.add_argument("--channel_multiplier", type=int, default=2)
316
+ parser.add_argument("--wandb", action="store_true")
317
+ parser.add_argument("--local_rank", type=int, default=0)
318
+
319
+ args = parser.parse_args()
320
+
321
+ n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
322
+ args.distributed = n_gpu > 1
323
+
324
+ if args.distributed:
325
+ torch.cuda.set_device(args.local_rank)
326
+ torch.distributed.init_process_group(backend="nccl", init_method="env://")
327
+ synchronize()
328
+
329
+ args.latent = 512
330
+ args.n_mlp = 8
331
+
332
+ args.start_iter = 0
333
+
334
+ generator = Generator(
335
+ args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
336
+ ).to(device)
337
+ discriminator = Discriminator(
338
+ args.size, channel_multiplier=args.channel_multiplier
339
+ ).to(device)
340
+ g_ema = Generator(
341
+ args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
342
+ ).to(device)
343
+ g_ema.eval()
344
+ accumulate(g_ema, generator, 0)
345
+
346
+ g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
347
+ d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
348
+
349
+ g_optim = optim.Adam(
350
+ generator.parameters(),
351
+ lr=args.lr * g_reg_ratio,
352
+ betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
353
+ )
354
+ d_optim = optim.Adam(
355
+ discriminator.parameters(),
356
+ lr=args.lr * d_reg_ratio,
357
+ betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
358
+ )
359
+
360
+ if args.ckpt is not None:
361
+ print("load model:", args.ckpt)
362
+
363
+ ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)
364
+
365
+ try:
366
+ ckpt_name = os.path.basename(args.ckpt)
367
+ args.start_iter = int(os.path.splitext(ckpt_name)[0])
368
+
369
+ except ValueError:
370
+ pass
371
+
372
+ generator.load_state_dict(ckpt["g"])
373
+ discriminator.load_state_dict(ckpt["d"])
374
+ g_ema.load_state_dict(ckpt["g_ema"])
375
+
376
+ g_optim.load_state_dict(ckpt["g_optim"])
377
+ d_optim.load_state_dict(ckpt["d_optim"])
378
+
379
+ if args.distributed:
380
+ generator = nn.parallel.DistributedDataParallel(
381
+ generator,
382
+ device_ids=[args.local_rank],
383
+ output_device=args.local_rank,
384
+ broadcast_buffers=False,
385
+ )
386
+
387
+ discriminator = nn.parallel.DistributedDataParallel(
388
+ discriminator,
389
+ device_ids=[args.local_rank],
390
+ output_device=args.local_rank,
391
+ broadcast_buffers=False,
392
+ )
393
+
394
+ transform = transforms.Compose(
395
+ [
396
+ transforms.RandomHorizontalFlip(),
397
+ transforms.ToTensor(),
398
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
399
+ ]
400
+ )
401
+
402
+ dataset = MultiResolutionDataset(args.path, transform, args.size)
403
+ loader = data.DataLoader(
404
+ dataset,
405
+ batch_size=args.batch,
406
+ sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
407
+ drop_last=True,
408
+ )
409
+
410
+ if get_rank() == 0 and wandb is not None and args.wandb:
411
+ wandb.init(project="stylegan 2")
412
+
413
+ train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)