Kurokabe commited on
Commit
3be620b
1 Parent(s): 2a06e99

Upload 84 files

Browse files

Add application files

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +4 -0
  2. .gitattributes +1 -0
  3. .gitignore +208 -0
  4. Dockerfile +21 -0
  5. configs/colab.yaml +50 -0
  6. configs/kny_image.yaml +47 -0
  7. configs/kny_image_full_style.yaml +47 -0
  8. configs/kny_image_full_vgg19.yaml +47 -0
  9. configs/kny_transformer_light.yaml +60 -0
  10. configs/kny_video_gpt2_large.yaml +50 -0
  11. configs/kny_video_gpt2_large_gradio.yaml +50 -0
  12. configs/kny_video_gpt2_medium.yaml +50 -0
  13. configs/kny_video_gpt2_xl.yaml +50 -0
  14. ganime/__main__.py +4 -0
  15. ganime/app.py +212 -0
  16. ganime/configs/__init__.py +0 -0
  17. ganime/configs/model_configs.py +70 -0
  18. ganime/data/__init__.py +0 -0
  19. ganime/data/base.py +282 -0
  20. ganime/data/experimental.py +222 -0
  21. ganime/data/kny.py +19 -0
  22. ganime/data/mnist.py +103 -0
  23. ganime/metrics/image.py +70 -0
  24. ganime/metrics/video.py +98 -0
  25. ganime/model/__init__.py +0 -0
  26. ganime/model/base.py +45 -0
  27. ganime/model/moving_vae.py +126 -0
  28. ganime/model/p2p/__init__.py +0 -0
  29. ganime/model/p2p/p2p.py +543 -0
  30. ganime/model/p2p/p2p_test.py +713 -0
  31. ganime/model/p2p/p2p_v2.py +498 -0
  32. ganime/model/p2p/p2p_v3.py +237 -0
  33. ganime/model/vae/vae.py +98 -0
  34. ganime/model/vq_vae/vq_vae.py +143 -0
  35. ganime/model/vqgan/__init__.py +0 -0
  36. ganime/model/vqgan/discriminator/__init__.py +0 -0
  37. ganime/model/vqgan/discriminator/model.py +64 -0
  38. ganime/model/vqgan/losses/__init__.py +0 -0
  39. ganime/model/vqgan/losses/lpips.py +134 -0
  40. ganime/model/vqgan/losses/vqperceptual.py +47 -0
  41. ganime/model/vqgan/vqgan.py +722 -0
  42. ganime/model/vqgan_clean/__init__.py +0 -0
  43. ganime/model/vqgan_clean/diffusion/__init__.py +0 -0
  44. ganime/model/vqgan_clean/diffusion/decoder.py +115 -0
  45. ganime/model/vqgan_clean/diffusion/encoder.py +125 -0
  46. ganime/model/vqgan_clean/diffusion/layers.py +179 -0
  47. ganime/model/vqgan_clean/discriminator/__init__.py +0 -0
  48. ganime/model/vqgan_clean/discriminator/model.py +88 -0
  49. ganime/model/vqgan_clean/discriminator/model_bkp.py +76 -0
  50. ganime/model/vqgan_clean/experimental/gpt2_embedding.py +1127 -0
.dockerignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .git
2
+ data
3
+ checkpoints
4
+ logs
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ models/vgg19/imagenet-vgg-verydeep-19.mat filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,venv
3
+ # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python,venv
4
+
5
+ ### Python ###
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
110
+ __pypackages__/
111
+
112
+ # Celery stuff
113
+ celerybeat-schedule
114
+ celerybeat.pid
115
+
116
+ # SageMath parsed files
117
+ *.sage.py
118
+
119
+ # Environments
120
+ .env
121
+ .venv
122
+ env/
123
+ venv/
124
+ ENV/
125
+ env.bak/
126
+ venv.bak/
127
+
128
+ # Spyder project settings
129
+ .spyderproject
130
+ .spyproject
131
+
132
+ # Rope project settings
133
+ .ropeproject
134
+
135
+ # mkdocs documentation
136
+ /site
137
+
138
+ # mypy
139
+ .mypy_cache/
140
+ .dmypy.json
141
+ dmypy.json
142
+
143
+ # Pyre type checker
144
+ .pyre/
145
+
146
+ # pytype static type analyzer
147
+ .pytype/
148
+
149
+ # Cython debug symbols
150
+ cython_debug/
151
+
152
+ # PyCharm
153
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
154
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
155
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
156
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
157
+ #.idea/
158
+
159
+ ### venv ###
160
+ # Virtualenv
161
+ # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
162
+ [Bb]in
163
+ [Ii]nclude
164
+ [Ll]ib
165
+ [Ll]ib64
166
+ [Ll]ocal
167
+ #[Ss]cripts
168
+ pyvenv.cfg
169
+ pip-selfcheck.json
170
+
171
+ ### VisualStudioCode ###
172
+ .vscode/*
173
+ # !.vscode/settings.json
174
+ # !.vscode/tasks.json
175
+ # !.vscode/launch.json
176
+ # !.vscode/extensions.json
177
+ # !.vscode/*.code-snippets
178
+
179
+ # Local History for Visual Studio Code
180
+ .history/
181
+
182
+ # Built Visual Studio Code Extensions
183
+ *.vsix
184
+
185
+ ### VisualStudioCode Patch ###
186
+ # Ignore all local history of files
187
+ .history
188
+ .ionide
189
+
190
+ # Support for Project snippet scope
191
+
192
+ # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,venv
193
+
194
+ *.npy
195
+ checkpoints/*
196
+ ganime_results/*
197
+ data/*
198
+ *.avi
199
+ *.out
200
+ notebooks/model/p2p_v2/*
201
+ logs/*
202
+ interesting_logs/*
203
+ notebooks/model/vq-gan/train_output/*
204
+ notebooks/model/vq-gan/validation_output/*
205
+ notebooks/model/vq-gan/test_output/*
206
+ *.zip
207
+ flagged/*
208
+ notebooks/model/vq-gan/gpt_kny_light_large_256/*
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM tensorflow/tensorflow:2.7.0-gpu-jupyter
2
+ # Because of https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/ and https://github.com/NVIDIA/nvidia-docker/issues/1631#issuecomment-1112828208
3
+ RUN rm /etc/apt/sources.list.d/cuda.list
4
+ RUN rm /etc/apt/sources.list.d/nvidia-ml.list
5
+ RUN apt-key del 7fa2af80
6
+ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub
7
+ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu2004/x86_64/7fa2af80.pub
8
+
9
+ # Update and install ffmpeg
10
+ RUN apt-get -y update
11
+ RUN apt-get -y upgrade
12
+ RUN apt-get install -y ffmpeg
13
+
14
+ # Setup environment
15
+ WORKDIR /GANime
16
+ ENV PROJECT_DIR=/GANime
17
+ COPY requirements.txt /GANime/requirements.txt
18
+ RUN pip install -r requirements.txt
19
+ COPY . .
20
+ RUN pip install -e .
21
+ EXPOSE 8888
configs/colab.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ transformer_config:
3
+ #checkpoint_path: GANime/checkpoints/kny_video_full_gpt2_medium/checkpoint
4
+ remaining_frames_method: "own_embeddings"
5
+ transformer_type: "gpt2-medium"
6
+ first_stage_config:
7
+ checkpoint_path: GANime/checkpoints/kny_image_full_vgg19/checkpoint
8
+ vqvae_config:
9
+ beta: 0.25
10
+ num_embeddings: 50257
11
+ embedding_dim: 128
12
+ autoencoder_config:
13
+ z_channels: 512
14
+ channels: 32
15
+ channels_multiplier:
16
+ - 2
17
+ - 4
18
+ - 8
19
+ - 8
20
+ num_res_blocks: 1
21
+ attention_resolution:
22
+ - 16
23
+ resolution: 128
24
+ dropout: 0.0
25
+ discriminator_config:
26
+ num_layers: 3
27
+ filters: 64
28
+
29
+ loss_config:
30
+ discriminator:
31
+ loss: "hinge"
32
+ factor: 1.0
33
+ iter_start: 16200
34
+ weight: 0.3
35
+ vqvae:
36
+ codebook_weight: 1.0
37
+ perceptual_weight: 4.0
38
+ perceptual_loss: "vgg19"
39
+
40
+ train:
41
+ batch_size: 64
42
+ accumulation_size: 1
43
+ n_epochs: 2000
44
+ len_x_train: 8000
45
+ warmup_epoch_percentage: 0.15
46
+ lr_start: 1e-5
47
+ lr_max: 2.5e-4
48
+ perceptual_loss_weight: 1.0
49
+ n_frames_before: 1
50
+ stop_ground_truth_after_epoch: 50
configs/kny_image.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ checkpoint_path: ../../../checkpoints/kny_image_full_no_disc/checkpoint
3
+ vqvae_config:
4
+ beta: 0.25
5
+ num_embeddings: 50257
6
+ embedding_dim: 128
7
+ autoencoder_config:
8
+ z_channels: 512
9
+ channels: 32
10
+ channels_multiplier:
11
+ - 2
12
+ - 4
13
+ - 8
14
+ - 8
15
+ num_res_blocks: 1
16
+ attention_resolution:
17
+ - 16
18
+ resolution: 128
19
+ dropout: 0.0
20
+ discriminator_config:
21
+ num_layers: 3
22
+ filters: 64
23
+
24
+ loss_config:
25
+ discriminator:
26
+ loss: "hinge"
27
+ factor: 1.0
28
+ iter_start: 5000
29
+ weight: 0.8
30
+ vqvae:
31
+ codebook_weight: 1.0
32
+ perceptual_weight: 4.0
33
+ perceptual_loss: "vgg19" # "vgg16", "vgg19", "style"
34
+
35
+ trainer:
36
+ batch_size: 32
37
+ n_epochs: 10000
38
+ gen_lr: 3e-5
39
+ disc_lr: 3e-5
40
+ gen_beta_1: 0.5
41
+ gen_beta_2: 0.9
42
+ disc_beta_1: 0.5
43
+ disc_beta_2: 0.9
44
+ gen_clip_norm: 1.0
45
+ disc_clip_norm: 1.0
46
+
47
+
configs/kny_image_full_style.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ checkpoint_path: ../../../checkpoints/kny_image_full_style/checkpoint
3
+ vqvae_config:
4
+ beta: 0.25
5
+ num_embeddings: 50257
6
+ embedding_dim: 128
7
+ autoencoder_config:
8
+ z_channels: 512
9
+ channels: 32
10
+ channels_multiplier:
11
+ - 2
12
+ - 4
13
+ - 8
14
+ - 8
15
+ num_res_blocks: 1
16
+ attention_resolution:
17
+ - 16
18
+ resolution: 128
19
+ dropout: 0.0
20
+ discriminator_config:
21
+ num_layers: 3
22
+ filters: 64
23
+
24
+ loss_config:
25
+ discriminator:
26
+ loss: "hinge"
27
+ factor: 1.0
28
+ iter_start: 50000000
29
+ weight: 0.8
30
+ vqvae:
31
+ codebook_weight: 1.0
32
+ perceptual_weight: 4.0
33
+ perceptual_loss: "style" # "vgg16", "vgg19", "style"
34
+
35
+ trainer:
36
+ batch_size: 32
37
+ n_epochs: 10000
38
+ gen_lr: 8e-5
39
+ disc_lr: 8e-5
40
+ gen_beta_1: 0.5
41
+ gen_beta_2: 0.9
42
+ disc_beta_1: 0.5
43
+ disc_beta_2: 0.9
44
+ gen_clip_norm: 1.0
45
+ disc_clip_norm: 1.0
46
+
47
+
configs/kny_image_full_vgg19.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ checkpoint_path: ../../../checkpoints/kny_image_full_vgg19/checkpoint
3
+ vqvae_config:
4
+ beta: 0.25
5
+ num_embeddings: 50257
6
+ embedding_dim: 128
7
+ autoencoder_config:
8
+ z_channels: 512
9
+ channels: 32
10
+ channels_multiplier:
11
+ - 2
12
+ - 4
13
+ - 8
14
+ - 8
15
+ num_res_blocks: 1
16
+ attention_resolution:
17
+ - 16
18
+ resolution: 128
19
+ dropout: 0.0
20
+ discriminator_config:
21
+ num_layers: 3
22
+ filters: 64
23
+
24
+ loss_config:
25
+ discriminator:
26
+ loss: "hinge"
27
+ factor: 1.0
28
+ iter_start: 50000000
29
+ weight: 0.8
30
+ vqvae:
31
+ codebook_weight: 1.0
32
+ perceptual_weight: 4.0
33
+ perceptual_loss: "vgg19" # "vgg16", "vgg19", "style"
34
+
35
+ trainer:
36
+ batch_size: 64
37
+ n_epochs: 10000
38
+ gen_lr: 3e-5
39
+ disc_lr: 5e-5
40
+ gen_beta_1: 0.5
41
+ gen_beta_2: 0.9
42
+ disc_beta_1: 0.5
43
+ disc_beta_2: 0.9
44
+ gen_clip_norm: 1.0
45
+ disc_clip_norm: 1.0
46
+
47
+
configs/kny_transformer_light.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ transformer_config:
3
+ checkpoint_path: ../../../checkpoints/kny_video_light/checkpoint
4
+ # vocab_size: 50257
5
+ # n_positions: 1024
6
+ # n_embd: 1024 #1280 #768
7
+ # n_layer: 24 #36 #12
8
+ # n_head: 16 #20 #12
9
+ # resid_pdrop: 0.1
10
+ # embd_pdrop: 0.1
11
+ # attn_pdrop: 0.1
12
+ # remaining_frames_method: "concat"
13
+ # remaining_frames_method: "token_type_ids"
14
+ remaining_frames_method: "own_embeddings"
15
+ first_stage_config:
16
+ checkpoint_path: ../../../checkpoints/kny_image_light_discriminator/checkpoint
17
+ vqvae_config:
18
+ beta: 0.25
19
+ num_embeddings: 64
20
+ embedding_dim: 256
21
+ autoencoder_config:
22
+ z_channels: 128
23
+ channels: 64
24
+ channels_multiplier:
25
+ - 1
26
+ - 1
27
+ - 2
28
+ - 2
29
+ - 4
30
+ num_res_blocks: 1
31
+ attention_resolution:
32
+ - 16
33
+ resolution: 128
34
+ dropout: 0.0
35
+ discriminator_config:
36
+ num_layers: 3
37
+ filters: 64
38
+
39
+ loss_config:
40
+ discriminator:
41
+ loss: "hinge"
42
+ factor: 1.0
43
+ iter_start: 16200
44
+ weight: 0.3
45
+ vqvae:
46
+ codebook_weight: 1.0
47
+ perceptual_weight: 4.0
48
+ perceptual_loss: "style"
49
+
50
+ train:
51
+ batch_size: 8
52
+ accumulation_size: 8
53
+ n_epochs: 2000
54
+ len_x_train: 631
55
+ warmup_epoch_percentage: 0.15
56
+ lr_start: 1e-5
57
+ lr_max: 2.5e-4
58
+ perceptual_loss_weight: 1.0
59
+ n_frames_before: 5
60
+ stop_ground_truth_after_epoch: 100
configs/kny_video_gpt2_large.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ transformer_config:
3
+ checkpoint_path: ../../../checkpoints/kny_video_full_gpt2_large_final/checkpoint
4
+ remaining_frames_method: "own_embeddings"
5
+ transformer_type: "gpt2-large"
6
+ first_stage_config:
7
+ checkpoint_path: ../../../checkpoints/kny_image_full_vgg19/checkpoint
8
+ vqvae_config:
9
+ beta: 0.25
10
+ num_embeddings: 50257
11
+ embedding_dim: 128
12
+ autoencoder_config:
13
+ z_channels: 512
14
+ channels: 32
15
+ channels_multiplier:
16
+ - 2
17
+ - 4
18
+ - 8
19
+ - 8
20
+ num_res_blocks: 1
21
+ attention_resolution:
22
+ - 16
23
+ resolution: 128
24
+ dropout: 0.0
25
+ discriminator_config:
26
+ num_layers: 3
27
+ filters: 64
28
+
29
+ loss_config:
30
+ discriminator:
31
+ loss: "hinge"
32
+ factor: 1.0
33
+ iter_start: 16200
34
+ weight: 0.3
35
+ vqvae:
36
+ codebook_weight: 1.0
37
+ perceptual_weight: 4.0
38
+ perceptual_loss: "vgg19"
39
+
40
+ train:
41
+ batch_size: 64
42
+ accumulation_size: 1
43
+ n_epochs: 10000
44
+ len_x_train: 28213
45
+ warmup_epoch_percentage: 0.15
46
+ lr_start: 1e-5
47
+ lr_max: 2.5e-4
48
+ perceptual_loss_weight: 1.0
49
+ n_frames_before: 1
50
+ stop_ground_truth_after_epoch: 1000
configs/kny_video_gpt2_large_gradio.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ transformer_config:
3
+ checkpoint_path: ./checkpoints/kny_video_full_gpt2_large_final/checkpoint
4
+ remaining_frames_method: "own_embeddings"
5
+ transformer_type: "gpt2-large"
6
+ first_stage_config:
7
+ checkpoint_path: ./checkpoints/kny_image_full_vgg19/checkpoint
8
+ vqvae_config:
9
+ beta: 0.25
10
+ num_embeddings: 50257
11
+ embedding_dim: 128
12
+ autoencoder_config:
13
+ z_channels: 512
14
+ channels: 32
15
+ channels_multiplier:
16
+ - 2
17
+ - 4
18
+ - 8
19
+ - 8
20
+ num_res_blocks: 1
21
+ attention_resolution:
22
+ - 16
23
+ resolution: 128
24
+ dropout: 0.0
25
+ discriminator_config:
26
+ num_layers: 3
27
+ filters: 64
28
+
29
+ loss_config:
30
+ discriminator:
31
+ loss: "hinge"
32
+ factor: 1.0
33
+ iter_start: 16200
34
+ weight: 0.3
35
+ vqvae:
36
+ codebook_weight: 1.0
37
+ perceptual_weight: 4.0
38
+ perceptual_loss: "vgg19"
39
+
40
+ train:
41
+ batch_size: 64
42
+ accumulation_size: 1
43
+ n_epochs: 10000
44
+ len_x_train: 28213
45
+ warmup_epoch_percentage: 0.15
46
+ lr_start: 1e-5
47
+ lr_max: 2.5e-4
48
+ perceptual_loss_weight: 1.0
49
+ n_frames_before: 1
50
+ stop_ground_truth_after_epoch: 1000
configs/kny_video_gpt2_medium.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ transformer_config:
3
+ checkpoint_path: ./checkpoints/kny_video_full_gpt2_medium/checkpoint
4
+ remaining_frames_method: "own_embeddings"
5
+ transformer_type: "gpt2-medium"
6
+ first_stage_config:
7
+ checkpoint_path: ./checkpoints/kny_image_full_vgg19/checkpoint
8
+ vqvae_config:
9
+ beta: 0.25
10
+ num_embeddings: 50257
11
+ embedding_dim: 128
12
+ autoencoder_config:
13
+ z_channels: 512
14
+ channels: 32
15
+ channels_multiplier:
16
+ - 2
17
+ - 4
18
+ - 8
19
+ - 8
20
+ num_res_blocks: 1
21
+ attention_resolution:
22
+ - 16
23
+ resolution: 128
24
+ dropout: 0.0
25
+ discriminator_config:
26
+ num_layers: 3
27
+ filters: 64
28
+
29
+ loss_config:
30
+ discriminator:
31
+ loss: "hinge"
32
+ factor: 1.0
33
+ iter_start: 16200
34
+ weight: 0.3
35
+ vqvae:
36
+ codebook_weight: 1.0
37
+ perceptual_weight: 4.0
38
+ perceptual_loss: "vgg19"
39
+
40
+ train:
41
+ batch_size: 64
42
+ accumulation_size: 1
43
+ n_epochs: 500
44
+ len_x_train: 28213
45
+ warmup_epoch_percentage: 0.15
46
+ lr_start: 5e-6
47
+ lr_max: 1e-4
48
+ perceptual_loss_weight: 1.0
49
+ n_frames_before: 5
50
+ stop_ground_truth_after_epoch: 200
configs/kny_video_gpt2_xl.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ transformer_config:
3
+ # checkpoint_path: ../../../checkpoints/kny_video_full_gpt2_xl/checkpoint
4
+ remaining_frames_method: "own_embeddings"
5
+ transformer_type: "gpt2-xl"
6
+ first_stage_config:
7
+ checkpoint_path: ../../../checkpoints/kny_image_full_vgg19/checkpoint
8
+ vqvae_config:
9
+ beta: 0.25
10
+ num_embeddings: 50257
11
+ embedding_dim: 128
12
+ autoencoder_config:
13
+ z_channels: 512
14
+ channels: 32
15
+ channels_multiplier:
16
+ - 2
17
+ - 4
18
+ - 8
19
+ - 8
20
+ num_res_blocks: 1
21
+ attention_resolution:
22
+ - 16
23
+ resolution: 128
24
+ dropout: 0.0
25
+ discriminator_config:
26
+ num_layers: 3
27
+ filters: 64
28
+
29
+ loss_config:
30
+ discriminator:
31
+ loss: "hinge"
32
+ factor: 1.0
33
+ iter_start: 16200
34
+ weight: 0.3
35
+ vqvae:
36
+ codebook_weight: 1.0
37
+ perceptual_weight: 4.0
38
+ perceptual_loss: "vgg19"
39
+
40
+ train:
41
+ batch_size: 64
42
+ accumulation_size: 1
43
+ n_epochs: 500
44
+ len_x_train: 28213
45
+ warmup_epoch_percentage: 0.15
46
+ lr_start: 5e-6
47
+ lr_max: 1e-4
48
+ perceptual_loss_weight: 1.0
49
+ n_frames_before: 1
50
+ stop_ground_truth_after_epoch: 200
ganime/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from ganime import app
2
+
3
+ if __name__ == "__main__":
4
+ app.run()
ganime/app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import click
4
+ import omegaconf
5
+ import ray
6
+ from pyprojroot.pyprojroot import here
7
+ from ray import tune
8
+ from ray.train import Trainer
9
+ from ray.tune.schedulers import AsyncHyperBandScheduler
10
+ from ray.tune.suggest import ConcurrencyLimiter
11
+ from ray.tune.suggest.optuna import OptunaSearch
12
+
13
+ from ganime.trainer.ganime import TrainableGANime
14
+
15
+ import os
16
+
17
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152
18
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1, 2, 3, 4, 5, 6"
19
+
20
+
21
+ def get_metric_direction(metric: str):
22
+ if "loss" in metric:
23
+ return "min"
24
+ else:
25
+ raise ValueError(f"Unknown metric: {metric}")
26
+
27
+
28
+ def trial_name_id(trial):
29
+ return f"{trial.trainable_name}"
30
+
31
+
32
+ def trial_dirname_creator(trial):
33
+ return f"{trial.trial_id}"
34
+
35
+
36
+ def get_search_space(model):
37
+ if model == "vqgan":
38
+ return {
39
+ # "beta": tune.uniform(0.1, 1.0),
40
+ "num_embeddings": tune.choice([64, 128, 256]),
41
+ "embedding_dim": tune.choice([128, 256, 512, 1024]),
42
+ "z_channels": tune.choice([64, 128, 256]),
43
+ "channels": tune.choice([64, 128, 256]),
44
+ "channels_multiplier": tune.choice(
45
+ [
46
+ [1, 2, 4],
47
+ [1, 1, 2, 2],
48
+ [1, 2, 2, 4],
49
+ [1, 1, 2, 2, 4],
50
+ ]
51
+ ),
52
+ "attention_resolution": tune.choice([[16], [32], [16, 32]]),
53
+ "batch_size": tune.choice([8, 16]),
54
+ "dropout": tune.choice([0.0, 0.1, 0.2]),
55
+ "weight": tune.quniform(0.1, 1.0, 0.1),
56
+ "codebook_weight": tune.quniform(0.2, 2.0, 0.2),
57
+ "perceptual_weight": tune.quniform(0.5, 5.0, 0.5),
58
+ "gen_lr": tune.qloguniform(1e-5, 1e-3, 1e-5),
59
+ "disc_lr": tune.qloguniform(1e-5, 1e-3, 1e-5),
60
+ "gen_beta_1": tune.quniform(0.5, 0.9, 0.1),
61
+ "gen_beta_2": tune.quniform(0.9, 0.999, 0.001),
62
+ "disc_beta_1": tune.quniform(0.5, 0.9, 0.1),
63
+ "disc_beta_2": tune.quniform(0.9, 0.999, 0.001),
64
+ "gen_clip_norm": tune.choice([1.0, None]),
65
+ "disc_clip_norm": tune.choice([1.0, None]),
66
+ }
67
+ elif model == "gpt":
68
+ return {
69
+ "remaining_frames_method": tune.choice(
70
+ ["concat", "token_type_ids", "own_embeddings"]
71
+ ),
72
+ # "batch_size": tune.choice([8, 16]),
73
+ "lr_max": tune.qloguniform(1e-5, 1e-3, 5e-5),
74
+ "lr_start": tune.sample_from(lambda spec: spec.config.lr_max / 10),
75
+ "perceptual_loss_weight": tune.quniform(0.0, 1.0, 0.1),
76
+ "n_frames_before": tune.randint(1, 10),
77
+ }
78
+
79
+
80
+ def tune_ganime(
81
+ experiment_name: str,
82
+ dataset_name: str,
83
+ config_file: str,
84
+ model: str,
85
+ metric: str,
86
+ epochs: int,
87
+ num_samples: int,
88
+ num_cpus: int,
89
+ num_gpus: int,
90
+ max_concurrent_trials: int,
91
+ ):
92
+
93
+ dataset_path = here("data")
94
+ analysis = tune.run(
95
+ TrainableGANime,
96
+ name=experiment_name,
97
+ search_alg=ConcurrencyLimiter(
98
+ OptunaSearch(), max_concurrent=max_concurrent_trials
99
+ ),
100
+ scheduler=AsyncHyperBandScheduler(max_t=epochs, grace_period=5),
101
+ metric=metric,
102
+ mode=get_metric_direction(metric),
103
+ num_samples=num_samples,
104
+ stop={"training_iteration": epochs},
105
+ local_dir="./ganime_results",
106
+ config={
107
+ "dataset_name": dataset_name,
108
+ "dataset_path": dataset_path,
109
+ "model": model,
110
+ "config_file": config_file,
111
+ "hyperparameters": get_search_space(model),
112
+ },
113
+ resources_per_trial={
114
+ "cpu": num_cpus // max_concurrent_trials,
115
+ "gpu": num_gpus / max_concurrent_trials,
116
+ },
117
+ trial_name_creator=trial_name_id,
118
+ trial_dirname_creator=trial_dirname_creator,
119
+ )
120
+ best_loss = analysis.get_best_config(metric="total_loss", mode="min")
121
+ # best_accuracy = analysis.get_best_config(metric="accuracy", mode="max")
122
+ print(f"Best loss config: {best_loss}")
123
+ # print(f"Best accuracy config: {best_accuracy}")
124
+ return analysis
125
+
126
+
127
+ @click.command()
128
+ @click.option(
129
+ "--dataset",
130
+ type=click.Choice(
131
+ ["moving_mnist_images", "kny_images", "kny_images_light"], case_sensitive=False
132
+ ),
133
+ default="kny_images_light",
134
+ help="Dataset to use",
135
+ )
136
+ @click.option(
137
+ "--model",
138
+ type=click.Choice(["vqgan", "gpt"], case_sensitive=False),
139
+ default="vqgan",
140
+ help="Model to use",
141
+ )
142
+ @click.option(
143
+ "--epochs",
144
+ default=500,
145
+ help="Number of epochs to run",
146
+ )
147
+ @click.option(
148
+ "--num_samples",
149
+ default=100,
150
+ help="Total number of trials to run",
151
+ )
152
+ @click.option(
153
+ "--num_cpus",
154
+ default=64,
155
+ help="Number of cpus to use",
156
+ )
157
+ @click.option(
158
+ "--num_gpus",
159
+ default=6,
160
+ help="Number of gpus to use",
161
+ )
162
+ @click.option(
163
+ "--max_concurrent_trials",
164
+ default=6,
165
+ help="Maximum number of concurrent trials",
166
+ )
167
+ @click.option(
168
+ "--metric",
169
+ type=click.Choice(
170
+ ["total_loss", "reconstruction_loss", "vq_loss", "disc_loss"],
171
+ case_sensitive=False,
172
+ ),
173
+ default="total_loss",
174
+ help="The metric used to select the best trial",
175
+ )
176
+ @click.option(
177
+ "--experiment_name",
178
+ default="kny_images_light_v2",
179
+ help="The name of the experiment for logging in Tensorboard",
180
+ )
181
+ @click.option(
182
+ "--config_file",
183
+ default="kny_image.yaml",
184
+ help="The name of the config file located inside ./config",
185
+ )
186
+ def run(
187
+ experiment_name: str,
188
+ config_file: str,
189
+ dataset: str,
190
+ model: str,
191
+ epochs: int,
192
+ num_samples: int,
193
+ num_cpus: int,
194
+ num_gpus: int,
195
+ max_concurrent_trials: int,
196
+ metric: str,
197
+ ):
198
+ config_file = here(os.path.join("configs", config_file))
199
+
200
+ ray.init(num_cpus=num_cpus, num_gpus=num_gpus)
201
+ tune_ganime(
202
+ experiment_name=experiment_name,
203
+ dataset_name=dataset,
204
+ config_file=config_file,
205
+ model=model,
206
+ epochs=epochs,
207
+ num_samples=num_samples,
208
+ num_cpus=num_cpus,
209
+ num_gpus=num_gpus,
210
+ max_concurrent_trials=max_concurrent_trials,
211
+ metric=metric,
212
+ )
ganime/configs/__init__.py ADDED
File without changes
ganime/configs/model_configs.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+ try:
4
+ from typing import Literal
5
+ except ImportError:
6
+ from typing_extensions import Literal
7
+
8
+
9
+ @dataclass
10
+ class GPTConfig:
11
+ n_layer: int
12
+ n_head: int
13
+ n_embedding: int
14
+ vocab_size: int
15
+ block_size: int
16
+ embedding_percentage_drop: float
17
+ attention_percentage_drop: float
18
+
19
+
20
+ @dataclass
21
+ class VQVAEConfig:
22
+ beta: float
23
+ num_embeddings: int
24
+ embedding_dim: int
25
+
26
+
27
+ @dataclass
28
+ class AutoencoderConfig:
29
+ z_channels: int
30
+ channels: int
31
+ channels_multiplier: List[int]
32
+ num_res_blocks: int
33
+ attention_resolution: List[int]
34
+ resolution: int
35
+ dropout: float
36
+
37
+
38
+ @dataclass
39
+ class DiscriminatorConfig:
40
+ num_layers: int
41
+ filters: int
42
+
43
+
44
+ @dataclass
45
+ class DiscriminatorLossConfig:
46
+ loss: Literal["hinge, vanilla"]
47
+ factor: float
48
+ iter_start: int
49
+ weight: float
50
+
51
+
52
+ @dataclass
53
+ class VQVAELossConfig:
54
+ codebook_weight: float
55
+ perceptual_weight: float
56
+
57
+
58
+ @dataclass
59
+ class LossConfig:
60
+ discriminator: DiscriminatorLossConfig
61
+ vqvae: VQVAELossConfig
62
+ perceptual_loss: str
63
+
64
+
65
+ @dataclass
66
+ class ModelConfig:
67
+ vqvae_config: VQVAEConfig
68
+ autoencoder_config: AutoencoderConfig
69
+ discriminator_config: DiscriminatorConfig
70
+ loss_config: LossConfig
ganime/data/__init__.py ADDED
File without changes
ganime/data/base.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import os
5
+ from tensorflow.keras.utils import Sequence
6
+ from abc import ABC, abstractmethod
7
+ from typing import Literal
8
+ import math
9
+ from ganime.data.experimental import ImageDataset
10
+
11
+
12
+ # class SequenceDataset(Sequence):
13
+ # def __init__(
14
+ # self,
15
+ # dataset_path: str,
16
+ # batch_size: int,
17
+ # split: Literal["train", "validation", "test"] = "train",
18
+ # ):
19
+ # self.batch_size = batch_size
20
+ # self.split = split
21
+ # self.data = self.load_data(dataset_path, split)
22
+ # self.data = self.preprocess_data(self.data)
23
+
24
+ # self.indices = np.arange(self.data.shape[0])
25
+ # self.on_epoch_end()
26
+
27
+ # @abstractmethod
28
+ # def load_data(self, dataset_path: str, split: str) -> np.ndarray:
29
+ # pass
30
+
31
+ # def preprocess_data(self, data: np.ndarray) -> np.ndarray:
32
+ # return data
33
+
34
+ # def __len__(self):
35
+ # return math.ceil(len(self.data) / self.batch_size)
36
+
37
+ # def __getitem__(self, idx):
38
+ # inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size]
39
+ # batch_x = self.data[inds]
40
+ # batch_y = batch_x
41
+
42
+ # return batch_x, batch_y
43
+
44
+ # def get_fixed_batch(self, idx):
45
+ # self.fixed_indices = (
46
+ # self.fixed_indices
47
+ # if hasattr(self, "fixed_indices")
48
+ # else self.indices[
49
+ # idx * self.batch_size : (idx + 1) * self.batch_size
50
+ # ].copy()
51
+ # )
52
+ # batch_x = self.data[self.fixed_indices]
53
+ # batch_y = batch_x
54
+
55
+ # return batch_x, batch_y
56
+
57
+ # def on_epoch_end(self):
58
+ # np.random.shuffle(self.indices)
59
+
60
+
61
+ # def load_kny_images(
62
+ # dataset_path: str, batch_size: int
63
+ # ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tuple]:
64
+ # import skvideo.io
65
+
66
+ # if os.path.exists(os.path.join(dataset_path, "kny", "kny_images.npy")):
67
+ # data = np.load(os.path.join(dataset_path, "kny", "kny_images.npy"))
68
+ # else:
69
+ # data = skvideo.io.vread(os.path.join(dataset_path, "kny", "01.mp4"))
70
+ # np.random.shuffle(data)
71
+
72
+ # def _preprocess(sample):
73
+ # image = tf.cast(sample, tf.float32) / 255.0 # Scale to unit interval.
74
+ # # video = video < tf.random.uniform(tf.shape(video)) # Randomly binarize.
75
+ # image = tf.image.resize(image, [64, 64])
76
+
77
+ # return image, image
78
+
79
+ # train_dataset = (
80
+ # tf.data.Dataset.from_tensor_slices(data[:5000])
81
+ # .map(_preprocess)
82
+ # .batch(batch_size)
83
+ # .prefetch(tf.data.AUTOTUNE)
84
+ # .shuffle(int(10e3))
85
+ # )
86
+ # test_dataset = (
87
+ # tf.data.Dataset.from_tensor_slices(data[5000:6000])
88
+ # .map(_preprocess)
89
+ # .batch(batch_size)
90
+ # .prefetch(tf.data.AUTOTUNE)
91
+ # .shuffle(int(10e3))
92
+ # )
93
+
94
+ # return train_dataset, test_dataset, data.shape[1:]
95
+
96
+
97
+ # def load_moving_mnist_vae(
98
+ # dataset_path: str, batch_size: int
99
+ # ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tuple]:
100
+ # data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy"))
101
+ # data.shape
102
+
103
+ # # We can see that data is of shape (window, n_samples, width, height)
104
+ # # But we want for keras something of shape (n_samples, window, width, height)
105
+ # data = np.moveaxis(data, 0, 1)
106
+ # # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
107
+ # data = np.expand_dims(data, axis=-1)
108
+
109
+ # def _preprocess(sample):
110
+ # video = tf.cast(sample, tf.float32) / 255.0 # Scale to unit interval.
111
+ # # video = video < tf.random.uniform(tf.shape(video)) # Randomly binarize.
112
+ # return video, video
113
+
114
+ # train_dataset = (
115
+ # tf.data.Dataset.from_tensor_slices(data[:9000])
116
+ # .map(_preprocess)
117
+ # .batch(batch_size)
118
+ # .prefetch(tf.data.AUTOTUNE)
119
+ # .shuffle(int(10e3))
120
+ # )
121
+ # test_dataset = (
122
+ # tf.data.Dataset.from_tensor_slices(data[9000:])
123
+ # .map(_preprocess)
124
+ # .batch(batch_size)
125
+ # .prefetch(tf.data.AUTOTUNE)
126
+ # .shuffle(int(10e3))
127
+ # )
128
+
129
+ # return train_dataset, test_dataset, data.shape[1:]
130
+
131
+
132
+ # def load_moving_mnist(
133
+ # dataset_path: str, batch_size: int
134
+ # ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tuple]:
135
+ # data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy"))
136
+ # data.shape
137
+
138
+ # # We can see that data is of shape (window, n_samples, width, height)
139
+ # # But we want for keras something of shape (n_samples, window, width, height)
140
+ # data = np.moveaxis(data, 0, 1)
141
+ # # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
142
+ # data = np.expand_dims(data, axis=-1)
143
+
144
+ # def _preprocess(sample):
145
+ # video = tf.cast(sample, tf.float32) / 255.0 # Scale to unit interval.
146
+ # # video = video < tf.random.uniform(tf.shape(video)) # Randomly binarize.
147
+ # first_frame = video[0:1, ...]
148
+ # last_frame = video[-1:, ...]
149
+ # first_last = tf.concat([first_frame, last_frame], axis=0)
150
+
151
+ # return first_last, video
152
+
153
+ # train_dataset = (
154
+ # tf.data.Dataset.from_tensor_slices(data[:9000])
155
+ # .map(_preprocess)
156
+ # .batch(batch_size)
157
+ # .prefetch(tf.data.AUTOTUNE)
158
+ # .shuffle(int(10e3))
159
+ # )
160
+ # test_dataset = (
161
+ # tf.data.Dataset.from_tensor_slices(data[9000:])
162
+ # .map(_preprocess)
163
+ # .batch(batch_size)
164
+ # .prefetch(tf.data.AUTOTUNE)
165
+ # .shuffle(int(10e3))
166
+ # )
167
+
168
+ # return train_dataset, test_dataset, data.shape[1:]
169
+
170
+
171
+ # def load_mnist(
172
+ # dataset_path: str, batch_size: int
173
+ # ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tuple]:
174
+ # data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy"))
175
+ # data.shape
176
+
177
+ # # We can see that data is of shape (window, n_samples, width, height)
178
+ # # But we want for keras something of shape (n_samples, window, width, height)
179
+ # data = np.moveaxis(data, 0, 1)
180
+ # # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
181
+ # data = np.expand_dims(data, axis=-1)
182
+
183
+ # def _preprocess(sample):
184
+ # video = tf.cast(sample, tf.float32) / 255.0 # Scale to unit interval.
185
+ # # video = video < tf.random.uniform(tf.shape(video)) # Randomly binarize.
186
+ # first_frame = video[0, ...]
187
+
188
+ # first_frame = tf.image.grayscale_to_rgb(first_frame)
189
+
190
+ # return first_frame, first_frame
191
+
192
+ # train_dataset = (
193
+ # tf.data.Dataset.from_tensor_slices(data[:9000])
194
+ # .map(_preprocess)
195
+ # .batch(batch_size)
196
+ # .prefetch(tf.data.AUTOTUNE)
197
+ # .shuffle(int(10e3))
198
+ # )
199
+ # test_dataset = (
200
+ # tf.data.Dataset.from_tensor_slices(data[9000:])
201
+ # .map(_preprocess)
202
+ # .batch(batch_size)
203
+ # .prefetch(tf.data.AUTOTUNE)
204
+ # .shuffle(int(10e3))
205
+ # )
206
+
207
+ # return train_dataset, test_dataset, data.shape[1:]
208
+ def preprocess_image(element):
209
+ element = tf.reshape(element, (tf.shape(element)[0], tf.shape(element)[1], 3))
210
+ element = tf.cast(element, tf.float32) / 255.0
211
+ return element, element
212
+
213
+
214
+ def load_kny_images_light(dataset_path, batch_size):
215
+ dataset_length = 34045
216
+ path = os.path.join(dataset_path, "kny", "images_tfrecords_light")
217
+ dataset = ImageDataset(path).load()
218
+ dataset = dataset.shuffle(
219
+ dataset_length, reshuffle_each_iteration=True, seed=10
220
+ ).map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
221
+
222
+ train_size = int(dataset_length * 0.8)
223
+ validation_size = int(dataset_length * 0.1)
224
+
225
+ train_ds = dataset.take(train_size)
226
+ validation_ds = dataset.skip(train_size).take(validation_size)
227
+ test_ds = dataset.skip(train_size + validation_size).take(validation_size)
228
+
229
+ train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(
230
+ tf.data.AUTOTUNE
231
+ )
232
+ validation_ds = validation_ds.batch(batch_size, drop_remainder=True).prefetch(
233
+ tf.data.AUTOTUNE
234
+ )
235
+ test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
236
+
237
+ return train_ds, validation_ds, test_ds
238
+
239
+
240
+ def load_kny_images(dataset_path, batch_size):
241
+ dataset_length = 52014
242
+ path = os.path.join(dataset_path, "kny", "images_tfrecords")
243
+ dataset = ImageDataset(path).load()
244
+ dataset = dataset.shuffle(dataset_length, reshuffle_each_iteration=True).map(
245
+ preprocess_image, num_parallel_calls=tf.data.AUTOTUNE
246
+ )
247
+
248
+ train_size = int(dataset_length * 0.8)
249
+ validation_size = int(dataset_length * 0.1)
250
+
251
+ train_ds = dataset.take(train_size)
252
+ validation_ds = dataset.skip(train_size).take(validation_size)
253
+ test_ds = dataset.skip(train_size + validation_size).take(validation_size)
254
+
255
+ train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(
256
+ tf.data.AUTOTUNE
257
+ )
258
+ validation_ds = validation_ds.batch(batch_size, drop_remainder=True).prefetch(
259
+ tf.data.AUTOTUNE
260
+ )
261
+ test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
262
+
263
+ return train_ds, validation_ds, test_ds
264
+
265
+
266
+ def load_dataset(
267
+ dataset_name: str, dataset_path: str, batch_size: int
268
+ ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
269
+ # if dataset_name == "moving_mnist_vae":
270
+ # return load_moving_mnist_vae(dataset_path, batch_size)
271
+ # elif dataset_name == "moving_mnist":
272
+ # return load_moving_mnist(dataset_path, batch_size)
273
+ # elif dataset_name == "mnist":
274
+ # return load_mnist(dataset_path, batch_size)
275
+ # elif dataset_name == "kny_images":
276
+ # return load_kny_images(dataset_path, batch_size)
277
+ if dataset_name == "kny_images":
278
+ return load_kny_images(dataset_path, batch_size)
279
+ if dataset_name == "kny_images_light":
280
+ return load_kny_images_light(dataset_path, batch_size)
281
+ else:
282
+ raise ValueError(f"Unknown dataset: {dataset_name}")
ganime/data/experimental.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractclassmethod, abstractmethod
2
+ import glob
3
+ import math
4
+ import os
5
+ from typing import Dict
6
+ from typing_extensions import dataclass_transform
7
+
8
+ import numpy as np
9
+ import tensorflow as tf
10
+ from tqdm.auto import tqdm
11
+
12
+
13
+ def _bytes_feature(value):
14
+ """Returns a bytes_list from a string / byte."""
15
+ if isinstance(value, type(tf.constant(0))): # if value ist tensor
16
+ value = value.numpy() # get value of tensor
17
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
18
+
19
+
20
+ def _float_feature(value):
21
+ """Returns a floast_list from a float / double."""
22
+ return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
23
+
24
+
25
+ def _int64_feature(value):
26
+ """Returns an int64_list from a bool / enum / int / uint."""
27
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
28
+
29
+
30
+ def serialize_array(array):
31
+ array = tf.io.serialize_tensor(array)
32
+ return array
33
+
34
+
35
+ class Dataset(ABC):
36
+ def __init__(self, dataset_path: str):
37
+ self.dataset_path = dataset_path
38
+
39
+ @classmethod
40
+ def _parse_single_element(cls, element) -> tf.train.Example:
41
+
42
+ features = tf.train.Features(feature=cls._get_features(element))
43
+
44
+ return tf.train.Example(features=features)
45
+
46
+ @abstractclassmethod
47
+ def _get_features(cls, element) -> Dict[str, tf.train.Feature]:
48
+ pass
49
+
50
+ @abstractclassmethod
51
+ def _parse_tfr_element(cls, element):
52
+ pass
53
+
54
+ @classmethod
55
+ def write_to_tfr(cls, data: np.ndarray, out_dir: str, filename: str):
56
+ if not os.path.exists(out_dir):
57
+ os.makedirs(out_dir)
58
+
59
+ # Write all elements to a single tfrecord file
60
+ single_file_name = cls.__write_to_single_tfr(data, out_dir, filename)
61
+
62
+ # The optimal size for a single tfrecord file is around 100 MB. Get the number of files that need to be created
63
+ number_splits = cls.__get_number_splits(single_file_name)
64
+
65
+ if number_splits > 1:
66
+ os.remove(single_file_name)
67
+ cls.__write_to_multiple_tfr(data, out_dir, filename, number_splits)
68
+
69
+ @classmethod
70
+ def __write_to_multiple_tfr(
71
+ cls, data: np.array, out_dir: str, filename: str, n_splits: int
72
+ ):
73
+
74
+ file_count = 0
75
+
76
+ max_files = math.ceil(data.shape[0] / n_splits)
77
+
78
+ print(f"Creating {n_splits} files with {max_files} elements each.")
79
+
80
+ for i in tqdm(range(n_splits)):
81
+ current_shard_name = os.path.join(
82
+ out_dir,
83
+ f"{filename}.tfrecords-{str(i).zfill(len(str(n_splits)))}-of-{n_splits}",
84
+ )
85
+ writer = tf.io.TFRecordWriter(current_shard_name)
86
+
87
+ current_shard_count = 0
88
+ while current_shard_count < max_files: # as long as our shard is not full
89
+ # get the index of the file that we want to parse now
90
+ index = i * max_files + current_shard_count
91
+ if index >= len(
92
+ data
93
+ ): # when we have consumed the whole data, preempt generation
94
+ break
95
+
96
+ current_element = data[index]
97
+
98
+ # create the required Example representation
99
+ out = cls._parse_single_element(element=current_element)
100
+
101
+ writer.write(out.SerializeToString())
102
+ current_shard_count += 1
103
+ file_count += 1
104
+
105
+ writer.close()
106
+ print(f"\nWrote {file_count} elements to TFRecord")
107
+ return file_count
108
+
109
+ @classmethod
110
+ def __get_number_splits(cls, filename: str):
111
+ target_size = 100 * 1024 * 1024 # 100mb
112
+
113
+ single_file_size = os.path.getsize(filename)
114
+ number_splits = math.ceil(single_file_size / target_size)
115
+ return number_splits
116
+
117
+ @classmethod
118
+ def __write_to_single_tfr(cls, data: np.array, out_dir: str, filename: str):
119
+
120
+ current_path_name = os.path.join(
121
+ out_dir,
122
+ f"{filename}.tfrecords-0-of-1",
123
+ )
124
+
125
+ writer = tf.io.TFRecordWriter(current_path_name)
126
+ for element in tqdm(data):
127
+ writer.write(cls._parse_single_element(element).SerializeToString())
128
+ writer.close()
129
+
130
+ return current_path_name
131
+
132
+ def load(self) -> tf.data.TFRecordDataset:
133
+ path = self.dataset_path
134
+ dataset = None
135
+
136
+ if os.path.isdir(path):
137
+ dataset = self._load_folder(path)
138
+ elif os.path.isfile(path):
139
+ dataset = self._load_file(path)
140
+ else:
141
+ raise ValueError(f"Path {path} is not a valid file or folder.")
142
+
143
+ dataset = dataset.map(self._parse_tfr_element)
144
+ return dataset
145
+
146
+ def _load_file(self, path) -> tf.data.TFRecordDataset:
147
+ return tf.data.TFRecordDataset(path)
148
+
149
+ def _load_folder(self, path) -> tf.data.TFRecordDataset:
150
+
151
+ return tf.data.TFRecordDataset(
152
+ glob.glob(os.path.join(path, "**/*.tfrecords*"), recursive=True)
153
+ )
154
+
155
+
156
+ class VideoDataset(Dataset):
157
+ @classmethod
158
+ def _get_features(cls, element) -> Dict[str, tf.train.Feature]:
159
+ return {
160
+ "frames": _int64_feature(element.shape[0]),
161
+ "height": _int64_feature(element.shape[1]),
162
+ "width": _int64_feature(element.shape[2]),
163
+ "depth": _int64_feature(element.shape[3]),
164
+ "raw_video": _bytes_feature(serialize_array(element)),
165
+ }
166
+
167
+ @classmethod
168
+ def _parse_tfr_element(cls, element):
169
+ # use the same structure as above; it's kinda an outline of the structure we now want to create
170
+ data = {
171
+ "frames": tf.io.FixedLenFeature([], tf.int64),
172
+ "height": tf.io.FixedLenFeature([], tf.int64),
173
+ "width": tf.io.FixedLenFeature([], tf.int64),
174
+ "raw_video": tf.io.FixedLenFeature([], tf.string),
175
+ "depth": tf.io.FixedLenFeature([], tf.int64),
176
+ }
177
+
178
+ content = tf.io.parse_single_example(element, data)
179
+
180
+ frames = content["frames"]
181
+ height = content["height"]
182
+ width = content["width"]
183
+ depth = content["depth"]
184
+ raw_video = content["raw_video"]
185
+
186
+ # get our 'feature'-- our image -- and reshape it appropriately
187
+ feature = tf.io.parse_tensor(raw_video, out_type=tf.uint8)
188
+ feature = tf.reshape(feature, shape=[frames, height, width, depth])
189
+ return feature
190
+
191
+
192
+ class ImageDataset(Dataset):
193
+ @classmethod
194
+ def _get_features(cls, element) -> Dict[str, tf.train.Feature]:
195
+ return {
196
+ "height": _int64_feature(element.shape[0]),
197
+ "width": _int64_feature(element.shape[1]),
198
+ "depth": _int64_feature(element.shape[2]),
199
+ "raw_image": _bytes_feature(serialize_array(element)),
200
+ }
201
+
202
+ @classmethod
203
+ def _parse_tfr_element(cls, element):
204
+ # use the same structure as above; it's kinda an outline of the structure we now want to create
205
+ data = {
206
+ "height": tf.io.FixedLenFeature([], tf.int64),
207
+ "width": tf.io.FixedLenFeature([], tf.int64),
208
+ "raw_image": tf.io.FixedLenFeature([], tf.string),
209
+ "depth": tf.io.FixedLenFeature([], tf.int64),
210
+ }
211
+
212
+ content = tf.io.parse_single_example(element, data)
213
+
214
+ height = content["height"]
215
+ width = content["width"]
216
+ depth = content["depth"]
217
+ raw_image = content["raw_image"]
218
+
219
+ # get our 'feature'-- our image -- and reshape it appropriately
220
+ feature = tf.io.parse_tensor(raw_image, out_type=tf.uint8)
221
+ feature = tf.reshape(feature, shape=[height, width, depth])
222
+ return feature
ganime/data/kny.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+
5
+ from .base import SequenceDataset
6
+
7
+
8
+ class KNYImage(SequenceDataset):
9
+ def load_data(self, dataset_path: str, split: str) -> np.ndarray:
10
+ data = np.load(os.path.join(dataset_path, "kny", "kny_images_64x128.npy"))
11
+ if split == "train":
12
+ data = data[:-5000]
13
+ else:
14
+ data = data[-5000:]
15
+
16
+ return data
17
+
18
+ def preprocess_data(self, data: np.ndarray) -> np.ndarray:
19
+ return data / 255
ganime/data/mnist.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from typing import Literal
4
+
5
+ import numpy as np
6
+
7
+ from .base import SequenceDataset
8
+ import math
9
+
10
+
11
+ class MovingMNISTImage(SequenceDataset):
12
+ def load_data(self, dataset_path: str, split: str) -> np.ndarray:
13
+ data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy"))
14
+ # Data is of shape (window, n_samples, width, height)
15
+ # But we want for keras something of shape (n_samples, window, width, height)
16
+ data = np.moveaxis(data, 0, 1)
17
+ # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
18
+ data = np.expand_dims(data, axis=-1)
19
+ if split == "train":
20
+ data = data[:-1000]
21
+ else:
22
+ data = data[-1000:]
23
+
24
+ data = np.concatenate([data, data, data], axis=-1)
25
+
26
+ return data
27
+
28
+ def __getitem__(self, idx):
29
+ inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size]
30
+ batch_x = self.data[inds, 0, ...]
31
+ batch_y = self.data[inds, 1, ...]
32
+
33
+ return batch_x, batch_y
34
+
35
+ def preprocess_data(self, data: np.ndarray) -> np.ndarray:
36
+ return data / 255
37
+
38
+
39
+ class MovingMNIST(SequenceDataset):
40
+ def __init__(
41
+ self,
42
+ dataset_path: str,
43
+ batch_size: int,
44
+ split: Literal["train", "validation", "test"] = "train",
45
+ ):
46
+ self.batch_size = batch_size
47
+ self.split = split
48
+ root_path = os.path.join(dataset_path, "moving_mnist", split)
49
+ self.paths = glob.glob(os.path.join(root_path, "*.npy"))
50
+ # self.data = self.preprocess_data(self.data)
51
+
52
+ self.indices = np.arange(len(self.paths))
53
+ self.on_epoch_end()
54
+
55
+ # def load_data(self, dataset_path: str, split: str) -> np.ndarray:
56
+ # data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy"))
57
+ # # Data is of shape (window, n_samples, width, height)
58
+ # # But we want for keras something of shape (n_samples, window, width, height)
59
+ # data = np.moveaxis(data, 0, 1)
60
+ # # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
61
+ # data = np.expand_dims(data, axis=-1)
62
+ # if split == "train":
63
+ # data = data[:100]
64
+ # else:
65
+ # data = data[100:110]
66
+
67
+ # data = np.concatenate([data, data, data], axis=-1)
68
+
69
+ # return data
70
+
71
+ def __len__(self):
72
+ return math.ceil(len(self.paths) / self.batch_size)
73
+
74
+ def __getitem__(self, idx):
75
+ inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size]
76
+ data = self.load_indices(inds)
77
+ batch_x = np.concatenate([data[:, 0:1, ...], data[:, -1:, ...]], axis=1)
78
+ batch_y = data[:, 1:, ...]
79
+
80
+ return batch_x, batch_y
81
+
82
+ def get_fixed_batch(self, idx):
83
+ self.fixed_indices = (
84
+ self.fixed_indices
85
+ if hasattr(self, "fixed_indices")
86
+ else self.indices[
87
+ idx * self.batch_size : (idx + 1) * self.batch_size
88
+ ].copy()
89
+ )
90
+ data = self.load_indices(self.fixed_indices)
91
+ batch_x = np.concatenate([data[:, 0:1, ...], data[:, -1:, ...]], axis=1)
92
+ batch_y = data[:, 1:, ...]
93
+
94
+ return batch_x, batch_y
95
+
96
+ def load_indices(self, indices):
97
+ paths_to_load = [self.paths[index] for index in indices]
98
+ data = [np.load(path) for path in paths_to_load]
99
+ data = np.array(data)
100
+ return self.preprocess_data(data)
101
+
102
+ def preprocess_data(self, data: np.ndarray) -> np.ndarray:
103
+ return data / 255
ganime/metrics/image.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from scipy import linalg
4
+ from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
5
+ from tqdm.auto import tqdm
6
+
7
+ inceptionv3 = InceptionV3(include_top=False, weights="imagenet", pooling="avg")
8
+
9
+
10
+ def resize_images(images, new_shape):
11
+ images = tf.image.resize(images, new_shape)
12
+ return images
13
+
14
+
15
+ def calculate_fid(real_embeddings, generated_embeddings):
16
+ # calculate mean and covariance statistics
17
+ mu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings, rowvar=False)
18
+ mu2, sigma2 = generated_embeddings.mean(axis=0), np.cov(
19
+ generated_embeddings, rowvar=False
20
+ )
21
+ # calculate sum squared difference between means
22
+ ssdiff = np.sum((mu1 - mu2) ** 2.0)
23
+ # calculate sqrt of product between cov
24
+ covmean = linalg.sqrtm(sigma1.dot(sigma2))
25
+ # check and correct imaginary numbers from sqrt
26
+ if np.iscomplexobj(covmean):
27
+ covmean = covmean.real
28
+ # calculate score
29
+ fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
30
+ return fid
31
+
32
+
33
+ def calculate_images_metrics(dataset, model, total_length):
34
+ fake_embeddings = []
35
+ real_embeddings = []
36
+
37
+ psnrs = []
38
+ ssims = []
39
+
40
+ for sample in tqdm(dataset, total=total_length):
41
+ generated = model(sample[0], training=False)[0]
42
+ generated, real = generated, sample[0]
43
+
44
+ real_resized = resize_images(real, (299, 299))
45
+ generated_resized = resize_images(generated, (299, 299))
46
+
47
+ real_activations = inceptionv3(real_resized, training=False)
48
+ generated_activations = inceptionv3(generated_resized, training=False)
49
+ fake_embeddings.append(generated_activations)
50
+ real_embeddings.append(real_activations)
51
+
52
+ fake_scaled = tf.cast(((generated * 0.5) + 1) * 255, tf.uint8)
53
+ real_scaled = tf.cast(((real * 0.5) + 1) * 255, tf.uint8)
54
+
55
+ psnrs.append(tf.image.psnr(fake_scaled, real_scaled, 255).numpy())
56
+ ssims.append(tf.image.ssim(fake_scaled, real_scaled, 255).numpy())
57
+
58
+ fid = calculate_fid(
59
+ tf.concat(fake_embeddings, axis=0).numpy(),
60
+ tf.concat(real_embeddings, axis=0).numpy(),
61
+ )
62
+
63
+ # kid = calculate_kid(
64
+ # tf.concat(fake_embeddings, axis=0).numpy(),
65
+ # tf.concat(real_embeddings, axis=0).numpy(),
66
+ # )
67
+
68
+ psnr = np.array(psnrs).mean()
69
+ ssim = np.array(ssims).mean()
70
+ return {"fid": fid, "ssim": ssim, "psnr": psnr}
ganime/metrics/video.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import tensorflow_gan as tfgan
4
+ import tensorflow_hub as hub
5
+ from sklearn.metrics.pairwise import polynomial_kernel
6
+ from tqdm.auto import tqdm
7
+
8
+ i3d = hub.KerasLayer("https://tfhub.dev/deepmind/i3d-kinetics-400/1")
9
+
10
+
11
+ def resize_videos(videos, target_resolution):
12
+ """Runs some preprocessing on the videos for I3D model.
13
+ Args:
14
+ videos: <T>[batch_size, num_frames, height, width, depth] The videos to be
15
+ preprocessed. We don't care about the specific dtype of the videos, it can
16
+ be anything that tf.image.resize_bilinear accepts. Values are expected to
17
+ be in [-1, 1].
18
+ target_resolution: (width, height): target video resolution
19
+ Returns:
20
+ videos: <float32>[batch_size, num_frames, height, width, depth]
21
+ """
22
+ min_frames = 9
23
+ B, T, H, W, C = videos.shape
24
+ videos = tf.transpose(videos, (1, 0, 2, 3, 4))
25
+ if T < min_frames:
26
+ videos = tf.concat([tf.zeros((min_frames - T, B, H, W, C)), videos], axis=0)
27
+ scaled_videos = tf.map_fn(lambda x: tf.image.resize(x, target_resolution), videos)
28
+ scaled_videos = tf.transpose(scaled_videos, (1, 0, 2, 3, 4))
29
+ return scaled_videos
30
+
31
+
32
+ def polynomial_mmd(X, Y):
33
+ m = X.shape[0]
34
+ n = Y.shape[0]
35
+ # compute kernels
36
+ K_XX = polynomial_kernel(X)
37
+ K_YY = polynomial_kernel(Y)
38
+ K_XY = polynomial_kernel(X, Y)
39
+ # compute mmd distance
40
+ K_XX_sum = (K_XX.sum() - np.diagonal(K_XX).sum()) / (m * (m - 1))
41
+ K_YY_sum = (K_YY.sum() - np.diagonal(K_YY).sum()) / (n * (n - 1))
42
+ K_XY_sum = K_XY.sum() / (m * n)
43
+ mmd = K_XX_sum + K_YY_sum - 2 * K_XY_sum
44
+ return mmd
45
+
46
+
47
+ def calculate_ssim_videos(fake, real):
48
+ fake = tf.cast(((fake * 0.5) + 1) * 255, tf.uint8)
49
+ real = tf.cast(((real * 0.5) + 1) * 255, tf.uint8)
50
+ ssims = []
51
+ for i in range(fake.shape[0]):
52
+ ssims.append(tf.image.ssim(fake[i], real[i], 255).numpy().mean())
53
+
54
+ return np.array(ssims).mean()
55
+
56
+
57
+ def calculate_psnr_videos(fake, real):
58
+ fake = tf.cast(((fake * 0.5) + 1) * 255, tf.uint8)
59
+ real = tf.cast(((real * 0.5) + 1) * 255, tf.uint8)
60
+ psnrs = []
61
+ for i in range(fake.shape[0]):
62
+ psnrs.append(tf.image.psnr(fake[i], real[i], 255).numpy().mean())
63
+
64
+ return np.array(psnrs).mean()
65
+
66
+
67
+ def calculate_videos_metrics(dataset, model, total_length):
68
+ fake_embeddings = []
69
+ real_embeddings = []
70
+
71
+ psnrs = []
72
+ ssims = []
73
+
74
+ for sample in tqdm(dataset, total=total_length):
75
+ generated = model(sample, training=False)
76
+ generated, real = generated[:, 1:], sample["y"][:, 1:] # ignore first frame
77
+
78
+ real_resized = resize_videos(real, (224, 224))
79
+ generated_resized = resize_videos(generated, (224, 224))
80
+
81
+ real_activations = i3d(real_resized)
82
+ generated_activations = i3d(generated_resized)
83
+ fake_embeddings.append(generated_activations)
84
+ real_embeddings.append(real_activations)
85
+
86
+ psnrs.append(calculate_psnr_videos(generated, real))
87
+ ssims.append(calculate_ssim_videos(generated, real))
88
+
89
+ # fake_concat, real_concat = tf.concat(fake_embeddings, axis=0), tf.concat(real_embeddings, axis=0)
90
+ fvd = tfgan.eval.frechet_classifier_distance_from_activations(
91
+ tf.concat(fake_embeddings, axis=0), tf.concat(real_embeddings, axis=0)
92
+ )
93
+ kvd = polynomial_mmd(
94
+ tf.concat(fake_embeddings, axis=0), tf.concat(real_embeddings, axis=0)
95
+ )
96
+ psnr = np.array(psnrs).mean()
97
+ ssim = np.array(ssims).mean()
98
+ return {"fvd": fvd, "kvd": kvd, "ssim": ssim, "psnr": psnr}
ganime/model/__init__.py ADDED
File without changes
ganime/model/base.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from ganime.model.vqgan_clean.vqgan import VQGAN
3
+
4
+
5
+ def load_model(
6
+ model: str, config: dict, strategy: tf.distribute.Strategy
7
+ ) -> tf.keras.Model:
8
+
9
+ if model == "vqgan":
10
+ with strategy.scope():
11
+ print(config["model"])
12
+ model = VQGAN(**config["model"])
13
+
14
+ gen_optimizer = tf.keras.optimizers.Adam(
15
+ learning_rate=config["trainer"]["gen_lr"],
16
+ beta_1=config["trainer"]["gen_beta_1"],
17
+ beta_2=config["trainer"]["gen_beta_2"],
18
+ clipnorm=config["trainer"]["gen_clip_norm"],
19
+ )
20
+ disc_optimizer = tf.keras.optimizers.Adam(
21
+ learning_rate=config["trainer"]["disc_lr"],
22
+ beta_1=config["trainer"]["disc_beta_1"],
23
+ beta_2=config["trainer"]["disc_beta_2"],
24
+ clipnorm=config["trainer"]["disc_clip_norm"],
25
+ )
26
+ model.compile(gen_optimizer=gen_optimizer, disc_optimizer=disc_optimizer)
27
+ return model
28
+ else:
29
+ raise ValueError(f"Unknown model: {model}")
30
+
31
+ # if model == "moving_vae":
32
+ # from ganime.model.moving_vae import MovingVAE
33
+
34
+ # with strategy.scope():
35
+ # model = MovingVAE(input_shape=input_shape)
36
+
37
+ # negloglik = lambda x, rv_x: -rv_x.log_prob(x)
38
+ # model.compile(
39
+ # optimizer=tf.optimizers.Adam(learning_rate=config["lr"]),
40
+ # loss=negloglik,
41
+ # )
42
+ # # model.build(input_shape=(None, *input_shape))
43
+ # # model.summary()
44
+
45
+ # return model
ganime/model/moving_vae.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorflow.keras import Model
2
+
3
+ import tensorflow as tf
4
+ import tensorflow_probability as tfp
5
+
6
+
7
+ class MovingVAE(Model):
8
+ def __init__(self, input_shape, encoded_size=64, base_depth=32):
9
+ super().__init__()
10
+
11
+ self.encoded_size = encoded_size
12
+ self.base_depth = base_depth
13
+
14
+ self.prior = tfp.distributions.Independent(
15
+ tfp.distributions.Normal(loc=tf.zeros(encoded_size), scale=1),
16
+ reinterpreted_batch_ndims=1,
17
+ )
18
+
19
+ self.encoder = tf.keras.Sequential(
20
+ [
21
+ tf.keras.layers.InputLayer(input_shape=input_shape),
22
+ tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
23
+ tf.keras.layers.Conv3D(
24
+ self.base_depth,
25
+ 5,
26
+ strides=1,
27
+ padding="same",
28
+ activation=tf.nn.leaky_relu,
29
+ ),
30
+ tf.keras.layers.Conv3D(
31
+ self.base_depth,
32
+ 5,
33
+ strides=2,
34
+ padding="same",
35
+ activation=tf.nn.leaky_relu,
36
+ ),
37
+ tf.keras.layers.Conv3D(
38
+ 2 * self.base_depth,
39
+ 5,
40
+ strides=1,
41
+ padding="same",
42
+ activation=tf.nn.leaky_relu,
43
+ ),
44
+ tf.keras.layers.Conv3D(
45
+ 2 * self.base_depth,
46
+ 5,
47
+ strides=2,
48
+ padding="same",
49
+ activation=tf.nn.leaky_relu,
50
+ ),
51
+ # tf.keras.layers.Conv3D(4 * encoded_size, 7, strides=1,
52
+ # padding='valid', activation=tf.nn.leaky_relu),
53
+ tf.keras.layers.Flatten(),
54
+ tf.keras.layers.Dense(
55
+ tfp.layers.MultivariateNormalTriL.params_size(self.encoded_size),
56
+ activation=None,
57
+ ),
58
+ tfp.layers.MultivariateNormalTriL(
59
+ self.encoded_size,
60
+ activity_regularizer=tfp.layers.KLDivergenceRegularizer(self.prior),
61
+ ),
62
+ ]
63
+ )
64
+
65
+ self.decoder = tf.keras.Sequential(
66
+ [
67
+ tf.keras.layers.InputLayer(input_shape=[self.encoded_size]),
68
+ tf.keras.layers.Reshape([1, 1, 1, self.encoded_size]),
69
+ tf.keras.layers.Conv3DTranspose(
70
+ self.base_depth,
71
+ (5, 4, 4),
72
+ strides=1,
73
+ padding="valid",
74
+ activation=tf.nn.leaky_relu,
75
+ ),
76
+ tf.keras.layers.Conv3DTranspose(
77
+ 2 * self.base_depth,
78
+ (5, 4, 4),
79
+ strides=(1, 2, 2),
80
+ padding="same",
81
+ activation=tf.nn.leaky_relu,
82
+ ),
83
+ tf.keras.layers.Conv3DTranspose(
84
+ 2 * self.base_depth,
85
+ (5, 4, 4),
86
+ strides=2,
87
+ padding="same",
88
+ activation=tf.nn.leaky_relu,
89
+ ),
90
+ tf.keras.layers.Conv3DTranspose(
91
+ self.base_depth,
92
+ (5, 4, 4),
93
+ strides=(1, 2, 2),
94
+ padding="same",
95
+ activation=tf.nn.leaky_relu,
96
+ ),
97
+ tf.keras.layers.Conv3DTranspose(
98
+ self.base_depth,
99
+ (5, 4, 4),
100
+ strides=2,
101
+ padding="same",
102
+ activation=tf.nn.leaky_relu,
103
+ ),
104
+ tf.keras.layers.Conv3DTranspose(
105
+ self.base_depth,
106
+ (5, 4, 4),
107
+ strides=1,
108
+ padding="same",
109
+ activation=tf.nn.leaky_relu,
110
+ ),
111
+ tf.keras.layers.Conv2D(
112
+ filters=1, kernel_size=5, strides=1, padding="same", activation=None
113
+ ),
114
+ tf.keras.layers.Flatten(),
115
+ tfp.layers.IndependentBernoulli(
116
+ input_shape, tfp.distributions.Bernoulli.logits
117
+ ),
118
+ ]
119
+ )
120
+
121
+ self.model = tf.keras.Model(
122
+ inputs=self.encoder.inputs, outputs=self.decoder(self.encoder.outputs[0])
123
+ )
124
+
125
+ def call(self, inputs):
126
+ return self.model(inputs)
ganime/model/p2p/__init__.py ADDED
File without changes
ganime/model/p2p/p2p.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from statistics import mode
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.python.keras import Model, Sequential
5
+ from tensorflow.python.keras.layers import Dense, LSTMCell, RNN, Conv2D, Conv2DTranspose
6
+ from tensorflow.keras.layers import BatchNormalization, TimeDistributed
7
+ from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
8
+ from tensorflow.keras.layers import Activation
9
+
10
+ # from tensorflow_probability.python.layers.dense_variational import (
11
+ # DenseReparameterization,
12
+ # )
13
+ # import tensorflow_probability as tfp
14
+ from tensorflow.keras.losses import Loss
15
+
16
+
17
+ class KLCriterion(Loss):
18
+ def call(self, y_true, y_pred):
19
+ (mu1, logvar1), (mu2, logvar2) = y_true, y_pred
20
+
21
+ """KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2))"""
22
+ sigma1 = tf.exp(tf.math.multiply(logvar1, 0.5))
23
+ sigma2 = tf.exp(tf.math.multiply(logvar2, 0.5))
24
+
25
+ kld = (
26
+ tf.math.log(sigma2 / sigma1)
27
+ + (tf.exp(logvar1) + tf.square(mu1 - mu2)) / (2 * tf.exp(logvar2))
28
+ - 0.5
29
+ )
30
+ return tf.reduce_sum(kld) / 22
31
+
32
+
33
+ class Encoder(Model):
34
+ def __init__(self, dim, nc=1):
35
+ super().__init__()
36
+ self.dim = dim
37
+ self.c1 = Sequential(
38
+ [
39
+ Conv2D(64, kernel_size=4, strides=2, padding="same"),
40
+ BatchNormalization(),
41
+ LeakyReLU(alpha=0.2),
42
+ ]
43
+ )
44
+ self.c2 = Sequential(
45
+ [
46
+ Conv2D(128, kernel_size=4, strides=2, padding="same"),
47
+ BatchNormalization(),
48
+ LeakyReLU(alpha=0.2),
49
+ ]
50
+ )
51
+ self.c3 = Sequential(
52
+ [
53
+ Conv2D(256, kernel_size=4, strides=2, padding="same"),
54
+ BatchNormalization(),
55
+ LeakyReLU(alpha=0.2),
56
+ ]
57
+ )
58
+ self.c4 = Sequential(
59
+ [
60
+ Conv2D(512, kernel_size=4, strides=2, padding="same"),
61
+ BatchNormalization(),
62
+ LeakyReLU(alpha=0.2),
63
+ ]
64
+ )
65
+ self.c5 = Sequential(
66
+ [
67
+ Conv2D(self.dim, kernel_size=4, strides=1, padding="valid"),
68
+ BatchNormalization(),
69
+ Activation("tanh"),
70
+ ]
71
+ )
72
+
73
+ def call(self, input):
74
+ h1 = self.c1(input)
75
+ h2 = self.c2(h1)
76
+ h3 = self.c3(h2)
77
+ h4 = self.c4(h3)
78
+ h5 = self.c5(h4)
79
+ return tf.reshape(h5, (-1, self.dim)), [h1, h2, h3, h4, h5]
80
+
81
+
82
+ class Decoder(Model):
83
+ def __init__(self, dim, nc=1):
84
+ super().__init__()
85
+ self.dim = dim
86
+ self.upc1 = Sequential(
87
+ [
88
+ Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid"),
89
+ BatchNormalization(),
90
+ LeakyReLU(alpha=0.2),
91
+ ]
92
+ )
93
+ self.upc2 = Sequential(
94
+ [
95
+ Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
96
+ BatchNormalization(),
97
+ LeakyReLU(alpha=0.2),
98
+ ]
99
+ )
100
+ self.upc3 = Sequential(
101
+ [
102
+ Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
103
+ BatchNormalization(),
104
+ LeakyReLU(alpha=0.2),
105
+ ]
106
+ )
107
+ self.upc4 = Sequential(
108
+ [
109
+ Conv2DTranspose(64, kernel_size=4, strides=2, padding="same"),
110
+ BatchNormalization(),
111
+ LeakyReLU(alpha=0.2),
112
+ ]
113
+ )
114
+ self.upc5 = Sequential(
115
+ [
116
+ Conv2DTranspose(1, kernel_size=4, strides=2, padding="same"),
117
+ Activation("sigmoid"),
118
+ ]
119
+ )
120
+
121
+ def call(self, input):
122
+ vec, skip = input
123
+ d1 = self.upc1(tf.reshape(vec, (-1, 1, 1, self.dim)))
124
+ d2 = self.upc2(tf.concat([d1, skip[3]], axis=-1))
125
+ d3 = self.upc3(tf.concat([d2, skip[2]], axis=-1))
126
+ d4 = self.upc4(tf.concat([d3, skip[1]], axis=-1))
127
+ output = self.upc5(tf.concat([d4, skip[0]], axis=-1))
128
+ return output
129
+
130
+
131
+ class MyLSTM(Model):
132
+ def __init__(self, input_shape, hidden_size, output_size, n_layers):
133
+ super().__init__()
134
+ self.hidden_size = hidden_size
135
+ self.n_layers = n_layers
136
+ self.embed = Dense(hidden_size, input_dim=input_shape)
137
+ # self.lstm = Sequential(
138
+ # [LSTMCell(hidden_size) for _ in range(n_layers)], name="lstm"
139
+ # )
140
+ # self.lstm = self.create_lstm(hidden_size, n_layers)
141
+ self.lstm = LSTMCell(hidden_size)
142
+ self.out = Dense(output_size)
143
+
144
+ def init_hidden(self, batch_size):
145
+ hidden = []
146
+ for i in range(self.n_layers):
147
+ hidden.append(
148
+ (
149
+ tf.Variable(tf.zeros([batch_size, self.hidden_size])),
150
+ tf.Variable(tf.zeros([batch_size, self.hidden_size])),
151
+ )
152
+ )
153
+ self.__dict__["hidden"] = hidden
154
+
155
+ def build(self, input_shape):
156
+ self.init_hidden(input_shape[0])
157
+
158
+ def call(self, inputs):
159
+ h_in = self.embed(inputs)
160
+ for i in range(self.n_layers):
161
+ _, self.hidden[i] = self.lstm(h_in, self.hidden[i])
162
+ h_in = self.hidden[i][0]
163
+
164
+ return self.out(h_in)
165
+
166
+
167
+ class MyGaussianLSTM(Model):
168
+ def __init__(self, input_shape, hidden_size, output_size, n_layers):
169
+ super().__init__()
170
+ self.hidden_size = hidden_size
171
+ self.n_layers = n_layers
172
+ self.embed = Dense(hidden_size, input_dim=input_shape)
173
+ # self.lstm = Sequential(
174
+ # [LSTMCell(hidden_size) for _ in range(n_layers)], name="lstm"
175
+ # )
176
+ self.lstm = LSTMCell(hidden_size)
177
+ self.mu_net = Dense(output_size)
178
+ self.logvar_net = Dense(output_size)
179
+ # self.out = Sequential(
180
+ # [
181
+ # tf.keras.layers.Dense(
182
+ # tfp.layers.MultivariateNormalTriL.params_size(output_size),
183
+ # activation=None,
184
+ # ),
185
+ # tfp.layers.MultivariateNormalTriL(output_size),
186
+ # ]
187
+ # )
188
+
189
+ def reparameterize(self, mu, logvar: tf.Tensor):
190
+ logvar = tf.math.exp(logvar * 0.5)
191
+ eps = tf.random.normal(logvar.shape)
192
+ return tf.add(tf.math.multiply(eps, logvar), mu)
193
+
194
+ def init_hidden(self, batch_size):
195
+ hidden = []
196
+ for i in range(self.n_layers):
197
+ hidden.append(
198
+ (
199
+ tf.Variable(tf.zeros([batch_size, self.hidden_size])),
200
+ tf.Variable(tf.zeros([batch_size, self.hidden_size])),
201
+ )
202
+ )
203
+ self.__dict__["hidden"] = hidden
204
+
205
+ def build(self, input_shape):
206
+ self.init_hidden(input_shape[0])
207
+
208
+ def call(self, inputs):
209
+ h_in = self.embed(inputs)
210
+ for i in range(self.n_layers):
211
+ # print(h_in.shape, self.hidden[i][0].shape, self.hidden[i][0].shape)
212
+
213
+ _, self.hidden[i] = self.lstm(h_in, self.hidden[i])
214
+ h_in = self.hidden[i][0]
215
+ mu = self.mu_net(h_in)
216
+ logvar = self.logvar_net(h_in)
217
+ z = self.reparameterize(mu, logvar)
218
+ return z, mu, logvar
219
+
220
+
221
+ class P2P(Model):
222
+ def __init__(
223
+ self,
224
+ channels: int = 1,
225
+ g_dim: int = 128,
226
+ z_dim: int = 10,
227
+ rnn_size: int = 256,
228
+ prior_rnn_layers: int = 1,
229
+ posterior_rnn_layers: int = 1,
230
+ predictor_rnn_layers: float = 1,
231
+ skip_prob: float = 0.5,
232
+ n_past: int = 1,
233
+ last_frame_skip: bool = False,
234
+ beta: float = 0.0001,
235
+ weight_align: float = 0.1,
236
+ weight_cpc: float = 100,
237
+ ):
238
+ super().__init__()
239
+ self.channels = channels
240
+ self.g_dim = g_dim
241
+ self.z_dim = z_dim
242
+ self.rnn_size = rnn_size
243
+ self.prior_rnn_layers = prior_rnn_layers
244
+ self.posterior_rnn_layers = posterior_rnn_layers
245
+ self.predictor_rnn_layers = predictor_rnn_layers
246
+
247
+ self.skip_prob = skip_prob
248
+ self.n_past = n_past
249
+ self.last_frame_skip = last_frame_skip
250
+ self.beta = beta
251
+ self.weight_align = weight_align
252
+ self.weight_cpc = weight_cpc
253
+
254
+ self.frame_predictor = MyLSTM(
255
+ self.g_dim + self.z_dim + 1 + 1,
256
+ self.rnn_size,
257
+ self.g_dim,
258
+ self.predictor_rnn_layers,
259
+ )
260
+
261
+ self.prior = MyGaussianLSTM(
262
+ self.g_dim + self.g_dim + 1 + 1,
263
+ self.rnn_size,
264
+ self.z_dim,
265
+ self.prior_rnn_layers,
266
+ )
267
+
268
+ self.posterior = MyGaussianLSTM(
269
+ self.g_dim + self.g_dim + 1 + 1,
270
+ self.rnn_size,
271
+ self.z_dim,
272
+ self.posterior_rnn_layers,
273
+ )
274
+
275
+ self.encoder = Encoder(self.g_dim, self.channels)
276
+ self.decoder = Decoder(self.g_dim, self.channels)
277
+
278
+ # criterions
279
+ self.mse_criterion = tf.keras.losses.MeanSquaredError()
280
+ self.kl_criterion = KLCriterion()
281
+ self.align_criterion = tf.keras.losses.MeanSquaredError()
282
+
283
+ # optimizers
284
+ self.frame_predictor_optimizer = tf.keras.optimizers.Adam(
285
+ learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8
286
+ )
287
+ self.posterior_optimizer = tf.keras.optimizers.Adam(
288
+ learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8
289
+ )
290
+ self.prior_optimizer = tf.keras.optimizers.Adam(
291
+ learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8
292
+ )
293
+ self.encoder_optimizer = tf.keras.optimizers.Adam(
294
+ learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8
295
+ )
296
+ self.decoder_optimizer = tf.keras.optimizers.Adam(
297
+ learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8
298
+ )
299
+
300
+ def get_global_descriptor(self, x, start_ix=0, cp_ix=None):
301
+ """Get the global descriptor based on x, start_ix, cp_ix."""
302
+ if cp_ix is None:
303
+ cp_ix = x.shape[1] - 1
304
+
305
+ x_cp = x[:, cp_ix, ...]
306
+ h_cp = self.encoder(x_cp)[0] # 1 is input for skip-connection
307
+
308
+ return x_cp, h_cp
309
+
310
+ def call(self, x, start_ix=0, cp_ix=-1):
311
+ batch_size = x.shape[0]
312
+
313
+ with tf.GradientTape(persistent=True) as tape:
314
+ mse_loss = 0
315
+ kld_loss = 0
316
+ cpc_loss = 0
317
+ align_loss = 0
318
+
319
+ seq_len = x.shape[1]
320
+ start_ix = 0
321
+ cp_ix = seq_len - 1
322
+ x_cp, global_z = self.get_global_descriptor(
323
+ x, start_ix, cp_ix
324
+ ) # here global_z is h_cp
325
+
326
+ skip_prob = self.skip_prob
327
+
328
+ prev_i = 0
329
+ max_skip_count = seq_len * skip_prob
330
+ skip_count = 0
331
+ probs = np.random.uniform(low=0, high=1, size=seq_len - 1)
332
+
333
+ for i in range(1, seq_len):
334
+ if (
335
+ probs[i - 1] <= skip_prob
336
+ and i >= self.n_past
337
+ and skip_count < max_skip_count
338
+ and i != 1
339
+ and i != cp_ix
340
+ ):
341
+ skip_count += 1
342
+ continue
343
+
344
+ time_until_cp = tf.fill([batch_size, 1], (cp_ix - i + 1) / cp_ix)
345
+ delta_time = tf.fill([batch_size, 1], ((i - prev_i) / cp_ix))
346
+ prev_i = i
347
+
348
+ h = self.encoder(x[:, i - 1, ...])
349
+ h_target = self.encoder(x[:, i, ...])[0]
350
+
351
+ if self.last_frame_skip or i <= self.n_past:
352
+ h, skip = h
353
+ else:
354
+ h = h[0]
355
+
356
+ # Control Point Aware
357
+ h_cpaw = tf.concat([h, global_z, time_until_cp, delta_time], axis=1)
358
+ h_target_cpaw = tf.concat(
359
+ [h_target, global_z, time_until_cp, delta_time], axis=1
360
+ )
361
+ zt, mu, logvar = self.posterior(h_target_cpaw)
362
+ zt_p, mu_p, logvar_p = self.prior(h_cpaw)
363
+
364
+ concat = tf.concat([h, zt, time_until_cp, delta_time], axis=1)
365
+ h_pred = self.frame_predictor(concat)
366
+ x_pred = self.decoder([h_pred, skip])
367
+
368
+ if i == cp_ix: # the gen-cp-frame should be exactly as x_cp
369
+ h_pred_p = self.frame_predictor(
370
+ tf.concat([h, zt_p, time_until_cp, delta_time], axis=1)
371
+ )
372
+ x_pred_p = self.decoder([h_pred_p, skip])
373
+ cpc_loss = self.mse_criterion(x_pred_p, x_cp)
374
+
375
+ if i > 1:
376
+ align_loss += self.align_criterion(h[0], h_pred)
377
+
378
+ mse_loss += self.mse_criterion(x_pred, x[:, i, ...])
379
+ kld_loss += self.kl_criterion((mu, logvar), (mu_p, logvar_p))
380
+
381
+ # backward
382
+ loss = mse_loss + kld_loss * self.beta + align_loss * self.weight_align
383
+
384
+ prior_loss = kld_loss + cpc_loss * self.weight_cpc
385
+
386
+ var_list_frame_predictor = self.frame_predictor.trainable_variables
387
+ var_list_posterior = self.posterior.trainable_variables
388
+ var_list_prior = self.prior.trainable_variables
389
+ var_list_encoder = self.encoder.trainable_variables
390
+ var_list_decoder = self.decoder.trainable_variables
391
+
392
+ # mse: frame_predictor + decoder
393
+ # align: frame_predictor + encoder
394
+ # kld: posterior + prior + encoder
395
+
396
+ var_list_without_prior = (
397
+ var_list_frame_predictor
398
+ + var_list_posterior
399
+ + var_list_encoder
400
+ + var_list_decoder
401
+ )
402
+
403
+ gradients_without_prior = tape.gradient(
404
+ loss,
405
+ var_list_without_prior,
406
+ )
407
+ gradients_prior = tape.gradient(
408
+ prior_loss,
409
+ var_list_prior,
410
+ )
411
+
412
+ self.update_model_without_prior(
413
+ gradients_without_prior,
414
+ var_list_without_prior,
415
+ )
416
+ self.update_prior(gradients_prior, var_list_prior)
417
+ del tape
418
+
419
+ return (
420
+ mse_loss / seq_len,
421
+ kld_loss / seq_len,
422
+ cpc_loss / seq_len,
423
+ align_loss / seq_len,
424
+ )
425
+
426
+ def p2p_generate(
427
+ self,
428
+ x,
429
+ len_output,
430
+ eval_cp_ix,
431
+ start_ix=0,
432
+ cp_ix=-1,
433
+ model_mode="full",
434
+ skip_frame=False,
435
+ init_hidden=True,
436
+ ):
437
+ batch_size, num_frames, h, w, channels = x.shape
438
+ dim_shape = (h, w, channels)
439
+
440
+ gen_seq = [x[:, 0, ...]]
441
+ x_in = x[:, 0, ...]
442
+
443
+ seq_len = x.shape[1]
444
+ cp_ix = seq_len - 1
445
+
446
+ x_cp, global_z = self.get_global_descriptor(
447
+ x, cp_ix=cp_ix
448
+ ) # here global_z is h_cp
449
+
450
+ skip_prob = self.skip_prob
451
+
452
+ prev_i = 0
453
+ max_skip_count = seq_len * skip_prob
454
+ skip_count = 0
455
+ probs = np.random.uniform(0, 1, len_output - 1)
456
+
457
+ for i in range(1, len_output):
458
+ if (
459
+ probs[i - 1] <= skip_prob
460
+ and i >= self.n_past
461
+ and skip_count < max_skip_count
462
+ and i != 1
463
+ and i != (len_output - 1)
464
+ and skip_frame
465
+ ):
466
+ skip_count += 1
467
+ gen_seq.append(tf.zeros_like(x_in))
468
+ continue
469
+
470
+ time_until_cp = tf.fill([batch_size, 1], (eval_cp_ix - i + 1) / eval_cp_ix)
471
+
472
+ delta_time = tf.fill([batch_size, 1], ((i - prev_i) / eval_cp_ix))
473
+
474
+ prev_i = i
475
+
476
+ h = self.encoder(x_in)
477
+
478
+ if self.last_frame_skip or i == 1 or i < self.n_past:
479
+ h, skip = h
480
+ else:
481
+ h, _ = h
482
+
483
+ h_cpaw = tf.concat([h, global_z, time_until_cp, delta_time], axis=1)
484
+
485
+ if i < self.n_past:
486
+ h_target = self.encoder(x[:, i, ...])[0]
487
+ h_target_cpaw = tf.concat(
488
+ [h_target, global_z, time_until_cp, delta_time], axis=1
489
+ )
490
+
491
+ zt, _, _ = self.posterior(h_target_cpaw)
492
+ zt_p, _, _ = self.prior(h_cpaw)
493
+
494
+ if model_mode == "posterior" or model_mode == "full":
495
+ self.frame_predictor(
496
+ tf.concat([h, zt, time_until_cp, delta_time], axis=1)
497
+ )
498
+ elif model_mode == "prior":
499
+ self.frame_predictor(
500
+ tf.concat([h, zt_p, time_until_cp, delta_time], axis=1)
501
+ )
502
+
503
+ x_in = x[:, i, ...]
504
+ gen_seq.append(x_in)
505
+ else:
506
+ if i < num_frames:
507
+ h_target = self.encoder(x[:, i, ...])[0]
508
+ h_target_cpaw = tf.concat(
509
+ [h_target, global_z, time_until_cp, delta_time], axis=1
510
+ )
511
+ else:
512
+ h_target_cpaw = h_cpaw
513
+
514
+ zt, _, _ = self.posterior(h_target_cpaw)
515
+ zt_p, _, _ = self.prior(h_cpaw)
516
+
517
+ if model_mode == "posterior":
518
+ h = self.frame_predictor(
519
+ tf.concat([h, zt, time_until_cp, delta_time], axis=1)
520
+ )
521
+ elif model_mode == "prior" or model_mode == "full":
522
+ h = self.frame_predictor(
523
+ tf.concat([h, zt_p, time_until_cp, delta_time], axis=1)
524
+ )
525
+
526
+ x_in = self.decoder([h, skip])
527
+ gen_seq.append(x_in)
528
+ return tf.stack(gen_seq, axis=1)
529
+
530
+ def update_model_without_prior(self, gradients, var_list):
531
+ self.frame_predictor_optimizer.apply_gradients(zip(gradients, var_list))
532
+ self.posterior_optimizer.apply_gradients(zip(gradients, var_list))
533
+ self.encoder_optimizer.apply_gradients(zip(gradients, var_list))
534
+ self.decoder_optimizer.apply_gradients(zip(gradients, var_list))
535
+
536
+ def update_prior(self, gradients, var_list):
537
+ self.prior_optimizer.apply_gradients(zip(gradients, var_list))
538
+
539
+ # def update_model_without_prior(self):
540
+ # self.frame_predictor_optimizer.step()
541
+ # self.posterior_optimizer.step()
542
+ # self.encoder_optimizer.step()
543
+ # self.decoder_optimizer.step()
ganime/model/p2p/p2p_test.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm.auto import tqdm
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras import Model, Sequential
5
+ from tensorflow.keras.layers import (
6
+ LSTM,
7
+ LSTMCell,
8
+ Activation,
9
+ BatchNormalization,
10
+ Conv2D,
11
+ Conv2DTranspose,
12
+ Conv3D,
13
+ Conv3DTranspose,
14
+ Dense,
15
+ Flatten,
16
+ Input,
17
+ Layer,
18
+ LeakyReLU,
19
+ MaxPooling2D,
20
+ Reshape,
21
+ TimeDistributed,
22
+ UpSampling2D,
23
+ )
24
+ from tensorflow.keras.losses import Loss
25
+ from tensorflow.keras.losses import KLDivergence, MeanSquaredError
26
+
27
+ # from tensorflow_probability.python.layers.dense_variational import (
28
+ # DenseReparameterization,
29
+ # )
30
+ # import tensorflow_probability as tfp
31
+ from tensorflow.keras.losses import Loss
32
+
33
+ initializer_conv_dense = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
34
+ initializer_batch_norm = tf.keras.initializers.RandomNormal(mean=1.0, stddev=0.02)
35
+
36
+
37
+ class KLCriterion(Loss):
38
+ def call(self, y_true, y_pred):
39
+ (mu1, logvar1), (mu2, logvar2) = y_true, y_pred
40
+
41
+ """KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2))"""
42
+ sigma1 = tf.exp(tf.math.multiply(logvar1, 0.5))
43
+ sigma2 = tf.exp(tf.math.multiply(logvar2, 0.5))
44
+
45
+ kld = (
46
+ tf.math.log(sigma2 / sigma1)
47
+ + (tf.exp(logvar1) + tf.square(mu1 - mu2)) / (2 * tf.exp(logvar2))
48
+ - 0.5
49
+ )
50
+ return tf.reduce_sum(kld) / 100
51
+
52
+
53
+ class Encoder(Model):
54
+ def __init__(self, dim, nc=1):
55
+ super().__init__()
56
+ self.dim = dim
57
+ self.c1 = Sequential(
58
+ [
59
+ Conv2D(
60
+ 64,
61
+ kernel_size=4,
62
+ strides=2,
63
+ padding="same",
64
+ kernel_initializer=initializer_conv_dense,
65
+ ),
66
+ # BatchNormalization(),
67
+ LeakyReLU(alpha=0.2),
68
+ ]
69
+ )
70
+ self.c2 = Sequential(
71
+ [
72
+ Conv2D(
73
+ 128,
74
+ kernel_size=4,
75
+ strides=2,
76
+ padding="same",
77
+ kernel_initializer=initializer_conv_dense,
78
+ ),
79
+ # BatchNormalization(),
80
+ LeakyReLU(alpha=0.2),
81
+ ]
82
+ )
83
+ self.c3 = Sequential(
84
+ [
85
+ Conv2D(
86
+ 256,
87
+ kernel_size=4,
88
+ strides=2,
89
+ padding="same",
90
+ kernel_initializer=initializer_conv_dense,
91
+ ),
92
+ # BatchNormalization(),
93
+ LeakyReLU(alpha=0.2),
94
+ ]
95
+ )
96
+ self.c4 = Sequential(
97
+ [
98
+ Conv2D(
99
+ 512,
100
+ kernel_size=4,
101
+ strides=2,
102
+ padding="same",
103
+ kernel_initializer=initializer_conv_dense,
104
+ ),
105
+ # BatchNormalization(),
106
+ LeakyReLU(alpha=0.2),
107
+ ]
108
+ )
109
+ self.c5 = Sequential(
110
+ [
111
+ Conv2D(
112
+ self.dim,
113
+ kernel_size=4,
114
+ strides=1,
115
+ padding="valid",
116
+ kernel_initializer=initializer_conv_dense,
117
+ ),
118
+ # BatchNormalization(),
119
+ Activation("tanh"),
120
+ ]
121
+ )
122
+
123
+ def call(self, input):
124
+ h1 = self.c1(input)
125
+ h2 = self.c2(h1)
126
+ h3 = self.c3(h2)
127
+ h4 = self.c4(h3)
128
+ h5 = self.c5(h4)
129
+ return tf.reshape(h5, (-1, self.dim)), [h1, h2, h3, h4, h5]
130
+
131
+
132
+ class Decoder(Model):
133
+ def __init__(self, dim, nc=1):
134
+ super().__init__()
135
+ self.dim = dim
136
+ self.upc1 = Sequential(
137
+ [
138
+ Conv2DTranspose(
139
+ 512,
140
+ kernel_size=4,
141
+ strides=1,
142
+ padding="valid",
143
+ kernel_initializer=initializer_conv_dense,
144
+ ),
145
+ # BatchNormalization(),
146
+ LeakyReLU(alpha=0.2),
147
+ ]
148
+ )
149
+ self.upc2 = Sequential(
150
+ [
151
+ Conv2DTranspose(
152
+ 256,
153
+ kernel_size=4,
154
+ strides=2,
155
+ padding="same",
156
+ kernel_initializer=initializer_conv_dense,
157
+ ),
158
+ # BatchNormalization(),
159
+ LeakyReLU(alpha=0.2),
160
+ ]
161
+ )
162
+ self.upc3 = Sequential(
163
+ [
164
+ Conv2DTranspose(
165
+ 128,
166
+ kernel_size=4,
167
+ strides=2,
168
+ padding="same",
169
+ kernel_initializer=initializer_conv_dense,
170
+ ),
171
+ # BatchNormalization(),
172
+ LeakyReLU(alpha=0.2),
173
+ ]
174
+ )
175
+ self.upc4 = Sequential(
176
+ [
177
+ Conv2DTranspose(
178
+ 64,
179
+ kernel_size=4,
180
+ strides=2,
181
+ padding="same",
182
+ kernel_initializer=initializer_conv_dense,
183
+ ),
184
+ # BatchNormalization(),
185
+ LeakyReLU(alpha=0.2),
186
+ ]
187
+ )
188
+ self.upc5 = Sequential(
189
+ [
190
+ Conv2DTranspose(
191
+ 1,
192
+ kernel_size=4,
193
+ strides=2,
194
+ padding="same",
195
+ kernel_initializer=initializer_conv_dense,
196
+ ),
197
+ Activation("sigmoid"),
198
+ ]
199
+ )
200
+
201
+ def call(self, input):
202
+ vec, skip = input
203
+ d1 = self.upc1(tf.reshape(vec, (-1, 1, 1, self.dim)))
204
+ d2 = self.upc2(tf.concat([d1, skip[3]], axis=-1))
205
+ d3 = self.upc3(tf.concat([d2, skip[2]], axis=-1))
206
+ d4 = self.upc4(tf.concat([d3, skip[1]], axis=-1))
207
+ output = self.upc5(tf.concat([d4, skip[0]], axis=-1))
208
+ return output
209
+
210
+
211
+ class MyLSTM(Model):
212
+ def __init__(self, input_shape, hidden_size, output_size, n_layers):
213
+ super().__init__()
214
+ self.hidden_size = hidden_size
215
+ self.n_layers = n_layers
216
+ self.embed = Dense(
217
+ hidden_size,
218
+ input_dim=input_shape,
219
+ kernel_initializer=initializer_conv_dense,
220
+ )
221
+ # self.lstm = Sequential(
222
+ # [LSTMCell(hidden_size) for _ in range(n_layers)], name="lstm"
223
+ # )
224
+ # self.lstm = self.create_lstm(hidden_size, n_layers)
225
+ self.lstm = [
226
+ LSTMCell(
227
+ hidden_size # , return_sequences=False if i == self.n_layers - 1 else True
228
+ )
229
+ for i in range(self.n_layers)
230
+ ] # LSTMCell(hidden_size)
231
+ self.lstm_rnn = tf.keras.layers.RNN(self.lstm[0], return_state=True)
232
+ self.out = Dense(output_size, kernel_initializer=initializer_conv_dense)
233
+
234
+ def init_hidden(self, batch_size):
235
+ hidden = []
236
+ for i in range(self.n_layers):
237
+ hidden.append(
238
+ (
239
+ tf.Variable(tf.zeros([batch_size, self.hidden_size])),
240
+ tf.Variable(tf.zeros([batch_size, self.hidden_size])),
241
+ )
242
+ )
243
+ self.__dict__["hidden"] = hidden
244
+
245
+ def build(self, input_shape):
246
+ self.init_hidden(input_shape[0])
247
+
248
+ def call(self, inputs):
249
+ h_in = self.embed(inputs)
250
+ h_in = tf.reshape(h_in, (-1, 1, self.hidden_size))
251
+ h_in, *state = self.lstm_rnn(h_in)
252
+ for i in range(self.n_layers):
253
+ h_in, state = self.lstm[i](h_in, state)
254
+ return self.out(h_in)
255
+
256
+
257
+ class MyGaussianLSTM(Model):
258
+ def __init__(self, input_shape, hidden_size, output_size, n_layers):
259
+ super().__init__()
260
+ self.hidden_size = hidden_size
261
+ self.n_layers = n_layers
262
+ self.embed = Dense(
263
+ hidden_size,
264
+ input_dim=input_shape,
265
+ kernel_initializer=initializer_conv_dense,
266
+ )
267
+ # self.lstm = Sequential(
268
+ # [LSTMCell(hidden_size) for _ in range(n_layers)], name="lstm"
269
+ # )
270
+ self.lstm = [
271
+ LSTMCell(
272
+ hidden_size # , return_sequences=False if i == self.n_layers - 1 else True
273
+ )
274
+ for i in range(self.n_layers)
275
+ ] # LSTMCell(hidden_size)
276
+ self.lstm_rnn = tf.keras.layers.RNN(self.lstm[0], return_state=True)
277
+ self.mu_net = Dense(output_size, kernel_initializer=initializer_conv_dense)
278
+ self.logvar_net = Dense(output_size, kernel_initializer=initializer_conv_dense)
279
+ # self.out = Sequential(
280
+ # [
281
+ # tf.keras.layers.Dense(
282
+ # tfp.layers.MultivariateNormalTriL.params_size(output_size),
283
+ # activation=None,
284
+ # ),
285
+ # tfp.layers.MultivariateNormalTriL(output_size),
286
+ # ]
287
+ # )
288
+
289
+ def reparameterize(self, mu, logvar: tf.Tensor):
290
+ logvar = tf.math.exp(logvar * 0.5)
291
+ eps = tf.random.normal(logvar.shape)
292
+ return tf.add(tf.math.multiply(eps, logvar), mu)
293
+
294
+ def init_hidden(self, batch_size):
295
+ hidden = []
296
+ for i in range(self.n_layers):
297
+ hidden.append(
298
+ (
299
+ tf.Variable(tf.zeros([batch_size, self.hidden_size])),
300
+ tf.Variable(tf.zeros([batch_size, self.hidden_size])),
301
+ )
302
+ )
303
+ self.__dict__["hidden"] = hidden
304
+
305
+ def build(self, input_shape):
306
+ self.init_hidden(input_shape[0])
307
+
308
+ def call(self, inputs):
309
+ h_in = self.embed(inputs)
310
+ # for i in range(self.n_layers):
311
+ # # print(h_in.shape, self.hidden[i][0].shape, self.hidden[i][0].shape)
312
+
313
+ # _, self.hidden[i] = self.lstm(h_in, self.hidden[i])
314
+ # h_in = self.hidden[i][0]
315
+ h_in = tf.reshape(h_in, (-1, 1, self.hidden_size))
316
+ h_in, *state = self.lstm_rnn(h_in)
317
+ for i in range(self.n_layers):
318
+ h_in, state = self.lstm[i](h_in, state)
319
+
320
+ mu = self.mu_net(h_in)
321
+ logvar = self.logvar_net(h_in)
322
+ z = self.reparameterize(mu, logvar)
323
+ return z, mu, logvar
324
+
325
+
326
+ class P2P(Model):
327
+ def __init__(
328
+ self,
329
+ channels: int = 1,
330
+ g_dim: int = 128,
331
+ z_dim: int = 10,
332
+ rnn_size: int = 256,
333
+ prior_rnn_layers: int = 1,
334
+ posterior_rnn_layers: int = 1,
335
+ predictor_rnn_layers: float = 2,
336
+ skip_prob: float = 0.5,
337
+ n_past: int = 1,
338
+ last_frame_skip: bool = False,
339
+ beta: float = 0.0001,
340
+ weight_align: float = 0.1,
341
+ weight_cpc: float = 100,
342
+ ):
343
+ super().__init__()
344
+ self.channels = channels
345
+ self.g_dim = g_dim
346
+ self.z_dim = z_dim
347
+ self.rnn_size = rnn_size
348
+ self.prior_rnn_layers = prior_rnn_layers
349
+ self.posterior_rnn_layers = posterior_rnn_layers
350
+ self.predictor_rnn_layers = predictor_rnn_layers
351
+
352
+ self.skip_prob = skip_prob
353
+ self.n_past = n_past
354
+ self.last_frame_skip = last_frame_skip
355
+ self.beta = beta
356
+ self.weight_align = weight_align
357
+ self.weight_cpc = weight_cpc
358
+
359
+ self.frame_predictor = MyLSTM(
360
+ self.g_dim + self.z_dim + 1 + 1,
361
+ self.rnn_size,
362
+ self.g_dim,
363
+ self.predictor_rnn_layers,
364
+ )
365
+
366
+ self.prior = MyGaussianLSTM(
367
+ self.g_dim + self.g_dim + 1 + 1,
368
+ self.rnn_size,
369
+ self.z_dim,
370
+ self.prior_rnn_layers,
371
+ )
372
+
373
+ self.posterior = MyGaussianLSTM(
374
+ self.g_dim + self.g_dim + 1 + 1,
375
+ self.rnn_size,
376
+ self.z_dim,
377
+ self.posterior_rnn_layers,
378
+ )
379
+
380
+ self.encoder = Encoder(self.g_dim, self.channels)
381
+ self.decoder = Decoder(self.g_dim, self.channels)
382
+
383
+ # criterions
384
+ self.mse_criterion = tf.keras.losses.MeanSquaredError()
385
+ self.kl_criterion = KLCriterion()
386
+ self.align_criterion = tf.keras.losses.MeanSquaredError()
387
+
388
+ self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
389
+ self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
390
+ name="reconstruction_loss"
391
+ )
392
+ self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
393
+ self.align_loss_tracker = tf.keras.metrics.Mean(name="align_loss")
394
+ self.cpc_loss_tracker = tf.keras.metrics.Mean(name="align_loss")
395
+
396
+ # optimizers
397
+ # self.frame_predictor_optimizer = tf.keras.optimizers.Adam(
398
+ # learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
399
+ # )
400
+ # self.posterior_optimizer = tf.keras.optimizers.Adam(
401
+ # learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
402
+ # )
403
+ # self.prior_optimizer = tf.keras.optimizers.Adam(
404
+ # learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
405
+ # )
406
+ # self.encoder_optimizer = tf.keras.optimizers.Adam(
407
+ # learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
408
+ # )
409
+ # self.decoder_optimizer = tf.keras.optimizers.Adam(
410
+ # learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
411
+ # )
412
+
413
+ @property
414
+ def metrics(self):
415
+ return [
416
+ self.total_loss_tracker,
417
+ self.reconstruction_loss_tracker,
418
+ self.kl_loss_tracker,
419
+ self.align_loss_tracker,
420
+ self.cpc_loss_tracker,
421
+ ]
422
+
423
+ def get_global_descriptor(self, x, start_ix=0, cp_ix=None):
424
+ """Get the global descriptor based on x, start_ix, cp_ix."""
425
+ if cp_ix is None:
426
+ cp_ix = x.shape[1] - 1
427
+
428
+ x_cp = x[:, cp_ix, ...]
429
+ h_cp = self.encoder(x_cp)[0] # 1 is input for skip-connection
430
+
431
+ return x_cp, h_cp
432
+
433
+ def compile(
434
+ self,
435
+ frame_predictor_optimizer,
436
+ prior_optimizer,
437
+ posterior_optimizer,
438
+ encoder_optimizer,
439
+ decoder_optimizer,
440
+ ):
441
+ super().compile()
442
+ self.frame_predictor_optimizer = frame_predictor_optimizer
443
+ self.prior_optimizer = prior_optimizer
444
+ self.posterior_optimizer = posterior_optimizer
445
+ self.encoder_optimizer = encoder_optimizer
446
+ self.decoder_optimizer = decoder_optimizer
447
+
448
+ def train_step(self, data):
449
+ y, x = data
450
+ batch_size = 100
451
+
452
+ mse_loss = 0
453
+ kld_loss = 0
454
+ cpc_loss = 0
455
+ align_loss = 0
456
+
457
+ seq_len = x.shape[1]
458
+ start_ix = 0
459
+ cp_ix = seq_len - 1
460
+ x_cp, global_z = self.get_global_descriptor(
461
+ x, start_ix, cp_ix
462
+ ) # here global_z is h_cp
463
+
464
+ skip_prob = self.skip_prob
465
+
466
+ prev_i = 0
467
+ max_skip_count = seq_len * skip_prob
468
+ skip_count = 0
469
+ probs = np.random.uniform(low=0, high=1, size=seq_len - 1)
470
+
471
+ with tf.GradientTape(persistent=True) as tape:
472
+ for i in tqdm(range(1, seq_len)):
473
+ if (
474
+ probs[i - 1] <= skip_prob
475
+ and i >= self.n_past
476
+ and skip_count < max_skip_count
477
+ and i != 1
478
+ and i != cp_ix
479
+ ):
480
+ skip_count += 1
481
+ continue
482
+
483
+ if i > 1:
484
+ align_loss += self.align_criterion(h, h_pred)
485
+
486
+ time_until_cp = tf.fill(
487
+ [batch_size, 1],
488
+ (cp_ix - i + 1) / cp_ix,
489
+ )
490
+ delta_time = tf.fill([batch_size, 1], ((i - prev_i) / cp_ix))
491
+ prev_i = i
492
+
493
+ h = self.encoder(x[:, i - 1, ...])
494
+ h_target = self.encoder(x[:, i, ...])[0]
495
+
496
+ if self.last_frame_skip or i <= self.n_past:
497
+ h, skip = h
498
+ else:
499
+ h = h[0]
500
+
501
+ # Control Point Aware
502
+ h_cpaw = tf.concat([h, global_z, time_until_cp, delta_time], axis=-1)
503
+ h_target_cpaw = tf.concat(
504
+ [h_target, global_z, time_until_cp, delta_time], axis=-1
505
+ )
506
+
507
+ zt, mu, logvar = self.posterior(h_target_cpaw)
508
+ zt_p, mu_p, logvar_p = self.prior(h_cpaw)
509
+
510
+ frame_predictor_input = tf.concat(
511
+ [h, zt, time_until_cp, delta_time], axis=-1
512
+ )
513
+ h_pred = self.frame_predictor(frame_predictor_input)
514
+ x_pred = self.decoder([h_pred, skip])
515
+
516
+ if i == cp_ix: # the gen-cp-frame should be exactly as x_cp
517
+ h_pred_p = self.frame_predictor(
518
+ tf.concat([h, zt_p, time_until_cp, delta_time], axis=-1)
519
+ )
520
+ x_pred_p = self.decoder([h_pred_p, skip])
521
+ cpc_loss = self.mse_criterion(x_pred_p, x_cp)
522
+
523
+ mse_loss += self.mse_criterion(x_pred, x[:, i, ...])
524
+ kld_loss += self.kl_criterion((mu, logvar), (mu_p, logvar_p))
525
+
526
+ # backward
527
+ loss = (
528
+ mse_loss
529
+ + kld_loss * self.beta
530
+ + align_loss * self.weight_align
531
+ # + cpc_loss * self.weight_cpc
532
+ )
533
+
534
+ prior_loss = kld_loss + cpc_loss * self.weight_cpc
535
+
536
+ var_list_frame_predictor = self.frame_predictor.trainable_variables
537
+ var_list_posterior = self.posterior.trainable_variables
538
+ var_list_prior = self.prior.trainable_variables
539
+ var_list_encoder = self.encoder.trainable_variables
540
+ var_list_decoder = self.decoder.trainable_variables
541
+
542
+ # mse: frame_predictor + decoder
543
+ # align: frame_predictor + encoder
544
+ # kld: posterior + prior + encoder
545
+
546
+ var_list = (
547
+ var_list_frame_predictor
548
+ + var_list_posterior
549
+ + var_list_encoder
550
+ + var_list_decoder
551
+ + var_list_prior
552
+ )
553
+
554
+ gradients = tape.gradient(
555
+ loss,
556
+ var_list,
557
+ )
558
+ gradients_prior = tape.gradient(
559
+ prior_loss,
560
+ var_list_prior,
561
+ )
562
+
563
+ self.update_model(
564
+ gradients,
565
+ var_list,
566
+ )
567
+ self.update_prior(gradients_prior, var_list_prior)
568
+ del tape
569
+
570
+ self.total_loss_tracker.update_state(loss)
571
+ self.kl_loss_tracker.update_state(kld_loss)
572
+ self.align_loss_tracker.update_state(align_loss)
573
+ self.reconstruction_loss_tracker.update_state(mse_loss)
574
+ self.cpc_loss_tracker.update_state(cpc_loss)
575
+
576
+ return {
577
+ "loss": self.total_loss_tracker.result(),
578
+ "reconstruction_loss": self.reconstruction_loss_tracker.result(),
579
+ "kl_loss": self.kl_loss_tracker.result(),
580
+ "align_loss": self.align_loss_tracker.result(),
581
+ "cpc_loss": self.cpc_loss_tracker.result(),
582
+ }
583
+
584
+ def call(
585
+ self,
586
+ inputs,
587
+ training=None,
588
+ mask=None
589
+ # len_output,
590
+ # eval_cp_ix,
591
+ # start_ix=0,
592
+ # cp_ix=-1,
593
+ # model_mode="full",
594
+ # skip_frame=False,
595
+ # init_hidden=True,
596
+ ):
597
+ len_output = 20
598
+ eval_cp_ix = len_output - 1
599
+ start_ix = 0
600
+ cp_ix = -1
601
+ model_mode = "full"
602
+ skip_frame = False
603
+ init_hidden = True
604
+
605
+ batch_size, num_frames, h, w, channels = inputs.shape
606
+ dim_shape = (h, w, channels)
607
+
608
+ gen_seq = [inputs[:, 0, ...]]
609
+ x_in = inputs[:, 0, ...]
610
+
611
+ seq_len = inputs.shape[1]
612
+ cp_ix = seq_len - 1
613
+
614
+ x_cp, global_z = self.get_global_descriptor(
615
+ inputs, cp_ix=cp_ix
616
+ ) # here global_z is h_cp
617
+
618
+ skip_prob = self.skip_prob
619
+
620
+ prev_i = 0
621
+ max_skip_count = seq_len * skip_prob
622
+ skip_count = 0
623
+ probs = np.random.uniform(0, 1, len_output - 1)
624
+
625
+ for i in range(1, len_output):
626
+ if (
627
+ probs[i - 1] <= skip_prob
628
+ and i >= self.n_past
629
+ and skip_count < max_skip_count
630
+ and i != 1
631
+ and i != (len_output - 1)
632
+ and skip_frame
633
+ ):
634
+ skip_count += 1
635
+ gen_seq.append(tf.zeros_like(x_in))
636
+ continue
637
+
638
+ time_until_cp = tf.fill([100, 1], (eval_cp_ix - i + 1) / eval_cp_ix)
639
+
640
+ delta_time = tf.fill([100, 1], ((i - prev_i) / eval_cp_ix))
641
+
642
+ prev_i = i
643
+
644
+ h = self.encoder(x_in)
645
+
646
+ if self.last_frame_skip or i == 1 or i < self.n_past:
647
+ h, skip = h
648
+ else:
649
+ h, _ = h
650
+
651
+ h_cpaw = tf.stop_gradient(tf.concat([h, global_z, time_until_cp, delta_time], axis=-1))
652
+
653
+ if i < self.n_past:
654
+ h_target = self.encoder(inputs[:, i, ...])[0]
655
+ h_target_cpaw = tf.stop_gradient(tf.concat(
656
+ [h_target, global_z, time_until_cp, delta_time], axis=1
657
+ ))
658
+
659
+ zt, _, _ = self.posterior(h_target_cpaw)
660
+ zt_p, _, _ = self.prior(h_cpaw)
661
+
662
+ if model_mode == "posterior" or model_mode == "full":
663
+ self.frame_predictor(
664
+ tf.concat([h, zt, time_until_cp, delta_time], axis=-1)
665
+ )
666
+ elif model_mode == "prior":
667
+ self.frame_predictor(
668
+ tf.concat([h, zt_p, time_until_cp, delta_time], axis=-1)
669
+ )
670
+
671
+ x_in = inputs[:, i, ...]
672
+ gen_seq.append(x_in)
673
+ else:
674
+ if i < num_frames:
675
+ h_target = self.encoder(inputs[:, i, ...])[0]
676
+ h_target_cpaw = tf.stop_gradient(tf.concat(
677
+ [h_target, global_z, time_until_cp, delta_time], axis=-1
678
+ ))
679
+ else:
680
+ h_target_cpaw = h_cpaw
681
+
682
+ zt, _, _ = self.posterior(h_target_cpaw)
683
+ zt_p, _, _ = self.prior(h_cpaw)
684
+
685
+ if model_mode == "posterior":
686
+ h = self.frame_predictor(
687
+ tf.concat([h, zt, time_until_cp, delta_time], axis=-1)
688
+ )
689
+ elif model_mode == "prior" or model_mode == "full":
690
+ h = self.frame_predictor(
691
+ tf.concat([h, zt_p, time_until_cp, delta_time], axis=-1)
692
+ )
693
+
694
+ x_in = tf.stop_gradient(self.decoder([h, skip]))
695
+ gen_seq.append(x_in)
696
+
697
+ return tf.stack(gen_seq, axis=1)
698
+
699
+ def update_model(self, gradients, var_list):
700
+ self.frame_predictor_optimizer.apply_gradients(zip(gradients, var_list))
701
+ self.posterior_optimizer.apply_gradients(zip(gradients, var_list))
702
+ self.encoder_optimizer.apply_gradients(zip(gradients, var_list))
703
+ self.decoder_optimizer.apply_gradients(zip(gradients, var_list))
704
+ #self.prior_optimizer.apply_gradients(zip(gradients, var_list))
705
+
706
+ def update_prior(self, gradients, var_list):
707
+ self.prior_optimizer.apply_gradients(zip(gradients, var_list))
708
+
709
+ # def update_model_without_prior(self):
710
+ # self.frame_predictor_optimizer.step()
711
+ # self.posterior_optimizer.step()
712
+ # self.encoder_optimizer.step()
713
+ # self.decoder_optimizer.step()
ganime/model/p2p/p2p_v2.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from tensorflow.keras import Model, Sequential
4
+ from tensorflow.keras.layers import (
5
+ LSTM,
6
+ Activation,
7
+ BatchNormalization,
8
+ Conv2D,
9
+ Conv2DTranspose,
10
+ Conv3D,
11
+ Conv3DTranspose,
12
+ Dense,
13
+ Flatten,
14
+ Input,
15
+ Layer,
16
+ LeakyReLU,
17
+ MaxPooling2D,
18
+ Reshape,
19
+ TimeDistributed,
20
+ UpSampling2D,
21
+ )
22
+ from tensorflow.keras.losses import Loss
23
+ from tensorflow.keras.losses import KLDivergence, MeanSquaredError
24
+ from tqdm.auto import tqdm
25
+
26
+
27
+ class KLCriterion(Loss):
28
+ def call(self, y_true, y_pred):
29
+ (mu1, logvar1), (mu2, logvar2) = y_true, y_pred
30
+
31
+ """KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2))"""
32
+ sigma1 = tf.exp(tf.math.multiply(logvar1, 0.5))
33
+ sigma2 = tf.exp(tf.math.multiply(logvar2, 0.5))
34
+
35
+ kld = (
36
+ tf.math.log(sigma2 / sigma1)
37
+ + (tf.exp(logvar1) + tf.square(mu1 - mu2)) / (2 * tf.exp(logvar2))
38
+ - 0.5
39
+ )
40
+ return kld
41
+
42
+
43
+ class Decoder(Model):
44
+ def __init__(self, dim, nc=1):
45
+ super().__init__()
46
+ self.dim = dim
47
+ self.upc1 = Sequential(
48
+ [
49
+ TimeDistributed(
50
+ Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid")
51
+ ),
52
+ BatchNormalization(),
53
+ LeakyReLU(alpha=0.2),
54
+ ]
55
+ )
56
+ self.upc2 = Sequential(
57
+ [
58
+ TimeDistributed(
59
+ Conv2DTranspose(256, kernel_size=4, strides=2, padding="same")
60
+ ),
61
+ BatchNormalization(),
62
+ LeakyReLU(alpha=0.2),
63
+ ]
64
+ )
65
+ self.upc3 = Sequential(
66
+ [
67
+ TimeDistributed(
68
+ Conv2DTranspose(128, kernel_size=4, strides=2, padding="same")
69
+ ),
70
+ BatchNormalization(),
71
+ LeakyReLU(alpha=0.2),
72
+ ]
73
+ )
74
+ self.upc4 = Sequential(
75
+ [
76
+ TimeDistributed(
77
+ Conv2DTranspose(64, kernel_size=4, strides=2, padding="same")
78
+ ),
79
+ BatchNormalization(),
80
+ LeakyReLU(alpha=0.2),
81
+ ]
82
+ )
83
+ self.upc5 = Sequential(
84
+ [
85
+ TimeDistributed(
86
+ Conv2DTranspose(1, kernel_size=4, strides=2, padding="same")
87
+ ),
88
+ Activation("sigmoid"),
89
+ ]
90
+ )
91
+
92
+ def call(self, input):
93
+ vec, skip = input
94
+ d1 = self.upc1(tf.reshape(vec, (-1, 1, 1, 1, self.dim)))
95
+ d2 = self.upc2(tf.concat([d1, skip[3]], axis=-1))
96
+ d3 = self.upc3(tf.concat([d2, skip[2]], axis=-1))
97
+ d4 = self.upc4(tf.concat([d3, skip[1]], axis=-1))
98
+ output = self.upc5(tf.concat([d4, skip[0]], axis=-1))
99
+ return output
100
+
101
+
102
+ class Sampling(Layer):
103
+ """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
104
+
105
+ def call(self, inputs):
106
+ z_mean, z_log_var = inputs
107
+ batch = tf.shape(z_mean)[0]
108
+ dim = tf.shape(z_mean)[1]
109
+ epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
110
+ return z_mean + tf.exp(0.5 * z_log_var) * epsilon
111
+
112
+ def compute_output_shape(self, input_shape):
113
+ return input_shape[0]
114
+
115
+
116
+ class P2P(Model):
117
+ def __init__(
118
+ self,
119
+ channels: int = 1,
120
+ g_dim: int = 128,
121
+ z_dim: int = 10,
122
+ rnn_size: int = 256,
123
+ prior_rnn_layers: int = 1,
124
+ posterior_rnn_layers: int = 1,
125
+ predictor_rnn_layers: float = 1,
126
+ skip_prob: float = 0.1,
127
+ n_past: int = 1,
128
+ last_frame_skip: bool = False,
129
+ beta: float = 0.0001,
130
+ weight_align: float = 0.1,
131
+ weight_cpc: float = 100,
132
+ ):
133
+ super().__init__()
134
+ # Models parameters
135
+ self.channels = channels
136
+ self.g_dim = g_dim
137
+ self.z_dim = z_dim
138
+ self.rnn_size = rnn_size
139
+ self.prior_rnn_layers = prior_rnn_layers
140
+ self.posterior_rnn_layers = posterior_rnn_layers
141
+ self.predictor_rnn_layers = predictor_rnn_layers
142
+
143
+ # Training parameters
144
+ self.skip_prob = skip_prob
145
+ self.n_past = n_past
146
+ self.last_frame_skip = last_frame_skip
147
+ self.beta = beta
148
+ self.weight_align = weight_align
149
+ self.weight_cpc = weight_cpc
150
+
151
+ self.frame_predictor = self.build_lstm()
152
+ self.prior = self.build_gaussian_lstm()
153
+ self.posterior = self.build_gaussian_lstm()
154
+ self.encoder = self.build_encoder()
155
+ self.decoder = self.build_decoder()
156
+
157
+ self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
158
+ self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
159
+ name="reconstruction_loss"
160
+ )
161
+ self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
162
+ self.align_loss_tracker = tf.keras.metrics.Mean(name="align_loss")
163
+ self.cpc_loss_tracker = tf.keras.metrics.Mean(name="align_loss")
164
+
165
+ self.kl_loss = KLCriterion(
166
+ reduction=tf.keras.losses.Reduction.NONE
167
+ ) # KLDivergence(reduction=tf.keras.losses.Reduction.NONE)
168
+ self.mse = MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
169
+ self.align_loss = MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
170
+
171
+ # self.optimizer = tf.keras.optimizers.Adam(
172
+ # learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
173
+ # )
174
+ # self.prior_optimizer = tf.keras.optimizers.Adam(
175
+ # learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
176
+ # )
177
+
178
+ # region Model building
179
+ def build_lstm(self):
180
+ input = Input(shape=(None, self.g_dim + self.z_dim))
181
+ embed = TimeDistributed(Dense(self.rnn_size))(input)
182
+ lstm = LSTM(self.rnn_size)(embed)
183
+ output = Dense(self.g_dim)(lstm)
184
+ output = (tf.expand_dims(output, axis=1),)
185
+
186
+ return Model(inputs=input, outputs=output, name="frame_predictor")
187
+
188
+ def build_gaussian_lstm(self):
189
+
190
+ input = Input(shape=(None, self.g_dim))
191
+ embed = TimeDistributed(Dense(self.rnn_size))(input)
192
+ lstm = LSTM(self.rnn_size)(embed)
193
+ mu = Dense(self.z_dim)(lstm)
194
+ logvar = Dense(self.z_dim)(lstm)
195
+ z = Sampling()([mu, logvar])
196
+
197
+ return Model(inputs=input, outputs=[mu, logvar, z])
198
+
199
+ def build_encoder(self):
200
+
201
+ input = Input(shape=(1, 64, 64, 1))
202
+
203
+ h = TimeDistributed(Conv2D(64, kernel_size=4, strides=2, padding="same"))(input)
204
+ h = BatchNormalization()(h)
205
+ h1 = LeakyReLU(alpha=0.2)(h)
206
+ # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
207
+
208
+ h = TimeDistributed(Conv2D(128, kernel_size=4, strides=2, padding="same"))(h1)
209
+ h = BatchNormalization()(h)
210
+ h2 = LeakyReLU(alpha=0.2)(h)
211
+ # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
212
+
213
+ h = TimeDistributed(Conv2D(256, kernel_size=4, strides=2, padding="same"))(h2)
214
+ h = BatchNormalization()(h)
215
+ h3 = LeakyReLU(alpha=0.2)(h)
216
+ # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
217
+
218
+ h = TimeDistributed(Conv2D(512, kernel_size=4, strides=2, padding="same"))(h3)
219
+ h = BatchNormalization()(h)
220
+ h4 = LeakyReLU(alpha=0.2)(h)
221
+ # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
222
+
223
+ h = TimeDistributed(
224
+ Conv2D(self.g_dim, kernel_size=4, strides=1, padding="valid")
225
+ )(h4)
226
+ h = BatchNormalization()(h)
227
+ h5 = Activation("tanh")(h)
228
+
229
+ output = tf.reshape(h5, (-1, 1, self.g_dim))
230
+ # h = Flatten()(h)
231
+ # output = Dense(self.g_dim)(h)
232
+ # output = tf.expand_dims(output, axis=1)
233
+ return Model(inputs=input, outputs=[output, [h1, h2, h3, h4]], name="encoder")
234
+
235
+ def build_decoder(self):
236
+ return Decoder(self.g_dim)
237
+
238
+ # def build_decoder(self):
239
+ # latent_inputs = Input(
240
+ # shape=(
241
+ # 1,
242
+ # self.g_dim,
243
+ # )
244
+ # )
245
+ # x = Dense(1 * 1 * 1 * 128, activation="relu")(latent_inputs)
246
+ # x = Reshape((1, 1, 1, 128))(x)
247
+ # x = TimeDistributed(
248
+ # Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid")
249
+ # )(x)
250
+ # x = BatchNormalization()(x)
251
+ # x1 = LeakyReLU(alpha=0.2)(x)
252
+
253
+ # x = TimeDistributed(
254
+ # Conv2DTranspose(256, kernel_size=4, strides=2, padding="same")
255
+ # )(x1)
256
+ # x = BatchNormalization()(x)
257
+ # x2 = LeakyReLU(alpha=0.2)(x)
258
+
259
+ # x = TimeDistributed(
260
+ # Conv2DTranspose(128, kernel_size=4, strides=2, padding="same")
261
+ # )(x2)
262
+ # x = BatchNormalization()(x)
263
+ # x3 = LeakyReLU(alpha=0.2)(x)
264
+
265
+ # x = TimeDistributed(
266
+ # Conv2DTranspose(64, kernel_size=4, strides=2, padding="same")
267
+ # )(x3)
268
+ # x = BatchNormalization()(x)
269
+ # x4 = LeakyReLU(alpha=0.2)(x)
270
+
271
+ # x = TimeDistributed(
272
+ # Conv2DTranspose(1, kernel_size=4, strides=2, padding="same")
273
+ # )(x4)
274
+ # x5 = Activation("sigmoid")(x)
275
+
276
+ # return Model(inputs=latent_inputs, outputs=x5, name="decoder")
277
+
278
+ # endregion
279
+
280
+ @property
281
+ def metrics(self):
282
+ return [
283
+ self.total_loss_tracker,
284
+ self.reconstruction_loss_tracker,
285
+ self.kl_loss_tracker,
286
+ self.align_loss_tracker,
287
+ self.cpc_loss_tracker,
288
+ ]
289
+
290
+ def call(self, inputs, training=None, mask=None):
291
+ first_frame = inputs[:, 0:1, ...]
292
+ last_frame = inputs[:, -1:, ...]
293
+
294
+ desired_length = 20
295
+ previous_frame = first_frame
296
+ generated = [first_frame]
297
+
298
+ z_last, _ = self.encoder(last_frame)
299
+ for i in range(1, desired_length):
300
+
301
+ z_prev = self.encoder(previous_frame)
302
+
303
+ if self.last_frame_skip or i == 1 or i < self.n_past:
304
+ z_prev, skip = z_prev
305
+ else:
306
+ z_prev = z_prev[0]
307
+
308
+ prior_input = tf.concat([z_prev, z_last], axis=1)
309
+
310
+ z_mean_prior, z_log_var_prior, z_prior = self.prior(prior_input)
311
+
312
+ predictor_input = tf.concat(
313
+ (z_prev, tf.expand_dims(z_prior, axis=1)), axis=-1
314
+ )
315
+ z_pred = self.frame_predictor(predictor_input)
316
+
317
+ current_frame = self.decoder([z_pred, skip])
318
+ generated.append(current_frame)
319
+ previous_frame = current_frame
320
+ return tf.concat(generated, axis=1)
321
+
322
+ def train_step(self, data):
323
+ global_batch_size = 100 # * 8
324
+ x, y = data
325
+
326
+ first_frame = x[:, 0:1, ...]
327
+ last_frame = x[:, -1:, ...]
328
+ desired_length = y.shape[1]
329
+ previous_frame = first_frame
330
+
331
+ reconstruction_loss = 0
332
+ kl_loss = 0
333
+ align_loss = 0
334
+ cpc_loss = 0
335
+
336
+ with tf.GradientTape(persistent=True) as tape:
337
+ z_last, _ = self.encoder(last_frame)
338
+ for i in tqdm(range(1, desired_length)):
339
+ current_frame = y[:, i : i + 1, ...]
340
+
341
+ z_prev = self.encoder(previous_frame)
342
+
343
+ if self.last_frame_skip or i <= self.n_past:
344
+ z_prev, skip = z_prev
345
+ else:
346
+ z_prev = z_prev[0]
347
+
348
+ z_curr, _ = self.encoder(current_frame)
349
+
350
+ prior_input = tf.concat([z_prev, z_last], axis=1)
351
+ posterior_input = tf.concat([z_curr, z_last], axis=1)
352
+
353
+ z_mean_prior, z_log_var_prior, z_prior = self.prior(prior_input)
354
+ z_mean_posterior, z_log_var_posterior, z_posterior = self.posterior(
355
+ posterior_input
356
+ )
357
+
358
+ # predictor_input = z_prev
359
+ predictor_input = tf.concat(
360
+ (z_prev, tf.expand_dims(z_posterior, axis=1)), axis=-1
361
+ )
362
+
363
+ z_pred = self.frame_predictor(predictor_input)
364
+
365
+ kl_loss += tf.reduce_sum(
366
+ self.kl_loss(
367
+ (z_mean_prior, z_log_var_prior),
368
+ (z_mean_posterior, z_log_var_posterior),
369
+ )
370
+ ) * (1.0 / global_batch_size)
371
+
372
+ if i > 1:
373
+ align_loss += tf.reduce_sum(self.align_loss(z_pred, z_curr)) * (
374
+ 1.0 / global_batch_size
375
+ )
376
+
377
+ if i == desired_length - 1:
378
+ h_pred_p = self.frame_predictor(
379
+ tf.concat([z_prev, tf.expand_dims(z_prior, axis=1)], axis=-1)
380
+ )
381
+ x_pred_p = self.decoder([h_pred_p, skip])
382
+ cpc_loss = tf.reduce_sum(self.mse(x_pred_p, current_frame)) * (
383
+ 1.0 / global_batch_size
384
+ )
385
+
386
+ prediction = self.decoder([z_pred, skip])
387
+ reconstruction_loss += tf.reduce_sum(
388
+ self.mse(prediction, current_frame)
389
+ ) * (1.0 / global_batch_size)
390
+
391
+ previous_frame = current_frame
392
+
393
+ loss = (
394
+ reconstruction_loss
395
+ + kl_loss * self.beta
396
+ + align_loss * self.weight_align
397
+ + cpc_loss * self.weight_cpc
398
+ )
399
+
400
+ prior_loss = kl_loss + cpc_loss * self.weight_cpc
401
+
402
+ grads_without_prior = tape.gradient(
403
+ loss,
404
+ (
405
+ self.encoder.trainable_weights
406
+ + self.decoder.trainable_weights
407
+ + self.posterior.trainable_weights
408
+ + self.frame_predictor.trainable_weights
409
+ ),
410
+ )
411
+ self.optimizer.apply_gradients(
412
+ zip(
413
+ grads_without_prior,
414
+ (
415
+ self.encoder.trainable_weights
416
+ + self.decoder.trainable_weights
417
+ + self.posterior.trainable_weights
418
+ + self.frame_predictor.trainable_weights
419
+ ),
420
+ )
421
+ )
422
+
423
+ grads_prior = tape.gradient(
424
+ prior_loss,
425
+ self.prior.trainable_weights,
426
+ )
427
+
428
+ self.optimizer.apply_gradients(
429
+ zip(
430
+ grads_prior,
431
+ self.prior.trainable_weights,
432
+ )
433
+ )
434
+ del tape
435
+
436
+ self.total_loss_tracker.update_state(loss)
437
+ self.kl_loss_tracker.update_state(kl_loss)
438
+ self.align_loss_tracker.update_state(align_loss)
439
+ self.reconstruction_loss_tracker.update_state(reconstruction_loss)
440
+ self.cpc_loss_tracker.update_state(cpc_loss)
441
+
442
+ return {
443
+ "loss": self.total_loss_tracker.result(),
444
+ "reconstruction_loss": self.reconstruction_loss_tracker.result(),
445
+ "kl_loss": self.kl_loss_tracker.result(),
446
+ "align_loss": self.align_loss_tracker.result(),
447
+ "cpc_loss": self.cpc_loss_tracker.result(),
448
+ }
449
+
450
+ # print("KL_LOSS")
451
+ # print(kl_loss)
452
+ # print("ALIGN_LOSS")
453
+ # print(align_loss)
454
+ # print("RECONSTRUCTION_LOSS")
455
+ # print(reconstruction_loss)
456
+
457
+ # with tf.GradientTape() as tape:
458
+ # z_mean, z_log_var, z = self.encoder(x)
459
+ # reconstruction = self.decoder(z)
460
+ # reconstruction_loss = tf.reduce_mean(
461
+ # tf.reduce_sum(
462
+ # tf.keras.losses.binary_crossentropy(y, reconstruction),
463
+ # axis=(1, 2),
464
+ # )
465
+ # )
466
+ # kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
467
+ # kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
468
+ # total_loss = reconstruction_loss + self.kl_beta * kl_loss
469
+ # grads = tape.gradient(total_loss, self.trainable_weights)
470
+ # self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
471
+ # self.total_loss_tracker.update_state(total_loss)
472
+ # self.reconstruction_loss_tracker.update_state(reconstruction_loss)
473
+ # self.kl_loss_tracker.update_state(kl_loss)
474
+ # return {
475
+ # "loss": self.total_loss_tracker.result(),
476
+ # "reconstruction_loss": self.reconstruction_loss_tracker.result(),
477
+ # "kl_loss": self.kl_loss_tracker.result(),
478
+ # }
479
+
480
+ # def test_step(self, data):
481
+ # if isinstance(data, tuple):
482
+ # data = data[0]
483
+
484
+ # z_mean, z_log_var, z = self.encoder(data)
485
+ # reconstruction = self.decoder(z)
486
+ # reconstruction_loss = tf.reduce_mean(
487
+ # tf.keras.losses.binary_crossentropy(data, reconstruction)
488
+ # )
489
+ # reconstruction_loss *= 28 * 28
490
+ # kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
491
+ # kl_loss = tf.reduce_mean(kl_loss)
492
+ # kl_loss *= -0.5
493
+ # total_loss = reconstruction_loss + kl_loss
494
+ # return {
495
+ # "loss": total_loss,
496
+ # "reconstruction_loss": reconstruction_loss,
497
+ # "kl_loss": kl_loss,
498
+ # }
ganime/model/p2p/p2p_v3.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from tensorflow.keras import Model, Sequential
4
+ from tensorflow.keras.layers import (
5
+ LSTM,
6
+ Activation,
7
+ BatchNormalization,
8
+ Conv2D,
9
+ Conv2DTranspose,
10
+ Conv3D,
11
+ Conv3DTranspose,
12
+ Dense,
13
+ Flatten,
14
+ Input,
15
+ Layer,
16
+ LeakyReLU,
17
+ MaxPooling2D,
18
+ Reshape,
19
+ TimeDistributed,
20
+ UpSampling2D,
21
+ )
22
+
23
+
24
+ SEQ_LEN = 20
25
+
26
+
27
+ class Sampling(Layer):
28
+ """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
29
+
30
+ def call(self, inputs):
31
+ z_mean, z_log_var = inputs
32
+ batch = tf.shape(z_mean)[0]
33
+ dim = tf.shape(z_mean)[1]
34
+ epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
35
+ return z_mean + tf.exp(0.5 * z_log_var) * epsilon
36
+
37
+ def compute_output_shape(self, input_shape):
38
+ return input_shape[0]
39
+
40
+
41
+ class P2P(Model):
42
+ def __init__(
43
+ self,
44
+ channels: int = 1,
45
+ g_dim: int = 128,
46
+ z_dim: int = 10,
47
+ rnn_size: int = 256,
48
+ prior_rnn_layers: int = 1,
49
+ posterior_rnn_layers: int = 1,
50
+ predictor_rnn_layers: float = 1,
51
+ skip_prob: float = 0.1,
52
+ n_past: int = 1,
53
+ last_frame_skip: bool = False,
54
+ beta: float = 0.0001,
55
+ weight_align: float = 0.1,
56
+ weight_cpc: float = 100,
57
+ ):
58
+ super().__init__()
59
+ # Models parameters
60
+ self.channels = channels
61
+ self.g_dim = g_dim
62
+ self.z_dim = z_dim
63
+ self.rnn_size = rnn_size
64
+ self.prior_rnn_layers = prior_rnn_layers
65
+ self.posterior_rnn_layers = posterior_rnn_layers
66
+ self.predictor_rnn_layers = predictor_rnn_layers
67
+
68
+ # Training parameters
69
+ self.skip_prob = skip_prob
70
+ self.n_past = n_past
71
+ self.last_frame_skip = last_frame_skip
72
+ self.beta = beta
73
+ self.weight_align = weight_align
74
+ self.weight_cpc = weight_cpc
75
+
76
+ self.frame_predictor = self.build_lstm()
77
+ self.prior = self.build_gaussian_lstm()
78
+ self.posterior = self.build_gaussian_lstm()
79
+ self.encoder = self.build_encoder()
80
+ self.decoder = self.build_decoder()
81
+
82
+ self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
83
+ self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
84
+ name="reconstruction_loss"
85
+ )
86
+ self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
87
+
88
+ # region Model building
89
+ def build_lstm(self):
90
+ input = Input(shape=(20, self.g_dim + self.z_dim + 1))
91
+ embed = TimeDistributed(Dense(self.rnn_size))(input)
92
+ lstm = LSTM(self.rnn_size, return_sequences=True)(embed)
93
+ output = TimeDistributed(Dense(self.g_dim))(lstm)
94
+
95
+ return Model(inputs=input, outputs=output, name="frame_predictor")
96
+
97
+ def build_gaussian_lstm(self):
98
+
99
+ input = Input(shape=(20, self.g_dim))
100
+ embed = TimeDistributed(Dense(self.rnn_size))(input)
101
+ lstm = LSTM(self.rnn_size, return_sequences=True)(embed)
102
+ mu = TimeDistributed(Dense(self.z_dim))(lstm)
103
+ logvar = TimeDistributed(Dense(self.z_dim))(lstm)
104
+ z = TimeDistributed(Sampling())([mu, logvar])
105
+
106
+ return Model(inputs=input, outputs=[mu, logvar, z])
107
+
108
+ def build_encoder(self):
109
+
110
+ input = Input(shape=(2, 64, 64, 1))
111
+
112
+ h = TimeDistributed(Conv2D(64, kernel_size=4, strides=2, padding="same"))(input)
113
+ h = BatchNormalization()(h)
114
+ h = LeakyReLU(alpha=0.2)(h)
115
+ # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
116
+
117
+ h = TimeDistributed(Conv2D(128, kernel_size=4, strides=2, padding="same"))(h)
118
+ h = BatchNormalization()(h)
119
+ h = LeakyReLU(alpha=0.2)(h)
120
+ # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
121
+
122
+ h = TimeDistributed(Conv2D(256, kernel_size=4, strides=2, padding="same"))(h)
123
+ h = BatchNormalization()(h)
124
+ h = LeakyReLU(alpha=0.2)(h)
125
+ # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
126
+
127
+ h = TimeDistributed(Conv2D(512, kernel_size=4, strides=2, padding="same"))(h)
128
+ h = BatchNormalization()(h)
129
+ h = LeakyReLU(alpha=0.2)(h)
130
+ # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
131
+
132
+ h = Flatten()(h)
133
+ # mu = Dense(self.g_dim)(h)
134
+ # logvar = Dense(self.g_dim)(h)
135
+
136
+ # z = Sampling()([mu, logvar])
137
+ lstm_input = Dense(self.g_dim * SEQ_LEN)(h)
138
+ lstm_input = Reshape((SEQ_LEN, self.g_dim))(lstm_input)
139
+ mu, logvar, z = self.posterior(lstm_input)
140
+
141
+ return Model(inputs=input, outputs=[mu, logvar, z], name="encoder")
142
+
143
+ def build_decoder(self):
144
+ latent_inputs = Input(shape=(SEQ_LEN, self.z_dim))
145
+ x = Dense(1 * 1 * 1 * 512, activation="relu")(latent_inputs)
146
+ x = Reshape((SEQ_LEN, 1, 1, 512))(x)
147
+ x = TimeDistributed(
148
+ Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid")
149
+ )(x)
150
+ x = BatchNormalization()(x)
151
+ x = LeakyReLU(alpha=0.2)(x)
152
+
153
+ x = TimeDistributed(
154
+ Conv2DTranspose(256, kernel_size=4, strides=2, padding="same")
155
+ )(x)
156
+ x = BatchNormalization()(x)
157
+ x = LeakyReLU(alpha=0.2)(x)
158
+
159
+ x = TimeDistributed(
160
+ Conv2DTranspose(128, kernel_size=4, strides=2, padding="same")
161
+ )(x)
162
+ x = BatchNormalization()(x)
163
+ x = LeakyReLU(alpha=0.2)(x)
164
+
165
+ x = TimeDistributed(
166
+ Conv2DTranspose(64, kernel_size=4, strides=2, padding="same")
167
+ )(x)
168
+ x = BatchNormalization()(x)
169
+ x = LeakyReLU(alpha=0.2)(x)
170
+
171
+ x = TimeDistributed(
172
+ Conv2DTranspose(1, kernel_size=4, strides=2, padding="same")
173
+ )(x)
174
+ x = Activation("sigmoid")(x)
175
+
176
+ return Model(inputs=latent_inputs, outputs=x, name="decoder")
177
+
178
+ # endregion
179
+
180
+ @property
181
+ def metrics(self):
182
+ return [
183
+ self.total_loss_tracker,
184
+ self.reconstruction_loss_tracker,
185
+ self.kl_loss_tracker,
186
+ ]
187
+
188
+ def call(self, inputs, training=None, mask=None):
189
+ z_mean, z_log_var, z = self.encoder(inputs)
190
+ pred = self.decoder(z)
191
+ return pred
192
+
193
+ def train_step(self, data):
194
+ x, y = data
195
+
196
+ with tf.GradientTape() as tape:
197
+ z_mean, z_log_var, z = self.encoder(x)
198
+ reconstruction = self.decoder(z)
199
+ reconstruction_loss = tf.reduce_mean(
200
+ tf.reduce_sum(
201
+ tf.keras.losses.binary_crossentropy(y, reconstruction),
202
+ axis=(1, 2),
203
+ )
204
+ )
205
+ kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
206
+ kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
207
+ total_loss = reconstruction_loss + self.beta * kl_loss
208
+ grads = tape.gradient(total_loss, self.trainable_weights)
209
+ self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
210
+ self.total_loss_tracker.update_state(total_loss)
211
+ self.reconstruction_loss_tracker.update_state(reconstruction_loss)
212
+ self.kl_loss_tracker.update_state(kl_loss)
213
+ return {
214
+ "loss": self.total_loss_tracker.result(),
215
+ "reconstruction_loss": self.reconstruction_loss_tracker.result(),
216
+ "kl_loss": self.kl_loss_tracker.result(),
217
+ }
218
+
219
+ def test_step(self, data):
220
+ if isinstance(data, tuple):
221
+ data = data[0]
222
+
223
+ z_mean, z_log_var, z = self.encoder(data)
224
+ reconstruction = self.decoder(z)
225
+ reconstruction_loss = tf.reduce_mean(
226
+ tf.keras.losses.binary_crossentropy(data, reconstruction)
227
+ )
228
+ reconstruction_loss *= 28 * 28
229
+ kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
230
+ kl_loss = tf.reduce_mean(kl_loss)
231
+ kl_loss *= -0.5
232
+ total_loss = reconstruction_loss + kl_loss
233
+ return {
234
+ "loss": total_loss,
235
+ "reconstruction_loss": reconstruction_loss,
236
+ "kl_loss": kl_loss,
237
+ }
ganime/model/vae/vae.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+
4
+ from tensorflow import keras
5
+ from tensorflow.keras import layers
6
+ import tensorflow as tf
7
+
8
+ input_shape = (20, 64, 64, 1)
9
+
10
+ class Sampling(keras.layers.Layer):
11
+ """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
12
+
13
+ def call(self, inputs):
14
+ z_mean, z_log_var = inputs
15
+ batch = tf.shape(z_mean)[0]
16
+ dim = z_mean.shape[1:]
17
+ epsilon = tf.keras.backend.random_normal(shape=(batch, *dim))
18
+ return z_mean + tf.exp(0.5 * z_log_var) * epsilon
19
+
20
+ def compute_output_shape(self, input_shape):
21
+ return input_shape[0]
22
+
23
+
24
+ class VAE(keras.Model):
25
+ def __init__(self, latent_dim:int=32, num_embeddings:int=128, beta:float = 0.5, **kwargs):
26
+ super().__init__(**kwargs)
27
+ self.latent_dim = latent_dim
28
+ self.num_embeddings = num_embeddings
29
+ self.beta = beta
30
+
31
+ self.encoder = self.get_encoder()
32
+ self.decoder = self.get_decoder()
33
+
34
+ self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
35
+ self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
36
+ name="reconstruction_loss"
37
+ )
38
+ self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
39
+
40
+
41
+ def get_encoder(self):
42
+ encoder_inputs = keras.Input(shape=input_shape)
43
+ x = layers.TimeDistributed(layers.Conv2D(32, 3, activation="relu", strides=2, padding="same"))(
44
+ encoder_inputs
45
+ )
46
+ x = layers.TimeDistributed(layers.Conv2D(64, 3, activation="relu", strides=2, padding="same"))(x)
47
+ x = layers.TimeDistributed(layers.Conv2D(self.latent_dim, 1, padding="same"))(x)
48
+
49
+ x = layers.TimeDistributed(layers.Flatten())(x)
50
+ mu = layers.TimeDistributed(layers.Dense(self.num_embeddings))(x)
51
+ logvar = layers.TimeDistributed(layers.Dense(self.num_embeddings))(x)
52
+ z = Sampling()([mu, logvar])
53
+
54
+ return keras.Model(encoder_inputs, [mu, logvar, z], name="encoder")
55
+
56
+
57
+ def get_decoder(self):
58
+ latent_inputs = keras.Input(shape=self.encoder.output[2].shape[1:])
59
+
60
+ x = layers.TimeDistributed(layers.Dense(16 * 16 * 32, activation="relu"))(latent_inputs)
61
+ x = layers.TimeDistributed(layers.Reshape((16, 16, 32)))(x)
62
+ x = layers.TimeDistributed(layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same"))(
63
+ x
64
+ )
65
+ x = layers.TimeDistributed(layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same"))(x)
66
+ decoder_outputs = layers.TimeDistributed(layers.Conv2DTranspose(1, 3, padding="same"))(x)
67
+ return keras.Model(latent_inputs, decoder_outputs, name="decoder")
68
+
69
+ def train_step(self, data):
70
+ x, y = data
71
+
72
+ with tf.GradientTape() as tape:
73
+ mu, logvar, z = self.encoder(x)
74
+ reconstruction = self.decoder(z)
75
+ reconstruction_loss = tf.reduce_mean(
76
+ tf.reduce_sum(
77
+ tf.keras.losses.binary_crossentropy(y, reconstruction),
78
+ axis=(1, 2),
79
+ )
80
+ )
81
+ kl_loss = -0.5 * (1 + logvar - tf.square(mu) - tf.exp(logvar))
82
+ kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
83
+ total_loss = reconstruction_loss + self.beta * kl_loss
84
+ grads = tape.gradient(total_loss, self.trainable_weights)
85
+ self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
86
+ self.total_loss_tracker.update_state(total_loss)
87
+ self.reconstruction_loss_tracker.update_state(reconstruction_loss)
88
+ self.kl_loss_tracker.update_state(kl_loss)
89
+ return {
90
+ "loss": self.total_loss_tracker.result(),
91
+ "reconstruction_loss": self.reconstruction_loss_tracker.result(),
92
+ "kl_loss": self.kl_loss_tracker.result(),
93
+ }
94
+
95
+ def call(self, inputs, training=False, mask=None):
96
+ z_mean, z_log_var, z = self.encoder(inputs)
97
+ pred = self.decoder(z)
98
+ return pred
ganime/model/vq_vae/vq_vae.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+
4
+ from tensorflow import keras
5
+ from tensorflow.keras import layers
6
+ import tensorflow as tf
7
+
8
+ input_shape = (20, 64, 64, 1)
9
+
10
+ class VectorQuantizer(layers.Layer):
11
+ def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
12
+ super().__init__(**kwargs)
13
+ self.embedding_dim = embedding_dim
14
+ self.num_embeddings = num_embeddings
15
+ self.beta = (
16
+ beta # This parameter is best kept between [0.25, 2] as per the paper.
17
+ )
18
+
19
+ # Initialize the embeddings which we will quantize.
20
+ w_init = tf.random_uniform_initializer()
21
+ self.embeddings = tf.Variable(
22
+ initial_value=w_init(
23
+ shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
24
+ ),
25
+ trainable=True,
26
+ name="embeddings_vqvae",
27
+ )
28
+
29
+ def call(self, x):
30
+ # Calculate the input shape of the inputs and
31
+ # then flatten the inputs keeping `embedding_dim` intact.
32
+ input_shape = tf.shape(x)
33
+ flattened = tf.reshape(x, [-1, self.embedding_dim])
34
+
35
+ # Quantization.
36
+ encoding_indices = self.get_code_indices(flattened)
37
+ encodings = tf.one_hot(encoding_indices, self.num_embeddings)
38
+ quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
39
+ quantized = tf.reshape(quantized, input_shape)
40
+
41
+ # Calculate vector quantization loss and add that to the layer. You can learn more
42
+ # about adding losses to different layers here:
43
+ # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
44
+ # the original paper to get a handle on the formulation of the loss function.
45
+ commitment_loss = self.beta * tf.reduce_mean(
46
+ (tf.stop_gradient(quantized) - x) ** 2
47
+ )
48
+ codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
49
+ self.add_loss(commitment_loss + codebook_loss)
50
+
51
+ # Straight-through estimator.
52
+ quantized = x + tf.stop_gradient(quantized - x)
53
+ return quantized
54
+
55
+ def get_code_indices(self, flattened_inputs):
56
+ # Calculate L2-normalized distance between the inputs and the codes.
57
+ similarity = tf.matmul(flattened_inputs, self.embeddings)
58
+ distances = (
59
+ tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
60
+ + tf.reduce_sum(self.embeddings ** 2, axis=0)
61
+ - 2 * similarity
62
+ )
63
+
64
+ # Derive the indices for minimum distances.
65
+ encoding_indices = tf.argmin(distances, axis=1)
66
+ return encoding_indices
67
+
68
+
69
+ class VQVAE(keras.Model):
70
+ def __init__(self, train_variance:float, latent_dim:int=32, num_embeddings:int=128, **kwargs):
71
+ super().__init__(**kwargs)
72
+ self.train_variance = train_variance
73
+ self.latent_dim = latent_dim
74
+ self.num_embeddings = num_embeddings
75
+
76
+ self.vqvae = self.get_vqvae()
77
+
78
+ self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
79
+ self.reconstruction_loss_tracker = keras.metrics.Mean(
80
+ name="reconstruction_loss"
81
+ )
82
+ self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss")
83
+
84
+
85
+ def get_encoder(self):
86
+ encoder_inputs = keras.Input(shape=input_shape)
87
+ x = layers.TimeDistributed(layers.Conv2D(32, 3, activation="relu", strides=2, padding="same"))(
88
+ encoder_inputs
89
+ )
90
+ x = layers.TimeDistributed(layers.Conv2D(64, 3, activation="relu", strides=2, padding="same"))(x)
91
+ encoder_outputs = layers.TimeDistributed(layers.Conv2D(self.latent_dim, 1, padding="same"))(x)
92
+ return keras.Model(encoder_inputs, encoder_outputs, name="encoder")
93
+
94
+
95
+ def get_decoder(self):
96
+ latent_inputs = keras.Input(shape=self.get_encoder().output.shape[1:])
97
+ x = layers.TimeDistributed(layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same"))(
98
+ latent_inputs
99
+ )
100
+ x = layers.TimeDistributed(layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same"))(x)
101
+ decoder_outputs = layers.TimeDistributed(layers.Conv2DTranspose(1, 3, padding="same"))(x)
102
+ return keras.Model(latent_inputs, decoder_outputs, name="decoder")
103
+
104
+ def get_vqvae(self):
105
+ self.vq_layer = VectorQuantizer(self.num_embeddings, self.latent_dim, name="vector_quantizer")
106
+ self.encoder = self.get_encoder()
107
+ self.decoder = self.get_decoder()
108
+ inputs = keras.Input(shape=input_shape)
109
+ encoder_outputs = self.encoder(inputs)
110
+ quantized_latents = self.vq_layer(encoder_outputs)
111
+ reconstructions = self.decoder(quantized_latents)
112
+ return keras.Model(inputs, reconstructions, name="vq_vae")
113
+
114
+ def train_step(self, data):
115
+ x, y = data
116
+ with tf.GradientTape() as tape:
117
+ # Outputs from the VQ-VAE.
118
+ reconstructions = self.vqvae(x)
119
+
120
+ # Calculate the losses.
121
+ reconstruction_loss = (
122
+ tf.reduce_mean((y - reconstructions) ** 2) / self.train_variance
123
+ )
124
+ total_loss = reconstruction_loss + sum(self.vqvae.losses)
125
+
126
+ # Backpropagation.
127
+ grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
128
+ self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))
129
+
130
+ # Loss tracking.
131
+ self.total_loss_tracker.update_state(total_loss)
132
+ self.reconstruction_loss_tracker.update_state(reconstruction_loss)
133
+ self.vq_loss_tracker.update_state(sum(self.vqvae.losses))
134
+
135
+ # Log results.
136
+ return {
137
+ "loss": self.total_loss_tracker.result(),
138
+ "reconstruction_loss": self.reconstruction_loss_tracker.result(),
139
+ "vqvae_loss": self.vq_loss_tracker.result(),
140
+ }
141
+
142
+ def call(self, inputs, training=False, mask=None):
143
+ return self.vqvae(inputs)
ganime/model/vqgan/__init__.py ADDED
File without changes
ganime/model/vqgan/discriminator/__init__.py ADDED
File without changes
ganime/model/vqgan/discriminator/model.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow import keras
5
+ from tensorflow.keras import Model, Sequential
6
+ from tensorflow.keras import layers
7
+
8
+
9
+ class NLayerDiscriminator(Model):
10
+ """Defines a PatchGAN discriminator as in Pix2Pix
11
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
12
+ """
13
+
14
+ def __init__(self, input_channels: int = 3, filters: int = 64, n_layers: int = 3):
15
+ super().__init__()
16
+
17
+ kernel_size = 4
18
+ self.sequence = [
19
+ layers.Conv2D(filters, kernel_size=kernel_size, padding="same"),
20
+ layers.LeakyReLU(alpha=0.2),
21
+ ]
22
+
23
+ filters_mult = 1
24
+ for n in range(1, n_layers):
25
+ filters_mult = min(2**n, 8)
26
+
27
+ self.sequence += [
28
+ layers.AveragePooling2D(pool_size=2),
29
+ layers.Conv2D(
30
+ filters * filters_mult,
31
+ kernel_size=kernel_size,
32
+ strides=1, # 2,
33
+ padding="same",
34
+ use_bias=False,
35
+ ),
36
+ layers.BatchNormalization(),
37
+ layers.LeakyReLU(alpha=0.2),
38
+ ]
39
+
40
+ filters_mult = min(2**n_layers, 8)
41
+ self.sequence += [
42
+ layers.Conv2D(
43
+ filters * filters_mult,
44
+ kernel_size=kernel_size,
45
+ strides=1,
46
+ padding="same",
47
+ use_bias=False,
48
+ ),
49
+ layers.BatchNormalization(),
50
+ layers.LeakyReLU(alpha=0.2),
51
+ ]
52
+
53
+ self.sequence += [
54
+ layers.Conv2D(1, kernel_size=kernel_size, strides=1, padding="same")
55
+ ]
56
+
57
+ # self.main = Sequential(sequence)
58
+
59
+ def call(self, inputs, training=True, mask=None):
60
+ h = inputs
61
+ for seq in self.sequence:
62
+ h = seq(h)
63
+ return h
64
+ # return self.main(inputs)
ganime/model/vqgan/losses/__init__.py ADDED
File without changes
ganime/model/vqgan/losses/lpips.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import torchvision.models as models
5
+ from tensorflow import keras
6
+ from tensorflow.keras import Model, Sequential
7
+ from tensorflow.keras import backend as K
8
+ from tensorflow.keras import layers
9
+ from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
10
+ from tensorflow.keras.losses import Loss
11
+ from pyprojroot.pyprojroot import here
12
+
13
+
14
+ def normalize_tensor(x, eps=1e-10):
15
+ norm_factor = tf.sqrt(tf.reduce_sum(x**2, axis=-1, keepdims=True))
16
+ return x / (norm_factor + eps)
17
+
18
+
19
+ class LPIPS(Loss):
20
+ def __init__(self, use_dropout=True, **kwargs):
21
+ super().__init__(**kwargs)
22
+
23
+ self.scaling_layer = ScalingLayer() # preprocess_input
24
+ selected_layers = [
25
+ "block1_conv2",
26
+ "block2_conv2",
27
+ "block3_conv3",
28
+ "block4_conv3",
29
+ "block5_conv3",
30
+ ]
31
+
32
+ # TODO here we load the same weights as pytorch, try with tensorflow weights
33
+ self.net = self.load_vgg16() # VGG16(weights="imagenet", include_top=False)
34
+ self.net.trainable = False
35
+ outputs = [self.net.get_layer(layer).output for layer in selected_layers]
36
+
37
+ self.model = Model(self.net.input, outputs)
38
+ self.lins = [NetLinLayer(use_dropout=use_dropout) for _ in selected_layers]
39
+
40
+ # TODO: here we use the pytorch weights of the linear layers, try without these layers, or without initializing the weights
41
+ self(tf.zeros((1, 16, 16, 1)), tf.zeros((1, 16, 16, 1)))
42
+ self.init_lin_layers()
43
+
44
+ def load_vgg16(self) -> Model:
45
+ """Load a VGG16 model with the same weights as PyTorch
46
+ https://github.com/ezavarygin/vgg16_pytorch2keras
47
+ """
48
+ pytorch_model = models.vgg16(pretrained=True)
49
+ # select weights in the conv2d layers and transpose them to keras dim ordering:
50
+ wblist_torch = list(pytorch_model.parameters())[:26]
51
+ wblist_keras = []
52
+ for i in range(len(wblist_torch)):
53
+ if wblist_torch[i].dim() == 4:
54
+ w = np.transpose(wblist_torch[i].detach().numpy(), axes=[2, 3, 1, 0])
55
+ wblist_keras.append(w)
56
+ elif wblist_torch[i].dim() == 1:
57
+ b = wblist_torch[i].detach().numpy()
58
+ wblist_keras.append(b)
59
+ else:
60
+ raise Exception("Fully connected layers are not implemented.")
61
+
62
+ keras_model = VGG16(include_top=False, weights=None)
63
+ keras_model.set_weights(wblist_keras)
64
+ return keras_model
65
+
66
+ def init_lin_layers(self):
67
+ for i in range(5):
68
+ weights = np.load(
69
+ os.path.join(here(), "models", "NetLinLayer", f"numpy_{i}.npy")
70
+ )
71
+ weights = np.moveaxis(weights, 1, 2)
72
+ self.lins[i].model.layers[1].set_weights([weights])
73
+
74
+ def call(self, y_true, y_pred):
75
+
76
+ scaled_true = self.scaling_layer(y_true)
77
+ scaled_pred = self.scaling_layer(y_pred)
78
+
79
+ outputs_true, outputs_pred = self.model(scaled_true), self.model(scaled_pred)
80
+ features_true, features_pred, diffs = {}, {}, {}
81
+
82
+ for kk in range(len(outputs_true)):
83
+ features_true[kk], features_pred[kk] = normalize_tensor(
84
+ outputs_true[kk]
85
+ ), normalize_tensor(outputs_pred[kk])
86
+
87
+ diffs[kk] = (features_true[kk] - features_pred[kk]) ** 2
88
+
89
+ res = [
90
+ tf.reduce_mean(self.lins[kk](diffs[kk]), axis=(-3, -2), keepdims=True)
91
+ for kk in range(len(outputs_true))
92
+ ]
93
+
94
+ return tf.reduce_sum(res)
95
+
96
+ # h1_list = self.model(self.scaling_layer(y_true))
97
+ # h2_list = self.model(self.scaling_layer(y_pred))
98
+
99
+ # rc_loss = 0.0
100
+ # for h1, h2 in zip(h1_list, h2_list):
101
+ # h1 = K.batch_flatten(h1)
102
+ # h2 = K.batch_flatten(h2)
103
+ # rc_loss += K.sum(K.square(h1 - h2), axis=-1)
104
+
105
+ # return rc_loss
106
+
107
+
108
+ class ScalingLayer(layers.Layer):
109
+ def __init__(self, **kwargs):
110
+ super().__init__(**kwargs)
111
+ self.shift = tf.Variable([-0.030, -0.088, -0.188])
112
+ self.scale = tf.Variable([0.458, 0.448, 0.450])
113
+
114
+ def call(self, inputs):
115
+ return (inputs - self.shift) / self.scale
116
+
117
+
118
+ class NetLinLayer(layers.Layer):
119
+ def __init__(self, channels_out=1, use_dropout=False):
120
+ super().__init__()
121
+ sequence = (
122
+ [
123
+ layers.Dropout(0.5),
124
+ ]
125
+ if use_dropout
126
+ else []
127
+ )
128
+ sequence += [
129
+ layers.Conv2D(channels_out, 1, padding="same", use_bias=False),
130
+ ]
131
+ self.model = Sequential(sequence)
132
+
133
+ def call(self, inputs):
134
+ return self.model(inputs)
ganime/model/vqgan/losses/vqperceptual.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Literal
2
+
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from tensorflow import keras
6
+ from tensorflow.keras import Model, layers
7
+ from tensorflow.keras.losses import Loss
8
+
9
+ from .lpips import LPIPS
10
+
11
+ from ..discriminator.model import NLayerDiscriminator
12
+
13
+
14
+ class VQLPIPSWithDiscriminator(Loss):
15
+ def __init__(
16
+ self, *, pixelloss_weight: float = 1.0, perceptual_weight: float = 1.0, **kwargs
17
+ ):
18
+ super().__init__(**kwargs)
19
+ self.pixelloss_weight = pixelloss_weight
20
+ self.perceptual_loss = LPIPS(reduction=tf.keras.losses.Reduction.NONE)
21
+ self.perceptual_weight = perceptual_weight
22
+
23
+ def call(
24
+ self,
25
+ y_true,
26
+ y_pred,
27
+ ):
28
+ reconstruction_loss = tf.abs(y_true - y_pred)
29
+ if self.perceptual_weight > 0:
30
+ perceptual_loss = self.perceptual_loss(y_true, y_pred)
31
+ reconstruction_loss += self.perceptual_weight * perceptual_loss
32
+ else:
33
+ perceptual_loss = 0.0
34
+
35
+ neg_log_likelihood = tf.reduce_mean(reconstruction_loss)
36
+
37
+ return neg_log_likelihood
38
+
39
+ # # GAN part
40
+ # if optimizer_idx == 0:
41
+ # if cond is None:
42
+ # assert not self.disc_conditional
43
+ # logits_fake = self.discriminator(y_pred)
44
+ # else:
45
+ # assert self.disc_conditional
46
+ # logits_fake = self.discriminator(tf.concat([y_pred, cond], axis=-1))
47
+ # g_loss = -tf.reduce_mean(logits_fake)
ganime/model/vqgan/vqgan.py ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Literal
2
+
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from .discriminator.model import NLayerDiscriminator
6
+ from .losses.vqperceptual import VQLPIPSWithDiscriminator
7
+ from tensorflow import keras
8
+ from tensorflow.keras import Model, layers, Sequential
9
+ from tensorflow.keras.optimizers import Optimizer
10
+ from tensorflow_addons.layers import GroupNormalization
11
+
12
+ INPUT_SHAPE = (64, 128, 3)
13
+ ENCODER_OUTPUT_SHAPE = (8, 8, 128)
14
+
15
+
16
+ @tf.function
17
+ def hinge_d_loss(logits_real, logits_fake):
18
+ loss_real = tf.reduce_mean(keras.activations.relu(1.0 - logits_real))
19
+ loss_fake = tf.reduce_mean(keras.activations.relu(1.0 + logits_fake))
20
+ d_loss = 0.5 * (loss_real + loss_fake)
21
+ return d_loss
22
+
23
+
24
+ @tf.function
25
+ def vanilla_d_loss(logits_real, logits_fake):
26
+ d_loss = 0.5 * (
27
+ tf.reduce_mean(keras.activations.softplus(-logits_real))
28
+ + tf.reduce_mean(keras.activations.softplus(logits_fake))
29
+ )
30
+ return d_loss
31
+
32
+
33
+ class VQGAN(keras.Model):
34
+ def __init__(
35
+ self,
36
+ train_variance: float,
37
+ num_embeddings: int,
38
+ embedding_dim: int,
39
+ beta: float = 0.25,
40
+ z_channels: int = 128, # 256,
41
+ codebook_weight: float = 1.0,
42
+ disc_num_layers: int = 3,
43
+ disc_factor: float = 1.0,
44
+ disc_iter_start: int = 0,
45
+ disc_conditional: bool = False,
46
+ disc_in_channels: int = 3,
47
+ disc_weight: float = 0.3,
48
+ disc_filters: int = 64,
49
+ disc_loss: Literal["hinge", "vanilla"] = "hinge",
50
+ **kwargs,
51
+ ):
52
+ super().__init__(**kwargs)
53
+ self.train_variance = train_variance
54
+ self.codebook_weight = codebook_weight
55
+
56
+ self.encoder = Encoder()
57
+ self.decoder = Decoder()
58
+ self.quantize = VectorQuantizer(num_embeddings, embedding_dim, beta=beta)
59
+
60
+ self.quant_conv = layers.Conv2D(embedding_dim, kernel_size=1)
61
+ self.post_quant_conv = layers.Conv2D(z_channels, kernel_size=1)
62
+
63
+ self.vqvae = self.get_vqvae()
64
+
65
+ self.perceptual_loss = VQLPIPSWithDiscriminator(
66
+ reduction=tf.keras.losses.Reduction.NONE
67
+ )
68
+
69
+ self.discriminator = NLayerDiscriminator(
70
+ input_channels=disc_in_channels,
71
+ filters=disc_filters,
72
+ n_layers=disc_num_layers,
73
+ )
74
+ self.discriminator_iter_start = disc_iter_start
75
+
76
+ if disc_loss == "hinge":
77
+ self.disc_loss = hinge_d_loss
78
+ elif disc_loss == "vanilla":
79
+ self.disc_loss = vanilla_d_loss
80
+ else:
81
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
82
+
83
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
84
+ self.disc_factor = disc_factor
85
+ self.discriminator_weight = disc_weight
86
+ self.disc_conditional = disc_conditional
87
+
88
+ self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
89
+ self.reconstruction_loss_tracker = keras.metrics.Mean(
90
+ name="reconstruction_loss"
91
+ )
92
+ self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss")
93
+ self.disc_loss_tracker = keras.metrics.Mean(name="disc_loss")
94
+
95
+ self.gen_optimizer: Optimizer = None
96
+ self.disc_optimizer: Optimizer = None
97
+
98
+ def get_vqvae(self):
99
+ inputs = keras.Input(shape=INPUT_SHAPE)
100
+ quant = self.encode(inputs)
101
+ reconstructed = self.decode(quant)
102
+ return keras.Model(inputs, reconstructed, name="vq_vae")
103
+
104
+ def encode(self, x):
105
+ h = self.encoder(x)
106
+ h = self.quant_conv(h)
107
+ return self.quantize(h)
108
+
109
+ def decode(self, quant):
110
+ quant = self.post_quant_conv(quant)
111
+ dec = self.decoder(quant)
112
+ return dec
113
+
114
+ def call(self, inputs, training=True, mask=None):
115
+ return self.vqvae(inputs)
116
+
117
+ def calculate_adaptive_weight(
118
+ self, nll_loss, g_loss, tape, trainable_vars, discriminator_weight
119
+ ):
120
+ nll_grads = tape.gradient(nll_loss, trainable_vars)[0]
121
+ g_grads = tape.gradient(g_loss, trainable_vars)[0]
122
+
123
+ d_weight = tf.norm(nll_grads) / (tf.norm(g_grads) + 1e-4)
124
+ d_weight = tf.stop_gradient(tf.clip_by_value(d_weight, 0.0, 1e4))
125
+ return d_weight * discriminator_weight
126
+
127
+ @tf.function
128
+ def adopt_weight(self, weight, global_step, threshold=0, value=0.0):
129
+ if global_step < threshold:
130
+ weight = value
131
+ return weight
132
+
133
+ def get_global_step(self, optimizer):
134
+ return optimizer.iterations
135
+
136
+ def compile(
137
+ self,
138
+ gen_optimizer,
139
+ disc_optimizer,
140
+ ):
141
+ super().compile()
142
+ self.gen_optimizer = gen_optimizer
143
+ self.disc_optimizer = disc_optimizer
144
+
145
+ def train_step(self, data):
146
+ x, y = data
147
+
148
+ # Autoencode
149
+ with tf.GradientTape() as tape:
150
+ with tf.GradientTape(persistent=True) as adaptive_tape:
151
+ reconstructions = self(x, training=True)
152
+
153
+ # Calculate the losses.
154
+ # reconstruction_loss = (
155
+ # tf.reduce_mean((y - reconstructions) ** 2) / self.train_variance
156
+ # )
157
+
158
+ logits_fake = self.discriminator(reconstructions, training=False)
159
+
160
+ g_loss = -tf.reduce_mean(logits_fake)
161
+ nll_loss = self.perceptual_loss(y, reconstructions)
162
+
163
+ d_weight = self.calculate_adaptive_weight(
164
+ nll_loss,
165
+ g_loss,
166
+ adaptive_tape,
167
+ self.decoder.conv_out.trainable_variables,
168
+ self.discriminator_weight,
169
+ )
170
+ del adaptive_tape
171
+
172
+ disc_factor = self.adopt_weight(
173
+ weight=self.disc_factor,
174
+ global_step=self.get_global_step(self.gen_optimizer),
175
+ threshold=self.discriminator_iter_start,
176
+ )
177
+
178
+ # total_loss = reconstruction_loss + sum(self.vqvae.losses)
179
+ total_loss = (
180
+ nll_loss
181
+ + d_weight * disc_factor * g_loss
182
+ # + self.codebook_weight * tf.reduce_mean(self.vqvae.losses)
183
+ + self.codebook_weight * sum(self.vqvae.losses)
184
+ )
185
+
186
+ # Backpropagation.
187
+ grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
188
+ self.gen_optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))
189
+
190
+ # Discriminator
191
+ with tf.GradientTape() as disc_tape:
192
+ logits_real = self.discriminator(y, training=True)
193
+ logits_fake = self.discriminator(reconstructions, training=True)
194
+
195
+ disc_factor = self.adopt_weight(
196
+ weight=self.disc_factor,
197
+ global_step=self.get_global_step(self.disc_optimizer),
198
+ threshold=self.discriminator_iter_start,
199
+ )
200
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
201
+
202
+ disc_grads = disc_tape.gradient(d_loss, self.discriminator.trainable_variables)
203
+ self.disc_optimizer.apply_gradients(
204
+ zip(disc_grads, self.discriminator.trainable_variables)
205
+ )
206
+
207
+ # Loss tracking.
208
+ self.total_loss_tracker.update_state(total_loss)
209
+ self.reconstruction_loss_tracker.update_state(nll_loss)
210
+ self.vq_loss_tracker.update_state(sum(self.vqvae.losses))
211
+ self.disc_loss_tracker.update_state(d_loss)
212
+
213
+ # Log results.
214
+ return {
215
+ "loss": self.total_loss_tracker.result(),
216
+ "reconstruction_loss": self.reconstruction_loss_tracker.result(),
217
+ "vqvae_loss": self.vq_loss_tracker.result(),
218
+ "disc_loss": self.disc_loss_tracker.result(),
219
+ }
220
+
221
+
222
+ class VectorQuantizer(layers.Layer):
223
+ def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
224
+ super().__init__(**kwargs)
225
+ self.embedding_dim = embedding_dim
226
+ self.num_embeddings = num_embeddings
227
+ self.beta = (
228
+ beta # This parameter is best kept between [0.25, 2] as per the paper.
229
+ )
230
+
231
+ # Initialize the embeddings which we will quantize.
232
+ w_init = tf.random_uniform_initializer()
233
+ self.embeddings = tf.Variable(
234
+ initial_value=w_init(
235
+ shape=(self.embedding_dim, self.num_embeddings) # , dtype="float32"
236
+ ),
237
+ trainable=True,
238
+ name="embeddings_vqvae",
239
+ )
240
+
241
+ def call(self, x):
242
+ # Calculate the input shape of the inputs and
243
+ # then flatten the inputs keeping `embedding_dim` intact.
244
+ input_shape = tf.shape(x)
245
+ flattened = tf.reshape(x, [-1, self.embedding_dim])
246
+
247
+ # Quantization.
248
+ encoding_indices = self.get_code_indices(flattened)
249
+ encodings = tf.one_hot(encoding_indices, self.num_embeddings)
250
+ quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
251
+ quantized = tf.reshape(quantized, input_shape)
252
+
253
+ # Calculate vector quantization loss and add that to the layer. You can learn more
254
+ # about adding losses to different layers here:
255
+ # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
256
+ # the original paper to get a handle on the formulation of the loss function.
257
+ commitment_loss = self.beta * tf.reduce_mean(
258
+ (tf.stop_gradient(quantized) - x) ** 2
259
+ )
260
+ codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
261
+ self.add_loss(commitment_loss + codebook_loss)
262
+
263
+ # Straight-through estimator.
264
+ quantized = x + tf.stop_gradient(quantized - x)
265
+ return quantized
266
+
267
+ def get_code_indices(self, flattened_inputs):
268
+ # Calculate L2-normalized distance between the inputs and the codes.
269
+ similarity = tf.matmul(flattened_inputs, self.embeddings)
270
+ distances = (
271
+ tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True)
272
+ + tf.reduce_sum(self.embeddings**2, axis=0)
273
+ - 2 * similarity
274
+ )
275
+
276
+ # Derive the indices for minimum distances.
277
+ encoding_indices = tf.argmin(distances, axis=1)
278
+ return encoding_indices
279
+
280
+
281
+ class Encoder(Model):
282
+ def __init__(
283
+ self,
284
+ *,
285
+ channels: int = 128,
286
+ output_channels: int = 3,
287
+ channels_multiplier: List[int] = [1, 1, 2, 2], # [1, 1, 2, 2, 4],
288
+ num_res_blocks: int = 1, # 2,
289
+ attention_resolution: List[int] = [16],
290
+ resolution: int = 64, # 256,
291
+ z_channels=128, # 256,
292
+ dropout=0.0,
293
+ double_z=False,
294
+ resamp_with_conv=True,
295
+ ):
296
+ super().__init__()
297
+
298
+ self.channels = channels
299
+ self.timestep_embeddings_channel = 0
300
+ self.num_resolutions = len(channels_multiplier)
301
+ self.num_res_blocks = num_res_blocks
302
+ self.resolution = resolution
303
+
304
+ self.conv_in = layers.Conv2D(
305
+ self.channels, kernel_size=3, strides=1, padding="same"
306
+ )
307
+
308
+ current_resolution = resolution
309
+
310
+ in_channels_multiplier = (1,) + tuple(channels_multiplier)
311
+
312
+ self.downsampling_list = []
313
+
314
+ for i_level in range(self.num_resolutions):
315
+ block_in = channels * in_channels_multiplier[i_level]
316
+ block_out = channels * channels_multiplier[i_level]
317
+ for i_block in range(self.num_res_blocks):
318
+ self.downsampling_list.append(
319
+ ResnetBlock(
320
+ in_channels=block_in,
321
+ out_channels=block_out,
322
+ timestep_embedding_channels=self.timestep_embeddings_channel,
323
+ dropout=dropout,
324
+ )
325
+ )
326
+ block_in = block_out
327
+
328
+ if current_resolution in attention_resolution:
329
+ # attentions.append(layers.Attention())
330
+ self.downsampling_list.append(AttentionBlock(block_in))
331
+
332
+ if i_level != self.num_resolutions - 1:
333
+ self.downsampling_list.append(Downsample(block_in, resamp_with_conv))
334
+
335
+ # self.downsampling = []
336
+
337
+ # for i_level in range(self.num_resolutions):
338
+ # block = []
339
+ # attentions = []
340
+ # block_in = channels * in_channels_multiplier[i_level]
341
+ # block_out = channels * channels_multiplier[i_level]
342
+ # for i_block in range(self.num_res_blocks):
343
+ # block.append(
344
+ # ResnetBlock(
345
+ # in_channels=block_in,
346
+ # out_channels=block_out,
347
+ # timestep_embedding_channels=self.timestep_embeddings_channel,
348
+ # dropout=dropout,
349
+ # )
350
+ # )
351
+ # block_in = block_out
352
+
353
+ # if current_resolution in attention_resolution:
354
+ # # attentions.append(layers.Attention())
355
+ # attentions.append(AttentionBlock(block_in))
356
+
357
+ # down = {}
358
+ # down["block"] = block
359
+ # down["attention"] = attentions
360
+ # if i_level != self.num_resolutions - 1:
361
+ # down["downsample"] = Downsample(block_in, resamp_with_conv)
362
+ # self.downsampling.append(down)
363
+
364
+ # middle
365
+ self.mid = {}
366
+ self.mid["block_1"] = ResnetBlock(
367
+ in_channels=block_in,
368
+ out_channels=block_in,
369
+ timestep_embedding_channels=self.timestep_embeddings_channel,
370
+ dropout=dropout,
371
+ )
372
+ self.mid["attn_1"] = AttentionBlock(block_in)
373
+ self.mid["block_2"] = ResnetBlock(
374
+ in_channels=block_in,
375
+ out_channels=block_in,
376
+ timestep_embedding_channels=self.timestep_embeddings_channel,
377
+ dropout=dropout,
378
+ )
379
+
380
+ # end
381
+ self.norm_out = GroupNormalization(groups=32, epsilon=1e-6)
382
+ self.conv_out = layers.Conv2D(
383
+ 2 * z_channels if double_z else z_channels,
384
+ kernel_size=3,
385
+ strides=1,
386
+ padding="same",
387
+ )
388
+
389
+ def summary(self):
390
+ x = layers.Input(shape=INPUT_SHAPE)
391
+ model = Model(inputs=[x], outputs=self.call(x))
392
+ return model.summary()
393
+
394
+ def call(self, inputs, training=True, mask=None):
395
+ h = self.conv_in(inputs)
396
+ for downsampling in self.downsampling_list:
397
+ h = downsampling(h)
398
+ # for i_level in range(self.num_resolutions):
399
+ # for i_block in range(self.num_res_blocks):
400
+ # h = self.downsampling[i_level]["block"][i_block](hs[-1])
401
+ # if len(self.downsampling[i_level]["attention"]) > 0:
402
+ # h = self.downsampling[i_level]["attention"][i_block](h)
403
+ # hs.append(h)
404
+ # if i_level != self.num_resolutions - 1:
405
+ # hs.append(self.downsampling[i_level]["downsample"](hs[-1]))
406
+
407
+ # h = hs[-1]
408
+ h = self.mid["block_1"](h)
409
+ h = self.mid["attn_1"](h)
410
+ h = self.mid["block_2"](h)
411
+
412
+ # end
413
+ h = self.norm_out(h)
414
+ h = keras.activations.swish(h)
415
+ h = self.conv_out(h)
416
+ return h
417
+
418
+
419
+ class Decoder(Model):
420
+ def __init__(
421
+ self,
422
+ *,
423
+ channels: int = 128,
424
+ output_channels: int = 3,
425
+ channels_multiplier: List[int] = [1, 1, 2, 2], # [1, 1, 2, 2, 4],
426
+ num_res_blocks: int = 1, # 2,
427
+ attention_resolution: List[int] = [16],
428
+ resolution: int = 64, # 256,
429
+ z_channels=128, # 256,
430
+ dropout=0.0,
431
+ give_pre_end=False,
432
+ resamp_with_conv=True,
433
+ ):
434
+ super().__init__()
435
+
436
+ self.channels = channels
437
+ self.timestep_embeddings_channel = 0
438
+ self.num_resolutions = len(channels_multiplier)
439
+ self.num_res_blocks = num_res_blocks
440
+ self.resolution = resolution
441
+ self.give_pre_end = give_pre_end
442
+
443
+ in_channels_multiplier = (1,) + tuple(channels_multiplier)
444
+ block_in = channels * channels_multiplier[-1]
445
+ current_resolution = resolution // 2 ** (self.num_resolutions - 1)
446
+ self.z_shape = (1, z_channels, current_resolution, current_resolution)
447
+
448
+ print(
449
+ "Working with z of shape {} = {} dimensions.".format(
450
+ self.z_shape, np.prod(self.z_shape)
451
+ )
452
+ )
453
+
454
+ self.conv_in = layers.Conv2D(block_in, kernel_size=3, strides=1, padding="same")
455
+
456
+ # middle
457
+ self.mid = {}
458
+ self.mid["block_1"] = ResnetBlock(
459
+ in_channels=block_in,
460
+ out_channels=block_in,
461
+ timestep_embedding_channels=self.timestep_embeddings_channel,
462
+ dropout=dropout,
463
+ )
464
+ self.mid["attn_1"] = AttentionBlock(block_in)
465
+ self.mid["block_2"] = ResnetBlock(
466
+ in_channels=block_in,
467
+ out_channels=block_in,
468
+ timestep_embedding_channels=self.timestep_embeddings_channel,
469
+ dropout=dropout,
470
+ )
471
+
472
+ # upsampling
473
+
474
+ self.upsampling_list = []
475
+
476
+ for i_level in reversed(range(self.num_resolutions)):
477
+ block_out = channels * channels_multiplier[i_level]
478
+ for i_block in range(self.num_res_blocks + 1):
479
+ self.upsampling_list.append(
480
+ ResnetBlock(
481
+ in_channels=block_in,
482
+ out_channels=block_out,
483
+ timestep_embedding_channels=self.timestep_embeddings_channel,
484
+ dropout=dropout,
485
+ )
486
+ )
487
+ block_in = block_out
488
+
489
+ if current_resolution in attention_resolution:
490
+ # attentions.append(layers.Attention())
491
+ self.upsampling_list.append(AttentionBlock(block_in))
492
+
493
+ if i_level != 0:
494
+ self.upsampling_list.append(Upsample(block_in, resamp_with_conv))
495
+ current_resolution *= 2
496
+ # self.upsampling.insert(0, upsampling)
497
+
498
+ # self.upsampling = []
499
+
500
+ # for i_level in reversed(range(self.num_resolutions)):
501
+ # block = []
502
+ # attentions = []
503
+ # block_out = channels * channels_multiplier[i_level]
504
+ # for i_block in range(self.num_res_blocks + 1):
505
+ # block.append(
506
+ # ResnetBlock(
507
+ # in_channels=block_in,
508
+ # out_channels=block_out,
509
+ # timestep_embedding_channels=self.timestep_embeddings_channel,
510
+ # dropout=dropout,
511
+ # )
512
+ # )
513
+ # block_in = block_out
514
+
515
+ # if current_resolution in attention_resolution:
516
+ # # attentions.append(layers.Attention())
517
+ # attentions.append(AttentionBlock(block_in))
518
+
519
+ # upsampling = {}
520
+ # upsampling["block"] = block
521
+ # upsampling["attention"] = attentions
522
+ # if i_level != 0:
523
+ # upsampling["upsample"] = Upsample(block_in, resamp_with_conv)
524
+ # current_resolution *= 2
525
+ # self.upsampling.insert(0, upsampling)
526
+
527
+ # end
528
+ self.norm_out = GroupNormalization(groups=32, epsilon=1e-6)
529
+ self.conv_out = layers.Conv2D(
530
+ output_channels,
531
+ kernel_size=3,
532
+ strides=1,
533
+ activation="sigmoid",
534
+ padding="same",
535
+ )
536
+
537
+ def summary(self):
538
+ x = layers.Input(shape=ENCODER_OUTPUT_SHAPE)
539
+ model = Model(inputs=[x], outputs=self.call(x))
540
+ return model.summary()
541
+
542
+ def call(self, inputs, training=True, mask=None):
543
+
544
+ h = self.conv_in(inputs)
545
+
546
+ # middle
547
+ h = self.mid["block_1"](h)
548
+ h = self.mid["attn_1"](h)
549
+ h = self.mid["block_2"](h)
550
+
551
+ for upsampling in self.upsampling_list:
552
+ h = upsampling(h)
553
+
554
+ # for i_level in reversed(range(self.num_resolutions)):
555
+ # for i_block in range(self.num_res_blocks + 1):
556
+ # h = self.upsampling[i_level]["block"][i_block](h)
557
+ # if len(self.upsampling[i_level]["attention"]) > 0:
558
+ # h = self.upsampling[i_level]["attention"][i_block](h)
559
+ # if i_level != 0:
560
+ # h = self.upsampling[i_level]["upsample"](h)
561
+
562
+ # end
563
+ if self.give_pre_end:
564
+ return h
565
+
566
+ h = self.norm_out(h)
567
+ h = keras.activations.swish(h)
568
+ h = self.conv_out(h)
569
+ return h
570
+
571
+
572
+ class ResnetBlock(layers.Layer):
573
+ def __init__(
574
+ self,
575
+ *,
576
+ in_channels,
577
+ dropout=0.0,
578
+ out_channels=None,
579
+ conv_shortcut=False,
580
+ timestep_embedding_channels=512,
581
+ ):
582
+ super().__init__()
583
+ self.in_channels = in_channels
584
+ out_channels = in_channels if out_channels is None else out_channels
585
+ self.out_channels = out_channels
586
+ self.use_conv_shortcut = conv_shortcut
587
+
588
+ self.norm1 = GroupNormalization(groups=32, epsilon=1e-6)
589
+
590
+ self.conv1 = layers.Conv2D(
591
+ out_channels, kernel_size=3, strides=1, padding="same"
592
+ )
593
+
594
+ if timestep_embedding_channels > 0:
595
+ self.timestep_embedding_projection = layers.Dense(out_channels)
596
+
597
+ self.norm2 = GroupNormalization(groups=32, epsilon=1e-6)
598
+ self.dropout = layers.Dropout(dropout)
599
+
600
+ self.conv2 = layers.Conv2D(
601
+ out_channels, kernel_size=3, strides=1, padding="same"
602
+ )
603
+
604
+ if self.in_channels != self.out_channels:
605
+ if self.use_conv_shortcut:
606
+ self.conv_shortcut = layers.Conv2D(
607
+ out_channels, kernel_size=3, strides=1, padding="same"
608
+ )
609
+ else:
610
+ self.nin_shortcut = layers.Conv2D(
611
+ out_channels, kernel_size=1, strides=1, padding="valid"
612
+ )
613
+
614
+ def call(self, x):
615
+ h = x
616
+ h = self.norm1(h)
617
+ h = keras.activations.swish(h)
618
+ h = self.conv1(h)
619
+
620
+ # if timestamp_embedding is not None:
621
+ # h = h + self.timestep_embedding_projection(keras.activations.swish(timestamp_embedding))
622
+
623
+ h = self.norm2(h)
624
+ h = keras.activations.swish(h)
625
+ h = self.dropout(h)
626
+ h = self.conv2(h)
627
+
628
+ if self.in_channels != self.out_channels:
629
+ if self.use_conv_shortcut:
630
+ x = self.conv_shortcut(x)
631
+ else:
632
+ x = self.nin_shortcut(x)
633
+
634
+ return x + h
635
+
636
+
637
+ class AttentionBlock(layers.Layer):
638
+ def __init__(self, channels):
639
+ super().__init__()
640
+
641
+ self.norm = GroupNormalization(groups=32, epsilon=1e-6)
642
+ self.q = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid")
643
+ self.k = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid")
644
+ self.v = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid")
645
+ self.proj_out = layers.Conv2D(
646
+ channels, kernel_size=1, strides=1, padding="valid"
647
+ )
648
+
649
+ def call(self, x):
650
+ h_ = x
651
+ h_ = self.norm(h_)
652
+ q = self.q(h_)
653
+ k = self.k(h_)
654
+ v = self.v(h_)
655
+
656
+ # compute attention
657
+ (
658
+ b,
659
+ h,
660
+ w,
661
+ c,
662
+ ) = q.shape
663
+ if b is None:
664
+ b = -1
665
+ q = tf.reshape(q, [b, h * w, c])
666
+ k = tf.reshape(k, [b, h * w, c])
667
+ w_ = tf.matmul(
668
+ q, k, transpose_b=True
669
+ ) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
670
+ w_ = w_ * (int(c) ** (-0.5))
671
+ w_ = keras.activations.softmax(w_)
672
+
673
+ # attend to values
674
+ v = tf.reshape(v, [b, h * w, c])
675
+ # w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
676
+ h_ = tf.matmul(
677
+ v, w_, transpose_a=True
678
+ ) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
679
+ # h_ = h_.reshape(b, c, h, w)
680
+ h_ = tf.reshape(h_, [b, h, w, c])
681
+
682
+ h_ = self.proj_out(h_)
683
+
684
+ return x + h_
685
+
686
+
687
+ class Downsample(layers.Layer):
688
+ def __init__(self, channels, with_conv=True):
689
+ super().__init__()
690
+ self.with_conv = with_conv
691
+ if self.with_conv:
692
+ # no asymmetric padding in torch conv, must do it ourselves
693
+ self.down_sample = layers.Conv2D(
694
+ channels, kernel_size=3, strides=2, padding="same"
695
+ )
696
+ else:
697
+ self.down_sample = layers.AveragePooling2D(pool_size=2, strides=2)
698
+
699
+ def call(self, x):
700
+ x = self.down_sample(x)
701
+ return x
702
+
703
+
704
+ class Upsample(layers.Layer):
705
+ def __init__(self, channels, with_conv=False):
706
+ super().__init__()
707
+ self.with_conv = with_conv
708
+ if False: # self.with_conv:
709
+ self.up_sample = layers.Conv2DTranspose(
710
+ channels, kernel_size=3, strides=2, padding="same"
711
+ )
712
+ else:
713
+ self.up_sample = Sequential(
714
+ [
715
+ layers.UpSampling2D(size=2, interpolation="nearest"),
716
+ layers.Conv2D(channels, kernel_size=3, strides=1, padding="same"),
717
+ ]
718
+ )
719
+
720
+ def call(self, x):
721
+ x = self.up_sample(x)
722
+ return x
ganime/model/vqgan_clean/__init__.py ADDED
File without changes
ganime/model/vqgan_clean/diffusion/__init__.py ADDED
File without changes
ganime/model/vqgan_clean/diffusion/decoder.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from tensorflow import keras
6
+ from tensorflow.keras import Model, layers
7
+ from tensorflow_addons.layers import GroupNormalization
8
+
9
+ from .layers import AttentionBlock, ResnetBlock, Upsample
10
+
11
+
12
+ # @tf.keras.utils.register_keras_serializable()
13
+ class Decoder(layers.Layer):
14
+ def __init__(
15
+ self,
16
+ *,
17
+ channels: int,
18
+ output_channels: int = 3,
19
+ channels_multiplier: List[int],
20
+ num_res_blocks: int,
21
+ attention_resolution: List[int],
22
+ resolution: int,
23
+ z_channels: int,
24
+ dropout: float,
25
+ **kwargs
26
+ ):
27
+ super().__init__(**kwargs)
28
+
29
+ self.channels = channels
30
+ self.output_channels = output_channels
31
+ self.channels_multiplier = channels_multiplier
32
+ self.num_resolutions = len(channels_multiplier)
33
+ self.num_res_blocks = num_res_blocks
34
+ self.attention_resolution = attention_resolution
35
+ self.resolution = resolution
36
+ self.z_channels = z_channels
37
+ self.dropout = dropout
38
+
39
+ block_in = channels * channels_multiplier[-1]
40
+ current_resolution = resolution // 2 ** (self.num_resolutions - 1)
41
+ self.z_shape = (1, z_channels, current_resolution, current_resolution)
42
+
43
+ print(
44
+ "Working with z of shape {} = {} dimensions.".format(
45
+ self.z_shape, np.prod(self.z_shape)
46
+ )
47
+ )
48
+
49
+ self.conv_in = layers.Conv2D(block_in, kernel_size=3, strides=1, padding="same")
50
+
51
+ # middle
52
+ self.mid = {}
53
+ self.mid["block_1"] = ResnetBlock(
54
+ in_channels=block_in,
55
+ out_channels=block_in,
56
+ dropout=dropout,
57
+ )
58
+ self.mid["attn_1"] = AttentionBlock(block_in)
59
+ self.mid["block_2"] = ResnetBlock(
60
+ in_channels=block_in,
61
+ out_channels=block_in,
62
+ dropout=dropout,
63
+ )
64
+
65
+ # upsampling
66
+
67
+ self.upsampling_list = []
68
+
69
+ for i_level in reversed(range(self.num_resolutions)):
70
+ block_out = channels * channels_multiplier[i_level]
71
+ for i_block in range(self.num_res_blocks + 1):
72
+ self.upsampling_list.append(
73
+ ResnetBlock(
74
+ in_channels=block_in,
75
+ out_channels=block_out,
76
+ dropout=dropout,
77
+ )
78
+ )
79
+ block_in = block_out
80
+
81
+ if current_resolution in attention_resolution:
82
+ # attentions.append(layers.Attention())
83
+ self.upsampling_list.append(AttentionBlock(block_in))
84
+
85
+ if i_level != 0:
86
+ self.upsampling_list.append(Upsample(block_in))
87
+ current_resolution *= 2
88
+
89
+ # end
90
+ self.norm_out = GroupNormalization(groups=32, epsilon=1e-6)
91
+ self.conv_out = layers.Conv2D(
92
+ output_channels,
93
+ kernel_size=3,
94
+ strides=1,
95
+ activation="tanh",
96
+ padding="same",
97
+ )
98
+
99
+ def call(self, inputs, training=True, mask=None):
100
+
101
+ h = self.conv_in(inputs)
102
+
103
+ # middle
104
+ h = self.mid["block_1"](h)
105
+ h = self.mid["attn_1"](h)
106
+ h = self.mid["block_2"](h)
107
+
108
+ for upsampling in self.upsampling_list:
109
+ h = upsampling(h)
110
+
111
+ # end
112
+ h = self.norm_out(h)
113
+ h = keras.activations.swish(h)
114
+ h = self.conv_out(h)
115
+ return h
ganime/model/vqgan_clean/diffusion/encoder.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import tensorflow as tf
3
+ from tensorflow import keras
4
+ from tensorflow.keras import layers, Model
5
+ from tensorflow_addons.layers import GroupNormalization
6
+ from .layers import ResnetBlock, AttentionBlock, Downsample
7
+
8
+
9
+ # @tf.keras.utils.register_keras_serializable()
10
+ class Encoder(layers.Layer):
11
+ def __init__(
12
+ self,
13
+ *,
14
+ channels: int,
15
+ channels_multiplier: List[int],
16
+ num_res_blocks: int,
17
+ attention_resolution: List[int],
18
+ resolution: int,
19
+ z_channels: int,
20
+ dropout: float,
21
+ **kwargs
22
+ ):
23
+ """Encode an image into a latent vector. The encoder will be constitued of multiple levels (lenght of `channels_multiplier`) with for each level `num_res_blocks` ResnetBlock.
24
+ Args:
25
+ channels (int, optional): The number of channel for the first layer. Defaults to 128.
26
+ channels_multiplier (List[int], optional): The channel multiplier for each level (previous level channels X multipler). Defaults to [1, 1, 2, 2].
27
+ num_res_blocks (int, optional): Number of ResnetBlock at each level. Defaults to 1.
28
+ attention_resolution (List[int], optional): Add an attention block if the current resolution is in this array. Defaults to [16].
29
+ resolution (int, optional): The starting resolution. Defaults to 64.
30
+ z_channels (int, optional): The number of channel at the end of the encoder. Defaults to 128.
31
+ dropout (float, optional): The dropout ratio for each ResnetBlock. Defaults to 0.0.
32
+ """
33
+ super().__init__(**kwargs)
34
+
35
+ self.channels = channels
36
+ self.channels_multiplier = channels_multiplier
37
+ self.num_resolutions = len(channels_multiplier)
38
+ self.num_res_blocks = num_res_blocks
39
+ self.attention_resolution = attention_resolution
40
+ self.resolution = resolution
41
+ self.z_channels = z_channels
42
+ self.dropout = dropout
43
+
44
+ self.conv_in = layers.Conv2D(
45
+ self.channels, kernel_size=3, strides=1, padding="same"
46
+ )
47
+
48
+ current_resolution = resolution
49
+
50
+ in_channels_multiplier = (1,) + tuple(channels_multiplier)
51
+
52
+ self.downsampling_list = []
53
+
54
+ for i_level in range(self.num_resolutions):
55
+ block_in = channels * in_channels_multiplier[i_level]
56
+ block_out = channels * channels_multiplier[i_level]
57
+ for i_block in range(self.num_res_blocks):
58
+ self.downsampling_list.append(
59
+ ResnetBlock(
60
+ in_channels=block_in,
61
+ out_channels=block_out,
62
+ dropout=dropout,
63
+ )
64
+ )
65
+ block_in = block_out
66
+
67
+ if current_resolution in attention_resolution:
68
+ self.downsampling_list.append(AttentionBlock(block_in))
69
+
70
+ if i_level != self.num_resolutions - 1:
71
+ self.downsampling_list.append(Downsample(block_in))
72
+ current_resolution = current_resolution // 2
73
+
74
+ # middle
75
+ self.mid = {}
76
+ self.mid["block_1"] = ResnetBlock(
77
+ in_channels=block_in,
78
+ out_channels=block_in,
79
+ dropout=dropout,
80
+ )
81
+ self.mid["attn_1"] = AttentionBlock(block_in)
82
+ self.mid["block_2"] = ResnetBlock(
83
+ in_channels=block_in,
84
+ out_channels=block_in,
85
+ dropout=dropout,
86
+ )
87
+
88
+ # end
89
+ self.norm_out = GroupNormalization(groups=32, epsilon=1e-6)
90
+ self.conv_out = layers.Conv2D(
91
+ z_channels,
92
+ kernel_size=3,
93
+ strides=1,
94
+ padding="same",
95
+ )
96
+
97
+ # def get_config(self):
98
+ # config = super().get_config()
99
+ # config.update(
100
+ # {
101
+ # "channels": self.channels,
102
+ # "channels_multiplier": self.channels_multiplier,
103
+ # "num_res_blocks": self.num_res_blocks,
104
+ # "attention_resolution": self.attention_resolution,
105
+ # "resolution": self.resolution,
106
+ # "z_channels": self.z_channels,
107
+ # "dropout": self.dropout,
108
+ # }
109
+ # )
110
+ # return config
111
+
112
+ def call(self, inputs, training=True, mask=None):
113
+ h = self.conv_in(inputs)
114
+ for downsampling in self.downsampling_list:
115
+ h = downsampling(h)
116
+
117
+ h = self.mid["block_1"](h)
118
+ h = self.mid["attn_1"](h)
119
+ h = self.mid["block_2"](h)
120
+
121
+ # end
122
+ h = self.norm_out(h)
123
+ h = keras.activations.swish(h)
124
+ h = self.conv_out(h)
125
+ return h
ganime/model/vqgan_clean/diffusion/layers.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow import keras
3
+ from tensorflow.keras import layers, Sequential
4
+ from tensorflow_addons.layers import GroupNormalization
5
+
6
+
7
+ @tf.keras.utils.register_keras_serializable()
8
+ class ResnetBlock(layers.Layer):
9
+ def __init__(
10
+ self,
11
+ *,
12
+ in_channels,
13
+ dropout=0.0,
14
+ out_channels=None,
15
+ conv_shortcut=False,
16
+ **kwargs
17
+ ):
18
+ super().__init__(**kwargs)
19
+ self.in_channels = in_channels
20
+ self.dropout_rate = dropout
21
+ out_channels = in_channels if out_channels is None else out_channels
22
+ self.out_channels = out_channels
23
+ self.use_conv_shortcut = conv_shortcut
24
+
25
+ self.norm1 = GroupNormalization(groups=32, epsilon=1e-6)
26
+
27
+ self.conv1 = layers.Conv2D(
28
+ out_channels, kernel_size=3, strides=1, padding="same"
29
+ )
30
+
31
+ self.norm2 = GroupNormalization(groups=32, epsilon=1e-6)
32
+ self.dropout = layers.Dropout(dropout)
33
+
34
+ self.conv2 = layers.Conv2D(
35
+ out_channels, kernel_size=3, strides=1, padding="same"
36
+ )
37
+
38
+ if self.in_channels != self.out_channels:
39
+ if self.use_conv_shortcut:
40
+ self.conv_shortcut = layers.Conv2D(
41
+ out_channels, kernel_size=3, strides=1, padding="same"
42
+ )
43
+ else:
44
+ self.nin_shortcut = layers.Conv2D(
45
+ out_channels, kernel_size=1, strides=1, padding="valid"
46
+ )
47
+
48
+ def get_config(self):
49
+ config = super().get_config()
50
+ config.update(
51
+ {
52
+ "in_channels": self.in_channels,
53
+ "dropout": self.dropout_rate,
54
+ "out_channels": self.out_channels,
55
+ "conv_shortcut": self.use_conv_shortcut,
56
+ }
57
+ )
58
+ return config
59
+
60
+ def call(self, x):
61
+ h = x
62
+ h = self.norm1(h)
63
+ h = keras.activations.swish(h)
64
+ h = self.conv1(h)
65
+
66
+ h = self.norm2(h)
67
+ h = keras.activations.swish(h)
68
+ h = self.dropout(h)
69
+ h = self.conv2(h)
70
+
71
+ if self.in_channels != self.out_channels:
72
+ if self.use_conv_shortcut:
73
+ x = self.conv_shortcut(x)
74
+ else:
75
+ x = self.nin_shortcut(x)
76
+
77
+ return x + h
78
+
79
+
80
+ @tf.keras.utils.register_keras_serializable()
81
+ class AttentionBlock(layers.Layer):
82
+ def __init__(self, channels, **kwargs):
83
+ super().__init__(**kwargs)
84
+ self.channels = channels
85
+ self.norm = GroupNormalization(groups=32, epsilon=1e-6)
86
+ self.q = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid")
87
+ self.k = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid")
88
+ self.v = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid")
89
+ self.proj_out = layers.Conv2D(
90
+ channels, kernel_size=1, strides=1, padding="valid"
91
+ )
92
+
93
+ self.attention = layers.Attention()
94
+
95
+ def get_config(self):
96
+ config = super().get_config()
97
+ config.update(
98
+ {
99
+ "channels": self.channels,
100
+ }
101
+ )
102
+ return config
103
+
104
+ def call(self, x):
105
+ h_ = x
106
+ h_ = self.norm(h_)
107
+ q = self.q(h_)
108
+ k = self.k(h_)
109
+ v = self.v(h_)
110
+
111
+ # compute attention
112
+ (b, h, w, c,) = (
113
+ tf.shape(q)[0],
114
+ tf.shape(q)[1],
115
+ tf.shape(q)[2],
116
+ tf.shape(q)[3],
117
+ )
118
+
119
+ if b is None:
120
+ b = -1
121
+ q = tf.reshape(q, [b, h * w, c])
122
+ k = tf.reshape(k, [b, h * w, c])
123
+ v = tf.reshape(v, [b, h * w, c])
124
+
125
+ h_ = self.attention([q, v, k])
126
+
127
+ h_ = tf.reshape(h_, [b, h, w, c])
128
+
129
+ h_ = self.proj_out(h_)
130
+
131
+ return x + h_
132
+
133
+
134
+ @tf.keras.utils.register_keras_serializable()
135
+ class Downsample(layers.Layer):
136
+ def __init__(self, channels, **kwargs):
137
+ super().__init__(**kwargs)
138
+ self.channels = channels
139
+ self.down_sample = self.down_sample = layers.AveragePooling2D(
140
+ pool_size=2, strides=2
141
+ )
142
+ self.conv = layers.Conv2D(channels, kernel_size=3, strides=1, padding="same")
143
+
144
+ def get_config(self):
145
+ config = super().get_config()
146
+ config.update(
147
+ {
148
+ "channels": self.channels,
149
+ }
150
+ )
151
+ return config
152
+
153
+ def call(self, x):
154
+ x = self.down_sample(x)
155
+ x = self.conv(x)
156
+ return x
157
+
158
+
159
+ @tf.keras.utils.register_keras_serializable()
160
+ class Upsample(layers.Layer):
161
+ def __init__(self, channels, **kwargs):
162
+ super().__init__(**kwargs)
163
+ self.channels = channels
164
+ self.up_sample = layers.UpSampling2D(size=2, interpolation="bilinear")
165
+ self.conv = layers.Conv2D(channels, kernel_size=3, strides=1, padding="same")
166
+
167
+ def get_config(self):
168
+ config = super().get_config()
169
+ config.update(
170
+ {
171
+ "channels": self.channels,
172
+ }
173
+ )
174
+ return config
175
+
176
+ def call(self, x):
177
+ x = self.up_sample(x)
178
+ x = self.conv(x)
179
+ return x
ganime/model/vqgan_clean/discriminator/__init__.py ADDED
File without changes
ganime/model/vqgan_clean/discriminator/model.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow import keras
5
+ from tensorflow.keras import Model, Sequential
6
+ from tensorflow.keras import layers
7
+ from tensorflow.keras.initializers import RandomNormal
8
+
9
+
10
+ class NLayerDiscriminator(Model):
11
+ """Defines a PatchGAN discriminator as in Pix2Pix
12
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
13
+ """
14
+
15
+ def __init__(self, filters: int = 64, n_layers: int = 3, **kwargs):
16
+ super().__init__(**kwargs)
17
+
18
+ init = RandomNormal(stddev=0.02)
19
+ self.filters = filters
20
+ self.n_layers = n_layers
21
+
22
+ kernel_size = 4
23
+
24
+ inp = tf.keras.layers.Input(shape=[256, 512, 3], name="input_image")
25
+ tar = tf.keras.layers.Input(shape=[256, 512, 3], name="target_image")
26
+
27
+ x = tf.keras.layers.concatenate([inp, tar])
28
+
29
+ x = layers.Conv2D(
30
+ filters,
31
+ kernel_size=kernel_size,
32
+ strides=2,
33
+ # strides=1,
34
+ padding="same",
35
+ kernel_initializer=init,
36
+ )(x)
37
+ x = layers.LeakyReLU(alpha=0.2)(x)
38
+
39
+ filters_mult = 1
40
+ for n in range(1, n_layers):
41
+ filters_mult = min(2**n, 8)
42
+
43
+ x = layers.Conv2D(
44
+ filters * filters_mult,
45
+ kernel_size=kernel_size,
46
+ # strides=1, # 2,
47
+ strides=2,
48
+ padding="same",
49
+ use_bias=False,
50
+ kernel_initializer=init,
51
+ )(x)
52
+ x = layers.BatchNormalization()(x)
53
+ x = layers.LeakyReLU(alpha=0.2)(x)
54
+
55
+ filters_mult = min(2**n_layers, 8)
56
+ x = layers.Conv2D(
57
+ filters * filters_mult,
58
+ kernel_size=kernel_size,
59
+ strides=1,
60
+ padding="same",
61
+ use_bias=False,
62
+ kernel_initializer=init,
63
+ )(x)
64
+ x = layers.BatchNormalization()(x)
65
+ x = layers.LeakyReLU(alpha=0.2)(x)
66
+
67
+ x = layers.Conv2D(
68
+ 1,
69
+ kernel_size=kernel_size,
70
+ strides=1,
71
+ padding="same",
72
+ # activation="sigmoid",
73
+ kernel_initializer=init,
74
+ )(x)
75
+ self.model = tf.keras.Model(inputs=[inp, tar], outputs=x)
76
+
77
+ def call(self, inputs, training=True, mask=None):
78
+ return self.model(inputs)
79
+
80
+ def get_config(self):
81
+ config = super().get_config()
82
+ config.update(
83
+ {
84
+ "filters": self.filters,
85
+ "n_layers": self.n_layers,
86
+ }
87
+ )
88
+ return config
ganime/model/vqgan_clean/discriminator/model_bkp.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow import keras
5
+ from tensorflow.keras import Model, Sequential
6
+ from tensorflow.keras import layers
7
+
8
+
9
+ class NLayerDiscriminator(Model):
10
+ """Defines a PatchGAN discriminator as in Pix2Pix
11
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
12
+ """
13
+
14
+ def __init__(self, filters: int = 64, n_layers: int = 3, **kwargs):
15
+ super().__init__(**kwargs)
16
+
17
+ self.filters = filters
18
+ self.n_layers = n_layers
19
+
20
+ kernel_size = 4
21
+ self.sequence = [
22
+ layers.Conv2D(filters, kernel_size=kernel_size, strides=1, padding="same"),
23
+ layers.LeakyReLU(alpha=0.2),
24
+ ]
25
+
26
+ filters_mult = 1
27
+ for n in range(1, n_layers):
28
+ filters_mult = min(2**n, 8)
29
+
30
+ self.sequence += [
31
+ layers.AveragePooling2D(pool_size=2),
32
+ layers.Conv2D(
33
+ filters * filters_mult,
34
+ kernel_size=kernel_size,
35
+ strides=1, # 2,
36
+ # strides=2,
37
+ padding="same",
38
+ use_bias=False,
39
+ ),
40
+ layers.BatchNormalization(),
41
+ layers.LeakyReLU(alpha=0.2),
42
+ ]
43
+
44
+ filters_mult = min(2**n_layers, 8)
45
+ self.sequence += [
46
+ layers.AveragePooling2D(pool_size=2),
47
+ layers.Conv2D(
48
+ filters * filters_mult,
49
+ kernel_size=kernel_size,
50
+ strides=1,
51
+ padding="same",
52
+ use_bias=False,
53
+ ),
54
+ layers.BatchNormalization(),
55
+ layers.LeakyReLU(alpha=0.2),
56
+ ]
57
+
58
+ self.sequence += [
59
+ layers.Conv2D(1, kernel_size=kernel_size, strides=1, padding="same")
60
+ ]
61
+
62
+ def call(self, inputs, training=True, mask=None):
63
+ h = inputs
64
+ for seq in self.sequence:
65
+ h = seq(h)
66
+ return h
67
+
68
+ def get_config(self):
69
+ config = super().get_config()
70
+ config.update(
71
+ {
72
+ "filters": self.filters,
73
+ "n_layers": self.n_layers,
74
+ }
75
+ )
76
+ return config
ganime/model/vqgan_clean/experimental/gpt2_embedding.py ADDED
@@ -0,0 +1,1127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ TF 2.0 OpenAI GPT-2 model."""
17
+
18
+ from dataclasses import dataclass
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import tensorflow as tf
23
+ from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
24
+
25
+
26
+ from transformers.activations_tf import get_tf_activation
27
+ from transformers.modeling_tf_outputs import (
28
+ TFBaseModelOutputWithPastAndCrossAttentions,
29
+ TFCausalLMOutputWithCrossAttentions,
30
+ TFSequenceClassifierOutputWithPast,
31
+ )
32
+ from transformers.modeling_tf_utils import (
33
+ TFCausalLanguageModelingLoss,
34
+ TFConv1D,
35
+ TFModelInputType,
36
+ TFPreTrainedModel,
37
+ TFSequenceClassificationLoss,
38
+ TFSequenceSummary,
39
+ TFSharedEmbeddings,
40
+ get_initializer,
41
+ keras_serializable,
42
+ unpack_inputs,
43
+ )
44
+ from transformers.tf_utils import shape_list, stable_softmax
45
+ from transformers.utils import (
46
+ DUMMY_INPUTS,
47
+ ModelOutput,
48
+ add_code_sample_docstrings,
49
+ add_start_docstrings,
50
+ add_start_docstrings_to_model_forward,
51
+ logging,
52
+ replace_return_docstrings,
53
+ )
54
+ from transformers import GPT2Config
55
+
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+ _CHECKPOINT_FOR_DOC = "gpt2"
60
+ _CONFIG_FOR_DOC = "GPT2Config"
61
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
62
+
63
+ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
64
+ "gpt2",
65
+ "gpt2-medium",
66
+ "gpt2-large",
67
+ "gpt2-xl",
68
+ "distilgpt2",
69
+ # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
70
+ ]
71
+
72
+
73
+ class TFAttention(tf.keras.layers.Layer):
74
+ def __init__(self, nx, config, scale=False, is_cross_attention=False, **kwargs):
75
+ super().__init__(**kwargs)
76
+
77
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
78
+ # [switch nx => n_state from Block to Attention to keep identical to TF implementation]
79
+ assert n_state % config.n_head == 0
80
+ self.n_head = config.n_head
81
+ self.split_size = n_state
82
+ self.scale = scale
83
+ self.output_attentions = config.output_attentions
84
+
85
+ self.is_cross_attention = is_cross_attention
86
+
87
+ if self.is_cross_attention:
88
+ self.c_attn = TFConv1D(
89
+ n_state * 2,
90
+ nx,
91
+ initializer_range=config.initializer_range,
92
+ name="c_attn",
93
+ )
94
+ self.q_attn = TFConv1D(
95
+ n_state, nx, initializer_range=config.initializer_range, name="q_attn"
96
+ )
97
+ else:
98
+ self.c_attn = TFConv1D(
99
+ n_state * 3,
100
+ nx,
101
+ initializer_range=config.initializer_range,
102
+ name="c_attn",
103
+ )
104
+
105
+ self.c_proj = TFConv1D(
106
+ n_state, nx, initializer_range=config.initializer_range, name="c_proj"
107
+ )
108
+ self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop)
109
+ self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop)
110
+ self.pruned_heads = set()
111
+
112
+ def prune_heads(self, heads):
113
+ pass
114
+
115
+ @staticmethod
116
+ def causal_attention_mask(nd, ns, dtype):
117
+ """
118
+ 1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
119
+ -1, ns-nd), but doesn't produce garbage on TPUs.
120
+ """
121
+ i = tf.range(nd)[:, None]
122
+ j = tf.range(ns)
123
+ m = i >= j - ns + nd
124
+ return tf.cast(m, dtype)
125
+
126
+ def _attn(
127
+ self, q, k, v, attention_mask, head_mask, output_attentions, training=False
128
+ ):
129
+ # q, k, v have shape [batch, heads, sequence, features]
130
+ w = tf.matmul(q, k, transpose_b=True)
131
+ if self.scale:
132
+ dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores
133
+ w = w / tf.math.sqrt(dk)
134
+
135
+ if not self.is_cross_attention:
136
+ # if only "normal" attention layer implements causal mask
137
+
138
+ # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
139
+ _, _, nd, ns = shape_list(w)
140
+ b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
141
+ b = tf.reshape(b, [1, 1, nd, ns])
142
+ w = w * b - 1e4 * (1 - b)
143
+
144
+ if attention_mask is not None:
145
+ # Apply the attention mask
146
+ attention_mask = tf.cast(attention_mask, dtype=w.dtype)
147
+ w = w + attention_mask
148
+
149
+ w = stable_softmax(w, axis=-1)
150
+ w = self.attn_dropout(w, training=training)
151
+
152
+ # Mask heads if we want to
153
+ if head_mask is not None:
154
+ w = w * head_mask
155
+
156
+ outputs = [tf.matmul(w, v)]
157
+ if output_attentions:
158
+ outputs.append(w)
159
+ return outputs
160
+
161
+ def merge_heads(self, x):
162
+ x = tf.transpose(x, [0, 2, 1, 3])
163
+ x_shape = shape_list(x)
164
+ new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
165
+ return tf.reshape(x, new_x_shape)
166
+
167
+ def split_heads(self, x):
168
+ x_shape = shape_list(x)
169
+ new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
170
+ x = tf.reshape(x, new_x_shape)
171
+ return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
172
+
173
+ def call(
174
+ self,
175
+ x,
176
+ layer_past,
177
+ attention_mask,
178
+ head_mask,
179
+ encoder_hidden_states,
180
+ encoder_attention_mask,
181
+ use_cache,
182
+ output_attentions,
183
+ training=False,
184
+ ):
185
+
186
+ if encoder_hidden_states is not None:
187
+ if not hasattr(self, "q_attn"):
188
+ raise ValueError(
189
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
190
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
191
+ )
192
+
193
+ query = self.q_attn(x)
194
+ kv_out = self.c_attn(encoder_hidden_states)
195
+ key, value = tf.split(kv_out, 2, axis=2)
196
+ attention_mask = encoder_attention_mask
197
+ else:
198
+ x = self.c_attn(x)
199
+ query, key, value = tf.split(x, 3, axis=2)
200
+
201
+ query = self.split_heads(query)
202
+ key = self.split_heads(key)
203
+ value = self.split_heads(value)
204
+ if layer_past is not None:
205
+ past_key, past_value = tf.unstack(layer_past, axis=0)
206
+ key = tf.concat([past_key, key], axis=-2)
207
+ value = tf.concat([past_value, value], axis=-2)
208
+
209
+ # to cope with keras serialization
210
+ if use_cache:
211
+ present = tf.stack([key, value], axis=0)
212
+ else:
213
+ present = (None,)
214
+
215
+ attn_outputs = self._attn(
216
+ query,
217
+ key,
218
+ value,
219
+ attention_mask,
220
+ head_mask,
221
+ output_attentions,
222
+ training=training,
223
+ )
224
+ a = attn_outputs[0]
225
+
226
+ a = self.merge_heads(a)
227
+ a = self.c_proj(a)
228
+ a = self.resid_dropout(a, training=training)
229
+
230
+ outputs = [a, present] + attn_outputs[1:]
231
+ return outputs # a, present, (attentions)
232
+
233
+
234
+ class TFMLP(tf.keras.layers.Layer):
235
+ def __init__(self, n_state, config, **kwargs):
236
+ super().__init__(**kwargs)
237
+ nx = config.n_embd
238
+ self.c_fc = TFConv1D(
239
+ n_state, nx, initializer_range=config.initializer_range, name="c_fc"
240
+ )
241
+ self.c_proj = TFConv1D(
242
+ nx, n_state, initializer_range=config.initializer_range, name="c_proj"
243
+ )
244
+ self.act = get_tf_activation(config.activation_function)
245
+ self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)
246
+
247
+ def call(self, x, training=False):
248
+ h = self.act(self.c_fc(x))
249
+ h2 = self.c_proj(h)
250
+ h2 = self.dropout(h2, training=training)
251
+ return h2
252
+
253
+
254
+ class TFBlock(tf.keras.layers.Layer):
255
+ def __init__(self, config, scale=False, **kwargs):
256
+ super().__init__(**kwargs)
257
+ nx = config.n_embd
258
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * nx
259
+ self.ln_1 = tf.keras.layers.LayerNormalization(
260
+ epsilon=config.layer_norm_epsilon, name="ln_1"
261
+ )
262
+ self.attn = TFAttention(nx, config, scale, name="attn")
263
+ self.ln_2 = tf.keras.layers.LayerNormalization(
264
+ epsilon=config.layer_norm_epsilon, name="ln_2"
265
+ )
266
+
267
+ if config.add_cross_attention:
268
+
269
+ self.crossattention = TFAttention(
270
+ nx, config, scale, name="crossattention", is_cross_attention=True
271
+ )
272
+ self.ln_cross_attn = tf.keras.layers.LayerNormalization(
273
+ epsilon=config.layer_norm_epsilon, name="ln_cross_attn"
274
+ )
275
+
276
+ self.mlp = TFMLP(inner_dim, config, name="mlp")
277
+
278
+ def call(
279
+ self,
280
+ x,
281
+ layer_past,
282
+ attention_mask,
283
+ head_mask,
284
+ encoder_hidden_states,
285
+ encoder_attention_mask,
286
+ use_cache,
287
+ output_attentions,
288
+ training=False,
289
+ ):
290
+ a = self.ln_1(x)
291
+ output_attn = self.attn(
292
+ a,
293
+ layer_past=layer_past,
294
+ attention_mask=attention_mask,
295
+ head_mask=head_mask,
296
+ encoder_hidden_states=None,
297
+ encoder_attention_mask=None,
298
+ use_cache=use_cache,
299
+ output_attentions=output_attentions,
300
+ training=training,
301
+ )
302
+ a = output_attn[0] # output_attn: a, present, (attentions)
303
+ outputs = output_attn[1:]
304
+ x = x + a
305
+
306
+ # Cross-Attention Block
307
+ if encoder_hidden_states is not None:
308
+ # add one self-attention block for cross-attention
309
+ if not hasattr(self, "crossattention"):
310
+ raise ValueError(
311
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
312
+ "cross-attention layers by setting `config.add_cross_attention=True`"
313
+ )
314
+
315
+ ca = self.ln_cross_attn(x)
316
+ output_cross_attn = self.crossattention(
317
+ ca,
318
+ layer_past=None,
319
+ attention_mask=attention_mask,
320
+ head_mask=head_mask,
321
+ encoder_hidden_states=encoder_hidden_states,
322
+ encoder_attention_mask=encoder_attention_mask,
323
+ use_cache=False,
324
+ output_attentions=output_attentions,
325
+ training=training,
326
+ )
327
+ ca = output_cross_attn[0] # output_attn: a, present, (cross_attentions)
328
+ x = x + ca
329
+ outputs = (
330
+ outputs + output_cross_attn[2:]
331
+ ) # add cross attentions if we output attention weights
332
+
333
+ m = self.ln_2(x)
334
+ m = self.mlp(m, training=training)
335
+ x = x + m
336
+
337
+ outputs = [x] + outputs
338
+ return outputs # x, present, (attentions, cross_attentions)
339
+
340
+
341
+ @keras_serializable
342
+ class TFGPT2MainLayer(tf.keras.layers.Layer):
343
+ config_class = GPT2Config
344
+
345
+ def __init__(self, config, *inputs, **kwargs):
346
+ super().__init__(*inputs, **kwargs)
347
+
348
+ self.config = config
349
+ self.output_attentions = config.output_attentions
350
+ self.output_hidden_states = config.output_hidden_states
351
+ self.use_cache = config.use_cache
352
+ self.return_dict = config.use_return_dict
353
+
354
+ self.num_hidden_layers = config.n_layer
355
+ self.vocab_size = config.vocab_size
356
+ self.n_embd = config.n_embd
357
+ self.n_positions = config.n_positions
358
+ self.initializer_range = config.initializer_range
359
+
360
+ self.wte = TFSharedEmbeddings(
361
+ config.vocab_size,
362
+ config.hidden_size,
363
+ initializer_range=config.initializer_range,
364
+ name="wte",
365
+ )
366
+
367
+ self.wte_remaining_frames = TFSharedEmbeddings(
368
+ config.vocab_size,
369
+ config.hidden_size,
370
+ initializer_range=config.initializer_range,
371
+ name="wte_remaining_frames",
372
+ )
373
+ self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
374
+ self.h = [
375
+ TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)
376
+ ]
377
+ self.ln_f = tf.keras.layers.LayerNormalization(
378
+ epsilon=config.layer_norm_epsilon, name="ln_f"
379
+ )
380
+
381
+ def build(self, input_shape):
382
+ with tf.name_scope("wpe"):
383
+ self.wpe = self.add_weight(
384
+ name="embeddings",
385
+ shape=[self.n_positions, self.n_embd],
386
+ initializer=get_initializer(self.initializer_range),
387
+ )
388
+ self.wte_remaining_frames.build(input_shape)
389
+
390
+ super().build(input_shape)
391
+
392
+ def get_input_embeddings(self):
393
+ return self.wte
394
+
395
+ def get_remaining_frames_embeddings(self):
396
+ return self.wte_remaining_frames
397
+
398
+ def set_input_embeddings(self, value):
399
+ self.wte.weight = value
400
+ self.wte.vocab_size = shape_list(value)[0]
401
+
402
+ def set_remaining_frames_embeddings(self, value):
403
+ self.wte_remaining_frames.weight = value
404
+ self.wte_remaining_frames.vocab_size = shape_list(value)[0]
405
+
406
+ def _prune_heads(self, heads_to_prune):
407
+ """
408
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
409
+ """
410
+ raise NotImplementedError
411
+
412
+ @unpack_inputs
413
+ def call(
414
+ self,
415
+ input_ids: Optional[TFModelInputType] = None,
416
+ remaining_frames_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
417
+ past: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
418
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
419
+ token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
420
+ position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
421
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
422
+ inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
423
+ encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
424
+ encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
425
+ use_cache: Optional[bool] = None,
426
+ output_attentions: Optional[bool] = None,
427
+ output_hidden_states: Optional[bool] = None,
428
+ return_dict: Optional[bool] = None,
429
+ training: Optional[bool] = False,
430
+ ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
431
+
432
+ if input_ids is not None and inputs_embeds is not None:
433
+ raise ValueError(
434
+ "You cannot specify both input_ids and inputs_embeds at the same time"
435
+ )
436
+ elif input_ids is not None:
437
+ input_shape = shape_list(input_ids)
438
+ input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
439
+ elif inputs_embeds is not None:
440
+ input_shape = shape_list(inputs_embeds)[:-1]
441
+ else:
442
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
443
+
444
+ if past is None:
445
+ past_length = 0
446
+ past = [None] * len(self.h)
447
+ else:
448
+ past_length = shape_list(past[0][0])[-2]
449
+
450
+ if position_ids is None:
451
+ position_ids = tf.expand_dims(
452
+ tf.range(past_length, input_shape[-1] + past_length), axis=0
453
+ )
454
+
455
+ if attention_mask is not None:
456
+ # We create a 3D attention mask from a 2D tensor mask.
457
+ # Sizes are [batch_size, 1, 1, to_seq_length]
458
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
459
+ # this attention mask is more simple than the triangular masking of causal attention
460
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
461
+ attention_mask_shape = shape_list(attention_mask)
462
+ attention_mask = tf.reshape(
463
+ attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
464
+ )
465
+
466
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
467
+ # masked positions, this operation will create a tensor which is 0.0 for
468
+ # positions we want to attend and -10000.0 for masked positions.
469
+ # Since we are adding it to the raw scores before the softmax, this is
470
+ # effectively the same as removing these entirely.
471
+ one_cst = tf.constant(1.0)
472
+ attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype)
473
+ attention_mask = tf.multiply(
474
+ tf.subtract(one_cst, attention_mask), tf.constant(-10000.0)
475
+ )
476
+
477
+ # Copied from `modeling_tf_t5.py` with -1e9 -> -10000
478
+ if self.config.add_cross_attention and encoder_attention_mask is not None:
479
+ # If a 2D ou 3D attention mask is provided for the cross-attention
480
+ # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
481
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
482
+ encoder_attention_mask = tf.cast(
483
+ encoder_attention_mask, dtype=encoder_hidden_states.dtype
484
+ )
485
+ num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
486
+ if num_dims_encoder_attention_mask == 3:
487
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
488
+ if num_dims_encoder_attention_mask == 2:
489
+ encoder_extended_attention_mask = encoder_attention_mask[
490
+ :, None, None, :
491
+ ]
492
+
493
+ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
494
+ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
495
+ # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
496
+ # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
497
+
498
+ encoder_extended_attention_mask = (
499
+ 1.0 - encoder_extended_attention_mask
500
+ ) * -10000.0
501
+ else:
502
+ encoder_extended_attention_mask = None
503
+
504
+ encoder_attention_mask = encoder_extended_attention_mask
505
+
506
+ # Prepare head mask if needed
507
+ # 1.0 in head_mask indicate we keep the head
508
+ # attention_probs has shape bsz x n_heads x N x N
509
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
510
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
511
+ if head_mask is not None:
512
+ raise NotImplementedError
513
+ else:
514
+ head_mask = [None] * self.num_hidden_layers
515
+ # head_mask = tf.constant([0] * self.num_hidden_layers)
516
+
517
+ position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
518
+
519
+ if inputs_embeds is None:
520
+ inputs_embeds = self.wte(input_ids, mode="embedding")
521
+
522
+ position_embeds = tf.gather(self.wpe, position_ids)
523
+
524
+ if token_type_ids is not None:
525
+ token_type_ids = tf.reshape(
526
+ token_type_ids, [-1, shape_list(token_type_ids)[-1]]
527
+ )
528
+ token_type_embeds = self.wte(token_type_ids, mode="embedding")
529
+ else:
530
+ token_type_embeds = tf.constant(0.0)
531
+
532
+ if remaining_frames_ids is not None:
533
+ remaining_frames_ids = tf.reshape(
534
+ remaining_frames_ids, [-1, shape_list(remaining_frames_ids)[-1]]
535
+ )
536
+ remaining_frames_embeds = self.wte_remaining_frames(
537
+ remaining_frames_ids, mode="embedding"
538
+ )
539
+ else:
540
+ remaining_frames_embeds = tf.constant(0.0)
541
+
542
+ position_embeds = tf.cast(position_embeds, dtype=inputs_embeds.dtype)
543
+ token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype)
544
+ remaining_frames_embeds = tf.cast(
545
+ remaining_frames_embeds, dtype=inputs_embeds.dtype
546
+ )
547
+ hidden_states = (
548
+ inputs_embeds
549
+ + position_embeds
550
+ + token_type_embeds
551
+ + remaining_frames_embeds
552
+ )
553
+ hidden_states = self.drop(hidden_states, training=training)
554
+
555
+ output_shape = input_shape + [shape_list(hidden_states)[-1]]
556
+
557
+ presents = () if use_cache else None
558
+ all_attentions = () if output_attentions else None
559
+ all_cross_attentions = (
560
+ () if output_attentions and self.config.add_cross_attention else None
561
+ )
562
+ all_hidden_states = () if output_hidden_states else None
563
+ for i, (block, layer_past) in enumerate(zip(self.h, past)):
564
+ if output_hidden_states:
565
+ all_hidden_states = all_hidden_states + (
566
+ tf.reshape(hidden_states, output_shape),
567
+ )
568
+
569
+ outputs = block(
570
+ hidden_states,
571
+ layer_past,
572
+ attention_mask,
573
+ head_mask[i],
574
+ encoder_hidden_states,
575
+ encoder_attention_mask,
576
+ use_cache,
577
+ output_attentions,
578
+ training=training,
579
+ )
580
+
581
+ hidden_states, present = outputs[:2]
582
+ if use_cache:
583
+ presents = presents + (present,)
584
+
585
+ if output_attentions:
586
+ all_attentions = all_attentions + (outputs[2],)
587
+ if (
588
+ self.config.add_cross_attention
589
+ and encoder_hidden_states is not None
590
+ ):
591
+ all_cross_attentions = all_cross_attentions + (outputs[3],)
592
+
593
+ hidden_states = self.ln_f(hidden_states)
594
+
595
+ hidden_states = tf.reshape(hidden_states, output_shape)
596
+ # Add last hidden state
597
+ if output_hidden_states:
598
+ all_hidden_states = all_hidden_states + (hidden_states,)
599
+
600
+ if output_attentions:
601
+ # let the number of heads free (-1) so we can extract attention even after head pruning
602
+ attention_output_shape = (
603
+ input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
604
+ )
605
+ all_attentions = tuple(
606
+ tf.reshape(t, attention_output_shape) for t in all_attentions
607
+ )
608
+
609
+ if not return_dict:
610
+ return tuple(
611
+ v
612
+ for v in [
613
+ hidden_states,
614
+ presents,
615
+ all_hidden_states,
616
+ all_attentions,
617
+ all_cross_attentions,
618
+ ]
619
+ if v is not None
620
+ )
621
+
622
+ return TFBaseModelOutputWithPastAndCrossAttentions(
623
+ last_hidden_state=hidden_states,
624
+ past_key_values=presents,
625
+ hidden_states=all_hidden_states,
626
+ attentions=all_attentions,
627
+ cross_attentions=all_cross_attentions,
628
+ )
629
+
630
+
631
+ class TFGPT2PreTrainedModel(TFPreTrainedModel):
632
+ """
633
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
634
+ models.
635
+ """
636
+
637
+ config_class = GPT2Config
638
+ base_model_prefix = "transformer"
639
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
640
+ _keys_to_ignore_on_load_unexpected = [
641
+ r"h.\d+.attn.bias",
642
+ r"h.\d+.crossattention.bias",
643
+ ]
644
+
645
+ @property
646
+ def dummy_inputs(self):
647
+ """
648
+ Dummy inputs to build the network.
649
+
650
+ Returns:
651
+ `Dict[str, tf.Tensor]`: The dummy inputs.
652
+ """
653
+ dummy = {"input_ids": tf.constant(DUMMY_INPUTS)}
654
+ # Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
655
+ if self.config.add_cross_attention:
656
+ batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape
657
+ shape = (batch_size, seq_len) + (self.config.hidden_size,)
658
+ h = tf.random.uniform(shape=shape)
659
+ dummy["encoder_hidden_states"] = h
660
+
661
+ return dummy
662
+
663
+ @tf.function(
664
+ input_signature=[
665
+ {
666
+ "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
667
+ "attention_mask": tf.TensorSpec(
668
+ (None, None), tf.int32, name="attention_mask"
669
+ ),
670
+ }
671
+ ]
672
+ )
673
+ def serving(self, inputs):
674
+ output = self.call(inputs)
675
+
676
+ return self.serving_output(output)
677
+
678
+
679
+ GPT2_START_DOCSTRING = r"""
680
+
681
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
682
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
683
+ etc.)
684
+
685
+ This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
686
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
687
+ behavior.
688
+
689
+ <Tip>
690
+
691
+ TF 2.0 models accepts two formats as inputs:
692
+
693
+ - having all inputs as keyword arguments (like PyTorch models), or
694
+ - having all inputs as a list, tuple or dict in the first positional arguments.
695
+
696
+ This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the
697
+ tensors in the first argument of the model call function: `model(inputs)`.
698
+
699
+ If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the
700
+ first positional argument :
701
+
702
+ - a single Tensor with `input_ids` only and nothing else: `model(inputs_ids)`
703
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
704
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
705
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
706
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
707
+
708
+ </Tip>
709
+
710
+ Parameters:
711
+ config ([`GPT2Config`]): Model configuration class with all the parameters of the model.
712
+ Initializing with a config file does not load the weights associated with the model, only the
713
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
714
+ """
715
+
716
+ GPT2_INPUTS_DOCSTRING = r"""
717
+ Args:
718
+ input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, input_ids_length)`):
719
+ `input_ids_length` = `sequence_length` if `past` is `None` else `past[0].shape[-2]` (`sequence_length` of
720
+ input past key value states). Indices of input sequence tokens in the vocabulary.
721
+
722
+ If `past` is used, only input IDs that do not have their past calculated should be passed as `input_ids`.
723
+
724
+ Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.__call__`] and
725
+ [`PreTrainedTokenizer.encode`] for details.
726
+
727
+ [What are input IDs?](../glossary#input-ids)
728
+ past (`List[tf.Tensor]` of length `config.n_layers`):
729
+ Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
730
+ `past` output below). Can be used to speed up sequential decoding. The token ids which have their past
731
+ given to this model should not be passed as input ids as they have already been computed.
732
+ attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
733
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
734
+
735
+ - 1 for tokens that are **not masked**,
736
+ - 0 for tokens that are **masked**.
737
+
738
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
739
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
740
+ `len(past_key_values) + len(input_ids)`
741
+
742
+ [What are attention masks?](../glossary#attention-mask)
743
+ token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
744
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
745
+ 1]`:
746
+
747
+ - 0 corresponds to a *sentence A* token,
748
+ - 1 corresponds to a *sentence B* token.
749
+
750
+ [What are token type IDs?](../glossary#token-type-ids)
751
+ position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
752
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
753
+ config.max_position_embeddings - 1]`.
754
+
755
+ [What are position IDs?](../glossary#position-ids)
756
+ head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
757
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
758
+
759
+ - 1 indicates the head is **not masked**,
760
+ - 0 indicates the head is **masked**.
761
+
762
+ inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
763
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
764
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
765
+ model's internal embedding lookup matrix.
766
+ output_attentions (`bool`, *optional*):
767
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
768
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
769
+ config will be used instead.
770
+ output_hidden_states (`bool`, *optional*):
771
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
772
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
773
+ used instead.
774
+ return_dict (`bool`, *optional*):
775
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
776
+ eager mode, in graph mode the value will always be set to True.
777
+ training (`bool`, *optional*, defaults to `False`):
778
+ Whether or not to use the model in training mode (some modules like dropout modules have different
779
+ behaviors between training and evaluation).
780
+ """
781
+
782
+
783
+ @add_start_docstrings(
784
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
785
+ GPT2_START_DOCSTRING,
786
+ )
787
+ class TFGPT2Model(TFGPT2PreTrainedModel):
788
+ def __init__(self, config, *inputs, **kwargs):
789
+ super().__init__(config, *inputs, **kwargs)
790
+ self.transformer = TFGPT2MainLayer(config, name="transformer")
791
+
792
+ @unpack_inputs
793
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
794
+ @add_code_sample_docstrings(
795
+ processor_class=_TOKENIZER_FOR_DOC,
796
+ checkpoint=_CHECKPOINT_FOR_DOC,
797
+ output_type=TFBaseModelOutputWithPastAndCrossAttentions,
798
+ config_class=_CONFIG_FOR_DOC,
799
+ )
800
+ def call(
801
+ self,
802
+ input_ids: Optional[TFModelInputType] = None,
803
+ remaining_frames_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
804
+ past: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
805
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
806
+ token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
807
+ position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
808
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
809
+ inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
810
+ encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
811
+ encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
812
+ use_cache: Optional[bool] = None,
813
+ output_attentions: Optional[bool] = None,
814
+ output_hidden_states: Optional[bool] = None,
815
+ return_dict: Optional[bool] = None,
816
+ training: Optional[bool] = False,
817
+ ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
818
+ r"""
819
+ encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
820
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
821
+ the model is configured as a decoder.
822
+ encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
823
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
824
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
825
+
826
+ - 1 for tokens that are **not masked**,
827
+ - 0 for tokens that are **masked**.
828
+
829
+ past (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
830
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
831
+ If `past` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have
832
+ their past key value states given to this model) of shape `(batch_size, 1)` instead of all
833
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
834
+ use_cache (`bool`, *optional*, defaults to `True`):
835
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
836
+ `past`). Set to `False` during training, `True` during generation
837
+ """
838
+
839
+ outputs = self.transformer(
840
+ input_ids=input_ids,
841
+ remaining_frames_ids=remaining_frames_ids,
842
+ past=past,
843
+ attention_mask=attention_mask,
844
+ token_type_ids=token_type_ids,
845
+ position_ids=position_ids,
846
+ head_mask=head_mask,
847
+ inputs_embeds=inputs_embeds,
848
+ encoder_hidden_states=encoder_hidden_states,
849
+ encoder_attention_mask=encoder_attention_mask,
850
+ use_cache=use_cache,
851
+ output_attentions=output_attentions,
852
+ output_hidden_states=output_hidden_states,
853
+ return_dict=return_dict,
854
+ training=training,
855
+ )
856
+
857
+ return outputs
858
+
859
+ def serving_output(self, output):
860
+ pkv = (
861
+ tf.convert_to_tensor(output.past_key_values)
862
+ if self.config.use_cache
863
+ else None
864
+ )
865
+ hs = (
866
+ tf.convert_to_tensor(output.hidden_states)
867
+ if self.config.output_hidden_states
868
+ else None
869
+ )
870
+ attns = (
871
+ tf.convert_to_tensor(output.attentions)
872
+ if self.config.output_attentions
873
+ else None
874
+ )
875
+ cross_attns = (
876
+ tf.convert_to_tensor(output.cross_attentions)
877
+ if self.config.output_attentions
878
+ and self.config.add_cross_attention
879
+ and output.cross_attentions is not None
880
+ else None
881
+ )
882
+
883
+ return TFBaseModelOutputWithPastAndCrossAttentions(
884
+ last_hidden_state=output.last_hidden_state,
885
+ past_key_values=pkv,
886
+ hidden_states=hs,
887
+ attentions=attns,
888
+ cross_attentions=cross_attns,
889
+ )
890
+
891
+
892
+ @add_start_docstrings(
893
+ """
894
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
895
+ embeddings).
896
+ """,
897
+ GPT2_START_DOCSTRING,
898
+ )
899
+ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
900
+ def __init__(self, config, *inputs, **kwargs):
901
+ super().__init__(config, *inputs, **kwargs)
902
+ self.transformer = TFGPT2MainLayer(config, name="transformer")
903
+
904
+ def get_output_embeddings(self):
905
+ return self.get_input_embeddings()
906
+
907
+ def set_output_embeddings(self, value):
908
+ self.set_input_embeddings(value)
909
+
910
+ def prepare_inputs_for_generation(
911
+ self, inputs, past=None, use_cache=None, use_xla=False, **kwargs
912
+ ):
913
+ # TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
914
+ # tests will need to be fixed after the change
915
+
916
+ # only last token for inputs_ids if past is defined in kwargs
917
+ if past:
918
+ inputs = tf.expand_dims(inputs[:, -1], -1)
919
+
920
+ # TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left
921
+ # for a future PR to not change too many things for now.
922
+ # All statements in this if case apply for both xla and non-xla (as they already do in PyTorch)
923
+ position_ids = None
924
+ attention_mask = None
925
+ if use_xla:
926
+ attention_mask = kwargs.get("attention_mask", None)
927
+ if past is not None and attention_mask is not None:
928
+ position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1
929
+ elif attention_mask is not None:
930
+ position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True)
931
+
932
+ return {
933
+ "input_ids": inputs,
934
+ "attention_mask": attention_mask,
935
+ "position_ids": position_ids,
936
+ "past": past,
937
+ "use_cache": use_cache,
938
+ }
939
+
940
+ def _update_model_kwargs_for_xla_generation(
941
+ self, outputs, model_kwargs, current_pos, max_length
942
+ ):
943
+ # TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
944
+ # quite some duplicated code patterns it seems
945
+ # also the `attention_mask` is currently used in a somewhat hacky to
946
+ # correctly influence the `past_key_values` - not sure if this is the way to go
947
+ # Let's keep that for a future PR.
948
+ past = outputs.past_key_values
949
+ is_past_initialized = model_kwargs.pop("past", None) is not None
950
+ attention_mask = model_kwargs.pop("attention_mask")
951
+ batch_size = attention_mask.shape[0]
952
+
953
+ if not is_past_initialized:
954
+ # past[0].shape[3] is seq_length of prompt
955
+ num_padding_values = max_length - past[0].shape[3] - 1
956
+
957
+ padding_values = np.zeros((5, 2), dtype=np.int32)
958
+ padding_values[3, 1] = num_padding_values
959
+ padding_values = tf.constant(padding_values)
960
+
961
+ new_past = list(past)
962
+ for i in range(len(past)):
963
+ new_past[i] = tf.pad(past[i], padding_values)
964
+
965
+ # Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids
966
+ attention_mask = tf.concat(
967
+ [
968
+ attention_mask,
969
+ tf.zeros(
970
+ (batch_size, num_padding_values), dtype=attention_mask.dtype
971
+ ),
972
+ tf.ones((batch_size, 1), dtype=attention_mask.dtype),
973
+ ],
974
+ axis=1,
975
+ )
976
+ else:
977
+ new_past = [None for _ in range(len(past))]
978
+ slice_start_base = tf.constant([0, 0, 0, 1, 0])
979
+ attention_mask_update_slice = tf.ones(
980
+ (batch_size, 1), dtype=attention_mask.dtype
981
+ )
982
+ # correct 5 here
983
+ new_past_index = current_pos - 1
984
+
985
+ for i in range(len(past)):
986
+ update_slice = past[i][:, :, :, -1:]
987
+ # Write the last slice to the first open location in the padded past array
988
+ # and then truncate the last slice off the array
989
+ new_past[i] = dynamic_update_slice(
990
+ past[i][:, :, :, :-1],
991
+ update_slice,
992
+ slice_start_base * new_past_index,
993
+ )
994
+
995
+ update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
996
+ attention_mask = dynamic_update_slice(
997
+ attention_mask, attention_mask_update_slice, update_start
998
+ )
999
+
1000
+ # set `attention_mask` and `past`
1001
+ model_kwargs["attention_mask"] = attention_mask
1002
+ model_kwargs["past"] = tuple(new_past)
1003
+
1004
+ return model_kwargs
1005
+
1006
+ @unpack_inputs
1007
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1008
+ @add_code_sample_docstrings(
1009
+ processor_class=_TOKENIZER_FOR_DOC,
1010
+ checkpoint=_CHECKPOINT_FOR_DOC,
1011
+ output_type=TFCausalLMOutputWithCrossAttentions,
1012
+ config_class=_CONFIG_FOR_DOC,
1013
+ )
1014
+ def call(
1015
+ self,
1016
+ input_ids: Optional[TFModelInputType] = None,
1017
+ remaining_frames_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
1018
+ past: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
1019
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
1020
+ token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
1021
+ position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
1022
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
1023
+ inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
1024
+ encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
1025
+ encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
1026
+ use_cache: Optional[bool] = None,
1027
+ output_attentions: Optional[bool] = None,
1028
+ output_hidden_states: Optional[bool] = None,
1029
+ return_dict: Optional[bool] = None,
1030
+ labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
1031
+ training: Optional[bool] = False,
1032
+ ) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]:
1033
+ r"""
1034
+ encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1035
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1036
+ the model is configured as a decoder.
1037
+ encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1038
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1039
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1040
+
1041
+ - 1 for tokens that are **not masked**,
1042
+ - 0 for tokens that are **masked**.
1043
+
1044
+ past (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
1045
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1046
+ If `past` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have
1047
+ their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1048
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1049
+ use_cache (`bool`, *optional*, defaults to `True`):
1050
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1051
+ `past`). Set to `False` during training, `True` during generation
1052
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1053
+ Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
1054
+ config.vocab_size - 1]`.
1055
+ """
1056
+
1057
+ transformer_outputs = self.transformer(
1058
+ input_ids=input_ids,
1059
+ remaining_frames_ids=remaining_frames_ids,
1060
+ past=past,
1061
+ attention_mask=attention_mask,
1062
+ token_type_ids=token_type_ids,
1063
+ position_ids=position_ids,
1064
+ head_mask=head_mask,
1065
+ inputs_embeds=inputs_embeds,
1066
+ encoder_hidden_states=encoder_hidden_states,
1067
+ encoder_attention_mask=encoder_attention_mask,
1068
+ use_cache=use_cache,
1069
+ output_attentions=output_attentions,
1070
+ output_hidden_states=output_hidden_states,
1071
+ return_dict=return_dict,
1072
+ training=training,
1073
+ )
1074
+ hidden_states = transformer_outputs[0]
1075
+ logits = self.transformer.wte(hidden_states, mode="linear")
1076
+
1077
+ loss = None
1078
+ if labels is not None:
1079
+ # shift labels to the left and cut last logit token
1080
+ shifted_logits = logits[:, :-1]
1081
+ labels = labels[:, 1:]
1082
+ loss = self.hf_compute_loss(labels, shifted_logits)
1083
+
1084
+ if not return_dict:
1085
+ output = (logits,) + transformer_outputs[1:]
1086
+ return ((loss,) + output) if loss is not None else output
1087
+
1088
+ return TFCausalLMOutputWithCrossAttentions(
1089
+ loss=loss,
1090
+ logits=logits,
1091
+ past_key_values=transformer_outputs.past_key_values,
1092
+ hidden_states=transformer_outputs.hidden_states,
1093
+ attentions=transformer_outputs.attentions,
1094
+ cross_attentions=transformer_outputs.cross_attentions,
1095
+ )
1096
+
1097
+ def serving_output(self, output):
1098
+ pkv = (
1099
+ tf.convert_to_tensor(output.past_key_values)
1100
+ if self.config.use_cache
1101
+ else None
1102
+ )
1103
+ hs = (
1104
+ tf.convert_to_tensor(output.hidden_states)
1105
+ if self.config.output_hidden_states
1106
+ else None
1107
+ )
1108
+ attns = (
1109
+ tf.convert_to_tensor(output.attentions)
1110
+ if self.config.output_attentions
1111
+ else None
1112
+ )
1113
+ cross_attns = (
1114
+ tf.convert_to_tensor(output.cross_attentions)
1115
+ if self.config.output_attentions
1116
+ and self.config.add_cross_attention
1117
+ and output.cross_attentions is not None
1118
+ else None
1119
+ )
1120
+
1121
+ return TFCausalLMOutputWithCrossAttentions(
1122
+ logits=output.logits,
1123
+ past_key_values=pkv,
1124
+ hidden_states=hs,
1125
+ attentions=attns,
1126
+ cross_attentions=cross_attns,
1127
+ )