showgan commited on
Commit
09b13b3
1 Parent(s): 05c5cfd

Training in progress, step 1000

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. .gitignore +166 -0
  3. README.md +16 -0
  4. added_tokens.json +1609 -0
  5. computer-vision-study-group/Notebooks/HuggingFace_vision_ecosystem_overview_(June_2022).ipynb +0 -0
  6. computer-vision-study-group/README.md +15 -0
  7. computer-vision-study-group/Sessions/Blip2.md +25 -0
  8. computer-vision-study-group/Sessions/Fiber.md +24 -0
  9. computer-vision-study-group/Sessions/FlexiViT.md +23 -0
  10. computer-vision-study-group/Sessions/HFVisionEcosystem.md +10 -0
  11. computer-vision-study-group/Sessions/HowDoVisionTransformersWork.md +27 -0
  12. computer-vision-study-group/Sessions/MaskedAutoEncoders.md +24 -0
  13. computer-vision-study-group/Sessions/NeuralRadianceFields.md +19 -0
  14. computer-vision-study-group/Sessions/PolarizedSelfAttention.md +14 -0
  15. computer-vision-study-group/Sessions/SwinTransformer.md +25 -0
  16. config.json +52 -0
  17. gradio-blocks/README.md +123 -0
  18. huggan/README.md +487 -0
  19. huggan/__init__.py +3 -0
  20. huggan/assets/cyclegan.png +3 -0
  21. huggan/assets/dcgan_mnist.png +0 -0
  22. huggan/assets/example_model.png +0 -0
  23. huggan/assets/example_space.png +0 -0
  24. huggan/assets/huggan_banner.png +0 -0
  25. huggan/assets/lightweight_gan_wandb.png +3 -0
  26. huggan/assets/metfaces.png +0 -0
  27. huggan/assets/pix2pix_maps.png +3 -0
  28. huggan/assets/wandb.png +3 -0
  29. huggan/model_card_template.md +50 -0
  30. huggan/pytorch/README.md +19 -0
  31. huggan/pytorch/__init__.py +0 -0
  32. huggan/pytorch/cyclegan/README.md +81 -0
  33. huggan/pytorch/cyclegan/__init__.py +0 -0
  34. huggan/pytorch/cyclegan/modeling_cyclegan.py +108 -0
  35. huggan/pytorch/cyclegan/train.py +354 -0
  36. huggan/pytorch/cyclegan/utils.py +44 -0
  37. huggan/pytorch/dcgan/README.md +155 -0
  38. huggan/pytorch/dcgan/__init__.py +0 -0
  39. huggan/pytorch/dcgan/modeling_dcgan.py +80 -0
  40. huggan/pytorch/dcgan/train.py +346 -0
  41. huggan/pytorch/huggan_mixin.py +131 -0
  42. huggan/pytorch/lightweight_gan/README.md +89 -0
  43. huggan/pytorch/lightweight_gan/__init__.py +0 -0
  44. huggan/pytorch/lightweight_gan/cli.py +178 -0
  45. huggan/pytorch/lightweight_gan/diff_augment.py +102 -0
  46. huggan/pytorch/lightweight_gan/lightweight_gan.py +1598 -0
  47. huggan/pytorch/metrics/README.md +39 -0
  48. huggan/pytorch/metrics/__init__.py +0 -0
  49. huggan/pytorch/metrics/fid_score.py +80 -0
  50. huggan/pytorch/metrics/inception.py +328 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ huggan/assets/cyclegan.png filter=lfs diff=lfs merge=lfs -text
37
+ huggan/assets/lightweight_gan_wandb.png filter=lfs diff=lfs merge=lfs -text
38
+ huggan/assets/pix2pix_maps.png filter=lfs diff=lfs merge=lfs -text
39
+ huggan/assets/wandb.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Initially taken from Github's Python gitignore file
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # tests and logs
12
+ tests/fixtures/cached_*_text.txt
13
+ logs/
14
+ lightning_logs/
15
+ lang_code_data/
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ .python-version
90
+
91
+ # celery beat schedule file
92
+ celerybeat-schedule
93
+
94
+ # SageMath parsed files
95
+ *.sage.py
96
+
97
+ # Environments
98
+ .env
99
+ .venv
100
+ env/
101
+ venv/
102
+ ENV/
103
+ env.bak/
104
+ venv.bak/
105
+
106
+ # Spyder project settings
107
+ .spyderproject
108
+ .spyproject
109
+
110
+ # Rope project settings
111
+ .ropeproject
112
+
113
+ # mkdocs documentation
114
+ /site
115
+
116
+ # mypy
117
+ .mypy_cache/
118
+ .dmypy.json
119
+ dmypy.json
120
+
121
+ # Pyre type checker
122
+ .pyre/
123
+
124
+ # vscode
125
+ .vs
126
+ .vscode
127
+
128
+ # Pycharm
129
+ .idea
130
+
131
+ # TF code
132
+ tensorflow_code
133
+
134
+ # Models
135
+ proc_data
136
+
137
+ # examples
138
+ runs
139
+ /runs_old
140
+ /wandb
141
+ /examples/runs
142
+ /examples/**/*.args
143
+ /examples/rag/sweep
144
+
145
+ # data
146
+ /data
147
+ serialization_dir
148
+
149
+ # emacs
150
+ *.*~
151
+ debug.env
152
+
153
+ # vim
154
+ .*.swp
155
+
156
+ #ctags
157
+ tags
158
+
159
+ # pre-commit
160
+ .pre-commit*
161
+
162
+ # .lock
163
+ *.lock
164
+
165
+ # DS_Store (MacOS)
166
+ .DS_Store
README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Community Events @ 🤗
2
+
3
+ A central repository for all community events organized by 🤗 HuggingFace. Come one, come all!
4
+ We're constantly finding ways to democratise the use of ML across modalities and languages. This repo contains information about all past, present and upcoming events.
5
+
6
+ ## Hugging Events
7
+
8
+ | **Event Name** | **Dates** | **Status** |
9
+ |-------------------------------------------------------------------------|-----------------|--------------------------------------------------------------------------------------------------------------|
10
+ | [Open Source AI Game Jam 🎮 (First Edition)](/open-source-ai-game-jam) | July 7th - 9th, 2023 | Finished |
11
+ | [Whisper Fine Tuning Event](/whisper-fine-tuning-event) | Dec 5th - 19th, 2022 | Finished |
12
+ | [Computer Vision Study Group](/computer-vision-study-group) | Ongoing | Monthly |
13
+ | [ML for Audio Study Group](https://github.com/Vaibhavs10/ml-with-audio) | Ongoing | Monthly |
14
+ | [Gradio Blocks](/gradio-blocks) | May 16th - 31st, 2022 | Finished |
15
+ | [HugGAN](/huggan) | Apr 4th - 17th, 2022 | Finished |
16
+ | [Keras Sprint](keras-sprint) | June, 2022 | Finished |
added_tokens.json ADDED
@@ -0,0 +1,1609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<|0.00|>": 50364,
3
+ "<|0.02|>": 50365,
4
+ "<|0.04|>": 50366,
5
+ "<|0.06|>": 50367,
6
+ "<|0.08|>": 50368,
7
+ "<|0.10|>": 50369,
8
+ "<|0.12|>": 50370,
9
+ "<|0.14|>": 50371,
10
+ "<|0.16|>": 50372,
11
+ "<|0.18|>": 50373,
12
+ "<|0.20|>": 50374,
13
+ "<|0.22|>": 50375,
14
+ "<|0.24|>": 50376,
15
+ "<|0.26|>": 50377,
16
+ "<|0.28|>": 50378,
17
+ "<|0.30|>": 50379,
18
+ "<|0.32|>": 50380,
19
+ "<|0.34|>": 50381,
20
+ "<|0.36|>": 50382,
21
+ "<|0.38|>": 50383,
22
+ "<|0.40|>": 50384,
23
+ "<|0.42|>": 50385,
24
+ "<|0.44|>": 50386,
25
+ "<|0.46|>": 50387,
26
+ "<|0.48|>": 50388,
27
+ "<|0.50|>": 50389,
28
+ "<|0.52|>": 50390,
29
+ "<|0.54|>": 50391,
30
+ "<|0.56|>": 50392,
31
+ "<|0.58|>": 50393,
32
+ "<|0.60|>": 50394,
33
+ "<|0.62|>": 50395,
34
+ "<|0.64|>": 50396,
35
+ "<|0.66|>": 50397,
36
+ "<|0.68|>": 50398,
37
+ "<|0.70|>": 50399,
38
+ "<|0.72|>": 50400,
39
+ "<|0.74|>": 50401,
40
+ "<|0.76|>": 50402,
41
+ "<|0.78|>": 50403,
42
+ "<|0.80|>": 50404,
43
+ "<|0.82|>": 50405,
44
+ "<|0.84|>": 50406,
45
+ "<|0.86|>": 50407,
46
+ "<|0.88|>": 50408,
47
+ "<|0.90|>": 50409,
48
+ "<|0.92|>": 50410,
49
+ "<|0.94|>": 50411,
50
+ "<|0.96|>": 50412,
51
+ "<|0.98|>": 50413,
52
+ "<|1.00|>": 50414,
53
+ "<|1.02|>": 50415,
54
+ "<|1.04|>": 50416,
55
+ "<|1.06|>": 50417,
56
+ "<|1.08|>": 50418,
57
+ "<|1.10|>": 50419,
58
+ "<|1.12|>": 50420,
59
+ "<|1.14|>": 50421,
60
+ "<|1.16|>": 50422,
61
+ "<|1.18|>": 50423,
62
+ "<|1.20|>": 50424,
63
+ "<|1.22|>": 50425,
64
+ "<|1.24|>": 50426,
65
+ "<|1.26|>": 50427,
66
+ "<|1.28|>": 50428,
67
+ "<|1.30|>": 50429,
68
+ "<|1.32|>": 50430,
69
+ "<|1.34|>": 50431,
70
+ "<|1.36|>": 50432,
71
+ "<|1.38|>": 50433,
72
+ "<|1.40|>": 50434,
73
+ "<|1.42|>": 50435,
74
+ "<|1.44|>": 50436,
75
+ "<|1.46|>": 50437,
76
+ "<|1.48|>": 50438,
77
+ "<|1.50|>": 50439,
78
+ "<|1.52|>": 50440,
79
+ "<|1.54|>": 50441,
80
+ "<|1.56|>": 50442,
81
+ "<|1.58|>": 50443,
82
+ "<|1.60|>": 50444,
83
+ "<|1.62|>": 50445,
84
+ "<|1.64|>": 50446,
85
+ "<|1.66|>": 50447,
86
+ "<|1.68|>": 50448,
87
+ "<|1.70|>": 50449,
88
+ "<|1.72|>": 50450,
89
+ "<|1.74|>": 50451,
90
+ "<|1.76|>": 50452,
91
+ "<|1.78|>": 50453,
92
+ "<|1.80|>": 50454,
93
+ "<|1.82|>": 50455,
94
+ "<|1.84|>": 50456,
95
+ "<|1.86|>": 50457,
96
+ "<|1.88|>": 50458,
97
+ "<|1.90|>": 50459,
98
+ "<|1.92|>": 50460,
99
+ "<|1.94|>": 50461,
100
+ "<|1.96|>": 50462,
101
+ "<|1.98|>": 50463,
102
+ "<|10.00|>": 50864,
103
+ "<|10.02|>": 50865,
104
+ "<|10.04|>": 50866,
105
+ "<|10.06|>": 50867,
106
+ "<|10.08|>": 50868,
107
+ "<|10.10|>": 50869,
108
+ "<|10.12|>": 50870,
109
+ "<|10.14|>": 50871,
110
+ "<|10.16|>": 50872,
111
+ "<|10.18|>": 50873,
112
+ "<|10.20|>": 50874,
113
+ "<|10.22|>": 50875,
114
+ "<|10.24|>": 50876,
115
+ "<|10.26|>": 50877,
116
+ "<|10.28|>": 50878,
117
+ "<|10.30|>": 50879,
118
+ "<|10.32|>": 50880,
119
+ "<|10.34|>": 50881,
120
+ "<|10.36|>": 50882,
121
+ "<|10.38|>": 50883,
122
+ "<|10.40|>": 50884,
123
+ "<|10.42|>": 50885,
124
+ "<|10.44|>": 50886,
125
+ "<|10.46|>": 50887,
126
+ "<|10.48|>": 50888,
127
+ "<|10.50|>": 50889,
128
+ "<|10.52|>": 50890,
129
+ "<|10.54|>": 50891,
130
+ "<|10.56|>": 50892,
131
+ "<|10.58|>": 50893,
132
+ "<|10.60|>": 50894,
133
+ "<|10.62|>": 50895,
134
+ "<|10.64|>": 50896,
135
+ "<|10.66|>": 50897,
136
+ "<|10.68|>": 50898,
137
+ "<|10.70|>": 50899,
138
+ "<|10.72|>": 50900,
139
+ "<|10.74|>": 50901,
140
+ "<|10.76|>": 50902,
141
+ "<|10.78|>": 50903,
142
+ "<|10.80|>": 50904,
143
+ "<|10.82|>": 50905,
144
+ "<|10.84|>": 50906,
145
+ "<|10.86|>": 50907,
146
+ "<|10.88|>": 50908,
147
+ "<|10.90|>": 50909,
148
+ "<|10.92|>": 50910,
149
+ "<|10.94|>": 50911,
150
+ "<|10.96|>": 50912,
151
+ "<|10.98|>": 50913,
152
+ "<|11.00|>": 50914,
153
+ "<|11.02|>": 50915,
154
+ "<|11.04|>": 50916,
155
+ "<|11.06|>": 50917,
156
+ "<|11.08|>": 50918,
157
+ "<|11.10|>": 50919,
158
+ "<|11.12|>": 50920,
159
+ "<|11.14|>": 50921,
160
+ "<|11.16|>": 50922,
161
+ "<|11.18|>": 50923,
162
+ "<|11.20|>": 50924,
163
+ "<|11.22|>": 50925,
164
+ "<|11.24|>": 50926,
165
+ "<|11.26|>": 50927,
166
+ "<|11.28|>": 50928,
167
+ "<|11.30|>": 50929,
168
+ "<|11.32|>": 50930,
169
+ "<|11.34|>": 50931,
170
+ "<|11.36|>": 50932,
171
+ "<|11.38|>": 50933,
172
+ "<|11.40|>": 50934,
173
+ "<|11.42|>": 50935,
174
+ "<|11.44|>": 50936,
175
+ "<|11.46|>": 50937,
176
+ "<|11.48|>": 50938,
177
+ "<|11.50|>": 50939,
178
+ "<|11.52|>": 50940,
179
+ "<|11.54|>": 50941,
180
+ "<|11.56|>": 50942,
181
+ "<|11.58|>": 50943,
182
+ "<|11.60|>": 50944,
183
+ "<|11.62|>": 50945,
184
+ "<|11.64|>": 50946,
185
+ "<|11.66|>": 50947,
186
+ "<|11.68|>": 50948,
187
+ "<|11.70|>": 50949,
188
+ "<|11.72|>": 50950,
189
+ "<|11.74|>": 50951,
190
+ "<|11.76|>": 50952,
191
+ "<|11.78|>": 50953,
192
+ "<|11.80|>": 50954,
193
+ "<|11.82|>": 50955,
194
+ "<|11.84|>": 50956,
195
+ "<|11.86|>": 50957,
196
+ "<|11.88|>": 50958,
197
+ "<|11.90|>": 50959,
198
+ "<|11.92|>": 50960,
199
+ "<|11.94|>": 50961,
200
+ "<|11.96|>": 50962,
201
+ "<|11.98|>": 50963,
202
+ "<|12.00|>": 50964,
203
+ "<|12.02|>": 50965,
204
+ "<|12.04|>": 50966,
205
+ "<|12.06|>": 50967,
206
+ "<|12.08|>": 50968,
207
+ "<|12.10|>": 50969,
208
+ "<|12.12|>": 50970,
209
+ "<|12.14|>": 50971,
210
+ "<|12.16|>": 50972,
211
+ "<|12.18|>": 50973,
212
+ "<|12.20|>": 50974,
213
+ "<|12.22|>": 50975,
214
+ "<|12.24|>": 50976,
215
+ "<|12.26|>": 50977,
216
+ "<|12.28|>": 50978,
217
+ "<|12.30|>": 50979,
218
+ "<|12.32|>": 50980,
219
+ "<|12.34|>": 50981,
220
+ "<|12.36|>": 50982,
221
+ "<|12.38|>": 50983,
222
+ "<|12.40|>": 50984,
223
+ "<|12.42|>": 50985,
224
+ "<|12.44|>": 50986,
225
+ "<|12.46|>": 50987,
226
+ "<|12.48|>": 50988,
227
+ "<|12.50|>": 50989,
228
+ "<|12.52|>": 50990,
229
+ "<|12.54|>": 50991,
230
+ "<|12.56|>": 50992,
231
+ "<|12.58|>": 50993,
232
+ "<|12.60|>": 50994,
233
+ "<|12.62|>": 50995,
234
+ "<|12.64|>": 50996,
235
+ "<|12.66|>": 50997,
236
+ "<|12.68|>": 50998,
237
+ "<|12.70|>": 50999,
238
+ "<|12.72|>": 51000,
239
+ "<|12.74|>": 51001,
240
+ "<|12.76|>": 51002,
241
+ "<|12.78|>": 51003,
242
+ "<|12.80|>": 51004,
243
+ "<|12.82|>": 51005,
244
+ "<|12.84|>": 51006,
245
+ "<|12.86|>": 51007,
246
+ "<|12.88|>": 51008,
247
+ "<|12.90|>": 51009,
248
+ "<|12.92|>": 51010,
249
+ "<|12.94|>": 51011,
250
+ "<|12.96|>": 51012,
251
+ "<|12.98|>": 51013,
252
+ "<|13.00|>": 51014,
253
+ "<|13.02|>": 51015,
254
+ "<|13.04|>": 51016,
255
+ "<|13.06|>": 51017,
256
+ "<|13.08|>": 51018,
257
+ "<|13.10|>": 51019,
258
+ "<|13.12|>": 51020,
259
+ "<|13.14|>": 51021,
260
+ "<|13.16|>": 51022,
261
+ "<|13.18|>": 51023,
262
+ "<|13.20|>": 51024,
263
+ "<|13.22|>": 51025,
264
+ "<|13.24|>": 51026,
265
+ "<|13.26|>": 51027,
266
+ "<|13.28|>": 51028,
267
+ "<|13.30|>": 51029,
268
+ "<|13.32|>": 51030,
269
+ "<|13.34|>": 51031,
270
+ "<|13.36|>": 51032,
271
+ "<|13.38|>": 51033,
272
+ "<|13.40|>": 51034,
273
+ "<|13.42|>": 51035,
274
+ "<|13.44|>": 51036,
275
+ "<|13.46|>": 51037,
276
+ "<|13.48|>": 51038,
277
+ "<|13.50|>": 51039,
278
+ "<|13.52|>": 51040,
279
+ "<|13.54|>": 51041,
280
+ "<|13.56|>": 51042,
281
+ "<|13.58|>": 51043,
282
+ "<|13.60|>": 51044,
283
+ "<|13.62|>": 51045,
284
+ "<|13.64|>": 51046,
285
+ "<|13.66|>": 51047,
286
+ "<|13.68|>": 51048,
287
+ "<|13.70|>": 51049,
288
+ "<|13.72|>": 51050,
289
+ "<|13.74|>": 51051,
290
+ "<|13.76|>": 51052,
291
+ "<|13.78|>": 51053,
292
+ "<|13.80|>": 51054,
293
+ "<|13.82|>": 51055,
294
+ "<|13.84|>": 51056,
295
+ "<|13.86|>": 51057,
296
+ "<|13.88|>": 51058,
297
+ "<|13.90|>": 51059,
298
+ "<|13.92|>": 51060,
299
+ "<|13.94|>": 51061,
300
+ "<|13.96|>": 51062,
301
+ "<|13.98|>": 51063,
302
+ "<|14.00|>": 51064,
303
+ "<|14.02|>": 51065,
304
+ "<|14.04|>": 51066,
305
+ "<|14.06|>": 51067,
306
+ "<|14.08|>": 51068,
307
+ "<|14.10|>": 51069,
308
+ "<|14.12|>": 51070,
309
+ "<|14.14|>": 51071,
310
+ "<|14.16|>": 51072,
311
+ "<|14.18|>": 51073,
312
+ "<|14.20|>": 51074,
313
+ "<|14.22|>": 51075,
314
+ "<|14.24|>": 51076,
315
+ "<|14.26|>": 51077,
316
+ "<|14.28|>": 51078,
317
+ "<|14.30|>": 51079,
318
+ "<|14.32|>": 51080,
319
+ "<|14.34|>": 51081,
320
+ "<|14.36|>": 51082,
321
+ "<|14.38|>": 51083,
322
+ "<|14.40|>": 51084,
323
+ "<|14.42|>": 51085,
324
+ "<|14.44|>": 51086,
325
+ "<|14.46|>": 51087,
326
+ "<|14.48|>": 51088,
327
+ "<|14.50|>": 51089,
328
+ "<|14.52|>": 51090,
329
+ "<|14.54|>": 51091,
330
+ "<|14.56|>": 51092,
331
+ "<|14.58|>": 51093,
332
+ "<|14.60|>": 51094,
333
+ "<|14.62|>": 51095,
334
+ "<|14.64|>": 51096,
335
+ "<|14.66|>": 51097,
336
+ "<|14.68|>": 51098,
337
+ "<|14.70|>": 51099,
338
+ "<|14.72|>": 51100,
339
+ "<|14.74|>": 51101,
340
+ "<|14.76|>": 51102,
341
+ "<|14.78|>": 51103,
342
+ "<|14.80|>": 51104,
343
+ "<|14.82|>": 51105,
344
+ "<|14.84|>": 51106,
345
+ "<|14.86|>": 51107,
346
+ "<|14.88|>": 51108,
347
+ "<|14.90|>": 51109,
348
+ "<|14.92|>": 51110,
349
+ "<|14.94|>": 51111,
350
+ "<|14.96|>": 51112,
351
+ "<|14.98|>": 51113,
352
+ "<|15.00|>": 51114,
353
+ "<|15.02|>": 51115,
354
+ "<|15.04|>": 51116,
355
+ "<|15.06|>": 51117,
356
+ "<|15.08|>": 51118,
357
+ "<|15.10|>": 51119,
358
+ "<|15.12|>": 51120,
359
+ "<|15.14|>": 51121,
360
+ "<|15.16|>": 51122,
361
+ "<|15.18|>": 51123,
362
+ "<|15.20|>": 51124,
363
+ "<|15.22|>": 51125,
364
+ "<|15.24|>": 51126,
365
+ "<|15.26|>": 51127,
366
+ "<|15.28|>": 51128,
367
+ "<|15.30|>": 51129,
368
+ "<|15.32|>": 51130,
369
+ "<|15.34|>": 51131,
370
+ "<|15.36|>": 51132,
371
+ "<|15.38|>": 51133,
372
+ "<|15.40|>": 51134,
373
+ "<|15.42|>": 51135,
374
+ "<|15.44|>": 51136,
375
+ "<|15.46|>": 51137,
376
+ "<|15.48|>": 51138,
377
+ "<|15.50|>": 51139,
378
+ "<|15.52|>": 51140,
379
+ "<|15.54|>": 51141,
380
+ "<|15.56|>": 51142,
381
+ "<|15.58|>": 51143,
382
+ "<|15.60|>": 51144,
383
+ "<|15.62|>": 51145,
384
+ "<|15.64|>": 51146,
385
+ "<|15.66|>": 51147,
386
+ "<|15.68|>": 51148,
387
+ "<|15.70|>": 51149,
388
+ "<|15.72|>": 51150,
389
+ "<|15.74|>": 51151,
390
+ "<|15.76|>": 51152,
391
+ "<|15.78|>": 51153,
392
+ "<|15.80|>": 51154,
393
+ "<|15.82|>": 51155,
394
+ "<|15.84|>": 51156,
395
+ "<|15.86|>": 51157,
396
+ "<|15.88|>": 51158,
397
+ "<|15.90|>": 51159,
398
+ "<|15.92|>": 51160,
399
+ "<|15.94|>": 51161,
400
+ "<|15.96|>": 51162,
401
+ "<|15.98|>": 51163,
402
+ "<|16.00|>": 51164,
403
+ "<|16.02|>": 51165,
404
+ "<|16.04|>": 51166,
405
+ "<|16.06|>": 51167,
406
+ "<|16.08|>": 51168,
407
+ "<|16.10|>": 51169,
408
+ "<|16.12|>": 51170,
409
+ "<|16.14|>": 51171,
410
+ "<|16.16|>": 51172,
411
+ "<|16.18|>": 51173,
412
+ "<|16.20|>": 51174,
413
+ "<|16.22|>": 51175,
414
+ "<|16.24|>": 51176,
415
+ "<|16.26|>": 51177,
416
+ "<|16.28|>": 51178,
417
+ "<|16.30|>": 51179,
418
+ "<|16.32|>": 51180,
419
+ "<|16.34|>": 51181,
420
+ "<|16.36|>": 51182,
421
+ "<|16.38|>": 51183,
422
+ "<|16.40|>": 51184,
423
+ "<|16.42|>": 51185,
424
+ "<|16.44|>": 51186,
425
+ "<|16.46|>": 51187,
426
+ "<|16.48|>": 51188,
427
+ "<|16.50|>": 51189,
428
+ "<|16.52|>": 51190,
429
+ "<|16.54|>": 51191,
430
+ "<|16.56|>": 51192,
431
+ "<|16.58|>": 51193,
432
+ "<|16.60|>": 51194,
433
+ "<|16.62|>": 51195,
434
+ "<|16.64|>": 51196,
435
+ "<|16.66|>": 51197,
436
+ "<|16.68|>": 51198,
437
+ "<|16.70|>": 51199,
438
+ "<|16.72|>": 51200,
439
+ "<|16.74|>": 51201,
440
+ "<|16.76|>": 51202,
441
+ "<|16.78|>": 51203,
442
+ "<|16.80|>": 51204,
443
+ "<|16.82|>": 51205,
444
+ "<|16.84|>": 51206,
445
+ "<|16.86|>": 51207,
446
+ "<|16.88|>": 51208,
447
+ "<|16.90|>": 51209,
448
+ "<|16.92|>": 51210,
449
+ "<|16.94|>": 51211,
450
+ "<|16.96|>": 51212,
451
+ "<|16.98|>": 51213,
452
+ "<|17.00|>": 51214,
453
+ "<|17.02|>": 51215,
454
+ "<|17.04|>": 51216,
455
+ "<|17.06|>": 51217,
456
+ "<|17.08|>": 51218,
457
+ "<|17.10|>": 51219,
458
+ "<|17.12|>": 51220,
459
+ "<|17.14|>": 51221,
460
+ "<|17.16|>": 51222,
461
+ "<|17.18|>": 51223,
462
+ "<|17.20|>": 51224,
463
+ "<|17.22|>": 51225,
464
+ "<|17.24|>": 51226,
465
+ "<|17.26|>": 51227,
466
+ "<|17.28|>": 51228,
467
+ "<|17.30|>": 51229,
468
+ "<|17.32|>": 51230,
469
+ "<|17.34|>": 51231,
470
+ "<|17.36|>": 51232,
471
+ "<|17.38|>": 51233,
472
+ "<|17.40|>": 51234,
473
+ "<|17.42|>": 51235,
474
+ "<|17.44|>": 51236,
475
+ "<|17.46|>": 51237,
476
+ "<|17.48|>": 51238,
477
+ "<|17.50|>": 51239,
478
+ "<|17.52|>": 51240,
479
+ "<|17.54|>": 51241,
480
+ "<|17.56|>": 51242,
481
+ "<|17.58|>": 51243,
482
+ "<|17.60|>": 51244,
483
+ "<|17.62|>": 51245,
484
+ "<|17.64|>": 51246,
485
+ "<|17.66|>": 51247,
486
+ "<|17.68|>": 51248,
487
+ "<|17.70|>": 51249,
488
+ "<|17.72|>": 51250,
489
+ "<|17.74|>": 51251,
490
+ "<|17.76|>": 51252,
491
+ "<|17.78|>": 51253,
492
+ "<|17.80|>": 51254,
493
+ "<|17.82|>": 51255,
494
+ "<|17.84|>": 51256,
495
+ "<|17.86|>": 51257,
496
+ "<|17.88|>": 51258,
497
+ "<|17.90|>": 51259,
498
+ "<|17.92|>": 51260,
499
+ "<|17.94|>": 51261,
500
+ "<|17.96|>": 51262,
501
+ "<|17.98|>": 51263,
502
+ "<|18.00|>": 51264,
503
+ "<|18.02|>": 51265,
504
+ "<|18.04|>": 51266,
505
+ "<|18.06|>": 51267,
506
+ "<|18.08|>": 51268,
507
+ "<|18.10|>": 51269,
508
+ "<|18.12|>": 51270,
509
+ "<|18.14|>": 51271,
510
+ "<|18.16|>": 51272,
511
+ "<|18.18|>": 51273,
512
+ "<|18.20|>": 51274,
513
+ "<|18.22|>": 51275,
514
+ "<|18.24|>": 51276,
515
+ "<|18.26|>": 51277,
516
+ "<|18.28|>": 51278,
517
+ "<|18.30|>": 51279,
518
+ "<|18.32|>": 51280,
519
+ "<|18.34|>": 51281,
520
+ "<|18.36|>": 51282,
521
+ "<|18.38|>": 51283,
522
+ "<|18.40|>": 51284,
523
+ "<|18.42|>": 51285,
524
+ "<|18.44|>": 51286,
525
+ "<|18.46|>": 51287,
526
+ "<|18.48|>": 51288,
527
+ "<|18.50|>": 51289,
528
+ "<|18.52|>": 51290,
529
+ "<|18.54|>": 51291,
530
+ "<|18.56|>": 51292,
531
+ "<|18.58|>": 51293,
532
+ "<|18.60|>": 51294,
533
+ "<|18.62|>": 51295,
534
+ "<|18.64|>": 51296,
535
+ "<|18.66|>": 51297,
536
+ "<|18.68|>": 51298,
537
+ "<|18.70|>": 51299,
538
+ "<|18.72|>": 51300,
539
+ "<|18.74|>": 51301,
540
+ "<|18.76|>": 51302,
541
+ "<|18.78|>": 51303,
542
+ "<|18.80|>": 51304,
543
+ "<|18.82|>": 51305,
544
+ "<|18.84|>": 51306,
545
+ "<|18.86|>": 51307,
546
+ "<|18.88|>": 51308,
547
+ "<|18.90|>": 51309,
548
+ "<|18.92|>": 51310,
549
+ "<|18.94|>": 51311,
550
+ "<|18.96|>": 51312,
551
+ "<|18.98|>": 51313,
552
+ "<|19.00|>": 51314,
553
+ "<|19.02|>": 51315,
554
+ "<|19.04|>": 51316,
555
+ "<|19.06|>": 51317,
556
+ "<|19.08|>": 51318,
557
+ "<|19.10|>": 51319,
558
+ "<|19.12|>": 51320,
559
+ "<|19.14|>": 51321,
560
+ "<|19.16|>": 51322,
561
+ "<|19.18|>": 51323,
562
+ "<|19.20|>": 51324,
563
+ "<|19.22|>": 51325,
564
+ "<|19.24|>": 51326,
565
+ "<|19.26|>": 51327,
566
+ "<|19.28|>": 51328,
567
+ "<|19.30|>": 51329,
568
+ "<|19.32|>": 51330,
569
+ "<|19.34|>": 51331,
570
+ "<|19.36|>": 51332,
571
+ "<|19.38|>": 51333,
572
+ "<|19.40|>": 51334,
573
+ "<|19.42|>": 51335,
574
+ "<|19.44|>": 51336,
575
+ "<|19.46|>": 51337,
576
+ "<|19.48|>": 51338,
577
+ "<|19.50|>": 51339,
578
+ "<|19.52|>": 51340,
579
+ "<|19.54|>": 51341,
580
+ "<|19.56|>": 51342,
581
+ "<|19.58|>": 51343,
582
+ "<|19.60|>": 51344,
583
+ "<|19.62|>": 51345,
584
+ "<|19.64|>": 51346,
585
+ "<|19.66|>": 51347,
586
+ "<|19.68|>": 51348,
587
+ "<|19.70|>": 51349,
588
+ "<|19.72|>": 51350,
589
+ "<|19.74|>": 51351,
590
+ "<|19.76|>": 51352,
591
+ "<|19.78|>": 51353,
592
+ "<|19.80|>": 51354,
593
+ "<|19.82|>": 51355,
594
+ "<|19.84|>": 51356,
595
+ "<|19.86|>": 51357,
596
+ "<|19.88|>": 51358,
597
+ "<|19.90|>": 51359,
598
+ "<|19.92|>": 51360,
599
+ "<|19.94|>": 51361,
600
+ "<|19.96|>": 51362,
601
+ "<|19.98|>": 51363,
602
+ "<|2.00|>": 50464,
603
+ "<|2.02|>": 50465,
604
+ "<|2.04|>": 50466,
605
+ "<|2.06|>": 50467,
606
+ "<|2.08|>": 50468,
607
+ "<|2.10|>": 50469,
608
+ "<|2.12|>": 50470,
609
+ "<|2.14|>": 50471,
610
+ "<|2.16|>": 50472,
611
+ "<|2.18|>": 50473,
612
+ "<|2.20|>": 50474,
613
+ "<|2.22|>": 50475,
614
+ "<|2.24|>": 50476,
615
+ "<|2.26|>": 50477,
616
+ "<|2.28|>": 50478,
617
+ "<|2.30|>": 50479,
618
+ "<|2.32|>": 50480,
619
+ "<|2.34|>": 50481,
620
+ "<|2.36|>": 50482,
621
+ "<|2.38|>": 50483,
622
+ "<|2.40|>": 50484,
623
+ "<|2.42|>": 50485,
624
+ "<|2.44|>": 50486,
625
+ "<|2.46|>": 50487,
626
+ "<|2.48|>": 50488,
627
+ "<|2.50|>": 50489,
628
+ "<|2.52|>": 50490,
629
+ "<|2.54|>": 50491,
630
+ "<|2.56|>": 50492,
631
+ "<|2.58|>": 50493,
632
+ "<|2.60|>": 50494,
633
+ "<|2.62|>": 50495,
634
+ "<|2.64|>": 50496,
635
+ "<|2.66|>": 50497,
636
+ "<|2.68|>": 50498,
637
+ "<|2.70|>": 50499,
638
+ "<|2.72|>": 50500,
639
+ "<|2.74|>": 50501,
640
+ "<|2.76|>": 50502,
641
+ "<|2.78|>": 50503,
642
+ "<|2.80|>": 50504,
643
+ "<|2.82|>": 50505,
644
+ "<|2.84|>": 50506,
645
+ "<|2.86|>": 50507,
646
+ "<|2.88|>": 50508,
647
+ "<|2.90|>": 50509,
648
+ "<|2.92|>": 50510,
649
+ "<|2.94|>": 50511,
650
+ "<|2.96|>": 50512,
651
+ "<|2.98|>": 50513,
652
+ "<|20.00|>": 51364,
653
+ "<|20.02|>": 51365,
654
+ "<|20.04|>": 51366,
655
+ "<|20.06|>": 51367,
656
+ "<|20.08|>": 51368,
657
+ "<|20.10|>": 51369,
658
+ "<|20.12|>": 51370,
659
+ "<|20.14|>": 51371,
660
+ "<|20.16|>": 51372,
661
+ "<|20.18|>": 51373,
662
+ "<|20.20|>": 51374,
663
+ "<|20.22|>": 51375,
664
+ "<|20.24|>": 51376,
665
+ "<|20.26|>": 51377,
666
+ "<|20.28|>": 51378,
667
+ "<|20.30|>": 51379,
668
+ "<|20.32|>": 51380,
669
+ "<|20.34|>": 51381,
670
+ "<|20.36|>": 51382,
671
+ "<|20.38|>": 51383,
672
+ "<|20.40|>": 51384,
673
+ "<|20.42|>": 51385,
674
+ "<|20.44|>": 51386,
675
+ "<|20.46|>": 51387,
676
+ "<|20.48|>": 51388,
677
+ "<|20.50|>": 51389,
678
+ "<|20.52|>": 51390,
679
+ "<|20.54|>": 51391,
680
+ "<|20.56|>": 51392,
681
+ "<|20.58|>": 51393,
682
+ "<|20.60|>": 51394,
683
+ "<|20.62|>": 51395,
684
+ "<|20.64|>": 51396,
685
+ "<|20.66|>": 51397,
686
+ "<|20.68|>": 51398,
687
+ "<|20.70|>": 51399,
688
+ "<|20.72|>": 51400,
689
+ "<|20.74|>": 51401,
690
+ "<|20.76|>": 51402,
691
+ "<|20.78|>": 51403,
692
+ "<|20.80|>": 51404,
693
+ "<|20.82|>": 51405,
694
+ "<|20.84|>": 51406,
695
+ "<|20.86|>": 51407,
696
+ "<|20.88|>": 51408,
697
+ "<|20.90|>": 51409,
698
+ "<|20.92|>": 51410,
699
+ "<|20.94|>": 51411,
700
+ "<|20.96|>": 51412,
701
+ "<|20.98|>": 51413,
702
+ "<|21.00|>": 51414,
703
+ "<|21.02|>": 51415,
704
+ "<|21.04|>": 51416,
705
+ "<|21.06|>": 51417,
706
+ "<|21.08|>": 51418,
707
+ "<|21.10|>": 51419,
708
+ "<|21.12|>": 51420,
709
+ "<|21.14|>": 51421,
710
+ "<|21.16|>": 51422,
711
+ "<|21.18|>": 51423,
712
+ "<|21.20|>": 51424,
713
+ "<|21.22|>": 51425,
714
+ "<|21.24|>": 51426,
715
+ "<|21.26|>": 51427,
716
+ "<|21.28|>": 51428,
717
+ "<|21.30|>": 51429,
718
+ "<|21.32|>": 51430,
719
+ "<|21.34|>": 51431,
720
+ "<|21.36|>": 51432,
721
+ "<|21.38|>": 51433,
722
+ "<|21.40|>": 51434,
723
+ "<|21.42|>": 51435,
724
+ "<|21.44|>": 51436,
725
+ "<|21.46|>": 51437,
726
+ "<|21.48|>": 51438,
727
+ "<|21.50|>": 51439,
728
+ "<|21.52|>": 51440,
729
+ "<|21.54|>": 51441,
730
+ "<|21.56|>": 51442,
731
+ "<|21.58|>": 51443,
732
+ "<|21.60|>": 51444,
733
+ "<|21.62|>": 51445,
734
+ "<|21.64|>": 51446,
735
+ "<|21.66|>": 51447,
736
+ "<|21.68|>": 51448,
737
+ "<|21.70|>": 51449,
738
+ "<|21.72|>": 51450,
739
+ "<|21.74|>": 51451,
740
+ "<|21.76|>": 51452,
741
+ "<|21.78|>": 51453,
742
+ "<|21.80|>": 51454,
743
+ "<|21.82|>": 51455,
744
+ "<|21.84|>": 51456,
745
+ "<|21.86|>": 51457,
746
+ "<|21.88|>": 51458,
747
+ "<|21.90|>": 51459,
748
+ "<|21.92|>": 51460,
749
+ "<|21.94|>": 51461,
750
+ "<|21.96|>": 51462,
751
+ "<|21.98|>": 51463,
752
+ "<|22.00|>": 51464,
753
+ "<|22.02|>": 51465,
754
+ "<|22.04|>": 51466,
755
+ "<|22.06|>": 51467,
756
+ "<|22.08|>": 51468,
757
+ "<|22.10|>": 51469,
758
+ "<|22.12|>": 51470,
759
+ "<|22.14|>": 51471,
760
+ "<|22.16|>": 51472,
761
+ "<|22.18|>": 51473,
762
+ "<|22.20|>": 51474,
763
+ "<|22.22|>": 51475,
764
+ "<|22.24|>": 51476,
765
+ "<|22.26|>": 51477,
766
+ "<|22.28|>": 51478,
767
+ "<|22.30|>": 51479,
768
+ "<|22.32|>": 51480,
769
+ "<|22.34|>": 51481,
770
+ "<|22.36|>": 51482,
771
+ "<|22.38|>": 51483,
772
+ "<|22.40|>": 51484,
773
+ "<|22.42|>": 51485,
774
+ "<|22.44|>": 51486,
775
+ "<|22.46|>": 51487,
776
+ "<|22.48|>": 51488,
777
+ "<|22.50|>": 51489,
778
+ "<|22.52|>": 51490,
779
+ "<|22.54|>": 51491,
780
+ "<|22.56|>": 51492,
781
+ "<|22.58|>": 51493,
782
+ "<|22.60|>": 51494,
783
+ "<|22.62|>": 51495,
784
+ "<|22.64|>": 51496,
785
+ "<|22.66|>": 51497,
786
+ "<|22.68|>": 51498,
787
+ "<|22.70|>": 51499,
788
+ "<|22.72|>": 51500,
789
+ "<|22.74|>": 51501,
790
+ "<|22.76|>": 51502,
791
+ "<|22.78|>": 51503,
792
+ "<|22.80|>": 51504,
793
+ "<|22.82|>": 51505,
794
+ "<|22.84|>": 51506,
795
+ "<|22.86|>": 51507,
796
+ "<|22.88|>": 51508,
797
+ "<|22.90|>": 51509,
798
+ "<|22.92|>": 51510,
799
+ "<|22.94|>": 51511,
800
+ "<|22.96|>": 51512,
801
+ "<|22.98|>": 51513,
802
+ "<|23.00|>": 51514,
803
+ "<|23.02|>": 51515,
804
+ "<|23.04|>": 51516,
805
+ "<|23.06|>": 51517,
806
+ "<|23.08|>": 51518,
807
+ "<|23.10|>": 51519,
808
+ "<|23.12|>": 51520,
809
+ "<|23.14|>": 51521,
810
+ "<|23.16|>": 51522,
811
+ "<|23.18|>": 51523,
812
+ "<|23.20|>": 51524,
813
+ "<|23.22|>": 51525,
814
+ "<|23.24|>": 51526,
815
+ "<|23.26|>": 51527,
816
+ "<|23.28|>": 51528,
817
+ "<|23.30|>": 51529,
818
+ "<|23.32|>": 51530,
819
+ "<|23.34|>": 51531,
820
+ "<|23.36|>": 51532,
821
+ "<|23.38|>": 51533,
822
+ "<|23.40|>": 51534,
823
+ "<|23.42|>": 51535,
824
+ "<|23.44|>": 51536,
825
+ "<|23.46|>": 51537,
826
+ "<|23.48|>": 51538,
827
+ "<|23.50|>": 51539,
828
+ "<|23.52|>": 51540,
829
+ "<|23.54|>": 51541,
830
+ "<|23.56|>": 51542,
831
+ "<|23.58|>": 51543,
832
+ "<|23.60|>": 51544,
833
+ "<|23.62|>": 51545,
834
+ "<|23.64|>": 51546,
835
+ "<|23.66|>": 51547,
836
+ "<|23.68|>": 51548,
837
+ "<|23.70|>": 51549,
838
+ "<|23.72|>": 51550,
839
+ "<|23.74|>": 51551,
840
+ "<|23.76|>": 51552,
841
+ "<|23.78|>": 51553,
842
+ "<|23.80|>": 51554,
843
+ "<|23.82|>": 51555,
844
+ "<|23.84|>": 51556,
845
+ "<|23.86|>": 51557,
846
+ "<|23.88|>": 51558,
847
+ "<|23.90|>": 51559,
848
+ "<|23.92|>": 51560,
849
+ "<|23.94|>": 51561,
850
+ "<|23.96|>": 51562,
851
+ "<|23.98|>": 51563,
852
+ "<|24.00|>": 51564,
853
+ "<|24.02|>": 51565,
854
+ "<|24.04|>": 51566,
855
+ "<|24.06|>": 51567,
856
+ "<|24.08|>": 51568,
857
+ "<|24.10|>": 51569,
858
+ "<|24.12|>": 51570,
859
+ "<|24.14|>": 51571,
860
+ "<|24.16|>": 51572,
861
+ "<|24.18|>": 51573,
862
+ "<|24.20|>": 51574,
863
+ "<|24.22|>": 51575,
864
+ "<|24.24|>": 51576,
865
+ "<|24.26|>": 51577,
866
+ "<|24.28|>": 51578,
867
+ "<|24.30|>": 51579,
868
+ "<|24.32|>": 51580,
869
+ "<|24.34|>": 51581,
870
+ "<|24.36|>": 51582,
871
+ "<|24.38|>": 51583,
872
+ "<|24.40|>": 51584,
873
+ "<|24.42|>": 51585,
874
+ "<|24.44|>": 51586,
875
+ "<|24.46|>": 51587,
876
+ "<|24.48|>": 51588,
877
+ "<|24.50|>": 51589,
878
+ "<|24.52|>": 51590,
879
+ "<|24.54|>": 51591,
880
+ "<|24.56|>": 51592,
881
+ "<|24.58|>": 51593,
882
+ "<|24.60|>": 51594,
883
+ "<|24.62|>": 51595,
884
+ "<|24.64|>": 51596,
885
+ "<|24.66|>": 51597,
886
+ "<|24.68|>": 51598,
887
+ "<|24.70|>": 51599,
888
+ "<|24.72|>": 51600,
889
+ "<|24.74|>": 51601,
890
+ "<|24.76|>": 51602,
891
+ "<|24.78|>": 51603,
892
+ "<|24.80|>": 51604,
893
+ "<|24.82|>": 51605,
894
+ "<|24.84|>": 51606,
895
+ "<|24.86|>": 51607,
896
+ "<|24.88|>": 51608,
897
+ "<|24.90|>": 51609,
898
+ "<|24.92|>": 51610,
899
+ "<|24.94|>": 51611,
900
+ "<|24.96|>": 51612,
901
+ "<|24.98|>": 51613,
902
+ "<|25.00|>": 51614,
903
+ "<|25.02|>": 51615,
904
+ "<|25.04|>": 51616,
905
+ "<|25.06|>": 51617,
906
+ "<|25.08|>": 51618,
907
+ "<|25.10|>": 51619,
908
+ "<|25.12|>": 51620,
909
+ "<|25.14|>": 51621,
910
+ "<|25.16|>": 51622,
911
+ "<|25.18|>": 51623,
912
+ "<|25.20|>": 51624,
913
+ "<|25.22|>": 51625,
914
+ "<|25.24|>": 51626,
915
+ "<|25.26|>": 51627,
916
+ "<|25.28|>": 51628,
917
+ "<|25.30|>": 51629,
918
+ "<|25.32|>": 51630,
919
+ "<|25.34|>": 51631,
920
+ "<|25.36|>": 51632,
921
+ "<|25.38|>": 51633,
922
+ "<|25.40|>": 51634,
923
+ "<|25.42|>": 51635,
924
+ "<|25.44|>": 51636,
925
+ "<|25.46|>": 51637,
926
+ "<|25.48|>": 51638,
927
+ "<|25.50|>": 51639,
928
+ "<|25.52|>": 51640,
929
+ "<|25.54|>": 51641,
930
+ "<|25.56|>": 51642,
931
+ "<|25.58|>": 51643,
932
+ "<|25.60|>": 51644,
933
+ "<|25.62|>": 51645,
934
+ "<|25.64|>": 51646,
935
+ "<|25.66|>": 51647,
936
+ "<|25.68|>": 51648,
937
+ "<|25.70|>": 51649,
938
+ "<|25.72|>": 51650,
939
+ "<|25.74|>": 51651,
940
+ "<|25.76|>": 51652,
941
+ "<|25.78|>": 51653,
942
+ "<|25.80|>": 51654,
943
+ "<|25.82|>": 51655,
944
+ "<|25.84|>": 51656,
945
+ "<|25.86|>": 51657,
946
+ "<|25.88|>": 51658,
947
+ "<|25.90|>": 51659,
948
+ "<|25.92|>": 51660,
949
+ "<|25.94|>": 51661,
950
+ "<|25.96|>": 51662,
951
+ "<|25.98|>": 51663,
952
+ "<|26.00|>": 51664,
953
+ "<|26.02|>": 51665,
954
+ "<|26.04|>": 51666,
955
+ "<|26.06|>": 51667,
956
+ "<|26.08|>": 51668,
957
+ "<|26.10|>": 51669,
958
+ "<|26.12|>": 51670,
959
+ "<|26.14|>": 51671,
960
+ "<|26.16|>": 51672,
961
+ "<|26.18|>": 51673,
962
+ "<|26.20|>": 51674,
963
+ "<|26.22|>": 51675,
964
+ "<|26.24|>": 51676,
965
+ "<|26.26|>": 51677,
966
+ "<|26.28|>": 51678,
967
+ "<|26.30|>": 51679,
968
+ "<|26.32|>": 51680,
969
+ "<|26.34|>": 51681,
970
+ "<|26.36|>": 51682,
971
+ "<|26.38|>": 51683,
972
+ "<|26.40|>": 51684,
973
+ "<|26.42|>": 51685,
974
+ "<|26.44|>": 51686,
975
+ "<|26.46|>": 51687,
976
+ "<|26.48|>": 51688,
977
+ "<|26.50|>": 51689,
978
+ "<|26.52|>": 51690,
979
+ "<|26.54|>": 51691,
980
+ "<|26.56|>": 51692,
981
+ "<|26.58|>": 51693,
982
+ "<|26.60|>": 51694,
983
+ "<|26.62|>": 51695,
984
+ "<|26.64|>": 51696,
985
+ "<|26.66|>": 51697,
986
+ "<|26.68|>": 51698,
987
+ "<|26.70|>": 51699,
988
+ "<|26.72|>": 51700,
989
+ "<|26.74|>": 51701,
990
+ "<|26.76|>": 51702,
991
+ "<|26.78|>": 51703,
992
+ "<|26.80|>": 51704,
993
+ "<|26.82|>": 51705,
994
+ "<|26.84|>": 51706,
995
+ "<|26.86|>": 51707,
996
+ "<|26.88|>": 51708,
997
+ "<|26.90|>": 51709,
998
+ "<|26.92|>": 51710,
999
+ "<|26.94|>": 51711,
1000
+ "<|26.96|>": 51712,
1001
+ "<|26.98|>": 51713,
1002
+ "<|27.00|>": 51714,
1003
+ "<|27.02|>": 51715,
1004
+ "<|27.04|>": 51716,
1005
+ "<|27.06|>": 51717,
1006
+ "<|27.08|>": 51718,
1007
+ "<|27.10|>": 51719,
1008
+ "<|27.12|>": 51720,
1009
+ "<|27.14|>": 51721,
1010
+ "<|27.16|>": 51722,
1011
+ "<|27.18|>": 51723,
1012
+ "<|27.20|>": 51724,
1013
+ "<|27.22|>": 51725,
1014
+ "<|27.24|>": 51726,
1015
+ "<|27.26|>": 51727,
1016
+ "<|27.28|>": 51728,
1017
+ "<|27.30|>": 51729,
1018
+ "<|27.32|>": 51730,
1019
+ "<|27.34|>": 51731,
1020
+ "<|27.36|>": 51732,
1021
+ "<|27.38|>": 51733,
1022
+ "<|27.40|>": 51734,
1023
+ "<|27.42|>": 51735,
1024
+ "<|27.44|>": 51736,
1025
+ "<|27.46|>": 51737,
1026
+ "<|27.48|>": 51738,
1027
+ "<|27.50|>": 51739,
1028
+ "<|27.52|>": 51740,
1029
+ "<|27.54|>": 51741,
1030
+ "<|27.56|>": 51742,
1031
+ "<|27.58|>": 51743,
1032
+ "<|27.60|>": 51744,
1033
+ "<|27.62|>": 51745,
1034
+ "<|27.64|>": 51746,
1035
+ "<|27.66|>": 51747,
1036
+ "<|27.68|>": 51748,
1037
+ "<|27.70|>": 51749,
1038
+ "<|27.72|>": 51750,
1039
+ "<|27.74|>": 51751,
1040
+ "<|27.76|>": 51752,
1041
+ "<|27.78|>": 51753,
1042
+ "<|27.80|>": 51754,
1043
+ "<|27.82|>": 51755,
1044
+ "<|27.84|>": 51756,
1045
+ "<|27.86|>": 51757,
1046
+ "<|27.88|>": 51758,
1047
+ "<|27.90|>": 51759,
1048
+ "<|27.92|>": 51760,
1049
+ "<|27.94|>": 51761,
1050
+ "<|27.96|>": 51762,
1051
+ "<|27.98|>": 51763,
1052
+ "<|28.00|>": 51764,
1053
+ "<|28.02|>": 51765,
1054
+ "<|28.04|>": 51766,
1055
+ "<|28.06|>": 51767,
1056
+ "<|28.08|>": 51768,
1057
+ "<|28.10|>": 51769,
1058
+ "<|28.12|>": 51770,
1059
+ "<|28.14|>": 51771,
1060
+ "<|28.16|>": 51772,
1061
+ "<|28.18|>": 51773,
1062
+ "<|28.20|>": 51774,
1063
+ "<|28.22|>": 51775,
1064
+ "<|28.24|>": 51776,
1065
+ "<|28.26|>": 51777,
1066
+ "<|28.28|>": 51778,
1067
+ "<|28.30|>": 51779,
1068
+ "<|28.32|>": 51780,
1069
+ "<|28.34|>": 51781,
1070
+ "<|28.36|>": 51782,
1071
+ "<|28.38|>": 51783,
1072
+ "<|28.40|>": 51784,
1073
+ "<|28.42|>": 51785,
1074
+ "<|28.44|>": 51786,
1075
+ "<|28.46|>": 51787,
1076
+ "<|28.48|>": 51788,
1077
+ "<|28.50|>": 51789,
1078
+ "<|28.52|>": 51790,
1079
+ "<|28.54|>": 51791,
1080
+ "<|28.56|>": 51792,
1081
+ "<|28.58|>": 51793,
1082
+ "<|28.60|>": 51794,
1083
+ "<|28.62|>": 51795,
1084
+ "<|28.64|>": 51796,
1085
+ "<|28.66|>": 51797,
1086
+ "<|28.68|>": 51798,
1087
+ "<|28.70|>": 51799,
1088
+ "<|28.72|>": 51800,
1089
+ "<|28.74|>": 51801,
1090
+ "<|28.76|>": 51802,
1091
+ "<|28.78|>": 51803,
1092
+ "<|28.80|>": 51804,
1093
+ "<|28.82|>": 51805,
1094
+ "<|28.84|>": 51806,
1095
+ "<|28.86|>": 51807,
1096
+ "<|28.88|>": 51808,
1097
+ "<|28.90|>": 51809,
1098
+ "<|28.92|>": 51810,
1099
+ "<|28.94|>": 51811,
1100
+ "<|28.96|>": 51812,
1101
+ "<|28.98|>": 51813,
1102
+ "<|29.00|>": 51814,
1103
+ "<|29.02|>": 51815,
1104
+ "<|29.04|>": 51816,
1105
+ "<|29.06|>": 51817,
1106
+ "<|29.08|>": 51818,
1107
+ "<|29.10|>": 51819,
1108
+ "<|29.12|>": 51820,
1109
+ "<|29.14|>": 51821,
1110
+ "<|29.16|>": 51822,
1111
+ "<|29.18|>": 51823,
1112
+ "<|29.20|>": 51824,
1113
+ "<|29.22|>": 51825,
1114
+ "<|29.24|>": 51826,
1115
+ "<|29.26|>": 51827,
1116
+ "<|29.28|>": 51828,
1117
+ "<|29.30|>": 51829,
1118
+ "<|29.32|>": 51830,
1119
+ "<|29.34|>": 51831,
1120
+ "<|29.36|>": 51832,
1121
+ "<|29.38|>": 51833,
1122
+ "<|29.40|>": 51834,
1123
+ "<|29.42|>": 51835,
1124
+ "<|29.44|>": 51836,
1125
+ "<|29.46|>": 51837,
1126
+ "<|29.48|>": 51838,
1127
+ "<|29.50|>": 51839,
1128
+ "<|29.52|>": 51840,
1129
+ "<|29.54|>": 51841,
1130
+ "<|29.56|>": 51842,
1131
+ "<|29.58|>": 51843,
1132
+ "<|29.60|>": 51844,
1133
+ "<|29.62|>": 51845,
1134
+ "<|29.64|>": 51846,
1135
+ "<|29.66|>": 51847,
1136
+ "<|29.68|>": 51848,
1137
+ "<|29.70|>": 51849,
1138
+ "<|29.72|>": 51850,
1139
+ "<|29.74|>": 51851,
1140
+ "<|29.76|>": 51852,
1141
+ "<|29.78|>": 51853,
1142
+ "<|29.80|>": 51854,
1143
+ "<|29.82|>": 51855,
1144
+ "<|29.84|>": 51856,
1145
+ "<|29.86|>": 51857,
1146
+ "<|29.88|>": 51858,
1147
+ "<|29.90|>": 51859,
1148
+ "<|29.92|>": 51860,
1149
+ "<|29.94|>": 51861,
1150
+ "<|29.96|>": 51862,
1151
+ "<|29.98|>": 51863,
1152
+ "<|3.00|>": 50514,
1153
+ "<|3.02|>": 50515,
1154
+ "<|3.04|>": 50516,
1155
+ "<|3.06|>": 50517,
1156
+ "<|3.08|>": 50518,
1157
+ "<|3.10|>": 50519,
1158
+ "<|3.12|>": 50520,
1159
+ "<|3.14|>": 50521,
1160
+ "<|3.16|>": 50522,
1161
+ "<|3.18|>": 50523,
1162
+ "<|3.20|>": 50524,
1163
+ "<|3.22|>": 50525,
1164
+ "<|3.24|>": 50526,
1165
+ "<|3.26|>": 50527,
1166
+ "<|3.28|>": 50528,
1167
+ "<|3.30|>": 50529,
1168
+ "<|3.32|>": 50530,
1169
+ "<|3.34|>": 50531,
1170
+ "<|3.36|>": 50532,
1171
+ "<|3.38|>": 50533,
1172
+ "<|3.40|>": 50534,
1173
+ "<|3.42|>": 50535,
1174
+ "<|3.44|>": 50536,
1175
+ "<|3.46|>": 50537,
1176
+ "<|3.48|>": 50538,
1177
+ "<|3.50|>": 50539,
1178
+ "<|3.52|>": 50540,
1179
+ "<|3.54|>": 50541,
1180
+ "<|3.56|>": 50542,
1181
+ "<|3.58|>": 50543,
1182
+ "<|3.60|>": 50544,
1183
+ "<|3.62|>": 50545,
1184
+ "<|3.64|>": 50546,
1185
+ "<|3.66|>": 50547,
1186
+ "<|3.68|>": 50548,
1187
+ "<|3.70|>": 50549,
1188
+ "<|3.72|>": 50550,
1189
+ "<|3.74|>": 50551,
1190
+ "<|3.76|>": 50552,
1191
+ "<|3.78|>": 50553,
1192
+ "<|3.80|>": 50554,
1193
+ "<|3.82|>": 50555,
1194
+ "<|3.84|>": 50556,
1195
+ "<|3.86|>": 50557,
1196
+ "<|3.88|>": 50558,
1197
+ "<|3.90|>": 50559,
1198
+ "<|3.92|>": 50560,
1199
+ "<|3.94|>": 50561,
1200
+ "<|3.96|>": 50562,
1201
+ "<|3.98|>": 50563,
1202
+ "<|30.00|>": 51864,
1203
+ "<|4.00|>": 50564,
1204
+ "<|4.02|>": 50565,
1205
+ "<|4.04|>": 50566,
1206
+ "<|4.06|>": 50567,
1207
+ "<|4.08|>": 50568,
1208
+ "<|4.10|>": 50569,
1209
+ "<|4.12|>": 50570,
1210
+ "<|4.14|>": 50571,
1211
+ "<|4.16|>": 50572,
1212
+ "<|4.18|>": 50573,
1213
+ "<|4.20|>": 50574,
1214
+ "<|4.22|>": 50575,
1215
+ "<|4.24|>": 50576,
1216
+ "<|4.26|>": 50577,
1217
+ "<|4.28|>": 50578,
1218
+ "<|4.30|>": 50579,
1219
+ "<|4.32|>": 50580,
1220
+ "<|4.34|>": 50581,
1221
+ "<|4.36|>": 50582,
1222
+ "<|4.38|>": 50583,
1223
+ "<|4.40|>": 50584,
1224
+ "<|4.42|>": 50585,
1225
+ "<|4.44|>": 50586,
1226
+ "<|4.46|>": 50587,
1227
+ "<|4.48|>": 50588,
1228
+ "<|4.50|>": 50589,
1229
+ "<|4.52|>": 50590,
1230
+ "<|4.54|>": 50591,
1231
+ "<|4.56|>": 50592,
1232
+ "<|4.58|>": 50593,
1233
+ "<|4.60|>": 50594,
1234
+ "<|4.62|>": 50595,
1235
+ "<|4.64|>": 50596,
1236
+ "<|4.66|>": 50597,
1237
+ "<|4.68|>": 50598,
1238
+ "<|4.70|>": 50599,
1239
+ "<|4.72|>": 50600,
1240
+ "<|4.74|>": 50601,
1241
+ "<|4.76|>": 50602,
1242
+ "<|4.78|>": 50603,
1243
+ "<|4.80|>": 50604,
1244
+ "<|4.82|>": 50605,
1245
+ "<|4.84|>": 50606,
1246
+ "<|4.86|>": 50607,
1247
+ "<|4.88|>": 50608,
1248
+ "<|4.90|>": 50609,
1249
+ "<|4.92|>": 50610,
1250
+ "<|4.94|>": 50611,
1251
+ "<|4.96|>": 50612,
1252
+ "<|4.98|>": 50613,
1253
+ "<|5.00|>": 50614,
1254
+ "<|5.02|>": 50615,
1255
+ "<|5.04|>": 50616,
1256
+ "<|5.06|>": 50617,
1257
+ "<|5.08|>": 50618,
1258
+ "<|5.10|>": 50619,
1259
+ "<|5.12|>": 50620,
1260
+ "<|5.14|>": 50621,
1261
+ "<|5.16|>": 50622,
1262
+ "<|5.18|>": 50623,
1263
+ "<|5.20|>": 50624,
1264
+ "<|5.22|>": 50625,
1265
+ "<|5.24|>": 50626,
1266
+ "<|5.26|>": 50627,
1267
+ "<|5.28|>": 50628,
1268
+ "<|5.30|>": 50629,
1269
+ "<|5.32|>": 50630,
1270
+ "<|5.34|>": 50631,
1271
+ "<|5.36|>": 50632,
1272
+ "<|5.38|>": 50633,
1273
+ "<|5.40|>": 50634,
1274
+ "<|5.42|>": 50635,
1275
+ "<|5.44|>": 50636,
1276
+ "<|5.46|>": 50637,
1277
+ "<|5.48|>": 50638,
1278
+ "<|5.50|>": 50639,
1279
+ "<|5.52|>": 50640,
1280
+ "<|5.54|>": 50641,
1281
+ "<|5.56|>": 50642,
1282
+ "<|5.58|>": 50643,
1283
+ "<|5.60|>": 50644,
1284
+ "<|5.62|>": 50645,
1285
+ "<|5.64|>": 50646,
1286
+ "<|5.66|>": 50647,
1287
+ "<|5.68|>": 50648,
1288
+ "<|5.70|>": 50649,
1289
+ "<|5.72|>": 50650,
1290
+ "<|5.74|>": 50651,
1291
+ "<|5.76|>": 50652,
1292
+ "<|5.78|>": 50653,
1293
+ "<|5.80|>": 50654,
1294
+ "<|5.82|>": 50655,
1295
+ "<|5.84|>": 50656,
1296
+ "<|5.86|>": 50657,
1297
+ "<|5.88|>": 50658,
1298
+ "<|5.90|>": 50659,
1299
+ "<|5.92|>": 50660,
1300
+ "<|5.94|>": 50661,
1301
+ "<|5.96|>": 50662,
1302
+ "<|5.98|>": 50663,
1303
+ "<|6.00|>": 50664,
1304
+ "<|6.02|>": 50665,
1305
+ "<|6.04|>": 50666,
1306
+ "<|6.06|>": 50667,
1307
+ "<|6.08|>": 50668,
1308
+ "<|6.10|>": 50669,
1309
+ "<|6.12|>": 50670,
1310
+ "<|6.14|>": 50671,
1311
+ "<|6.16|>": 50672,
1312
+ "<|6.18|>": 50673,
1313
+ "<|6.20|>": 50674,
1314
+ "<|6.22|>": 50675,
1315
+ "<|6.24|>": 50676,
1316
+ "<|6.26|>": 50677,
1317
+ "<|6.28|>": 50678,
1318
+ "<|6.30|>": 50679,
1319
+ "<|6.32|>": 50680,
1320
+ "<|6.34|>": 50681,
1321
+ "<|6.36|>": 50682,
1322
+ "<|6.38|>": 50683,
1323
+ "<|6.40|>": 50684,
1324
+ "<|6.42|>": 50685,
1325
+ "<|6.44|>": 50686,
1326
+ "<|6.46|>": 50687,
1327
+ "<|6.48|>": 50688,
1328
+ "<|6.50|>": 50689,
1329
+ "<|6.52|>": 50690,
1330
+ "<|6.54|>": 50691,
1331
+ "<|6.56|>": 50692,
1332
+ "<|6.58|>": 50693,
1333
+ "<|6.60|>": 50694,
1334
+ "<|6.62|>": 50695,
1335
+ "<|6.64|>": 50696,
1336
+ "<|6.66|>": 50697,
1337
+ "<|6.68|>": 50698,
1338
+ "<|6.70|>": 50699,
1339
+ "<|6.72|>": 50700,
1340
+ "<|6.74|>": 50701,
1341
+ "<|6.76|>": 50702,
1342
+ "<|6.78|>": 50703,
1343
+ "<|6.80|>": 50704,
1344
+ "<|6.82|>": 50705,
1345
+ "<|6.84|>": 50706,
1346
+ "<|6.86|>": 50707,
1347
+ "<|6.88|>": 50708,
1348
+ "<|6.90|>": 50709,
1349
+ "<|6.92|>": 50710,
1350
+ "<|6.94|>": 50711,
1351
+ "<|6.96|>": 50712,
1352
+ "<|6.98|>": 50713,
1353
+ "<|7.00|>": 50714,
1354
+ "<|7.02|>": 50715,
1355
+ "<|7.04|>": 50716,
1356
+ "<|7.06|>": 50717,
1357
+ "<|7.08|>": 50718,
1358
+ "<|7.10|>": 50719,
1359
+ "<|7.12|>": 50720,
1360
+ "<|7.14|>": 50721,
1361
+ "<|7.16|>": 50722,
1362
+ "<|7.18|>": 50723,
1363
+ "<|7.20|>": 50724,
1364
+ "<|7.22|>": 50725,
1365
+ "<|7.24|>": 50726,
1366
+ "<|7.26|>": 50727,
1367
+ "<|7.28|>": 50728,
1368
+ "<|7.30|>": 50729,
1369
+ "<|7.32|>": 50730,
1370
+ "<|7.34|>": 50731,
1371
+ "<|7.36|>": 50732,
1372
+ "<|7.38|>": 50733,
1373
+ "<|7.40|>": 50734,
1374
+ "<|7.42|>": 50735,
1375
+ "<|7.44|>": 50736,
1376
+ "<|7.46|>": 50737,
1377
+ "<|7.48|>": 50738,
1378
+ "<|7.50|>": 50739,
1379
+ "<|7.52|>": 50740,
1380
+ "<|7.54|>": 50741,
1381
+ "<|7.56|>": 50742,
1382
+ "<|7.58|>": 50743,
1383
+ "<|7.60|>": 50744,
1384
+ "<|7.62|>": 50745,
1385
+ "<|7.64|>": 50746,
1386
+ "<|7.66|>": 50747,
1387
+ "<|7.68|>": 50748,
1388
+ "<|7.70|>": 50749,
1389
+ "<|7.72|>": 50750,
1390
+ "<|7.74|>": 50751,
1391
+ "<|7.76|>": 50752,
1392
+ "<|7.78|>": 50753,
1393
+ "<|7.80|>": 50754,
1394
+ "<|7.82|>": 50755,
1395
+ "<|7.84|>": 50756,
1396
+ "<|7.86|>": 50757,
1397
+ "<|7.88|>": 50758,
1398
+ "<|7.90|>": 50759,
1399
+ "<|7.92|>": 50760,
1400
+ "<|7.94|>": 50761,
1401
+ "<|7.96|>": 50762,
1402
+ "<|7.98|>": 50763,
1403
+ "<|8.00|>": 50764,
1404
+ "<|8.02|>": 50765,
1405
+ "<|8.04|>": 50766,
1406
+ "<|8.06|>": 50767,
1407
+ "<|8.08|>": 50768,
1408
+ "<|8.10|>": 50769,
1409
+ "<|8.12|>": 50770,
1410
+ "<|8.14|>": 50771,
1411
+ "<|8.16|>": 50772,
1412
+ "<|8.18|>": 50773,
1413
+ "<|8.20|>": 50774,
1414
+ "<|8.22|>": 50775,
1415
+ "<|8.24|>": 50776,
1416
+ "<|8.26|>": 50777,
1417
+ "<|8.28|>": 50778,
1418
+ "<|8.30|>": 50779,
1419
+ "<|8.32|>": 50780,
1420
+ "<|8.34|>": 50781,
1421
+ "<|8.36|>": 50782,
1422
+ "<|8.38|>": 50783,
1423
+ "<|8.40|>": 50784,
1424
+ "<|8.42|>": 50785,
1425
+ "<|8.44|>": 50786,
1426
+ "<|8.46|>": 50787,
1427
+ "<|8.48|>": 50788,
1428
+ "<|8.50|>": 50789,
1429
+ "<|8.52|>": 50790,
1430
+ "<|8.54|>": 50791,
1431
+ "<|8.56|>": 50792,
1432
+ "<|8.58|>": 50793,
1433
+ "<|8.60|>": 50794,
1434
+ "<|8.62|>": 50795,
1435
+ "<|8.64|>": 50796,
1436
+ "<|8.66|>": 50797,
1437
+ "<|8.68|>": 50798,
1438
+ "<|8.70|>": 50799,
1439
+ "<|8.72|>": 50800,
1440
+ "<|8.74|>": 50801,
1441
+ "<|8.76|>": 50802,
1442
+ "<|8.78|>": 50803,
1443
+ "<|8.80|>": 50804,
1444
+ "<|8.82|>": 50805,
1445
+ "<|8.84|>": 50806,
1446
+ "<|8.86|>": 50807,
1447
+ "<|8.88|>": 50808,
1448
+ "<|8.90|>": 50809,
1449
+ "<|8.92|>": 50810,
1450
+ "<|8.94|>": 50811,
1451
+ "<|8.96|>": 50812,
1452
+ "<|8.98|>": 50813,
1453
+ "<|9.00|>": 50814,
1454
+ "<|9.02|>": 50815,
1455
+ "<|9.04|>": 50816,
1456
+ "<|9.06|>": 50817,
1457
+ "<|9.08|>": 50818,
1458
+ "<|9.10|>": 50819,
1459
+ "<|9.12|>": 50820,
1460
+ "<|9.14|>": 50821,
1461
+ "<|9.16|>": 50822,
1462
+ "<|9.18|>": 50823,
1463
+ "<|9.20|>": 50824,
1464
+ "<|9.22|>": 50825,
1465
+ "<|9.24|>": 50826,
1466
+ "<|9.26|>": 50827,
1467
+ "<|9.28|>": 50828,
1468
+ "<|9.30|>": 50829,
1469
+ "<|9.32|>": 50830,
1470
+ "<|9.34|>": 50831,
1471
+ "<|9.36|>": 50832,
1472
+ "<|9.38|>": 50833,
1473
+ "<|9.40|>": 50834,
1474
+ "<|9.42|>": 50835,
1475
+ "<|9.44|>": 50836,
1476
+ "<|9.46|>": 50837,
1477
+ "<|9.48|>": 50838,
1478
+ "<|9.50|>": 50839,
1479
+ "<|9.52|>": 50840,
1480
+ "<|9.54|>": 50841,
1481
+ "<|9.56|>": 50842,
1482
+ "<|9.58|>": 50843,
1483
+ "<|9.60|>": 50844,
1484
+ "<|9.62|>": 50845,
1485
+ "<|9.64|>": 50846,
1486
+ "<|9.66|>": 50847,
1487
+ "<|9.68|>": 50848,
1488
+ "<|9.70|>": 50849,
1489
+ "<|9.72|>": 50850,
1490
+ "<|9.74|>": 50851,
1491
+ "<|9.76|>": 50852,
1492
+ "<|9.78|>": 50853,
1493
+ "<|9.80|>": 50854,
1494
+ "<|9.82|>": 50855,
1495
+ "<|9.84|>": 50856,
1496
+ "<|9.86|>": 50857,
1497
+ "<|9.88|>": 50858,
1498
+ "<|9.90|>": 50859,
1499
+ "<|9.92|>": 50860,
1500
+ "<|9.94|>": 50861,
1501
+ "<|9.96|>": 50862,
1502
+ "<|9.98|>": 50863,
1503
+ "<|af|>": 50327,
1504
+ "<|am|>": 50334,
1505
+ "<|ar|>": 50272,
1506
+ "<|as|>": 50350,
1507
+ "<|az|>": 50304,
1508
+ "<|ba|>": 50355,
1509
+ "<|be|>": 50330,
1510
+ "<|bg|>": 50292,
1511
+ "<|bn|>": 50302,
1512
+ "<|bo|>": 50347,
1513
+ "<|br|>": 50309,
1514
+ "<|bs|>": 50315,
1515
+ "<|ca|>": 50270,
1516
+ "<|cs|>": 50283,
1517
+ "<|cy|>": 50297,
1518
+ "<|da|>": 50285,
1519
+ "<|de|>": 50261,
1520
+ "<|el|>": 50281,
1521
+ "<|en|>": 50259,
1522
+ "<|es|>": 50262,
1523
+ "<|et|>": 50307,
1524
+ "<|eu|>": 50310,
1525
+ "<|fa|>": 50300,
1526
+ "<|fi|>": 50277,
1527
+ "<|fo|>": 50338,
1528
+ "<|fr|>": 50265,
1529
+ "<|gl|>": 50319,
1530
+ "<|gu|>": 50333,
1531
+ "<|haw|>": 50352,
1532
+ "<|ha|>": 50354,
1533
+ "<|he|>": 50279,
1534
+ "<|hi|>": 50276,
1535
+ "<|hr|>": 50291,
1536
+ "<|ht|>": 50339,
1537
+ "<|hu|>": 50286,
1538
+ "<|hy|>": 50312,
1539
+ "<|id|>": 50275,
1540
+ "<|is|>": 50311,
1541
+ "<|it|>": 50274,
1542
+ "<|ja|>": 50266,
1543
+ "<|jw|>": 50356,
1544
+ "<|ka|>": 50329,
1545
+ "<|kk|>": 50316,
1546
+ "<|km|>": 50323,
1547
+ "<|kn|>": 50306,
1548
+ "<|ko|>": 50264,
1549
+ "<|la|>": 50294,
1550
+ "<|lb|>": 50345,
1551
+ "<|ln|>": 50353,
1552
+ "<|lo|>": 50336,
1553
+ "<|lt|>": 50293,
1554
+ "<|lv|>": 50301,
1555
+ "<|mg|>": 50349,
1556
+ "<|mi|>": 50295,
1557
+ "<|mk|>": 50308,
1558
+ "<|ml|>": 50296,
1559
+ "<|mn|>": 50314,
1560
+ "<|mr|>": 50320,
1561
+ "<|ms|>": 50282,
1562
+ "<|mt|>": 50343,
1563
+ "<|my|>": 50346,
1564
+ "<|ne|>": 50313,
1565
+ "<|nl|>": 50271,
1566
+ "<|nn|>": 50342,
1567
+ "<|nocaptions|>": 50362,
1568
+ "<|notimestamps|>": 50363,
1569
+ "<|no|>": 50288,
1570
+ "<|oc|>": 50328,
1571
+ "<|pa|>": 50321,
1572
+ "<|pl|>": 50269,
1573
+ "<|ps|>": 50340,
1574
+ "<|pt|>": 50267,
1575
+ "<|ro|>": 50284,
1576
+ "<|ru|>": 50263,
1577
+ "<|sa|>": 50344,
1578
+ "<|sd|>": 50332,
1579
+ "<|si|>": 50322,
1580
+ "<|sk|>": 50298,
1581
+ "<|sl|>": 50305,
1582
+ "<|sn|>": 50324,
1583
+ "<|so|>": 50326,
1584
+ "<|sq|>": 50317,
1585
+ "<|sr|>": 50303,
1586
+ "<|startoflm|>": 50360,
1587
+ "<|startofprev|>": 50361,
1588
+ "<|startoftranscript|>": 50258,
1589
+ "<|su|>": 50357,
1590
+ "<|sv|>": 50273,
1591
+ "<|sw|>": 50318,
1592
+ "<|ta|>": 50287,
1593
+ "<|te|>": 50299,
1594
+ "<|tg|>": 50331,
1595
+ "<|th|>": 50289,
1596
+ "<|tk|>": 50341,
1597
+ "<|tl|>": 50348,
1598
+ "<|transcribe|>": 50359,
1599
+ "<|translate|>": 50358,
1600
+ "<|tr|>": 50268,
1601
+ "<|tt|>": 50351,
1602
+ "<|uk|>": 50280,
1603
+ "<|ur|>": 50290,
1604
+ "<|uz|>": 50337,
1605
+ "<|vi|>": 50278,
1606
+ "<|yi|>": 50335,
1607
+ "<|yo|>": 50325,
1608
+ "<|zh|>": 50260
1609
+ }
computer-vision-study-group/Notebooks/HuggingFace_vision_ecosystem_overview_(June_2022).ipynb ADDED
The diff for this file is too large to render. See raw diff
 
computer-vision-study-group/README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Computer Vision Study Group
2
+
3
+ This is a collection of all past sessions that have been held as part of the Hugging Face Computer Vision Study Group.
4
+
5
+ | |Session Name | Session Link |
6
+ |--- |--- | --- |
7
+ |❓|How Do Vision Transformers Work? | [Session Sheet](Sessions/HowDoVisionTransformersWork.md) |
8
+ |🔅|Polarized Self-Attention | [Session Sheet](Sessions/PolarizedSelfAttention.md)|
9
+ |🍄|Swin Transformer | [Session Sheet](Sessions/SwinTransformer.md)|
10
+ |🔮|Introduction to Neural Radiance Fields | [Session Sheet](Sessions/NeuralRadianceFields.md)|
11
+ |🌐|Hugging Face Vision Ecosystem Overview (June 2022) | [Session Sheet](Sessions/HFVisionEcosystem.md)|
12
+ |🪂|Masked Autoencoders Are Scalable Vision Learners | [Session Sheet](Sessions/MaskedAutoEncoders.md)|
13
+ |🦊|Fiber: Coarse-to-Fine Vision-Language Pre-Training | [Session Sheet](Sessions/Fiber.md)|
14
+ |⚔️ |FlexiViT: One Model for All Patch Sizes| [Session Sheet](Sessions/FlexiViT.md)|
15
+ |🤖|BLIP-2: Bootstrapping Language-Image Pre-training| [Session Sheet](Sessions/Blip2.md)|
computer-vision-study-group/Sessions/Blip2.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models
2
+ Session by [johko](https://github.com/johko)
3
+
4
+
5
+ ## Recording 📺
6
+ [YouTube](https://www.youtube.com/watch?v=k0DAtZCCl1w&pp=ygUdaHVnZ2luZyBmYWNlIHN0dWR5IGdyb3VwIHN3aW4%3D)
7
+
8
+
9
+ ## Session Slides 🖥️
10
+ [Google Drive](https://docs.google.com/presentation/d/1Y_8Qu0CMlt7jvCd8Jw0c_ILh8LHB0XgnlrvXObe5FYs/edit?usp=sharing)
11
+
12
+
13
+ ## Original Paper 📄
14
+ [Hugging Face](https://huggingface.co/papers/2301.12597) /
15
+ [arxiv](https://arxiv.org/abs/2301.12597)
16
+
17
+
18
+ ## GitHub Repo 🧑🏽‍💻
19
+ https://github.com/salesforce/lavis
20
+
21
+
22
+ ## Additional Resources 📚
23
+ - [BLIP-2 Demo Space](https://huggingface.co/spaces/hysts/BLIP2-with-transformers)
24
+ - [BLIP-2 Transformers Example Notebooks](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/BLIP-2) by Niels Rogge
25
+ - [BLIP-2 Transformers Docs](https://huggingface.co/docs/transformers/model_doc/blip-2)
computer-vision-study-group/Sessions/Fiber.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fiber: Coarse-to-Fine Vision-Language Pre-Training with Fusion in the Backbone
2
+ Session by [johko](https://github.com/johko)
3
+
4
+
5
+ ## Recording 📺
6
+ [YouTube](https://www.youtube.com/watch?v=m9qhNGuWE2g&t=20s&pp=ygUdaHVnZ2luZyBmYWNlIHN0dWR5IGdyb3VwIHN3aW4%3D)
7
+
8
+
9
+ ## Session Slides 🖥️
10
+ [Google Drive](https://docs.google.com/presentation/d/1vSu27tE87ZM103_CkgqsW7JeIp2mrmyl/edit?usp=sharing&ouid=107717747412022342990&rtpof=true&sd=true)
11
+
12
+
13
+ ## Original Paper 📄
14
+ [Hugging Face](https://huggingface.co/papers/2206.07643) /
15
+ [arxiv](https://arxiv.org/abs/2206.07643)
16
+
17
+
18
+ ## GitHub Repo 🧑🏽‍💻
19
+ https://github.com/microsoft/fiber
20
+
21
+
22
+ ## Additional Resources 📚
23
+ - [Text to Pokemon](https://huggingface.co/spaces/lambdalabs/text-to-pokemon) HF Space to create your own Pokemon
24
+ - [Paper to Pokemon](https://huggingface.co/spaces/hugging-fellows/paper-to-pokemon) derived from the above space - create your own Pokemon from a paper
computer-vision-study-group/Sessions/FlexiViT.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FlexiViT: One Model for All Patch Sizes
2
+ Session by [johko](https://github.com/johko)
3
+
4
+
5
+ ## Recording 📺
6
+ [YouTube](https://www.youtube.com/watch?v=TlRYBgsl7Q8&t=977s&pp=ygUdaHVnZ2luZyBmYWNlIHN0dWR5IGdyb3VwIHN3aW4%3D)
7
+
8
+
9
+ ## Session Slides 🖥️
10
+ [Google Drive](https://docs.google.com/presentation/d/1rLAYr160COYQMUN0FDH7D9pP8qe1_QyXGvfbHkutOt8/edit?usp=sharing)
11
+
12
+
13
+ ## Original Paper 📄
14
+ [Hugging Face](https://huggingface.co/papers/2212.08013) /
15
+ [arxiv](https://arxiv.org/abs/2212.08013)
16
+
17
+
18
+ ## GitHub Repo 🧑🏽‍💻
19
+ https://github.com/google-research/big_vision
20
+
21
+
22
+ ## Additional Resources 📚
23
+ - [FlexiViT PR](https://github.com/google-research/big_vision/pull/24)
computer-vision-study-group/Sessions/HFVisionEcosystem.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Vision Ecosystem Overview (June 2022)
2
+ Session by [Niels Rogge](https://github.com/NielsRogge)
3
+
4
+
5
+ ## Recording 📺
6
+ [YouTube](https://www.youtube.com/watch?v=oL-xmufhZM8&pp=ygUdaHVnZ2luZyBmYWNlIHN0dWR5IGdyb3VwIHN3aW4%3D)
7
+
8
+
9
+ ## Additional Resources 📚
10
+ - [Accompanying Notebook](../Notebooks/HuggingFace_vision_ecosystem_overview_(June_2022).ipynb)
computer-vision-study-group/Sessions/HowDoVisionTransformersWork.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How Do Vision Transformers Work
2
+ Session by [johko](https://github.com/johko)
3
+
4
+
5
+ ## Session Slides 🖥️
6
+ [Google Drive](https://docs.google.com/presentation/d/1PewOHVABkxx0jO9PoJSQi8to_WNlL4HdDp4M9e4L8hs/edit?usp=drivesdks)
7
+
8
+
9
+ ## Original Paper 📄
10
+ [Hugging Face](https://huggingface.co/papers/2202.06709) /
11
+ [arxiv](https://arxiv.org/pdf/2202.06709.pdf)
12
+
13
+
14
+ ## GitHub Repo 🧑🏽‍💻
15
+ https://github.com/microsoft/Swin-Transformer
16
+
17
+
18
+ ## Additional Resources 📚
19
+ Hessian Matrices:
20
+
21
+ - https://stackoverflow.com/questions/23297090/how-calculating-hessian-works-for-neural-network-learning
22
+ - https://machinelearningmastery.com/a-gentle-introduction-to-hessian-matrices/
23
+
24
+ Loss Landscape Visualization:
25
+
26
+ - https://mathformachines.com/posts/visualizing-the-loss-landscape/
27
+ - https://github.com/tomgoldstein/loss-landscape
computer-vision-study-group/Sessions/MaskedAutoEncoders.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Masked Autoencoders are Scalable Vision Learners
2
+ Session by [johko](https://github.com/johko)
3
+
4
+
5
+ ## Recording 📺
6
+ [YouTube](https://www.youtube.com/watch?v=AC6flxUFLrg&pp=ygUdaHVnZ2luZyBmYWNlIHN0dWR5IGdyb3VwIHN3aW4%3D)
7
+
8
+
9
+ ## Session Slides 🖥️
10
+ [Google Drive](https://docs.google.com/presentation/d/10ZZ-Rl1D57VX005a58OmqNeOB6gPnE54/edit?usp=sharing&ouid=107717747412022342990&rtpof=true&sd=true)
11
+
12
+
13
+ ## Original Paper 📄
14
+ [Hugging Face](https://huggingface.co/papers/2111.06377) /
15
+ [arxiv](https://arxiv.org/abs/2111.06377)
16
+
17
+
18
+ ## GitHub Repo 🧑🏽‍💻
19
+ https://github.com/facebookresearch/mae
20
+
21
+
22
+ ## Additional Resources 📚
23
+ - [Transformers Docs ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)
24
+ - [Transformers ViTMAE Demo Notebook](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/ViTMAE) by Niels Rogge
computer-vision-study-group/Sessions/NeuralRadianceFields.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Introduction to Neural Radiance Fields
2
+ Session by [Aritra](https://arig23498.github.io/) and [Ritwik](ritwikraha.github.io)
3
+
4
+
5
+ ## Recording 📺
6
+ [YouTube](https://www.youtube.com/watch?v=U2XS7SxOy2s)
7
+
8
+
9
+ ## Session Slides 🖥️
10
+ [Google Drive](https://docs.google.com/presentation/d/e/2PACX-1vTQVnoTJGhRxDscNV1Mg2aYhvXP8cKODpB5Ii72NWoetCGrTLBJWx_UD1oPXHrzPtj7xO8MS_3TQaSH/pub?start=false&loop=false&delayms=3000)
11
+
12
+
13
+ ## Original Paper 📄
14
+ [Hugging Face](https://huggingface.co/papers/2003.08934) /
15
+ [arxiv](https://arxiv.org/abs/2003.08934)
16
+
17
+
18
+ ## GitHub Repo 🧑🏽‍💻
19
+ https://github.com/bmild/nerf
computer-vision-study-group/Sessions/PolarizedSelfAttention.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Polarized Self-Attention
2
+ Session by [Satpal](https://github.com/satpalsr)
3
+
4
+ ## Session Slides 🖥️
5
+ [GitHub PDF](https://github.com/satpalsr/Talks/blob/main/PSA_discussion.pdf)
6
+
7
+
8
+ ## Original Paper 📄
9
+ [Hugging Face](https://huggingface.co/papers/2107.00782) /
10
+ [arxiv](https://arxiv.org/pdf/2107.00782.pdf)
11
+
12
+
13
+ ## GitHub Repo 🧑🏽‍💻
14
+ https://github.com/DeLightCMU/PSA
computer-vision-study-group/Sessions/SwinTransformer.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Swin Transformer
2
+ Session by [johko](https://github.com/johko)
3
+
4
+
5
+ ## Recording 📺
6
+ [YouTube](https://www.youtube.com/watch?v=Ngikt-K1Ecc&t=305s&pp=ygUdaHVnZ2luZyBmYWNlIHN0dWR5IGdyb3VwIHN3aW4%3D)
7
+
8
+
9
+ ## Session Slides 🖥️
10
+ [Google Drive](https://docs.google.com/presentation/d/1RoFIC6vE55RS4WNqSlzNu3ljB6F-_8edtprAFXpGvKs/edit?usp=sharing)
11
+
12
+
13
+ ## Original Paper 📄
14
+ [Hugging Face](https://huggingface.co/papers/2103.14030) /
15
+ [arxiv](https://arxiv.org/pdf/2103.14030.pdf)
16
+
17
+
18
+ ## GitHub Repo 🧑🏽‍💻
19
+ https://github.com/xxxnell/how-do-vits-work
20
+
21
+
22
+ ## Additional Resources 📚
23
+ - [Transformers Docs Swin v1](https://huggingface.co/docs/transformers/model_doc/swin)
24
+ - [Transformers Docs Swin v2](https://huggingface.co/docs/transformers/model_doc/swinv2)
25
+ - [Transformers Docs Swin Super Resolution](https://huggingface.co/docs/transformers/model_doc/swin2sr)
config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openai/whisper-small",
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "apply_spec_augment": false,
6
+ "architectures": [
7
+ "WhisperForConditionalGeneration"
8
+ ],
9
+ "attention_dropout": 0.0,
10
+ "begin_suppress_tokens": [
11
+ 220,
12
+ 50257
13
+ ],
14
+ "bos_token_id": 50257,
15
+ "classifier_proj_size": 256,
16
+ "d_model": 768,
17
+ "decoder_attention_heads": 12,
18
+ "decoder_ffn_dim": 3072,
19
+ "decoder_layerdrop": 0.0,
20
+ "decoder_layers": 12,
21
+ "decoder_start_token_id": 50258,
22
+ "dropout": 0.0,
23
+ "encoder_attention_heads": 12,
24
+ "encoder_ffn_dim": 3072,
25
+ "encoder_layerdrop": 0.0,
26
+ "encoder_layers": 12,
27
+ "eos_token_id": 50257,
28
+ "forced_decoder_ids": null,
29
+ "init_std": 0.02,
30
+ "is_encoder_decoder": true,
31
+ "mask_feature_length": 10,
32
+ "mask_feature_min_masks": 0,
33
+ "mask_feature_prob": 0.0,
34
+ "mask_time_length": 10,
35
+ "mask_time_min_masks": 2,
36
+ "mask_time_prob": 0.05,
37
+ "max_length": 448,
38
+ "max_source_positions": 1500,
39
+ "max_target_positions": 448,
40
+ "median_filter_width": 7,
41
+ "model_type": "whisper",
42
+ "num_hidden_layers": 12,
43
+ "num_mel_bins": 80,
44
+ "pad_token_id": 50257,
45
+ "scale_embedding": false,
46
+ "suppress_tokens": [],
47
+ "torch_dtype": "float32",
48
+ "transformers_version": "4.40.0.dev0",
49
+ "use_cache": false,
50
+ "use_weighted_layer_sum": false,
51
+ "vocab_size": 51865
52
+ }
gradio-blocks/README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Welcome to the [Gradio](https://gradio.app/) Blocks Party 🥳
2
+
3
+ ![image (1)](https://user-images.githubusercontent.com/81195143/167954125-9854bf6b-4ae5-4735-8fdd-830fec41efa1.png)
4
+
5
+
6
+ _**Timeline**: May 17th, 2022 - May 31st, 2022_
7
+
8
+ ---
9
+
10
+ We are happy to invite you to the Gradio Blocks Party - a community event in which we will create **interactive demos** for state-of-the-art machine learning models. Demos are powerful because they allow anyone — not just ML engineers — to try out models in the browser, give feedback on predictions, identify trustworthy models. The event will take place from **May 17th to 31st**. We will be organizing this event on [Github](https://github.com/huggingface/community-events) and the [Hugging Face discord channel](https://discord.com/invite/feTf9x3ZSB). Prizes will be given at the end of the event, see: [Prizes](#prizes)
11
+
12
+ <img src="https://user-images.githubusercontent.com/81195143/168656398-ace7acc9-ef7a-4e90-a9cd-c7d15dd800e1.gif" width="1160" height="600"/>
13
+
14
+ ## What is Gradio?
15
+
16
+ Gradio is a Python library that allows you to quickly build web-based machine learning demos, data science dashboards, or other kinds of web apps, entirely in Python. These web apps can be launched from wherever you use Python (jupyter notebooks, colab notebooks, Python terminal, etc.) and shared with anyone instantly using Gradio's auto-generated share links. To learn more about Gradio see the Getting Started Guide: https://gradio.app/getting_started/ and the new Course on Huggingface about Gradio: [Gradio Course](https://huggingface.co/course/chapter9/1?fw=pt).
17
+
18
+ Gradio can be installed via pip and comes preinstalled in Hugging Face Spaces, the latest version of Gradio can be set in the README in spaces by setting the sdk_version for example `sdk_version: 3.0b8`
19
+
20
+ `pip install gradio` to install gradio locally
21
+
22
+
23
+ ## What is Blocks?
24
+
25
+ `gradio.Blocks` is a low-level API that allows you to have full control over the data flows and layout of your application. You can build very complex, multi-step applications using Blocks. If you have already used `gradio.Interface`, you know that you can easily create fully-fledged machine learning demos with just a few lines of code. The Interface API is very convenient but in some cases may not be sufficiently flexible for your needs. For example, you might want to:
26
+
27
+ * Group together related demos as multiple tabs in one web app.
28
+ * Change the layout of your demo instead of just having all of the inputs on the left and outputs on the right.
29
+ * Have multi-step interfaces, in which the output of one model becomes the input to the next model, or have more flexible data flows in general.
30
+ * Change a component's properties (for example, the choices in a Dropdown) or its visibility based on user input.
31
+
32
+ To learn more about Blocks, see the [official guide](https://www.gradio.app/introduction_to_blocks/) and the [docs](https://gradio.app/docs/).
33
+
34
+ ## What is Hugging Face Spaces?
35
+
36
+ Spaces are a simple way to host ML demo apps directly on your profile or your organization’s profile on Hugging Face. This allows you to create your ML portfolio, showcase your projects at conferences or to stakeholders, and work collaboratively with other people in the ML ecosystem. Learn more about Spaces in the [docs](https://huggingface.co/docs/hub/spaces).
37
+
38
+ ## How Do Gradio and Hugging Face work together?
39
+
40
+ Hugging Face Spaces is a free hosting option for Gradio demos. Spaces comes with 3 SDK options: Gradio, Streamlit and Static HTML demos. Spaces can be public or private and the workflow is similar to github repos. There are over 2000+ Gradio spaces currently on Hugging Face. Learn more about spaces and gradio: https://huggingface.co/docs/hub/spaces
41
+
42
+ ## Event Plan
43
+
44
+ main components of the event consist of:
45
+
46
+ 1. Learning about Gradio and the new Blocks Feature
47
+ 2. Building your own Blocks demo using Gradio and Hugging Face Spaces
48
+ 3. Submitting your demo on Spaces to the Gradio Blocks Party Organization
49
+ 4. Share your blocks demo with a permanent shareable link
50
+ 5. Win Prizes
51
+
52
+
53
+ ## Example spaces using Blocks
54
+
55
+ <img width="1180" alt="mindseye-lite" src="https://user-images.githubusercontent.com/81195143/168619604-cf1ac733-c10e-487f-add4-8da48002dcff.png">
56
+
57
+ - [dalle-mini](https://huggingface.co/spaces/dalle-mini/dalle-mini)([Code](https://huggingface.co/spaces/dalle-mini/dalle-mini/blob/main/app/gradio/app.py))
58
+ - [mindseye-lite](https://huggingface.co/spaces/multimodalart/mindseye-lite)([Code](https://huggingface.co/spaces/multimodalart/mindseye-lite/blob/main/app.py))
59
+ - [ArcaneGAN-blocks](https://huggingface.co/spaces/akhaliq/ArcaneGAN-blocks)([Code](https://huggingface.co/spaces/akhaliq/ArcaneGAN-blocks/blob/main/app.py))
60
+ - [gr-blocks](https://huggingface.co/spaces/merve/gr-blocks)([Code](https://huggingface.co/spaces/merve/gr-blocks/blob/main/app.py))
61
+ - [tortoisse-tts](https://huggingface.co/spaces/osanseviero/tortoisse-tts)([Code](https://huggingface.co/spaces/osanseviero/tortoisse-tts/blob/main/app.py))
62
+ - [CaptchaCracker](https://huggingface.co/spaces/osanseviero/tortoisse-tts)([Code](https://huggingface.co/spaces/akhaliq/CaptchaCracker/blob/main/app.py))
63
+
64
+
65
+ ## To participate in the event
66
+
67
+ - Join the organization for Blocks event
68
+ - [https://huggingface.co/Gradio-Blocks](https://huggingface.co/Gradio-Blocks)
69
+ - Join the discord
70
+ - [discord](https://discord.com/invite/feTf9x3ZSB)
71
+
72
+
73
+ Participants will be building and sharing Gradio demos using the Blocks feature. We will share a list of ideas of spaces that can be created using blocks or participants are free to try out their own ideas. At the end of the event, spaces will be evaluated and prizes will be given.
74
+
75
+
76
+ ## Potential ideas for creating spaces:
77
+
78
+
79
+ - Trending papers from https://paperswithcode.com/
80
+ - Models from huggingface model hub: https://huggingface.co/models
81
+ - Models from other model hubs
82
+ - Tensorflow Hub: see example Gradio demos at https://huggingface.co/tensorflow
83
+ - Pytorch Hub: see example Gradio demos at https://huggingface.co/pytorch
84
+ - ONNX model Hub: see example Gradio demos at https://huggingface.co/onnx
85
+ - PaddlePaddle Model Hub: see example Gradio demos at https://huggingface.co/PaddlePaddle
86
+ - participant ideas, try out your own ideas
87
+
88
+
89
+ ## Prizes
90
+ - 1st place winner based on likes
91
+ - [Hugging Face PRO subscription](https://huggingface.co/pricing) for 1 year
92
+ - Embedding your Gradio Blocks demo in the Gradio Blog
93
+ - top 10 winners based on likes
94
+ - Swag from [Hugging Face merch shop](https://huggingface.myshopify.com/): t-shirts, hoodies, mugs of your choice
95
+ - top 25 winners based on likes
96
+ - [Hugging Face PRO subscription](https://huggingface.co/pricing) for 1 month
97
+ - Blocks event badge on HF for all participants!
98
+
99
+ ## Prizes Criteria
100
+
101
+ - Staff Picks
102
+ - Most liked Spaces
103
+ - Community Pick (voting)
104
+ - Most Creative Space (voting)
105
+ - Most Educational Space (voting)
106
+ - CEO's pick (one prize for a particularly impactful demo), picked by @clem
107
+ - CTO's pick (one prize for a particularly technically impressive demo), picked by @julien
108
+
109
+
110
+ ## Creating a Gradio demo on Hugging Face Spaces
111
+
112
+ Once a model has been picked from the choices above or feel free to try your own idea, you can share a model in a Space using Gradio
113
+
114
+ Read more about how to add [Gradio spaces](https://huggingface.co/blog/gradio-spaces).
115
+
116
+ Steps to add Gradio Spaces to the Gradio Blocks Party org
117
+ 1. Create an account on Hugging Face
118
+ 2. Join the Gradio Blocks Party Organization by clicking "Join Organization" button in the organization page or using the shared link above
119
+ 3. Once your request is approved, add your space using the Gradio SDK and share the link with the community!
120
+
121
+ ## LeaderBoard for Most Popular Blocks Event Spaces based on Likes
122
+
123
+ - See Leaderboard: https://huggingface.co/spaces/Gradio-Blocks/Leaderboard
huggan/README.md ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HugGAN Sprint
2
+
3
+ ![Banner](assets/huggan_banner.png?raw=true "Banner")
4
+
5
+ _**Timeline**: April 4th, 2022 - April 17th, 2022_
6
+
7
+ ---
8
+
9
+ Welcome to HugGAN Sprint! The goal of this sprint is to add more GANs and GAN-based demos to the Hugging Face Hub 🤗.
10
+
11
+ During the sprint, we’ll be bringing in some awesome speakers to talk about GANs and the future of generative models. Oh, and if you need access to compute for your project, we’ll help you there too! As an added bonus, if you choose to participate, we’ll send you a gift (specific details TBD). We encourage you to form teams of ~2-3 people! Make friends in the Discord :)
12
+
13
+ To join:
14
+
15
+ 1. Fill out [this form](https://forms.gle/goq41UgzsvuKKTFFA), so we can keep track of who’s joining.
16
+ 2. Send a reaction in the [#join-sprint channel](https://discord.com/channels/879548962464493619/954070850645135462) under the HugGAN category in Discord. This will add you to the rest of the related channels. If you haven't joined our discord yet, [click here](https://discord.gg/H3bUrDPTfS).
17
+ 3. Once you’ve decided what you want to work on, add your project’s information to [this sheet](https://docs.google.com/spreadsheets/d/1aAHqOOk2SOw4j6mrJLkLT6ZyKyLDOvGF5D9tuUqnoG8/edit#gid=0), where you can describe your project and let us know if you need additional compute. Still brainstorming? Feel free to propose ideas in #sprint-discussions.
18
+
19
+ ## Table of Contents
20
+
21
+ - [Important dates](#important-dates)
22
+ - [How to install relevant libraries](#how-to-install-relevant-libraries)
23
+ - [General workflow](#general-workflow)
24
+ - [Datasets to add](#datasets-to-add)
25
+ - [Links to check out](#links-to-check-out)
26
+ - [GAN metrics](#gan-metrics)
27
+ - [Evaluation](#evaluation)
28
+ - [Prizes](#prizes)
29
+ - [Communication and Problems](#communication-and-problems)
30
+ - [Talks](#talks)
31
+ - [General Tips & Tricks](#general-tips-and-tricks)
32
+
33
+ ## Important dates
34
+
35
+ | Date | Description |
36
+ | ----------- | ----------- |
37
+ | April 4th | Sprint Kickoff 🚀 |
38
+ | April 15th | Submission Deadline 🛑 |
39
+ | April 22nd | Prizes Announced for Participants 🎁 |
40
+
41
+ ## How to install relevant libraries
42
+
43
+ You'll need the following dependencies installed to use this repo:
44
+
45
+ - [PyTorch](https://pytorch.org/) or [Keras](https://keras.io/) - depending on which framework you prefer ;)
46
+ - [🤗 Datasets](https://huggingface.co/docs/datasets/index)
47
+ - [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) - in case you're planning to train a PyTorch model and you want it to be run effortlessly
48
+
49
+ We recommend installing the above libraries in a [virtual environment](https://docs.python.org/3/library/venv.html).
50
+ If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). Create a virtual environment with the version of Python you're going to use and activate it.
51
+
52
+ You should be able to run the command:
53
+
54
+ ```bash
55
+ python3 -m venv <your-venv-name>
56
+ ```
57
+
58
+ You can activate your venv by running
59
+
60
+ ```bash
61
+ source ~/<your-venv-name>/bin/activate
62
+ ```
63
+
64
+ ### Install Dependencies
65
+
66
+ We've packaged up the example scripts here into a simple Python package. To install it, just pip install it
67
+
68
+ ```
69
+ git clone https://github.com/huggingface/community-events.git
70
+ cd community-events
71
+ pip install .
72
+ ```
73
+
74
+ If you use `pip install -e .` instead of `pip install`, it will install the package in development mode, which can be useful if you are planning on contributing any changes here 🤗.
75
+
76
+ ## General workflow
77
+
78
+ The process to follow is outlined below. It consists of 3 steps:
79
+
80
+ 1. Get a dataset and push to the Hub
81
+ 2. Train a model and push to the Hub
82
+ 3. Create a demo (🤗 Space)
83
+
84
+ These steps are explained in more detail below.
85
+
86
+ ### 1. Get a dataset and push to Hub
87
+
88
+ The first step is the most obvious one: to train a GAN (or any neural network), we need a dataset. This could be either a dataset that is already available on the [Hub](https://huggingface.co/datasets), or one that isn't already. Below we'll explain how to load the data in both cases.
89
+
90
+ Note that we maintain a list of interesting datasets to add to the Hub [here](#datasets-to-add).
91
+
92
+ #### 1.1 Use a dataset already available on the Hub
93
+
94
+ Most famous computer vision datasets are already available on the [Hub](https://huggingface.co/datasets?task_categories=task_categories:image-classification) (such as [MNIST](https://huggingface.co/datasets/mnist), [Fashion MNIST](https://huggingface.co/datasets/fashion_mnist), [CIFAR-10](https://huggingface.co/datasets/cifar10), [CIFAR-100](https://huggingface.co/datasets/cifar100), etc.).
95
+
96
+ Loading a dataset can be done as follows:
97
+
98
+ ```python
99
+ from datasets import load_dataset
100
+
101
+ # a general one ...
102
+ dataset = load_dataset("mnist")
103
+
104
+ # ... or one that's part of the huggan organization
105
+ dataset = load_dataset("huggan/edges2shoes")
106
+ ```
107
+
108
+ In a notebook, you can **directly see** the images by selecting a split and then the appropriate column:
109
+
110
+ ```python
111
+ example = dataset['train'][0]
112
+ print(example['image'])
113
+ ```
114
+
115
+ #### 1.2 Upload a new dataset to the Hub
116
+
117
+ In case your dataset is not already on the Hub, you can upload it to the `huggan` [organization](https://huggingface.co/huggan). If you've signed up for the event by filling in the [spreadsheet]((https://docs.google.com/spreadsheets/d/1aAHqOOk2SOw4j6mrJLkLT6ZyKyLDOvGF5D9tuUqnoG8/edit#gid=0)), your Hugging Face account should be part of it.
118
+
119
+ Let's illustrate with an example how this was done for NVIDIA's [MetFaces dataset](https://github.com/NVlabs/metfaces-dataset):
120
+
121
+ <p align="center">
122
+ <img src="https://github.com/NVlabs/metfaces-dataset/blob/master/img/metfaces-teaser.png" alt="drawing" width="700"/>
123
+ </p>
124
+
125
+ Previously, this dataset was only hosted on [Google Drive](https://github.com/NVlabs/metfaces-dataset#overview), and not really easily accessible.
126
+
127
+ To begin with, one should check that one is correctly logged in and that `git-lfs` is installed so that the dataset can be uploaded.
128
+
129
+ Run:
130
+
131
+ ```bash
132
+ huggingface-cli login
133
+ ```
134
+
135
+ in a terminal, or case you're working in a notebook
136
+
137
+ ```python
138
+ from huggingface_hub import notebook_login
139
+
140
+ notebook_login()
141
+ ```
142
+
143
+ It is recommended to login with your access token that can be found under your HuggingFace profile (icon in the top right corner on [hf.co](http://hf.co/), then Settings -> Access Tokens -> User Access Tokens -> New Token (if you haven't generated one already). Alternatively, you can go to [your token settings](https://huggingface.co/settings/tokens) directly.
144
+
145
+ You can then copy-paste this token to log in locally.
146
+
147
+ Next, let's make sure that `git-lfs` is correctly installed. To so, simply run:
148
+
149
+ ```bash
150
+ git-lfs -v
151
+ ```
152
+
153
+ The output should show something like `git-lfs/2.13.2 (GitHub; linux amd64; go 1.15.4)`. If your console states that the `git-lfs` command was not found, please make sure to install it [here](https://git-lfs.github.com/) or simply via:
154
+
155
+ ```bash
156
+ sudo apt-get install git-lfs
157
+ git config --global user.email "you@example.com"
158
+ git config --global user.name "Your Name"
159
+ ```
160
+
161
+ Next, one can leverage the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) builder to very easily upload an image dataset to the hub. In case the dataset you're uploading has a direct download URL, you can simply provide it to the `data_files` argument as shown below. Otherwise, you'll need to go to the link of the dataset and manually download it first as a zip/tar (which was the case for MetFaces), and provide the file through the `data_files` argument. Alternatively, it may be that you have a folder with images, in which case you can provide it using the `data_dir` argument. Note that the latter assumes a [particular structure](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder).
162
+
163
+ ```python
164
+ from datasets import load_dataset
165
+
166
+ # option 1: local folder
167
+ dataset = load_dataset("imagefolder", data_dir="path_to_folder")
168
+ # option 2: local or remote file(s), supporting the following extensions: tar, gzip, zip, xz, rar, zstd
169
+ dataset = load_dataset("imagefolder", data_files="path_to_file_or_direct_download_link")
170
+
171
+ # note that you can also provide them as separate splits, like so:
172
+ dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]})
173
+ ```
174
+
175
+ Once you've loaded your dataset, you can push it to the Hub with a single line of code:
176
+
177
+ ```python
178
+ dataset.push_to_hub("huggan/name-of-your-dataset")
179
+ ```
180
+
181
+ Et voila! Your dataset is now available on the Hub :) If you wait a bit, the Dataset viewer should be able to preview images in the browser. The MetFaces dataset can be seen here: https://huggingface.co/datasets/huggan/metfaces.
182
+
183
+ <p align="center">
184
+ <img src="https://github.com/huggingface/community-events/blob/main/huggan/assets/metfaces.png" alt="drawing" width="700"/>
185
+ </p>
186
+
187
+ The cool thing is that anyone can now access this dataset from anywhere, using `load_dataset` 🎉🥳 this means that you can easily load the dataset on another computer for instance, or in a different environment. Amazing, isn't it?
188
+
189
+ ❗ Note: When uploading a dataset, make sure that it has appropriate column names. The `ImageFolder` utility automatically creates `image` and `label` columns, however if there's only one image class, it makes sense to remove the `label` column before pushing to the hub. This can be done as follows:
190
+
191
+ ```python
192
+ dataset = dataset.remove_columns("label")
193
+ ```
194
+
195
+ Note that you can always update a dataset by simply calling `push_to_hub` again (providing the same name).
196
+
197
+ #### 1.3 Processing the data
198
+
199
+ Once you've uploaded your dataset, you can load it and create a dataloader for it. The code example below shows how to apply some data augmentation and creating a PyTorch Dataloader (the [PyTorch example scripts](pytorch) all leverage this). More info can also be found in the [docs](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#process-image-data).
200
+
201
+ ```python
202
+ from datasets import load_dataset
203
+ from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
204
+ from torch.utils.data import DataLoader
205
+
206
+ # load your data
207
+ dataset = load_dataset("dataset_name")
208
+
209
+ image_size = 256
210
+
211
+ # define image transformations (e.g. using torchvision)
212
+ transform = Compose(
213
+ [
214
+ Resize(image_size),
215
+ CenterCrop(image_size),
216
+ ToTensor(),
217
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
218
+ ]
219
+ )
220
+
221
+ # define function
222
+ def transforms(examples):
223
+ examples["image"] = [transform(image.convert("RGB")) for image in examples["image"]]
224
+
225
+ return examples
226
+
227
+ transformed_dataset = dataset.with_transform(transforms)
228
+
229
+ # create dataloader
230
+ dataloader = DataLoader(
231
+ transformed_dataset["train"], batch_size="your batch size", shuffle=True, num_workers="your number of CPU cores"
232
+ )
233
+ ```
234
+
235
+ As can be seen, we leverage the [`with_transform`](https://huggingface.co/docs/datasets/v2.0.0/en/package_reference/main_classes#datasets.Dataset.with_transform) method here, which will make sure the image transformations will only be performed when iterating over the data (i.e. data augmentation is performed on-the-fly, making it very RAM-friendly) rather than performing it on the entire dataset in one go (which would be the case if you use [`map`](https://huggingface.co/docs/datasets/v2.0.0/en/package_reference/main_classes#datasets.Dataset.map)). The `with_transform` method does the same thing as [`set_transform`](https://huggingface.co/docs/datasets/v2.0.0/en/package_reference/main_classes#datasets.Dataset.set_transform), except that it does return a new `Dataset` rather than performing the operation in-place.
236
+
237
+ ### 2. Train a model and push to Hub
238
+
239
+ Next, one can start training a model. This could be any model you'd like. However, we provide some example scripts to help you get started, in both [PyTorch](pytorch) and [Tensorflow](tensorflow). An example is the [DCGAN](pytorch/dcgan) model for unconditional image generation. Simply follow the README that explains all the details of the relevant implementation, and run it in your environment.
240
+
241
+ The PyTorch example scripts all leverage 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index), which provides an easy API to make your scripts run on any kind of distributed setting (multi-GPUs, TPUs etc.) and with mixed precision, while still letting you write your own training loop.
242
+
243
+ Alternatively, we also provide a [Links to Check Out](#links-to-check-out) section to give you some inspiration.
244
+
245
+ Below, we explain in more detail how to upload your model to the Hub, depending on the framework you're using (sections [2.1](#21-pytorch) and [2.2](#22-keras)). In section [2.3](#33-alternative-ways-to-upload-a-model-to-the-hub), we'll explain how to write a nice model card. In section [2.4](24-model-cards), we'll illustrate alternative ways to upload (and re-use) a model to (and from) the hub. Finally, in section [2.5](25-accelerate), we explain 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index), the awesome library that makes training PyTorch models on any kind of environment a breeze. Be sure to check it out!
246
+
247
+ #### 2.1 PyTorch
248
+
249
+ If you're planning to train a custom PyTorch model, it's recommended to make it inherit from `PyTorchModelHubMixin`. This makes sure you can push it to the Hub at the end of training, and reload it afterwards using `from_pretrained`, as shown in the code example below:
250
+
251
+ ```python
252
+ from huggingface_hub import PyTorchModelHubMixin
253
+
254
+ class MyGenerator(nn.Module, PyTorchModelHubMixin):
255
+ def __init__(self, **kwargs):
256
+ super().__init__()
257
+ self.config = kwargs.pop("config", None)
258
+ self.layer = ...
259
+ def forward(self, ...):
260
+ return ...
261
+
262
+ # Create model
263
+ model = MyGenerator()
264
+
265
+ # Push to HuggingFace Hub
266
+ model.push_to_hub("huggan/name-of-your-model").
267
+
268
+ # Reload from HuggingFace Hub
269
+ reloaded = MyGenerator.from_pretrained("huggan/name-of-your-model").
270
+ ```
271
+
272
+ This `PyTorchModelHubMixin` class is available in the [`huggingface_hub` library](https://github.com/huggingface/huggingface_hub), which comes pre-installed if you install `datasets` (or `transformers`) in your environment.
273
+
274
+ #### 2.2 Keras
275
+
276
+ In Keras, one can leverage the `push_to_hub_keras` and `from_pretrained_keras` methods:
277
+
278
+ ```python
279
+ import tensorflow as tf
280
+ from huggingface_hub import push_to_hub_keras, from_pretrained_keras
281
+
282
+ # Build a Keras model
283
+ inputs = tf.keras.layers.Input(shape=(2,))
284
+ x = tf.keras.layers.Dense(2, activation="relu")(inputs)
285
+ model = tf.keras.models.Model(inputs=inputs, outputs=x)
286
+ model.compile(optimizer="adam", loss="mse")
287
+
288
+ # Push to HuggingFace Hub
289
+ push_to_hub_keras(model, "huggan/my-cool-model")
290
+
291
+ # Reload from HuggingFace Hub
292
+ reloaded = from_pretrained_keras("huggan/my-cool-model")
293
+ ```
294
+
295
+ These methods are available in the [`huggingface_hub` library](https://github.com/huggingface/huggingface_hub), which comes pre-installed if you install `datasets` (or `transformers`) in your environment. Note that the `push_to_hub_keras` method supports pushing several models (such as a generator and discriminator) to the same repo, as illustrated [here](https://github.com/huggingface/huggingface_hub/issues/533#issuecomment-1058093158).
296
+
297
+ #### 2.3 Alternative ways to upload a model to the Hub
298
+
299
+ Besides the methods explained in sections 2.1 and 2.2 above, you can also share model assets directly from git, which is explained in depth in [this guide](https://huggingface.co/docs/hub/adding-a-model#uploading-your-files).
300
+
301
+ #### 2.4 Model cards
302
+
303
+ When uploading a model to the Hub, it's important to include a so-called [model card](https://huggingface.co/course/chapter4/4?fw=pt) with it. This is just a README (in Markdown) 🃏 that includes:
304
+ - license,
305
+ - task,
306
+ - `huggan` and `gan` tags,
307
+ - dataset metadata,
308
+ - information related to the model,
309
+ - information on dataset, intended uses,
310
+ - a model output.
311
+
312
+ If you trained one of the example models, this model card will be automatically generated for you. If you didn’t train the model yourself, be sure to both credit the original authors and include the associated license in your model card! Here is an [example model repo](https://huggingface.co/merve/anime-faces-generator).
313
+
314
+ You can also use this [template model card](model_card_template.md)
315
+ as a guide to build your own.
316
+
317
+ ![Alt text](assets/example_model.png?raw=true "Title")
318
+
319
+ #### 2.5 Accelerate
320
+
321
+ HuggingFace `accelerate` is an awesome library for training PyTorch models. Here we show why.
322
+
323
+ Basically, the library requires to replace this:
324
+
325
+ ```
326
+ my_model.to(device)
327
+
328
+ for batch in my_training_dataloader:
329
+ my_optimizer.zero_grad()
330
+ inputs, targets = batch
331
+ inputs = inputs.to(device)
332
+ targets = targets.to(device)
333
+ outputs = my_model(inputs)
334
+ loss = my_loss_function(outputs, targets)
335
+ loss.backward()
336
+ my_optimizer.step()
337
+ ```
338
+
339
+ by this:
340
+
341
+ ```diff
342
+ + from accelerate import Accelerator
343
+
344
+ + accelerator = Accelerator()
345
+ - my_model.to(device)
346
+ # Pass every important object (model, optimizer, dataloader) to *accelerator.prepare*
347
+ + my_model, my_optimizer, my_training_dataloader = accelerate.prepare(
348
+ + my_model, my_optimizer, my_training_dataloader
349
+ + )
350
+
351
+ for batch in my_training_dataloader:
352
+ my_optimizer.zero_grad()
353
+ inputs, targets = batch
354
+ - inputs = inputs.to(device)
355
+ - targets = targets.to(device)
356
+ outputs = my_model(inputs)
357
+ loss = my_loss_function(outputs, targets)
358
+ # Just a small change for the backward instruction
359
+ - loss.backward()
360
+ + accelerator.backward(loss)
361
+ my_optimizer.step()
362
+ ```
363
+
364
+ and BOOM, your script runs on **any kind of hardware**, including CPU, multi-CPU, GPU, multi-GPU and TPU. It also supports things like [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [mixed precision](https://arxiv.org/abs/1710.03740) for training efficiently.
365
+
366
+ You can now run your script as follows:
367
+
368
+ ```bash
369
+ accelerate config
370
+ ```
371
+
372
+ => Accelerate will ask what kind of environment you'd like to run your script on, simply answer the questions being asked. Next:
373
+
374
+ ```bash
375
+ accelerate launch <your script.py>
376
+ ```
377
+
378
+ This will run your script on the environment you asked for. You can always check the environment settings by typing:
379
+
380
+ ```bash
381
+ accelerate env
382
+ ```
383
+
384
+ You can of course change the environment by running `accelerate config` again.
385
+
386
+ ### 3. Create a demo
387
+
388
+ Once you share a model, you then should share a [Space](https://huggingface.co/spaces) based on your SDK of choice (Gradio or Streamlit) or as a static page. 🌌
389
+
390
+ ![Alt text](assets/example_space.png?raw=true "Title")
391
+
392
+ Here is an [example Space](https://huggingface.co/spaces/merve/anime-face-generator) corresponding to the model example shared above. Don’t know how to create a space? Read more about how to add spaces [here](https://huggingface.co/docs/hub/spaces).
393
+
394
+ Below, we list some other great example GAN Spaces:
395
+ - AnimeGANv2: https://huggingface.co/spaces/akhaliq/AnimeGANv2
396
+ - ArcaneGAN: https://huggingface.co/spaces/akhaliq/ArcaneGAN
397
+ - This Pokemon does not exist: https://huggingface.co/spaces/ronvolutional/ai-pokemon-card
398
+ - GFP-GAN: https://huggingface.co/spaces/akhaliq/GFPGAN
399
+ - DualStyleGAN: https://huggingface.co/spaces/hysts/DualStyleGAN
400
+
401
+ ## Example Scripts
402
+
403
+ In this repo, we have provided some example scripts you can use to train your own GANs. Below is a table of the available scripts:
404
+
405
+ | Name | Paper |
406
+ | ----------- | ----------- |
407
+ | [DCGAN](pytorch/dcgan) | [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434) |
408
+ | [pix2pix](pytorch/pix2pix) | [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004) |
409
+ | [CycleGAN](pytorch/cyclegan) | [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593)
410
+
411
+ ## Datasets to add
412
+
413
+ Below, we list some datasets which could be added to the Hub (feel free to add on one of these, or open a PR to add more datasets!):
414
+
415
+ - DeepFashion: https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html
416
+ - Flowers: https://www.robots.ox.ac.uk/~vgg/data/flowers/
417
+ - LSUN: https://www.yf.io/p/lsun
418
+
419
+ ## Links to Check Out
420
+
421
+ Below, we list some possible awesome project ideas (feel free to work on one of these, or open a PR to add more project ideas!):
422
+
423
+ PyTorch:
424
+ - Lightweight-GAN: https://github.com/lucidrains/lightweight-gan
425
+ - StyleGAN2: https://github.com/lucidrains/stylegan2-pytorch
426
+ - StyleGAN2-ada: https://github.com/NVlabs/stylegan2-ada
427
+ - StyleGAN3 (alias-free GAN): https://github.com/NVlabs/stylegan3
428
+ - BigGAN: https://github.com/ajbrock/BigGAN-PyTorch, https://github.com/huggingface/pytorch-pretrained-BigGAN
429
+ - ADGAN: https://github.com/menyifang/ADGAN
430
+ - ICGAN: https://github.com/facebookresearch/ic_gan
431
+ - StarGANv2: https://github.com/clovaai/stargan-v2
432
+ - Progressive Growing GAN: https://github.com/Maggiking/PGGAN-PyTorch
433
+ - Vision Aided GAN: https://github.com/nupurkmr9/vision-aided-gan
434
+ - DiffAugment (for training data-efficient GANs): https://github.com/mit-han-lab/data-efficient-gans
435
+ - StyleGAN-XL: https://github.com/autonomousvision/stylegan_xl
436
+ - CUT: https://github.com/taesungp/contrastive-unpaired-translation
437
+ - studioGAN (library with many GAN implementations): https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
438
+ - MMGeneration (library with many GAN implementations): https://github.com/open-mmlab/mmgeneration
439
+ - Deformable GAN: https://github.com/ssfootball04/pose-transfer
440
+ - Denoising Diffusion GAN: https://github.com/NVlabs/denoising-diffusion-gan
441
+
442
+ Keras:
443
+ - WGAN-GP: https://keras.io/examples/generative/wgan_gp/
444
+ - Conditional GAN: https://keras.io/examples/generative/conditional_gan/
445
+ - CycleGAN, DiscoGAN etc.: https://github.com/eriklindernoren/Keras-GAN
446
+ - Neural Style Transfer: https://www.tensorflow.org/tutorials/generative/style_transfer
447
+ - Image Super Resolution: https://github.com/idealo/image-super-resolution
448
+ - Deformable GAN: https://github.com/AliaksandrSiarohin/pose-gan
449
+
450
+ General links & tutorials:
451
+ - https://github.com/yhlleo/GAN-Metrics
452
+ - https://paperswithcode.com/task/image-generation
453
+
454
+ ## GAN metrics
455
+
456
+ There have been several quantitative measures defined for assessing the quality of GANs (and other generative models). Refer to [this page](pytorch/metrics) for more info.
457
+
458
+ ## Evaluation
459
+
460
+ For each submission, you are expected to submit:
461
+
462
+ 1. A model repository
463
+ 2. A space made with the model repository you created
464
+
465
+ ## Prizes
466
+
467
+ TODO
468
+
469
+ ## Communication and Problems
470
+
471
+ If you encounter any problems or have any questions, you should use one of the following platforms depending on your type of problem. Hugging Face is an "open-source-first" organization meaning that we'll try to solve all problems in the most public and most transparent way possible so that everybody in the community profits.
472
+
473
+ The following table summarizes what platform to use for which problem.
474
+
475
+ - Problem/question/bug with the 🤗 Datasets library that you think is a general problem that also impacts other people, please open an [Issues on Datasets](https://github.com/huggingface/datasets/issues/new?assignees=&labels=bug&template=bug-report.md&title=) and ping @nielsrogge.
476
+ - Problem/question with a modified, customized training script that is less likely to impact other people, please post your problem/question [on the forum](https://discuss.huggingface.co/) and ping @nielsrogge.
477
+ - Other questions regarding the event, rules of the event, or if you are not sure where to post your question, please ask in the Discord channel [**#sprint-discussions**](https://discord.com/channels/879548962464493619/954111918895943720).
478
+
479
+ ## Talks
480
+
481
+ TODO
482
+
483
+ ## General Tips and Tricks
484
+
485
+ - Memory efficient training:
486
+
487
+ In case, you are getting out-of-memory errors on your GPU, we recommend to use [bitsandbytes](https://github.com/facebookresearch/bitsandbytes) to replace the native memory-intensive Adam optimizer with the one of `bitsandbytes`. It can be used to both train the generator and the discriminator in case you're training a GAN.
huggan/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ TEMPLATE_MODEL_CARD_PATH = Path(__file__).parent.absolute() / 'model_card_template.md'
huggan/assets/cyclegan.png ADDED

Git LFS Details

  • SHA256: 6ae90c6e39d2675e1059d60c4d1af1da4895eaef8665f5b3b70189b1b96d348b
  • Pointer size: 132 Bytes
  • Size of remote file: 2.74 MB
huggan/assets/dcgan_mnist.png ADDED
huggan/assets/example_model.png ADDED
huggan/assets/example_space.png ADDED
huggan/assets/huggan_banner.png ADDED
huggan/assets/lightweight_gan_wandb.png ADDED

Git LFS Details

  • SHA256: 3f11f0a72781708bd2b2f3f9beaa87b859b501765f7782112ec758afd347dc2f
  • Pointer size: 132 Bytes
  • Size of remote file: 2.95 MB
huggan/assets/metfaces.png ADDED
huggan/assets/pix2pix_maps.png ADDED

Git LFS Details

  • SHA256: ef74c7a85d56e5a4819b84bf6c1916f4b99090252f469379cf110885073b1508
  • Pointer size: 132 Bytes
  • Size of remote file: 2.93 MB
huggan/assets/wandb.png ADDED

Git LFS Details

  • SHA256: cd973bc2b323d414c7757ec6e792e5a2794fab27c3481312bac6e77d5e75ea4d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.85 MB
huggan/model_card_template.md ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - huggan
4
+ - gan
5
+ # See a list of available tags here:
6
+ # https://github.com/huggingface/hub-docs/blob/main/js/src/lib/interfaces/Types.ts#L12
7
+ # task: unconditional-image-generation or conditional-image-generation or image-to-image
8
+ license: mit
9
+ ---
10
+
11
+ # MyModelName
12
+
13
+ ## Model description
14
+
15
+ Describe the model here (what it does, what it's used for, etc.)
16
+
17
+ ## Intended uses & limitations
18
+
19
+ #### How to use
20
+
21
+ ```python
22
+ # You can include sample code which will be formatted
23
+ ```
24
+
25
+ #### Limitations and bias
26
+
27
+ Provide examples of latent issues and potential remediations.
28
+
29
+ ## Training data
30
+
31
+ Describe the data you used to train the model.
32
+ If you initialized it with pre-trained weights, add a link to the pre-trained model card or repository with description of the pre-training data.
33
+
34
+ ## Training procedure
35
+
36
+ Preprocessing, hardware used, hyperparameters...
37
+
38
+ ## Eval results
39
+
40
+ ## Generated Images
41
+
42
+ You can embed local or remote images using `![](...)`
43
+
44
+ ### BibTeX entry and citation info
45
+
46
+ ```bibtex
47
+ @inproceedings{...,
48
+ year={2020}
49
+ }
50
+ ```
huggan/pytorch/README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Example scripts (PyTorch)
2
+
3
+ This directory contains a few example scripts that allow you to train famous GANs on your own data using a bit of 🤗 magic.
4
+
5
+ More concretely, these scripts:
6
+ - leverage 🤗 [Datasets](https://huggingface.co/docs/datasets/index) to load any image dataset from the hub (including your own, possibly private, dataset)
7
+ - leverage 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) to instantly run the script on (multi-) CPU, (multi-) GPU, TPU environments, supporting fp16 and mixed precision as well as DeepSpeed
8
+ - leverage 🤗 [Hub](https://huggingface.co/) to push the model to the hub at the end of training, allowing to easily create a demo for it afterwards
9
+
10
+ Currently, it contains the following examples:
11
+
12
+ | Name | Paper |
13
+ | ----------- | ----------- |
14
+ | [DCGAN](dcgan) | [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434) |
15
+ | [pix2pix](pix2pix) | [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004) |
16
+ | [CycleGAN](cyclegan) | [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593)
17
+ | [Lightweight GAN](lightweight_gan) | [Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis](https://openreview.net/forum?id=1Fqg133qRaI)
18
+
19
+
huggan/pytorch/__init__.py ADDED
File without changes
huggan/pytorch/cyclegan/README.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training CycleGAN on your own data
2
+
3
+ This folder contains a script to train [CycleGAN](https://arxiv.org/abs/1703.10593), leveraging the [Hugging Face](https://huggingface.co/) ecosystem for processing data and pushing the model to the Hub.
4
+
5
+ <p align="center">
6
+ <img src="https://camo.githubusercontent.com/16fa02525bf502bec1aac77a3eb5b96928b0f25d73f7d9dedcc041ba28c38751/68747470733a2f2f6a756e79616e7a2e6769746875622e696f2f4379636c6547414e2f696d616765732f7465617365725f686967685f7265732e6a7067" alt="drawing" width="700"/>
7
+ </p>
8
+
9
+ Example applications of CycleGAN. Taken from [this repo](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
10
+
11
+ The script leverages 🤗 Datasets for loading and processing data, and 🤗 Accelerate for instantly running on CPU, single, multi-GPUs or TPU, also supporting mixed precision.
12
+
13
+ ## Launching the script
14
+
15
+ To train the model with the default parameters (200 epochs, 256x256 images, etc.) on [huggan/facades](https://huggingface.co/datasets/huggan/facades) on your environment, first run:
16
+
17
+ ```bash
18
+ accelerate config
19
+ ```
20
+
21
+ and answer the questions asked. Next, launch the script as follows:
22
+
23
+ ```
24
+ accelerate launch train.py
25
+ ```
26
+
27
+ This will create local "images" and "saved_models" directories, containing generated images and saved checkpoints over the course of the training.
28
+
29
+ To train on another dataset available on the hub, simply do:
30
+
31
+ ```
32
+ accelerate launch train.py --dataset huggan/edges2shoes
33
+ ```
34
+
35
+ Make sure to pick a dataset which has "imageA" and "imageB" columns defined. One can always tweak the script in case the column names are different.
36
+
37
+ ## Training on your own data
38
+
39
+ You can of course also train on your own images. For this, one can leverage Datasets' [ImageFolder](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder). Make sure to authenticate with the hub first, by running the `huggingface-cli login` command in a terminal, or the following in case you're working in a notebook:
40
+
41
+ ```python
42
+ from huggingface_hub import notebook_login
43
+
44
+ notebook_login()
45
+ ```
46
+
47
+ Next, run the following in a notebook/script:
48
+
49
+ ```python
50
+ from datasets import load_dataset
51
+
52
+ # first: load dataset
53
+ # option 1: from local folder
54
+ dataset = load_dataset("imagefolder", data_dir="path_to_folder")
55
+ # option 2: from remote URL (e.g. a zip file)
56
+ dataset = load_dataset("imagefolder", data_files="URL to .zip file")
57
+
58
+ # next: push to the hub (assuming git-LFS is installed)
59
+ dataset.push_to_hub("huggan/my-awesome-dataset")
60
+ ```
61
+
62
+ You can then simply pass the name of the dataset to the script:
63
+
64
+ ```
65
+ accelerate launch train.py --dataset huggan/my-awesome-dataset
66
+ ```
67
+
68
+ ## Pushing model to the Hub
69
+
70
+ You can push your trained generator to the hub after training by specifying the `push_to_hub` flag.
71
+ Then, you can run the script as follows:
72
+
73
+ ```
74
+ accelerate launch train.py --push_to_hub --model_name cyclegan-horse2zebra
75
+ ```
76
+
77
+ This is made possible by making the generator inherit from `PyTorchModelHubMixin`available in the `huggingface_hub` library.
78
+
79
+ # Citation
80
+
81
+ This repo is entirely based on Erik Linder-Norén's [PyTorch-GAN repo](https://github.com/eriklindernoren/PyTorch-GAN), but with added HuggingFace goodies.
huggan/pytorch/cyclegan/__init__.py ADDED
File without changes
huggan/pytorch/cyclegan/modeling_cyclegan.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+
5
+ from huggan.pytorch.huggan_mixin import HugGANModelHubMixin
6
+
7
+
8
+ ##############################
9
+ # RESNET
10
+ ##############################
11
+
12
+
13
+ class ResidualBlock(nn.Module):
14
+ def __init__(self, in_features):
15
+ super(ResidualBlock, self).__init__()
16
+
17
+ self.block = nn.Sequential(
18
+ nn.ReflectionPad2d(1),
19
+ nn.Conv2d(in_features, in_features, 3),
20
+ nn.InstanceNorm2d(in_features),
21
+ nn.ReLU(inplace=True),
22
+ nn.ReflectionPad2d(1),
23
+ nn.Conv2d(in_features, in_features, 3),
24
+ nn.InstanceNorm2d(in_features),
25
+ )
26
+
27
+ def forward(self, x):
28
+ return x + self.block(x)
29
+
30
+
31
+ class GeneratorResNet(nn.Module, HugGANModelHubMixin):
32
+ def __init__(self, input_shape, num_residual_blocks):
33
+ super(GeneratorResNet, self).__init__()
34
+
35
+ channels = input_shape[0]
36
+
37
+ # Initial convolution block
38
+ out_features = 64
39
+ model = [
40
+ nn.ReflectionPad2d(channels),
41
+ nn.Conv2d(channels, out_features, 7),
42
+ nn.InstanceNorm2d(out_features),
43
+ nn.ReLU(inplace=True),
44
+ ]
45
+ in_features = out_features
46
+
47
+ # Downsampling
48
+ for _ in range(2):
49
+ out_features *= 2
50
+ model += [
51
+ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
52
+ nn.InstanceNorm2d(out_features),
53
+ nn.ReLU(inplace=True),
54
+ ]
55
+ in_features = out_features
56
+
57
+ # Residual blocks
58
+ for _ in range(num_residual_blocks):
59
+ model += [ResidualBlock(out_features)]
60
+
61
+ # Upsampling
62
+ for _ in range(2):
63
+ out_features //= 2
64
+ model += [
65
+ nn.Upsample(scale_factor=2),
66
+ nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
67
+ nn.InstanceNorm2d(out_features),
68
+ nn.ReLU(inplace=True),
69
+ ]
70
+ in_features = out_features
71
+
72
+ # Output layer
73
+ model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
74
+
75
+ self.model = nn.Sequential(*model)
76
+
77
+ def forward(self, x):
78
+ return self.model(x)
79
+
80
+
81
+ ##############################
82
+ # Discriminator
83
+ ##############################
84
+
85
+
86
+ class Discriminator(nn.Module):
87
+ def __init__(self, channels):
88
+ super(Discriminator, self).__init__()
89
+
90
+ def discriminator_block(in_filters, out_filters, normalize=True):
91
+ """Returns downsampling layers of each discriminator block"""
92
+ layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
93
+ if normalize:
94
+ layers.append(nn.InstanceNorm2d(out_filters))
95
+ layers.append(nn.LeakyReLU(0.2, inplace=True))
96
+ return layers
97
+
98
+ self.model = nn.Sequential(
99
+ *discriminator_block(channels, 64, normalize=False),
100
+ *discriminator_block(64, 128),
101
+ *discriminator_block(128, 256),
102
+ *discriminator_block(256, 512),
103
+ nn.ZeroPad2d((1, 0, 1, 0)),
104
+ nn.Conv2d(512, 1, 4, padding=1)
105
+ )
106
+
107
+ def forward(self, img):
108
+ return self.model(img)
huggan/pytorch/cyclegan/train.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import numpy as np
4
+ import itertools
5
+ from pathlib import Path
6
+ import datetime
7
+ import time
8
+ import sys
9
+
10
+ from PIL import Image
11
+
12
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomHorizontalFlip
13
+ from torchvision.utils import save_image, make_grid
14
+
15
+ from torch.utils.data import DataLoader
16
+
17
+ from modeling_cyclegan import GeneratorResNet, Discriminator
18
+
19
+ from utils import ReplayBuffer, LambdaLR
20
+
21
+ from datasets import load_dataset
22
+
23
+ from accelerate import Accelerator
24
+
25
+ import torch.nn as nn
26
+ import torch
27
+
28
+ def parse_args(args=None):
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
31
+ parser.add_argument("--num_epochs", type=int, default=200, help="number of epochs of training")
32
+ parser.add_argument("--dataset_name", type=str, default="huggan/facades", help="name of the dataset")
33
+ parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
34
+ parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
35
+ parser.add_argument("--beta1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
36
+ parser.add_argument("--beta2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
37
+ parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
38
+ parser.add_argument("--num_workers", type=int, default=8, help="Number of CPU threads to use during batch generation")
39
+ parser.add_argument("--image_size", type=int, default=256, help="Size of images for training")
40
+ parser.add_argument("--channels", type=int, default=3, help="Number of image channels")
41
+ parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator outputs")
42
+ parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between saving model checkpoints")
43
+ parser.add_argument("--n_residual_blocks", type=int, default=9, help="number of residual blocks in generator")
44
+ parser.add_argument("--lambda_cyc", type=float, default=10.0, help="cycle loss weight")
45
+ parser.add_argument("--lambda_id", type=float, default=5.0, help="identity loss weight")
46
+ parser.add_argument("--fp16", action="store_true", help="If passed, will use FP16 training.")
47
+ parser.add_argument(
48
+ "--mixed_precision",
49
+ type=str,
50
+ default="no",
51
+ choices=["no", "fp16", "bf16"],
52
+ help="Whether to use mixed precision. Choose"
53
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
54
+ "and an Nvidia Ampere GPU.",
55
+ )
56
+ parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.")
57
+ parser.add_argument(
58
+ "--push_to_hub",
59
+ action="store_true",
60
+ help="Whether to push the model to the HuggingFace hub after training.",
61
+ )
62
+ parser.add_argument(
63
+ "--pytorch_dump_folder_path",
64
+ required="--push_to_hub" in sys.argv,
65
+ type=Path,
66
+ help="Path to save the model. Will be created if it doesn't exist already.",
67
+ )
68
+ parser.add_argument(
69
+ "--model_name",
70
+ required="--push_to_hub" in sys.argv,
71
+ type=str,
72
+ help="Name of the model on the hub.",
73
+ )
74
+ parser.add_argument(
75
+ "--organization_name",
76
+ required=False,
77
+ default="huggan",
78
+ type=str,
79
+ help="Organization name to push to, in case args.push_to_hub is specified.",
80
+ )
81
+ return parser.parse_args(args=args)
82
+
83
+
84
+ def weights_init_normal(m):
85
+ classname = m.__class__.__name__
86
+ if classname.find("Conv") != -1:
87
+ torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
88
+ if hasattr(m, "bias") and m.bias is not None:
89
+ torch.nn.init.constant_(m.bias.data, 0.0)
90
+ elif classname.find("BatchNorm2d") != -1:
91
+ torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
92
+ torch.nn.init.constant_(m.bias.data, 0.0)
93
+
94
+
95
+ def training_function(config, args):
96
+ accelerator = Accelerator(fp16=args.fp16, cpu=args.cpu, mixed_precision=args.mixed_precision)
97
+
98
+ # Create sample and checkpoint directories
99
+ os.makedirs("images/%s" % args.dataset_name, exist_ok=True)
100
+ os.makedirs("saved_models/%s" % args.dataset_name, exist_ok=True)
101
+
102
+ # Losses
103
+ criterion_GAN = torch.nn.MSELoss()
104
+ criterion_cycle = torch.nn.L1Loss()
105
+ criterion_identity = torch.nn.L1Loss()
106
+
107
+ input_shape = (args.channels, args.image_size, args.image_size)
108
+ # Calculate output shape of image discriminator (PatchGAN)
109
+ output_shape = (1, args.image_size // 2 ** 4, args.image_size // 2 ** 4)
110
+
111
+ # Initialize generator and discriminator
112
+ G_AB = GeneratorResNet(input_shape, args.n_residual_blocks)
113
+ G_BA = GeneratorResNet(input_shape, args.n_residual_blocks)
114
+ D_A = Discriminator(args.channels)
115
+ D_B = Discriminator(args.channels)
116
+
117
+ if args.epoch != 0:
118
+ # Load pretrained models
119
+ G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (args.dataset_name, args.epoch)))
120
+ G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (args.dataset_name, args.epoch)))
121
+ D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (args.dataset_name, args.epoch)))
122
+ D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (args.dataset_name, args.epoch)))
123
+ else:
124
+ # Initialize weights
125
+ G_AB.apply(weights_init_normal)
126
+ G_BA.apply(weights_init_normal)
127
+ D_A.apply(weights_init_normal)
128
+ D_B.apply(weights_init_normal)
129
+
130
+ # Optimizers
131
+ optimizer_G = torch.optim.Adam(
132
+ itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=args.lr, betas=(args.beta1, args.beta2)
133
+ )
134
+ optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
135
+ optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
136
+
137
+ # Learning rate update schedulers
138
+ lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
139
+ optimizer_G, lr_lambda=LambdaLR(args.num_epochs, args.epoch, args.decay_epoch).step
140
+ )
141
+ lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
142
+ optimizer_D_A, lr_lambda=LambdaLR(args.num_epochs, args.epoch, args.decay_epoch).step
143
+ )
144
+ lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
145
+ optimizer_D_B, lr_lambda=LambdaLR(args.num_epochs, args.epoch, args.decay_epoch).step
146
+ )
147
+
148
+ # Buffers of previously generated samples
149
+ fake_A_buffer = ReplayBuffer()
150
+ fake_B_buffer = ReplayBuffer()
151
+
152
+ # Image transformations
153
+ transform = Compose([
154
+ Resize(int(args.image_size * 1.12), Image.BICUBIC),
155
+ RandomCrop((args.image_size, args.image_size)),
156
+ RandomHorizontalFlip(),
157
+ ToTensor(),
158
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
159
+ ])
160
+
161
+ def transforms(examples):
162
+ examples["A"] = [transform(image.convert("RGB")) for image in examples["imageA"]]
163
+ examples["B"] = [transform(image.convert("RGB")) for image in examples["imageB"]]
164
+
165
+ del examples["imageA"]
166
+ del examples["imageB"]
167
+
168
+ return examples
169
+
170
+ dataset = load_dataset(args.dataset_name)
171
+ transformed_dataset = dataset.with_transform(transforms)
172
+
173
+ splits = transformed_dataset['train'].train_test_split(test_size=0.1)
174
+ train_ds = splits['train']
175
+ val_ds = splits['test']
176
+
177
+ dataloader = DataLoader(train_ds, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers)
178
+ val_dataloader = DataLoader(val_ds, batch_size=5, shuffle=True, num_workers=1)
179
+
180
+ def sample_images(batches_done):
181
+ """Saves a generated sample from the test set"""
182
+ batch = next(iter(val_dataloader))
183
+ G_AB.eval()
184
+ G_BA.eval()
185
+ real_A = batch["A"]
186
+ fake_B = G_AB(real_A)
187
+ real_B = batch["B"]
188
+ fake_A = G_BA(real_B)
189
+ # Arange images along x-axis
190
+ real_A = make_grid(real_A, nrow=5, normalize=True)
191
+ real_B = make_grid(real_B, nrow=5, normalize=True)
192
+ fake_A = make_grid(fake_A, nrow=5, normalize=True)
193
+ fake_B = make_grid(fake_B, nrow=5, normalize=True)
194
+ # Arange images along y-axis
195
+ image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
196
+ save_image(image_grid, "images/%s/%s.png" % (args.dataset_name, batches_done), normalize=False)
197
+
198
+ G_AB, G_BA, D_A, D_B, optimizer_G, optimizer_D_A, optimizer_D_B, dataloader, val_dataloader = accelerator.prepare(G_AB, G_BA, D_A, D_B, optimizer_G, optimizer_D_A, optimizer_D_B, dataloader, val_dataloader)
199
+
200
+ # ----------
201
+ # Training
202
+ # ----------
203
+
204
+ prev_time = time.time()
205
+ for epoch in range(args.epoch, args.num_epochs):
206
+ for i, batch in enumerate(dataloader):
207
+
208
+ # Set model input
209
+ real_A = batch["A"]
210
+ real_B = batch["B"]
211
+
212
+ # Adversarial ground truths
213
+ valid = torch.ones((real_A.size(0), *output_shape), device=accelerator.device)
214
+ fake = torch.zeros((real_A.size(0), *output_shape), device=accelerator.device)
215
+
216
+ # ------------------
217
+ # Train Generators
218
+ # ------------------
219
+
220
+ G_AB.train()
221
+ G_BA.train()
222
+
223
+ optimizer_G.zero_grad()
224
+
225
+ # Identity loss
226
+ loss_id_A = criterion_identity(G_BA(real_A), real_A)
227
+ loss_id_B = criterion_identity(G_AB(real_B), real_B)
228
+
229
+ loss_identity = (loss_id_A + loss_id_B) / 2
230
+
231
+ # GAN loss
232
+ fake_B = G_AB(real_A)
233
+ loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
234
+ fake_A = G_BA(real_B)
235
+ loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
236
+
237
+ loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
238
+
239
+ # Cycle loss
240
+ recov_A = G_BA(fake_B)
241
+ loss_cycle_A = criterion_cycle(recov_A, real_A)
242
+ recov_B = G_AB(fake_A)
243
+ loss_cycle_B = criterion_cycle(recov_B, real_B)
244
+
245
+ loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
246
+
247
+ # Total loss
248
+ loss_G = loss_GAN + args.lambda_cyc * loss_cycle + args.lambda_id * loss_identity
249
+
250
+ accelerator.backward(loss_G)
251
+ optimizer_G.step()
252
+
253
+ # -----------------------
254
+ # Train Discriminator A
255
+ # -----------------------
256
+
257
+ optimizer_D_A.zero_grad()
258
+
259
+ # Real loss
260
+ loss_real = criterion_GAN(D_A(real_A), valid)
261
+ # Fake loss (on batch of previously generated samples)
262
+ fake_A_ = fake_A_buffer.push_and_pop(fake_A)
263
+ loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
264
+ # Total loss
265
+ loss_D_A = (loss_real + loss_fake) / 2
266
+
267
+ accelerator.backward(loss_D_A)
268
+ optimizer_D_A.step()
269
+
270
+ # -----------------------
271
+ # Train Discriminator B
272
+ # -----------------------
273
+
274
+ optimizer_D_B.zero_grad()
275
+
276
+ # Real loss
277
+ loss_real = criterion_GAN(D_B(real_B), valid)
278
+ # Fake loss (on batch of previously generated samples)
279
+ fake_B_ = fake_B_buffer.push_and_pop(fake_B)
280
+ loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
281
+ # Total loss
282
+ loss_D_B = (loss_real + loss_fake) / 2
283
+
284
+ accelerator.backward(loss_D_B)
285
+ optimizer_D_B.step()
286
+
287
+ loss_D = (loss_D_A + loss_D_B) / 2
288
+
289
+ # --------------
290
+ # Log Progress
291
+ # --------------
292
+
293
+ # Determine approximate time left
294
+ batches_done = epoch * len(dataloader) + i
295
+ batches_left = args.num_epochs * len(dataloader) - batches_done
296
+ time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
297
+ prev_time = time.time()
298
+
299
+ # Print log
300
+ sys.stdout.write(
301
+ "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
302
+ % (
303
+ epoch,
304
+ args.num_epochs,
305
+ i,
306
+ len(dataloader),
307
+ loss_D.item(),
308
+ loss_G.item(),
309
+ loss_GAN.item(),
310
+ loss_cycle.item(),
311
+ loss_identity.item(),
312
+ time_left,
313
+ )
314
+ )
315
+
316
+ # If at sample interval save image
317
+ if batches_done % args.sample_interval == 0:
318
+ sample_images(batches_done)
319
+
320
+ # Update learning rates
321
+ lr_scheduler_G.step()
322
+ lr_scheduler_D_A.step()
323
+ lr_scheduler_D_B.step()
324
+
325
+ if args.checkpoint_interval != -1 and epoch % args.checkpoint_interval == 0:
326
+ # Save model checkpoints
327
+ torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (args.dataset_name, epoch))
328
+ torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (args.dataset_name, epoch))
329
+ torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (args.dataset_name, epoch))
330
+ torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (args.dataset_name, epoch))
331
+
332
+ # Optionally push to hub
333
+ if args.push_to_hub:
334
+ save_directory = args.pytorch_dump_folder_path
335
+ if not save_directory.exists():
336
+ save_directory.mkdir(parents=True)
337
+
338
+ G_AB.push_to_hub(
339
+ repo_path_or_name=save_directory / args.model_name,
340
+ organization=args.organization_name,
341
+ )
342
+
343
+ def main():
344
+ args = parse_args()
345
+ print(args)
346
+
347
+ # Make directory for saving generated images
348
+ os.makedirs("images", exist_ok=True)
349
+
350
+ training_function({}, args)
351
+
352
+
353
+ if __name__ == "__main__":
354
+ main()
huggan/pytorch/cyclegan/utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import time
3
+ import datetime
4
+ import sys
5
+
6
+ from torch.autograd import Variable
7
+ import torch
8
+ import numpy as np
9
+
10
+ from torchvision.utils import save_image
11
+
12
+
13
+ class ReplayBuffer:
14
+ def __init__(self, max_size=50):
15
+ assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
16
+ self.max_size = max_size
17
+ self.data = []
18
+
19
+ def push_and_pop(self, data):
20
+ to_return = []
21
+ for element in data.data:
22
+ element = torch.unsqueeze(element, 0)
23
+ if len(self.data) < self.max_size:
24
+ self.data.append(element)
25
+ to_return.append(element)
26
+ else:
27
+ if random.uniform(0, 1) > 0.5:
28
+ i = random.randint(0, self.max_size - 1)
29
+ to_return.append(self.data[i].clone())
30
+ self.data[i] = element
31
+ else:
32
+ to_return.append(element)
33
+ return Variable(torch.cat(to_return))
34
+
35
+
36
+ class LambdaLR:
37
+ def __init__(self, n_epochs, offset, decay_start_epoch):
38
+ assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
39
+ self.n_epochs = n_epochs
40
+ self.offset = offset
41
+ self.decay_start_epoch = decay_start_epoch
42
+
43
+ def step(self, epoch):
44
+ return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)
huggan/pytorch/dcgan/README.md ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Train DCGAN on your custom data
2
+
3
+ This folder contains a script to train [DCGAN](https://arxiv.org/abs/1511.06434) for unconditional image generation, leveraging the [Hugging Face](https://huggingface.co/) ecosystem for processing your data and pushing the model to the Hub.
4
+
5
+ The script leverages 🤗 Datasets for loading and processing data, and 🤗 Accelerate for instantly running on CPU, single, multi-GPUs or TPU, also supporting fp16/mixed precision.
6
+
7
+ <p align="center">
8
+ <img src="https://raw.githubusercontent.com/huggingface/community-events/main/huggan/assets/dcgan_mnist.png" alt="drawing" width="300"/>
9
+ </p>
10
+
11
+
12
+ ## Launching the script
13
+
14
+ To train the model with the default parameters (5 epochs, 64x64 images, etc.) on [MNIST](https://huggingface.co/datasets/mnist), first run:
15
+
16
+ ```bash
17
+ accelerate config
18
+ ```
19
+
20
+ and answer the questions asked about your environment. Next, launch the script as follows:
21
+
22
+ ```bash
23
+ accelerate launch train.py
24
+ ```
25
+
26
+ This will create a local "images" directory, containing generated images over the course of the training.
27
+
28
+ To train on another dataset available on the hub, simply do (for instance):
29
+
30
+ ```bash
31
+ python train.py --dataset cifar-10
32
+ ```
33
+
34
+ In case you'd like to tweak the script to your liking, first fork the "community-events" [repo](https://github.com/huggingface/community-events) (see the button on the top right), then clone it locally:
35
+
36
+ ```bash
37
+ git clone https://github.com/<your Github username>/community-events.git
38
+ ```
39
+
40
+ and edit to your liking.
41
+
42
+ ## Training on your own data
43
+
44
+ You can of course also train on your own images. For this, one can leverage Datasets' [ImageFolder](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder). Make sure to authenticate with the hub first, by running the `huggingface-cli login` command in a terminal, or the following in case you're working in a notebook:
45
+
46
+ ```python
47
+ from huggingface_hub import notebook_login
48
+
49
+ notebook_login()
50
+ ```
51
+
52
+ Next, run the following in a notebook/script:
53
+
54
+ ```python
55
+ from datasets import load_dataset
56
+
57
+ # first: load dataset
58
+ # option 1: from local folder
59
+ dataset = load_dataset("imagefolder", data_dir="path_to_folder")
60
+ # option 2: from remote URL (e.g. a zip file)
61
+ dataset = load_dataset("imagefolder", data_files="URL to .zip file")
62
+
63
+ # next: push to the hub (assuming git-LFS is installed)
64
+ dataset.push_to_hub("huggan/my-awesome-dataset")
65
+ ```
66
+
67
+ You can then simply pass the name of the dataset to the script:
68
+
69
+ ```bash
70
+ accelerate launch train.py --dataset huggan/my-awesome-dataset
71
+ ```
72
+
73
+ ## Pushing model to the Hub
74
+
75
+ You can push your trained generator to the hub after training by specifying the `push_to_hub` flag, along with a `model_name` and `pytorch_dump_folder_path`.
76
+
77
+ ```bash
78
+ accelerate launch train.py --push_to_hub --model_name dcgan-mnist
79
+ ```
80
+
81
+ This is made possible by making the generator inherit from `PyTorchModelHubMixin`available in the `huggingface_hub` library.
82
+
83
+ This means that after training, generating a new image can be done as follows:
84
+
85
+ ```python
86
+ import torch
87
+ import torch.nn as nn
88
+ from torchvision.transforms import ToPILImage
89
+ from huggingface_hub import PyTorchModelHubMixin
90
+
91
+ class Generator(nn.Module, PyTorchModelHubMixin):
92
+ def __init__(self, num_channels=3, latent_dim=100, hidden_size=64):
93
+ super(Generator, self).__init__()
94
+ self.model = nn.Sequential(
95
+ # input is Z, going into a convolution
96
+ nn.ConvTranspose2d(latent_dim, hidden_size * 8, 4, 1, 0, bias=False),
97
+ nn.BatchNorm2d(hidden_size * 8),
98
+ nn.ReLU(True),
99
+ # state size. (hidden_size*8) x 4 x 4
100
+ nn.ConvTranspose2d(hidden_size * 8, hidden_size * 4, 4, 2, 1, bias=False),
101
+ nn.BatchNorm2d(hidden_size * 4),
102
+ nn.ReLU(True),
103
+ # state size. (hidden_size*4) x 8 x 8
104
+ nn.ConvTranspose2d(hidden_size * 4, hidden_size * 2, 4, 2, 1, bias=False),
105
+ nn.BatchNorm2d(hidden_size * 2),
106
+ nn.ReLU(True),
107
+ # state size. (hidden_size*2) x 16 x 16
108
+ nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1, bias=False),
109
+ nn.BatchNorm2d(hidden_size),
110
+ nn.ReLU(True),
111
+ # state size. (hidden_size) x 32 x 32
112
+ nn.ConvTranspose2d(hidden_size, num_channels, 4, 2, 1, bias=False),
113
+ nn.Tanh()
114
+ # state size. (num_channels) x 64 x 64
115
+ )
116
+
117
+ def forward(self, noise):
118
+ pixel_values = self.model(noise)
119
+
120
+ return pixel_values
121
+
122
+ model = Generator.from_pretrained("huggan/dcgan-mnist")
123
+
124
+ device = "cuda" if torch.cuda.is_available() else "cpu
125
+ model.to(device)
126
+
127
+ with torch.no_grad():
128
+ z = torch.randn(1, 100, 1, 1, device=device)
129
+ pixel_values = model(z)
130
+
131
+ # turn into actual image
132
+ image = pixel_values[0]
133
+ image = (image + 1) /2
134
+ image = ToPILImage()(image)
135
+ image.save("generated.png")
136
+ ```
137
+
138
+ ## Weights and Biases integration
139
+
140
+ You can easily add logging to [Weights and Biases](https://wandb.ai/site) by passing the `--wandb` flag:
141
+
142
+ ```bash
143
+ accelerate launch train.py --wandb
144
+ ````
145
+
146
+ You can then follow the progress of your GAN in a browser:
147
+
148
+ <p align="center">
149
+ <img src="https://raw.githubusercontent.com/huggingface/community-events/main/huggan/assets/wandb.png" alt="drawing" width="700"/>
150
+ </p>
151
+
152
+
153
+ # Citation
154
+
155
+ This repo is entirely based on PyTorch's official [DCGAN tutorial](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html), but with added HuggingFace goodies.
huggan/pytorch/dcgan/__init__.py ADDED
File without changes
huggan/pytorch/dcgan/modeling_dcgan.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright (c) 2022 PyTorch contributors and The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions.
15
+
16
+ import torch.nn as nn
17
+
18
+ from huggan.pytorch.huggan_mixin import HugGANModelHubMixin
19
+
20
+
21
+ class Generator(nn.Module, HugGANModelHubMixin):
22
+ def __init__(self, num_channels=3, latent_dim=100, hidden_size=64):
23
+ super(Generator, self).__init__()
24
+ self.model = nn.Sequential(
25
+ # input is Z, going into a convolution
26
+ nn.ConvTranspose2d(latent_dim, hidden_size * 8, 4, 1, 0, bias=False),
27
+ nn.BatchNorm2d(hidden_size * 8),
28
+ nn.ReLU(True),
29
+ # state size. (hidden_size*8) x 4 x 4
30
+ nn.ConvTranspose2d(hidden_size * 8, hidden_size * 4, 4, 2, 1, bias=False),
31
+ nn.BatchNorm2d(hidden_size * 4),
32
+ nn.ReLU(True),
33
+ # state size. (hidden_size*4) x 8 x 8
34
+ nn.ConvTranspose2d(hidden_size * 4, hidden_size * 2, 4, 2, 1, bias=False),
35
+ nn.BatchNorm2d(hidden_size * 2),
36
+ nn.ReLU(True),
37
+ # state size. (hidden_size*2) x 16 x 16
38
+ nn.ConvTranspose2d(hidden_size * 2, hidden_size, 4, 2, 1, bias=False),
39
+ nn.BatchNorm2d(hidden_size),
40
+ nn.ReLU(True),
41
+ # state size. (hidden_size) x 32 x 32
42
+ nn.ConvTranspose2d(hidden_size, num_channels, 4, 2, 1, bias=False),
43
+ nn.Tanh()
44
+ # state size. (num_channels) x 64 x 64
45
+ )
46
+
47
+ def forward(self, noise):
48
+ pixel_values = self.model(noise)
49
+
50
+ return pixel_values
51
+
52
+
53
+ class Discriminator(nn.Module):
54
+ def __init__(self, num_channels=3, hidden_size=64):
55
+ super(Discriminator, self).__init__()
56
+ self.model = nn.Sequential(
57
+ # input is (num_channels) x 64 x 64
58
+ nn.Conv2d(num_channels, hidden_size, 4, 2, 1, bias=False),
59
+ nn.LeakyReLU(0.2, inplace=True),
60
+ # state size. (hidden_size) x 32 x 32
61
+ nn.Conv2d(hidden_size, hidden_size * 2, 4, 2, 1, bias=False),
62
+ nn.BatchNorm2d(hidden_size * 2),
63
+ nn.LeakyReLU(0.2, inplace=True),
64
+ # state size. (hidden_size*2) x 16 x 16
65
+ nn.Conv2d(hidden_size * 2, hidden_size * 4, 4, 2, 1, bias=False),
66
+ nn.BatchNorm2d(hidden_size * 4),
67
+ nn.LeakyReLU(0.2, inplace=True),
68
+ # state size. (hidden_size*4) x 8 x 8
69
+ nn.Conv2d(hidden_size * 4, hidden_size * 8, 4, 2, 1, bias=False),
70
+ nn.BatchNorm2d(hidden_size * 8),
71
+ nn.LeakyReLU(0.2, inplace=True),
72
+ # state size. (hidden_size*8) x 4 x 4
73
+ nn.Conv2d(hidden_size * 8, 1, 4, 1, 0, bias=False),
74
+ nn.Sigmoid(),
75
+ )
76
+
77
+ def forward(self, pixel_values):
78
+ logits = self.model(pixel_values)
79
+
80
+ return logits
huggan/pytorch/dcgan/train.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright (c) 2022 PyTorch contributors and The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions.
15
+
16
+ """ Training a Deep Convolutional Generative Adversarial Network (DCGAN) leveraging the 🤗 ecosystem.
17
+ Paper: https://arxiv.org/abs/1511.06434.
18
+ Based on PyTorch's official tutorial: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html.
19
+ """
20
+
21
+
22
+ import argparse
23
+ import logging
24
+ import os
25
+ import sys
26
+ from pathlib import Path
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ from torch.utils.data import DataLoader
31
+ from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
32
+ ToTensor, ToPILImage)
33
+ from torchvision.utils import save_image
34
+
35
+ from PIL import Image, ImageFile
36
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
37
+
38
+ from accelerate import Accelerator
39
+
40
+ from modeling_dcgan import Discriminator, Generator
41
+
42
+ from datasets import load_dataset
43
+
44
+ from huggan.pytorch.metrics.inception import InceptionV3
45
+ from huggan.pytorch.metrics.fid_score import calculate_fretchet
46
+
47
+ import wandb
48
+
49
+ logger = logging.getLogger(__name__)
50
+
51
+
52
+ def parse_args(args=None):
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument("--dataset", type=str, default="mnist", help="Dataset to load from the HuggingFace hub.")
55
+ parser.add_argument("--num_workers", type=int, default=0, help="Number of workers when loading data")
56
+ parser.add_argument("--batch_size", type=int, default=128, help="Batch size to use during training")
57
+ parser.add_argument(
58
+ "--image_size",
59
+ type=int,
60
+ default=64,
61
+ help="Spatial size to use when resizing images for training.",
62
+ )
63
+ parser.add_argument(
64
+ "--num_channels",
65
+ type=int,
66
+ default=3,
67
+ help="Number of channels in the training images. For color images this is 3.",
68
+ )
69
+ parser.add_argument("--latent_dim", type=int, default=100, help="Dimensionality of the latent space.")
70
+ parser.add_argument(
71
+ "--generator_hidden_size",
72
+ type=int,
73
+ default=64,
74
+ help="Hidden size of the generator's feature maps.",
75
+ )
76
+ parser.add_argument(
77
+ "--discriminator_hidden_size",
78
+ type=int,
79
+ default=64,
80
+ help="Hidden size of the discriminator's feature maps.",
81
+ )
82
+ parser.add_argument("--num_epochs", type=int, default=5, help="number of epochs of training")
83
+ parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
84
+ parser.add_argument(
85
+ "--beta1",
86
+ type=float,
87
+ default=0.5,
88
+ help="adam: decay of first order momentum of gradient",
89
+ )
90
+ parser.add_argument("--fp16", action="store_true", help="If passed, will use FP16 training.")
91
+ parser.add_argument(
92
+ "--mixed_precision",
93
+ type=str,
94
+ default="no",
95
+ choices=["no", "fp16", "bf16"],
96
+ help="Whether to use mixed precision. Choose"
97
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
98
+ "and an Nvidia Ampere GPU.",
99
+ )
100
+ parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.")
101
+ parser.add_argument("--output_dir", type=Path, default=Path("./output"), help="Name of the directory to dump generated images during training.")
102
+ parser.add_argument("--wandb", action="store_true", help="If passed, will log to Weights and Biases.")
103
+ parser.add_argument(
104
+ "--logging_steps",
105
+ type=int,
106
+ default=50,
107
+ help="Number of steps between each logging",
108
+ )
109
+ parser.add_argument(
110
+ "--push_to_hub",
111
+ action="store_true",
112
+ help="Whether to push the model to the HuggingFace hub after training.",
113
+ )
114
+ parser.add_argument(
115
+ "--model_name",
116
+ default=None,
117
+ type=str,
118
+ help="Name of the model on the hub.",
119
+ )
120
+ parser.add_argument(
121
+ "--organization_name",
122
+ default="huggan",
123
+ type=str,
124
+ help="Organization name to push to, in case args.push_to_hub is specified.",
125
+ )
126
+ args = parser.parse_args()
127
+
128
+ if args.push_to_hub:
129
+ assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
130
+ assert args.model_name is not None, "Need a `model_name` to create a repo when `--push_to_hub` is passed."
131
+
132
+ if args.output_dir is not None:
133
+ os.makedirs(args.output_dir, exist_ok=True)
134
+
135
+ return args
136
+
137
+
138
+ # Custom weights initialization called on Generator and Discriminator
139
+ def weights_init(m):
140
+ classname = m.__class__.__name__
141
+ if classname.find("Conv") != -1:
142
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
143
+ elif classname.find("BatchNorm") != -1:
144
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
145
+ nn.init.constant_(m.bias.data, 0)
146
+
147
+
148
+ def training_function(config, args):
149
+
150
+ # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
151
+ accelerator = Accelerator(fp16=args.fp16, cpu=args.cpu, mixed_precision=args.mixed_precision)
152
+
153
+ # Setup logging, we only want one process per machine to log things on the screen.
154
+ # accelerator.is_local_main_process is only True for one process per machine.
155
+ logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
156
+ if accelerator.is_local_main_process:
157
+ # set up Weights and Biases if requested
158
+ if args.wandb:
159
+ import wandb
160
+
161
+ wandb.init(project=str(args.output_dir).split("/")[-1])
162
+
163
+ # Loss function
164
+ criterion = nn.BCELoss()
165
+
166
+ # Initialize generator and discriminator
167
+ generator = Generator(
168
+ num_channels=args.num_channels,
169
+ latent_dim=args.latent_dim,
170
+ hidden_size=args.generator_hidden_size,
171
+ )
172
+ discriminator = Discriminator(num_channels=args.num_channels, hidden_size=args.discriminator_hidden_size)
173
+
174
+ # Initialize weights
175
+ generator.apply(weights_init)
176
+ discriminator.apply(weights_init)
177
+
178
+ # Initialize Inceptionv3 (for FID metric)
179
+ model = InceptionV3()
180
+
181
+ # Initialize Inceptionv3 (for FID metric)
182
+ model = InceptionV3()
183
+
184
+ # Create batch of latent vectors that we will use to visualize
185
+ # the progression of the generator
186
+ fixed_noise = torch.randn(64, args.latent_dim, 1, 1, device=accelerator.device)
187
+
188
+ # Establish convention for real and fake labels during training
189
+ real_label = 1.0
190
+ fake_label = 0.0
191
+
192
+ # Setup Adam optimizers for both G and D
193
+ discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
194
+ generator_optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
195
+
196
+ # Configure data loader
197
+ dataset = load_dataset(args.dataset)
198
+
199
+ transform = Compose(
200
+ [
201
+ Resize(args.image_size),
202
+ CenterCrop(args.image_size),
203
+ ToTensor(),
204
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
205
+ ]
206
+ )
207
+
208
+ def transforms(examples):
209
+ examples["pixel_values"] = [transform(image.convert("RGB")) for image in examples["image"]]
210
+
211
+ del examples["image"]
212
+
213
+ return examples
214
+
215
+ transformed_dataset = dataset.with_transform(transforms)
216
+
217
+ dataloader = DataLoader(
218
+ transformed_dataset["train"], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers
219
+ )
220
+
221
+ generator, discriminator, generator_optimizer, discriminator_optimizer, dataloader = accelerator.prepare(generator, discriminator, generator_optimizer, discriminator_optimizer, dataloader)
222
+
223
+ # ----------
224
+ # Training
225
+ # ----------
226
+
227
+ # Training Loop
228
+
229
+ # Lists to keep track of progress
230
+ img_list = []
231
+
232
+ logger.info("***** Running training *****")
233
+ logger.info(f" Num Epochs = {args.num_epochs}")
234
+ # For each epoch
235
+ for epoch in range(args.num_epochs):
236
+ # For each batch in the dataloader
237
+ for step, batch in enumerate(dataloader, 0):
238
+
239
+ ############################
240
+ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
241
+ ###########################
242
+ ## Train with all-real batch
243
+ discriminator.zero_grad()
244
+ # Format batch
245
+ real_cpu = batch["pixel_values"]
246
+ batch_size = real_cpu.size(0)
247
+ label = torch.full((batch_size,), real_label, dtype=torch.float, device=accelerator.device)
248
+ # Forward pass real batch through D
249
+ output = discriminator(real_cpu).view(-1)
250
+ # Calculate loss on all-real batch
251
+ errD_real = criterion(output, label)
252
+ # Calculate gradients for D in backward pass
253
+ accelerator.backward(errD_real)
254
+ D_x = output.mean().item()
255
+
256
+ ## Train with all-fake batch
257
+ # Generate batch of latent vectors
258
+ noise = torch.randn(batch_size, args.latent_dim, 1, 1, device=accelerator.device)
259
+ # Generate fake image batch with G
260
+ fake = generator(noise)
261
+ label.fill_(fake_label)
262
+ # Classify all fake batch with D
263
+ output = discriminator(fake.detach()).view(-1)
264
+ # Calculate D's loss on the all-fake batch
265
+ errD_fake = criterion(output, label)
266
+ # Calculate the gradients for this batch, accumulated (summed) with previous gradients
267
+ accelerator.backward(errD_fake)
268
+ D_G_z1 = output.mean().item()
269
+ # Compute error of D as sum over the fake and the real batches
270
+ errD = errD_real + errD_fake
271
+ # Update D
272
+ discriminator_optimizer.step()
273
+
274
+ ############################
275
+ # (2) Update G network: maximize log(D(G(z)))
276
+ ###########################
277
+ generator.zero_grad()
278
+ label.fill_(real_label) # fake labels are real for generator cost
279
+ # Since we just updated D, perform another forward pass of all-fake batch through D
280
+ output = discriminator(fake).view(-1)
281
+ # Calculate G's loss based on this output
282
+ errG = criterion(output, label)
283
+ # Calculate gradients for G
284
+ accelerator.backward(errG)
285
+ D_G_z2 = output.mean().item()
286
+ # Update G
287
+ generator_optimizer.step()
288
+
289
+ # Log all results
290
+ if (step + 1) % args.logging_steps == 0:
291
+ errD.detach()
292
+ errG.detach()
293
+
294
+ if accelerator.state.num_processes > 1:
295
+ errD = accelerator.gather(errD).sum() / accelerator.state.num_processes
296
+ errG = accelerator.gather(errG).sum() / accelerator.state.num_processes
297
+
298
+ train_logs = {
299
+ "epoch": epoch,
300
+ "discriminator_loss": errD,
301
+ "generator_loss": errG,
302
+ "D_x": D_x,
303
+ "D_G_z1": D_G_z1,
304
+ "D_G_z2": D_G_z2,
305
+ }
306
+ log_str = ""
307
+ for k, v in train_logs.items():
308
+ log_str += "| {}: {:.3e}".format(k, v)
309
+
310
+ if accelerator.is_local_main_process:
311
+ logger.info(log_str)
312
+ if args.wandb:
313
+ wandb.log(train_logs)
314
+
315
+ # Check how the generator is doing by saving G's output on fixed_noise
316
+ if (step % 500 == 0) or ((epoch == args.num_epochs - 1) and (step == len(dataloader) - 1)):
317
+ with torch.no_grad():
318
+ fake_images = generator(fixed_noise).detach().cpu()
319
+ file_name = args.output_dir/f"iter_{step}.png"
320
+ save_image(fake_images.data[:25], file_name, nrow=5, normalize=True)
321
+ if accelerator.is_local_main_process and args.wandb:
322
+ wandb.log({'generated_examples': wandb.Image(str(file_name)) })
323
+
324
+ # Calculate FID metric
325
+ fid = calculate_fretchet(real_cpu, fake, model.to(accelerator.device))
326
+ logger.info(f"FID: {fid}")
327
+ if accelerator.is_local_main_process and args.wandb:
328
+ wandb.log({"FID": fid})
329
+
330
+ # Optionally push to hub
331
+ if accelerator.is_main_process and args.push_to_hub:
332
+ generator.module.push_to_hub(
333
+ repo_path_or_name=args.output_dir / args.model_name,
334
+ organization=args.organization_name,
335
+ )
336
+
337
+
338
+ def main():
339
+ args = parse_args()
340
+ print(args)
341
+
342
+ training_function({}, args)
343
+
344
+
345
+ if __name__ == "__main__":
346
+ main()
huggan/pytorch/huggan_mixin.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from re import TEMPLATE
3
+ from typing import Optional, Union
4
+ import os
5
+
6
+ from huggingface_hub import PyTorchModelHubMixin, HfApi, HfFolder, Repository
7
+
8
+ from huggan import TEMPLATE_MODEL_CARD_PATH
9
+
10
+
11
+ class HugGANModelHubMixin(PyTorchModelHubMixin):
12
+ """A mixin to push PyTorch Models to the Hugging Face Hub. This
13
+ mixin was adapted from the PyTorchModelHubMixin to also push a template
14
+ README.md for the HugGAN sprint.
15
+ """
16
+
17
+ def push_to_hub(
18
+ self,
19
+ repo_path_or_name: Optional[str] = None,
20
+ repo_url: Optional[str] = None,
21
+ commit_message: Optional[str] = "Add model",
22
+ organization: Optional[str] = None,
23
+ private: Optional[bool] = None,
24
+ api_endpoint: Optional[str] = None,
25
+ use_auth_token: Optional[Union[bool, str]] = None,
26
+ git_user: Optional[str] = None,
27
+ git_email: Optional[str] = None,
28
+ config: Optional[dict] = None,
29
+ skip_lfs_files: bool = False,
30
+ default_model_card: Optional[str] = TEMPLATE_MODEL_CARD_PATH
31
+ ) -> str:
32
+ """
33
+ Upload model checkpoint or tokenizer files to the Hub while
34
+ synchronizing a local clone of the repo in `repo_path_or_name`.
35
+ Parameters:
36
+ repo_path_or_name (`str`, *optional*):
37
+ Can either be a repository name for your model or tokenizer in
38
+ the Hub or a path to a local folder (in which case the
39
+ repository will have the name of that local folder). If not
40
+ specified, will default to the name given by `repo_url` and a
41
+ local directory with that name will be created.
42
+ repo_url (`str`, *optional*):
43
+ Specify this in case you want to push to an existing repository
44
+ in the hub. If unspecified, a new repository will be created in
45
+ your namespace (unless you specify an `organization`) with
46
+ `repo_name`.
47
+ commit_message (`str`, *optional*):
48
+ Message to commit while pushing. Will default to `"add config"`,
49
+ `"add tokenizer"` or `"add model"` depending on the type of the
50
+ class.
51
+ organization (`str`, *optional*):
52
+ Organization in which you want to push your model or tokenizer
53
+ (you must be a member of this organization).
54
+ private (`bool`, *optional*):
55
+ Whether the repository created should be private.
56
+ api_endpoint (`str`, *optional*):
57
+ The API endpoint to use when pushing the model to the hub.
58
+ use_auth_token (`bool` or `str`, *optional*):
59
+ The token to use as HTTP bearer authorization for remote files.
60
+ If `True`, will use the token generated when running
61
+ `transformers-cli login` (stored in `~/.huggingface`). Will
62
+ default to `True` if `repo_url` is not specified.
63
+ git_user (`str`, *optional*):
64
+ will override the `git config user.name` for committing and
65
+ pushing files to the hub.
66
+ git_email (`str`, *optional*):
67
+ will override the `git config user.email` for committing and
68
+ pushing files to the hub.
69
+ config (`dict`, *optional*):
70
+ Configuration object to be saved alongside the model weights.
71
+ default_model_card (`str`, *optional*):
72
+ Path to a markdown file to use as your default model card.
73
+ Returns:
74
+ The url of the commit of your model in the given repository.
75
+ """
76
+
77
+ if repo_path_or_name is None and repo_url is None:
78
+ raise ValueError(
79
+ "You need to specify a `repo_path_or_name` or a `repo_url`."
80
+ )
81
+
82
+ if use_auth_token is None and repo_url is None:
83
+ token = HfFolder.get_token()
84
+ if token is None:
85
+ raise ValueError(
86
+ "You must login to the Hugging Face hub on this computer by typing `huggingface-cli login` and "
87
+ "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
88
+ "token as the `use_auth_token` argument."
89
+ )
90
+ elif isinstance(use_auth_token, str):
91
+ token = use_auth_token
92
+ else:
93
+ token = None
94
+
95
+ if repo_path_or_name is None:
96
+ repo_path_or_name = repo_url.split("/")[-1]
97
+
98
+ # If no URL is passed and there's no path to a directory containing files, create a repo
99
+ if repo_url is None and not os.path.exists(repo_path_or_name):
100
+ repo_id = Path(repo_path_or_name).name
101
+ if organization:
102
+ repo_id = f"{organization}/{repo_id}"
103
+ repo_url = HfApi(endpoint=api_endpoint).create_repo(
104
+ repo_id=repo_id,
105
+ token=token,
106
+ private=private,
107
+ repo_type=None,
108
+ exist_ok=True,
109
+ )
110
+
111
+ repo = Repository(
112
+ repo_path_or_name,
113
+ clone_from=repo_url,
114
+ use_auth_token=use_auth_token,
115
+ git_user=git_user,
116
+ git_email=git_email,
117
+ skip_lfs_files=skip_lfs_files
118
+ )
119
+ repo.git_pull(rebase=True)
120
+
121
+ # Save the files in the cloned repo
122
+ self.save_pretrained(repo_path_or_name, config=config)
123
+
124
+ model_card_path = Path(repo_path_or_name) / 'README.md'
125
+ if not model_card_path.exists():
126
+ model_card_path.write_text(TEMPLATE_MODEL_CARD_PATH.read_text())
127
+
128
+ # Commit and push!
129
+ repo.git_add()
130
+ repo.git_commit(commit_message)
131
+ return repo.git_push()
huggan/pytorch/lightweight_gan/README.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Train Lightweight GAN on your custom data
2
+
3
+ This folder contains a script to train ['Lightweight' GAN](https://openreview.net/forum?id=1Fqg133qRaI) for unconditional image generation, leveraging the [Hugging Face](https://huggingface.co/) ecosystem for processing your data and pushing the model to the Hub.
4
+
5
+ The script leverages 🤗 Datasets for loading and processing data, and 🤗 Accelerate for instantly running on CPU, single, multi-GPUs or TPU, also supporting mixed precision.
6
+
7
+ <p align="center">
8
+ <img src="https://raw.githubusercontent.com/lucidrains/lightweight-gan/main/images/pizza-512.jpg" alt="drawing" width="300"/>
9
+ </p>
10
+
11
+ Pizza's that don't exist. Courtesy of Phil Wang.
12
+
13
+ ## Launching the script
14
+
15
+ To train the model with the default parameters on [huggan/CelebA-faces](https://huggingface.co/datasets/huggan/CelebA-faces), first run:
16
+
17
+ ```bash
18
+ accelerate config
19
+ ```
20
+
21
+ and answer the questions asked about your environment. Next, launch the script as follows:
22
+
23
+ ```bash
24
+ accelerate launch cli.py
25
+ ```
26
+
27
+ This will instantly run on multi-GPUs (if you asked for that). To train on another dataset available on the hub, simply do (for instance):
28
+
29
+ ```bash
30
+ accelerate launch cli.py --dataset_name huggan/pokemon
31
+ ```
32
+
33
+ In case you'd like to tweak the script to your liking, first fork the "community-events" [repo](https://github.com/huggingface/community-events) (see the button on the top right), then clone it locally:
34
+
35
+ ```bash
36
+ git clone https://github.com/<your Github username>/community-events.git
37
+ ```
38
+
39
+ and edit to your liking.
40
+
41
+ ## Training on your own data
42
+
43
+ You can of course also train on your own images. For this, one can leverage Datasets' [ImageFolder](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder). Make sure to authenticate with the hub first, by running the `huggingface-cli login` command in a terminal, or the following in case you're working in a notebook:
44
+
45
+ ```python
46
+ from huggingface_hub import notebook_login
47
+
48
+ notebook_login()
49
+ ```
50
+
51
+ Next, run the following in a notebook/script:
52
+
53
+ ```python
54
+ from datasets import load_dataset
55
+
56
+ # first: load dataset
57
+ # option 1: from local folder
58
+ dataset = load_dataset("imagefolder", data_dir="path_to_folder")
59
+ # option 2: from remote URL (e.g. a zip file)
60
+ dataset = load_dataset("imagefolder", data_files="URL to .zip file")
61
+
62
+ # next: push to the hub (assuming git-LFS is installed)
63
+ dataset.push_to_hub("huggan/my-awesome-dataset")
64
+ ```
65
+
66
+ You can then simply pass the name of the dataset to the script:
67
+
68
+ ```bash
69
+ accelerate launch cli.py --dataset huggan/my-awesome-dataset
70
+ ```
71
+
72
+ ## Weights and Biases integration
73
+
74
+ You can easily add logging to [Weights and Biases](https://wandb.ai/site) by passing the `--wandb` flag:
75
+
76
+ ```bash
77
+ accelerate launch cli.py --wandb
78
+ ````
79
+
80
+ You can then follow the progress of your GAN in a browser:
81
+
82
+ <p align="center">
83
+ <img src="https://raw.githubusercontent.com/huggingface/community-events/main/huggan/assets/lightweight_gan_wandb.png" alt="drawing" width="700"/>
84
+ </p>
85
+
86
+
87
+ # Citation
88
+
89
+ This repo is entirely based on lucidrains' [Pytorch implementation](https://github.com/lucidrains/lightweight-gan), but with added HuggingFace goodies.
huggan/pytorch/lightweight_gan/__init__.py ADDED
File without changes
huggan/pytorch/lightweight_gan/cli.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fire
2
+ import random
3
+ from retry.api import retry_call
4
+ from tqdm import tqdm
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ from lightweight_gan import Trainer, NanException
8
+
9
+ import torch
10
+ import torch.multiprocessing as mp
11
+
12
+ import numpy as np
13
+
14
+ def exists(val):
15
+ return val is not None
16
+
17
+ def default(val, d):
18
+ return val if exists(val) else d
19
+
20
+ def cast_list(el):
21
+ return el if isinstance(el, list) else [el]
22
+
23
+ def timestamped_filename(prefix = 'generated-'):
24
+ now = datetime.now()
25
+ timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
26
+ return f'{prefix}{timestamp}'
27
+
28
+ def set_seed(seed):
29
+ torch.manual_seed(seed)
30
+ torch.backends.cudnn.deterministic = True
31
+ torch.backends.cudnn.benchmark = False
32
+ np.random.seed(seed)
33
+ random.seed(seed)
34
+
35
+ def run_training(model_args, data, load_from, new, num_train_steps, name, seed):
36
+
37
+ if seed is not None:
38
+ set_seed(seed)
39
+
40
+ model = Trainer(**model_args)
41
+
42
+ if not new:
43
+ model.load(load_from)
44
+ else:
45
+ model.clear()
46
+
47
+ progress_bar = tqdm(initial = model.steps, total = num_train_steps, mininterval=10., desc=f'{name}<{data}>')
48
+ G, D, D_aug = model.init_accelerator()
49
+
50
+ # model.set_data_src(data)
51
+
52
+ while model.steps < num_train_steps:
53
+ # retry_call(model.train, tries=3, exceptions=NanException)
54
+ model.train(G, D, D_aug)
55
+ progress_bar.n = model.steps
56
+ progress_bar.refresh()
57
+ if model.accelerator.is_local_main_process and model.steps % 50 == 0:
58
+ model.print_log()
59
+
60
+ model.save(model.checkpoint_num)
61
+
62
+ def train_from_folder(
63
+ dataset_name = 'huggan/CelebA-faces',
64
+ data = './data',
65
+ results_dir = './results',
66
+ models_dir = './models',
67
+ name = 'default',
68
+ new = False,
69
+ load_from = -1,
70
+ image_size = 256,
71
+ optimizer = 'adam',
72
+ fmap_max = 512,
73
+ transparent = False,
74
+ greyscale = False,
75
+ batch_size = 10,
76
+ gradient_accumulate_every = 4,
77
+ num_train_steps = 150000,
78
+ learning_rate = 2e-4,
79
+ save_every = 10000,
80
+ evaluate_every = 1000,
81
+ generate = False,
82
+ generate_types = ['default', 'ema'],
83
+ generate_interpolation = False,
84
+ aug_test = False,
85
+ aug_prob=None,
86
+ aug_types=['cutout', 'translation'],
87
+ dataset_aug_prob=0.,
88
+ attn_res_layers = [32],
89
+ freq_chan_attn = False,
90
+ disc_output_size = 1,
91
+ dual_contrast_loss = False,
92
+ antialias = False,
93
+ interpolation_num_steps = 100,
94
+ save_frames = False,
95
+ num_image_tiles = None,
96
+ calculate_fid_every = None,
97
+ calculate_fid_num_images = 12800,
98
+ clear_fid_cache = False,
99
+ seed = 42,
100
+ cpu = False,
101
+ mixed_precision = "no",
102
+ show_progress = False,
103
+ wandb = False,
104
+ push_to_hub = False,
105
+ organization_name = None,
106
+ ):
107
+ if push_to_hub:
108
+ if name == 'default':
109
+ raise RuntimeError(
110
+ "You've chosen to push to hub, but have left the --name flag as 'default'."
111
+ " You should name your model something other than 'default'!"
112
+ )
113
+
114
+ num_image_tiles = default(num_image_tiles, 4 if image_size > 512 else 8)
115
+
116
+ model_args = dict(
117
+ dataset_name = dataset_name,
118
+ name = name,
119
+ results_dir = results_dir,
120
+ models_dir = models_dir,
121
+ batch_size = batch_size,
122
+ gradient_accumulate_every = gradient_accumulate_every,
123
+ attn_res_layers = cast_list(attn_res_layers),
124
+ freq_chan_attn = freq_chan_attn,
125
+ disc_output_size = disc_output_size,
126
+ dual_contrast_loss = dual_contrast_loss,
127
+ antialias = antialias,
128
+ image_size = image_size,
129
+ num_image_tiles = num_image_tiles,
130
+ optimizer = optimizer,
131
+ fmap_max = fmap_max,
132
+ transparent = transparent,
133
+ greyscale = greyscale,
134
+ lr = learning_rate,
135
+ save_every = save_every,
136
+ evaluate_every = evaluate_every,
137
+ aug_prob = aug_prob,
138
+ aug_types = cast_list(aug_types),
139
+ dataset_aug_prob = dataset_aug_prob,
140
+ calculate_fid_every = calculate_fid_every,
141
+ calculate_fid_num_images = calculate_fid_num_images,
142
+ clear_fid_cache = clear_fid_cache,
143
+ cpu = cpu,
144
+ mixed_precision = mixed_precision,
145
+ wandb = wandb,
146
+ push_to_hub = push_to_hub,
147
+ organization_name = organization_name
148
+ )
149
+
150
+ if generate:
151
+ model = Trainer(**model_args)
152
+ model.load(load_from)
153
+ samples_name = timestamped_filename()
154
+ checkpoint = model.checkpoint_num
155
+ dir_result = model.generate(samples_name, num_image_tiles, checkpoint, generate_types)
156
+ print(f'sample images generated at {dir_result}')
157
+ return
158
+
159
+ if generate_interpolation:
160
+ model = Trainer(**model_args)
161
+ model.load(load_from)
162
+ samples_name = timestamped_filename()
163
+ model.generate_interpolation(samples_name, num_image_tiles, num_steps = interpolation_num_steps, save_frames = save_frames)
164
+ print(f'interpolation generated at {results_dir}/{name}/{samples_name}')
165
+ return
166
+
167
+ if show_progress:
168
+ model = Trainer(**model_args)
169
+ model.show_progress(num_images=num_image_tiles, types=generate_types)
170
+ return
171
+
172
+ run_training(model_args, data, load_from, new, num_train_steps, name, seed)
173
+
174
+ def main():
175
+ fire.Fire(train_from_folder)
176
+
177
+ if __name__ == "__main__":
178
+ main()
huggan/pytorch/lightweight_gan/diff_augment.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def DiffAugment(x, types=[]):
8
+ for p in types:
9
+ for f in AUGMENT_FNS[p]:
10
+ x = f(x)
11
+ return x.contiguous()
12
+
13
+
14
+ # """
15
+ # Augmentation functions got images as `x`
16
+ # where `x` is tensor with this dimensions:
17
+ # 0 - count of images
18
+ # 1 - channels
19
+ # 2 - width
20
+ # 3 - height of image
21
+ # """
22
+
23
+ def rand_brightness(x):
24
+ x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
25
+ return x
26
+
27
+ def rand_saturation(x):
28
+ x_mean = x.mean(dim=1, keepdim=True)
29
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
30
+ return x
31
+
32
+ def rand_contrast(x):
33
+ x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
34
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
35
+ return x
36
+
37
+ def rand_translation(x, ratio=0.125):
38
+ shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
39
+ translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
40
+ translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
41
+ grid_batch, grid_x, grid_y = torch.meshgrid(
42
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
43
+ torch.arange(x.size(2), dtype=torch.long, device=x.device),
44
+ torch.arange(x.size(3), dtype=torch.long, device=x.device),
45
+ indexing = 'ij')
46
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
47
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
48
+ x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
49
+ x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
50
+ return x
51
+
52
+ def rand_offset(x, ratio=1, ratio_h=1, ratio_v=1):
53
+ w, h = x.size(2), x.size(3)
54
+
55
+ imgs = []
56
+ for img in x.unbind(dim = 0):
57
+ max_h = int(w * ratio * ratio_h)
58
+ max_v = int(h * ratio * ratio_v)
59
+
60
+ value_h = random.randint(0, max_h) * 2 - max_h
61
+ value_v = random.randint(0, max_v) * 2 - max_v
62
+
63
+ if abs(value_h) > 0:
64
+ img = torch.roll(img, value_h, 2)
65
+
66
+ if abs(value_v) > 0:
67
+ img = torch.roll(img, value_v, 1)
68
+
69
+ imgs.append(img)
70
+
71
+ return torch.stack(imgs)
72
+
73
+ def rand_offset_h(x, ratio=1):
74
+ return rand_offset(x, ratio=1, ratio_h=ratio, ratio_v=0)
75
+
76
+ def rand_offset_v(x, ratio=1):
77
+ return rand_offset(x, ratio=1, ratio_h=0, ratio_v=ratio)
78
+
79
+ def rand_cutout(x, ratio=0.5):
80
+ cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
81
+ offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
82
+ offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
83
+ grid_batch, grid_x, grid_y = torch.meshgrid(
84
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
85
+ torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
86
+ torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
87
+ indexing = 'ij')
88
+ grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
89
+ grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
90
+ mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
91
+ mask[grid_batch, grid_x, grid_y] = 0
92
+ x = x * mask.unsqueeze(1)
93
+ return x
94
+
95
+ AUGMENT_FNS = {
96
+ 'color': [rand_brightness, rand_saturation, rand_contrast],
97
+ 'offset': [rand_offset],
98
+ 'offset_h': [rand_offset_h],
99
+ 'offset_v': [rand_offset_v],
100
+ 'translation': [rand_translation],
101
+ 'cutout': [rand_cutout],
102
+ }
huggan/pytorch/lightweight_gan/lightweight_gan.py ADDED
@@ -0,0 +1,1598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import tempfile
4
+ from random import random
5
+ import math
6
+ from math import log2, floor
7
+ from pathlib import Path
8
+ from functools import partial
9
+ from contextlib import contextmanager, ExitStack
10
+ from pathlib import Path
11
+ from shutil import rmtree
12
+
13
+ import torch
14
+ from torch.optim import Adam
15
+ from torch import nn, einsum
16
+ import torch.nn.functional as F
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from torch.autograd import grad as torch_grad
19
+
20
+ from PIL import Image
21
+ import torchvision
22
+ from torchvision import transforms
23
+ from torchvision.utils import save_image
24
+ from kornia.filters import filter2d
25
+
26
+ from huggan.pytorch.lightweight_gan.diff_augment import DiffAugment
27
+
28
+ from tqdm import tqdm
29
+ from einops import rearrange, reduce, repeat
30
+
31
+ from datasets import load_dataset
32
+
33
+ from accelerate import Accelerator, DistributedDataParallelKwargs
34
+ from huggingface_hub import hf_hub_download, create_repo
35
+
36
+ from huggan.pytorch.huggan_mixin import HugGANModelHubMixin
37
+ from huggan.utils.hub import get_full_repo_name
38
+
39
+ # constants
40
+
41
+ # NUM_CORES = multiprocessing.cpu_count()
42
+ EXTS = ['jpg', 'jpeg', 'png']
43
+ PYTORCH_WEIGHTS_NAME = 'model.pt'
44
+
45
+
46
+ # helpers
47
+
48
+ def exists(val):
49
+ return val is not None
50
+
51
+
52
+ @contextmanager
53
+ def null_context():
54
+ yield
55
+
56
+
57
+ def is_power_of_two(val):
58
+ return log2(val).is_integer()
59
+
60
+
61
+ def default(val, d):
62
+ return val if exists(val) else d
63
+
64
+
65
+ def set_requires_grad(model, bool):
66
+ for p in model.parameters():
67
+ p.requires_grad = bool
68
+
69
+
70
+ def cycle(iterable):
71
+ while True:
72
+ for i in iterable:
73
+ yield i
74
+
75
+
76
+ def raise_if_nan(t):
77
+ if torch.isnan(t):
78
+ raise NanException
79
+
80
+
81
+ def evaluate_in_chunks(max_batch_size, model, *args):
82
+ split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
83
+ chunked_outputs = [model(*i) for i in split_args]
84
+ if len(chunked_outputs) == 1:
85
+ return chunked_outputs[0]
86
+ return torch.cat(chunked_outputs, dim=0)
87
+
88
+
89
+ def slerp(val, low, high):
90
+ low_norm = low / torch.norm(low, dim=1, keepdim=True)
91
+ high_norm = high / torch.norm(high, dim=1, keepdim=True)
92
+ omega = torch.acos((low_norm * high_norm).sum(1))
93
+ so = torch.sin(omega)
94
+ res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
95
+ return res
96
+
97
+
98
+ def safe_div(n, d):
99
+ try:
100
+ res = n / d
101
+ except ZeroDivisionError:
102
+ prefix = '' if int(n >= 0) else '-'
103
+ res = float(f'{prefix}inf')
104
+ return res
105
+
106
+
107
+ # loss functions
108
+
109
+ def gen_hinge_loss(fake, real):
110
+ return fake.mean()
111
+
112
+
113
+ def hinge_loss(real, fake):
114
+ return (F.relu(1 + real) + F.relu(1 - fake)).mean()
115
+
116
+
117
+ def dual_contrastive_loss(real_logits, fake_logits):
118
+ device = real_logits.device
119
+ real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits))
120
+
121
+ def loss_half(t1, t2):
122
+ t1 = rearrange(t1, 'i -> i ()')
123
+ t2 = repeat(t2, 'j -> i j', i=t1.shape[0])
124
+ t = torch.cat((t1, t2), dim=-1)
125
+ return F.cross_entropy(t, torch.zeros(t1.shape[0], device=device, dtype=torch.long))
126
+
127
+ return loss_half(real_logits, fake_logits) + loss_half(-fake_logits, -real_logits)
128
+
129
+
130
+ # helper classes
131
+
132
+ class NanException(Exception):
133
+ pass
134
+
135
+
136
+ class EMA():
137
+ def __init__(self, beta):
138
+ super().__init__()
139
+ self.beta = beta
140
+
141
+ def update_average(self, old, new):
142
+ if not exists(old):
143
+ return new
144
+ return old * self.beta + (1 - self.beta) * new
145
+
146
+
147
+ class RandomApply(nn.Module):
148
+ def __init__(self, prob, fn, fn_else=lambda x: x):
149
+ super().__init__()
150
+ self.fn = fn
151
+ self.fn_else = fn_else
152
+ self.prob = prob
153
+
154
+ def forward(self, x):
155
+ fn = self.fn if random() < self.prob else self.fn_else
156
+ return fn(x)
157
+
158
+
159
+ class ChanNorm(nn.Module):
160
+ def __init__(self, dim, eps=1e-5):
161
+ super().__init__()
162
+ self.eps = eps
163
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
164
+ self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
165
+
166
+ def forward(self, x):
167
+ var = torch.var(x, dim=1, unbiased=False, keepdim=True)
168
+ mean = torch.mean(x, dim=1, keepdim=True)
169
+ return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
170
+
171
+
172
+ class PreNorm(nn.Module):
173
+ def __init__(self, dim, fn):
174
+ super().__init__()
175
+ self.fn = fn
176
+ self.norm = ChanNorm(dim)
177
+
178
+ def forward(self, x):
179
+ return self.fn(self.norm(x))
180
+
181
+
182
+ class Residual(nn.Module):
183
+ def __init__(self, fn):
184
+ super().__init__()
185
+ self.fn = fn
186
+
187
+ def forward(self, x):
188
+ return self.fn(x) + x
189
+
190
+
191
+ class SumBranches(nn.Module):
192
+ def __init__(self, branches):
193
+ super().__init__()
194
+ self.branches = nn.ModuleList(branches)
195
+
196
+ def forward(self, x):
197
+ return sum(map(lambda fn: fn(x), self.branches))
198
+
199
+
200
+ class Fuzziness(nn.Module):
201
+ def __init__(self):
202
+ super().__init__()
203
+ f = torch.Tensor([1, 2, 1])
204
+ self.register_buffer('f', f)
205
+
206
+ def forward(self, x):
207
+ f = self.f
208
+ f = f[None, None, :] * f[None, :, None]
209
+ return filter2d(x, f, normalized=True)
210
+
211
+
212
+ Blur = nn.Identity
213
+
214
+
215
+ # attention
216
+
217
+ class DepthWiseConv2d(nn.Module):
218
+ def __init__(self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True):
219
+ super().__init__()
220
+ self.net = nn.Sequential(
221
+ nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding, groups=dim_in, stride=stride,
222
+ bias=bias),
223
+ nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias)
224
+ )
225
+
226
+ def forward(self, x):
227
+ return self.net(x)
228
+
229
+
230
+ class LinearAttention(nn.Module):
231
+ def __init__(self, dim, dim_head=64, heads=8):
232
+ super().__init__()
233
+ self.scale = dim_head ** -0.5
234
+ self.heads = heads
235
+ inner_dim = dim_head * heads
236
+
237
+ self.nonlin = nn.GELU()
238
+ self.to_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
239
+ self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding=1, bias=False)
240
+ self.to_out = nn.Conv2d(inner_dim, dim, 1)
241
+
242
+ def forward(self, fmap):
243
+ h, x, y = self.heads, *fmap.shape[-2:]
244
+ q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim=1))
245
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h=h), (q, k, v))
246
+
247
+ q = q.softmax(dim=-1)
248
+ k = k.softmax(dim=-2)
249
+
250
+ q = q * self.scale
251
+
252
+ context = einsum('b n d, b n e -> b d e', k, v)
253
+ out = einsum('b n d, b d e -> b n e', q, context)
254
+ out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h=h, x=x, y=y)
255
+
256
+ out = self.nonlin(out)
257
+ return self.to_out(out)
258
+
259
+
260
+ # dataset
261
+
262
+ def convert_image_to(img_type, image):
263
+ if image.mode != img_type:
264
+ return image.convert(img_type)
265
+ return image
266
+
267
+
268
+ class identity(object):
269
+ def __call__(self, tensor):
270
+ return tensor
271
+
272
+
273
+ class expand_greyscale(object):
274
+ def __init__(self, transparent):
275
+ self.transparent = transparent
276
+
277
+ def __call__(self, tensor):
278
+ channels = tensor.shape[0]
279
+ num_target_channels = 4 if self.transparent else 3
280
+
281
+ if channels == num_target_channels:
282
+ return tensor
283
+
284
+ alpha = None
285
+ if channels == 1:
286
+ color = tensor.expand(3, -1, -1)
287
+ elif channels == 2:
288
+ color = tensor[:1].expand(3, -1, -1)
289
+ alpha = tensor[1:]
290
+ else:
291
+ raise Exception(f'image with invalid number of channels given {channels}')
292
+
293
+ if not exists(alpha) and self.transparent:
294
+ alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device)
295
+
296
+ return color if not self.transparent else torch.cat((color, alpha))
297
+
298
+
299
+ def resize_to_minimum_size(min_size, image):
300
+ if max(*image.size) < min_size:
301
+ return torchvision.transforms.functional.resize(image, min_size)
302
+ return image
303
+
304
+
305
+ # augmentations
306
+
307
+ def random_hflip(tensor, prob):
308
+ if prob > random():
309
+ return tensor
310
+ return torch.flip(tensor, dims=(3,))
311
+
312
+
313
+ class AugWrapper(nn.Module):
314
+ def __init__(self, D, image_size):
315
+ super().__init__()
316
+ self.D = D
317
+
318
+ def forward(self, images, prob=0., types=[], detach=False, **kwargs):
319
+ context = torch.no_grad if detach else null_context
320
+
321
+ with context():
322
+ if random() < prob:
323
+ images = random_hflip(images, prob=0.5)
324
+ images = DiffAugment(images, types=types)
325
+
326
+ return self.D(images, **kwargs)
327
+
328
+
329
+ # modifiable global variables
330
+
331
+ norm_class = nn.BatchNorm2d
332
+
333
+
334
+ def upsample(scale_factor=2):
335
+ return nn.Upsample(scale_factor=scale_factor)
336
+
337
+
338
+ # squeeze excitation classes
339
+
340
+ # global context network
341
+ # https://arxiv.org/abs/2012.13375
342
+ # similar to squeeze-excite, but with a simplified attention pooling and a subsequent layer norm
343
+
344
+ class GlobalContext(nn.Module):
345
+ def __init__(
346
+ self,
347
+ *,
348
+ chan_in,
349
+ chan_out
350
+ ):
351
+ super().__init__()
352
+ self.to_k = nn.Conv2d(chan_in, 1, 1)
353
+ chan_intermediate = max(3, chan_out // 2)
354
+
355
+ self.net = nn.Sequential(
356
+ nn.Conv2d(chan_in, chan_intermediate, 1),
357
+ nn.LeakyReLU(0.1),
358
+ nn.Conv2d(chan_intermediate, chan_out, 1),
359
+ nn.Sigmoid()
360
+ )
361
+
362
+ def forward(self, x):
363
+ context = self.to_k(x)
364
+ context = context.flatten(2).softmax(dim=-1)
365
+ out = einsum('b i n, b c n -> b c i', context, x.flatten(2))
366
+ out = out.unsqueeze(-1)
367
+ return self.net(out)
368
+
369
+
370
+ # frequency channel attention
371
+ # https://arxiv.org/abs/2012.11879
372
+
373
+ def get_1d_dct(i, freq, L):
374
+ result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L)
375
+ return result * (1 if freq == 0 else math.sqrt(2))
376
+
377
+
378
+ def get_dct_weights(width, channel, fidx_u, fidx_v):
379
+ dct_weights = torch.zeros(1, channel, width, width)
380
+ c_part = channel // len(fidx_u)
381
+
382
+ for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)):
383
+ for x in range(width):
384
+ for y in range(width):
385
+ coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width)
386
+ dct_weights[:, i * c_part: (i + 1) * c_part, x, y] = coor_value
387
+
388
+ return dct_weights
389
+
390
+
391
+ class FCANet(nn.Module):
392
+ def __init__(
393
+ self,
394
+ *,
395
+ chan_in,
396
+ chan_out,
397
+ reduction=4,
398
+ width
399
+ ):
400
+ super().__init__()
401
+
402
+ freq_w, freq_h = ([0] * 8), list(range(8)) # in paper, it seems 16 frequencies was ideal
403
+ dct_weights = get_dct_weights(width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w])
404
+ self.register_buffer('dct_weights', dct_weights)
405
+
406
+ chan_intermediate = max(3, chan_out // reduction)
407
+
408
+ self.net = nn.Sequential(
409
+ nn.Conv2d(chan_in, chan_intermediate, 1),
410
+ nn.LeakyReLU(0.1),
411
+ nn.Conv2d(chan_intermediate, chan_out, 1),
412
+ nn.Sigmoid()
413
+ )
414
+
415
+ def forward(self, x):
416
+ x = reduce(x * self.dct_weights, 'b c (h h1) (w w1) -> b c h1 w1', 'sum', h1=1, w1=1)
417
+ return self.net(x)
418
+
419
+
420
+ # generative adversarial network
421
+
422
+ class Generator(nn.Module):
423
+ def __init__(
424
+ self,
425
+ *,
426
+ image_size,
427
+ latent_dim=256,
428
+ fmap_max=512,
429
+ fmap_inverse_coef=12,
430
+ transparent=False,
431
+ greyscale=False,
432
+ attn_res_layers=[],
433
+ freq_chan_attn=False
434
+ ):
435
+ super().__init__()
436
+ resolution = log2(image_size)
437
+ assert is_power_of_two(image_size), 'image size must be a power of 2'
438
+
439
+ if transparent:
440
+ init_channel = 4
441
+ elif greyscale:
442
+ init_channel = 1
443
+ else:
444
+ init_channel = 3
445
+
446
+ fmap_max = default(fmap_max, latent_dim)
447
+
448
+ self.initial_conv = nn.Sequential(
449
+ nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4),
450
+ norm_class(latent_dim * 2),
451
+ nn.GLU(dim=1)
452
+ )
453
+
454
+ num_layers = int(resolution) - 2
455
+ features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2)))
456
+ features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
457
+ features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))
458
+ features = [latent_dim, *features]
459
+
460
+ in_out_features = list(zip(features[:-1], features[1:]))
461
+
462
+ self.res_layers = range(2, num_layers + 2)
463
+ self.layers = nn.ModuleList([])
464
+ self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))
465
+
466
+ self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10))
467
+ self.sle_map = list(filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map))
468
+ self.sle_map = dict(self.sle_map)
469
+
470
+ self.num_layers_spatial_res = 1
471
+
472
+ for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features):
473
+ image_width = 2 ** res
474
+
475
+ attn = None
476
+ if image_width in attn_res_layers:
477
+ attn = PreNorm(chan_in, LinearAttention(chan_in))
478
+
479
+ sle = None
480
+ if res in self.sle_map:
481
+ residual_layer = self.sle_map[res]
482
+ sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]
483
+
484
+ if freq_chan_attn:
485
+ sle = FCANet(
486
+ chan_in=chan_out,
487
+ chan_out=sle_chan_out,
488
+ width=2 ** (res + 1)
489
+ )
490
+ else:
491
+ sle = GlobalContext(
492
+ chan_in=chan_out,
493
+ chan_out=sle_chan_out
494
+ )
495
+
496
+ layer = nn.ModuleList([
497
+ nn.Sequential(
498
+ upsample(),
499
+ Blur(),
500
+ nn.Conv2d(chan_in, chan_out * 2, 3, padding=1),
501
+ norm_class(chan_out * 2),
502
+ nn.GLU(dim=1)
503
+ ),
504
+ sle,
505
+ attn
506
+ ])
507
+ self.layers.append(layer)
508
+
509
+ self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1)
510
+
511
+ def forward(self, x):
512
+ x = rearrange(x, 'b c -> b c () ()')
513
+ x = self.initial_conv(x)
514
+ x = F.normalize(x, dim=1)
515
+
516
+ residuals = dict()
517
+
518
+ for (res, (up, sle, attn)) in zip(self.res_layers, self.layers):
519
+ if exists(attn):
520
+ x = attn(x) + x
521
+
522
+ x = up(x)
523
+
524
+ if exists(sle):
525
+ out_res = self.sle_map[res]
526
+ residual = sle(x)
527
+ residuals[out_res] = residual
528
+
529
+ next_res = res + 1
530
+ if next_res in residuals:
531
+ x = x * residuals[next_res]
532
+
533
+ return self.out_conv(x)
534
+
535
+
536
+ class SimpleDecoder(nn.Module):
537
+ def __init__(
538
+ self,
539
+ *,
540
+ chan_in,
541
+ chan_out=3,
542
+ num_upsamples=4,
543
+ ):
544
+ super().__init__()
545
+
546
+ self.layers = nn.ModuleList([])
547
+ final_chan = chan_out
548
+ chans = chan_in
549
+
550
+ for ind in range(num_upsamples):
551
+ last_layer = ind == (num_upsamples - 1)
552
+ chan_out = chans if not last_layer else final_chan * 2
553
+ layer = nn.Sequential(
554
+ upsample(),
555
+ nn.Conv2d(chans, chan_out, 3, padding=1),
556
+ nn.GLU(dim=1)
557
+ )
558
+ self.layers.append(layer)
559
+ chans //= 2
560
+
561
+ def forward(self, x):
562
+ for layer in self.layers:
563
+ x = layer(x)
564
+ return x
565
+
566
+
567
+ class Discriminator(nn.Module):
568
+ def __init__(
569
+ self,
570
+ *,
571
+ image_size,
572
+ fmap_max=512,
573
+ fmap_inverse_coef=12,
574
+ transparent=False,
575
+ greyscale=False,
576
+ disc_output_size=5,
577
+ attn_res_layers=[]
578
+ ):
579
+ super().__init__()
580
+ resolution = log2(image_size)
581
+ assert is_power_of_two(image_size), 'image size must be a power of 2'
582
+ assert disc_output_size in {1, 5}, 'discriminator output dimensions can only be 5x5 or 1x1'
583
+
584
+ resolution = int(resolution)
585
+
586
+ if transparent:
587
+ init_channel = 4
588
+ elif greyscale:
589
+ init_channel = 1
590
+ else:
591
+ init_channel = 3
592
+
593
+ num_non_residual_layers = max(0, int(resolution) - 8)
594
+ num_residual_layers = 8 - 3
595
+
596
+ non_residual_resolutions = range(min(8, resolution), 2, -1)
597
+ features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), non_residual_resolutions))
598
+ features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
599
+
600
+ if num_non_residual_layers == 0:
601
+ res, _ = features[0]
602
+ features[0] = (res, init_channel)
603
+
604
+ chan_in_out = list(zip(features[:-1], features[1:]))
605
+
606
+ self.non_residual_layers = nn.ModuleList([])
607
+ for ind in range(num_non_residual_layers):
608
+ first_layer = ind == 0
609
+ last_layer = ind == (num_non_residual_layers - 1)
610
+ chan_out = features[0][-1] if last_layer else init_channel
611
+
612
+ self.non_residual_layers.append(nn.Sequential(
613
+ Blur(),
614
+ nn.Conv2d(init_channel, chan_out, 4, stride=2, padding=1),
615
+ nn.LeakyReLU(0.1)
616
+ ))
617
+
618
+ self.residual_layers = nn.ModuleList([])
619
+
620
+ for (res, ((_, chan_in), (_, chan_out))) in zip(non_residual_resolutions, chan_in_out):
621
+ image_width = 2 ** res
622
+
623
+ attn = None
624
+ if image_width in attn_res_layers:
625
+ attn = PreNorm(chan_in, LinearAttention(chan_in))
626
+
627
+ self.residual_layers.append(nn.ModuleList([
628
+ SumBranches([
629
+ nn.Sequential(
630
+ Blur(),
631
+ nn.Conv2d(chan_in, chan_out, 4, stride=2, padding=1),
632
+ nn.LeakyReLU(0.1),
633
+ nn.Conv2d(chan_out, chan_out, 3, padding=1),
634
+ nn.LeakyReLU(0.1)
635
+ ),
636
+ nn.Sequential(
637
+ Blur(),
638
+ nn.AvgPool2d(2),
639
+ nn.Conv2d(chan_in, chan_out, 1),
640
+ nn.LeakyReLU(0.1),
641
+ )
642
+ ]),
643
+ attn
644
+ ]))
645
+
646
+ last_chan = features[-1][-1]
647
+ if disc_output_size == 5:
648
+ self.to_logits = nn.Sequential(
649
+ nn.Conv2d(last_chan, last_chan, 1),
650
+ nn.LeakyReLU(0.1),
651
+ nn.Conv2d(last_chan, 1, 4)
652
+ )
653
+ elif disc_output_size == 1:
654
+ self.to_logits = nn.Sequential(
655
+ Blur(),
656
+ nn.Conv2d(last_chan, last_chan, 3, stride=2, padding=1),
657
+ nn.LeakyReLU(0.1),
658
+ nn.Conv2d(last_chan, 1, 4)
659
+ )
660
+
661
+ self.to_shape_disc_out = nn.Sequential(
662
+ nn.Conv2d(init_channel, 64, 3, padding=1),
663
+ Residual(PreNorm(64, LinearAttention(64))),
664
+ SumBranches([
665
+ nn.Sequential(
666
+ Blur(),
667
+ nn.Conv2d(64, 32, 4, stride=2, padding=1),
668
+ nn.LeakyReLU(0.1),
669
+ nn.Conv2d(32, 32, 3, padding=1),
670
+ nn.LeakyReLU(0.1)
671
+ ),
672
+ nn.Sequential(
673
+ Blur(),
674
+ nn.AvgPool2d(2),
675
+ nn.Conv2d(64, 32, 1),
676
+ nn.LeakyReLU(0.1),
677
+ )
678
+ ]),
679
+ Residual(PreNorm(32, LinearAttention(32))),
680
+ nn.AdaptiveAvgPool2d((4, 4)),
681
+ nn.Conv2d(32, 1, 4)
682
+ )
683
+
684
+ self.decoder1 = SimpleDecoder(chan_in=last_chan, chan_out=init_channel)
685
+ self.decoder2 = SimpleDecoder(chan_in=features[-2][-1], chan_out=init_channel) if resolution >= 9 else None
686
+
687
+ def forward(self, x, calc_aux_loss=False):
688
+ orig_img = x
689
+
690
+ for layer in self.non_residual_layers:
691
+ x = layer(x)
692
+
693
+ layer_outputs = []
694
+
695
+ for (net, attn) in self.residual_layers:
696
+ if exists(attn):
697
+ x = attn(x) + x
698
+
699
+ x = net(x)
700
+ layer_outputs.append(x)
701
+
702
+ out = self.to_logits(x).flatten(1)
703
+
704
+ img_32x32 = F.interpolate(orig_img, size=(32, 32))
705
+ out_32x32 = self.to_shape_disc_out(img_32x32)
706
+
707
+ if not calc_aux_loss:
708
+ return out, out_32x32, None
709
+
710
+ # self-supervised auto-encoding loss
711
+
712
+ layer_8x8 = layer_outputs[-1]
713
+ layer_16x16 = layer_outputs[-2]
714
+
715
+ recon_img_8x8 = self.decoder1(layer_8x8)
716
+
717
+ aux_loss = F.mse_loss(
718
+ recon_img_8x8,
719
+ F.interpolate(orig_img, size=recon_img_8x8.shape[2:])
720
+ )
721
+
722
+ if exists(self.decoder2):
723
+ select_random_quadrant = lambda rand_quadrant, img: \
724
+ rearrange(img, 'b c (m h) (n w) -> (m n) b c h w', m=2, n=2)[rand_quadrant]
725
+ crop_image_fn = partial(select_random_quadrant, floor(random() * 4))
726
+ img_part, layer_16x16_part = map(crop_image_fn, (orig_img, layer_16x16))
727
+
728
+ recon_img_16x16 = self.decoder2(layer_16x16_part)
729
+
730
+ aux_loss_16x16 = F.mse_loss(
731
+ recon_img_16x16,
732
+ F.interpolate(img_part, size=recon_img_16x16.shape[2:])
733
+ )
734
+
735
+ aux_loss = aux_loss + aux_loss_16x16
736
+
737
+ return out, out_32x32, aux_loss
738
+
739
+
740
+ class LightweightGAN(nn.Module, HugGANModelHubMixin):
741
+ def __init__(
742
+ self,
743
+ *,
744
+ latent_dim,
745
+ image_size,
746
+ optimizer="adam",
747
+ fmap_max=512,
748
+ fmap_inverse_coef=12,
749
+ transparent=False,
750
+ greyscale=False,
751
+ disc_output_size=5,
752
+ attn_res_layers=[],
753
+ freq_chan_attn=False,
754
+ ttur_mult=1.,
755
+ lr=2e-4,
756
+ ):
757
+ super().__init__()
758
+
759
+ self.config = {
760
+ 'latent_dim': latent_dim,
761
+ 'image_size': image_size,
762
+ 'optimizer': optimizer,
763
+ 'fmap_max': fmap_max,
764
+ 'fmap_inverse_coef': fmap_inverse_coef,
765
+ 'transparent': transparent,
766
+ 'greyscale': greyscale,
767
+ 'disc_output_size': disc_output_size,
768
+ 'attn_res_layers': attn_res_layers,
769
+ 'freq_chan_attn': freq_chan_attn,
770
+ 'ttur_mult': ttur_mult,
771
+ 'lr': lr
772
+ }
773
+
774
+ self.latent_dim = latent_dim
775
+ self.image_size = image_size
776
+
777
+ G_kwargs = dict(
778
+ image_size=image_size,
779
+ latent_dim=latent_dim,
780
+ fmap_max=fmap_max,
781
+ fmap_inverse_coef=fmap_inverse_coef,
782
+ transparent=transparent,
783
+ greyscale=greyscale,
784
+ attn_res_layers=attn_res_layers,
785
+ freq_chan_attn=freq_chan_attn
786
+ )
787
+
788
+ self.G = Generator(**G_kwargs)
789
+
790
+ self.D = Discriminator(
791
+ image_size=image_size,
792
+ fmap_max=fmap_max,
793
+ fmap_inverse_coef=fmap_inverse_coef,
794
+ transparent=transparent,
795
+ greyscale=greyscale,
796
+ attn_res_layers=attn_res_layers,
797
+ disc_output_size=disc_output_size
798
+ )
799
+
800
+ self.ema_updater = EMA(0.995)
801
+ self.GE = Generator(**G_kwargs)
802
+ set_requires_grad(self.GE, False)
803
+
804
+ if optimizer == "adam":
805
+ self.G_opt = Adam(self.G.parameters(), lr=lr, betas=(0.5, 0.9))
806
+ self.D_opt = Adam(self.D.parameters(), lr=lr * ttur_mult, betas=(0.5, 0.9))
807
+ elif optimizer == "adabelief":
808
+ from adabelief_pytorch import AdaBelief
809
+
810
+ self.G_opt = AdaBelief(self.G.parameters(), lr=lr, betas=(0.5, 0.9))
811
+ self.D_opt = AdaBelief(self.D.parameters(), lr=lr * ttur_mult, betas=(0.5, 0.9))
812
+ else:
813
+ assert False, "No valid optimizer is given"
814
+
815
+ self.apply(self._init_weights)
816
+ self.reset_parameter_averaging()
817
+
818
+ self.D_aug = AugWrapper(self.D, image_size)
819
+
820
+ def _init_weights(self, m):
821
+ if type(m) in {nn.Conv2d, nn.Linear}:
822
+ nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
823
+
824
+ def EMA(self):
825
+ def update_moving_average(ma_model, current_model):
826
+ for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
827
+ old_weight, up_weight = ma_params.data, current_params.data
828
+ ma_params.data = self.ema_updater.update_average(old_weight, up_weight)
829
+
830
+ for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
831
+ new_buffer_value = self.ema_updater.update_average(ma_buffer, current_buffer)
832
+ ma_buffer.copy_(new_buffer_value)
833
+
834
+ update_moving_average(self.GE, self.G)
835
+
836
+ def reset_parameter_averaging(self):
837
+ self.GE.load_state_dict(self.G.state_dict())
838
+
839
+ def forward(self, x):
840
+ raise NotImplemented
841
+
842
+ def _save_pretrained(self, save_directory):
843
+ """
844
+ Overwrite this method in case you don't want to save complete model,
845
+ rather some specific layers
846
+ """
847
+ path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME)
848
+ model_to_save = self.module if hasattr(self, "module") else self
849
+
850
+ # We update this to be a dict containing 'GAN', as that's what is expected
851
+ torch.save({'GAN': model_to_save.state_dict()}, path)
852
+
853
+ @classmethod
854
+ def _from_pretrained(
855
+ cls,
856
+ model_id,
857
+ revision,
858
+ cache_dir,
859
+ force_download,
860
+ proxies,
861
+ resume_download,
862
+ local_files_only,
863
+ token,
864
+ map_location="cpu",
865
+ strict=False,
866
+ **model_kwargs,
867
+ ):
868
+ """
869
+ Overwrite this method in case you wish to initialize your model in a
870
+ different way.
871
+ """
872
+ map_location = torch.device(map_location)
873
+
874
+ if os.path.isdir(model_id):
875
+ print("Loading weights from local directory")
876
+ model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
877
+ else:
878
+ model_file = hf_hub_download(
879
+ repo_id=model_id,
880
+ filename=PYTORCH_WEIGHTS_NAME,
881
+ revision=revision,
882
+ cache_dir=cache_dir,
883
+ force_download=force_download,
884
+ proxies=proxies,
885
+ resume_download=resume_download,
886
+ token=token,
887
+ local_files_only=local_files_only,
888
+ )
889
+
890
+ # We update here to directly unpack config
891
+ model = cls(**model_kwargs['config'])
892
+
893
+ state_dict = torch.load(model_file, map_location=map_location)
894
+ model.load_state_dict(state_dict["GAN"], strict=strict)
895
+ model.eval()
896
+
897
+ return model
898
+
899
+
900
+ # trainer
901
+
902
+ class Trainer():
903
+ def __init__(
904
+ self,
905
+ dataset_name="huggan/CelebA-faces",
906
+ name='default',
907
+ results_dir='results',
908
+ models_dir='models',
909
+ base_dir='./',
910
+ optimizer='adam',
911
+ latent_dim=256,
912
+ image_size=128,
913
+ num_image_tiles=8,
914
+ fmap_max=512,
915
+ transparent=False,
916
+ greyscale=False,
917
+ batch_size=4,
918
+ gp_weight=10,
919
+ gradient_accumulate_every=1,
920
+ attn_res_layers=[],
921
+ freq_chan_attn=False,
922
+ disc_output_size=5,
923
+ dual_contrast_loss=False,
924
+ antialias=False,
925
+ lr=2e-4,
926
+ lr_mlp=1.,
927
+ ttur_mult=1.,
928
+ save_every=10000,
929
+ evaluate_every=1000,
930
+ aug_prob=None,
931
+ aug_types=['translation', 'cutout'],
932
+ dataset_aug_prob=0.,
933
+ calculate_fid_every=None,
934
+ calculate_fid_num_images=12800,
935
+ clear_fid_cache=False,
936
+ log=False,
937
+ cpu=False,
938
+ mixed_precision="no",
939
+ wandb=False,
940
+ push_to_hub=False,
941
+ organization_name=None,
942
+ *args,
943
+ **kwargs
944
+ ):
945
+ self.GAN_params = [args, kwargs]
946
+ self.GAN = None
947
+
948
+ self.dataset_name = dataset_name
949
+
950
+ self.name = name
951
+
952
+ base_dir = Path(base_dir)
953
+ self.base_dir = base_dir
954
+ self.results_dir = base_dir / results_dir
955
+ self.models_dir = base_dir / models_dir
956
+ self.fid_dir = base_dir / 'fid' / name
957
+
958
+ # Note - in original repo config is private - ".config.json", but here, we make it public
959
+ self.config_path = self.models_dir / name / 'config.json'
960
+
961
+ assert is_power_of_two(image_size), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
962
+ assert all(map(is_power_of_two,
963
+ attn_res_layers)), 'resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)'
964
+
965
+ assert not (
966
+ dual_contrast_loss and disc_output_size > 1), 'discriminator output size cannot be greater than 1 if using dual contrastive loss'
967
+
968
+ self.image_size = image_size
969
+ self.num_image_tiles = num_image_tiles
970
+
971
+ self.latent_dim = latent_dim
972
+ self.fmap_max = fmap_max
973
+ self.transparent = transparent
974
+ self.greyscale = greyscale
975
+
976
+ assert (int(self.transparent) + int(self.greyscale)) < 2, 'you can only set either transparency or greyscale'
977
+
978
+ self.aug_prob = aug_prob
979
+ self.aug_types = aug_types
980
+
981
+ self.lr = lr
982
+ self.optimizer = optimizer
983
+ self.ttur_mult = ttur_mult
984
+ self.batch_size = batch_size
985
+ self.gradient_accumulate_every = gradient_accumulate_every
986
+
987
+ self.gp_weight = gp_weight
988
+
989
+ self.evaluate_every = evaluate_every
990
+ self.save_every = save_every
991
+ self.steps = 0
992
+
993
+ self.attn_res_layers = attn_res_layers
994
+ self.freq_chan_attn = freq_chan_attn
995
+
996
+ self.disc_output_size = disc_output_size
997
+ self.antialias = antialias
998
+
999
+ self.dual_contrast_loss = dual_contrast_loss
1000
+
1001
+ self.d_loss = 0
1002
+ self.g_loss = 0
1003
+ self.last_gp_loss = None
1004
+ self.last_recon_loss = None
1005
+ self.last_fid = None
1006
+
1007
+ self.init_folders()
1008
+
1009
+ self.loader = None
1010
+ self.dataset_aug_prob = dataset_aug_prob
1011
+
1012
+ self.calculate_fid_every = calculate_fid_every
1013
+ self.calculate_fid_num_images = calculate_fid_num_images
1014
+ self.clear_fid_cache = clear_fid_cache
1015
+
1016
+ self.syncbatchnorm = torch.cuda.device_count() > 1 and not cpu
1017
+
1018
+ self.cpu = cpu
1019
+ self.mixed_precision = mixed_precision
1020
+
1021
+ self.wandb = wandb
1022
+
1023
+ self.push_to_hub = push_to_hub
1024
+ self.organization_name = organization_name
1025
+ self.repo_name = get_full_repo_name(self.name, self.organization_name)
1026
+ if self.push_to_hub:
1027
+ self.repo_url = create_repo(self.repo_name, exist_ok=True)
1028
+
1029
+ @property
1030
+ def image_extension(self):
1031
+ return 'jpg' if not self.transparent else 'png'
1032
+
1033
+ @property
1034
+ def checkpoint_num(self):
1035
+ return floor(self.steps // self.save_every)
1036
+
1037
+ def init_GAN(self):
1038
+ args, kwargs = self.GAN_params
1039
+
1040
+ # set some global variables before instantiating GAN
1041
+
1042
+ global norm_class
1043
+ global Blur
1044
+
1045
+ norm_class = nn.SyncBatchNorm if self.syncbatchnorm else nn.BatchNorm2d
1046
+ Blur = nn.Identity if not self.antialias else Fuzziness
1047
+
1048
+ # instantiate GAN
1049
+
1050
+ self.GAN = LightweightGAN(
1051
+ optimizer=self.optimizer,
1052
+ lr=self.lr,
1053
+ latent_dim=self.latent_dim,
1054
+ attn_res_layers=self.attn_res_layers,
1055
+ freq_chan_attn=self.freq_chan_attn,
1056
+ image_size=self.image_size,
1057
+ ttur_mult=self.ttur_mult,
1058
+ fmap_max=self.fmap_max,
1059
+ disc_output_size=self.disc_output_size,
1060
+ transparent=self.transparent,
1061
+ greyscale=self.greyscale,
1062
+ *args,
1063
+ **kwargs
1064
+ )
1065
+
1066
+ def write_config(self):
1067
+ self.config_path.write_text(json.dumps(self.config()))
1068
+
1069
+ def load_config(self):
1070
+ config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text())
1071
+ self.image_size = config['image_size']
1072
+ self.transparent = config['transparent']
1073
+ self.syncbatchnorm = config['syncbatchnorm']
1074
+ self.disc_output_size = config['disc_output_size']
1075
+ self.greyscale = config.pop('greyscale', False)
1076
+ self.attn_res_layers = config.pop('attn_res_layers', [])
1077
+ self.freq_chan_attn = config.pop('freq_chan_attn', False)
1078
+ self.optimizer = config.pop('optimizer', 'adam')
1079
+ self.fmap_max = config.pop('fmap_max', 512)
1080
+ del self.GAN
1081
+ self.init_GAN()
1082
+
1083
+ def config(self):
1084
+ return {
1085
+ 'image_size': self.image_size,
1086
+ 'transparent': self.transparent,
1087
+ 'greyscale': self.greyscale,
1088
+ 'syncbatchnorm': self.syncbatchnorm,
1089
+ 'disc_output_size': self.disc_output_size,
1090
+ 'optimizer': self.optimizer,
1091
+ 'attn_res_layers': self.attn_res_layers,
1092
+ 'freq_chan_attn': self.freq_chan_attn
1093
+ }
1094
+
1095
+ def set_data_src(self):
1096
+ # start of using HuggingFace dataset
1097
+ dataset = load_dataset(self.dataset_name)
1098
+
1099
+ if self.transparent:
1100
+ num_channels = 4
1101
+ pillow_mode = 'RGBA'
1102
+ expand_fn = expand_greyscale(self.transparent)
1103
+ elif self.greyscale:
1104
+ num_channels = 1
1105
+ pillow_mode = 'L'
1106
+ expand_fn = identity()
1107
+ else:
1108
+ num_channels = 3
1109
+ pillow_mode = 'RGB'
1110
+ expand_fn = expand_greyscale(self.transparent)
1111
+
1112
+ convert_image_fn = partial(convert_image_to, pillow_mode)
1113
+
1114
+ transform = transforms.Compose([
1115
+ transforms.Lambda(convert_image_fn),
1116
+ transforms.Lambda(partial(resize_to_minimum_size, self.image_size)),
1117
+ transforms.Resize(self.image_size),
1118
+ RandomApply(0., transforms.RandomResizedCrop(self.image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)),
1119
+ transforms.CenterCrop(self.image_size)),
1120
+ transforms.ToTensor(),
1121
+ transforms.Lambda(expand_fn)
1122
+ ])
1123
+
1124
+ def transform_images(examples):
1125
+ transformed_images = [transform(image.convert("RGB")) for image in examples["image"]]
1126
+
1127
+ examples["image"] = torch.stack(transformed_images)
1128
+
1129
+ return examples
1130
+
1131
+ transformed_dataset = dataset.with_transform(transform_images)
1132
+
1133
+ per_device_batch_size = math.ceil(self.batch_size / self.accelerator.num_processes)
1134
+ dataloader = DataLoader(transformed_dataset["train"], per_device_batch_size, sampler=None, shuffle=False,
1135
+ drop_last=True, pin_memory=True)
1136
+ num_samples = len(transformed_dataset)
1137
+ ## end of HuggingFace dataset
1138
+
1139
+ # Note - in original repo, this is wrapped with cycle, but we will do that after accelerator prepares
1140
+ self.loader = dataloader
1141
+
1142
+ # auto set augmentation prob for user if dataset is detected to be low
1143
+ # num_samples = len(self.dataset)
1144
+ if not exists(self.aug_prob) and num_samples < 1e5:
1145
+ self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6)
1146
+ print(f'autosetting augmentation probability to {round(self.aug_prob * 100)}%')
1147
+
1148
+ def init_accelerator(self):
1149
+ # Initialize the accelerator. We will let the accelerator handle device placement.
1150
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
1151
+ self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs], mixed_precision=self.mixed_precision, cpu=self.cpu)
1152
+
1153
+ if self.accelerator.is_local_main_process:
1154
+ # set up Weights and Biases if requested
1155
+ if self.wandb:
1156
+ import wandb
1157
+
1158
+ wandb.init(project=str(self.results_dir).split("/")[-1])
1159
+
1160
+ if not exists(self.GAN):
1161
+ self.init_GAN()
1162
+
1163
+ G = self.GAN.G
1164
+ D = self.GAN.D
1165
+ D_aug = self.GAN.D_aug
1166
+
1167
+ # discriminator loss fn
1168
+
1169
+ self.set_data_src()
1170
+
1171
+ # prepare
1172
+ G, D, D_aug, self.GAN.D_opt, self.GAN.G_opt, self.loader = self.accelerator.prepare(G, D, D_aug, self.GAN.D_opt,
1173
+ self.GAN.G_opt, self.loader)
1174
+ self.loader = cycle(self.loader)
1175
+
1176
+ return G, D, D_aug
1177
+
1178
+ def train(self, G, D, D_aug):
1179
+ assert exists(self.loader), 'You must first initialize the data source with `.set_data_src(<folder of images>)`'
1180
+
1181
+ self.GAN.train()
1182
+ total_disc_loss = torch.zeros([], device=self.accelerator.device)
1183
+ total_gen_loss = torch.zeros([], device=self.accelerator.device)
1184
+
1185
+ batch_size = math.ceil(self.batch_size / self.accelerator.num_processes)
1186
+
1187
+ image_size = self.GAN.image_size
1188
+ latent_dim = self.GAN.latent_dim
1189
+
1190
+ aug_prob = default(self.aug_prob, 0)
1191
+ aug_types = self.aug_types
1192
+ aug_kwargs = {'prob': aug_prob, 'types': aug_types}
1193
+
1194
+ apply_gradient_penalty = self.steps % 4 == 0
1195
+
1196
+ # discriminator loss fn
1197
+
1198
+ if self.dual_contrast_loss:
1199
+ D_loss_fn = dual_contrastive_loss
1200
+ else:
1201
+ D_loss_fn = hinge_loss
1202
+
1203
+ # train discriminator
1204
+
1205
+ self.GAN.D_opt.zero_grad()
1206
+ for i in range(self.gradient_accumulate_every):
1207
+ latents = torch.randn(batch_size, latent_dim, device=self.accelerator.device)
1208
+ image_batch = next(self.loader)["image"]
1209
+ image_batch.requires_grad_()
1210
+
1211
+ with torch.no_grad():
1212
+ generated_images = G(latents)
1213
+
1214
+ fake_output, fake_output_32x32, _ = D_aug(generated_images, detach=True, **aug_kwargs)
1215
+
1216
+ real_output, real_output_32x32, real_aux_loss = D_aug(image_batch, calc_aux_loss=True, **aug_kwargs)
1217
+
1218
+ real_output_loss = real_output
1219
+ fake_output_loss = fake_output
1220
+
1221
+ divergence = D_loss_fn(real_output_loss, fake_output_loss)
1222
+ divergence_32x32 = D_loss_fn(real_output_32x32, fake_output_32x32)
1223
+ disc_loss = divergence + divergence_32x32
1224
+
1225
+ aux_loss = real_aux_loss
1226
+ disc_loss = disc_loss + aux_loss
1227
+
1228
+ if apply_gradient_penalty:
1229
+ outputs = [real_output, real_output_32x32]
1230
+ if self.accelerator.scaler is not None:
1231
+ outputs = list(map(self.accelerator.scaler.scale, outputs))
1232
+
1233
+ scaled_gradients = torch_grad(outputs=outputs, inputs=image_batch,
1234
+ grad_outputs=list(
1235
+ map(lambda t: torch.ones(t.size(), device=self.accelerator.device),
1236
+ outputs)),
1237
+ create_graph=True, retain_graph=True, only_inputs=True)[0]
1238
+
1239
+ inv_scale = 1.
1240
+ if self.accelerator.scaler is not None:
1241
+ inv_scale = safe_div(1., self.accelerator.scaler.get_scale())
1242
+
1243
+ if inv_scale != float('inf'):
1244
+ gradients = scaled_gradients * inv_scale
1245
+
1246
+ gradients = gradients.reshape(batch_size, -1)
1247
+ gp = self.gp_weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
1248
+
1249
+ if not torch.isnan(gp):
1250
+ disc_loss = disc_loss + gp
1251
+ self.last_gp_loss = gp.clone().detach().item()
1252
+
1253
+ # divide loss by gradient accumulation steps since gradients
1254
+ # are accumulated for multiple backward passes in PyTorch
1255
+ disc_loss = disc_loss / self.gradient_accumulate_every
1256
+
1257
+ disc_loss.register_hook(raise_if_nan)
1258
+ self.accelerator.backward(disc_loss)
1259
+ total_disc_loss += divergence
1260
+
1261
+ self.last_recon_loss = aux_loss.item()
1262
+ self.d_loss = float(total_disc_loss.item() / self.gradient_accumulate_every)
1263
+ self.GAN.D_opt.step()
1264
+
1265
+ # generator loss fn
1266
+
1267
+ if self.dual_contrast_loss:
1268
+ G_loss_fn = dual_contrastive_loss
1269
+ G_requires_calc_real = True
1270
+ else:
1271
+ G_loss_fn = gen_hinge_loss
1272
+ G_requires_calc_real = False
1273
+
1274
+ # train generator
1275
+
1276
+ self.GAN.G_opt.zero_grad()
1277
+
1278
+ for i in range(self.gradient_accumulate_every):
1279
+ latents = torch.randn(batch_size, latent_dim, device=self.accelerator.device)
1280
+
1281
+ if G_requires_calc_real:
1282
+ image_batch = next(self.loader)["image"]
1283
+ image_batch.requires_grad_()
1284
+
1285
+ generated_images = G(latents)
1286
+
1287
+ fake_output, fake_output_32x32, _ = D_aug(generated_images, **aug_kwargs)
1288
+ real_output, real_output_32x32, _ = D_aug(image_batch, **aug_kwargs) if G_requires_calc_real else (
1289
+ None, None, None)
1290
+
1291
+ loss = G_loss_fn(fake_output, real_output)
1292
+ loss_32x32 = G_loss_fn(fake_output_32x32, real_output_32x32)
1293
+
1294
+ gen_loss = loss + loss_32x32
1295
+
1296
+ gen_loss = gen_loss / self.gradient_accumulate_every
1297
+
1298
+ gen_loss.register_hook(raise_if_nan)
1299
+ self.accelerator.backward(gen_loss)
1300
+ total_gen_loss += loss
1301
+
1302
+ # divide loss by gradient accumulation steps since gradients
1303
+ # are accumulated for multiple backward passes in PyTorch
1304
+ self.g_loss = float(total_gen_loss.item() / self.gradient_accumulate_every)
1305
+ self.GAN.G_opt.step()
1306
+
1307
+ # calculate moving averages
1308
+ if self.accelerator.is_main_process and self.steps % 10 == 0 and self.steps > 20000:
1309
+ self.GAN.EMA()
1310
+
1311
+ if self.accelerator.is_main_process and self.steps <= 25000 and self.steps % 1000 == 2:
1312
+ self.GAN.reset_parameter_averaging()
1313
+
1314
+ # save from NaN errors
1315
+
1316
+ if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)):
1317
+ print(f'NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}')
1318
+ self.load(self.checkpoint_num)
1319
+ raise NanException
1320
+
1321
+ del total_disc_loss
1322
+ del total_gen_loss
1323
+
1324
+ # periodically save results
1325
+
1326
+ if self.accelerator.is_main_process:
1327
+ if self.steps % self.save_every == 0:
1328
+ self.save(self.checkpoint_num)
1329
+
1330
+ if self.push_to_hub:
1331
+ with tempfile.TemporaryDirectory() as temp_dir:
1332
+ self.GAN.push_to_hub(temp_dir, self.repo_url, config=self.GAN.config, skip_lfs_files=True)
1333
+
1334
+ if self.steps % self.evaluate_every == 0 or (self.steps % 100 == 0 and self.steps < 20000):
1335
+ self.evaluate(floor(self.steps / self.evaluate_every), num_image_tiles=self.num_image_tiles)
1336
+
1337
+ if exists(self.calculate_fid_every) and self.steps % self.calculate_fid_every == 0 and self.steps != 0:
1338
+ num_batches = math.ceil(self.calculate_fid_num_images / self.batch_size)
1339
+ fid = self.calculate_fid(num_batches)
1340
+ self.last_fid = fid
1341
+
1342
+ with open(str(self.results_dir / self.name / f'fid_scores.txt'), 'a') as f:
1343
+ f.write(f'{self.steps},{fid}\n')
1344
+
1345
+ self.steps += 1
1346
+
1347
+ @torch.no_grad()
1348
+ def evaluate(self, num=0, num_image_tiles=4):
1349
+ self.GAN.eval()
1350
+
1351
+ ext = self.image_extension
1352
+ num_rows = num_image_tiles
1353
+
1354
+ latent_dim = self.GAN.latent_dim
1355
+ image_size = self.GAN.image_size
1356
+
1357
+ # latents and noise
1358
+
1359
+ latents = torch.randn(num_rows ** 2, latent_dim, device=self.accelerator.device)
1360
+
1361
+ # regular
1362
+
1363
+ generated_images = self.generate_(self.GAN.G, latents)
1364
+ file_name = str(self.results_dir / self.name / f'{str(num)}.{ext}')
1365
+ save_image(generated_images, file_name, nrow=num_rows)
1366
+
1367
+ # moving averages
1368
+
1369
+ generated_images = self.generate_(self.GAN.GE.to(self.accelerator.device), latents)
1370
+ file_name_ema = str(self.results_dir / self.name / f'{str(num)}-ema.{ext}')
1371
+ save_image(generated_images, file_name_ema, nrow=num_rows)
1372
+
1373
+ if self.accelerator.is_local_main_process and self.wandb:
1374
+ import wandb
1375
+
1376
+ wandb.log({'generated_examples': wandb.Image(str(file_name))})
1377
+ wandb.log({'generated_examples_ema': wandb.Image(str(file_name_ema))})
1378
+
1379
+ @torch.no_grad()
1380
+ def generate(self, num=0, num_image_tiles=4, checkpoint=None, types=['default', 'ema']):
1381
+ self.GAN.eval()
1382
+
1383
+ latent_dim = self.GAN.latent_dim
1384
+ dir_name = self.name + str('-generated-') + str(checkpoint)
1385
+ dir_full = Path().absolute() / self.results_dir / dir_name
1386
+ ext = self.image_extension
1387
+
1388
+ if not dir_full.exists():
1389
+ os.mkdir(dir_full)
1390
+
1391
+ # regular
1392
+ if 'default' in types:
1393
+ for i in tqdm(range(num_image_tiles), desc='Saving generated default images'):
1394
+ latents = torch.randn(1, latent_dim, device=self.accelerator.device)
1395
+ generated_image = self.generate_(self.GAN.G, latents)
1396
+ path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}.{ext}')
1397
+ save_image(generated_image[0], path, nrow=1)
1398
+
1399
+ # moving averages
1400
+ if 'ema' in types:
1401
+ for i in tqdm(range(num_image_tiles), desc='Saving generated EMA images'):
1402
+ latents = torch.randn(1, latent_dim, device=self.accelerator.device)
1403
+ generated_image = self.generate_(self.GAN.GE, latents)
1404
+ path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}-ema.{ext}')
1405
+ save_image(generated_image[0], path, nrow=1)
1406
+
1407
+ return dir_full
1408
+
1409
+ @torch.no_grad()
1410
+ def show_progress(self, num_images=4, types=['default', 'ema']):
1411
+ checkpoints = self.get_checkpoints()
1412
+ assert exists(checkpoints), 'cannot find any checkpoints to create a training progress video for'
1413
+
1414
+ dir_name = self.name + str('-progress')
1415
+ dir_full = Path().absolute() / self.results_dir / dir_name
1416
+ ext = self.image_extension
1417
+ latents = None
1418
+
1419
+ zfill_length = math.ceil(math.log10(len(checkpoints)))
1420
+
1421
+ if not dir_full.exists():
1422
+ os.mkdir(dir_full)
1423
+
1424
+ for checkpoint in tqdm(checkpoints, desc='Generating progress images'):
1425
+ self.load(checkpoint, print_version=False)
1426
+ self.GAN.eval()
1427
+
1428
+ if checkpoint == 0:
1429
+ latents = torch.randn(num_images, self.GAN.latent_dim, self.accelerator.device)
1430
+
1431
+ # regular
1432
+ if 'default' in types:
1433
+ generated_image = self.generate_(self.GAN.G, latents)
1434
+ path = str(self.results_dir / dir_name / f'{str(checkpoint).zfill(zfill_length)}.{ext}')
1435
+ save_image(generated_image, path, nrow=num_images)
1436
+
1437
+ # moving averages
1438
+ if 'ema' in types:
1439
+ generated_image = self.generate_(self.GAN.GE, latents)
1440
+ path = str(self.results_dir / dir_name / f'{str(checkpoint).zfill(zfill_length)}-ema.{ext}')
1441
+ save_image(generated_image, path, nrow=num_images)
1442
+
1443
+ @torch.no_grad()
1444
+ def calculate_fid(self, num_batches):
1445
+ from pytorch_fid import fid_score
1446
+ real_path = self.fid_dir / 'real'
1447
+ fake_path = self.fid_dir / 'fake'
1448
+
1449
+ # remove any existing files used for fid calculation and recreate directories
1450
+ if not real_path.exists() or self.clear_fid_cache:
1451
+ rmtree(real_path, ignore_errors=True)
1452
+ os.makedirs(real_path)
1453
+
1454
+ for batch_num in tqdm(range(num_batches), desc='calculating FID - saving reals'):
1455
+ real_batch = next(self.loader)["image"]
1456
+ for k, image in enumerate(real_batch.unbind(0)):
1457
+ ind = k + batch_num * self.batch_size
1458
+ save_image(image, real_path / f'{ind}.png')
1459
+
1460
+ # generate a bunch of fake images in results / name / fid_fake
1461
+ rmtree(fake_path, ignore_errors=True)
1462
+ os.makedirs(fake_path)
1463
+
1464
+ self.GAN.eval()
1465
+ ext = self.image_extension
1466
+
1467
+ latent_dim = self.GAN.latent_dim
1468
+ image_size = self.GAN.image_size
1469
+
1470
+ for batch_num in tqdm(range(num_batches), desc='calculating FID - saving generated'):
1471
+ # latents and noise
1472
+ latents = torch.randn(self.batch_size, latent_dim, device=self.accelerator.device)
1473
+
1474
+ # moving averages
1475
+ generated_images = self.generate_(self.GAN.GE, latents)
1476
+
1477
+ for j, image in enumerate(generated_images.unbind(0)):
1478
+ ind = j + batch_num * self.batch_size
1479
+ save_image(image, str(fake_path / f'{str(ind)}-ema.{ext}'))
1480
+
1481
+ return fid_score.calculate_fid_given_paths([str(real_path), str(fake_path)], 256, latents.device, 2048)
1482
+
1483
+ @torch.no_grad()
1484
+ def generate_(self, G, style, num_image_tiles=8):
1485
+ generated_images = evaluate_in_chunks(self.batch_size, G, style)
1486
+ return generated_images.clamp_(0., 1.)
1487
+
1488
+ @torch.no_grad()
1489
+ def generate_interpolation(self, num=0, num_image_tiles=8, num_steps=100, save_frames=False):
1490
+ self.GAN.eval()
1491
+ ext = self.image_extension
1492
+ num_rows = num_image_tiles
1493
+
1494
+ latent_dim = self.GAN.latent_dim
1495
+ image_size = self.GAN.image_size
1496
+
1497
+ # latents and noise
1498
+ latents_low = torch.randn(num_rows ** 2, latent_dim, device=self.accelerator.device)
1499
+ latents_high = torch.randn(num_rows ** 2, latent_dim, device=self.accelerator.device)
1500
+
1501
+ ratios = torch.linspace(0., 8., num_steps)
1502
+
1503
+ frames = []
1504
+ for ratio in tqdm(ratios):
1505
+ interp_latents = slerp(ratio, latents_low, latents_high)
1506
+ generated_images = self.generate_(self.GAN.GE, interp_latents)
1507
+ images_grid = torchvision.utils.make_grid(generated_images, nrow=num_rows)
1508
+ pil_image = transforms.ToPILImage()(images_grid.cpu())
1509
+
1510
+ if self.transparent:
1511
+ background = Image.new('RGBA', pil_image.size, (255, 255, 255))
1512
+ pil_image = Image.alpha_composite(background, pil_image)
1513
+
1514
+ frames.append(pil_image)
1515
+
1516
+ frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:],
1517
+ duration=80, loop=0, optimize=True)
1518
+
1519
+ if save_frames:
1520
+ folder_path = (self.results_dir / self.name / f'{str(num)}')
1521
+ folder_path.mkdir(parents=True, exist_ok=True)
1522
+ for ind, frame in enumerate(frames):
1523
+ frame.save(str(folder_path / f'{str(ind)}.{ext}'))
1524
+
1525
+ def print_log(self):
1526
+ data = [
1527
+ ('G', self.g_loss),
1528
+ ('D', self.d_loss),
1529
+ ('GP', self.last_gp_loss),
1530
+ ('SS', self.last_recon_loss),
1531
+ ('FID', self.last_fid)
1532
+ ]
1533
+
1534
+ data = [d for d in data if exists(d[1])]
1535
+ log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data))
1536
+ print(log)
1537
+
1538
+ if self.accelerator.is_local_main_process:
1539
+ log_dict = {v[0]: v[1] for v in data}
1540
+ if self.wandb:
1541
+ import wandb
1542
+
1543
+ wandb.log(log_dict)
1544
+
1545
+ def model_name(self, num):
1546
+ return str(self.models_dir / self.name / f'model_{num}.pt')
1547
+
1548
+ def init_folders(self):
1549
+ (self.results_dir / self.name).mkdir(parents=True, exist_ok=True)
1550
+ (self.models_dir / self.name).mkdir(parents=True, exist_ok=True)
1551
+
1552
+ def clear(self):
1553
+ rmtree(str(self.models_dir / self.name), True)
1554
+ rmtree(str(self.results_dir / self.name), True)
1555
+ rmtree(str(self.fid_dir), True)
1556
+ rmtree(str(self.config_path), True)
1557
+ self.init_folders()
1558
+
1559
+ def save(self, num):
1560
+ save_data = {
1561
+ 'GAN': self.GAN.state_dict(),
1562
+ }
1563
+
1564
+ torch.save(save_data, self.model_name(num))
1565
+ self.write_config()
1566
+
1567
+ def load(self, num=-1):
1568
+ self.load_config()
1569
+
1570
+ name = num
1571
+ if num == -1:
1572
+ checkpoints = self.get_checkpoints()
1573
+
1574
+ if not exists(checkpoints):
1575
+ return
1576
+
1577
+ name = checkpoints[-1]
1578
+ print(f'continuing from previous epoch - {name}')
1579
+
1580
+ self.steps = name * self.save_every
1581
+
1582
+ load_data = torch.load(self.model_name(name))
1583
+
1584
+ try:
1585
+ self.GAN.load_state_dict(load_data['GAN'])
1586
+ except Exception as e:
1587
+ print(
1588
+ 'unable to load save model. please try downgrading the package to the version specified by the saved model')
1589
+ raise e
1590
+
1591
+ def get_checkpoints(self):
1592
+ file_paths = [p for p in Path(self.models_dir / self.name).glob('model_*.pt')]
1593
+ saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths))
1594
+
1595
+ if len(saved_nums) == 0:
1596
+ return None
1597
+
1598
+ return saved_nums
huggan/pytorch/metrics/README.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GAN metrics
2
+
3
+ In order to track progress 📈 in (un)conditional image generation, a few quantitative metrics have been proposed. Below, we explain the most popular ones. For a more extensive overview, we refer the reader to [Borji, 2021](https://arxiv.org/abs/2103.09396) - which is an up-to-date version of [Borji, 2018](https://arxiv.org/abs/1802.03446). The TLDR is that, despite the use of many popular metrics, objective and comprehensive evaluation of generative models is still an open problem 🤷‍♂️.
4
+
5
+ Quantitative metrics are of course just a proxy of image quality. The most widely used (Inception Score and FID) have several drawbacks [Barratt et al., 2018](https://arxiv.org/abs/1801.01973), [Sajjadi et al., 2018](https://arxiv.org/abs/1806.00035), [Kynkäänniemi et al., 2019](https://arxiv.org/abs/1904.06991).
6
+
7
+ ## Inception score
8
+
9
+ The Inception score was proposed in [Salimans et al., 2016](https://arxiv.org/abs/1606.03498). The authors used a pre-trained Inceptionv3 neural net to classify the images generated by a GAN, and computed a score based on the class probablities of the neural net. The authors claimed that the score correlates well with subjective human evaluation. For an extensive explanation of the metric (as well as an implementation in Numpy and Keras), we refer the reader to [this blog post](https://machinelearningmastery.com/how-to-implement-the-inception-score-from-scratch-for-evaluating-generated-images/#:~:text=The%20Inception%20Score%2C%20or%20IS%20for%20short%2C%20is%20an%20objective,Improved%20Techniques%20for%20Training%20GANs.%E2%80%9D).
10
+
11
+ ## Fréchet Inception Distance (FID)
12
+
13
+ The FID metric was proposed in [Heusel et al., 2018](https://arxiv.org/abs/1706.08500), and is currently the most widely used metric for evaluating image generation. Rather than only evaluating the generated images (as the Inception score), the FID metric compares the generated images to real images.
14
+
15
+ The Fréchet distance meaures the distance between 2 multivariate Gaussian distributions. What does that mean? Concretely, the FID metric uses a pre-trained neural network (the same one as the one of the Inception score, Inceptionv3), and first forwards both real and generated images through it in order to get feature maps. Next, one computes statistics (namely, the mean and standard deviation) of the feature maps for both distributions (generated and real images). Finally, the distance between both distributions is computed based on these statistics.
16
+
17
+ The FID metric assumes that feature maps of a pre-trained neural net extracted on real vs. fake images should be similar (the authors argue that this is a good quantitative metric for assessing image quality, correlating well with human judgement).
18
+
19
+ An important disadvantage of the FID metric is that is has an issue of generalization; a model that simply memorizes the training data can obtain a perfect score on these metrics [Razavi et al., 2019](https://arxiv.org/abs/1906.00446).
20
+
21
+ Variants have been proposed for other modalities, such as the Fréchet Audio Distance [Kilgour et al., 2018](https://arxiv.org/abs/1812.08466) and the Fréchet Video Distance [Unterthiner et al., 2018](https://arxiv.org/abs/1812.01717).
22
+
23
+ The official implementation is in Tensorflow and can be found [here](https://github.com/bioinf-jku/TTUR). A PyTorch implementation can be found [here](https://github.com/mseitzer/pytorch-fid).
24
+
25
+ ## Clean FID
26
+
27
+ In 2021, a paper by [Parmar et al.](https://arxiv.org/abs/2104.11222) indicated that the FID metric is often poorly computed, due to incorrect implementations of low-level image preprocessing (such as resizing of images) in popular frameworks such as PyTorch and TensorFlow. This can produce widely different values for the FID metric.
28
+
29
+ The official implementation of the cleaner FID version can be found [here](https://github.com/GaParmar/clean-fid).
30
+
31
+ Note that FID has many, many other variants including spatial FID (sFID), class-aware FID (CAFD) and conditional FID, Fast FID, Memorization-informed FID (MiFID), Unbiased FID, etc.
32
+
33
+ ## Precision and Recall
34
+
35
+ Despite the FID metric being popular and correlating well with human evaluation, [Sajjadi et al., 2018](https://arxiv.org/abs/1806.00035) pointed out that, due to the fact that the FID score is just a scalar number, it is unable to distinguish between different failure cases. Two generative models could obtain the same FID score while generating images that look entirely different. Hence, the authors proposed a novel approach, defining precision (P) and recall (R) for distributions.
36
+
37
+ Precision measures the similarity of generated instances to the real ones and recall measures the ability of a generator to synthesize all instances found in the training set. Hence, precision measures the quality and recall the coverage.
38
+
39
+ These metrics were then further improved by [Kynkäänniemi et al., 2019](https://arxiv.org/abs/1904.06991).
huggan/pytorch/metrics/__init__.py ADDED
File without changes
huggan/pytorch/metrics/fid_score.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sources:
2
+ # https://www.kaggle.com/code/ibtesama/gan-in-pytorch-with-fid/notebook
3
+ # https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
4
+
5
+ import numpy as np
6
+ from scipy import linalg
7
+ from torch.nn.functional import adaptive_avg_pool2d
8
+
9
+
10
+ def calculate_activation_statistics(images, model, batch_size=128, dims=2048):
11
+ model.eval()
12
+ act = np.empty((len(images), dims))
13
+
14
+ batch = images
15
+ pred = model(batch)[0]
16
+
17
+ # If model output is not scalar, apply global spatial average pooling.
18
+ # This happens if you choose a dimensionality not equal 2048.
19
+ if pred.size(2) != 1 or pred.size(3) != 1:
20
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
21
+
22
+ act = pred.cpu().data.numpy().reshape(pred.size(0), -1)
23
+
24
+ mu = np.mean(act, axis=0)
25
+ sigma = np.cov(act, rowvar=False)
26
+ return mu, sigma
27
+
28
+
29
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
30
+ """Numpy implementation of the Frechet Distance.
31
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
32
+ and X_2 ~ N(mu_2, C_2) is
33
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
34
+ """
35
+
36
+ mu1 = np.atleast_1d(mu1)
37
+ mu2 = np.atleast_1d(mu2)
38
+
39
+ sigma1 = np.atleast_2d(sigma1)
40
+ sigma2 = np.atleast_2d(sigma2)
41
+
42
+ assert mu1.shape == mu2.shape, \
43
+ 'Training and test mean vectors have different lengths'
44
+ assert sigma1.shape == sigma2.shape, \
45
+ 'Training and test covariances have different dimensions'
46
+
47
+ diff = mu1 - mu2
48
+
49
+
50
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
51
+ if not np.isfinite(covmean).all():
52
+ msg = ('fid calculation produces singular product; '
53
+ 'adding %s to diagonal of cov estimates') % eps
54
+ print(msg)
55
+ offset = np.eye(sigma1.shape[0]) * eps
56
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
57
+
58
+
59
+ if np.iscomplexobj(covmean):
60
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
61
+ m = np.max(np.abs(covmean.imag))
62
+ raise ValueError('Imaginary component {}'.format(m))
63
+ covmean = covmean.real
64
+
65
+ tr_covmean = np.trace(covmean)
66
+
67
+ return (diff.dot(diff) + np.trace(sigma1) +
68
+ np.trace(sigma2) - 2 * tr_covmean)
69
+
70
+
71
+ def calculate_fretchet(images_real, images_fake, model):
72
+ """Calculate the fretched distance."""
73
+
74
+ # calculate statistics (mean + std)
75
+ mu_1, std_1 = calculate_activation_statistics(images_real, model)
76
+ mu_2, std_2 = calculate_activation_statistics(images_fake, model)
77
+
78
+ # compute distance
79
+ fid_value = calculate_frechet_distance(mu_1, std_1, mu_2, std_2)
80
+ return fid_value
huggan/pytorch/metrics/inception.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+
6
+ try:
7
+ from torchvision.models.utils import load_state_dict_from_url
8
+ except ImportError:
9
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
+
11
+ # Inception weights ported to Pytorch from
12
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
14
+
15
+
16
+ class InceptionV3(nn.Module):
17
+ """Pretrained InceptionV3 network returning feature maps"""
18
+
19
+ # Index of default block of inception to return,
20
+ # corresponds to output of final average pooling
21
+ DEFAULT_BLOCK_INDEX = 3
22
+
23
+ # Maps feature dimensionality to their output blocks indices
24
+ BLOCK_INDEX_BY_DIM = {
25
+ 64: 0, # First max pooling features
26
+ 192: 1, # Second max pooling featurs
27
+ 768: 2, # Pre-aux classifier features
28
+ 2048: 3 # Final average pooling features
29
+ }
30
+
31
+ def __init__(self,
32
+ output_blocks=(DEFAULT_BLOCK_INDEX,),
33
+ resize_input=True,
34
+ normalize_input=True,
35
+ requires_grad=False,
36
+ use_fid_inception=True):
37
+ """Build pretrained InceptionV3
38
+
39
+ Parameters
40
+ ----------
41
+ output_blocks : list of int
42
+ Indices of blocks to return features of. Possible values are:
43
+ - 0: corresponds to output of first max pooling
44
+ - 1: corresponds to output of second max pooling
45
+ - 2: corresponds to output which is fed to aux classifier
46
+ - 3: corresponds to output of final average pooling
47
+ resize_input : bool
48
+ If true, bilinearly resizes input to width and height 299 before
49
+ feeding input to model. As the network without fully connected
50
+ layers is fully convolutional, it should be able to handle inputs
51
+ of arbitrary size, so resizing might not be strictly needed
52
+ normalize_input : bool
53
+ If true, scales the input from range (0, 1) to the range the
54
+ pretrained Inception network expects, namely (-1, 1)
55
+ requires_grad : bool
56
+ If true, parameters of the model require gradients. Possibly useful
57
+ for finetuning the network
58
+ use_fid_inception : bool
59
+ If true, uses the pretrained Inception model used in Tensorflow's
60
+ FID implementation. If false, uses the pretrained Inception model
61
+ available in torchvision. The FID Inception model has different
62
+ weights and a slightly different structure from torchvision's
63
+ Inception model. If you want to compute FID scores, you are
64
+ strongly advised to set this parameter to true to get comparable
65
+ results.
66
+ """
67
+ super(InceptionV3, self).__init__()
68
+
69
+ self.resize_input = resize_input
70
+ self.normalize_input = normalize_input
71
+ self.output_blocks = sorted(output_blocks)
72
+ self.last_needed_block = max(output_blocks)
73
+
74
+ assert self.last_needed_block <= 3, \
75
+ 'Last possible output block index is 3'
76
+
77
+ self.blocks = nn.ModuleList()
78
+
79
+ if use_fid_inception:
80
+ inception = fid_inception_v3()
81
+ else:
82
+ inception = _inception_v3(pretrained=True)
83
+
84
+ # Block 0: input to maxpool1
85
+ block0 = [
86
+ inception.Conv2d_1a_3x3,
87
+ inception.Conv2d_2a_3x3,
88
+ inception.Conv2d_2b_3x3,
89
+ nn.MaxPool2d(kernel_size=3, stride=2)
90
+ ]
91
+ self.blocks.append(nn.Sequential(*block0))
92
+
93
+ # Block 1: maxpool1 to maxpool2
94
+ if self.last_needed_block >= 1:
95
+ block1 = [
96
+ inception.Conv2d_3b_1x1,
97
+ inception.Conv2d_4a_3x3,
98
+ nn.MaxPool2d(kernel_size=3, stride=2)
99
+ ]
100
+ self.blocks.append(nn.Sequential(*block1))
101
+
102
+ # Block 2: maxpool2 to aux classifier
103
+ if self.last_needed_block >= 2:
104
+ block2 = [
105
+ inception.Mixed_5b,
106
+ inception.Mixed_5c,
107
+ inception.Mixed_5d,
108
+ inception.Mixed_6a,
109
+ inception.Mixed_6b,
110
+ inception.Mixed_6c,
111
+ inception.Mixed_6d,
112
+ inception.Mixed_6e,
113
+ ]
114
+ self.blocks.append(nn.Sequential(*block2))
115
+
116
+ # Block 3: aux classifier to final avgpool
117
+ if self.last_needed_block >= 3:
118
+ block3 = [
119
+ inception.Mixed_7a,
120
+ inception.Mixed_7b,
121
+ inception.Mixed_7c,
122
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
123
+ ]
124
+ self.blocks.append(nn.Sequential(*block3))
125
+
126
+ for param in self.parameters():
127
+ param.requires_grad = requires_grad
128
+
129
+ def forward(self, inp):
130
+ """Get Inception feature maps
131
+
132
+ Parameters
133
+ ----------
134
+ inp : torch.autograd.Variable
135
+ Input tensor of shape Bx3xHxW. Values are expected to be in
136
+ range (0, 1)
137
+
138
+ Returns
139
+ -------
140
+ List of torch.autograd.Variable, corresponding to the selected output
141
+ block, sorted ascending by index
142
+ """
143
+ outp = []
144
+ x = inp
145
+
146
+ if self.resize_input:
147
+ x = F.interpolate(x,
148
+ size=(299, 299),
149
+ mode='bilinear',
150
+ align_corners=False)
151
+
152
+ if self.normalize_input:
153
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154
+
155
+ for idx, block in enumerate(self.blocks):
156
+ x = block(x)
157
+ if idx in self.output_blocks:
158
+ outp.append(x)
159
+
160
+ if idx == self.last_needed_block:
161
+ break
162
+
163
+ return outp
164
+
165
+
166
+ def _inception_v3(*args, **kwargs):
167
+ """Wraps `torchvision.models.inception_v3`
168
+
169
+ Skips default weight inititialization if supported by torchvision version.
170
+ See https://github.com/mseitzer/pytorch-fid/issues/28.
171
+ """
172
+ try:
173
+ version = tuple(map(int, torchvision.__version__.split('.')[:2]))
174
+ except ValueError:
175
+ # Just a caution against weird version strings
176
+ version = (0,)
177
+
178
+ if version >= (0, 6):
179
+ kwargs['init_weights'] = False
180
+
181
+ return torchvision.models.inception_v3(*args, **kwargs)
182
+
183
+
184
+ def fid_inception_v3():
185
+ """Build pretrained Inception model for FID computation
186
+
187
+ The Inception model for FID computation uses a different set of weights
188
+ and has a slightly different structure than torchvision's Inception.
189
+
190
+ This method first constructs torchvision's Inception and then patches the
191
+ necessary parts that are different in the FID Inception model.
192
+ """
193
+ inception = _inception_v3(num_classes=1008,
194
+ aux_logits=False,
195
+ pretrained=False)
196
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
197
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
198
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
199
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
200
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
201
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
202
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
203
+ inception.Mixed_7b = FIDInceptionE_1(1280)
204
+ inception.Mixed_7c = FIDInceptionE_2(2048)
205
+
206
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
207
+ inception.load_state_dict(state_dict)
208
+ return inception
209
+
210
+
211
+ class FIDInceptionA(torchvision.models.inception.InceptionA):
212
+ """InceptionA block patched for FID computation"""
213
+ def __init__(self, in_channels, pool_features):
214
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
215
+
216
+ def forward(self, x):
217
+ branch1x1 = self.branch1x1(x)
218
+
219
+ branch5x5 = self.branch5x5_1(x)
220
+ branch5x5 = self.branch5x5_2(branch5x5)
221
+
222
+ branch3x3dbl = self.branch3x3dbl_1(x)
223
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
224
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
225
+
226
+ # Patch: Tensorflow's average pool does not use the padded zero's in
227
+ # its average calculation
228
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
229
+ count_include_pad=False)
230
+ branch_pool = self.branch_pool(branch_pool)
231
+
232
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
233
+ return torch.cat(outputs, 1)
234
+
235
+
236
+ class FIDInceptionC(torchvision.models.inception.InceptionC):
237
+ """InceptionC block patched for FID computation"""
238
+ def __init__(self, in_channels, channels_7x7):
239
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
240
+
241
+ def forward(self, x):
242
+ branch1x1 = self.branch1x1(x)
243
+
244
+ branch7x7 = self.branch7x7_1(x)
245
+ branch7x7 = self.branch7x7_2(branch7x7)
246
+ branch7x7 = self.branch7x7_3(branch7x7)
247
+
248
+ branch7x7dbl = self.branch7x7dbl_1(x)
249
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
250
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
251
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
252
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
253
+
254
+ # Patch: Tensorflow's average pool does not use the padded zero's in
255
+ # its average calculation
256
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
257
+ count_include_pad=False)
258
+ branch_pool = self.branch_pool(branch_pool)
259
+
260
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
261
+ return torch.cat(outputs, 1)
262
+
263
+
264
+ class FIDInceptionE_1(torchvision.models.inception.InceptionE):
265
+ """First InceptionE block patched for FID computation"""
266
+ def __init__(self, in_channels):
267
+ super(FIDInceptionE_1, self).__init__(in_channels)
268
+
269
+ def forward(self, x):
270
+ branch1x1 = self.branch1x1(x)
271
+
272
+ branch3x3 = self.branch3x3_1(x)
273
+ branch3x3 = [
274
+ self.branch3x3_2a(branch3x3),
275
+ self.branch3x3_2b(branch3x3),
276
+ ]
277
+ branch3x3 = torch.cat(branch3x3, 1)
278
+
279
+ branch3x3dbl = self.branch3x3dbl_1(x)
280
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
281
+ branch3x3dbl = [
282
+ self.branch3x3dbl_3a(branch3x3dbl),
283
+ self.branch3x3dbl_3b(branch3x3dbl),
284
+ ]
285
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
286
+
287
+ # Patch: Tensorflow's average pool does not use the padded zero's in
288
+ # its average calculation
289
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
290
+ count_include_pad=False)
291
+ branch_pool = self.branch_pool(branch_pool)
292
+
293
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
294
+ return torch.cat(outputs, 1)
295
+
296
+
297
+ class FIDInceptionE_2(torchvision.models.inception.InceptionE):
298
+ """Second InceptionE block patched for FID computation"""
299
+ def __init__(self, in_channels):
300
+ super(FIDInceptionE_2, self).__init__(in_channels)
301
+
302
+ def forward(self, x):
303
+ branch1x1 = self.branch1x1(x)
304
+
305
+ branch3x3 = self.branch3x3_1(x)
306
+ branch3x3 = [
307
+ self.branch3x3_2a(branch3x3),
308
+ self.branch3x3_2b(branch3x3),
309
+ ]
310
+ branch3x3 = torch.cat(branch3x3, 1)
311
+
312
+ branch3x3dbl = self.branch3x3dbl_1(x)
313
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
314
+ branch3x3dbl = [
315
+ self.branch3x3dbl_3a(branch3x3dbl),
316
+ self.branch3x3dbl_3b(branch3x3dbl),
317
+ ]
318
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
319
+
320
+ # Patch: The FID Inception model uses max pooling instead of average
321
+ # pooling. This is likely an error in this specific Inception
322
+ # implementation, as other Inception models use average pooling here
323
+ # (which matches the description in the paper).
324
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
325
+ branch_pool = self.branch_pool(branch_pool)
326
+
327
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
328
+ return torch.cat(outputs, 1)