vitaliykinakh commited on
Commit
8d6cd57
β€’
1 Parent(s): 931138c
.gitignore ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Created by https://www.toptal.com/developers/gitignore/api/python,pycharm,jupyternotebooks,images
3
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm,jupyternotebooks,images
4
+
5
+ .idea/
6
+ models/InfoSCC-GAN/
7
+ models/BigGAN/
8
+ models/CVAE
9
+ *.csv
10
+
11
+
12
+ ### Images ###
13
+ # JPEG
14
+ *.jpg
15
+ *.jpeg
16
+ *.jpe
17
+ *.jif
18
+ *.jfif
19
+ *.jfi
20
+
21
+ # JPEG 2000
22
+ *.jp2
23
+ *.j2k
24
+ *.jpf
25
+ *.jpx
26
+ *.jpm
27
+ *.mj2
28
+
29
+ # JPEG XR
30
+ *.jxr
31
+ *.hdp
32
+ *.wdp
33
+
34
+ # Graphics Interchange Format
35
+ *.gif
36
+
37
+ # RAW
38
+ *.raw
39
+
40
+ # Web P
41
+ *.webp
42
+
43
+ # Portable Network Graphics
44
+ *.png
45
+
46
+ # Animated Portable Network Graphics
47
+ *.apng
48
+
49
+ # Multiple-image Network Graphics
50
+ *.mng
51
+
52
+ # Tagged Image File Format
53
+ *.tiff
54
+ *.tif
55
+
56
+ # Scalable Vector Graphics
57
+ *.svg
58
+ *.svgz
59
+
60
+ # Portable Document Format
61
+ *.pdf
62
+
63
+ # X BitMap
64
+ *.xbm
65
+
66
+ # BMP
67
+ *.bmp
68
+ *.dib
69
+
70
+ # ICO
71
+ *.ico
72
+
73
+ # 3D Images
74
+ *.3dm
75
+ *.max
76
+
77
+ ### JupyterNotebooks ###
78
+ # gitignore template for Jupyter Notebooks
79
+ # website: http://jupyter.org/
80
+
81
+ .ipynb_checkpoints
82
+ */.ipynb_checkpoints/*
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # Remove previous ipynb_checkpoints
89
+ # git rm -r .ipynb_checkpoints/
90
+
91
+ ### PyCharm ###
92
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
93
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
94
+
95
+ # User-specific stuff
96
+ .idea/**/workspace.xml
97
+ .idea/**/tasks.xml
98
+ .idea/**/usage.statistics.xml
99
+ .idea/**/dictionaries
100
+ .idea/**/shelf
101
+
102
+ # AWS User-specific
103
+ .idea/**/aws.xml
104
+
105
+ # Generated files
106
+ .idea/**/contentModel.xml
107
+
108
+ # Sensitive or high-churn files
109
+ .idea/**/dataSources/
110
+ .idea/**/dataSources.ids
111
+ .idea/**/dataSources.local.xml
112
+ .idea/**/sqlDataSources.xml
113
+ .idea/**/dynamic.xml
114
+ .idea/**/uiDesigner.xml
115
+ .idea/**/dbnavigator.xml
116
+
117
+ # Gradle
118
+ .idea/**/gradle.xml
119
+ .idea/**/libraries
120
+
121
+ # Gradle and Maven with auto-import
122
+ # When using Gradle or Maven with auto-import, you should exclude module files,
123
+ # since they will be recreated, and may cause churn. Uncomment if using
124
+ # auto-import.
125
+ # .idea/artifacts
126
+ # .idea/compiler.xml
127
+ # .idea/jarRepositories.xml
128
+ # .idea/modules.xml
129
+ # .idea/*.iml
130
+ # .idea/modules
131
+ # *.iml
132
+ # *.ipr
133
+
134
+ # CMake
135
+ cmake-build-*/
136
+
137
+ # Mongo Explorer plugin
138
+ .idea/**/mongoSettings.xml
139
+
140
+ # File-based project format
141
+ *.iws
142
+
143
+ # IntelliJ
144
+ out/
145
+
146
+ # mpeltonen/sbt-idea plugin
147
+ .idea_modules/
148
+
149
+ # JIRA plugin
150
+ atlassian-ide-plugin.xml
151
+
152
+ # Cursive Clojure plugin
153
+ .idea/replstate.xml
154
+
155
+ # SonarLint plugin
156
+ .idea/sonarlint/
157
+
158
+ # Crashlytics plugin (for Android Studio and IntelliJ)
159
+ com_crashlytics_export_strings.xml
160
+ crashlytics.properties
161
+ crashlytics-build.properties
162
+ fabric.properties
163
+
164
+ # Editor-based Rest Client
165
+ .idea/httpRequests
166
+
167
+ # Android studio 3.1+ serialized cache file
168
+ .idea/caches/build_file_checksums.ser
169
+
170
+ ### PyCharm Patch ###
171
+ # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
172
+
173
+ # *.iml
174
+ # modules.xml
175
+ # .idea/misc.xml
176
+ # *.ipr
177
+
178
+ # Sonarlint plugin
179
+ # https://plugins.jetbrains.com/plugin/7973-sonarlint
180
+ .idea/**/sonarlint/
181
+
182
+ # SonarQube Plugin
183
+ # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
184
+ .idea/**/sonarIssues.xml
185
+
186
+ # Markdown Navigator plugin
187
+ # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
188
+ .idea/**/markdown-navigator.xml
189
+ .idea/**/markdown-navigator-enh.xml
190
+ .idea/**/markdown-navigator/
191
+
192
+ # Cache file creation bug
193
+ # See https://youtrack.jetbrains.com/issue/JBR-2257
194
+ .idea/$CACHE_FILE$
195
+
196
+ # CodeStream plugin
197
+ # https://plugins.jetbrains.com/plugin/12206-codestream
198
+ .idea/codestream.xml
199
+
200
+ ### Python ###
201
+ # Byte-compiled / optimized / DLL files
202
+ __pycache__/
203
+ *.py[cod]
204
+ *$py.class
205
+
206
+ # C extensions
207
+ *.so
208
+
209
+ # Distribution / packaging
210
+ .Python
211
+ build/
212
+ develop-eggs/
213
+ dist/
214
+ downloads/
215
+ eggs/
216
+ .eggs/
217
+ lib/
218
+ lib64/
219
+ parts/
220
+ sdist/
221
+ var/
222
+ wheels/
223
+ share/python-wheels/
224
+ *.egg-info/
225
+ .installed.cfg
226
+ *.egg
227
+ MANIFEST
228
+
229
+ # PyInstaller
230
+ # Usually these files are written by a python script from a template
231
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
232
+ *.manifest
233
+ *.spec
234
+
235
+ # Installer logs
236
+ pip-log.txt
237
+ pip-delete-this-directory.txt
238
+
239
+ # Unit test / coverage reports
240
+ htmlcov/
241
+ .tox/
242
+ .nox/
243
+ .coverage
244
+ .coverage.*
245
+ .cache
246
+ nosetests.xml
247
+ coverage.xml
248
+ *.cover
249
+ *.py,cover
250
+ .hypothesis/
251
+ .pytest_cache/
252
+ cover/
253
+
254
+ # Translations
255
+ *.mo
256
+ *.pot
257
+
258
+ # Django stuff:
259
+ *.log
260
+ local_settings.py
261
+ db.sqlite3
262
+ db.sqlite3-journal
263
+
264
+ # Flask stuff:
265
+ instance/
266
+ .webassets-cache
267
+
268
+ # Scrapy stuff:
269
+ .scrapy
270
+
271
+ # Sphinx documentation
272
+ docs/_build/
273
+
274
+ # PyBuilder
275
+ .pybuilder/
276
+ target/
277
+
278
+ # Jupyter Notebook
279
+
280
+ # IPython
281
+
282
+ # pyenv
283
+ # For a library or package, you might want to ignore these files since the code is
284
+ # intended to run in multiple environments; otherwise, check them in:
285
+ # .python-version
286
+
287
+ # pipenv
288
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
289
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
290
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
291
+ # install all needed dependencies.
292
+ #Pipfile.lock
293
+
294
+ # poetry
295
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
296
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
297
+ # commonly ignored for libraries.
298
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
299
+ #poetry.lock
300
+
301
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
302
+ __pypackages__/
303
+
304
+ # Celery stuff
305
+ celerybeat-schedule
306
+ celerybeat.pid
307
+
308
+ # SageMath parsed files
309
+ *.sage.py
310
+
311
+ # Environments
312
+ .env
313
+ .venv
314
+ env/
315
+ venv/
316
+ ENV/
317
+ env.bak/
318
+ venv.bak/
319
+
320
+ # Spyder project settings
321
+ .spyderproject
322
+ .spyproject
323
+
324
+ # Rope project settings
325
+ .ropeproject
326
+
327
+ # mkdocs documentation
328
+ /site
329
+
330
+ # mypy
331
+ .mypy_cache/
332
+ .dmypy.json
333
+ dmypy.json
334
+
335
+ # Pyre type checker
336
+ .pyre/
337
+
338
+ # pytype static type analyzer
339
+ .pytype/
340
+
341
+ # Cython debug symbols
342
+ cython_debug/
343
+
344
+ # PyCharm
345
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
346
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
347
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
348
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
349
+ #.idea/
350
+
351
+ # End of https://www.toptal.com/developers/gitignore/api/python,pycharm,jupyternotebooks,images
app.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ # Custom imports
4
+ from src.app import MultiPage
5
+ from src.app import explore_infoscc_gan, explore_biggan, explore_cvae, compare_models
6
+
7
+ # Create an instance of the app
8
+ app = MultiPage()
9
+
10
+ # Title of the main page
11
+ st.title('Galaxy Zoo generation')
12
+
13
+ # Add all your applications (pages) here
14
+ app.add_page('Compare models', compare_models.app)
15
+ app.add_page('Explore BigGAN', explore_biggan.app)
16
+ app.add_page('Explore cVAE', explore_cvae.app)
17
+ app.add_page('Explore InfoSCC-GAN', explore_infoscc_gan.app)
18
+
19
+ # The main app
20
+ app.run()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ -f https://download.pytorch.org/whl/torch_stable.html
2
+ gdown
3
+ googledrivedownloader==0.4
4
+ pandas==1.4.1
5
+ streamlit==1.7.0
6
+ torch==1.9.1+cpu
7
+ torchvision==0.10.1+cpu
src/app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .multipage import MultiPage
src/app/compare_models.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import math
3
+
4
+ import streamlit as st
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import src.app.params as params
11
+ from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4, q4_options, q5, q5_options, \
12
+ q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
13
+ from src.models import ConditionalGenerator as InfoSCC_GAN
14
+ from src.models.big.BigGAN2 import Generator as BigGAN2Generator
15
+ from src.models import ConditionalDecoder as cVAE
16
+ from src.data import get_labels_train, make_galaxy_labels_hierarchical
17
+ from src.utils import download_file, sample_labels
18
+
19
+
20
+ device = params.device
21
+ bs = 10 # number of images to generate each model
22
+ n_cols = int(math.sqrt(bs))
23
+ size = params.size
24
+ n_layers = int(math.log2(size) - 2)
25
+
26
+ # manual labels
27
+ q1_out = [0] * len(q1_options)
28
+ q2_out = [0] * len(q2_options)
29
+ q3_out = [0] * len(q3_options)
30
+ q4_out = [0] * len(q4_options)
31
+ q5_out = [0] * len(q5_options)
32
+ q6_out = [0] * len(q6_options)
33
+ q7_out = [0] * len(q7_options)
34
+ q8_out = [0] * len(q8_options)
35
+ q9_out = [0] * len(q9_options)
36
+ q10_out = [0] * len(q10_options)
37
+ q11_out = [0] * len(q11_options)
38
+
39
+
40
+ def clear_out(elems=None):
41
+ global q1_out, q2_out, q3_out, q4_out, q5_out, q6_out, q6_out, q7_out, q8_out, q9_out, q10_out, q11_out
42
+
43
+ if elems is None:
44
+ elems = list(range(1, 12))
45
+
46
+ if 1 in elems:
47
+ q1_out = [0] * len(q1_options)
48
+ if 2 in elems:
49
+ q2_out = [0] * len(q2_options)
50
+ if 3 in elems:
51
+ q3_out = [0] * len(q3_options)
52
+ if 4 in elems:
53
+ q4_out = [0] * len(q4_options)
54
+ if 5 in elems:
55
+ q5_out = [0] * len(q5_options)
56
+ if 6 in elems:
57
+ q6_out = [0] * len(q6_options)
58
+ if 7 in elems:
59
+ q7_out = [0] * len(q7_options)
60
+ if 8 in elems:
61
+ q8_out = [0] * len(q8_options)
62
+ if 9 in elems:
63
+ q9_out = [0] * len(q9_options)
64
+ if 10 in elems:
65
+ q10_out = [0] * len(q10_options)
66
+ if 11 in elems:
67
+ q11_out = [0] * len(q11_options)
68
+
69
+
70
+ @st.cache(allow_output_mutation=True)
71
+ def load_model(model_type: str):
72
+
73
+ print(f'Loading model: {model_type}')
74
+ if model_type == 'InfoSCC-GAN':
75
+ g = InfoSCC_GAN(size=params.size,
76
+ y_size=params.shape_label,
77
+ z_size=params.noise_dim)
78
+
79
+ if not Path(params.path_infoscc_gan).exists():
80
+ download_file(params.drive_id_infoscc_gan, params.path_infoscc_gan)
81
+
82
+ ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
83
+ g.load_state_dict(ckpt['g_ema'])
84
+ elif model_type == 'BigGAN':
85
+ g = BigGAN2Generator()
86
+
87
+ if not Path(params.path_biggan).exists():
88
+ download_file(params.drive_id_biggan, params.path_biggan)
89
+
90
+ ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
91
+ g.load_state_dict(ckpt)
92
+ elif model_type == 'cVAE':
93
+ g = cVAE()
94
+
95
+ if not Path(params.path_cvae).exists():
96
+ download_file(params.drive_id_cvae, params.path_cvae)
97
+
98
+ ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
99
+ g.load_state_dict(ckpt)
100
+ else:
101
+ raise ValueError('Unsupported model')
102
+ g = g.eval().to(device=params.device)
103
+ return g
104
+
105
+
106
+ @st.cache
107
+ def get_labels() -> torch.Tensor:
108
+ path_labels = params.path_labels
109
+
110
+ if not Path(path_labels).exists():
111
+ download_file(params.drive_id_labels, path_labels)
112
+
113
+ labels_train = get_labels_train(path_labels)
114
+ return labels_train
115
+
116
+
117
+ def get_eps(n: int) -> torch.Tensor:
118
+ eps = torch.randn((n, params.dim_z), device=device)
119
+ return eps
120
+
121
+
122
+ def app():
123
+ global q1_out, q2_out, q3_out, q4_out, q5_out, q6_out, q6_out, q7_out, q8_out, q9_out, q10_out, q11_out
124
+
125
+ st.title('Compare models')
126
+ st.markdown('This demo allows to compare BigGAN, InfoSCC-GAN and cVAE models for conditional galaxy generation.')
127
+ st.markdown('In each there there are images generated with the same labels by each of the models')
128
+
129
+ biggan = load_model('BigGAN')
130
+ infoscc_gan = load_model('InfoSCC-GAN')
131
+ cvae = load_model('cVAE')
132
+ labels_train = get_labels()
133
+
134
+ eps = get_eps(bs) # for BigGAN and cVAE
135
+ eps_infoscc = infoscc_gan.sample_eps(bs)
136
+
137
+ zs = np.array([[0.0] * params.n_basis] * n_layers, dtype=np.float32)
138
+ zs_torch = torch.from_numpy(zs).unsqueeze(0).repeat(bs, 1, 1).to(device)
139
+
140
+ # ========================== Labels ================================
141
+ st.subheader('Label')
142
+ st.markdown(r'There are two types of selecting labels: __Random__ - sample random samples from the dataset;'
143
+ r' __Manual__ - select labels manually (advanced use). When using __Manual__ all of the images will be'
144
+ r' generated with tha same labels')
145
+ label_type = st.radio('Label type', options=['Random', 'Manual (Advanced)'])
146
+ if label_type == 'Random':
147
+ labels = sample_labels(labels_train, bs).to(device)
148
+
149
+ st.markdown(r'Click on __Sample labels__ button to sample random input labels')
150
+ change_label = st.button('Sample label')
151
+
152
+ if change_label:
153
+ labels = sample_labels(labels_train, bs).to(device)
154
+ elif label_type == 'Manual (Advanced)':
155
+ st.markdown('Answer the questions below')
156
+
157
+ q1_select_box = st.selectbox(q1, options=q1_options)
158
+ clear_out()
159
+ q1_out[q1_options.index(q1_select_box)] = 1
160
+ # 1
161
+
162
+ if q1_select_box == 'Smooth':
163
+ q7_select_box = st.selectbox(q7, options=q7_options)
164
+ clear_out([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
165
+ q7_out[q7_options.index(q7_select_box)] = 1
166
+ # 1 - 7
167
+
168
+ q6_select_box = st.selectbox(q6, options=q6_options)
169
+ clear_out([2, 3, 4, 5, 6, 8, 9, 10, 11])
170
+ q6_out[q6_options.index(q6_select_box)] = 1
171
+ # 1 - 7 - 6
172
+
173
+ if q6_select_box == 'Yes':
174
+ q8_select_box = st.selectbox(q8, options=q8_options)
175
+ clear_out([2, 3, 4, 5, 8, 9, 10, 11])
176
+ q8_out[q8_options.index(q8_select_box)] = 1
177
+ # 1 - 7 - 6 - 8 - end
178
+
179
+ elif q1_select_box == 'Features or disk':
180
+ q2_select_box = st.selectbox(q2, options=q2_options)
181
+ clear_out([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
182
+ q2_out[q2_options.index(q2_select_box)] = 1
183
+ # 1 - 2
184
+
185
+ if q2_select_box == 'Yes':
186
+ q9_select_box = st.selectbox(q9, options=q9_options)
187
+ clear_out([3, 4, 5, 6, 7, 8, 9, 10, 11])
188
+ q9_out[q9_options.index(q9_select_box)] = 1
189
+ # 1 - 2 - 9
190
+
191
+ q6_select_box = st.selectbox(q6, options=q6_options)
192
+ clear_out([3, 4, 5, 6, 7, 8, 10, 11])
193
+ q6_out[q6_options.index(q6_select_box)] = 1
194
+ # 1 - 2 - 9 - 6
195
+
196
+ if q6_select_box == 'Yes':
197
+ q8_select_box = st.selectbox(q8, options=q8_options)
198
+ clear_out([3, 4, 5, 7, 8, 10, 11])
199
+ q8_out[q8_options.index(q8_select_box)] = 1
200
+ # 1 - 2 - 9 - 6 - 8
201
+ else:
202
+ q3_select_box = st.selectbox(q3, options=q3_options)
203
+ clear_out([3, 4, 5, 6, 7, 8, 9, 10, 11])
204
+ q3_out[q3_options.index(q3_select_box)] = 1
205
+ # 1 - 2 - 3
206
+
207
+ q4_select_box = st.selectbox(q4, options=q4_options)
208
+ clear_out([4, 5, 6, 7, 8, 9, 10, 11])
209
+ q4_out[q4_options.index(q4_select_box)] = 1
210
+ # 1 - 2 - 3 - 4
211
+
212
+ if q4_select_box == 'Yes':
213
+ q10_select_box = st.selectbox(q10, options=q10_options)
214
+ clear_out([5, 6, 7, 8, 9, 10, 11])
215
+ q10_out[q10_options.index(q10_select_box)] = 1
216
+ # 1 - 2 - 3 - 4 - 10
217
+
218
+ q11_select_box = st.selectbox(q11, options=q11_options)
219
+ clear_out([5, 6, 7, 8, 9, 11])
220
+ q11_out[q11_options.index(q11_select_box)] = 1
221
+ # 1 - 2 - 3 - 4 - 10 - 11
222
+
223
+ q5_select_box = st.selectbox(q5, options=q5_options)
224
+ clear_out([5, 6, 7, 8, 9])
225
+ q5_out[q5_options.index(q5_select_box)] = 1
226
+ # 1 - 2 - 3 - 4 - 10 - 11 - 5
227
+
228
+ q6_select_box = st.selectbox(q6, options=q6_options)
229
+ clear_out([6, 7, 8, 9])
230
+ q6_out[q6_options.index(q6_select_box)] = 1
231
+ # 1 - 2 - 3 - 4 - 10 - 11 - 5 - 6
232
+
233
+ if q6_select_box == 'Yes':
234
+ q8_select_box = st.selectbox(q8, options=q8_options)
235
+ clear_out([7, 8, 9])
236
+ q8_out[q8_options.index(q8_select_box)] = 1
237
+ # 1 - 2 - 3 - 4 - 10 - 11 - 5 - 6 - 8 - End
238
+ else:
239
+ q5_select_box = st.selectbox(q5, options=q5_options)
240
+ clear_out([5, 6, 7, 8, 9, 10, 11])
241
+ q5_out[q5_options.index(q5_select_box)] = 1
242
+ # 1 - 2 - 3 - 4 - 5
243
+
244
+ q6_select_box = st.selectbox(q6, options=q6_options)
245
+ clear_out([6, 7, 8, 9, 10, 11])
246
+ q6_out[q6_options.index(q6_select_box)] = 1
247
+ # 1 - 2 - 3 - 4 - 5 - 6
248
+
249
+ if q6_select_box == 'Yes':
250
+ q8_select_box = st.selectbox(q8, options=q8_options)
251
+ clear_out([7, 8, 9, 10, 11])
252
+ q8_out[q8_options.index(q8_select_box)] = 1
253
+ # 1 - 2 - 3 - 4 - 5 - 6 - 8 - End
254
+
255
+ labels = [*q1_out, *q2_out, *q3_out, *q4_out, *q5_out, *q6_out, *q7_out, *q8_out, *q9_out, *q10_out, *q11_out]
256
+ labels = torch.Tensor(labels).to(device)
257
+ labels = labels.unsqueeze(0).repeat(bs, 1)
258
+ labels = make_galaxy_labels_hierarchical(labels)
259
+ clear_out()
260
+ # ========================== Labels ================================
261
+
262
+ st.subheader('Noise')
263
+ st.markdown(r'Click on __Change eps__ button to change input $\varepsilon$ latent space')
264
+ change_eps = st.button('Change eps')
265
+ if change_eps:
266
+ eps = get_eps(bs) # for BigGAN and cVAE
267
+ eps_infoscc = infoscc_gan.sample_eps(bs)
268
+
269
+ with torch.no_grad():
270
+ imgs_biggan = biggan(eps, labels).squeeze(0).cpu()
271
+ imgs_infoscc = infoscc_gan(labels, eps_infoscc, zs_torch).squeeze(0).cpu()
272
+ imgs_cvae = cvae(eps, labels).squeeze(0).cpu()
273
+
274
+ if params.upsample:
275
+ imgs_biggan = F.interpolate(imgs_biggan, (size * 4, size * 4), mode='bicubic')
276
+ imgs_infoscc = F.interpolate(imgs_infoscc, (size * 4, size * 4), mode='bicubic')
277
+ imgs_cvae = F.interpolate(imgs_cvae, (size * 4, size * 4), mode='bicubic')
278
+
279
+ imgs_biggan = torch.clip(imgs_biggan, 0, 1)
280
+ imgs_biggan = [(imgs_biggan[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) for i in range(bs)]
281
+ imgs_infoscc = [(imgs_infoscc[i].permute(1, 2, 0).numpy() * 127.5 + 127.5).astype(np.uint8) for i in range(bs)]
282
+ imgs_cvae = [(imgs_cvae[i].permute(1, 2, 0).numpy() * 127.5 + 127.5).astype(np.uint8) for i in range(bs)]
283
+
284
+ c1, c2, c3 = st.columns(3)
285
+ c1.header('BigGAN')
286
+ c1.image(imgs_biggan, use_column_width=True)
287
+
288
+ c2.header('InfoSCC-GAN')
289
+ c2.image(imgs_infoscc, use_column_width=True)
290
+
291
+ c3.header('cVAE')
292
+ c3.image(imgs_cvae, use_column_width=True)
src/app/explore_biggan.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from pathlib import Path
3
+
4
+ import streamlit as st
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import src.app.params as params
11
+ from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4, q4_options, q5, q5_options, \
12
+ q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
13
+ from src.models.big.BigGAN2 import Generator as BigGAN2Generator
14
+ from src.data import get_labels_train, make_galaxy_labels_hierarchical
15
+ from src.utils import download_file, sample_labels
16
+
17
+
18
+ # global parameters
19
+ device = params.device
20
+ size = params.size
21
+ y_size = shape_label = params.shape_label
22
+ n_channels = params.n_channels
23
+ upsample = params.upsample
24
+ dim_z = params.dim_z
25
+ bs = 16 # number of samples to generate
26
+ n_cols = int(math.sqrt(bs))
27
+ model_path = params.path_biggan
28
+ drive_id = params.drive_id_biggan
29
+ path_labels = params.path_labels
30
+
31
+ # manual labels
32
+ q1_out = [0] * len(q1_options)
33
+ q2_out = [0] * len(q2_options)
34
+ q3_out = [0] * len(q3_options)
35
+ q4_out = [0] * len(q4_options)
36
+ q5_out = [0] * len(q5_options)
37
+ q6_out = [0] * len(q6_options)
38
+ q7_out = [0] * len(q7_options)
39
+ q8_out = [0] * len(q8_options)
40
+ q9_out = [0] * len(q9_options)
41
+ q10_out = [0] * len(q10_options)
42
+ q11_out = [0] * len(q11_options)
43
+
44
+
45
+ def clear_out(elems=None):
46
+ global q1_out, q2_out, q3_out, q4_out, q5_out, q6_out, q6_out, q7_out, q8_out, q9_out, q10_out, q11_out
47
+
48
+ if elems is None:
49
+ elems = list(range(1, 12))
50
+
51
+ if 1 in elems:
52
+ q1_out = [0] * len(q1_options)
53
+ if 2 in elems:
54
+ q2_out = [0] * len(q2_options)
55
+ if 3 in elems:
56
+ q3_out = [0] * len(q3_options)
57
+ if 4 in elems:
58
+ q4_out = [0] * len(q4_options)
59
+ if 5 in elems:
60
+ q5_out = [0] * len(q5_options)
61
+ if 6 in elems:
62
+ q6_out = [0] * len(q6_options)
63
+ if 7 in elems:
64
+ q7_out = [0] * len(q7_options)
65
+ if 8 in elems:
66
+ q8_out = [0] * len(q8_options)
67
+ if 9 in elems:
68
+ q9_out = [0] * len(q9_options)
69
+ if 10 in elems:
70
+ q10_out = [0] * len(q10_options)
71
+ if 11 in elems:
72
+ q11_out = [0] * len(q11_options)
73
+
74
+
75
+ @st.cache(allow_output_mutation=True)
76
+ def load_model(model_path: str) -> BigGAN2Generator:
77
+
78
+ print(f'Loading model: {model_path}')
79
+ g = BigGAN2Generator()
80
+ ckpt = torch.load(model_path, map_location=torch.device('cpu'))
81
+ g.load_state_dict(ckpt)
82
+ g.eval().to(device)
83
+ return g
84
+
85
+
86
+ def get_eps(n: int) -> torch.Tensor:
87
+ eps = torch.randn((n, dim_z), device=device)
88
+ return eps
89
+
90
+
91
+ @st.cache
92
+ def get_labels() -> torch.Tensor:
93
+ if not Path(path_labels).exists():
94
+ download_file(params.drive_id_labels, path_labels)
95
+
96
+ labels_train = get_labels_train(path_labels)
97
+ return labels_train
98
+
99
+
100
+ def app():
101
+ global q1_out, q2_out, q3_out, q4_out, q5_out, q6_out, q6_out, q7_out, q8_out, q9_out, q10_out, q11_out
102
+
103
+ st.title('Explore BigGAN')
104
+ st.markdown('This demo shows BigGAN for conditional galaxy generation')
105
+
106
+ if not Path(model_path).exists():
107
+ download_file(drive_id, model_path)
108
+
109
+ model = load_model(model_path)
110
+ eps = get_eps(bs)
111
+ labels_train = get_labels()
112
+
113
+ # ========================== Labels ================================
114
+ st.subheader('Label')
115
+ st.markdown(r'There are two types of selecting labels: __Random__ - sample random samples from the dataset;'
116
+ r' __Manual__ - select labels manually (advanced use). When using __Manual__ all of the images will be'
117
+ r' generated with tha same labels')
118
+ label_type = st.radio('Label type', options=['Random', 'Manual (Advanced)'])
119
+ if label_type == 'Random':
120
+ labels = sample_labels(labels_train, bs).to(device)
121
+
122
+ st.markdown(r'Click on __Sample labels__ button to sample random input labels')
123
+ change_label = st.button('Sample label')
124
+
125
+ if change_label:
126
+ labels = sample_labels(labels_train, bs).to(device)
127
+ elif label_type == 'Manual (Advanced)':
128
+ st.markdown('Answer the questions below')
129
+
130
+ q1_select_box = st.selectbox(q1, options=q1_options)
131
+ clear_out()
132
+ q1_out[q1_options.index(q1_select_box)] = 1
133
+ # 1
134
+
135
+ if q1_select_box == 'Smooth':
136
+ q7_select_box = st.selectbox(q7, options=q7_options)
137
+ clear_out([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
138
+ q7_out[q7_options.index(q7_select_box)] = 1
139
+ # 1 - 7
140
+
141
+ q6_select_box = st.selectbox(q6, options=q6_options)
142
+ clear_out([2, 3, 4, 5, 6, 8, 9, 10, 11])
143
+ q6_out[q6_options.index(q6_select_box)] = 1
144
+ # 1 - 7 - 6
145
+
146
+ if q6_select_box == 'Yes':
147
+ q8_select_box = st.selectbox(q8, options=q8_options)
148
+ clear_out([2, 3, 4, 5, 8, 9, 10, 11])
149
+ q8_out[q8_options.index(q8_select_box)] = 1
150
+ # 1 - 7 - 6 - 8 - end
151
+
152
+ elif q1_select_box == 'Features or disk':
153
+ q2_select_box = st.selectbox(q2, options=q2_options)
154
+ clear_out([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
155
+ q2_out[q2_options.index(q2_select_box)] = 1
156
+ # 1 - 2
157
+
158
+ if q2_select_box == 'Yes':
159
+ q9_select_box = st.selectbox(q9, options=q9_options)
160
+ clear_out([3, 4, 5, 6, 7, 8, 9, 10, 11])
161
+ q9_out[q9_options.index(q9_select_box)] = 1
162
+ # 1 - 2 - 9
163
+
164
+ q6_select_box = st.selectbox(q6, options=q6_options)
165
+ clear_out([3, 4, 5, 6, 7, 8, 10, 11])
166
+ q6_out[q6_options.index(q6_select_box)] = 1
167
+ # 1 - 2 - 9 - 6
168
+
169
+ if q6_select_box == 'Yes':
170
+ q8_select_box = st.selectbox(q8, options=q8_options)
171
+ clear_out([3, 4, 5, 7, 8, 10, 11])
172
+ q8_out[q8_options.index(q8_select_box)] = 1
173
+ # 1 - 2 - 9 - 6 - 8
174
+ else:
175
+ q3_select_box = st.selectbox(q3, options=q3_options)
176
+ clear_out([3, 4, 5, 6, 7, 8, 9, 10, 11])
177
+ q3_out[q3_options.index(q3_select_box)] = 1
178
+ # 1 - 2 - 3
179
+
180
+ q4_select_box = st.selectbox(q4, options=q4_options)
181
+ clear_out([4, 5, 6, 7, 8, 9, 10, 11])
182
+ q4_out[q4_options.index(q4_select_box)] = 1
183
+ # 1 - 2 - 3 - 4
184
+
185
+ if q4_select_box == 'Yes':
186
+ q10_select_box = st.selectbox(q10, options=q10_options)
187
+ clear_out([5, 6, 7, 8, 9, 10, 11])
188
+ q10_out[q10_options.index(q10_select_box)] = 1
189
+ # 1 - 2 - 3 - 4 - 10
190
+
191
+ q11_select_box = st.selectbox(q11, options=q11_options)
192
+ clear_out([5, 6, 7, 8, 9, 11])
193
+ q11_out[q11_options.index(q11_select_box)] = 1
194
+ # 1 - 2 - 3 - 4 - 10 - 11
195
+
196
+ q5_select_box = st.selectbox(q5, options=q5_options)
197
+ clear_out([5, 6, 7, 8, 9])
198
+ q5_out[q5_options.index(q5_select_box)] = 1
199
+ # 1 - 2 - 3 - 4 - 10 - 11 - 5
200
+
201
+ q6_select_box = st.selectbox(q6, options=q6_options)
202
+ clear_out([6, 7, 8, 9])
203
+ q6_out[q6_options.index(q6_select_box)] = 1
204
+ # 1 - 2 - 3 - 4 - 10 - 11 - 5 - 6
205
+
206
+ if q6_select_box == 'Yes':
207
+ q8_select_box = st.selectbox(q8, options=q8_options)
208
+ clear_out([7, 8, 9])
209
+ q8_out[q8_options.index(q8_select_box)] = 1
210
+ # 1 - 2 - 3 - 4 - 10 - 11 - 5 - 6 - 8 - End
211
+ else:
212
+ q5_select_box = st.selectbox(q5, options=q5_options)
213
+ clear_out([5, 6, 7, 8, 9, 10, 11])
214
+ q5_out[q5_options.index(q5_select_box)] = 1
215
+ # 1 - 2 - 3 - 4 - 5
216
+
217
+ q6_select_box = st.selectbox(q6, options=q6_options)
218
+ clear_out([6, 7, 8, 9, 10, 11])
219
+ q6_out[q6_options.index(q6_select_box)] = 1
220
+ # 1 - 2 - 3 - 4 - 5 - 6
221
+
222
+ if q6_select_box == 'Yes':
223
+ q8_select_box = st.selectbox(q8, options=q8_options)
224
+ clear_out([7, 8, 9, 10, 11])
225
+ q8_out[q8_options.index(q8_select_box)] = 1
226
+ # 1 - 2 - 3 - 4 - 5 - 6 - 8 - End
227
+
228
+ labels = [*q1_out, *q2_out, *q3_out, *q4_out, *q5_out, *q6_out, *q7_out, *q8_out, *q9_out, *q10_out, *q11_out]
229
+ labels = torch.Tensor(labels).to(device)
230
+ labels = labels.unsqueeze(0).repeat(bs, 1)
231
+ labels = make_galaxy_labels_hierarchical(labels)
232
+ clear_out()
233
+ # ========================== Labels ================================
234
+
235
+ st.subheader('Noise')
236
+ st.markdown(r'Click on __Change eps__ button to change input $\varepsilon$ latent space')
237
+ change_eps = st.button('Change eps')
238
+ if change_eps:
239
+ eps = get_eps(bs)
240
+
241
+ with torch.no_grad():
242
+ imgs = model(eps, labels)
243
+
244
+ if upsample:
245
+ imgs = F.interpolate(imgs, (size * 4, size * 4), mode='bicubic')
246
+
247
+ imgs = torch.clip(imgs, 0, 1)
248
+ imgs = [(imgs[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) for i in range(bs)]
249
+
250
+ counter = 0
251
+ for r in range(bs // n_cols):
252
+ cols = st.columns(n_cols)
253
+
254
+ for c in range(n_cols):
255
+ cols[c].image(imgs[counter])
256
+ counter += 1
src/app/explore_cvae.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from pathlib import Path
3
+
4
+ import streamlit as st
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import src.app.params as params
11
+ from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4, q4_options, q5, q5_options, \
12
+ q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
13
+ from src.models import ConditionalDecoder
14
+ from src.data import get_labels_train, make_galaxy_labels_hierarchical
15
+ from src.utils import download_file, sample_labels
16
+
17
+
18
+ # global parameters
19
+ device = params.device
20
+ size = params.size
21
+ y_size = shape_label = params.shape_label
22
+ n_channels = params.n_channels
23
+ upsample = params.upsample
24
+ dim_z = params.dim_z
25
+ bs = 16 # number of samples to generate
26
+ n_cols = int(math.sqrt(bs))
27
+ model_path = params.path_cvae
28
+ drive_id = params.drive_id_cvae
29
+ path_labels = params.path_labels
30
+
31
+ # manual labels
32
+ q1_out = [0] * len(q1_options)
33
+ q2_out = [0] * len(q2_options)
34
+ q3_out = [0] * len(q3_options)
35
+ q4_out = [0] * len(q4_options)
36
+ q5_out = [0] * len(q5_options)
37
+ q6_out = [0] * len(q6_options)
38
+ q7_out = [0] * len(q7_options)
39
+ q8_out = [0] * len(q8_options)
40
+ q9_out = [0] * len(q9_options)
41
+ q10_out = [0] * len(q10_options)
42
+ q11_out = [0] * len(q11_options)
43
+
44
+
45
+ def clear_out(elems=None):
46
+ global q1_out, q2_out, q3_out, q4_out, q5_out, q6_out, q6_out, q7_out, q8_out, q9_out, q10_out, q11_out
47
+
48
+ if elems is None:
49
+ elems = list(range(1, 12))
50
+
51
+ if 1 in elems:
52
+ q1_out = [0] * len(q1_options)
53
+ if 2 in elems:
54
+ q2_out = [0] * len(q2_options)
55
+ if 3 in elems:
56
+ q3_out = [0] * len(q3_options)
57
+ if 4 in elems:
58
+ q4_out = [0] * len(q4_options)
59
+ if 5 in elems:
60
+ q5_out = [0] * len(q5_options)
61
+ if 6 in elems:
62
+ q6_out = [0] * len(q6_options)
63
+ if 7 in elems:
64
+ q7_out = [0] * len(q7_options)
65
+ if 8 in elems:
66
+ q8_out = [0] * len(q8_options)
67
+ if 9 in elems:
68
+ q9_out = [0] * len(q9_options)
69
+ if 10 in elems:
70
+ q10_out = [0] * len(q10_options)
71
+ if 11 in elems:
72
+ q11_out = [0] * len(q11_options)
73
+
74
+
75
+ @st.cache(allow_output_mutation=True)
76
+ def load_model(model_path: str) -> ConditionalDecoder:
77
+
78
+ print(f'Loading model: {model_path}')
79
+ g = ConditionalDecoder()
80
+ ckpt = torch.load(model_path, map_location=torch.device('cpu'))
81
+ g.load_state_dict(ckpt)
82
+ g.eval().to(device)
83
+ return g
84
+
85
+
86
+ def get_eps(n: int) -> torch.Tensor:
87
+ eps = torch.randn((n, dim_z), device=device)
88
+ return eps
89
+
90
+
91
+ @st.cache
92
+ def get_labels() -> torch.Tensor:
93
+ if not Path(path_labels).exists():
94
+ download_file(params.drive_id_labels, path_labels)
95
+
96
+ labels_train = get_labels_train(path_labels)
97
+ return labels_train
98
+
99
+
100
+ def app():
101
+ global q1_out, q2_out, q3_out, q4_out, q5_out, q6_out, q6_out, q7_out, q8_out, q9_out, q10_out, q11_out
102
+
103
+ st.title('Explore cVAE')
104
+ st.markdown('This demo shows cVAE for conditional galaxy generation')
105
+
106
+ if not Path(model_path).exists():
107
+ download_file(drive_id, model_path)
108
+
109
+ model = load_model(model_path)
110
+ eps = get_eps(bs)
111
+ labels_train = get_labels()
112
+
113
+ # ========================== Labels ================================
114
+ st.subheader('Label')
115
+ st.markdown(r'There are two types of selecting labels: __Random__ - sample random samples from the dataset;'
116
+ r' __Manual__ - select labels manually (advanced use). When using __Manual__ all of the images will be'
117
+ r' generated with tha same labels')
118
+ label_type = st.radio('Label type', options=['Random', 'Manual (Advanced)'])
119
+ if label_type == 'Random':
120
+ labels = sample_labels(labels_train, bs).to(device)
121
+
122
+ st.markdown(r'Click on __Sample labels__ button to sample random input labels')
123
+ change_label = st.button('Sample label')
124
+
125
+ if change_label:
126
+ labels = sample_labels(labels_train, bs).to(device)
127
+ elif label_type == 'Manual (Advanced)':
128
+ st.markdown('Answer the questions below')
129
+
130
+ q1_select_box = st.selectbox(q1, options=q1_options)
131
+ clear_out()
132
+ q1_out[q1_options.index(q1_select_box)] = 1
133
+ # 1
134
+
135
+ if q1_select_box == 'Smooth':
136
+ q7_select_box = st.selectbox(q7, options=q7_options)
137
+ clear_out([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
138
+ q7_out[q7_options.index(q7_select_box)] = 1
139
+ # 1 - 7
140
+
141
+ q6_select_box = st.selectbox(q6, options=q6_options)
142
+ clear_out([2, 3, 4, 5, 6, 8, 9, 10, 11])
143
+ q6_out[q6_options.index(q6_select_box)] = 1
144
+ # 1 - 7 - 6
145
+
146
+ if q6_select_box == 'Yes':
147
+ q8_select_box = st.selectbox(q8, options=q8_options)
148
+ clear_out([2, 3, 4, 5, 8, 9, 10, 11])
149
+ q8_out[q8_options.index(q8_select_box)] = 1
150
+ # 1 - 7 - 6 - 8 - end
151
+
152
+ elif q1_select_box == 'Features or disk':
153
+ q2_select_box = st.selectbox(q2, options=q2_options)
154
+ clear_out([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
155
+ q2_out[q2_options.index(q2_select_box)] = 1
156
+ # 1 - 2
157
+
158
+ if q2_select_box == 'Yes':
159
+ q9_select_box = st.selectbox(q9, options=q9_options)
160
+ clear_out([3, 4, 5, 6, 7, 8, 9, 10, 11])
161
+ q9_out[q9_options.index(q9_select_box)] = 1
162
+ # 1 - 2 - 9
163
+
164
+ q6_select_box = st.selectbox(q6, options=q6_options)
165
+ clear_out([3, 4, 5, 6, 7, 8, 10, 11])
166
+ q6_out[q6_options.index(q6_select_box)] = 1
167
+ # 1 - 2 - 9 - 6
168
+
169
+ if q6_select_box == 'Yes':
170
+ q8_select_box = st.selectbox(q8, options=q8_options)
171
+ clear_out([3, 4, 5, 7, 8, 10, 11])
172
+ q8_out[q8_options.index(q8_select_box)] = 1
173
+ # 1 - 2 - 9 - 6 - 8
174
+ else:
175
+ q3_select_box = st.selectbox(q3, options=q3_options)
176
+ clear_out([3, 4, 5, 6, 7, 8, 9, 10, 11])
177
+ q3_out[q3_options.index(q3_select_box)] = 1
178
+ # 1 - 2 - 3
179
+
180
+ q4_select_box = st.selectbox(q4, options=q4_options)
181
+ clear_out([4, 5, 6, 7, 8, 9, 10, 11])
182
+ q4_out[q4_options.index(q4_select_box)] = 1
183
+ # 1 - 2 - 3 - 4
184
+
185
+ if q4_select_box == 'Yes':
186
+ q10_select_box = st.selectbox(q10, options=q10_options)
187
+ clear_out([5, 6, 7, 8, 9, 10, 11])
188
+ q10_out[q10_options.index(q10_select_box)] = 1
189
+ # 1 - 2 - 3 - 4 - 10
190
+
191
+ q11_select_box = st.selectbox(q11, options=q11_options)
192
+ clear_out([5, 6, 7, 8, 9, 11])
193
+ q11_out[q11_options.index(q11_select_box)] = 1
194
+ # 1 - 2 - 3 - 4 - 10 - 11
195
+
196
+ q5_select_box = st.selectbox(q5, options=q5_options)
197
+ clear_out([5, 6, 7, 8, 9])
198
+ q5_out[q5_options.index(q5_select_box)] = 1
199
+ # 1 - 2 - 3 - 4 - 10 - 11 - 5
200
+
201
+ q6_select_box = st.selectbox(q6, options=q6_options)
202
+ clear_out([6, 7, 8, 9])
203
+ q6_out[q6_options.index(q6_select_box)] = 1
204
+ # 1 - 2 - 3 - 4 - 10 - 11 - 5 - 6
205
+
206
+ if q6_select_box == 'Yes':
207
+ q8_select_box = st.selectbox(q8, options=q8_options)
208
+ clear_out([7, 8, 9])
209
+ q8_out[q8_options.index(q8_select_box)] = 1
210
+ # 1 - 2 - 3 - 4 - 10 - 11 - 5 - 6 - 8 - End
211
+ else:
212
+ q5_select_box = st.selectbox(q5, options=q5_options)
213
+ clear_out([5, 6, 7, 8, 9, 10, 11])
214
+ q5_out[q5_options.index(q5_select_box)] = 1
215
+ # 1 - 2 - 3 - 4 - 5
216
+
217
+ q6_select_box = st.selectbox(q6, options=q6_options)
218
+ clear_out([6, 7, 8, 9, 10, 11])
219
+ q6_out[q6_options.index(q6_select_box)] = 1
220
+ # 1 - 2 - 3 - 4 - 5 - 6
221
+
222
+ if q6_select_box == 'Yes':
223
+ q8_select_box = st.selectbox(q8, options=q8_options)
224
+ clear_out([7, 8, 9, 10, 11])
225
+ q8_out[q8_options.index(q8_select_box)] = 1
226
+ # 1 - 2 - 3 - 4 - 5 - 6 - 8 - End
227
+
228
+ labels = [*q1_out, *q2_out, *q3_out, *q4_out, *q5_out, *q6_out, *q7_out, *q8_out, *q9_out, *q10_out, *q11_out]
229
+ labels = torch.Tensor(labels).to(device)
230
+ labels = labels.unsqueeze(0).repeat(bs, 1)
231
+ labels = make_galaxy_labels_hierarchical(labels)
232
+ clear_out()
233
+ # ========================== Labels ================================
234
+
235
+ st.subheader('Noise')
236
+ st.markdown(r'Click on __Change eps__ button to change input $\varepsilon$ latent space')
237
+ change_eps = st.button('Change eps')
238
+ if change_eps:
239
+ eps = get_eps(bs)
240
+
241
+ with torch.no_grad():
242
+ imgs = model(eps, labels)
243
+
244
+ if upsample:
245
+ imgs = F.interpolate(imgs, (size * 4, size * 4), mode='bicubic')
246
+
247
+ imgs = [(imgs[i].permute(1, 2, 0).numpy() * 127.5 + 127.5).astype(np.uint8) for i in range(bs)]
248
+
249
+ counter = 0
250
+ for r in range(bs // n_cols):
251
+ cols = st.columns(n_cols)
252
+
253
+ for c in range(n_cols):
254
+ cols[c].image(imgs[counter])
255
+ counter += 1
src/app/explore_infoscc_gan.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import math
3
+
4
+ import numpy as np
5
+ import streamlit as st
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import src.app.params as params
11
+ from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4, q4_options, q5, q5_options, \
12
+ q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
13
+ from src.models import ConditionalGenerator
14
+ from src.data import get_labels_train, make_galaxy_labels_hierarchical
15
+ from src.utils import download_file, sample_labels
16
+
17
+ # global parameters
18
+ device = params.device
19
+ size = params.size
20
+ y_size = params.shape_label
21
+ n_channels = params.n_channels
22
+ upsample = params.upsample
23
+ z_size = noise_dim = params.noise_dim
24
+ n_layers = int(math.log2(size) - 2)
25
+ n_basis = params.n_basis
26
+ y_type = params.y_type
27
+ bs = 16 # number of samples to generate
28
+ n_cols = int(math.sqrt(bs))
29
+ model_path = params.path_infoscc_gan # path to the model
30
+ drive_id = params.drive_id_infoscc_gan # google drive id of the model
31
+ path_labels = params.path_labels
32
+
33
+ # manual labels
34
+ q1_out = [0] * len(q1_options)
35
+ q2_out = [0] * len(q2_options)
36
+ q3_out = [0] * len(q3_options)
37
+ q4_out = [0] * len(q4_options)
38
+ q5_out = [0] * len(q5_options)
39
+ q6_out = [0] * len(q6_options)
40
+ q7_out = [0] * len(q7_options)
41
+ q8_out = [0] * len(q8_options)
42
+ q9_out = [0] * len(q9_options)
43
+ q10_out = [0] * len(q10_options)
44
+ q11_out = [0] * len(q11_options)
45
+
46
+
47
+ def clear_out(elems=None):
48
+ global q1_out, q2_out, q3_out, q4_out, q5_out, q6_out, q6_out, q7_out, q8_out, q9_out, q10_out, q11_out
49
+
50
+ if elems is None:
51
+ elems = list(range(1, 12))
52
+
53
+ if 1 in elems:
54
+ q1_out = [0] * len(q1_options)
55
+ if 2 in elems:
56
+ q2_out = [0] * len(q2_options)
57
+ if 3 in elems:
58
+ q3_out = [0] * len(q3_options)
59
+ if 4 in elems:
60
+ q4_out = [0] * len(q4_options)
61
+ if 5 in elems:
62
+ q5_out = [0] * len(q5_options)
63
+ if 6 in elems:
64
+ q6_out = [0] * len(q6_options)
65
+ if 7 in elems:
66
+ q7_out = [0] * len(q7_options)
67
+ if 8 in elems:
68
+ q8_out = [0] * len(q8_options)
69
+ if 9 in elems:
70
+ q9_out = [0] * len(q9_options)
71
+ if 10 in elems:
72
+ q10_out = [0] * len(q10_options)
73
+ if 11 in elems:
74
+ q11_out = [0] * len(q11_options)
75
+
76
+
77
+ @st.cache(allow_output_mutation=True)
78
+ def load_model(model_path: str) -> ConditionalGenerator:
79
+
80
+ print(f'Loading model: {model_path}')
81
+ g_ema = ConditionalGenerator(size, y_size, z_size, n_channels, n_basis, noise_dim)
82
+ ckpt = torch.load(model_path, map_location=torch.device('cpu'))
83
+ g_ema.load_state_dict(ckpt['g_ema'])
84
+ g_ema.eval().to(device)
85
+ return g_ema
86
+
87
+
88
+ @st.cache
89
+ def get_labels() -> torch.Tensor:
90
+ if not Path(path_labels).exists():
91
+ download_file(params.drive_id_labels, path_labels)
92
+ labels_train = get_labels_train(path_labels)
93
+ return labels_train
94
+
95
+
96
+ def app():
97
+ global q1_out, q2_out, q3_out, q4_out, q5_out, q6_out, q6_out, q7_out, q8_out, q9_out, q10_out, q11_out
98
+
99
+ st.title('Explore InfoSCC-GAN')
100
+ st.markdown('This demo shows InfoSCC-GAN for conditional galaxy generation')
101
+ st.subheader(r'<- Use sidebar to explore $z_1, ..., z_k$ latent variables')
102
+
103
+ if not Path(model_path).exists():
104
+ download_file(drive_id, model_path)
105
+
106
+ model = load_model(model_path)
107
+ eps = model.sample_eps(bs).to(device)
108
+ labels_train = get_labels()
109
+
110
+ # get zs
111
+ zs = np.array([[0.0] * n_basis] * n_layers, dtype=np.float32)
112
+
113
+ for l in range(n_layers):
114
+ st.sidebar.markdown(f'## Layer: {l}')
115
+ for d in range(n_basis):
116
+ zs[l][d] = st.sidebar.slider(f'Dimension: {d}', key=f'{l}{d}',
117
+ min_value=-5., max_value=5., value=0., step=0.1)
118
+
119
+ # ========================== Labels ================================
120
+ st.subheader('Label')
121
+ st.markdown(r'There are two types of selecting labels: __Random__ - sample random samples from the dataset;'
122
+ r' __Manual__ - select labels manually (advanced use). When using __Manual__ all of the images will be'
123
+ r' generated with tha same labels')
124
+ label_type = st.radio('Label type', options=['Random', 'Manual (Advanced)'])
125
+ if label_type == 'Random':
126
+ labels = sample_labels(labels_train, bs).to(device)
127
+
128
+ st.markdown(r'Click on __Sample labels__ button to sample random input labels')
129
+ change_label = st.button('Sample label')
130
+
131
+ if change_label:
132
+ labels = sample_labels(labels_train, bs).to(device)
133
+ elif label_type == 'Manual (Advanced)':
134
+ st.markdown('Answer the questions below')
135
+
136
+ q1_select_box = st.selectbox(q1, options=q1_options)
137
+ clear_out()
138
+ q1_out[q1_options.index(q1_select_box)] = 1
139
+ # 1
140
+
141
+ if q1_select_box == 'Smooth':
142
+ q7_select_box = st.selectbox(q7, options=q7_options)
143
+ clear_out([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
144
+ q7_out[q7_options.index(q7_select_box)] = 1
145
+ # 1 - 7
146
+
147
+ q6_select_box = st.selectbox(q6, options=q6_options)
148
+ clear_out([2, 3, 4, 5, 6, 8, 9, 10, 11])
149
+ q6_out[q6_options.index(q6_select_box)] = 1
150
+ # 1 - 7 - 6
151
+
152
+ if q6_select_box == 'Yes':
153
+ q8_select_box = st.selectbox(q8, options=q8_options)
154
+ clear_out([2, 3, 4, 5, 8, 9, 10, 11])
155
+ q8_out[q8_options.index(q8_select_box)] = 1
156
+ # 1 - 7 - 6 - 8 - end
157
+
158
+ elif q1_select_box == 'Features or disk':
159
+ q2_select_box = st.selectbox(q2, options=q2_options)
160
+ clear_out([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
161
+ q2_out[q2_options.index(q2_select_box)] = 1
162
+ # 1 - 2
163
+
164
+ if q2_select_box == 'Yes':
165
+ q9_select_box = st.selectbox(q9, options=q9_options)
166
+ clear_out([3, 4, 5, 6, 7, 8, 9, 10, 11])
167
+ q9_out[q9_options.index(q9_select_box)] = 1
168
+ # 1 - 2 - 9
169
+
170
+ q6_select_box = st.selectbox(q6, options=q6_options)
171
+ clear_out([3, 4, 5, 6, 7, 8, 10, 11])
172
+ q6_out[q6_options.index(q6_select_box)] = 1
173
+ # 1 - 2 - 9 - 6
174
+
175
+ if q6_select_box == 'Yes':
176
+ q8_select_box = st.selectbox(q8, options=q8_options)
177
+ clear_out([3, 4, 5, 7, 8, 10, 11])
178
+ q8_out[q8_options.index(q8_select_box)] = 1
179
+ # 1 - 2 - 9 - 6 - 8
180
+ else:
181
+ q3_select_box = st.selectbox(q3, options=q3_options)
182
+ clear_out([3, 4, 5, 6, 7, 8, 9, 10, 11])
183
+ q3_out[q3_options.index(q3_select_box)] = 1
184
+ # 1 - 2 - 3
185
+
186
+ q4_select_box = st.selectbox(q4, options=q4_options)
187
+ clear_out([4, 5, 6, 7, 8, 9, 10, 11])
188
+ q4_out[q4_options.index(q4_select_box)] = 1
189
+ # 1 - 2 - 3 - 4
190
+
191
+ if q4_select_box == 'Yes':
192
+ q10_select_box = st.selectbox(q10, options=q10_options)
193
+ clear_out([5, 6, 7, 8, 9, 10, 11])
194
+ q10_out[q10_options.index(q10_select_box)] = 1
195
+ # 1 - 2 - 3 - 4 - 10
196
+
197
+ q11_select_box = st.selectbox(q11, options=q11_options)
198
+ clear_out([5, 6, 7, 8, 9, 11])
199
+ q11_out[q11_options.index(q11_select_box)] = 1
200
+ # 1 - 2 - 3 - 4 - 10 - 11
201
+
202
+ q5_select_box = st.selectbox(q5, options=q5_options)
203
+ clear_out([5, 6, 7, 8, 9])
204
+ q5_out[q5_options.index(q5_select_box)] = 1
205
+ # 1 - 2 - 3 - 4 - 10 - 11 - 5
206
+
207
+ q6_select_box = st.selectbox(q6, options=q6_options)
208
+ clear_out([6, 7, 8, 9])
209
+ q6_out[q6_options.index(q6_select_box)] = 1
210
+ # 1 - 2 - 3 - 4 - 10 - 11 - 5 - 6
211
+
212
+ if q6_select_box == 'Yes':
213
+ q8_select_box = st.selectbox(q8, options=q8_options)
214
+ clear_out([7, 8, 9])
215
+ q8_out[q8_options.index(q8_select_box)] = 1
216
+ # 1 - 2 - 3 - 4 - 10 - 11 - 5 - 6 - 8 - End
217
+ else:
218
+ q5_select_box = st.selectbox(q5, options=q5_options)
219
+ clear_out([5, 6, 7, 8, 9, 10, 11])
220
+ q5_out[q5_options.index(q5_select_box)] = 1
221
+ # 1 - 2 - 3 - 4 - 5
222
+
223
+ q6_select_box = st.selectbox(q6, options=q6_options)
224
+ clear_out([6, 7, 8, 9, 10, 11])
225
+ q6_out[q6_options.index(q6_select_box)] = 1
226
+ # 1 - 2 - 3 - 4 - 5 - 6
227
+
228
+ if q6_select_box == 'Yes':
229
+ q8_select_box = st.selectbox(q8, options=q8_options)
230
+ clear_out([7, 8, 9, 10, 11])
231
+ q8_out[q8_options.index(q8_select_box)] = 1
232
+ # 1 - 2 - 3 - 4 - 5 - 6 - 8 - End
233
+
234
+ labels = [*q1_out, *q2_out, *q3_out, *q4_out, *q5_out, *q6_out, *q7_out, *q8_out, *q9_out, *q10_out, *q11_out]
235
+ labels = torch.Tensor(labels).to(device)
236
+ labels = labels.unsqueeze(0).repeat(bs, 1)
237
+ labels = make_galaxy_labels_hierarchical(labels)
238
+ clear_out()
239
+ # ========================== Labels ================================
240
+
241
+ st.subheader('Noise')
242
+ st.markdown(r'Click on __Change eps__ button to change input $\varepsilon$ latent space')
243
+ change_eps = st.button('Change eps')
244
+ if change_eps:
245
+ eps = model.sample_eps(bs).to(device)
246
+
247
+ zs_torch = torch.from_numpy(zs).unsqueeze(0).repeat(bs, 1, 1).to(device)
248
+
249
+ with torch.no_grad():
250
+ imgs = model(labels, eps, zs_torch).squeeze(0).cpu()
251
+
252
+ if upsample:
253
+ imgs = F.interpolate(imgs, (size * 4, size * 4), mode='bicubic')
254
+
255
+ imgs = [(imgs[i].permute(1, 2, 0).numpy() * 127.5 + 127.5).astype(np.uint8) for i in range(bs)]
256
+
257
+ counter = 0
258
+ for r in range(bs // n_cols):
259
+ cols = st.columns(n_cols)
260
+
261
+ for c in range(n_cols):
262
+ cols[c].image(imgs[counter])
263
+ counter += 1
src/app/multipage.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is the framework for generating multiple Streamlit applications
3
+ through an object oriented framework.
4
+ """
5
+
6
+ # Import necessary libraries
7
+ import streamlit as st
8
+
9
+
10
+ # Define the multipage class to manage the multiple apps in our program
11
+ class MultiPage:
12
+ """Framework for combining multiple streamlit applications."""
13
+
14
+ def __init__(self) -> None:
15
+ """Constructor class to generate a list which will store all our applications as an instance variable."""
16
+ self.pages = []
17
+
18
+ def add_page(self, title, func) -> None:
19
+ """Class Method to Add pages to the project
20
+ Args:
21
+ title ([str]): The title of page which we are adding to the list of apps
22
+
23
+ func: Python function to render this page in Streamlit
24
+ """
25
+
26
+ self.pages.append({
27
+
28
+ "title": title,
29
+ "function": func
30
+ })
31
+
32
+ def run(self):
33
+ # Drodown to select the page to run
34
+ page = st.sidebar.selectbox(
35
+ 'App Navigation',
36
+ self.pages,
37
+ format_func=lambda page: page['title']
38
+ )
39
+
40
+ # run the app function
41
+ page['function']()
src/app/params.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains list of global parameters for the Galaxy Zoo generation app
3
+ """
4
+
5
+ device = 'cpu'
6
+ size = 64 # generated image size
7
+ shape_label = 37 # shape of the input label
8
+ n_channels = 3 # number of color channels in image
9
+ upsample = True # if true, generated images will be upsampled
10
+ noise_dim = 512 # noise size in InfoSCC-GAN
11
+ n_basis = 6 # size of additional z vectors in InfoSCC-GAN
12
+ y_type = 'real' # type of labels in InfoSCC-GAN
13
+ dim_z = 128 # z vector size in BigGAN and cVAE
14
+
15
+ path_infoscc_gan = './models/InfoSCC-GAN/generator.pt'
16
+ drive_id_infoscc_gan = '1_kIujc497OH0ZJ7PNPwS5_otNlS7jMLI'
17
+
18
+ path_biggan = './models/BigGAN/generator.pth'
19
+ drive_id_biggan = '1sMSDdnQ5GjHcno5knHTDSKAKhhoHh_4z'
20
+
21
+ path_cvae = './models/CVAE/generator.pth'
22
+ drive_id_cvae = '17FmLvhwXq8PQMrD1CtjqyoAy5BobYMTE'
23
+
24
+ path_labels = './data/training_solutions_rev1.csv'
25
+ drive_id_labels = '1dzsB_HdGtmSHE4pCppamISpFaJBfPF7E'
src/app/questions.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains question and options for the manual labeling
3
+ """
4
+
5
+ q1 = 'Is the object a smooth galaxy, a galaxy with features/disk or a star?'
6
+ q1_options = ['Smooth', 'Features or disk', 'Star or artifact']
7
+
8
+ q2 = 'Is it edge-on? '
9
+ q2_options = ['Yes', 'No']
10
+
11
+ q3 = 'Is there a bar?'
12
+ q3_options = ['Yes', 'No']
13
+
14
+ q4 = 'Is there a spiral pattern?'
15
+ q4_options = ['Yes', 'No']
16
+
17
+ q5 = 'How prominent is the central bulge?'
18
+ q5_options = ['No bulge', 'Just noticeable', 'Obvious', 'Dominant']
19
+
20
+ q6 = 'Is there anything "odd" about the galaxy?'
21
+ q6_options = ['Yes', 'No']
22
+
23
+ q7 = 'How round is the smooth galaxy?'
24
+ q7_options = ['Completely round', 'In between', 'Cigar-shaped']
25
+
26
+ q8 = 'What is the odd feature?'
27
+ q8_options = ['Ring', 'Lens or are', 'Disturbed', 'Irregular', 'Other', 'Merger', 'Dust lane']
28
+
29
+ q9 = 'What shape is the bulge in the edge-on galaxy?'
30
+ q9_options = ['Rounded', 'Boxy', 'No bulge']
31
+
32
+ q10 = 'How tightly wound are the spiral arms?'
33
+ q10_options = ['Tight', 'Medium', 'Loose']
34
+
35
+ q11 = 'How many spiral arms are there?'
36
+ q11_options = ['1', '2', '3', '4', 'more than four', 'can`t tell']
src/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .data import get_labels_train
2
+ from .labels import make_galaxy_labels_hierarchical
src/data/data.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from pandas import read_csv
2
+
3
+ import torch
4
+
5
+
6
+ def get_labels_train(file_galaxy_labels) -> torch.Tensor:
7
+ df_galaxy_labels = read_csv(file_galaxy_labels)
8
+ labels_train = df_galaxy_labels[df_galaxy_labels.columns[1:]].values
9
+ labels_train = torch.from_numpy(labels_train).float()
10
+ return labels_train
src/data/labels.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+
5
+
6
+ class_groups = {
7
+ # group : indices (assuming 0th position is id)
8
+ 0: (),
9
+ 1: (1, 2, 3),
10
+ 2: (4, 5),
11
+ 3: (6, 7),
12
+ 4: (8, 9),
13
+ 5: (10, 11, 12, 13),
14
+ 6: (14, 15),
15
+ 7: (16, 17, 18),
16
+ 8: (19, 20, 21, 22, 23, 24, 25),
17
+ 9: (26, 27, 28),
18
+ 10: (29, 30, 31),
19
+ 11: (32, 33, 34, 35, 36, 37),
20
+ }
21
+
22
+
23
+ class_groups_indices = {g: np.array(ixs)-1 for g, ixs in class_groups.items()}
24
+
25
+
26
+ hierarchy = {
27
+ # group : parent (group, label)
28
+ 2: (1, 1),
29
+ 3: (2, 1),
30
+ 4: (2, 1),
31
+ 5: (2, 1),
32
+ 7: (1, 0),
33
+ 8: (6, 0),
34
+ 9: (2, 0),
35
+ 10: (4, 0),
36
+ 11: (4, 0),
37
+ }
38
+
39
+
40
+ def make_galaxy_labels_hierarchical(labels: torch.Tensor) -> torch.Tensor:
41
+ """ transform groups of galaxy label probabilities to follow the hierarchical order defined in galaxy zoo
42
+ more info here: https://www.kaggle.com/c/galaxy-zoo-the-galaxy-challenge/overview/the-galaxy-zoo-decision-tree
43
+ labels is a NxL torch tensor, where N is the batch size and L is the number of labels,
44
+ all labels should be > 1
45
+ the indices of label groups are listed in class_groups_indices
46
+
47
+ Return
48
+ ------
49
+ hierarchical_labels : NxL torch tensor, where L is the total number of labels
50
+ """
51
+ shift = labels.shape[1] > 37 ## in case the id is included at 0th position, shift indices accordingly
52
+ index = lambda i: class_groups_indices[i] + shift
53
+
54
+ for i in range(1, 12):
55
+ ## normalize probabilities to 1
56
+ norm = torch.sum(labels[:, index(i)], dim=1, keepdims=True)
57
+ norm[norm == 0] += 1e-4 ## add small number to prevent NaNs dividing by zero, yet keep track of gradient
58
+ labels[:, index(i)] /= norm
59
+ ## renormalize according to hierarchical structure
60
+ if i not in [1, 6]:
61
+ parent_group_label = labels[:, index(hierarchy[i][0])]
62
+ labels[:, index(i)] *= parent_group_label[:, hierarchy[i][1]].unsqueeze(-1)
63
+ return labels
src/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .infoscc_gan import ConditionalGenerator
2
+ from .cvae import ConditionalDecoder
src/models/big/BigGAN2.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import init
6
+ import torch.optim as optim
7
+ import torch.nn.functional as F
8
+
9
+ import src.models.big.layers as layers
10
+ from src.models.parameter import labels_dim, parameter
11
+ from src.models.neuralnetwork import NeuralNetwork
12
+
13
+
14
+ # Architectures for G
15
+ # Attention is passed in in the format '32_64' to mean applying an attention
16
+ # block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
17
+ def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
18
+ arch = {}
19
+ arch[512] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
20
+ 'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
21
+ 'upsample' : [True] * 7,
22
+ 'resolution' : [8, 16, 32, 64, 128, 256, 512],
23
+ 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
24
+ for i in range(3,10)}}
25
+ arch[256] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2]],
26
+ 'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1]],
27
+ 'upsample' : [True] * 6,
28
+ 'resolution' : [8, 16, 32, 64, 128, 256],
29
+ 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
30
+ for i in range(3,9)}}
31
+ arch[128] = {'in_channels' : [ch * item for item in [16, 16, 8, 4, 2]],
32
+ 'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]],
33
+ 'upsample' : [True] * 5,
34
+ 'resolution' : [8, 16, 32, 64, 128],
35
+ 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
36
+ for i in range(3,8)}}
37
+ arch[64] = {'in_channels' : [ch * item for item in [16, 16, 8, 4]],
38
+ 'out_channels' : [ch * item for item in [16, 8, 4, 2]],
39
+ 'upsample' : [True] * 4,
40
+ 'resolution' : [8, 16, 32, 64],
41
+ 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
42
+ for i in range(3,7)}}
43
+ arch[32] = {'in_channels' : [ch * item for item in [4, 4, 4]],
44
+ 'out_channels' : [ch * item for item in [4, 4, 4]],
45
+ 'upsample' : [True] * 3,
46
+ 'resolution' : [8, 16, 32],
47
+ 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
48
+ for i in range(3,6)}}
49
+
50
+ return arch
51
+
52
+ class Generator(NeuralNetwork):
53
+ def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=64, labels_dim=labels_dim,
54
+ G_kernel_size=3, G_attn='64', n_classes=1,
55
+ num_G_SVs=1, num_G_SV_itrs=1,
56
+ G_shared=True, shared_dim=0, hier=False,
57
+ cross_replica=False, mybn=False,
58
+ G_activation=nn.ReLU(inplace=False),
59
+ G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
60
+ BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,
61
+ G_init='ortho', skip_init=False, no_optim=False,
62
+ G_param='SN', norm_style='bn',
63
+ **kwargs):
64
+ super(Generator, self).__init__()
65
+ # Channel width mulitplier
66
+ self.ch = G_ch
67
+ # Dimensionality of the latent space
68
+ self.dim_z = dim_z
69
+ # The initial spatial dimensions
70
+ self.bottom_width = bottom_width
71
+ # Resolution of the output
72
+ self.resolution = resolution
73
+ # Kernel size?
74
+ self.kernel_size = G_kernel_size
75
+ # Attention?
76
+ self.attention = G_attn
77
+ # number of classes, for use in categorical conditional generation
78
+ self.n_classes = n_classes
79
+ # Use shared embeddings?
80
+ self.G_shared = G_shared
81
+ # Dimensionality of the shared embedding? Unused if not using G_shared
82
+ self.shared_dim = shared_dim if shared_dim > 0 else dim_z
83
+ # Hierarchical latent space?
84
+ self.hier = hier
85
+ # Cross replica batchnorm?
86
+ self.cross_replica = cross_replica
87
+ # Use my batchnorm?
88
+ self.mybn = mybn
89
+ # nonlinearity for residual blocks
90
+ self.activation = G_activation
91
+ # Initialization style
92
+ self.init = G_init
93
+ # Parameterization style
94
+ self.G_param = G_param
95
+ # Normalization style
96
+ self.norm_style = norm_style
97
+ # Epsilon for BatchNorm?
98
+ self.BN_eps = BN_eps
99
+ # Epsilon for Spectral Norm?
100
+ self.SN_eps = SN_eps
101
+ # fp16?
102
+ self.fp16 = G_fp16
103
+ # Architecture dict
104
+ self.arch = G_arch(self.ch, self.attention)[resolution]
105
+
106
+ # If using hierarchical latents, adjust z
107
+ if self.hier:
108
+ # Number of places z slots into
109
+ self.num_slots = len(self.arch['in_channels']) + 1
110
+ self.z_chunk_size = (self.dim_z // self.num_slots)
111
+ # Recalculate latent dimensionality for even splitting into chunks
112
+ self.dim_z = self.z_chunk_size * self.num_slots
113
+ else:
114
+ self.num_slots = 1
115
+ self.z_chunk_size = 0
116
+
117
+ # Which convs, batchnorms, and linear layers to use
118
+ if self.G_param == 'SN':
119
+ self.which_conv = functools.partial(layers.SNConv2d,
120
+ kernel_size=3, padding=1,
121
+ num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
122
+ eps=self.SN_eps)
123
+ self.which_linear = functools.partial(layers.SNLinear,
124
+ num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
125
+ eps=self.SN_eps)
126
+ else:
127
+ self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
128
+ self.which_linear = nn.Linear
129
+
130
+ # We use a non-spectral-normed embedding here regardless;
131
+ # For some reason applying SN to G's embedding seems to randomly cripple G
132
+ self.which_embedding = nn.Embedding
133
+ bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
134
+ else self.which_embedding)
135
+ self.which_bn = functools.partial(layers.ccbn,
136
+ which_linear=bn_linear,
137
+ cross_replica=self.cross_replica,
138
+ mybn=self.mybn,
139
+ input_size=(self.shared_dim + self.z_chunk_size if self.G_shared
140
+ else self.n_classes),
141
+ norm_style=self.norm_style,
142
+ eps=self.BN_eps)
143
+
144
+
145
+ # Prepare model
146
+ # prepare label input
147
+ self.transform_label_layer = torch.nn.Linear(labels_dim, 128)
148
+ # If not using shared embeddings, self.shared is just a passthrough
149
+ self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared
150
+ else layers.identity())
151
+ # First linear layer
152
+ self.linear = self.which_linear(self.dim_z // self.num_slots,
153
+ self.arch['in_channels'][0] * (self.bottom_width **2))
154
+
155
+ # self.blocks is a doubly-nested list of modules, the outer loop intended
156
+ # to be over blocks at a given resolution (resblocks and/or self-attention)
157
+ # while the inner loop is over a given block
158
+ self.blocks = []
159
+ for index in range(len(self.arch['out_channels'])):
160
+ self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],
161
+ out_channels=self.arch['out_channels'][index],
162
+ which_conv=self.which_conv,
163
+ which_bn=self.which_bn,
164
+ activation=self.activation,
165
+ upsample=(functools.partial(F.interpolate, scale_factor=2)
166
+ if self.arch['upsample'][index] else None))]]
167
+
168
+ # If attention on this block, attach it to the end
169
+ if self.arch['attention'][self.arch['resolution'][index]]:
170
+ print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
171
+ self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]
172
+
173
+ # Turn self.blocks into a ModuleList so that it's all properly registered.
174
+ self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
175
+
176
+ # output layer: batchnorm-relu-conv.
177
+ # Consider using a non-spectral conv here
178
+ self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
179
+ cross_replica=self.cross_replica,
180
+ mybn=self.mybn),
181
+ self.activation,
182
+ self.which_conv(self.arch['out_channels'][-1], 3))
183
+
184
+ # Initialize weights. Optionally skip init for testing.
185
+ if not skip_init:
186
+ self.init_weights()
187
+
188
+ # Set up optimizer
189
+ # If this is an EMA copy, no need for an optim, so just return now
190
+ if no_optim:
191
+ return
192
+ self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
193
+ if G_mixed_precision:
194
+ print('Using fp16 adam in G...')
195
+ import utils
196
+ self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
197
+ betas=(self.B1, self.B2), weight_decay=0,
198
+ eps=self.adam_eps)
199
+ else:
200
+ self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
201
+ betas=(self.B1, self.B2), weight_decay=0,
202
+ eps=self.adam_eps)
203
+
204
+ # LR scheduling, left here for forward compatibility
205
+ # self.lr_sched = {'itr' : 0}# if self.progressive else {}
206
+ # self.j = 0
207
+ self.set_optimizer(parameter.optimizer, lr=parameter.learning_rate, betas=parameter.betas)
208
+
209
+ # Initialize
210
+ def init_weights(self):
211
+ self.param_count = 0
212
+ for module in self.modules():
213
+ if (isinstance(module, nn.Conv2d)
214
+ or isinstance(module, nn.Linear)
215
+ or isinstance(module, nn.Embedding)):
216
+ if self.init == 'ortho':
217
+ init.orthogonal_(module.weight)
218
+ elif self.init == 'N02':
219
+ init.normal_(module.weight, 0, 0.02)
220
+ elif self.init in ['glorot', 'xavier']:
221
+ init.xavier_uniform_(module.weight)
222
+ else:
223
+ print('Init style not recognized...')
224
+ self.param_count += sum([p.data.nelement() for p in module.parameters()])
225
+ print('Param count for G''s initialized parameters: %d' % self.param_count)
226
+
227
+
228
+ def transform_labels(self, labels):
229
+ """ prepore labels for input to generator """
230
+ return self.transform_label_layer(labels)
231
+
232
+
233
+ # Note on this forward function: we pass in a y vector which has
234
+ # already been passed through G.shared to enable easy class-wise
235
+ # interpolation later. If we passed in the one-hot and then ran it through
236
+ # G.shared in this forward function, it would be harder to handle.
237
+ def forward(self, z, y):
238
+ # If hierarchical, concatenate zs and ys
239
+ y = self.transform_labels(y)
240
+ if self.hier:
241
+ zs = torch.split(z, self.z_chunk_size, 1)
242
+ z = zs[0]
243
+ ys = [torch.cat([y, item], 1) for item in zs[1:]]
244
+ else:
245
+ ys = [y] * len(self.blocks)
246
+
247
+ # First linear layer
248
+ h = self.linear(z)
249
+ # Reshape
250
+ h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
251
+
252
+ # Loop over blocks
253
+ for index, blocklist in enumerate(self.blocks):
254
+ # Second inner loop in case block has multiple layers
255
+ for block in blocklist:
256
+ h = block(h, ys[index])
257
+
258
+ # Apply batchnorm-relu-conv-tanh at output
259
+ return torch.sigmoid(self.output_layer(h))
260
+ # return torch.tanh(self.output_layer(h))
261
+
262
+
263
+ # Discriminator architecture, same paradigm as G's above
264
+ def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'):
265
+ arch = {}
266
+ arch[256] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]],
267
+ 'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
268
+ 'downsample' : [True] * 6 + [False],
269
+ 'resolution' : [128, 64, 32, 16, 8, 4, 4 ],
270
+ 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
271
+ for i in range(2,8)}}
272
+ arch[128] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 16]],
273
+ 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]],
274
+ 'downsample' : [True] * 5 + [False],
275
+ 'resolution' : [64, 32, 16, 8, 4, 4],
276
+ 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
277
+ for i in range(2,8)}}
278
+ arch[64] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8]],
279
+ 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]],
280
+ 'downsample' : [True] * 4 + [False],
281
+ 'resolution' : [32, 16, 8, 4, 4],
282
+ 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
283
+ for i in range(2,7)}}
284
+ arch[32] = {'in_channels' : [3] + [item * ch for item in [4, 4, 4]],
285
+ 'out_channels' : [item * ch for item in [4, 4, 4, 4]],
286
+ 'downsample' : [True, True, False, False],
287
+ 'resolution' : [16, 16, 16, 16],
288
+ 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
289
+ for i in range(2,6)}}
290
+ return arch
291
+
292
+ class Discriminator(NeuralNetwork):
293
+
294
+ def __init__(self, D_ch=64, D_wide=True, resolution=64, labels_dim=labels_dim,
295
+ D_kernel_size=3, D_attn='64', n_classes=1,
296
+ num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
297
+ D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8,
298
+ SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False,
299
+ D_init='ortho', skip_init=False, D_param='SN', **kwargs):
300
+ super(Discriminator, self).__init__()
301
+ # Width multiplier
302
+ self.ch = D_ch
303
+ # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
304
+ self.D_wide = D_wide
305
+ # Resolution
306
+ self.resolution = resolution
307
+ # Kernel size
308
+ self.kernel_size = D_kernel_size
309
+ # Attention?
310
+ self.attention = D_attn
311
+ # Number of classes
312
+ self.n_classes = n_classes
313
+ # Activation
314
+ self.activation = D_activation
315
+ # Initialization style
316
+ self.init = D_init
317
+ # Parameterization style
318
+ self.D_param = D_param
319
+ # Epsilon for Spectral Norm?
320
+ self.SN_eps = SN_eps
321
+ # Fp16?
322
+ self.fp16 = D_fp16
323
+ # Architecture
324
+ self.arch = D_arch(self.ch, self.attention)[resolution]
325
+
326
+ # Which convs, batchnorms, and linear layers to use
327
+ # No option to turn off SN in D right now
328
+ if self.D_param == 'SN':
329
+ self.which_conv = functools.partial(layers.SNConv2d,
330
+ kernel_size=3, padding=1,
331
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
332
+ eps=self.SN_eps)
333
+ self.which_linear = functools.partial(layers.SNLinear,
334
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
335
+ eps=self.SN_eps)
336
+ self.which_embedding = functools.partial(layers.SNEmbedding,
337
+ num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
338
+ eps=self.SN_eps)
339
+ # Prepare model
340
+ # prepare label input
341
+ self.transform_label_layer = torch.nn.Linear(labels_dim, 1024)
342
+ # self.blocks is a doubly-nested list of modules, the outer loop intended
343
+ # to be over blocks at a given resolution (resblocks and/or self-attention)
344
+ self.blocks = []
345
+ for index in range(len(self.arch['out_channels'])):
346
+ self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
347
+ out_channels=self.arch['out_channels'][index],
348
+ which_conv=self.which_conv,
349
+ wide=self.D_wide,
350
+ activation=self.activation,
351
+ preactivation=(index > 0),
352
+ downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
353
+ # If attention on this block, attach it to the end
354
+ if self.arch['attention'][self.arch['resolution'][index]]:
355
+ print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
356
+ self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
357
+ self.which_conv)]
358
+ # Turn self.blocks into a ModuleList so that it's all properly registered.
359
+ self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
360
+ # Linear output layer. The output dimension is typically 1, but may be
361
+ # larger if we're e.g. turning this into a VAE with an inference output
362
+ self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
363
+ # Embedding for projection discrimination
364
+ self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
365
+
366
+ # Initialize weights
367
+ if not skip_init:
368
+ self.init_weights()
369
+
370
+ # Set up optimizer
371
+ self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
372
+ if D_mixed_precision:
373
+ print('Using fp16 adam in D...')
374
+ import utils
375
+ self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
376
+ betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
377
+ else:
378
+ self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
379
+ betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
380
+ # LR scheduling, left here for forward compatibility
381
+ # self.lr_sched = {'itr' : 0}# if self.progressive else {}
382
+ # self.j = 0
383
+ self.set_optimizer(parameter.optimizer, lr=parameter.learning_rate*3, betas=parameter.betas)
384
+
385
+ # Initialize
386
+ def init_weights(self):
387
+ self.param_count = 0
388
+ for module in self.modules():
389
+ if (isinstance(module, nn.Conv2d)
390
+ or isinstance(module, nn.Linear)
391
+ or isinstance(module, nn.Embedding)):
392
+ if self.init == 'ortho':
393
+ init.orthogonal_(module.weight)
394
+ elif self.init == 'N02':
395
+ init.normal_(module.weight, 0, 0.02)
396
+ elif self.init in ['glorot', 'xavier']:
397
+ init.xavier_uniform_(module.weight)
398
+ else:
399
+ print('Init style not recognized...')
400
+ self.param_count += sum([p.data.nelement() for p in module.parameters()])
401
+ print('Param count for D''s initialized parameters: %d' % self.param_count)
402
+
403
+ def transform_labels(self, labels):
404
+ """ prepore labels for input to discriminator """
405
+ return self.transform_label_layer(labels)
406
+
407
+
408
+ def forward(self, x, y=None):
409
+ # Stick x into h for cleaner for loops without flow control
410
+ h = x
411
+ # Loop over blocks
412
+ for index, blocklist in enumerate(self.blocks):
413
+ for block in blocklist:
414
+ h = block(h)
415
+ # Apply global sum pooling as in SN-GAN
416
+ h = torch.sum(self.activation(h), [2, 3])
417
+ # Get initial class-unconditional output
418
+ out = self.linear(h)
419
+ # Get projection of final featureset onto class vectors and add to evidence
420
+ y = self.transform_labels(y)
421
+ out = out + torch.sum(y * h, 1, keepdim=True)
422
+ # out = out + torch.sum(self.embed(y) * h, 1, keepdim=True) ## use y = torch.tensor(0)
423
+ return out
424
+
425
+ # Parallelized G_D to minimize cross-gpu communication
426
+ # Without this, Generator outputs would get all-gathered and then rebroadcast.
427
+ class G_D(nn.Module):
428
+ def __init__(self, G, D):
429
+ super(G_D, self).__init__()
430
+ self.G = G
431
+ self.D = D
432
+
433
+ def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False,
434
+ split_D=False):
435
+ # If training G, enable grad tape
436
+ with torch.set_grad_enabled(train_G):
437
+ # Get Generator output given noise
438
+ G_z = self.G(z, self.G.shared(gy))
439
+ # Cast as necessary
440
+ if self.G.fp16 and not self.D.fp16:
441
+ G_z = G_z.float()
442
+ if self.D.fp16 and not self.G.fp16:
443
+ G_z = G_z.half()
444
+ # Split_D means to run D once with real data and once with fake,
445
+ # rather than concatenating along the batch dimension.
446
+ if split_D:
447
+ D_fake = self.D(G_z, gy)
448
+ if x is not None:
449
+ D_real = self.D(x, dy)
450
+ return D_fake, D_real
451
+ else:
452
+ if return_G_z:
453
+ return D_fake, G_z
454
+ else:
455
+ return D_fake
456
+ # If real data is provided, concatenate it with the Generator's output
457
+ # along the batch dimension for improved efficiency.
458
+ else:
459
+ D_input = torch.cat([G_z, x], 0) if x is not None else G_z
460
+ D_class = torch.cat([gy, dy], 0) if dy is not None else gy
461
+ # Get Discriminator output
462
+ D_out = self.D(D_input, D_class)
463
+ if x is not None:
464
+ return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real
465
+ else:
466
+ if return_G_z:
467
+ return D_out, G_z
468
+ else:
469
+ return D_out
src/models/big/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 Andy Brock
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
src/models/big/README.md ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BigGAN-PyTorch
2
+ The author's officially unofficial PyTorch BigGAN implementation.
3
+
4
+ ![Dogball? Dogball!](imgs/header_image.jpg?raw=true "Dogball? Dogball!")
5
+
6
+
7
+ This repo contains code for 4-8 GPU training of BigGANs from [Large Scale GAN Training for High Fidelity Natural Image Synthesis](https://arxiv.org/abs/1809.11096) by Andrew Brock, Jeff Donahue, and Karen Simonyan.
8
+
9
+ This code is by Andy Brock and Alex Andonian.
10
+
11
+ ## How To Use This Code
12
+ You will need:
13
+
14
+ - [PyTorch](https://PyTorch.org/), version 1.0.1
15
+ - tqdm, numpy, scipy, and h5py
16
+ - The ImageNet training set
17
+
18
+ First, you may optionally prepare a pre-processed HDF5 version of your target dataset for faster I/O. Following this (or not), you'll need the Inception moments needed to calculate FID. These can both be done by modifying and running
19
+
20
+ ```sh
21
+ sh scripts/utils/prepare_data.sh
22
+ ```
23
+
24
+ Which by default assumes your ImageNet training set is downloaded into the root folder `data` in this directory, and will prepare the cached HDF5 at 128x128 pixel resolution.
25
+
26
+ In the scripts folder, there are multiple bash scripts which will train BigGANs with different batch sizes. This code assumes you do not have access to a full TPU pod, and accordingly
27
+ spoofs mega-batches by using gradient accumulation (averaging grads over multiple minibatches, and only taking an optimizer step after N accumulations). By default, the `launch_BigGAN_bs256x8.sh` script trains a
28
+ full-sized BigGAN model with a batch size of 256 and 8 gradient accumulations, for a total batch size of 2048. On 8xV100 with full-precision training (no Tensor cores), this script takes 15 days to train to 150k iterations.
29
+
30
+ You will first need to figure out the maximum batch size your setup can support. The pre-trained models provided here were trained on 8xV100 (16GB VRAM each) which can support slightly more than the BS256 used by default.
31
+ Once you've determined this, you should modify the script so that the batch size times the number of gradient accumulations is equal to your desired total batch size (BigGAN defaults to 2048).
32
+
33
+ Note also that this script uses the `--load_in_mem` arg, which loads the entire (~64GB) I128.hdf5 file into RAM for faster data loading. If you don't have enough RAM to support this (probably 96GB+), remove this argument.
34
+
35
+
36
+ ## Metrics and Sampling
37
+ ![I believe I can fly!](imgs/interp_sample.jpg?raw=true "I believe I can fly!")
38
+
39
+ During training, this script will output logs with training metrics and test metrics, will save multiple copies (2 most recent and 5 highest-scoring) of the model weights/optimizer params, and will produce samples and interpolations every time it saves weights.
40
+ The logs folder contains scripts to process these logs and plot the results using MATLAB (sorry not sorry).
41
+
42
+ After training, one can use `sample.py` to produce additional samples and interpolations, test with different truncation values, batch sizes, number of standing stat accumulations, etc. See the `sample_BigGAN_bs256x8.sh` script for an example.
43
+
44
+ By default, everything is saved to weights/samples/logs/data folders which are assumed to be in the same folder as this repo.
45
+ You can point all of these to a different base folder using the `--base_root` argument, or pick specific locations for each of these with their respective arguments (e.g. `--logs_root`).
46
+
47
+ We include scripts to run BigGAN-deep, but we have not fully trained a model using them, so consider them untested. Additionally, we include scripts to run a model on CIFAR, and to run SA-GAN (with EMA) and SN-GAN on ImageNet. The SA-GAN code assumes you have 4xTitanX (or equivalent in terms of GPU RAM) and will run with a batch size of 128 and 2 gradient accumulations.
48
+
49
+ ## An Important Note on Inception Metrics
50
+ This repo uses the PyTorch in-built inception network to calculate IS and FID.
51
+ These scores are different from the scores you would get using the official TF inception code, and are only for monitoring purposes!
52
+ Run sample.py on your model, with the `--sample_npz` argument, then run inception_tf13 to calculate the actual TensorFlow IS. Note that you will need to have TensorFlow 1.3 or earlier installed, as TF1.4+ breaks the original IS code.
53
+
54
+ ## Pretrained models
55
+ ![PyTorch Inception Score and FID](imgs/IS_FID.png)
56
+ We include two pretrained model checkpoints (with G, D, the EMA copy of G, the optimizers, and the state dict):
57
+ - The main checkpoint is for a BigGAN trained on ImageNet at 128x128, using BS256 and 8 gradient accumulations, taken just before collapse, with a TF Inception Score of 97.35 +/- 1.79: [LINK](https://drive.google.com/open?id=1nAle7FCVFZdix2--ks0r5JBkFnKw8ctW)
58
+ - An earlier checkpoint of the first model (100k G iters), at high performance but well before collapse, which may be easier to fine-tune: [LINK](https://drive.google.com/open?id=1dmZrcVJUAWkPBGza_XgswSuT-UODXZcO)
59
+
60
+
61
+
62
+ Pretrained models for Places-365 coming soon.
63
+
64
+ This repo also contains scripts for porting the original TFHub BigGAN Generator weights to PyTorch. See the scripts in the TFHub folder for more details.
65
+
66
+ ## Fine-tuning, Using Your Own Dataset, or Making New Training Functions
67
+ ![That's deep, man](imgs/DeepSamples.png?raw=true "Deep Samples")
68
+
69
+ If you wish to resume interrupted training or fine-tune a pre-trained model, run the same launch script but with the `--resume` argument added.
70
+ Experiment names are automatically generated from the configuration, but can be overridden using the `--experiment_name` arg (for example, if you wish to fine-tune a model using modified optimizer settings).
71
+
72
+ To prep your own dataset, you will need to add it to datasets.py and modify the convenience dicts in utils.py (dset_dict, imsize_dict, root_dict, nclass_dict, classes_per_sheet_dict) to have the appropriate metadata for your dataset.
73
+ Repeat the process in prepare_data.sh (optionally produce an HDF5 preprocessed copy, and calculate the Inception Moments for FID).
74
+
75
+ By default, the training script will save the top 5 best checkpoints as measured by Inception Score.
76
+ For datasets other than ImageNet, Inception Score can be a very poor measure of quality, so you will likely want to use `--which_best FID` instead.
77
+
78
+ To use your own training function (e.g. train a BigVAE): either modify train_fns.GAN_training_function or add a new train fn and add it after the `if config['which_train_fn'] == 'GAN':` line in `train.py`.
79
+
80
+
81
+ ## Neat Stuff
82
+ - We include the full training and metrics logs [here](https://drive.google.com/open?id=1ZhY9Mg2b_S4QwxNmt57aXJ9FOC3ZN1qb) for reference. I've found that one of the hardest things about re-implementing a paper can be checking if the logs line up early in training,
83
+ especially if training takes multiple weeks. Hopefully these will be helpful for future work.
84
+ - We include an accelerated FID calculation--the original scipy version can require upwards of 10 minutes to calculate the matrix sqrt, this version uses an accelerated PyTorch version to calculate it in under a second.
85
+ - We include an accelerated, low-memory consumption ortho reg implementation.
86
+ - By default, we only compute the top singular value (the spectral norm), but this code supports computing more SVs through the `--num_G_SVs` argument.
87
+
88
+ ## Key Differences Between This Code And The Original BigGAN
89
+ - We use the optimizer settings from SA-GAN (G_lr=1e-4, D_lr=4e-4, num_D_steps=1, as opposed to BigGAN's G_lr=5e-5, D_lr=2e-4, num_D_steps=2).
90
+ While slightly less performant, this was the first corner we cut to bring training times down.
91
+ - By default, we do not use Cross-Replica BatchNorm (AKA Synced BatchNorm).
92
+ The two variants we tried (a custom, naive one and the one included in this repo) have slightly different gradients (albeit identical forward passes) from the built-in BatchNorm, which appear to be sufficient to cripple training.
93
+ - Gradient accumulation means that we update the SV estimates and the BN statistics 8 times more frequently. This means that the BN stats are much closer to standing stats, and that the singular value estimates tend to be more accurate.
94
+ Because of this, we measure metrics by default with G in test mode (using the BatchNorm running stat estimates instead of computing standing stats as in the paper). We do still support standing stats (see the sample.sh scripts).
95
+ This could also conceivably result in gradients from the earlier accumulations being stale, but in practice this does not appear to be a problem.
96
+ - The currently provided pretrained models were not trained with orthogonal regularization. Training without ortho reg seems to increase the probability that models will not be amenable to truncation,
97
+ but it looks like this particular model got a winning ticket. Regardless, we provide two highly optimized (fast and minimal memory consumption) ortho reg implementations which directly compute the ortho reg. gradients.
98
+
99
+ ## A Note On The Design Of This Repo
100
+ This code is designed from the ground up to serve as an extensible, hackable base for further research code.
101
+ We've put a lot of thought into making sure the abstractions are the *right* thickness for research--not so thick as to be impenetrable, but not so thin as to be useless.
102
+ The key idea is that if you want to experiment with a SOTA setup and make some modification (try out your own new loss function, architecture, self-attention block, etc) you should be able to easily do so just by dropping your code in one or two places, without having to worry about the rest of the codebase.
103
+ Things like the use of self.which_conv and functools.partial in the BigGAN.py model definition were put together with this in mind, as was the design of the Spectral Norm class inheritance.
104
+
105
+ With that said, this is a somewhat large codebase for a single project. While we tried to be thorough with the comments, if there's something you think could be more clear, better written, or better refactored, please feel free to raise an issue or a pull request.
106
+
107
+ ## Feature Requests
108
+ Want to work on or improve this code? There are a couple things this repo would benefit from, but which don't yet work.
109
+
110
+ - Synchronized BatchNorm (AKA Cross-Replica BatchNorm). We tried out two variants of this, but for some unknown reason it crippled training each time.
111
+ We have not tried the [apex](https://github.com/NVIDIA/apex) SyncBN as my school's servers are on ancient NVIDIA drivers that don't support it--apex would probably be a good place to start.
112
+ - Mixed precision training and making use of Tensor cores. This repo includes a naive mixed-precision Adam implementation which works early in training but leads to early collapse, and doesn't do anything to activate Tensor cores (it just reduces memory consumption).
113
+ As above, integrating [apex](https://github.com/NVIDIA/apex) into this code and employing its mixed-precision training techniques to take advantage of Tensor cores and reduce memory consumption could yield substantial speed gains.
114
+
115
+ ## Misc Notes
116
+ See [This directory](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a) for ImageNet labels.
117
+
118
+ If you use this code, please cite
119
+ ```text
120
+ @inproceedings{
121
+ brock2018large,
122
+ title={Large Scale {GAN} Training for High Fidelity Natural Image Synthesis},
123
+ author={Andrew Brock and Jeff Donahue and Karen Simonyan},
124
+ booktitle={International Conference on Learning Representations},
125
+ year={2019},
126
+ url={https://openreview.net/forum?id=B1xsqj09Fm},
127
+ }
128
+ ```
129
+
130
+ ## Acknowledgments
131
+ Thanks to Google for the generous cloud credit donations.
132
+
133
+ [SyncBN](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch) by Jiayuan Mao and Tete Xiao.
134
+
135
+ [Progress bar](https://github.com/Lasagne/Recipes/tree/master/papers/densenet) originally from Jan SchlΓΌter.
136
+
137
+ Test metrics logger from [VoxNet.](https://github.com/dimatura/voxnet)
138
+
139
+ PyTorch [implementation of cov](https://discuss.PyTorch.org/t/covariance-and-gradient-support/16217/2) from Modar M. Alfadly.
140
+
141
+ PyTorch [fast Matrix Sqrt](https://github.com/msubhransu/matrix-sqrt) for FID from Tsung-Yu Lin and Subhransu Maji.
142
+
143
+ TensorFlow Inception Score code from [OpenAI's Improved-GAN.](https://github.com/openai/improved-gan)
144
+
src/models/big/__init__.py ADDED
File without changes
src/models/big/animal_hash.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ c = ['Aardvark', 'Abyssinian', 'Affenpinscher', 'Akbash', 'Akita', 'Albatross',
2
+ 'Alligator', 'Alpaca', 'Angelfish', 'Ant', 'Anteater', 'Antelope', 'Ape',
3
+ 'Armadillo', 'Ass', 'Avocet', 'Axolotl', 'Baboon', 'Badger', 'Balinese',
4
+ 'Bandicoot', 'Barb', 'Barnacle', 'Barracuda', 'Bat', 'Beagle', 'Bear',
5
+ 'Beaver', 'Bee', 'Beetle', 'Binturong', 'Bird', 'Birman', 'Bison',
6
+ 'Bloodhound', 'Boar', 'Bobcat', 'Bombay', 'Bongo', 'Bonobo', 'Booby',
7
+ 'Budgerigar', 'Buffalo', 'Bulldog', 'Bullfrog', 'Burmese', 'Butterfly',
8
+ 'Caiman', 'Camel', 'Capybara', 'Caracal', 'Caribou', 'Cassowary', 'Cat',
9
+ 'Caterpillar', 'Catfish', 'Cattle', 'Centipede', 'Chameleon', 'Chamois',
10
+ 'Cheetah', 'Chicken', 'Chihuahua', 'Chimpanzee', 'Chinchilla', 'Chinook',
11
+ 'Chipmunk', 'Chough', 'Cichlid', 'Clam', 'Coati', 'Cobra', 'Cockroach',
12
+ 'Cod', 'Collie', 'Coral', 'Cormorant', 'Cougar', 'Cow', 'Coyote',
13
+ 'Crab', 'Crane', 'Crocodile', 'Crow', 'Curlew', 'Cuscus', 'Cuttlefish',
14
+ 'Dachshund', 'Dalmatian', 'Deer', 'Dhole', 'Dingo', 'Dinosaur', 'Discus',
15
+ 'Dodo', 'Dog', 'Dogball', 'Dogfish', 'Dolphin', 'Donkey', 'Dormouse',
16
+ 'Dove', 'Dragonfly', 'Drever', 'Duck', 'Dugong', 'Dunker', 'Dunlin',
17
+ 'Eagle', 'Earwig', 'Echidna', 'Eel', 'Eland', 'Elephant', 'ElephantSeal',
18
+ 'Elk', 'Emu', 'Falcon', 'Ferret', 'Finch', 'Fish', 'Flamingo', 'Flounder',
19
+ 'Fly', 'Fossa', 'Fox', 'Frigatebird', 'Frog', 'Galago', 'Gar', 'Gaur',
20
+ 'Gazelle', 'Gecko', 'Gerbil', 'Gharial', 'GiantPanda', 'Gibbon', 'Giraffe',
21
+ 'Gnat', 'Gnu', 'Goat', 'Goldfinch', 'Goldfish', 'Goose', 'Gopher',
22
+ 'Gorilla', 'Goshawk', 'Grasshopper', 'Greyhound', 'Grouse', 'Guanaco',
23
+ 'GuineaFowl', 'GuineaPig', 'Gull', 'Guppy', 'Hamster', 'Hare', 'Harrier',
24
+ 'Havanese', 'Hawk', 'Hedgehog', 'Heron', 'Herring', 'Himalayan',
25
+ 'Hippopotamus', 'Hornet', 'Horse', 'Human', 'Hummingbird', 'Hyena',
26
+ 'Ibis', 'Iguana', 'Impala', 'Indri', 'Insect', 'Jackal', 'Jaguar',
27
+ 'Javanese', 'Jay', 'Jellyfish', 'Kakapo', 'Kangaroo', 'Kingfisher',
28
+ 'Kiwi', 'Koala', 'KomodoDragon', 'Kouprey', 'Kudu', 'Labradoodle',
29
+ 'Ladybird', 'Lapwing', 'Lark', 'Lemming', 'Lemur', 'Leopard', 'Liger',
30
+ 'Lion', 'Lionfish', 'Lizard', 'Llama', 'Lobster', 'Locust', 'Loris',
31
+ 'Louse', 'Lynx', 'Lyrebird', 'Macaw', 'Magpie', 'Mallard', 'Maltese',
32
+ 'Manatee', 'Mandrill', 'Markhor', 'Marten', 'Mastiff', 'Mayfly', 'Meerkat',
33
+ 'Millipede', 'Mink', 'Mole', 'Molly', 'Mongoose', 'Mongrel', 'Monkey',
34
+ 'Moorhen', 'Moose', 'Mosquito', 'Moth', 'Mouse', 'Mule', 'Narwhal',
35
+ 'Neanderthal', 'Newfoundland', 'Newt', 'Nightingale', 'Numbat', 'Ocelot',
36
+ 'Octopus', 'Okapi', 'Olm', 'Opossum', 'Orang-utan', 'Oryx', 'Ostrich',
37
+ 'Otter', 'Owl', 'Ox', 'Oyster', 'Pademelon', 'Panther', 'Parrot',
38
+ 'Partridge', 'Peacock', 'Peafowl', 'Pekingese', 'Pelican', 'Penguin',
39
+ 'Persian', 'Pheasant', 'Pig', 'Pigeon', 'Pika', 'Pike', 'Piranha',
40
+ 'Platypus', 'Pointer', 'Pony', 'Poodle', 'Porcupine', 'Porpoise',
41
+ 'Possum', 'PrairieDog', 'Prawn', 'Puffin', 'Pug', 'Puma', 'Quail',
42
+ 'Quelea', 'Quetzal', 'Quokka', 'Quoll', 'Rabbit', 'Raccoon', 'Ragdoll',
43
+ 'Rail', 'Ram', 'Rat', 'Rattlesnake', 'Raven', 'RedDeer', 'RedPanda',
44
+ 'Reindeer', 'Rhinoceros', 'Robin', 'Rook', 'Rottweiler', 'Ruff',
45
+ 'Salamander', 'Salmon', 'SandDollar', 'Sandpiper', 'Saola',
46
+ 'Sardine', 'Scorpion', 'SeaLion', 'SeaUrchin', 'Seahorse',
47
+ 'Seal', 'Serval', 'Shark', 'Sheep', 'Shrew', 'Shrimp', 'Siamese',
48
+ 'Siberian', 'Skunk', 'Sloth', 'Snail', 'Snake', 'Snowshoe', 'Somali',
49
+ 'Sparrow', 'Spider', 'Sponge', 'Squid', 'Squirrel', 'Starfish', 'Starling',
50
+ 'Stingray', 'Stinkbug', 'Stoat', 'Stork', 'Swallow', 'Swan', 'Tang',
51
+ 'Tapir', 'Tarsier', 'Termite', 'Tetra', 'Tiffany', 'Tiger', 'Toad',
52
+ 'Tortoise', 'Toucan', 'Tropicbird', 'Trout', 'Tuatara', 'Turkey',
53
+ 'Turtle', 'Uakari', 'Uguisu', 'Umbrellabird', 'Viper', 'Vulture',
54
+ 'Wallaby', 'Walrus', 'Warthog', 'Wasp', 'WaterBuffalo', 'Weasel',
55
+ 'Whale', 'Whippet', 'Wildebeest', 'Wolf', 'Wolverine', 'Wombat',
56
+ 'Woodcock', 'Woodlouse', 'Woodpecker', 'Worm', 'Wrasse', 'Wren',
57
+ 'Yak', 'Zebra', 'Zebu', 'Zonkey']
58
+ a = ['able', 'above', 'absent', 'absolute', 'abstract', 'abundant', 'academic',
59
+ 'acceptable', 'accepted', 'accessible', 'accurate', 'accused', 'active',
60
+ 'actual', 'acute', 'added', 'additional', 'adequate', 'adjacent',
61
+ 'administrative', 'adorable', 'advanced', 'adverse', 'advisory',
62
+ 'aesthetic', 'afraid', 'african', 'aggregate', 'aggressive', 'agreeable',
63
+ 'agreed', 'agricultural', 'alert', 'alive', 'alleged', 'allied', 'alone',
64
+ 'alright', 'alternative', 'amateur', 'amazing', 'ambitious', 'american',
65
+ 'amused', 'ancient', 'angry', 'annoyed', 'annual', 'anonymous', 'anxious',
66
+ 'appalling', 'apparent', 'applicable', 'appropriate', 'arab', 'arbitrary',
67
+ 'architectural', 'armed', 'arrogant', 'artificial', 'artistic', 'ashamed',
68
+ 'asian', 'asleep', 'assistant', 'associated', 'atomic', 'attractive',
69
+ 'australian', 'automatic', 'autonomous', 'available', 'average',
70
+ 'awake', 'aware', 'awful', 'awkward', 'back', 'bad', 'balanced', 'bare',
71
+ 'basic', 'beautiful', 'beneficial', 'better', 'bewildered', 'big',
72
+ 'binding', 'biological', 'bitter', 'bizarre', 'black', 'blank', 'blind',
73
+ 'blonde', 'bloody', 'blue', 'blushing', 'boiling', 'bold', 'bored',
74
+ 'boring', 'bottom', 'brainy', 'brave', 'breakable', 'breezy', 'brief',
75
+ 'bright', 'brilliant', 'british', 'broad', 'broken', 'brown', 'bumpy',
76
+ 'burning', 'busy', 'calm', 'canadian', 'capable', 'capitalist', 'careful',
77
+ 'casual', 'catholic', 'causal', 'cautious', 'central', 'certain',
78
+ 'changing', 'characteristic', 'charming', 'cheap', 'cheerful', 'chemical',
79
+ 'chief', 'chilly', 'chinese', 'chosen', 'christian', 'chronic', 'chubby',
80
+ 'circular', 'civic', 'civil', 'civilian', 'classic', 'classical', 'clean',
81
+ 'clear', 'clever', 'clinical', 'close', 'closed', 'cloudy', 'clumsy',
82
+ 'coastal', 'cognitive', 'coherent', 'cold', 'collective', 'colonial',
83
+ 'colorful', 'colossal', 'coloured', 'colourful', 'combative', 'combined',
84
+ 'comfortable', 'coming', 'commercial', 'common', 'communist', 'compact',
85
+ 'comparable', 'comparative', 'compatible', 'competent', 'competitive',
86
+ 'complete', 'complex', 'complicated', 'comprehensive', 'compulsory',
87
+ 'conceptual', 'concerned', 'concrete', 'condemned', 'confident',
88
+ 'confidential', 'confused', 'conscious', 'conservation', 'conservative',
89
+ 'considerable', 'consistent', 'constant', 'constitutional',
90
+ 'contemporary', 'content', 'continental', 'continued', 'continuing',
91
+ 'continuous', 'controlled', 'controversial', 'convenient', 'conventional',
92
+ 'convinced', 'convincing', 'cooing', 'cool', 'cooperative', 'corporate',
93
+ 'correct', 'corresponding', 'costly', 'courageous', 'crazy', 'creative',
94
+ 'creepy', 'criminal', 'critical', 'crooked', 'crowded', 'crucial',
95
+ 'crude', 'cruel', 'cuddly', 'cultural', 'curious', 'curly', 'current',
96
+ 'curved', 'cute', 'daily', 'damaged', 'damp', 'dangerous', 'dark', 'dead',
97
+ 'deaf', 'deafening', 'dear', 'decent', 'decisive', 'deep', 'defeated',
98
+ 'defensive', 'defiant', 'definite', 'deliberate', 'delicate', 'delicious',
99
+ 'delighted', 'delightful', 'democratic', 'dependent', 'depressed',
100
+ 'desirable', 'desperate', 'detailed', 'determined', 'developed',
101
+ 'developing', 'devoted', 'different', 'difficult', 'digital', 'diplomatic',
102
+ 'direct', 'dirty', 'disabled', 'disappointed', 'disastrous',
103
+ 'disciplinary', 'disgusted', 'distant', 'distinct', 'distinctive',
104
+ 'distinguished', 'disturbed', 'disturbing', 'diverse', 'divine', 'dizzy',
105
+ 'domestic', 'dominant', 'double', 'doubtful', 'drab', 'dramatic',
106
+ 'dreadful', 'driving', 'drunk', 'dry', 'dual', 'due', 'dull', 'dusty',
107
+ 'dutch', 'dying', 'dynamic', 'eager', 'early', 'eastern', 'easy',
108
+ 'economic', 'educational', 'eerie', 'effective', 'efficient',
109
+ 'elaborate', 'elated', 'elderly', 'eldest', 'electoral', 'electric',
110
+ 'electrical', 'electronic', 'elegant', 'eligible', 'embarrassed',
111
+ 'embarrassing', 'emotional', 'empirical', 'empty', 'enchanting',
112
+ 'encouraging', 'endless', 'energetic', 'english', 'enormous',
113
+ 'enthusiastic', 'entire', 'entitled', 'envious', 'environmental', 'equal',
114
+ 'equivalent', 'essential', 'established', 'estimated', 'ethical',
115
+ 'ethnic', 'european', 'eventual', 'everyday', 'evident', 'evil',
116
+ 'evolutionary', 'exact', 'excellent', 'exceptional', 'excess',
117
+ 'excessive', 'excited', 'exciting', 'exclusive', 'existing', 'exotic',
118
+ 'expected', 'expensive', 'experienced', 'experimental', 'explicit',
119
+ 'extended', 'extensive', 'external', 'extra', 'extraordinary', 'extreme',
120
+ 'exuberant', 'faint', 'fair', 'faithful', 'familiar', 'famous', 'fancy',
121
+ 'fantastic', 'far', 'fascinating', 'fashionable', 'fast', 'fat', 'fatal',
122
+ 'favourable', 'favourite', 'federal', 'fellow', 'female', 'feminist',
123
+ 'few', 'fierce', 'filthy', 'final', 'financial', 'fine', 'firm', 'fiscal',
124
+ 'fit', 'fixed', 'flaky', 'flat', 'flexible', 'fluffy', 'fluttering',
125
+ 'flying', 'following', 'fond', 'foolish', 'foreign', 'formal',
126
+ 'formidable', 'forthcoming', 'fortunate', 'forward', 'fragile',
127
+ 'frail', 'frantic', 'free', 'french', 'frequent', 'fresh', 'friendly',
128
+ 'frightened', 'front', 'frozen', 'fucking', 'full', 'full-time', 'fun',
129
+ 'functional', 'fundamental', 'funny', 'furious', 'future', 'fuzzy',
130
+ 'gastric', 'gay', 'general', 'generous', 'genetic', 'gentle', 'genuine',
131
+ 'geographical', 'german', 'giant', 'gigantic', 'given', 'glad',
132
+ 'glamorous', 'gleaming', 'global', 'glorious', 'golden', 'good',
133
+ 'gorgeous', 'gothic', 'governing', 'graceful', 'gradual', 'grand',
134
+ 'grateful', 'greasy', 'great', 'greek', 'green', 'grey', 'grieving',
135
+ 'grim', 'gross', 'grotesque', 'growing', 'grubby', 'grumpy', 'guilty',
136
+ 'handicapped', 'handsome', 'happy', 'hard', 'harsh', 'head', 'healthy',
137
+ 'heavy', 'helpful', 'helpless', 'hidden', 'high', 'high-pitched',
138
+ 'hilarious', 'hissing', 'historic', 'historical', 'hollow', 'holy',
139
+ 'homeless', 'homely', 'hon', 'honest', 'horizontal', 'horrible',
140
+ 'hostile', 'hot', 'huge', 'human', 'hungry', 'hurt', 'hushed', 'husky',
141
+ 'icy', 'ideal', 'identical', 'ideological', 'ill', 'illegal',
142
+ 'imaginative', 'immediate', 'immense', 'imperial', 'implicit',
143
+ 'important', 'impossible', 'impressed', 'impressive', 'improved',
144
+ 'inadequate', 'inappropriate', 'inc', 'inclined', 'increased',
145
+ 'increasing', 'incredible', 'independent', 'indian', 'indirect',
146
+ 'individual', 'industrial', 'inevitable', 'influential', 'informal',
147
+ 'inherent', 'initial', 'injured', 'inland', 'inner', 'innocent',
148
+ 'innovative', 'inquisitive', 'instant', 'institutional', 'insufficient',
149
+ 'intact', 'integral', 'integrated', 'intellectual', 'intelligent',
150
+ 'intense', 'intensive', 'interested', 'interesting', 'interim',
151
+ 'interior', 'intermediate', 'internal', 'international', 'intimate',
152
+ 'invisible', 'involved', 'iraqi', 'irish', 'irrelevant', 'islamic',
153
+ 'isolated', 'israeli', 'italian', 'itchy', 'japanese', 'jealous',
154
+ 'jewish', 'jittery', 'joint', 'jolly', 'joyous', 'judicial', 'juicy',
155
+ 'junior', 'just', 'keen', 'key', 'kind', 'known', 'korean', 'labour',
156
+ 'large', 'large-scale', 'late', 'latin', 'lazy', 'leading', 'left',
157
+ 'legal', 'legislative', 'legitimate', 'lengthy', 'lesser', 'level',
158
+ 'lexical', 'liable', 'liberal', 'light', 'like', 'likely', 'limited',
159
+ 'linear', 'linguistic', 'liquid', 'literary', 'little', 'live', 'lively',
160
+ 'living', 'local', 'logical', 'lonely', 'long', 'long-term', 'loose',
161
+ 'lost', 'loud', 'lovely', 'low', 'loyal', 'ltd', 'lucky', 'mad',
162
+ 'magenta', 'magic', 'magnetic', 'magnificent', 'main', 'major', 'male',
163
+ 'mammoth', 'managerial', 'managing', 'manual', 'many', 'marginal',
164
+ 'marine', 'marked', 'married', 'marvellous', 'marxist', 'mass', 'massive',
165
+ 'mathematical', 'mature', 'maximum', 'mean', 'meaningful', 'mechanical',
166
+ 'medical', 'medieval', 'melodic', 'melted', 'mental', 'mere',
167
+ 'metropolitan', 'mid', 'middle', 'middle-class', 'mighty', 'mild',
168
+ 'military', 'miniature', 'minimal', 'minimum', 'ministerial', 'minor',
169
+ 'miserable', 'misleading', 'missing', 'misty', 'mixed', 'moaning',
170
+ 'mobile', 'moderate', 'modern', 'modest', 'molecular', 'monetary',
171
+ 'monthly', 'moral', 'motionless', 'muddy', 'multiple', 'mushy',
172
+ 'musical', 'mute', 'mutual', 'mysterious', 'naked', 'narrow', 'nasty',
173
+ 'national', 'native', 'natural', 'naughty', 'naval', 'near', 'nearby',
174
+ 'neat', 'necessary', 'negative', 'neighbouring', 'nervous', 'net',
175
+ 'neutral', 'new', 'nice', 'nineteenth-century', 'noble', 'noisy',
176
+ 'normal', 'northern', 'nosy', 'notable', 'novel', 'nuclear', 'numerous',
177
+ 'nursing', 'nutritious', 'nutty', 'obedient', 'objective', 'obliged',
178
+ 'obnoxious', 'obvious', 'occasional', 'occupational', 'odd', 'official',
179
+ 'ok', 'okay', 'old', 'old-fashioned', 'olympic', 'only', 'open',
180
+ 'operational', 'opposite', 'optimistic', 'oral', 'orange', 'ordinary',
181
+ 'organic', 'organisational', 'original', 'orthodox', 'other', 'outdoor',
182
+ 'outer', 'outrageous', 'outside', 'outstanding', 'overall', 'overseas',
183
+ 'overwhelming', 'painful', 'pale', 'palestinian', 'panicky', 'parallel',
184
+ 'parental', 'parliamentary', 'part-time', 'partial', 'particular',
185
+ 'passing', 'passive', 'past', 'patient', 'payable', 'peaceful',
186
+ 'peculiar', 'perfect', 'permanent', 'persistent', 'personal', 'petite',
187
+ 'philosophical', 'physical', 'pink', 'plain', 'planned', 'plastic',
188
+ 'pleasant', 'pleased', 'poised', 'polish', 'polite', 'political', 'poor',
189
+ 'popular', 'positive', 'possible', 'post-war', 'potential', 'powerful',
190
+ 'practical', 'precious', 'precise', 'preferred', 'pregnant',
191
+ 'preliminary', 'premier', 'prepared', 'present', 'presidential',
192
+ 'pretty', 'previous', 'prickly', 'primary', 'prime', 'primitive',
193
+ 'principal', 'printed', 'prior', 'private', 'probable', 'productive',
194
+ 'professional', 'profitable', 'profound', 'progressive', 'prominent',
195
+ 'promising', 'proper', 'proposed', 'prospective', 'protective',
196
+ 'protestant', 'proud', 'provincial', 'psychiatric', 'psychological',
197
+ 'public', 'puny', 'pure', 'purple', 'purring', 'puzzled', 'quaint',
198
+ 'qualified', 'quick', 'quickest', 'quiet', 'racial', 'radical', 'rainy',
199
+ 'random', 'rapid', 'rare', 'raspy', 'rational', 'ratty', 'raw', 'ready',
200
+ 'real', 'realistic', 'rear', 'reasonable', 'recent', 'red', 'reduced',
201
+ 'redundant', 'regional', 'registered', 'regular', 'regulatory', 'related',
202
+ 'relative', 'relaxed', 'relevant', 'reliable', 'relieved', 'religious',
203
+ 'reluctant', 'remaining', 'remarkable', 'remote', 'renewed',
204
+ 'representative', 'repulsive', 'required', 'resident', 'residential',
205
+ 'resonant', 'respectable', 'respective', 'responsible', 'resulting',
206
+ 'retail', 'retired', 'revolutionary', 'rich', 'ridiculous', 'right',
207
+ 'rigid', 'ripe', 'rising', 'rival', 'roasted', 'robust', 'rolling',
208
+ 'roman', 'romantic', 'rotten', 'rough', 'round', 'royal', 'rubber',
209
+ 'rude', 'ruling', 'running', 'rural', 'russian', 'sacred', 'sad', 'safe',
210
+ 'salty', 'satisfactory', 'satisfied', 'scared', 'scary', 'scattered',
211
+ 'scientific', 'scornful', 'scottish', 'scrawny', 'screeching',
212
+ 'secondary', 'secret', 'secure', 'select', 'selected', 'selective',
213
+ 'selfish', 'semantic', 'senior', 'sensible', 'sensitive', 'separate',
214
+ 'serious', 'severe', 'sexual', 'shaggy', 'shaky', 'shallow', 'shared',
215
+ 'sharp', 'sheer', 'shiny', 'shivering', 'shocked', 'short', 'short-term',
216
+ 'shrill', 'shy', 'sick', 'significant', 'silent', 'silky', 'silly',
217
+ 'similar', 'simple', 'single', 'skilled', 'skinny', 'sleepy', 'slight',
218
+ 'slim', 'slimy', 'slippery', 'slow', 'small', 'smart', 'smiling',
219
+ 'smoggy', 'smooth', 'so-called', 'social', 'socialist', 'soft', 'solar',
220
+ 'sole', 'solid', 'sophisticated', 'sore', 'sorry', 'sound', 'sour',
221
+ 'southern', 'soviet', 'spanish', 'spare', 'sparkling', 'spatial',
222
+ 'special', 'specific', 'specified', 'spectacular', 'spicy', 'spiritual',
223
+ 'splendid', 'spontaneous', 'sporting', 'spotless', 'spotty', 'square',
224
+ 'squealing', 'stable', 'stale', 'standard', 'static', 'statistical',
225
+ 'statutory', 'steady', 'steep', 'sticky', 'stiff', 'still', 'stingy',
226
+ 'stormy', 'straight', 'straightforward', 'strange', 'strategic',
227
+ 'strict', 'striking', 'striped', 'strong', 'structural', 'stuck',
228
+ 'stupid', 'subjective', 'subsequent', 'substantial', 'subtle',
229
+ 'successful', 'successive', 'sudden', 'sufficient', 'suitable',
230
+ 'sunny', 'super', 'superb', 'superior', 'supporting', 'supposed',
231
+ 'supreme', 'sure', 'surprised', 'surprising', 'surrounding',
232
+ 'surviving', 'suspicious', 'sweet', 'swift', 'swiss', 'symbolic',
233
+ 'sympathetic', 'systematic', 'tall', 'tame', 'tan', 'tart',
234
+ 'tasteless', 'tasty', 'technical', 'technological', 'teenage',
235
+ 'temporary', 'tender', 'tense', 'terrible', 'territorial', 'testy',
236
+ 'then', 'theoretical', 'thick', 'thin', 'thirsty', 'thorough',
237
+ 'thoughtful', 'thoughtless', 'thundering', 'tight', 'tiny', 'tired',
238
+ 'top', 'tory', 'total', 'tough', 'toxic', 'traditional', 'tragic',
239
+ 'tremendous', 'tricky', 'tropical', 'troubled', 'turkish', 'typical',
240
+ 'ugliest', 'ugly', 'ultimate', 'unable', 'unacceptable', 'unaware',
241
+ 'uncertain', 'unchanged', 'uncomfortable', 'unconscious', 'underground',
242
+ 'underlying', 'unemployed', 'uneven', 'unexpected', 'unfair',
243
+ 'unfortunate', 'unhappy', 'uniform', 'uninterested', 'unique', 'united',
244
+ 'universal', 'unknown', 'unlikely', 'unnecessary', 'unpleasant',
245
+ 'unsightly', 'unusual', 'unwilling', 'upper', 'upset', 'uptight',
246
+ 'urban', 'urgent', 'used', 'useful', 'useless', 'usual', 'vague',
247
+ 'valid', 'valuable', 'variable', 'varied', 'various', 'varying', 'vast',
248
+ 'verbal', 'vertical', 'very', 'victorian', 'victorious', 'video-taped',
249
+ 'violent', 'visible', 'visiting', 'visual', 'vital', 'vivacious',
250
+ 'vivid', 'vocational', 'voiceless', 'voluntary', 'vulnerable',
251
+ 'wandering', 'warm', 'wasteful', 'watery', 'weak', 'wealthy', 'weary',
252
+ 'wee', 'weekly', 'weird', 'welcome', 'well', 'well-known', 'welsh',
253
+ 'western', 'wet', 'whispering', 'white', 'whole', 'wicked', 'wide',
254
+ 'wide-eyed', 'widespread', 'wild', 'willing', 'wise', 'witty',
255
+ 'wonderful', 'wooden', 'working', 'working-class', 'worldwide',
256
+ 'worried', 'worrying', 'worthwhile', 'worthy', 'written', 'wrong',
257
+ 'yellow', 'young', 'yummy', 'zany', 'zealous']
258
+ b = ['abiding', 'accelerating', 'accepting', 'accomplishing', 'achieving',
259
+ 'acquiring', 'acteding', 'activating', 'adapting', 'adding', 'addressing',
260
+ 'administering', 'admiring', 'admiting', 'adopting', 'advising', 'affording',
261
+ 'agreeing', 'alerting', 'alighting', 'allowing', 'altereding', 'amusing',
262
+ 'analyzing', 'announcing', 'annoying', 'answering', 'anticipating',
263
+ 'apologizing', 'appearing', 'applauding', 'applieding', 'appointing',
264
+ 'appraising', 'appreciating', 'approving', 'arbitrating', 'arguing',
265
+ 'arising', 'arranging', 'arresting', 'arriving', 'ascertaining', 'asking',
266
+ 'assembling', 'assessing', 'assisting', 'assuring', 'attaching', 'attacking',
267
+ 'attaining', 'attempting', 'attending', 'attracting', 'auditeding', 'avoiding',
268
+ 'awaking', 'backing', 'baking', 'balancing', 'baning', 'banging', 'baring',
269
+ 'bating', 'bathing', 'battling', 'bing', 'beaming', 'bearing', 'beating',
270
+ 'becoming', 'beging', 'begining', 'behaving', 'beholding', 'belonging',
271
+ 'bending', 'beseting', 'beting', 'biding', 'binding', 'biting', 'bleaching',
272
+ 'bleeding', 'blessing', 'blinding', 'blinking', 'bloting', 'blowing',
273
+ 'blushing', 'boasting', 'boiling', 'bolting', 'bombing', 'booking',
274
+ 'boring', 'borrowing', 'bouncing', 'bowing', 'boxing', 'braking',
275
+ 'branching', 'breaking', 'breathing', 'breeding', 'briefing', 'bringing',
276
+ 'broadcasting', 'bruising', 'brushing', 'bubbling', 'budgeting', 'building',
277
+ 'bumping', 'burning', 'bursting', 'burying', 'busting', 'buying', 'buzing',
278
+ 'calculating', 'calling', 'camping', 'caring', 'carrying', 'carving',
279
+ 'casting', 'cataloging', 'catching', 'causing', 'challenging', 'changing',
280
+ 'charging', 'charting', 'chasing', 'cheating', 'checking', 'cheering',
281
+ 'chewing', 'choking', 'choosing', 'choping', 'claiming', 'claping',
282
+ 'clarifying', 'classifying', 'cleaning', 'clearing', 'clinging', 'cliping',
283
+ 'closing', 'clothing', 'coaching', 'coiling', 'collecting', 'coloring',
284
+ 'combing', 'coming', 'commanding', 'communicating', 'comparing', 'competing',
285
+ 'compiling', 'complaining', 'completing', 'composing', 'computing',
286
+ 'conceiving', 'concentrating', 'conceptualizing', 'concerning', 'concluding',
287
+ 'conducting', 'confessing', 'confronting', 'confusing', 'connecting',
288
+ 'conserving', 'considering', 'consisting', 'consolidating', 'constructing',
289
+ 'consulting', 'containing', 'continuing', 'contracting', 'controling',
290
+ 'converting', 'coordinating', 'copying', 'correcting', 'correlating',
291
+ 'costing', 'coughing', 'counseling', 'counting', 'covering', 'cracking',
292
+ 'crashing', 'crawling', 'creating', 'creeping', 'critiquing', 'crossing',
293
+ 'crushing', 'crying', 'curing', 'curling', 'curving', 'cuting', 'cycling',
294
+ 'daming', 'damaging', 'dancing', 'daring', 'dealing', 'decaying', 'deceiving',
295
+ 'deciding', 'decorating', 'defining', 'delaying', 'delegating', 'delighting',
296
+ 'delivering', 'demonstrating', 'depending', 'describing', 'deserting',
297
+ 'deserving', 'designing', 'destroying', 'detailing', 'detecting',
298
+ 'determining', 'developing', 'devising', 'diagnosing', 'diging',
299
+ 'directing', 'disagreing', 'disappearing', 'disapproving', 'disarming',
300
+ 'discovering', 'disliking', 'dispensing', 'displaying', 'disproving',
301
+ 'dissecting', 'distributing', 'diving', 'diverting', 'dividing', 'doing',
302
+ 'doubling', 'doubting', 'drafting', 'draging', 'draining', 'dramatizing',
303
+ 'drawing', 'dreaming', 'dressing', 'drinking', 'driping', 'driving',
304
+ 'dropping', 'drowning', 'druming', 'drying', 'dusting', 'dwelling',
305
+ 'earning', 'eating', 'editeding', 'educating', 'eliminating',
306
+ 'embarrassing', 'employing', 'emptying', 'enacteding', 'encouraging',
307
+ 'ending', 'enduring', 'enforcing', 'engineering', 'enhancing',
308
+ 'enjoying', 'enlisting', 'ensuring', 'entering', 'entertaining',
309
+ 'escaping', 'establishing', 'estimating', 'evaluating', 'examining',
310
+ 'exceeding', 'exciting', 'excusing', 'executing', 'exercising', 'exhibiting',
311
+ 'existing', 'expanding', 'expecting', 'expediting', 'experimenting',
312
+ 'explaining', 'exploding', 'expressing', 'extending', 'extracting',
313
+ 'facing', 'facilitating', 'fading', 'failing', 'fancying', 'fastening',
314
+ 'faxing', 'fearing', 'feeding', 'feeling', 'fencing', 'fetching', 'fighting',
315
+ 'filing', 'filling', 'filming', 'finalizing', 'financing', 'finding',
316
+ 'firing', 'fiting', 'fixing', 'flaping', 'flashing', 'fleing', 'flinging',
317
+ 'floating', 'flooding', 'flowing', 'flowering', 'flying', 'folding',
318
+ 'following', 'fooling', 'forbiding', 'forcing', 'forecasting', 'foregoing',
319
+ 'foreseing', 'foretelling', 'forgeting', 'forgiving', 'forming',
320
+ 'formulating', 'forsaking', 'framing', 'freezing', 'frightening', 'frying',
321
+ 'gathering', 'gazing', 'generating', 'geting', 'giving', 'glowing', 'gluing',
322
+ 'going', 'governing', 'grabing', 'graduating', 'grating', 'greasing', 'greeting',
323
+ 'grinning', 'grinding', 'griping', 'groaning', 'growing', 'guaranteeing',
324
+ 'guarding', 'guessing', 'guiding', 'hammering', 'handing', 'handling',
325
+ 'handwriting', 'hanging', 'happening', 'harassing', 'harming', 'hating',
326
+ 'haunting', 'heading', 'healing', 'heaping', 'hearing', 'heating', 'helping',
327
+ 'hiding', 'hitting', 'holding', 'hooking', 'hoping', 'hopping', 'hovering',
328
+ 'hugging', 'hmuming', 'hunting', 'hurrying', 'hurting', 'hypothesizing',
329
+ 'identifying', 'ignoring', 'illustrating', 'imagining', 'implementing',
330
+ 'impressing', 'improving', 'improvising', 'including', 'increasing',
331
+ 'inducing', 'influencing', 'informing', 'initiating', 'injecting',
332
+ 'injuring', 'inlaying', 'innovating', 'inputing', 'inspecting',
333
+ 'inspiring', 'installing', 'instituting', 'instructing', 'insuring',
334
+ 'integrating', 'intending', 'intensifying', 'interesting',
335
+ 'interfering', 'interlaying', 'interpreting', 'interrupting',
336
+ 'interviewing', 'introducing', 'inventing', 'inventorying',
337
+ 'investigating', 'inviting', 'irritating', 'itching', 'jailing',
338
+ 'jamming', 'jogging', 'joining', 'joking', 'judging', 'juggling', 'jumping',
339
+ 'justifying', 'keeping', 'kepting', 'kicking', 'killing', 'kissing', 'kneeling',
340
+ 'kniting', 'knocking', 'knotting', 'knowing', 'labeling', 'landing', 'lasting',
341
+ 'laughing', 'launching', 'laying', 'leading', 'leaning', 'leaping', 'learning',
342
+ 'leaving', 'lecturing', 'leding', 'lending', 'leting', 'leveling',
343
+ 'licensing', 'licking', 'lying', 'lifteding', 'lighting', 'lightening',
344
+ 'liking', 'listing', 'listening', 'living', 'loading', 'locating',
345
+ 'locking', 'loging', 'longing', 'looking', 'losing', 'loving',
346
+ 'maintaining', 'making', 'maning', 'managing', 'manipulating',
347
+ 'manufacturing', 'mapping', 'marching', 'marking', 'marketing',
348
+ 'marrying', 'matching', 'mating', 'mattering', 'meaning', 'measuring',
349
+ 'meddling', 'mediating', 'meeting', 'melting', 'melting', 'memorizing',
350
+ 'mending', 'mentoring', 'milking', 'mining', 'misleading', 'missing',
351
+ 'misspelling', 'mistaking', 'misunderstanding', 'mixing', 'moaning',
352
+ 'modeling', 'modifying', 'monitoring', 'mooring', 'motivating',
353
+ 'mourning', 'moving', 'mowing', 'muddling', 'muging', 'multiplying',
354
+ 'murdering', 'nailing', 'naming', 'navigating', 'needing', 'negotiating',
355
+ 'nesting', 'noding', 'nominating', 'normalizing', 'noting', 'noticing',
356
+ 'numbering', 'obeying', 'objecting', 'observing', 'obtaining', 'occuring',
357
+ 'offending', 'offering', 'officiating', 'opening', 'operating', 'ordering',
358
+ 'organizing', 'orienteding', 'originating', 'overcoming', 'overdoing',
359
+ 'overdrawing', 'overflowing', 'overhearing', 'overtaking', 'overthrowing',
360
+ 'owing', 'owning', 'packing', 'paddling', 'painting', 'parking', 'parting',
361
+ 'participating', 'passing', 'pasting', 'pating', 'pausing', 'paying',
362
+ 'pecking', 'pedaling', 'peeling', 'peeping', 'perceiving', 'perfecting',
363
+ 'performing', 'permiting', 'persuading', 'phoning', 'photographing',
364
+ 'picking', 'piloting', 'pinching', 'pining', 'pinpointing', 'pioneering',
365
+ 'placing', 'planing', 'planting', 'playing', 'pleading', 'pleasing',
366
+ 'plugging', 'pointing', 'poking', 'polishing', 'poping', 'possessing',
367
+ 'posting', 'pouring', 'practicing', 'praiseding', 'praying', 'preaching',
368
+ 'preceding', 'predicting', 'prefering', 'preparing', 'prescribing',
369
+ 'presenting', 'preserving', 'preseting', 'presiding', 'pressing',
370
+ 'pretending', 'preventing', 'pricking', 'printing', 'processing',
371
+ 'procuring', 'producing', 'professing', 'programing', 'progressing',
372
+ 'projecting', 'promising', 'promoting', 'proofreading', 'proposing',
373
+ 'protecting', 'proving', 'providing', 'publicizing', 'pulling', 'pumping',
374
+ 'punching', 'puncturing', 'punishing', 'purchasing', 'pushing', 'puting',
375
+ 'qualifying', 'questioning', 'queuing', 'quiting', 'racing', 'radiating',
376
+ 'raining', 'raising', 'ranking', 'rating', 'reaching', 'reading',
377
+ 'realigning', 'realizing', 'reasoning', 'receiving', 'recognizing',
378
+ 'recommending', 'reconciling', 'recording', 'recruiting', 'reducing',
379
+ 'referring', 'reflecting', 'refusing', 'regreting', 'regulating',
380
+ 'rehabilitating', 'reigning', 'reinforcing', 'rejecting', 'rejoicing',
381
+ 'relating', 'relaxing', 'releasing', 'relying', 'remaining', 'remembering',
382
+ 'reminding', 'removing', 'rendering', 'reorganizing', 'repairing',
383
+ 'repeating', 'replacing', 'replying', 'reporting', 'representing',
384
+ 'reproducing', 'requesting', 'rescuing', 'researching', 'resolving',
385
+ 'responding', 'restoreding', 'restructuring', 'retiring', 'retrieving',
386
+ 'returning', 'reviewing', 'revising', 'rhyming', 'riding', 'riding',
387
+ 'ringing', 'rinsing', 'rising', 'risking', 'robing', 'rocking', 'rolling',
388
+ 'roting', 'rubing', 'ruining', 'ruling', 'runing', 'rushing', 'sacking',
389
+ 'sailing', 'satisfying', 'saving', 'sawing', 'saying', 'scaring',
390
+ 'scattering', 'scheduling', 'scolding', 'scorching', 'scraping',
391
+ 'scratching', 'screaming', 'screwing', 'scribbling', 'scrubing',
392
+ 'sealing', 'searching', 'securing', 'seing', 'seeking', 'selecting',
393
+ 'selling', 'sending', 'sensing', 'separating', 'serving', 'servicing',
394
+ 'seting', 'settling', 'sewing', 'shading', 'shaking', 'shaping',
395
+ 'sharing', 'shaving', 'shearing', 'sheding', 'sheltering', 'shining',
396
+ 'shivering', 'shocking', 'shoing', 'shooting', 'shoping', 'showing',
397
+ 'shrinking', 'shruging', 'shuting', 'sighing', 'signing', 'signaling',
398
+ 'simplifying', 'sining', 'singing', 'sinking', 'siping', 'siting',
399
+ 'sketching', 'skiing', 'skiping', 'slaping', 'slaying', 'sleeping',
400
+ 'sliding', 'slinging', 'slinking', 'sliping', 'sliting', 'slowing',
401
+ 'smashing', 'smelling', 'smiling', 'smiting', 'smoking', 'snatching',
402
+ 'sneaking', 'sneezing', 'sniffing', 'snoring', 'snowing', 'soaking',
403
+ 'solving', 'soothing', 'soothsaying', 'sorting', 'sounding', 'sowing',
404
+ 'sparing', 'sparking', 'sparkling', 'speaking', 'specifying', 'speeding',
405
+ 'spelling', 'spending', 'spilling', 'spining', 'spiting', 'spliting',
406
+ 'spoiling', 'spoting', 'spraying', 'spreading', 'springing', 'sprouting',
407
+ 'squashing', 'squeaking', 'squealing', 'squeezing', 'staining', 'stamping',
408
+ 'standing', 'staring', 'starting', 'staying', 'stealing', 'steering',
409
+ 'stepping', 'sticking', 'stimulating', 'stinging', 'stinking', 'stirring',
410
+ 'stitching', 'stoping', 'storing', 'straping', 'streamlining',
411
+ 'strengthening', 'stretching', 'striding', 'striking', 'stringing',
412
+ 'stripping', 'striving', 'stroking', 'structuring', 'studying',
413
+ 'stuffing', 'subleting', 'subtracting', 'succeeding', 'sucking',
414
+ 'suffering', 'suggesting', 'suiting', 'summarizing', 'supervising',
415
+ 'supplying', 'supporting', 'supposing', 'surprising', 'surrounding',
416
+ 'suspecting', 'suspending', 'swearing', 'sweating', 'sweeping', 'swelling',
417
+ 'swimming', 'swinging', 'switching', 'symbolizing', 'synthesizing',
418
+ 'systemizing', 'tabulating', 'taking', 'talking', 'taming', 'taping',
419
+ 'targeting', 'tasting', 'teaching', 'tearing', 'teasing', 'telephoning',
420
+ 'telling', 'tempting', 'terrifying', 'testing', 'thanking', 'thawing',
421
+ 'thinking', 'thriving', 'throwing', 'thrusting', 'ticking', 'tickling',
422
+ 'tying', 'timing', 'tiping', 'tiring', 'touching', 'touring', 'towing',
423
+ 'tracing', 'trading', 'training', 'transcribing', 'transfering',
424
+ 'transforming', 'translating', 'transporting', 'traping', 'traveling',
425
+ 'treading', 'treating', 'trembling', 'tricking', 'triping', 'troting',
426
+ 'troubling', 'troubleshooting', 'trusting', 'trying', 'tuging', 'tumbling',
427
+ 'turning', 'tutoring', 'twisting', 'typing', 'undergoing', 'understanding',
428
+ 'undertaking', 'undressing', 'unfastening', 'unifying', 'uniting',
429
+ 'unlocking', 'unpacking', 'untidying', 'updating', 'upgrading',
430
+ 'upholding', 'upseting', 'using', 'utilizing', 'vanishing', 'verbalizing',
431
+ 'verifying', 'vexing', 'visiting', 'wailing', 'waiting', 'waking',
432
+ 'walking', 'wandering', 'wanting', 'warming', 'warning', 'washing',
433
+ 'wasting', 'watching', 'watering', 'waving', 'wearing', 'weaving',
434
+ 'wedding', 'weeping', 'weighing', 'welcoming', 'wending', 'weting',
435
+ 'whining', 'whiping', 'whirling', 'whispering', 'whistling', 'wining',
436
+ 'winding', 'winking', 'wiping', 'wishing', 'withdrawing', 'withholding',
437
+ 'withstanding', 'wobbling', 'wondering', 'working', 'worrying', 'wrapping',
438
+ 'wrecking', 'wrestling', 'wriggling', 'wringing', 'writing', 'x-raying',
439
+ 'yawning', 'yelling', 'zipping', 'zooming']
src/models/big/cheat sheet ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from big.BigGAN2 import Generator,Discriminator
2
+
3
+ from big.losses import generator_loss, discriminator_loss
4
+
5
+ generator = Generator().cuda()
6
+ discriminator = Discriminator().cuda()
7
+
8
+ label_transformed_fake = label_fc_net(label_fake)
9
+ label_transformed_real = label_fc_net(label_real)
10
+
11
+ generated_images = generator(decoder_input,label_transformed_fake)
12
+
13
+ #disc training
14
+
15
+ prediction_fake = discriminator(generated_images.detach(),label_transformed_fake).view(-1)
16
+ prediction_real = discriminator(images,label_transformed_real).view(-1)
17
+
18
+ d_loss_real,d_loss_fake = discriminator_loss(prediction_fake,prediction_real)
19
+
20
+ discriminator.optim.step()
21
+
22
+
23
+ #gen training
24
+
25
+ prediction = discriminator(generated_images,label_transformed_fake).view(-1)
26
+
27
+ g_loss = generator_loss( prediction)
28
+ g_loss.backward()
29
+
30
+ generator.optim.step()
src/models/big/datasets.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ''' Datasets
2
+ This file contains definitions for our CIFAR, ImageFolder, and HDF5 datasets
3
+ '''
4
+ import os
5
+ import os.path
6
+ import sys
7
+ from PIL import Image
8
+ import numpy as np
9
+ from tqdm import tqdm, trange
10
+
11
+ import torchvision.datasets as dset
12
+ import torchvision.transforms as transforms
13
+ from torchvision.datasets.utils import download_url, check_integrity
14
+ import torch.utils.data as data
15
+ from torch.utils.data import DataLoader
16
+
17
+ IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
18
+
19
+
20
+ def is_image_file(filename):
21
+ """Checks if a file is an image.
22
+
23
+ Args:
24
+ filename (string): path to a file
25
+
26
+ Returns:
27
+ bool: True if the filename ends with a known image extension
28
+ """
29
+ filename_lower = filename.lower()
30
+ return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)
31
+
32
+
33
+ def find_classes(dir):
34
+ classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
35
+ classes.sort()
36
+ class_to_idx = {classes[i]: i for i in range(len(classes))}
37
+ return classes, class_to_idx
38
+
39
+
40
+ def make_dataset(dir, class_to_idx):
41
+ images = []
42
+ dir = os.path.expanduser(dir)
43
+ for target in tqdm(sorted(os.listdir(dir))):
44
+ d = os.path.join(dir, target)
45
+ if not os.path.isdir(d):
46
+ continue
47
+
48
+ for root, _, fnames in sorted(os.walk(d)):
49
+ for fname in sorted(fnames):
50
+ if is_image_file(fname):
51
+ path = os.path.join(root, fname)
52
+ item = (path, class_to_idx[target])
53
+ images.append(item)
54
+
55
+ return images
56
+
57
+
58
+ def pil_loader(path):
59
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
60
+ with open(path, 'rb') as f:
61
+ img = Image.open(f)
62
+ return img.convert('RGB')
63
+
64
+
65
+ def accimage_loader(path):
66
+ import accimage
67
+ try:
68
+ return accimage.Image(path)
69
+ except IOError:
70
+ # Potentially a decoding problem, fall back to PIL.Image
71
+ return pil_loader(path)
72
+
73
+
74
+ def default_loader(path):
75
+ from torchvision import get_image_backend
76
+ if get_image_backend() == 'accimage':
77
+ return accimage_loader(path)
78
+ else:
79
+ return pil_loader(path)
80
+
81
+
82
+ class ImageFolder(data.Dataset):
83
+ """A generic data loader where the images are arranged in this way: ::
84
+
85
+ root/dogball/xxx.png
86
+ root/dogball/xxy.png
87
+ root/dogball/xxz.png
88
+
89
+ root/cat/123.png
90
+ root/cat/nsdf3.png
91
+ root/cat/asd932_.png
92
+
93
+ Args:
94
+ root (string): Root directory path.
95
+ transform (callable, optional): A function/transform that takes in an PIL image
96
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
97
+ target_transform (callable, optional): A function/transform that takes in the
98
+ target and transforms it.
99
+ loader (callable, optional): A function to load an image given its path.
100
+
101
+ Attributes:
102
+ classes (list): List of the class names.
103
+ class_to_idx (dict): Dict with items (class_name, class_index).
104
+ imgs (list): List of (image path, class_index) tuples
105
+ """
106
+
107
+ def __init__(self, root, transform=None, target_transform=None,
108
+ loader=default_loader, load_in_mem=False,
109
+ index_filename='imagenet_imgs.npz', **kwargs):
110
+ classes, class_to_idx = find_classes(root)
111
+ # Load pre-computed image directory walk
112
+ if os.path.exists(index_filename):
113
+ print('Loading pre-saved Index file %s...' % index_filename)
114
+ imgs = np.load(index_filename)['imgs']
115
+ # If first time, walk the folder directory and save the
116
+ # results to a pre-computed file.
117
+ else:
118
+ print('Generating Index file %s...' % index_filename)
119
+ imgs = make_dataset(root, class_to_idx)
120
+ np.savez_compressed(index_filename, **{'imgs' : imgs})
121
+ if len(imgs) == 0:
122
+ raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
123
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
124
+
125
+ self.root = root
126
+ self.imgs = imgs
127
+ self.classes = classes
128
+ self.class_to_idx = class_to_idx
129
+ self.transform = transform
130
+ self.target_transform = target_transform
131
+ self.loader = loader
132
+ self.load_in_mem = load_in_mem
133
+
134
+ if self.load_in_mem:
135
+ print('Loading all images into memory...')
136
+ self.data, self.labels = [], []
137
+ for index in tqdm(range(len(self.imgs))):
138
+ path, target = imgs[index][0], imgs[index][1]
139
+ self.data.append(self.transform(self.loader(path)))
140
+ self.labels.append(target)
141
+
142
+
143
+ def __getitem__(self, index):
144
+ """
145
+ Args:
146
+ index (int): Index
147
+
148
+ Returns:
149
+ tuple: (image, target) where target is class_index of the target class.
150
+ """
151
+ if self.load_in_mem:
152
+ img = self.data[index]
153
+ target = self.labels[index]
154
+ else:
155
+ path, target = self.imgs[index]
156
+ img = self.loader(str(path))
157
+ if self.transform is not None:
158
+ img = self.transform(img)
159
+
160
+ if self.target_transform is not None:
161
+ target = self.target_transform(target)
162
+
163
+ # print(img.size(), target)
164
+ return img, int(target)
165
+
166
+ def __len__(self):
167
+ return len(self.imgs)
168
+
169
+ def __repr__(self):
170
+ fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
171
+ fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
172
+ fmt_str += ' Root Location: {}\n'.format(self.root)
173
+ tmp = ' Transforms (if any): '
174
+ fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
175
+ tmp = ' Target Transforms (if any): '
176
+ fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
177
+ return fmt_str
178
+
179
+
180
+ ''' ILSVRC_HDF5: A dataset to support I/O from an HDF5 to avoid
181
+ having to load individual images all the time. '''
182
+ import h5py as h5
183
+ import torch
184
+ class ILSVRC_HDF5(data.Dataset):
185
+ def __init__(self, root, transform=None, target_transform=None,
186
+ load_in_mem=False, train=True,download=False, validate_seed=0,
187
+ val_split=0, **kwargs): # last four are dummies
188
+
189
+ self.root = root
190
+ self.num_imgs = len(h5.File(root, 'r')['labels'])
191
+
192
+ # self.transform = transform
193
+ self.target_transform = target_transform
194
+
195
+ # Set the transform here
196
+ self.transform = transform
197
+
198
+ # load the entire dataset into memory?
199
+ self.load_in_mem = load_in_mem
200
+
201
+ # If loading into memory, do so now
202
+ if self.load_in_mem:
203
+ print('Loading %s into memory...' % root)
204
+ with h5.File(root,'r') as f:
205
+ self.data = f['imgs'][:]
206
+ self.labels = f['labels'][:]
207
+
208
+ def __getitem__(self, index):
209
+ """
210
+ Args:
211
+ index (int): Index
212
+
213
+ Returns:
214
+ tuple: (image, target) where target is class_index of the target class.
215
+ """
216
+ # If loaded the entire dataset in RAM, get image from memory
217
+ if self.load_in_mem:
218
+ img = self.data[index]
219
+ target = self.labels[index]
220
+
221
+ # Else load it from disk
222
+ else:
223
+ with h5.File(self.root,'r') as f:
224
+ img = f['imgs'][index]
225
+ target = f['labels'][index]
226
+
227
+
228
+ # if self.transform is not None:
229
+ # img = self.transform(img)
230
+ # Apply my own transform
231
+ img = ((torch.from_numpy(img).float() / 255) - 0.5) * 2
232
+
233
+ if self.target_transform is not None:
234
+ target = self.target_transform(target)
235
+
236
+ return img, int(target)
237
+
238
+ def __len__(self):
239
+ return self.num_imgs
240
+ # return len(self.f['imgs'])
241
+
242
+ import pickle
243
+ class CIFAR10(dset.CIFAR10):
244
+
245
+ def __init__(self, root, train=True,
246
+ transform=None, target_transform=None,
247
+ download=True, validate_seed=0,
248
+ val_split=0, load_in_mem=True, **kwargs):
249
+ self.root = os.path.expanduser(root)
250
+ self.transform = transform
251
+ self.target_transform = target_transform
252
+ self.train = train # training set or test set
253
+ self.val_split = val_split
254
+
255
+ if download:
256
+ self.download()
257
+
258
+ if not self._check_integrity():
259
+ raise RuntimeError('Dataset not found or corrupted.' +
260
+ ' You can use download=True to download it')
261
+
262
+ # now load the picked numpy arrays
263
+ self.data = []
264
+ self.labels= []
265
+ for fentry in self.train_list:
266
+ f = fentry[0]
267
+ file = os.path.join(self.root, self.base_folder, f)
268
+ fo = open(file, 'rb')
269
+ if sys.version_info[0] == 2:
270
+ entry = pickle.load(fo)
271
+ else:
272
+ entry = pickle.load(fo, encoding='latin1')
273
+ self.data.append(entry['data'])
274
+ if 'labels' in entry:
275
+ self.labels += entry['labels']
276
+ else:
277
+ self.labels += entry['fine_labels']
278
+ fo.close()
279
+
280
+ self.data = np.concatenate(self.data)
281
+ # Randomly select indices for validation
282
+ if self.val_split > 0:
283
+ label_indices = [[] for _ in range(max(self.labels)+1)]
284
+ for i,l in enumerate(self.labels):
285
+ label_indices[l] += [i]
286
+ label_indices = np.asarray(label_indices)
287
+
288
+ # randomly grab 500 elements of each class
289
+ np.random.seed(validate_seed)
290
+ self.val_indices = []
291
+ for l_i in label_indices:
292
+ self.val_indices += list(l_i[np.random.choice(len(l_i), int(len(self.data) * val_split) // (max(self.labels) + 1) ,replace=False)])
293
+
294
+ if self.train=='validate':
295
+ self.data = self.data[self.val_indices]
296
+ self.labels = list(np.asarray(self.labels)[self.val_indices])
297
+
298
+ self.data = self.data.reshape((int(50e3 * self.val_split), 3, 32, 32))
299
+ self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
300
+
301
+ elif self.train:
302
+ print(np.shape(self.data))
303
+ if self.val_split > 0:
304
+ self.data = np.delete(self.data,self.val_indices,axis=0)
305
+ self.labels = list(np.delete(np.asarray(self.labels),self.val_indices,axis=0))
306
+
307
+ self.data = self.data.reshape((int(50e3 * (1.-self.val_split)), 3, 32, 32))
308
+ self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
309
+ else:
310
+ f = self.test_list[0][0]
311
+ file = os.path.join(self.root, self.base_folder, f)
312
+ fo = open(file, 'rb')
313
+ if sys.version_info[0] == 2:
314
+ entry = pickle.load(fo)
315
+ else:
316
+ entry = pickle.load(fo, encoding='latin1')
317
+ self.data = entry['data']
318
+ if 'labels' in entry:
319
+ self.labels = entry['labels']
320
+ else:
321
+ self.labels = entry['fine_labels']
322
+ fo.close()
323
+ self.data = self.data.reshape((10000, 3, 32, 32))
324
+ self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
325
+
326
+ def __getitem__(self, index):
327
+ """
328
+ Args:
329
+ index (int): Index
330
+ Returns:
331
+ tuple: (image, target) where target is index of the target class.
332
+ """
333
+ img, target = self.data[index], self.labels[index]
334
+
335
+ # doing this so that it is consistent with all other datasets
336
+ # to return a PIL Image
337
+ img = Image.fromarray(img)
338
+
339
+ if self.transform is not None:
340
+ img = self.transform(img)
341
+
342
+ if self.target_transform is not None:
343
+ target = self.target_transform(target)
344
+
345
+ return img, target
346
+
347
+ def __len__(self):
348
+ return len(self.data)
349
+
350
+
351
+ class CIFAR100(CIFAR10):
352
+ base_folder = 'cifar-100-python'
353
+ url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
354
+ filename = "cifar-100-python.tar.gz"
355
+ tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
356
+ train_list = [
357
+ ['train', '16019d7e3df5f24257cddd939b257f8d'],
358
+ ]
359
+
360
+ test_list = [
361
+ ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
362
+ ]
src/models/big/layers.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ''' Layers
2
+ This file contains various layers for the BigGAN models.
3
+ '''
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.nn import Parameter as P
8
+
9
+ from src.models.big.sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
10
+
11
+
12
+ # Projection of x onto y
13
+ def proj(x, y):
14
+ return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
15
+
16
+
17
+ # Orthogonalize x wrt list of vectors ys
18
+ def gram_schmidt(x, ys):
19
+ for y in ys:
20
+ x = x - proj(x, y)
21
+ return x
22
+
23
+
24
+ # Apply num_itrs steps of the power method to estimate top N singular values.
25
+ def power_iteration(W, u_, update=True, eps=1e-12):
26
+ # Lists holding singular vectors and values
27
+ us, vs, svs = [], [], []
28
+ for i, u in enumerate(u_):
29
+ # Run one step of the power iteration
30
+ with torch.no_grad():
31
+ v = torch.matmul(u, W)
32
+ # Run Gram-Schmidt to subtract components of all other singular vectors
33
+ v = F.normalize(gram_schmidt(v, vs), eps=eps)
34
+ # Add to the list
35
+ vs += [v]
36
+ # Update the other singular vector
37
+ u = torch.matmul(v, W.t())
38
+ # Run Gram-Schmidt to subtract components of all other singular vectors
39
+ u = F.normalize(gram_schmidt(u, us), eps=eps)
40
+ # Add to the list
41
+ us += [u]
42
+ if update:
43
+ u_[i][:] = u
44
+ # Compute this singular value and add it to the list
45
+ svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
46
+ #svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
47
+ return svs, us, vs
48
+
49
+
50
+ # Convenience passthrough function
51
+ class identity(nn.Module):
52
+ def forward(self, input):
53
+ return input
54
+
55
+
56
+ # Spectral normalization base class
57
+ class SN(object):
58
+ def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
59
+ # Number of power iterations per step
60
+ self.num_itrs = num_itrs
61
+ # Number of singular values
62
+ self.num_svs = num_svs
63
+ # Transposed?
64
+ self.transpose = transpose
65
+ # Epsilon value for avoiding divide-by-0
66
+ self.eps = eps
67
+ # Register a singular vector for each sv
68
+ for i in range(self.num_svs):
69
+ self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
70
+ self.register_buffer('sv%d' % i, torch.ones(1))
71
+
72
+ # Singular vectors (u side)
73
+ @property
74
+ def u(self):
75
+ return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
76
+
77
+ # Singular values;
78
+ # note that these buffers are just for logging and are not used in training.
79
+ @property
80
+ def sv(self):
81
+ return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
82
+
83
+ # Compute the spectrally-normalized weight
84
+ def W_(self):
85
+ W_mat = self.weight.view(self.weight.size(0), -1)
86
+ if self.transpose:
87
+ W_mat = W_mat.t()
88
+ # Apply num_itrs power iterations
89
+ for _ in range(self.num_itrs):
90
+ svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
91
+ # Update the svs
92
+ if self.training:
93
+ with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
94
+ for i, sv in enumerate(svs):
95
+ self.sv[i][:] = sv
96
+ return self.weight / svs[0]
97
+
98
+
99
+ # 2D Conv layer with spectral norm
100
+ class SNConv2d(nn.Conv2d, SN):
101
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
102
+ padding=0, dilation=1, groups=1, bias=True,
103
+ num_svs=1, num_itrs=1, eps=1e-12):
104
+ nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
105
+ padding, dilation, groups, bias)
106
+ SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
107
+ def forward(self, x):
108
+ return F.conv2d(x, self.W_(), self.bias, self.stride,
109
+ self.padding, self.dilation, self.groups)
110
+
111
+
112
+ # Linear layer with spectral norm
113
+ class SNLinear(nn.Linear, SN):
114
+ def __init__(self, in_features, out_features, bias=True,
115
+ num_svs=1, num_itrs=1, eps=1e-12):
116
+ nn.Linear.__init__(self, in_features, out_features, bias)
117
+ SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
118
+ def forward(self, x):
119
+ return F.linear(x, self.W_(), self.bias)
120
+
121
+
122
+ # Embedding layer with spectral norm
123
+ # We use num_embeddings as the dim instead of embedding_dim here
124
+ # for convenience sake
125
+ class SNEmbedding(nn.Embedding, SN):
126
+ def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
127
+ max_norm=None, norm_type=2, scale_grad_by_freq=False,
128
+ sparse=False, _weight=None,
129
+ num_svs=1, num_itrs=1, eps=1e-12):
130
+ nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
131
+ max_norm, norm_type, scale_grad_by_freq,
132
+ sparse, _weight)
133
+ SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
134
+ def forward(self, x):
135
+ return F.embedding(x, self.W_())
136
+
137
+
138
+ # A non-local block as used in SA-GAN
139
+ # Note that the implementation as described in the paper is largely incorrect;
140
+ # refer to the released code for the actual implementation.
141
+ class Attention(nn.Module):
142
+ def __init__(self, ch, which_conv=SNConv2d, name='attention'):
143
+ super(Attention, self).__init__()
144
+ # Channel multiplier
145
+ self.ch = ch
146
+ self.which_conv = which_conv
147
+ self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
148
+ self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
149
+ self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
150
+ self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
151
+ # Learnable gain parameter
152
+ self.gamma = P(torch.tensor(0.), requires_grad=True)
153
+ def forward(self, x, y=None):
154
+ # Apply convs
155
+ theta = self.theta(x)
156
+ phi = F.max_pool2d(self.phi(x), [2,2])
157
+ g = F.max_pool2d(self.g(x), [2,2])
158
+ # Perform reshapes
159
+ theta = theta.view(-1, self. ch // 8, x.shape[2] * x.shape[3])
160
+ phi = phi.view(-1, self. ch // 8, x.shape[2] * x.shape[3] // 4)
161
+ g = g.view(-1, self. ch // 2, x.shape[2] * x.shape[3] // 4)
162
+ # Matmul and softmax to get attention maps
163
+ beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
164
+ # Attention map times g path
165
+ o = self.o(torch.bmm(g, beta.transpose(1,2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
166
+ return self.gamma * o + x
167
+
168
+
169
+ # Fused batchnorm op
170
+ def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
171
+ # Apply scale and shift--if gain and bias are provided, fuse them here
172
+ # Prepare scale
173
+ scale = torch.rsqrt(var + eps)
174
+ # If a gain is provided, use it
175
+ if gain is not None:
176
+ scale = scale * gain
177
+ # Prepare shift
178
+ shift = mean * scale
179
+ # If bias is provided, use it
180
+ if bias is not None:
181
+ shift = shift - bias
182
+ return x * scale - shift
183
+ #return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
184
+
185
+
186
+ # Manual BN
187
+ # Calculate means and variances using mean-of-squares minus mean-squared
188
+ def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
189
+ # Cast x to float32 if necessary
190
+ float_x = x.float()
191
+ # Calculate expected value of x (m) and expected value of x**2 (m2)
192
+ # Mean of x
193
+ m = torch.mean(float_x, [0, 2, 3], keepdim=True)
194
+ # Mean of x squared
195
+ m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
196
+ # Calculate variance as mean of squared minus mean squared.
197
+ var = (m2 - m **2)
198
+ # Cast back to float 16 if necessary
199
+ var = var.type(x.type())
200
+ m = m.type(x.type())
201
+ # Return mean and variance for updating stored mean/var if requested
202
+ if return_mean_var:
203
+ return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
204
+ else:
205
+ return fused_bn(x, m, var, gain, bias, eps)
206
+
207
+
208
+ # My batchnorm, supports standing stats
209
+ class myBN(nn.Module):
210
+ def __init__(self, num_channels, eps=1e-5, momentum=0.1):
211
+ super(myBN, self).__init__()
212
+ # momentum for updating running stats
213
+ self.momentum = momentum
214
+ # epsilon to avoid dividing by 0
215
+ self.eps = eps
216
+ # Momentum
217
+ self.momentum = momentum
218
+ # Register buffers
219
+ self.register_buffer('stored_mean', torch.zeros(num_channels))
220
+ self.register_buffer('stored_var', torch.ones(num_channels))
221
+ self.register_buffer('accumulation_counter', torch.zeros(1))
222
+ # Accumulate running means and vars
223
+ self.accumulate_standing = False
224
+
225
+ # reset standing stats
226
+ def reset_stats(self):
227
+ self.stored_mean[:] = 0
228
+ self.stored_var[:] = 0
229
+ self.accumulation_counter[:] = 0
230
+
231
+ def forward(self, x, gain, bias):
232
+ if self.training:
233
+ out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
234
+ # If accumulating standing stats, increment them
235
+ if self.accumulate_standing:
236
+ self.stored_mean[:] = self.stored_mean + mean.data
237
+ self.stored_var[:] = self.stored_var + var.data
238
+ self.accumulation_counter += 1.0
239
+ # If not accumulating standing stats, take running averages
240
+ else:
241
+ self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
242
+ self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
243
+ return out
244
+ # If not in training mode, use the stored statistics
245
+ else:
246
+ mean = self.stored_mean.view(1, -1, 1, 1)
247
+ var = self.stored_var.view(1, -1, 1, 1)
248
+ # If using standing stats, divide them by the accumulation counter
249
+ if self.accumulate_standing:
250
+ mean = mean / self.accumulation_counter
251
+ var = var / self.accumulation_counter
252
+ return fused_bn(x, mean, var, gain, bias, self.eps)
253
+
254
+
255
+ # Simple function to handle groupnorm norm stylization
256
+ def groupnorm(x, norm_style):
257
+ # If number of channels specified in norm_style:
258
+ if 'ch' in norm_style:
259
+ ch = int(norm_style.split('_')[-1])
260
+ groups = max(int(x.shape[1]) // ch, 1)
261
+ # If number of groups specified in norm style
262
+ elif 'grp' in norm_style:
263
+ groups = int(norm_style.split('_')[-1])
264
+ # If neither, default to groups = 16
265
+ else:
266
+ groups = 16
267
+ return F.group_norm(x, groups)
268
+
269
+
270
+ # Class-conditional bn
271
+ # output size is the number of channels, input size is for the linear layers
272
+ # Andy's Note: this class feels messy but I'm not really sure how to clean it up
273
+ # Suggestions welcome! (By which I mean, refactor this and make a pull request
274
+ # if you want to make this more readable/usable).
275
+ class ccbn(nn.Module):
276
+ def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
277
+ cross_replica=False, mybn=False, norm_style='bn',):
278
+ super(ccbn, self).__init__()
279
+ self.output_size, self.input_size = output_size, input_size
280
+ # Prepare gain and bias layers
281
+ self.gain = which_linear(input_size, output_size)
282
+ self.bias = which_linear(input_size, output_size)
283
+ # epsilon to avoid dividing by 0
284
+ self.eps = eps
285
+ # Momentum
286
+ self.momentum = momentum
287
+ # Use cross-replica batchnorm?
288
+ self.cross_replica = cross_replica
289
+ # Use my batchnorm?
290
+ self.mybn = mybn
291
+ # Norm style?
292
+ self.norm_style = norm_style
293
+
294
+ if self.cross_replica:
295
+ self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
296
+ elif self.mybn:
297
+ self.bn = myBN(output_size, self.eps, self.momentum)
298
+ elif self.norm_style in ['bn', 'in']:
299
+ self.register_buffer('stored_mean', torch.zeros(output_size))
300
+ self.register_buffer('stored_var', torch.ones(output_size))
301
+
302
+
303
+ def forward(self, x, y):
304
+ # Calculate class-conditional gains and biases
305
+ gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
306
+ bias = self.bias(y).view(y.size(0), -1, 1, 1)
307
+ # If using my batchnorm
308
+ if self.mybn or self.cross_replica:
309
+ return self.bn(x, gain=gain, bias=bias)
310
+ # else:
311
+ else:
312
+ if self.norm_style == 'bn':
313
+ out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
314
+ self.training, 0.1, self.eps)
315
+ elif self.norm_style == 'in':
316
+ out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
317
+ self.training, 0.1, self.eps)
318
+ elif self.norm_style == 'gn':
319
+ out = groupnorm(x, self.normstyle)
320
+ elif self.norm_style == 'nonorm':
321
+ out = x
322
+ return out * gain + bias
323
+ def extra_repr(self):
324
+ s = 'out: {output_size}, in: {input_size},'
325
+ s +=' cross_replica={cross_replica}'
326
+ return s.format(**self.__dict__)
327
+
328
+
329
+ # Normal, non-class-conditional BN
330
+ class bn(nn.Module):
331
+ def __init__(self, output_size, eps=1e-5, momentum=0.1,
332
+ cross_replica=False, mybn=False):
333
+ super(bn, self).__init__()
334
+ self.output_size= output_size
335
+ # Prepare gain and bias layers
336
+ self.gain = P(torch.ones(output_size), requires_grad=True)
337
+ self.bias = P(torch.zeros(output_size), requires_grad=True)
338
+ # epsilon to avoid dividing by 0
339
+ self.eps = eps
340
+ # Momentum
341
+ self.momentum = momentum
342
+ # Use cross-replica batchnorm?
343
+ self.cross_replica = cross_replica
344
+ # Use my batchnorm?
345
+ self.mybn = mybn
346
+
347
+ if self.cross_replica:
348
+ self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
349
+ elif mybn:
350
+ self.bn = myBN(output_size, self.eps, self.momentum)
351
+ # Register buffers if neither of the above
352
+ else:
353
+ self.register_buffer('stored_mean', torch.zeros(output_size))
354
+ self.register_buffer('stored_var', torch.ones(output_size))
355
+
356
+ def forward(self, x, y=None):
357
+ if self.cross_replica or self.mybn:
358
+ gain = self.gain.view(1,-1,1,1)
359
+ bias = self.bias.view(1,-1,1,1)
360
+ return self.bn(x, gain=gain, bias=bias)
361
+ else:
362
+ return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
363
+ self.bias, self.training, self.momentum, self.eps)
364
+
365
+
366
+ # Generator blocks
367
+ # Note that this class assumes the kernel size and padding (and any other
368
+ # settings) have been selected in the main generator module and passed in
369
+ # through the which_conv arg. Similar rules apply with which_bn (the input
370
+ # size [which is actually the number of channels of the conditional info] must
371
+ # be preselected)
372
+ class GBlock(nn.Module):
373
+ def __init__(self, in_channels, out_channels,
374
+ which_conv=nn.Conv2d, which_bn=bn, activation=None,
375
+ upsample=None):
376
+ super(GBlock, self).__init__()
377
+
378
+ self.in_channels, self.out_channels = in_channels, out_channels
379
+ self.which_conv, self.which_bn = which_conv, which_bn
380
+ self.activation = activation
381
+ self.upsample = upsample
382
+ # Conv layers
383
+ self.conv1 = self.which_conv(self.in_channels, self.out_channels)
384
+ self.conv2 = self.which_conv(self.out_channels, self.out_channels)
385
+ self.learnable_sc = in_channels != out_channels or upsample
386
+ if self.learnable_sc:
387
+ self.conv_sc = self.which_conv(in_channels, out_channels,
388
+ kernel_size=1, padding=0)
389
+ # Batchnorm layers
390
+ self.bn1 = self.which_bn(in_channels)
391
+ self.bn2 = self.which_bn(out_channels)
392
+ # upsample layers
393
+ self.upsample = upsample
394
+
395
+ def forward(self, x, y):
396
+ h = self.activation(self.bn1(x, y))
397
+ if self.upsample:
398
+ h = self.upsample(h)
399
+ x = self.upsample(x)
400
+ h = self.conv1(h)
401
+ h = self.activation(self.bn2(h, y))
402
+ h = self.conv2(h)
403
+ if self.learnable_sc:
404
+ x = self.conv_sc(x)
405
+ return h + x
406
+
407
+
408
+ # Residual block for the discriminator
409
+ class DBlock(nn.Module):
410
+ def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
411
+ preactivation=False, activation=None, downsample=None,):
412
+ super(DBlock, self).__init__()
413
+ self.in_channels, self.out_channels = in_channels, out_channels
414
+ # If using wide D (as in SA-GAN and BigGAN), change the channel pattern
415
+ self.hidden_channels = self.out_channels if wide else self.in_channels
416
+ self.which_conv = which_conv
417
+ self.preactivation = preactivation
418
+ self.activation = activation
419
+ self.downsample = downsample
420
+
421
+ # Conv layers
422
+ self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
423
+ self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
424
+ self.learnable_sc = True if (in_channels != out_channels) or downsample else False
425
+ if self.learnable_sc:
426
+ self.conv_sc = self.which_conv(in_channels, out_channels,
427
+ kernel_size=1, padding=0)
428
+ def shortcut(self, x):
429
+ if self.preactivation:
430
+ if self.learnable_sc:
431
+ x = self.conv_sc(x)
432
+ if self.downsample:
433
+ x = self.downsample(x)
434
+ else:
435
+ if self.downsample:
436
+ x = self.downsample(x)
437
+ if self.learnable_sc:
438
+ x = self.conv_sc(x)
439
+ return x
440
+
441
+ def forward(self, x):
442
+ if self.preactivation:
443
+ # h = self.activation(x) # NOT TODAY SATAN
444
+ # Andy's note: This line *must* be an out-of-place ReLU or it
445
+ # will negatively affect the shortcut connection.
446
+ h = F.relu(x)
447
+ else:
448
+ h = x
449
+ h = self.conv1(h)
450
+ h = self.conv2(self.activation(h))
451
+ if self.downsample:
452
+ h = self.downsample(h)
453
+
454
+ return h + self.shortcut(x)
455
+
456
+ # dogball
src/models/big/sync_batchnorm/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12
+ from .replicate import DataParallelWithCallback, patch_replication_callback
src/models/big/sync_batchnorm/batchnorm.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from torch.nn.modules.batchnorm import _BatchNorm
17
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18
+
19
+ from .comm import SyncMaster
20
+
21
+ __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22
+
23
+
24
+ def _sum_ft(tensor):
25
+ """sum over the first and last dimention"""
26
+ return tensor.sum(dim=0).sum(dim=-1)
27
+
28
+
29
+ def _unsqueeze_ft(tensor):
30
+ """add new dementions at the front and the tail"""
31
+ return tensor.unsqueeze(0).unsqueeze(-1)
32
+
33
+
34
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36
+ # _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size'])
37
+
38
+ class _SynchronizedBatchNorm(_BatchNorm):
39
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41
+
42
+ self._sync_master = SyncMaster(self._data_parallel_master)
43
+
44
+ self._is_parallel = False
45
+ self._parallel_id = None
46
+ self._slave_pipe = None
47
+
48
+ def forward(self, input, gain=None, bias=None):
49
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50
+ if not (self._is_parallel and self.training):
51
+ out = F.batch_norm(
52
+ input, self.running_mean, self.running_var, self.weight, self.bias,
53
+ self.training, self.momentum, self.eps)
54
+ if gain is not None:
55
+ out = out + gain
56
+ if bias is not None:
57
+ out = out + bias
58
+ return out
59
+
60
+ # Resize the input to (B, C, -1).
61
+ input_shape = input.size()
62
+ # print(input_shape)
63
+ input = input.view(input.size(0), input.size(1), -1)
64
+
65
+ # Compute the sum and square-sum.
66
+ sum_size = input.size(0) * input.size(2)
67
+ input_sum = _sum_ft(input)
68
+ input_ssum = _sum_ft(input ** 2)
69
+ # Reduce-and-broadcast the statistics.
70
+ # print('it begins')
71
+ if self._parallel_id == 0:
72
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
73
+ else:
74
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
75
+ # if self._parallel_id == 0:
76
+ # # print('here')
77
+ # sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
78
+ # else:
79
+ # # print('there')
80
+ # sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
81
+
82
+ # print('how2')
83
+ # num = sum_size
84
+ # print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu())))
85
+ # Fix the graph
86
+ # sum = (sum.detach() - input_sum.detach()) + input_sum
87
+ # ssum = (ssum.detach() - input_ssum.detach()) + input_ssum
88
+
89
+ # mean = sum / num
90
+ # var = ssum / num - mean ** 2
91
+ # # var = (ssum - mean * sum) / num
92
+ # inv_std = torch.rsqrt(var + self.eps)
93
+
94
+ # Compute the output.
95
+ if gain is not None:
96
+ # print('gaining')
97
+ # scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1)
98
+ # shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1)
99
+ # output = input * scale - shift
100
+ output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1)
101
+ elif self.affine:
102
+ # MJY:: Fuse the multiplication for speed.
103
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
104
+ else:
105
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
106
+
107
+ # Reshape it.
108
+ return output.view(input_shape)
109
+
110
+ def __data_parallel_replicate__(self, ctx, copy_id):
111
+ self._is_parallel = True
112
+ self._parallel_id = copy_id
113
+
114
+ # parallel_id == 0 means master device.
115
+ if self._parallel_id == 0:
116
+ ctx.sync_master = self._sync_master
117
+ else:
118
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
119
+
120
+ def _data_parallel_master(self, intermediates):
121
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
122
+
123
+ # Always using same "device order" makes the ReduceAdd operation faster.
124
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
125
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
126
+
127
+ to_reduce = [i[1][:2] for i in intermediates]
128
+ to_reduce = [j for i in to_reduce for j in i] # flatten
129
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
130
+
131
+ sum_size = sum([i[1].sum_size for i in intermediates])
132
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
133
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
134
+
135
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
136
+ # print('a')
137
+ # print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size)
138
+ # broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device))
139
+ # print('b')
140
+ outputs = []
141
+ for i, rec in enumerate(intermediates):
142
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
143
+ # outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3])))
144
+
145
+ return outputs
146
+
147
+ def _compute_mean_std(self, sum_, ssum, size):
148
+ """Compute the mean and standard-deviation with sum and square-sum. This method
149
+ also maintains the moving average on the master device."""
150
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
151
+ mean = sum_ / size
152
+ sumvar = ssum - sum_ * mean
153
+ unbias_var = sumvar / (size - 1)
154
+ bias_var = sumvar / size
155
+
156
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
157
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
158
+ return mean, torch.rsqrt(bias_var + self.eps)
159
+ # return mean, bias_var.clamp(self.eps) ** -0.5
160
+
161
+
162
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
163
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
164
+ mini-batch.
165
+
166
+ .. math::
167
+
168
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
169
+
170
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
171
+ standard-deviation are reduced across all devices during training.
172
+
173
+ For example, when one uses `nn.DataParallel` to wrap the network during
174
+ training, PyTorch's implementation normalize the tensor on each device using
175
+ the statistics only on that device, which accelerated the computation and
176
+ is also easy to implement, but the statistics might be inaccurate.
177
+ Instead, in this synchronized version, the statistics will be computed
178
+ over all training samples distributed on multiple devices.
179
+
180
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
181
+ as the built-in PyTorch implementation.
182
+
183
+ The mean and standard-deviation are calculated per-dimension over
184
+ the mini-batches and gamma and beta are learnable parameter vectors
185
+ of size C (where C is the input size).
186
+
187
+ During training, this layer keeps a running estimate of its computed mean
188
+ and variance. The running sum is kept with a default momentum of 0.1.
189
+
190
+ During evaluation, this running mean/variance is used for normalization.
191
+
192
+ Because the BatchNorm is done over the `C` dimension, computing statistics
193
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
194
+
195
+ Args:
196
+ num_features: num_features from an expected input of size
197
+ `batch_size x num_features [x width]`
198
+ eps: a value added to the denominator for numerical stability.
199
+ Default: 1e-5
200
+ momentum: the value used for the running_mean and running_var
201
+ computation. Default: 0.1
202
+ affine: a boolean value that when set to ``True``, gives the layer learnable
203
+ affine parameters. Default: ``True``
204
+
205
+ Shape:
206
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
207
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
208
+
209
+ Examples:
210
+ >>> # With Learnable Parameters
211
+ >>> m = SynchronizedBatchNorm1d(100)
212
+ >>> # Without Learnable Parameters
213
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
214
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
215
+ >>> output = m(input)
216
+ """
217
+
218
+ def _check_input_dim(self, input):
219
+ if input.dim() != 2 and input.dim() != 3:
220
+ raise ValueError('expected 2D or 3D input (got {}D input)'
221
+ .format(input.dim()))
222
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
223
+
224
+
225
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
226
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
227
+ of 3d inputs
228
+
229
+ .. math::
230
+
231
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
232
+
233
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
234
+ standard-deviation are reduced across all devices during training.
235
+
236
+ For example, when one uses `nn.DataParallel` to wrap the network during
237
+ training, PyTorch's implementation normalize the tensor on each device using
238
+ the statistics only on that device, which accelerated the computation and
239
+ is also easy to implement, but the statistics might be inaccurate.
240
+ Instead, in this synchronized version, the statistics will be computed
241
+ over all training samples distributed on multiple devices.
242
+
243
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
244
+ as the built-in PyTorch implementation.
245
+
246
+ The mean and standard-deviation are calculated per-dimension over
247
+ the mini-batches and gamma and beta are learnable parameter vectors
248
+ of size C (where C is the input size).
249
+
250
+ During training, this layer keeps a running estimate of its computed mean
251
+ and variance. The running sum is kept with a default momentum of 0.1.
252
+
253
+ During evaluation, this running mean/variance is used for normalization.
254
+
255
+ Because the BatchNorm is done over the `C` dimension, computing statistics
256
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
257
+
258
+ Args:
259
+ num_features: num_features from an expected input of
260
+ size batch_size x num_features x height x width
261
+ eps: a value added to the denominator for numerical stability.
262
+ Default: 1e-5
263
+ momentum: the value used for the running_mean and running_var
264
+ computation. Default: 0.1
265
+ affine: a boolean value that when set to ``True``, gives the layer learnable
266
+ affine parameters. Default: ``True``
267
+
268
+ Shape:
269
+ - Input: :math:`(N, C, H, W)`
270
+ - Output: :math:`(N, C, H, W)` (same shape as input)
271
+
272
+ Examples:
273
+ >>> # With Learnable Parameters
274
+ >>> m = SynchronizedBatchNorm2d(100)
275
+ >>> # Without Learnable Parameters
276
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
277
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
278
+ >>> output = m(input)
279
+ """
280
+
281
+ def _check_input_dim(self, input):
282
+ if input.dim() != 4:
283
+ raise ValueError('expected 4D input (got {}D input)'
284
+ .format(input.dim()))
285
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
286
+
287
+
288
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
289
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
290
+ of 4d inputs
291
+
292
+ .. math::
293
+
294
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
295
+
296
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
297
+ standard-deviation are reduced across all devices during training.
298
+
299
+ For example, when one uses `nn.DataParallel` to wrap the network during
300
+ training, PyTorch's implementation normalize the tensor on each device using
301
+ the statistics only on that device, which accelerated the computation and
302
+ is also easy to implement, but the statistics might be inaccurate.
303
+ Instead, in this synchronized version, the statistics will be computed
304
+ over all training samples distributed on multiple devices.
305
+
306
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
307
+ as the built-in PyTorch implementation.
308
+
309
+ The mean and standard-deviation are calculated per-dimension over
310
+ the mini-batches and gamma and beta are learnable parameter vectors
311
+ of size C (where C is the input size).
312
+
313
+ During training, this layer keeps a running estimate of its computed mean
314
+ and variance. The running sum is kept with a default momentum of 0.1.
315
+
316
+ During evaluation, this running mean/variance is used for normalization.
317
+
318
+ Because the BatchNorm is done over the `C` dimension, computing statistics
319
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
320
+ or Spatio-temporal BatchNorm
321
+
322
+ Args:
323
+ num_features: num_features from an expected input of
324
+ size batch_size x num_features x depth x height x width
325
+ eps: a value added to the denominator for numerical stability.
326
+ Default: 1e-5
327
+ momentum: the value used for the running_mean and running_var
328
+ computation. Default: 0.1
329
+ affine: a boolean value that when set to ``True``, gives the layer learnable
330
+ affine parameters. Default: ``True``
331
+
332
+ Shape:
333
+ - Input: :math:`(N, C, D, H, W)`
334
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
335
+
336
+ Examples:
337
+ >>> # With Learnable Parameters
338
+ >>> m = SynchronizedBatchNorm3d(100)
339
+ >>> # Without Learnable Parameters
340
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
341
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
342
+ >>> output = m(input)
343
+ """
344
+
345
+ def _check_input_dim(self, input):
346
+ if input.dim() != 5:
347
+ raise ValueError('expected 5D input (got {}D input)'
348
+ .format(input.dim()))
349
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
src/models/big/sync_batchnorm/batchnorm_reimpl.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : batchnorm_reimpl.py
4
+ # Author : acgtyrant
5
+ # Date : 11/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.init as init
14
+
15
+ __all__ = ['BatchNormReimpl']
16
+
17
+
18
+ class BatchNorm2dReimpl(nn.Module):
19
+ """
20
+ A re-implementation of batch normalization, used for testing the numerical
21
+ stability.
22
+
23
+ Author: acgtyrant
24
+ See also:
25
+ https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26
+ """
27
+ def __init__(self, num_features, eps=1e-5, momentum=0.1):
28
+ super().__init__()
29
+
30
+ self.num_features = num_features
31
+ self.eps = eps
32
+ self.momentum = momentum
33
+ self.weight = nn.Parameter(torch.empty(num_features))
34
+ self.bias = nn.Parameter(torch.empty(num_features))
35
+ self.register_buffer('running_mean', torch.zeros(num_features))
36
+ self.register_buffer('running_var', torch.ones(num_features))
37
+ self.reset_parameters()
38
+
39
+ def reset_running_stats(self):
40
+ self.running_mean.zero_()
41
+ self.running_var.fill_(1)
42
+
43
+ def reset_parameters(self):
44
+ self.reset_running_stats()
45
+ init.uniform_(self.weight)
46
+ init.zeros_(self.bias)
47
+
48
+ def forward(self, input_):
49
+ batchsize, channels, height, width = input_.size()
50
+ numel = batchsize * height * width
51
+ input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52
+ sum_ = input_.sum(1)
53
+ sum_of_square = input_.pow(2).sum(1)
54
+ mean = sum_ / numel
55
+ sumvar = sum_of_square - sum_ * mean
56
+
57
+ self.running_mean = (
58
+ (1 - self.momentum) * self.running_mean
59
+ + self.momentum * mean.detach()
60
+ )
61
+ unbias_var = sumvar / (numel - 1)
62
+ self.running_var = (
63
+ (1 - self.momentum) * self.running_var
64
+ + self.momentum * unbias_var.detach()
65
+ )
66
+
67
+ bias_var = sumvar / numel
68
+ inv_std = 1 / (bias_var + self.eps).pow(0.5)
69
+ output = (
70
+ (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71
+ self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72
+
73
+ return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74
+
src/models/big/sync_batchnorm/comm.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : comm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import queue
12
+ import collections
13
+ import threading
14
+
15
+ __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
+
17
+
18
+ class FutureResult(object):
19
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
20
+
21
+ def __init__(self):
22
+ self._result = None
23
+ self._lock = threading.Lock()
24
+ self._cond = threading.Condition(self._lock)
25
+
26
+ def put(self, result):
27
+ with self._lock:
28
+ assert self._result is None, 'Previous result has\'t been fetched.'
29
+ self._result = result
30
+ self._cond.notify()
31
+
32
+ def get(self):
33
+ with self._lock:
34
+ if self._result is None:
35
+ self._cond.wait()
36
+
37
+ res = self._result
38
+ self._result = None
39
+ return res
40
+
41
+
42
+ _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
+ _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
+
45
+
46
+ class SlavePipe(_SlavePipeBase):
47
+ """Pipe for master-slave communication."""
48
+
49
+ def run_slave(self, msg):
50
+ self.queue.put((self.identifier, msg))
51
+ ret = self.result.get()
52
+ self.queue.put(True)
53
+ return ret
54
+
55
+
56
+ class SyncMaster(object):
57
+ """An abstract `SyncMaster` object.
58
+
59
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62
+ and passed to a registered callback.
63
+ - After receiving the messages, the master device should gather the information and determine to message passed
64
+ back to each slave devices.
65
+ """
66
+
67
+ def __init__(self, master_callback):
68
+ """
69
+
70
+ Args:
71
+ master_callback: a callback to be invoked after having collected messages from slave devices.
72
+ """
73
+ self._master_callback = master_callback
74
+ self._queue = queue.Queue()
75
+ self._registry = collections.OrderedDict()
76
+ self._activated = False
77
+
78
+ def __getstate__(self):
79
+ return {'master_callback': self._master_callback}
80
+
81
+ def __setstate__(self, state):
82
+ self.__init__(state['master_callback'])
83
+
84
+ def register_slave(self, identifier):
85
+ """
86
+ Register an slave device.
87
+
88
+ Args:
89
+ identifier: an identifier, usually is the device id.
90
+
91
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
92
+
93
+ """
94
+ if self._activated:
95
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
96
+ self._activated = False
97
+ self._registry.clear()
98
+ future = FutureResult()
99
+ self._registry[identifier] = _MasterRegistry(future)
100
+ return SlavePipe(identifier, self._queue, future)
101
+
102
+ def run_master(self, master_msg):
103
+ """
104
+ Main entry for the master device in each forward pass.
105
+ The messages were first collected from each devices (including the master device), and then
106
+ an callback will be invoked to compute the message to be sent back to each devices
107
+ (including the master device).
108
+
109
+ Args:
110
+ master_msg: the message that the master want to send to itself. This will be placed as the first
111
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112
+
113
+ Returns: the message to be sent back to the master device.
114
+
115
+ """
116
+ self._activated = True
117
+
118
+ intermediates = [(0, master_msg)]
119
+ for i in range(self.nr_slaves):
120
+ intermediates.append(self._queue.get())
121
+
122
+ results = self._master_callback(intermediates)
123
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
124
+
125
+ for i, res in results:
126
+ if i == 0:
127
+ continue
128
+ self._registry[i].result.put(res)
129
+
130
+ for i in range(self.nr_slaves):
131
+ assert self._queue.get() is True
132
+
133
+ return results[0][1]
134
+
135
+ @property
136
+ def nr_slaves(self):
137
+ return len(self._registry)
src/models/big/sync_batchnorm/replicate.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : replicate.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import functools
12
+
13
+ from torch.nn.parallel.data_parallel import DataParallel
14
+
15
+ __all__ = [
16
+ 'CallbackContext',
17
+ 'execute_replication_callbacks',
18
+ 'DataParallelWithCallback',
19
+ 'patch_replication_callback'
20
+ ]
21
+
22
+
23
+ class CallbackContext(object):
24
+ pass
25
+
26
+
27
+ def execute_replication_callbacks(modules):
28
+ """
29
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30
+
31
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32
+
33
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
34
+ (shared among multiple copies of this module on different devices).
35
+ Through this context, different copies can share some information.
36
+
37
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38
+ of any slave copies.
39
+ """
40
+ master_copy = modules[0]
41
+ nr_modules = len(list(master_copy.modules()))
42
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
43
+
44
+ for i, module in enumerate(modules):
45
+ for j, m in enumerate(module.modules()):
46
+ if hasattr(m, '__data_parallel_replicate__'):
47
+ m.__data_parallel_replicate__(ctxs[j], i)
48
+
49
+
50
+ class DataParallelWithCallback(DataParallel):
51
+ """
52
+ Data Parallel with a replication callback.
53
+
54
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55
+ original `replicate` function.
56
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57
+
58
+ Examples:
59
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61
+ # sync_bn.__data_parallel_replicate__ will be invoked.
62
+ """
63
+
64
+ def replicate(self, module, device_ids):
65
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66
+ execute_replication_callbacks(modules)
67
+ return modules
68
+
69
+
70
+ def patch_replication_callback(data_parallel):
71
+ """
72
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
73
+ Useful when you have customized `DataParallel` implementation.
74
+
75
+ Examples:
76
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78
+ > patch_replication_callback(sync_bn)
79
+ # this is equivalent to
80
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82
+ """
83
+
84
+ assert isinstance(data_parallel, DataParallel)
85
+
86
+ old_replicate = data_parallel.replicate
87
+
88
+ @functools.wraps(old_replicate)
89
+ def new_replicate(module, device_ids):
90
+ modules = old_replicate(module, device_ids)
91
+ execute_replication_callbacks(modules)
92
+ return modules
93
+
94
+ data_parallel.replicate = new_replicate
src/models/big/sync_batchnorm/unittest.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : unittest.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import unittest
12
+ import torch
13
+
14
+
15
+ class TorchTestCase(unittest.TestCase):
16
+ def assertTensorClose(self, x, y):
17
+ adiff = float((x - y).abs().max())
18
+ if (y == 0).all():
19
+ rdiff = 'NaN'
20
+ else:
21
+ rdiff = float((adiff / y).abs().max())
22
+
23
+ message = (
24
+ 'Tensor close check failed\n'
25
+ 'adiff={}\n'
26
+ 'rdiff={}\n'
27
+ ).format(adiff, rdiff)
28
+ self.assertTrue(torch.allclose(x, y), message)
29
+
src/models/big/utils.py ADDED
@@ -0,0 +1,1193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ ''' Utilities file
5
+ This file contains utility functions for bookkeeping, logging, and data loading.
6
+ Methods which directly affect training should either go in layers, the model,
7
+ or train_fns.py.
8
+ '''
9
+
10
+ from __future__ import print_function
11
+ import sys
12
+ import os
13
+ import numpy as np
14
+ import time
15
+ import datetime
16
+ import json
17
+ import pickle
18
+ from argparse import ArgumentParser
19
+ import animal_hash
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import torchvision
25
+ import torchvision.transforms as transforms
26
+ from torch.utils.data import DataLoader
27
+
28
+ import datasets as dset
29
+
30
+ def prepare_parser():
31
+ usage = 'Parser for all scripts.'
32
+ parser = ArgumentParser(description=usage)
33
+
34
+ ### Dataset/Dataloader stuff ###
35
+ parser.add_argument(
36
+ '--dataset', type=str, default='I128_hdf5',
37
+ help='Which Dataset to train on, out of I128, I256, C10, C100;'
38
+ 'Append "_hdf5" to use the hdf5 version for ISLVRC '
39
+ '(default: %(default)s)')
40
+ parser.add_argument(
41
+ '--augment', action='store_true', default=False,
42
+ help='Augment with random crops and flips (default: %(default)s)')
43
+ parser.add_argument(
44
+ '--num_workers', type=int, default=8,
45
+ help='Number of dataloader workers; consider using less for HDF5 '
46
+ '(default: %(default)s)')
47
+ parser.add_argument(
48
+ '--no_pin_memory', action='store_false', dest='pin_memory', default=True,
49
+ help='Pin data into memory through dataloader? (default: %(default)s)')
50
+ parser.add_argument(
51
+ '--shuffle', action='store_true', default=False,
52
+ help='Shuffle the data (strongly recommended)? (default: %(default)s)')
53
+ parser.add_argument(
54
+ '--load_in_mem', action='store_true', default=False,
55
+ help='Load all data into memory? (default: %(default)s)')
56
+ parser.add_argument(
57
+ '--use_multiepoch_sampler', action='store_true', default=False,
58
+ help='Use the multi-epoch sampler for dataloader? (default: %(default)s)')
59
+
60
+
61
+ ### Model stuff ###
62
+ parser.add_argument(
63
+ '--model', type=str, default='BigGAN',
64
+ help='Name of the model module (default: %(default)s)')
65
+ parser.add_argument(
66
+ '--G_param', type=str, default='SN',
67
+ help='Parameterization style to use for G, spectral norm (SN) or SVD (SVD)'
68
+ ' or None (default: %(default)s)')
69
+ parser.add_argument(
70
+ '--D_param', type=str, default='SN',
71
+ help='Parameterization style to use for D, spectral norm (SN) or SVD (SVD)'
72
+ ' or None (default: %(default)s)')
73
+ parser.add_argument(
74
+ '--G_ch', type=int, default=64,
75
+ help='Channel multiplier for G (default: %(default)s)')
76
+ parser.add_argument(
77
+ '--D_ch', type=int, default=64,
78
+ help='Channel multiplier for D (default: %(default)s)')
79
+ parser.add_argument(
80
+ '--G_depth', type=int, default=1,
81
+ help='Number of resblocks per stage in G? (default: %(default)s)')
82
+ parser.add_argument(
83
+ '--D_depth', type=int, default=1,
84
+ help='Number of resblocks per stage in D? (default: %(default)s)')
85
+ parser.add_argument(
86
+ '--D_thin', action='store_false', dest='D_wide', default=True,
87
+ help='Use the SN-GAN channel pattern for D? (default: %(default)s)')
88
+ parser.add_argument(
89
+ '--G_shared', action='store_true', default=False,
90
+ help='Use shared embeddings in G? (default: %(default)s)')
91
+ parser.add_argument(
92
+ '--shared_dim', type=int, default=0,
93
+ help='G''s shared embedding dimensionality; if 0, will be equal to dim_z. '
94
+ '(default: %(default)s)')
95
+ parser.add_argument(
96
+ '--dim_z', type=int, default=128,
97
+ help='Noise dimensionality: %(default)s)')
98
+ parser.add_argument(
99
+ '--z_var', type=float, default=1.0,
100
+ help='Noise variance: %(default)s)')
101
+ parser.add_argument(
102
+ '--hier', action='store_true', default=False,
103
+ help='Use hierarchical z in G? (default: %(default)s)')
104
+ parser.add_argument(
105
+ '--cross_replica', action='store_true', default=False,
106
+ help='Cross_replica batchnorm in G?(default: %(default)s)')
107
+ parser.add_argument(
108
+ '--mybn', action='store_true', default=False,
109
+ help='Use my batchnorm (which supports standing stats?) %(default)s)')
110
+ parser.add_argument(
111
+ '--G_nl', type=str, default='relu',
112
+ help='Activation function for G (default: %(default)s)')
113
+ parser.add_argument(
114
+ '--D_nl', type=str, default='relu',
115
+ help='Activation function for D (default: %(default)s)')
116
+ parser.add_argument(
117
+ '--G_attn', type=str, default='64',
118
+ help='What resolutions to use attention on for G (underscore separated) '
119
+ '(default: %(default)s)')
120
+ parser.add_argument(
121
+ '--D_attn', type=str, default='64',
122
+ help='What resolutions to use attention on for D (underscore separated) '
123
+ '(default: %(default)s)')
124
+ parser.add_argument(
125
+ '--norm_style', type=str, default='bn',
126
+ help='Normalizer style for G, one of bn [batchnorm], in [instancenorm], '
127
+ 'ln [layernorm], gn [groupnorm] (default: %(default)s)')
128
+
129
+ ### Model init stuff ###
130
+ parser.add_argument(
131
+ '--seed', type=int, default=0,
132
+ help='Random seed to use; affects both initialization and '
133
+ ' dataloading. (default: %(default)s)')
134
+ parser.add_argument(
135
+ '--G_init', type=str, default='ortho',
136
+ help='Init style to use for G (default: %(default)s)')
137
+ parser.add_argument(
138
+ '--D_init', type=str, default='ortho',
139
+ help='Init style to use for D(default: %(default)s)')
140
+ parser.add_argument(
141
+ '--skip_init', action='store_true', default=False,
142
+ help='Skip initialization, ideal for testing when ortho init was used '
143
+ '(default: %(default)s)')
144
+
145
+ ### Optimizer stuff ###
146
+ parser.add_argument(
147
+ '--G_lr', type=float, default=5e-5,
148
+ help='Learning rate to use for Generator (default: %(default)s)')
149
+ parser.add_argument(
150
+ '--D_lr', type=float, default=2e-4,
151
+ help='Learning rate to use for Discriminator (default: %(default)s)')
152
+ parser.add_argument(
153
+ '--G_B1', type=float, default=0.0,
154
+ help='Beta1 to use for Generator (default: %(default)s)')
155
+ parser.add_argument(
156
+ '--D_B1', type=float, default=0.0,
157
+ help='Beta1 to use for Discriminator (default: %(default)s)')
158
+ parser.add_argument(
159
+ '--G_B2', type=float, default=0.999,
160
+ help='Beta2 to use for Generator (default: %(default)s)')
161
+ parser.add_argument(
162
+ '--D_B2', type=float, default=0.999,
163
+ help='Beta2 to use for Discriminator (default: %(default)s)')
164
+
165
+ ### Batch size, parallel, and precision stuff ###
166
+ parser.add_argument(
167
+ '--batch_size', type=int, default=64,
168
+ help='Default overall batchsize (default: %(default)s)')
169
+ parser.add_argument(
170
+ '--G_batch_size', type=int, default=0,
171
+ help='Batch size to use for G; if 0, same as D (default: %(default)s)')
172
+ parser.add_argument(
173
+ '--num_G_accumulations', type=int, default=1,
174
+ help='Number of passes to accumulate G''s gradients over '
175
+ '(default: %(default)s)')
176
+ parser.add_argument(
177
+ '--num_D_steps', type=int, default=2,
178
+ help='Number of D steps per G step (default: %(default)s)')
179
+ parser.add_argument(
180
+ '--num_D_accumulations', type=int, default=1,
181
+ help='Number of passes to accumulate D''s gradients over '
182
+ '(default: %(default)s)')
183
+ parser.add_argument(
184
+ '--split_D', action='store_true', default=False,
185
+ help='Run D twice rather than concatenating inputs? (default: %(default)s)')
186
+ parser.add_argument(
187
+ '--num_epochs', type=int, default=100,
188
+ help='Number of epochs to train for (default: %(default)s)')
189
+ parser.add_argument(
190
+ '--parallel', action='store_true', default=False,
191
+ help='Train with multiple GPUs (default: %(default)s)')
192
+ parser.add_argument(
193
+ '--G_fp16', action='store_true', default=False,
194
+ help='Train with half-precision in G? (default: %(default)s)')
195
+ parser.add_argument(
196
+ '--D_fp16', action='store_true', default=False,
197
+ help='Train with half-precision in D? (default: %(default)s)')
198
+ parser.add_argument(
199
+ '--D_mixed_precision', action='store_true', default=False,
200
+ help='Train with half-precision activations but fp32 params in D? '
201
+ '(default: %(default)s)')
202
+ parser.add_argument(
203
+ '--G_mixed_precision', action='store_true', default=False,
204
+ help='Train with half-precision activations but fp32 params in G? '
205
+ '(default: %(default)s)')
206
+ parser.add_argument(
207
+ '--accumulate_stats', action='store_true', default=False,
208
+ help='Accumulate "standing" batchnorm stats? (default: %(default)s)')
209
+ parser.add_argument(
210
+ '--num_standing_accumulations', type=int, default=16,
211
+ help='Number of forward passes to use in accumulating standing stats? '
212
+ '(default: %(default)s)')
213
+
214
+ ### Bookkeping stuff ###
215
+ parser.add_argument(
216
+ '--G_eval_mode', action='store_true', default=False,
217
+ help='Run G in eval mode (running/standing stats?) at sample/test time? '
218
+ '(default: %(default)s)')
219
+ parser.add_argument(
220
+ '--save_every', type=int, default=2000,
221
+ help='Save every X iterations (default: %(default)s)')
222
+ parser.add_argument(
223
+ '--num_save_copies', type=int, default=2,
224
+ help='How many copies to save (default: %(default)s)')
225
+ parser.add_argument(
226
+ '--num_best_copies', type=int, default=2,
227
+ help='How many previous best checkpoints to save (default: %(default)s)')
228
+ parser.add_argument(
229
+ '--which_best', type=str, default='IS',
230
+ help='Which metric to use to determine when to save new "best"'
231
+ 'checkpoints, one of IS or FID (default: %(default)s)')
232
+ parser.add_argument(
233
+ '--no_fid', action='store_true', default=False,
234
+ help='Calculate IS only, not FID? (default: %(default)s)')
235
+ parser.add_argument(
236
+ '--test_every', type=int, default=5000,
237
+ help='Test every X iterations (default: %(default)s)')
238
+ parser.add_argument(
239
+ '--num_inception_images', type=int, default=50000,
240
+ help='Number of samples to compute inception metrics with '
241
+ '(default: %(default)s)')
242
+ parser.add_argument(
243
+ '--hashname', action='store_true', default=False,
244
+ help='Use a hash of the experiment name instead of the full config '
245
+ '(default: %(default)s)')
246
+ parser.add_argument(
247
+ '--base_root', type=str, default='',
248
+ help='Default location to store all weights, samples, data, and logs '
249
+ ' (default: %(default)s)')
250
+ parser.add_argument(
251
+ '--data_root', type=str, default='data',
252
+ help='Default location where data is stored (default: %(default)s)')
253
+ parser.add_argument(
254
+ '--weights_root', type=str, default='weights',
255
+ help='Default location to store weights (default: %(default)s)')
256
+ parser.add_argument(
257
+ '--logs_root', type=str, default='logs',
258
+ help='Default location to store logs (default: %(default)s)')
259
+ parser.add_argument(
260
+ '--samples_root', type=str, default='samples',
261
+ help='Default location to store samples (default: %(default)s)')
262
+ parser.add_argument(
263
+ '--pbar', type=str, default='mine',
264
+ help='Type of progressbar to use; one of "mine" or "tqdm" '
265
+ '(default: %(default)s)')
266
+ parser.add_argument(
267
+ '--name_suffix', type=str, default='',
268
+ help='Suffix for experiment name for loading weights for sampling '
269
+ '(consider "best0") (default: %(default)s)')
270
+ parser.add_argument(
271
+ '--experiment_name', type=str, default='',
272
+ help='Optionally override the automatic experiment naming with this arg. '
273
+ '(default: %(default)s)')
274
+ parser.add_argument(
275
+ '--config_from_name', action='store_true', default=False,
276
+ help='Use a hash of the experiment name instead of the full config '
277
+ '(default: %(default)s)')
278
+
279
+ ### EMA Stuff ###
280
+ parser.add_argument(
281
+ '--ema', action='store_true', default=False,
282
+ help='Keep an ema of G''s weights? (default: %(default)s)')
283
+ parser.add_argument(
284
+ '--ema_decay', type=float, default=0.9999,
285
+ help='EMA decay rate (default: %(default)s)')
286
+ parser.add_argument(
287
+ '--use_ema', action='store_true', default=False,
288
+ help='Use the EMA parameters of G for evaluation? (default: %(default)s)')
289
+ parser.add_argument(
290
+ '--ema_start', type=int, default=0,
291
+ help='When to start updating the EMA weights (default: %(default)s)')
292
+
293
+ ### Numerical precision and SV stuff ###
294
+ parser.add_argument(
295
+ '--adam_eps', type=float, default=1e-8,
296
+ help='epsilon value to use for Adam (default: %(default)s)')
297
+ parser.add_argument(
298
+ '--BN_eps', type=float, default=1e-5,
299
+ help='epsilon value to use for BatchNorm (default: %(default)s)')
300
+ parser.add_argument(
301
+ '--SN_eps', type=float, default=1e-8,
302
+ help='epsilon value to use for Spectral Norm(default: %(default)s)')
303
+ parser.add_argument(
304
+ '--num_G_SVs', type=int, default=1,
305
+ help='Number of SVs to track in G (default: %(default)s)')
306
+ parser.add_argument(
307
+ '--num_D_SVs', type=int, default=1,
308
+ help='Number of SVs to track in D (default: %(default)s)')
309
+ parser.add_argument(
310
+ '--num_G_SV_itrs', type=int, default=1,
311
+ help='Number of SV itrs in G (default: %(default)s)')
312
+ parser.add_argument(
313
+ '--num_D_SV_itrs', type=int, default=1,
314
+ help='Number of SV itrs in D (default: %(default)s)')
315
+
316
+ ### Ortho reg stuff ###
317
+ parser.add_argument(
318
+ '--G_ortho', type=float, default=0.0, # 1e-4 is default for BigGAN
319
+ help='Modified ortho reg coefficient in G(default: %(default)s)')
320
+ parser.add_argument(
321
+ '--D_ortho', type=float, default=0.0,
322
+ help='Modified ortho reg coefficient in D (default: %(default)s)')
323
+ parser.add_argument(
324
+ '--toggle_grads', action='store_true', default=True,
325
+ help='Toggle D and G''s "requires_grad" settings when not training them? '
326
+ ' (default: %(default)s)')
327
+
328
+ ### Which train function ###
329
+ parser.add_argument(
330
+ '--which_train_fn', type=str, default='GAN',
331
+ help='How2trainyourbois (default: %(default)s)')
332
+
333
+ ### Resume training stuff
334
+ parser.add_argument(
335
+ '--load_weights', type=str, default='',
336
+ help='Suffix for which weights to load (e.g. best0, copy0) '
337
+ '(default: %(default)s)')
338
+ parser.add_argument(
339
+ '--resume', action='store_true', default=False,
340
+ help='Resume training? (default: %(default)s)')
341
+
342
+ ### Log stuff ###
343
+ parser.add_argument(
344
+ '--logstyle', type=str, default='%3.3e',
345
+ help='What style to use when logging training metrics?'
346
+ 'One of: %#.#f/ %#.#e (float/exp, text),'
347
+ 'pickle (python pickle),'
348
+ 'npz (numpy zip),'
349
+ 'mat (MATLAB .mat file) (default: %(default)s)')
350
+ parser.add_argument(
351
+ '--log_G_spectra', action='store_true', default=False,
352
+ help='Log the top 3 singular values in each SN layer in G? '
353
+ '(default: %(default)s)')
354
+ parser.add_argument(
355
+ '--log_D_spectra', action='store_true', default=False,
356
+ help='Log the top 3 singular values in each SN layer in D? '
357
+ '(default: %(default)s)')
358
+ parser.add_argument(
359
+ '--sv_log_interval', type=int, default=10,
360
+ help='Iteration interval for logging singular values '
361
+ ' (default: %(default)s)')
362
+
363
+ return parser
364
+
365
+ # Arguments for sample.py; not presently used in train.py
366
+ def add_sample_parser(parser):
367
+ parser.add_argument(
368
+ '--sample_npz', action='store_true', default=False,
369
+ help='Sample "sample_num_npz" images and save to npz? '
370
+ '(default: %(default)s)')
371
+ parser.add_argument(
372
+ '--sample_num_npz', type=int, default=50000,
373
+ help='Number of images to sample when sampling NPZs '
374
+ '(default: %(default)s)')
375
+ parser.add_argument(
376
+ '--sample_sheets', action='store_true', default=False,
377
+ help='Produce class-conditional sample sheets and stick them in '
378
+ 'the samples root? (default: %(default)s)')
379
+ parser.add_argument(
380
+ '--sample_interps', action='store_true', default=False,
381
+ help='Produce interpolation sheets and stick them in '
382
+ 'the samples root? (default: %(default)s)')
383
+ parser.add_argument(
384
+ '--sample_sheet_folder_num', type=int, default=-1,
385
+ help='Number to use for the folder for these sample sheets '
386
+ '(default: %(default)s)')
387
+ parser.add_argument(
388
+ '--sample_random', action='store_true', default=False,
389
+ help='Produce a single random sheet? (default: %(default)s)')
390
+ parser.add_argument(
391
+ '--sample_trunc_curves', type=str, default='',
392
+ help='Get inception metrics with a range of variances?'
393
+ 'To use this, specify a startpoint, step, and endpoint, e.g. '
394
+ '--sample_trunc_curves 0.2_0.1_1.0 for a startpoint of 0.2, '
395
+ 'endpoint of 1.0, and stepsize of 1.0. Note that this is '
396
+ 'not exactly identical to using tf.truncated_normal, but should '
397
+ 'have approximately the same effect. (default: %(default)s)')
398
+ parser.add_argument(
399
+ '--sample_inception_metrics', action='store_true', default=False,
400
+ help='Calculate Inception metrics with sample.py? (default: %(default)s)')
401
+ return parser
402
+
403
+ # Convenience dicts
404
+ dset_dict = {'I32': dset.ImageFolder, 'I64': dset.ImageFolder,
405
+ 'I128': dset.ImageFolder, 'I256': dset.ImageFolder,
406
+ 'I32_hdf5': dset.ILSVRC_HDF5, 'I64_hdf5': dset.ILSVRC_HDF5,
407
+ 'I128_hdf5': dset.ILSVRC_HDF5, 'I256_hdf5': dset.ILSVRC_HDF5,
408
+ 'C10': dset.CIFAR10, 'C100': dset.CIFAR100}
409
+ imsize_dict = {'I32': 32, 'I32_hdf5': 32,
410
+ 'I64': 64, 'I64_hdf5': 64,
411
+ 'I128': 128, 'I128_hdf5': 128,
412
+ 'I256': 256, 'I256_hdf5': 256,
413
+ 'C10': 32, 'C100': 32}
414
+ root_dict = {'I32': 'ImageNet', 'I32_hdf5': 'ILSVRC32.hdf5',
415
+ 'I64': 'ImageNet', 'I64_hdf5': 'ILSVRC64.hdf5',
416
+ 'I128': 'ImageNet', 'I128_hdf5': 'ILSVRC128.hdf5',
417
+ 'I256': 'ImageNet', 'I256_hdf5': 'ILSVRC256.hdf5',
418
+ 'C10': 'cifar', 'C100': 'cifar'}
419
+ nclass_dict = {'I32': 1000, 'I32_hdf5': 1000,
420
+ 'I64': 1000, 'I64_hdf5': 1000,
421
+ 'I128': 1000, 'I128_hdf5': 1000,
422
+ 'I256': 1000, 'I256_hdf5': 1000,
423
+ 'C10': 10, 'C100': 100}
424
+ # Number of classes to put per sample sheet
425
+ classes_per_sheet_dict = {'I32': 50, 'I32_hdf5': 50,
426
+ 'I64': 50, 'I64_hdf5': 50,
427
+ 'I128': 20, 'I128_hdf5': 20,
428
+ 'I256': 20, 'I256_hdf5': 20,
429
+ 'C10': 10, 'C100': 100}
430
+ activation_dict = {'inplace_relu': nn.ReLU(inplace=True),
431
+ 'relu': nn.ReLU(inplace=False),
432
+ 'ir': nn.ReLU(inplace=True),}
433
+
434
+ class CenterCropLongEdge(object):
435
+ """Crops the given PIL Image on the long edge.
436
+ Args:
437
+ size (sequence or int): Desired output size of the crop. If size is an
438
+ int instead of sequence like (h, w), a square crop (size, size) is
439
+ made.
440
+ """
441
+ def __call__(self, img):
442
+ """
443
+ Args:
444
+ img (PIL Image): Image to be cropped.
445
+ Returns:
446
+ PIL Image: Cropped image.
447
+ """
448
+ return transforms.functional.center_crop(img, min(img.size))
449
+
450
+ def __repr__(self):
451
+ return self.__class__.__name__
452
+
453
+ class RandomCropLongEdge(object):
454
+ """Crops the given PIL Image on the long edge with a random start point.
455
+ Args:
456
+ size (sequence or int): Desired output size of the crop. If size is an
457
+ int instead of sequence like (h, w), a square crop (size, size) is
458
+ made.
459
+ """
460
+ def __call__(self, img):
461
+ """
462
+ Args:
463
+ img (PIL Image): Image to be cropped.
464
+ Returns:
465
+ PIL Image: Cropped image.
466
+ """
467
+ size = (min(img.size), min(img.size))
468
+ # Only step forward along this edge if it's the long edge
469
+ i = (0 if size[0] == img.size[0]
470
+ else np.random.randint(low=0,high=img.size[0] - size[0]))
471
+ j = (0 if size[1] == img.size[1]
472
+ else np.random.randint(low=0,high=img.size[1] - size[1]))
473
+ return transforms.functional.crop(img, i, j, size[0], size[1])
474
+
475
+ def __repr__(self):
476
+ return self.__class__.__name__
477
+
478
+
479
+ # multi-epoch Dataset sampler to avoid memory leakage and enable resumption of
480
+ # training from the same sample regardless of if we stop mid-epoch
481
+ class MultiEpochSampler(torch.utils.data.Sampler):
482
+ r"""Samples elements randomly over multiple epochs
483
+
484
+ Arguments:
485
+ data_source (Dataset): dataset to sample from
486
+ num_epochs (int) : Number of times to loop over the dataset
487
+ start_itr (int) : which iteration to begin from
488
+ """
489
+
490
+ def __init__(self, data_source, num_epochs, start_itr=0, batch_size=128):
491
+ self.data_source = data_source
492
+ self.num_samples = len(self.data_source)
493
+ self.num_epochs = num_epochs
494
+ self.start_itr = start_itr
495
+ self.batch_size = batch_size
496
+
497
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
498
+ raise ValueError("num_samples should be a positive integeral "
499
+ "value, but got num_samples={}".format(self.num_samples))
500
+
501
+ def __iter__(self):
502
+ n = len(self.data_source)
503
+ # Determine number of epochs
504
+ num_epochs = int(np.ceil((n * self.num_epochs
505
+ - (self.start_itr * self.batch_size)) / float(n)))
506
+ # Sample all the indices, and then grab the last num_epochs index sets;
507
+ # This ensures if we're starting at epoch 4, we're still grabbing epoch 4's
508
+ # indices
509
+ out = [torch.randperm(n) for epoch in range(self.num_epochs)][-num_epochs:]
510
+ # Ignore the first start_itr % n indices of the first epoch
511
+ out[0] = out[0][(self.start_itr * self.batch_size % n):]
512
+ # if self.replacement:
513
+ # return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
514
+ # return iter(.tolist())
515
+ output = torch.cat(out).tolist()
516
+ print('Length dataset output is %d' % len(output))
517
+ return iter(output)
518
+
519
+ def __len__(self):
520
+ return len(self.data_source) * self.num_epochs - self.start_itr * self.batch_size
521
+
522
+
523
+ # Convenience function to centralize all data loaders
524
+ def get_data_loaders(dataset, data_root=None, augment=False, batch_size=64,
525
+ num_workers=8, shuffle=True, load_in_mem=False, hdf5=False,
526
+ pin_memory=True, drop_last=True, start_itr=0,
527
+ num_epochs=500, use_multiepoch_sampler=False,
528
+ **kwargs):
529
+
530
+ # Append /FILENAME.hdf5 to root if using hdf5
531
+ data_root += '/%s' % root_dict[dataset]
532
+ print('Using dataset root location %s' % data_root)
533
+
534
+ which_dataset = dset_dict[dataset]
535
+ norm_mean = [0.5,0.5,0.5]
536
+ norm_std = [0.5,0.5,0.5]
537
+ image_size = imsize_dict[dataset]
538
+ # For image folder datasets, name of the file where we store the precomputed
539
+ # image locations to avoid having to walk the dirs every time we load.
540
+ dataset_kwargs = {'index_filename': '%s_imgs.npz' % dataset}
541
+
542
+ # HDF5 datasets have their own inbuilt transform, no need to train_transform
543
+ if 'hdf5' in dataset:
544
+ train_transform = None
545
+ else:
546
+ if augment:
547
+ print('Data will be augmented...')
548
+ if dataset in ['C10', 'C100']:
549
+ train_transform = [transforms.RandomCrop(32, padding=4),
550
+ transforms.RandomHorizontalFlip()]
551
+ else:
552
+ train_transform = [RandomCropLongEdge(),
553
+ transforms.Resize(image_size),
554
+ transforms.RandomHorizontalFlip()]
555
+ else:
556
+ print('Data will not be augmented...')
557
+ if dataset in ['C10', 'C100']:
558
+ train_transform = []
559
+ else:
560
+ train_transform = [CenterCropLongEdge(), transforms.Resize(image_size)]
561
+ # train_transform = [transforms.Resize(image_size), transforms.CenterCrop]
562
+ train_transform = transforms.Compose(train_transform + [
563
+ transforms.ToTensor(),
564
+ transforms.Normalize(norm_mean, norm_std)])
565
+ train_set = which_dataset(root=data_root, transform=train_transform,
566
+ load_in_mem=load_in_mem, **dataset_kwargs)
567
+
568
+ # Prepare loader; the loaders list is for forward compatibility with
569
+ # using validation / test splits.
570
+ loaders = []
571
+ if use_multiepoch_sampler:
572
+ print('Using multiepoch sampler from start_itr %d...' % start_itr)
573
+ loader_kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory}
574
+ sampler = MultiEpochSampler(train_set, num_epochs, start_itr, batch_size)
575
+ train_loader = DataLoader(train_set, batch_size=batch_size,
576
+ sampler=sampler, **loader_kwargs)
577
+ else:
578
+ loader_kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory,
579
+ 'drop_last': drop_last} # Default, drop last incomplete batch
580
+ train_loader = DataLoader(train_set, batch_size=batch_size,
581
+ shuffle=shuffle, **loader_kwargs)
582
+ loaders.append(train_loader)
583
+ return loaders
584
+
585
+
586
+ # Utility file to seed rngs
587
+ def seed_rng(seed):
588
+ torch.manual_seed(seed)
589
+ torch.cuda.manual_seed(seed)
590
+ np.random.seed(seed)
591
+
592
+
593
+ # Utility to peg all roots to a base root
594
+ # If a base root folder is provided, peg all other root folders to it.
595
+ def update_config_roots(config):
596
+ if config['base_root']:
597
+ print('Pegging all root folders to base root %s' % config['base_root'])
598
+ for key in ['data', 'weights', 'logs', 'samples']:
599
+ config['%s_root' % key] = '%s/%s' % (config['base_root'], key)
600
+ return config
601
+
602
+
603
+ # Utility to prepare root folders if they don't exist; parent folder must exist
604
+ def prepare_root(config):
605
+ for key in ['weights_root', 'logs_root', 'samples_root']:
606
+ if not os.path.exists(config[key]):
607
+ print('Making directory %s for %s...' % (config[key], key))
608
+ os.mkdir(config[key])
609
+
610
+
611
+ # Simple wrapper that applies EMA to a model. COuld be better done in 1.0 using
612
+ # the parameters() and buffers() module functions, but for now this works
613
+ # with state_dicts using .copy_
614
+ class ema(object):
615
+ def __init__(self, source, target, decay=0.9999, start_itr=0):
616
+ self.source = source
617
+ self.target = target
618
+ self.decay = decay
619
+ # Optional parameter indicating what iteration to start the decay at
620
+ self.start_itr = start_itr
621
+ # Initialize target's params to be source's
622
+ self.source_dict = self.source.state_dict()
623
+ self.target_dict = self.target.state_dict()
624
+ print('Initializing EMA parameters to be source parameters...')
625
+ with torch.no_grad():
626
+ for key in self.source_dict:
627
+ self.target_dict[key].data.copy_(self.source_dict[key].data)
628
+ # target_dict[key].data = source_dict[key].data # Doesn't work!
629
+
630
+ def update(self, itr=None):
631
+ # If an iteration counter is provided and itr is less than the start itr,
632
+ # peg the ema weights to the underlying weights.
633
+ if itr and itr < self.start_itr:
634
+ decay = 0.0
635
+ else:
636
+ decay = self.decay
637
+ with torch.no_grad():
638
+ for key in self.source_dict:
639
+ self.target_dict[key].data.copy_(self.target_dict[key].data * decay
640
+ + self.source_dict[key].data * (1 - decay))
641
+
642
+
643
+ # Apply modified ortho reg to a model
644
+ # This function is an optimized version that directly computes the gradient,
645
+ # instead of computing and then differentiating the loss.
646
+ def ortho(model, strength=1e-4, blacklist=[]):
647
+ with torch.no_grad():
648
+ for param in model.parameters():
649
+ # Only apply this to parameters with at least 2 axes, and not in the blacklist
650
+ if len(param.shape) < 2 or any([param is item for item in blacklist]):
651
+ continue
652
+ w = param.view(param.shape[0], -1)
653
+ grad = (2 * torch.mm(torch.mm(w, w.t())
654
+ * (1. - torch.eye(w.shape[0], device=w.device)), w))
655
+ param.grad.data += strength * grad.view(param.shape)
656
+
657
+
658
+ # Default ortho reg
659
+ # This function is an optimized version that directly computes the gradient,
660
+ # instead of computing and then differentiating the loss.
661
+ def default_ortho(model, strength=1e-4, blacklist=[]):
662
+ with torch.no_grad():
663
+ for param in model.parameters():
664
+ # Only apply this to parameters with at least 2 axes & not in blacklist
665
+ if len(param.shape) < 2 or param in blacklist:
666
+ continue
667
+ w = param.view(param.shape[0], -1)
668
+ grad = (2 * torch.mm(torch.mm(w, w.t())
669
+ - torch.eye(w.shape[0], device=w.device), w))
670
+ param.grad.data += strength * grad.view(param.shape)
671
+
672
+
673
+ # Convenience utility to switch off requires_grad
674
+ def toggle_grad(model, on_or_off):
675
+ for param in model.parameters():
676
+ param.requires_grad = on_or_off
677
+
678
+
679
+ # Function to join strings or ignore them
680
+ # Base string is the string to link "strings," while strings
681
+ # is a list of strings or Nones.
682
+ def join_strings(base_string, strings):
683
+ return base_string.join([item for item in strings if item])
684
+
685
+
686
+ # Save a model's weights, optimizer, and the state_dict
687
+ def save_weights(G, D, state_dict, weights_root, experiment_name,
688
+ name_suffix=None, G_ema=None):
689
+ root = '/'.join([weights_root, experiment_name])
690
+ if not os.path.exists(root):
691
+ os.mkdir(root)
692
+ if name_suffix:
693
+ print('Saving weights to %s/%s...' % (root, name_suffix))
694
+ else:
695
+ print('Saving weights to %s...' % root)
696
+ torch.save(G.state_dict(),
697
+ '%s/%s.pth' % (root, join_strings('_', ['G', name_suffix])))
698
+ torch.save(G.optim.state_dict(),
699
+ '%s/%s.pth' % (root, join_strings('_', ['G_optim', name_suffix])))
700
+ torch.save(D.state_dict(),
701
+ '%s/%s.pth' % (root, join_strings('_', ['D', name_suffix])))
702
+ torch.save(D.optim.state_dict(),
703
+ '%s/%s.pth' % (root, join_strings('_', ['D_optim', name_suffix])))
704
+ torch.save(state_dict,
705
+ '%s/%s.pth' % (root, join_strings('_', ['state_dict', name_suffix])))
706
+ if G_ema is not None:
707
+ torch.save(G_ema.state_dict(),
708
+ '%s/%s.pth' % (root, join_strings('_', ['G_ema', name_suffix])))
709
+
710
+
711
+ # Load a model's weights, optimizer, and the state_dict
712
+ def load_weights(G, D, state_dict, weights_root, experiment_name,
713
+ name_suffix=None, G_ema=None, strict=True, load_optim=True):
714
+ root = '/'.join([weights_root, experiment_name])
715
+ if name_suffix:
716
+ print('Loading %s weights from %s...' % (name_suffix, root))
717
+ else:
718
+ print('Loading weights from %s...' % root)
719
+ if G is not None:
720
+ G.load_state_dict(
721
+ torch.load('%s/%s.pth' % (root, join_strings('_', ['G', name_suffix]))),
722
+ strict=strict)
723
+ if load_optim:
724
+ G.optim.load_state_dict(
725
+ torch.load('%s/%s.pth' % (root, join_strings('_', ['G_optim', name_suffix]))))
726
+ if D is not None:
727
+ D.load_state_dict(
728
+ torch.load('%s/%s.pth' % (root, join_strings('_', ['D', name_suffix]))),
729
+ strict=strict)
730
+ if load_optim:
731
+ D.optim.load_state_dict(
732
+ torch.load('%s/%s.pth' % (root, join_strings('_', ['D_optim', name_suffix]))))
733
+ # Load state dict
734
+ for item in state_dict:
735
+ state_dict[item] = torch.load('%s/%s.pth' % (root, join_strings('_', ['state_dict', name_suffix])))[item]
736
+ if G_ema is not None:
737
+ G_ema.load_state_dict(
738
+ torch.load('%s/%s.pth' % (root, join_strings('_', ['G_ema', name_suffix]))),
739
+ strict=strict)
740
+
741
+
742
+ ''' MetricsLogger originally stolen from VoxNet source code.
743
+ Used for logging inception metrics'''
744
+ class MetricsLogger(object):
745
+ def __init__(self, fname, reinitialize=False):
746
+ self.fname = fname
747
+ self.reinitialize = reinitialize
748
+ if os.path.exists(self.fname):
749
+ if self.reinitialize:
750
+ print('{} exists, deleting...'.format(self.fname))
751
+ os.remove(self.fname)
752
+
753
+ def log(self, record=None, **kwargs):
754
+ """
755
+ Assumption: no newlines in the input.
756
+ """
757
+ if record is None:
758
+ record = {}
759
+ record.update(kwargs)
760
+ record['_stamp'] = time.time()
761
+ with open(self.fname, 'a') as f:
762
+ f.write(json.dumps(record, ensure_ascii=True) + '\n')
763
+
764
+
765
+ # Logstyle is either:
766
+ # '%#.#f' for floating point representation in text
767
+ # '%#.#e' for exponent representation in text
768
+ # 'npz' for output to npz # NOT YET SUPPORTED
769
+ # 'pickle' for output to a python pickle # NOT YET SUPPORTED
770
+ # 'mat' for output to a MATLAB .mat file # NOT YET SUPPORTED
771
+ class MyLogger(object):
772
+ def __init__(self, fname, reinitialize=False, logstyle='%3.3f'):
773
+ self.root = fname
774
+ if not os.path.exists(self.root):
775
+ os.mkdir(self.root)
776
+ self.reinitialize = reinitialize
777
+ self.metrics = []
778
+ self.logstyle = logstyle # One of '%3.3f' or like '%3.3e'
779
+
780
+ # Delete log if re-starting and log already exists
781
+ def reinit(self, item):
782
+ if os.path.exists('%s/%s.log' % (self.root, item)):
783
+ if self.reinitialize:
784
+ # Only print the removal mess
785
+ if 'sv' in item :
786
+ if not any('sv' in item for item in self.metrics):
787
+ print('Deleting singular value logs...')
788
+ else:
789
+ print('{} exists, deleting...'.format('%s_%s.log' % (self.root, item)))
790
+ os.remove('%s/%s.log' % (self.root, item))
791
+
792
+ # Log in plaintext; this is designed for being read in MATLAB(sorry not sorry)
793
+ def log(self, itr, **kwargs):
794
+ for arg in kwargs:
795
+ if arg not in self.metrics:
796
+ if self.reinitialize:
797
+ self.reinit(arg)
798
+ self.metrics += [arg]
799
+ if self.logstyle == 'pickle':
800
+ print('Pickle not currently supported...')
801
+ # with open('%s/%s.log' % (self.root, arg), 'a') as f:
802
+ # pickle.dump(kwargs[arg], f)
803
+ elif self.logstyle == 'mat':
804
+ print('.mat logstyle not currently supported...')
805
+ else:
806
+ with open('%s/%s.log' % (self.root, arg), 'a') as f:
807
+ f.write('%d: %s\n' % (itr, self.logstyle % kwargs[arg]))
808
+
809
+
810
+ # Write some metadata to the logs directory
811
+ def write_metadata(logs_root, experiment_name, config, state_dict):
812
+ with open(('%s/%s/metalog.txt' %
813
+ (logs_root, experiment_name)), 'w') as writefile:
814
+ writefile.write('datetime: %s\n' % str(datetime.datetime.now()))
815
+ writefile.write('config: %s\n' % str(config))
816
+ writefile.write('state: %s\n' %str(state_dict))
817
+
818
+
819
+ """
820
+ Very basic progress indicator to wrap an iterable in.
821
+
822
+ Author: Jan SchlΓΌter
823
+ Andy's adds: time elapsed in addition to ETA, makes it possible to add
824
+ estimated time to 1k iters instead of estimated time to completion.
825
+ """
826
+ def progress(items, desc='', total=None, min_delay=0.1, displaytype='s1k'):
827
+ """
828
+ Returns a generator over `items`, printing the number and percentage of
829
+ items processed and the estimated remaining processing time before yielding
830
+ the next item. `total` gives the total number of items (required if `items`
831
+ has no length), and `min_delay` gives the minimum time in seconds between
832
+ subsequent prints. `desc` gives an optional prefix text (end with a space).
833
+ """
834
+ total = total or len(items)
835
+ t_start = time.time()
836
+ t_last = 0
837
+ for n, item in enumerate(items):
838
+ t_now = time.time()
839
+ if t_now - t_last > min_delay:
840
+ print("\r%s%d/%d (%6.2f%%)" % (
841
+ desc, n+1, total, n / float(total) * 100), end=" ")
842
+ if n > 0:
843
+
844
+ if displaytype == 's1k': # minutes/seconds for 1000 iters
845
+ next_1000 = n + (1000 - n%1000)
846
+ t_done = t_now - t_start
847
+ t_1k = t_done / n * next_1000
848
+ outlist = list(divmod(t_done, 60)) + list(divmod(t_1k - t_done, 60))
849
+ print("(TE/ET1k: %d:%02d / %d:%02d)" % tuple(outlist), end=" ")
850
+ else:# displaytype == 'eta':
851
+ t_done = t_now - t_start
852
+ t_total = t_done / n * total
853
+ outlist = list(divmod(t_done, 60)) + list(divmod(t_total - t_done, 60))
854
+ print("(TE/ETA: %d:%02d / %d:%02d)" % tuple(outlist), end=" ")
855
+
856
+ sys.stdout.flush()
857
+ t_last = t_now
858
+ yield item
859
+ t_total = time.time() - t_start
860
+ print("\r%s%d/%d (100.00%%) (took %d:%02d)" % ((desc, total, total) +
861
+ divmod(t_total, 60)))
862
+
863
+
864
+ # Sample function for use with inception metrics
865
+ def sample(G, z_, y_, config):
866
+ with torch.no_grad():
867
+ z_.sample_()
868
+ y_.sample_()
869
+ if config['parallel']:
870
+ G_z = nn.parallel.data_parallel(G, (z_, G.shared(y_)))
871
+ else:
872
+ G_z = G(z_, G.shared(y_))
873
+ return G_z, y_
874
+
875
+
876
+ # Sample function for sample sheets
877
+ def sample_sheet(G, classes_per_sheet, num_classes, samples_per_class, parallel,
878
+ samples_root, experiment_name, folder_number, z_=None):
879
+ # Prepare sample directory
880
+ if not os.path.isdir('%s/%s' % (samples_root, experiment_name)):
881
+ os.mkdir('%s/%s' % (samples_root, experiment_name))
882
+ if not os.path.isdir('%s/%s/%d' % (samples_root, experiment_name, folder_number)):
883
+ os.mkdir('%s/%s/%d' % (samples_root, experiment_name, folder_number))
884
+ # loop over total number of sheets
885
+ for i in range(num_classes // classes_per_sheet):
886
+ ims = []
887
+ y = torch.arange(i * classes_per_sheet, (i + 1) * classes_per_sheet, device='cuda')
888
+ for j in range(samples_per_class):
889
+ if (z_ is not None) and hasattr(z_, 'sample_') and classes_per_sheet <= z_.size(0):
890
+ z_.sample_()
891
+ else:
892
+ z_ = torch.randn(classes_per_sheet, G.dim_z, device='cuda')
893
+ with torch.no_grad():
894
+ if parallel:
895
+ o = nn.parallel.data_parallel(G, (z_[:classes_per_sheet], G.shared(y)))
896
+ else:
897
+ o = G(z_[:classes_per_sheet], G.shared(y))
898
+
899
+ ims += [o.data.cpu()]
900
+ # This line should properly unroll the images
901
+ out_ims = torch.stack(ims, 1).view(-1, ims[0].shape[1], ims[0].shape[2],
902
+ ims[0].shape[3]).data.float().cpu()
903
+ # The path for the samples
904
+ image_filename = '%s/%s/%d/samples%d.jpg' % (samples_root, experiment_name,
905
+ folder_number, i)
906
+ torchvision.utils.save_image(out_ims, image_filename,
907
+ nrow=samples_per_class, normalize=True)
908
+
909
+
910
+ # Interp function; expects x0 and x1 to be of shape (shape0, 1, rest_of_shape..)
911
+ def interp(x0, x1, num_midpoints):
912
+ lerp = torch.linspace(0, 1.0, num_midpoints + 2, device='cuda').to(x0.dtype)
913
+ return ((x0 * (1 - lerp.view(1, -1, 1))) + (x1 * lerp.view(1, -1, 1)))
914
+
915
+
916
+ # interp sheet function
917
+ # Supports full, class-wise and intra-class interpolation
918
+ def interp_sheet(G, num_per_sheet, num_midpoints, num_classes, parallel,
919
+ samples_root, experiment_name, folder_number, sheet_number=0,
920
+ fix_z=False, fix_y=False, device='cuda'):
921
+ # Prepare zs and ys
922
+ if fix_z: # If fix Z, only sample 1 z per row
923
+ zs = torch.randn(num_per_sheet, 1, G.dim_z, device=device)
924
+ zs = zs.repeat(1, num_midpoints + 2, 1).view(-1, G.dim_z)
925
+ else:
926
+ zs = interp(torch.randn(num_per_sheet, 1, G.dim_z, device=device),
927
+ torch.randn(num_per_sheet, 1, G.dim_z, device=device),
928
+ num_midpoints).view(-1, G.dim_z)
929
+ if fix_y: # If fix y, only sample 1 z per row
930
+ ys = sample_1hot(num_per_sheet, num_classes)
931
+ ys = G.shared(ys).view(num_per_sheet, 1, -1)
932
+ ys = ys.repeat(1, num_midpoints + 2, 1).view(num_per_sheet * (num_midpoints + 2), -1)
933
+ else:
934
+ ys = interp(G.shared(sample_1hot(num_per_sheet, num_classes)).view(num_per_sheet, 1, -1),
935
+ G.shared(sample_1hot(num_per_sheet, num_classes)).view(num_per_sheet, 1, -1),
936
+ num_midpoints).view(num_per_sheet * (num_midpoints + 2), -1)
937
+ # Run the net--note that we've already passed y through G.shared.
938
+ if G.fp16:
939
+ zs = zs.half()
940
+ with torch.no_grad():
941
+ if parallel:
942
+ out_ims = nn.parallel.data_parallel(G, (zs, ys)).data.cpu()
943
+ else:
944
+ out_ims = G(zs, ys).data.cpu()
945
+ interp_style = '' + ('Z' if not fix_z else '') + ('Y' if not fix_y else '')
946
+ image_filename = '%s/%s/%d/interp%s%d.jpg' % (samples_root, experiment_name,
947
+ folder_number, interp_style,
948
+ sheet_number)
949
+ torchvision.utils.save_image(out_ims, image_filename,
950
+ nrow=num_midpoints + 2, normalize=True)
951
+
952
+
953
+ # Convenience debugging function to print out gradnorms and shape from each layer
954
+ # May need to rewrite this so we can actually see which parameter is which
955
+ def print_grad_norms(net):
956
+ gradsums = [[float(torch.norm(param.grad).item()),
957
+ float(torch.norm(param).item()), param.shape]
958
+ for param in net.parameters()]
959
+ order = np.argsort([item[0] for item in gradsums])
960
+ print(['%3.3e,%3.3e, %s' % (gradsums[item_index][0],
961
+ gradsums[item_index][1],
962
+ str(gradsums[item_index][2]))
963
+ for item_index in order])
964
+
965
+
966
+ # Get singular values to log. This will use the state dict to find them
967
+ # and substitute underscores for dots.
968
+ def get_SVs(net, prefix):
969
+ d = net.state_dict()
970
+ return {('%s_%s' % (prefix, key)).replace('.', '_') :
971
+ float(d[key].item())
972
+ for key in d if 'sv' in key}
973
+
974
+
975
+ # Name an experiment based on its config
976
+ def name_from_config(config):
977
+ name = '_'.join([
978
+ item for item in [
979
+ 'Big%s' % config['which_train_fn'],
980
+ config['dataset'],
981
+ config['model'] if config['model'] != 'BigGAN' else None,
982
+ 'seed%d' % config['seed'],
983
+ 'Gch%d' % config['G_ch'],
984
+ 'Dch%d' % config['D_ch'],
985
+ 'Gd%d' % config['G_depth'] if config['G_depth'] > 1 else None,
986
+ 'Dd%d' % config['D_depth'] if config['D_depth'] > 1 else None,
987
+ 'bs%d' % config['batch_size'],
988
+ 'Gfp16' if config['G_fp16'] else None,
989
+ 'Dfp16' if config['D_fp16'] else None,
990
+ 'nDs%d' % config['num_D_steps'] if config['num_D_steps'] > 1 else None,
991
+ 'nDa%d' % config['num_D_accumulations'] if config['num_D_accumulations'] > 1 else None,
992
+ 'nGa%d' % config['num_G_accumulations'] if config['num_G_accumulations'] > 1 else None,
993
+ 'Glr%2.1e' % config['G_lr'],
994
+ 'Dlr%2.1e' % config['D_lr'],
995
+ 'GB%3.3f' % config['G_B1'] if config['G_B1'] !=0.0 else None,
996
+ 'GBB%3.3f' % config['G_B2'] if config['G_B2'] !=0.999 else None,
997
+ 'DB%3.3f' % config['D_B1'] if config['D_B1'] !=0.0 else None,
998
+ 'DBB%3.3f' % config['D_B2'] if config['D_B2'] !=0.999 else None,
999
+ 'Gnl%s' % config['G_nl'],
1000
+ 'Dnl%s' % config['D_nl'],
1001
+ 'Ginit%s' % config['G_init'],
1002
+ 'Dinit%s' % config['D_init'],
1003
+ 'G%s' % config['G_param'] if config['G_param'] != 'SN' else None,
1004
+ 'D%s' % config['D_param'] if config['D_param'] != 'SN' else None,
1005
+ 'Gattn%s' % config['G_attn'] if config['G_attn'] != '0' else None,
1006
+ 'Dattn%s' % config['D_attn'] if config['D_attn'] != '0' else None,
1007
+ 'Gortho%2.1e' % config['G_ortho'] if config['G_ortho'] > 0.0 else None,
1008
+ 'Dortho%2.1e' % config['D_ortho'] if config['D_ortho'] > 0.0 else None,
1009
+ config['norm_style'] if config['norm_style'] != 'bn' else None,
1010
+ 'cr' if config['cross_replica'] else None,
1011
+ 'Gshared' if config['G_shared'] else None,
1012
+ 'hier' if config['hier'] else None,
1013
+ 'ema' if config['ema'] else None,
1014
+ config['name_suffix'] if config['name_suffix'] else None,
1015
+ ]
1016
+ if item is not None])
1017
+ # dogball
1018
+ if config['hashname']:
1019
+ return hashname(name)
1020
+ else:
1021
+ return name
1022
+
1023
+
1024
+ # A simple function to produce a unique experiment name from the animal hashes.
1025
+ def hashname(name):
1026
+ h = hash(name)
1027
+ a = h % len(animal_hash.a)
1028
+ h = h // len(animal_hash.a)
1029
+ b = h % len(animal_hash.b)
1030
+ h = h // len(animal_hash.c)
1031
+ c = h % len(animal_hash.c)
1032
+ return animal_hash.a[a] + animal_hash.b[b] + animal_hash.c[c]
1033
+
1034
+
1035
+ # Get GPU memory, -i is the index
1036
+ def query_gpu(indices):
1037
+ os.system('nvidia-smi -i 0 --query-gpu=memory.free --format=csv')
1038
+
1039
+
1040
+ # Convenience function to count the number of parameters in a module
1041
+ def count_parameters(module):
1042
+ print('Number of parameters: {}'.format(
1043
+ sum([p.data.nelement() for p in module.parameters()])))
1044
+
1045
+
1046
+ # Convenience function to sample an index, not actually a 1-hot
1047
+ def sample_1hot(batch_size, num_classes, device='cuda'):
1048
+ return torch.randint(low=0, high=num_classes, size=(batch_size,),
1049
+ device=device, dtype=torch.int64, requires_grad=False)
1050
+
1051
+
1052
+ # A highly simplified convenience class for sampling from distributions
1053
+ # One could also use PyTorch's inbuilt distributions package.
1054
+ # Note that this class requires initialization to proceed as
1055
+ # x = Distribution(torch.randn(size))
1056
+ # x.init_distribution(dist_type, **dist_kwargs)
1057
+ # x = x.to(device,dtype)
1058
+ # This is partially based on https://discuss.pytorch.org/t/subclassing-torch-tensor/23754/2
1059
+ class Distribution(torch.Tensor):
1060
+ # Init the params of the distribution
1061
+ def init_distribution(self, dist_type, **kwargs):
1062
+ self.dist_type = dist_type
1063
+ self.dist_kwargs = kwargs
1064
+ if self.dist_type == 'normal':
1065
+ self.mean, self.var = kwargs['mean'], kwargs['var']
1066
+ elif self.dist_type == 'categorical':
1067
+ self.num_categories = kwargs['num_categories']
1068
+
1069
+ def sample_(self):
1070
+ if self.dist_type == 'normal':
1071
+ self.normal_(self.mean, self.var)
1072
+ elif self.dist_type == 'categorical':
1073
+ self.random_(0, self.num_categories)
1074
+ # return self.variable
1075
+
1076
+ # Silly hack: overwrite the to() method to wrap the new object
1077
+ # in a distribution as well
1078
+ def to(self, *args, **kwargs):
1079
+ new_obj = Distribution(self)
1080
+ new_obj.init_distribution(self.dist_type, **self.dist_kwargs)
1081
+ new_obj.data = super().to(*args, **kwargs)
1082
+ return new_obj
1083
+
1084
+
1085
+ # Convenience function to prepare a z and y vector
1086
+ def prepare_z_y(G_batch_size, dim_z, nclasses, device='cuda',
1087
+ fp16=False,z_var=1.0):
1088
+ z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False))
1089
+ z_.init_distribution('normal', mean=0, var=z_var)
1090
+ z_ = z_.to(device,torch.float16 if fp16 else torch.float32)
1091
+
1092
+ if fp16:
1093
+ z_ = z_.half()
1094
+
1095
+ y_ = Distribution(torch.zeros(G_batch_size, requires_grad=False))
1096
+ y_.init_distribution('categorical',num_categories=nclasses)
1097
+ y_ = y_.to(device, torch.int64)
1098
+ return z_, y_
1099
+
1100
+
1101
+ def initiate_standing_stats(net):
1102
+ for module in net.modules():
1103
+ if hasattr(module, 'accumulate_standing'):
1104
+ module.reset_stats()
1105
+ module.accumulate_standing = True
1106
+
1107
+
1108
+ def accumulate_standing_stats(net, z, y, nclasses, num_accumulations=16):
1109
+ initiate_standing_stats(net)
1110
+ net.train()
1111
+ for i in range(num_accumulations):
1112
+ with torch.no_grad():
1113
+ z.normal_()
1114
+ y.random_(0, nclasses)
1115
+ x = net(z, net.shared(y)) # No need to parallelize here unless using syncbn
1116
+ # Set to eval mode
1117
+ net.eval()
1118
+
1119
+
1120
+ # This version of Adam keeps an fp32 copy of the parameters and
1121
+ # does all of the parameter updates in fp32, while still doing the
1122
+ # forwards and backwards passes using fp16 (i.e. fp16 copies of the
1123
+ # parameters and fp16 activations).
1124
+ #
1125
+ # Note that this calls .float().cuda() on the params.
1126
+ import math
1127
+ from torch.optim.optimizer import Optimizer
1128
+ class Adam16(Optimizer):
1129
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,weight_decay=0):
1130
+ defaults = dict(lr=lr, betas=betas, eps=eps,
1131
+ weight_decay=weight_decay)
1132
+ params = list(params)
1133
+ super(Adam16, self).__init__(params, defaults)
1134
+
1135
+ # Safety modification to make sure we floatify our state
1136
+ def load_state_dict(self, state_dict):
1137
+ super(Adam16, self).load_state_dict(state_dict)
1138
+ for group in self.param_groups:
1139
+ for p in group['params']:
1140
+ self.state[p]['exp_avg'] = self.state[p]['exp_avg'].float()
1141
+ self.state[p]['exp_avg_sq'] = self.state[p]['exp_avg_sq'].float()
1142
+ self.state[p]['fp32_p'] = self.state[p]['fp32_p'].float()
1143
+
1144
+ def step(self, closure=None):
1145
+ """Performs a single optimization step.
1146
+ Arguments:
1147
+ closure (callable, optional): A closure that reevaluates the model
1148
+ and returns the loss.
1149
+ """
1150
+ loss = None
1151
+ if closure is not None:
1152
+ loss = closure()
1153
+
1154
+ for group in self.param_groups:
1155
+ for p in group['params']:
1156
+ if p.grad is None:
1157
+ continue
1158
+
1159
+ grad = p.grad.data.float()
1160
+ state = self.state[p]
1161
+
1162
+ # State initialization
1163
+ if len(state) == 0:
1164
+ state['step'] = 0
1165
+ # Exponential moving average of gradient values
1166
+ state['exp_avg'] = grad.new().resize_as_(grad).zero_()
1167
+ # Exponential moving average of squared gradient values
1168
+ state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
1169
+ # Fp32 copy of the weights
1170
+ state['fp32_p'] = p.data.float()
1171
+
1172
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
1173
+ beta1, beta2 = group['betas']
1174
+
1175
+ state['step'] += 1
1176
+
1177
+ if group['weight_decay'] != 0:
1178
+ grad = grad.add(group['weight_decay'], state['fp32_p'])
1179
+
1180
+ # Decay the first and second moment running average coefficient
1181
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
1182
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
1183
+
1184
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
1185
+
1186
+ bias_correction1 = 1 - beta1 ** state['step']
1187
+ bias_correction2 = 1 - beta2 ** state['step']
1188
+ step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
1189
+
1190
+ state['fp32_p'].addcdiv_(-step_size, exp_avg, denom)
1191
+ p.data = state['fp32_p'].half()
1192
+
1193
+ return loss
src/models/cvae.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import cat
2
+ from torch.optim import Adam
3
+ from torch.nn import Sequential, ModuleList, \
4
+ Conv2d, Linear, \
5
+ LeakyReLU, Tanh, \
6
+ BatchNorm1d, BatchNorm2d, \
7
+ ConvTranspose2d, UpsamplingBilinear2d
8
+
9
+ from .neuralnetwork import NeuralNetwork
10
+
11
+
12
+ # parameters for cVAE
13
+ colors_dim = 3
14
+ labels_dim = 37
15
+ momentum = 0.99 # Batchnorm
16
+ negative_slope = 0.2 # LeakyReLU
17
+ optimizer = Adam
18
+ betas = (0.5, 0.999)
19
+
20
+ # hyperparameters
21
+ learning_rate = 2e-4
22
+ latent_dim = 128
23
+
24
+
25
+ def genUpsample(input_channels, output_channels, stride, pad):
26
+ return Sequential(
27
+ ConvTranspose2d(input_channels, output_channels, 4, stride, pad, bias=False),
28
+ BatchNorm2d(output_channels),
29
+ LeakyReLU(negative_slope=negative_slope))
30
+
31
+
32
+ def genUpsample2(input_channels, output_channels, kernel_size):
33
+ return Sequential(
34
+ Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=1, padding= (kernel_size-1) // 2),
35
+ BatchNorm2d(output_channels),
36
+ LeakyReLU(negative_slope=negative_slope),
37
+ Conv2d(output_channels, output_channels, kernel_size=kernel_size, stride=1, padding= (kernel_size-1) // 2),
38
+ BatchNorm2d(output_channels),
39
+ LeakyReLU(negative_slope=negative_slope),
40
+ UpsamplingBilinear2d(scale_factor=2))
41
+
42
+
43
+ class ConditionalDecoder(NeuralNetwork):
44
+ def __init__(self, ll_scaling=1.0, dim_z=latent_dim):
45
+ super(ConditionalDecoder, self).__init__()
46
+ self.dim_z = dim_z
47
+ ngf = 32
48
+ self.init = genUpsample(self.dim_z, ngf * 16, 1, 0)
49
+ self.embedding = Sequential(
50
+ Linear(labels_dim, self.dim_z),
51
+ BatchNorm1d(self.dim_z, momentum=momentum),
52
+ LeakyReLU(negative_slope=negative_slope),
53
+ )
54
+ self.dense_init = Sequential(
55
+ Linear(self.dim_z*2, self.dim_z),
56
+ BatchNorm1d(self.dim_z, momentum=momentum),
57
+ LeakyReLU(negative_slope=negative_slope),
58
+ )
59
+ self.m_modules = ModuleList() # to 4x4
60
+ self.c_modules = ModuleList()
61
+ for i in range(4):
62
+ self.m_modules.append(genUpsample2(ngf * 2**(4-i), ngf * 2**(3-i), 3))
63
+ self.c_modules.append(Sequential(Conv2d(ngf * 2**(3-i), colors_dim, 3, 1, 1, bias=False), Tanh()))
64
+ self.set_optimizer(optimizer, lr=learning_rate*ll_scaling, betas=betas)
65
+
66
+ def forward(self, latent, labels, step=3):
67
+ y = self.embedding(labels)
68
+ out = cat((latent, y), dim=1)
69
+ out = self.dense_init(out)
70
+ out = out.unsqueeze(2).unsqueeze(3)
71
+ out = self.init(out)
72
+ for i in range(step):
73
+ out = self.m_modules[i](out)
74
+ out = self.c_modules[step](self.m_modules[step](out))
75
+ return out
src/models/infoscc_gan.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict
2
+ from functools import partial
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ def get_activation(activation: str = "lrelu"):
10
+ actv_layers = {
11
+ "relu": nn.ReLU,
12
+ "lrelu": partial(nn.LeakyReLU, 0.2),
13
+ }
14
+ assert activation in actv_layers, f"activation [{activation}] not implemented"
15
+ return actv_layers[activation]
16
+
17
+
18
+ def get_normalization(normalization: str = "batch_norm"):
19
+ norm_layers = {
20
+ "instance_norm": nn.InstanceNorm2d,
21
+ "batch_norm": nn.BatchNorm2d,
22
+ "group_norm": partial(nn.GroupNorm, num_groups=8),
23
+ "layer_norm": partial(nn.GroupNorm, num_groups=1),
24
+ }
25
+ assert normalization in norm_layers, f"normalization [{normalization}] not implemented"
26
+ return norm_layers[normalization]
27
+
28
+
29
+ class ConvLayer(nn.Sequential):
30
+ def __init__(
31
+ self,
32
+ in_channels: int,
33
+ out_channels: int,
34
+ kernel_size: int = 3,
35
+ stride: int = 1,
36
+ padding: Optional[int] = 1,
37
+ padding_mode: str = "zeros",
38
+ groups: int = 1,
39
+ bias: bool = True,
40
+ transposed: bool = False,
41
+ normalization: Optional[str] = None,
42
+ activation: Optional[str] = "lrelu",
43
+ pre_activate: bool = False,
44
+ ):
45
+ if transposed:
46
+ conv = partial(nn.ConvTranspose2d, output_padding=stride-1)
47
+ padding_mode = "zeros"
48
+ else:
49
+ conv = nn.Conv2d
50
+ layers = [
51
+ conv(
52
+ in_channels,
53
+ out_channels,
54
+ kernel_size=kernel_size,
55
+ stride=stride,
56
+ padding=padding,
57
+ padding_mode=padding_mode,
58
+ groups=groups,
59
+ bias=bias,
60
+ )
61
+ ]
62
+
63
+ norm_actv = []
64
+ if normalization is not None:
65
+ norm_actv.append(
66
+ get_normalization(normalization)(
67
+ num_channels=in_channels if pre_activate else out_channels
68
+ )
69
+ )
70
+ if activation is not None:
71
+ norm_actv.append(
72
+ get_activation(activation)(inplace=True)
73
+ )
74
+
75
+ if pre_activate:
76
+ layers = norm_actv + layers
77
+ else:
78
+ layers = layers + norm_actv
79
+
80
+ super().__init__(
81
+ *layers
82
+ )
83
+
84
+
85
+ class SubspaceLayer(nn.Module):
86
+ def __init__(
87
+ self,
88
+ dim: int,
89
+ n_basis: int,
90
+ ):
91
+ super().__init__()
92
+
93
+ self.U = nn.Parameter(torch.empty(n_basis, dim))
94
+ nn.init.orthogonal_(self.U)
95
+ self.L = nn.Parameter(torch.FloatTensor([3 * i for i in range(n_basis, 0, -1)]))
96
+ self.mu = nn.Parameter(torch.zeros(dim))
97
+
98
+ def forward(self, z):
99
+ return (self.L * z) @ self.U + self.mu
100
+
101
+
102
+ class EigenBlock(nn.Module):
103
+ def __init__(
104
+ self,
105
+ width: int,
106
+ height: int,
107
+ in_channels: int,
108
+ out_channels: int,
109
+ n_basis: int,
110
+ ):
111
+ super().__init__()
112
+
113
+ self.projection = SubspaceLayer(dim=width*height*in_channels, n_basis=n_basis)
114
+ self.subspace_conv1 = ConvLayer(
115
+ in_channels,
116
+ in_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0,
120
+ transposed=True,
121
+ activation=None,
122
+ normalization=None,
123
+ )
124
+ self.subspace_conv2 = ConvLayer(
125
+ in_channels,
126
+ out_channels,
127
+ kernel_size=3,
128
+ stride=2,
129
+ padding=1,
130
+ transposed=True,
131
+ activation=None,
132
+ normalization=None,
133
+ )
134
+
135
+ self.feature_conv1 = ConvLayer(
136
+ in_channels,
137
+ out_channels,
138
+ kernel_size=3,
139
+ stride=2,
140
+ transposed=True,
141
+ pre_activate=True,
142
+ )
143
+ self.feature_conv2 = ConvLayer(
144
+ out_channels,
145
+ out_channels,
146
+ kernel_size=3,
147
+ stride=1,
148
+ transposed=True,
149
+ pre_activate=True,
150
+ )
151
+
152
+ def forward(self, z, h):
153
+ phi = self.projection(z).view(h.shape)
154
+ h = self.feature_conv1(h + self.subspace_conv1(phi))
155
+ h = self.feature_conv2(h + self.subspace_conv2(phi))
156
+ return h
157
+
158
+
159
+ class ConditionalGenerator(nn.Module):
160
+
161
+ """Conditional generator
162
+ It generates images from one hot label + noise sampled from N(0, 1) with explorable z injection space
163
+ Based on EigenGAN
164
+ """
165
+
166
+ def __init__(self,
167
+ size: int,
168
+ y_size: int,
169
+ z_size: int,
170
+ out_channels: int = 3,
171
+ n_basis: int = 6,
172
+ noise_dim: int = 512,
173
+ base_channels: int = 16,
174
+ max_channels: int = 512,
175
+ y_type: str = 'one_hot'):
176
+
177
+ if y_type not in ['one_hot', 'multi_label', 'mixed', 'real']:
178
+ raise ValueError('Unsupported `y_type`')
179
+
180
+ super(ConditionalGenerator, self).__init__()
181
+
182
+ assert (size & (size - 1) == 0) and size != 0, "img size should be a power of 2"
183
+
184
+ self.y_type = y_type
185
+ self.y_size = y_size
186
+ self.eps_size = z_size
187
+
188
+ self.noise_dim = noise_dim
189
+ self.n_basis = n_basis
190
+ self.n_blocks = int(math.log(size, 2)) - 2
191
+
192
+ def get_channels(i_block):
193
+ return min(max_channels, base_channels * (2 ** (self.n_blocks - i_block)))
194
+
195
+ self.y_fc = nn.Linear(self.y_size, self.y_size)
196
+ self.concat_fc = nn.Linear(self.y_size + self.eps_size, self.noise_dim)
197
+
198
+ self.fc = nn.Linear(self.noise_dim, 4 * 4 * get_channels(0))
199
+
200
+ self.blocks = nn.ModuleList()
201
+ for i in range(self.n_blocks):
202
+ self.blocks.append(
203
+ EigenBlock(
204
+ width=4 * (2 ** i),
205
+ height=4 * (2 ** i),
206
+ in_channels=get_channels(i),
207
+ out_channels=get_channels(i + 1),
208
+ n_basis=self.n_basis,
209
+ )
210
+ )
211
+
212
+ self.out = nn.Sequential(
213
+ ConvLayer(base_channels, out_channels, kernel_size=7, stride=1, padding=3, pre_activate=True),
214
+ nn.Tanh(),
215
+ )
216
+
217
+ def forward(self,
218
+ y: torch.Tensor,
219
+ eps: Optional[torch.Tensor] = None,
220
+ zs: Optional[torch.Tensor] = None,
221
+ return_eps: bool = False):
222
+
223
+ bs = y.size(0)
224
+
225
+ if eps is None:
226
+ eps = self.sample_eps(bs)
227
+
228
+ if zs is None:
229
+ zs = self.sample_zs(bs)
230
+
231
+ y_out = self.y_fc(y)
232
+ concat = torch.cat((y_out, eps), dim=1)
233
+ concat = self.concat_fc(concat)
234
+
235
+ out = self.fc(concat).view(len(eps), -1, 4, 4)
236
+ for block, z in zip(self.blocks, zs.permute(1, 0, 2)):
237
+ out = block(z, out)
238
+ out = self.out(out)
239
+
240
+ if return_eps:
241
+ return out, concat
242
+
243
+ return out
244
+
245
+ def sample_zs(self, batch: int, truncation: float = 1.):
246
+ device = self.get_device()
247
+ zs = torch.randn(batch, self.n_blocks, self.n_basis, device=device)
248
+
249
+ if truncation < 1.:
250
+ zs = torch.zeros_like(zs) * (1 - truncation) + zs * truncation
251
+ return zs
252
+
253
+ def sample_eps(self, batch: int, truncation: float = 1.):
254
+ device = self.get_device()
255
+ eps = torch.randn(batch, self.eps_size, device=device)
256
+
257
+ if truncation < 1.:
258
+ eps = torch.zeros_like(eps) * (1 - truncation) + eps * truncation
259
+ return eps
260
+
261
+ def get_device(self):
262
+ return self.fc.weight.device
263
+
264
+ def orthogonal_regularizer(self):
265
+ reg = []
266
+ for layer in self.modules():
267
+ if isinstance(layer, SubspaceLayer):
268
+ UUT = layer.U @ layer.U.t()
269
+ reg.append(
270
+ ((UUT - torch.eye(UUT.shape[0], device=UUT.device)) ** 2).mean()
271
+ )
272
+ return sum(reg) / len(reg)
src/models/neuralnetwork.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class NeuralNetwork(torch.nn.Module):
5
+ """ base class with convenient procedures used by all NN"""
6
+ def __init__(self):
7
+ super(NeuralNetwork, self).__init__()
8
+ self.parameter_file = f"parameter_state_dict_{self._get_name()}.pth"
9
+ # self.cuda() ## all NN shall run on cuda ### doesnt seem to work
10
+
11
+ def save(self) -> None:
12
+ """ save learned parameters to parameter_file """
13
+ torch.save(self.state_dict(), self.parameter_file)
14
+
15
+ def load(self) -> None:
16
+ """ load learned parameters from parameter_file """
17
+ self.load_state_dict(torch.load(self.parameter_file))
18
+ self.eval()
19
+
20
+ @staticmethod
21
+ def same_padding(kernel_size=1) -> float:
22
+ """ return padding required to mimic 'same' padding in tensorflow """
23
+ return (kernel_size-1) // 2
24
+
25
+ def set_optimizer(self, optimizer, **kwargs) -> None:
26
+ self.optimizer = optimizer(self.parameters(), **kwargs)
27
+
28
+ def get_total_number_parameters(self) -> float:
29
+ """ return total number of parameters """
30
+ return sum([p.numel() for p in classifier.parameters()])
31
+
32
+ def zero_grad(self):
33
+ """ faster implementation of zero_grad """
34
+ for p in self.parameters():
35
+ p.grad = None
36
+ # self.zero_grad(set_to_none=True)
37
+
38
+
39
+ def update_networks_on_loss(loss: torch.Tensor, *networks) -> None:
40
+ if not loss:
41
+ return
42
+ for network in networks:
43
+ network.zero_grad()
44
+ loss.backward()
45
+ for network in networks:
46
+ network.optimizer.step()
src/models/parameter.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ hardcoded parameter
2
+
3
+ these can be changed in a jupyter notebook during runtime via
4
+
5
+ >>> import parameter
6
+ >>> parameter.parameter = new_value
7
+
8
+ """
9
+
10
+ from torch.optim import Adam
11
+
12
+ ###############
13
+ ## hardcoded ##
14
+ ###############
15
+
16
+
17
+ # Input
18
+ image_dim = 64
19
+ colors_dim = 3
20
+ labels_dim = 37 #3
21
+ input_size = (colors_dim,image_dim,image_dim)
22
+
23
+
24
+ #############
25
+ ## mutable ##
26
+ #############
27
+
28
+ class Parameter:
29
+ """ container for hyperparameters"""
30
+
31
+ def __init__(self):
32
+ # Encoder/Decoder
33
+ self.latent_dim = 8
34
+ self.decoder_dim = self.latent_dim # differs from latent_dim if PCA applied before decoder
35
+
36
+ # General
37
+ self.learning_rate = 0.0002
38
+ self.betas = (0.5,0.999) ## 0.999 is default beta2 in tensorflow
39
+ self.optimizer = Adam
40
+ self.negative_slope = 0.2 # for LeakyReLU
41
+ self.momentum = 0.99 # for BatchNorm
42
+
43
+ # Loss weights
44
+ self.alpha = 1 # switch VAE (1) / AE (0)
45
+ self.beta = 1 # weight for KL-loss
46
+ self.gamma = 1024 # weight for learned-metric-loss (https://arxiv.org/pdf/1512.09300.pdf)
47
+ self.delta = 1 # weight for class-loss
48
+ self.zeta = 0.5 # weight for MSE-loss
49
+
50
+ def return_parameter_dict(self):
51
+ return(self.__dict__)
52
+
53
+ parameter = Parameter()
src/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .utils import download_file
2
+ from .utils import sample_labels
src/utils/utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gdown
3
+
4
+ import torch
5
+
6
+
7
+ def download_file(file_id: str, output_path: str):
8
+ gdown.download(f'https://drive.google.com/uc?id={file_id}', output_path)
9
+
10
+
11
+ def sample_labels(labels: torch.Tensor, n: int) -> torch.Tensor:
12
+ high = labels.shape[0]
13
+ idx = np.random.randint(0, high, size=n)
14
+ return labels[idx]