Upload 84 files
Browse filesAdd application files
This view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +4 -0
- .gitattributes +1 -0
- .gitignore +208 -0
- Dockerfile +21 -0
- configs/colab.yaml +50 -0
- configs/kny_image.yaml +47 -0
- configs/kny_image_full_style.yaml +47 -0
- configs/kny_image_full_vgg19.yaml +47 -0
- configs/kny_transformer_light.yaml +60 -0
- configs/kny_video_gpt2_large.yaml +50 -0
- configs/kny_video_gpt2_large_gradio.yaml +50 -0
- configs/kny_video_gpt2_medium.yaml +50 -0
- configs/kny_video_gpt2_xl.yaml +50 -0
- ganime/__main__.py +4 -0
- ganime/app.py +212 -0
- ganime/configs/__init__.py +0 -0
- ganime/configs/model_configs.py +70 -0
- ganime/data/__init__.py +0 -0
- ganime/data/base.py +282 -0
- ganime/data/experimental.py +222 -0
- ganime/data/kny.py +19 -0
- ganime/data/mnist.py +103 -0
- ganime/metrics/image.py +70 -0
- ganime/metrics/video.py +98 -0
- ganime/model/__init__.py +0 -0
- ganime/model/base.py +45 -0
- ganime/model/moving_vae.py +126 -0
- ganime/model/p2p/__init__.py +0 -0
- ganime/model/p2p/p2p.py +543 -0
- ganime/model/p2p/p2p_test.py +713 -0
- ganime/model/p2p/p2p_v2.py +498 -0
- ganime/model/p2p/p2p_v3.py +237 -0
- ganime/model/vae/vae.py +98 -0
- ganime/model/vq_vae/vq_vae.py +143 -0
- ganime/model/vqgan/__init__.py +0 -0
- ganime/model/vqgan/discriminator/__init__.py +0 -0
- ganime/model/vqgan/discriminator/model.py +64 -0
- ganime/model/vqgan/losses/__init__.py +0 -0
- ganime/model/vqgan/losses/lpips.py +134 -0
- ganime/model/vqgan/losses/vqperceptual.py +47 -0
- ganime/model/vqgan/vqgan.py +722 -0
- ganime/model/vqgan_clean/__init__.py +0 -0
- ganime/model/vqgan_clean/diffusion/__init__.py +0 -0
- ganime/model/vqgan_clean/diffusion/decoder.py +115 -0
- ganime/model/vqgan_clean/diffusion/encoder.py +125 -0
- ganime/model/vqgan_clean/diffusion/layers.py +179 -0
- ganime/model/vqgan_clean/discriminator/__init__.py +0 -0
- ganime/model/vqgan_clean/discriminator/model.py +88 -0
- ganime/model/vqgan_clean/discriminator/model_bkp.py +76 -0
- ganime/model/vqgan_clean/experimental/gpt2_embedding.py +1127 -0
.dockerignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.git
|
2 |
+
data
|
3 |
+
checkpoints
|
4 |
+
logs
|
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
models/vgg19/imagenet-vgg-verydeep-19.mat filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,venv
|
3 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python,venv
|
4 |
+
|
5 |
+
### Python ###
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
|
11 |
+
# C extensions
|
12 |
+
*.so
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
cover/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
db.sqlite3-journal
|
68 |
+
|
69 |
+
# Flask stuff:
|
70 |
+
instance/
|
71 |
+
.webassets-cache
|
72 |
+
|
73 |
+
# Scrapy stuff:
|
74 |
+
.scrapy
|
75 |
+
|
76 |
+
# Sphinx documentation
|
77 |
+
docs/_build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
.pybuilder/
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
# For a library or package, you might want to ignore these files since the code is
|
92 |
+
# intended to run in multiple environments; otherwise, check them in:
|
93 |
+
# .python-version
|
94 |
+
|
95 |
+
# pipenv
|
96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
99 |
+
# install all needed dependencies.
|
100 |
+
#Pipfile.lock
|
101 |
+
|
102 |
+
# poetry
|
103 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
104 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
105 |
+
# commonly ignored for libraries.
|
106 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
107 |
+
#poetry.lock
|
108 |
+
|
109 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
110 |
+
__pypackages__/
|
111 |
+
|
112 |
+
# Celery stuff
|
113 |
+
celerybeat-schedule
|
114 |
+
celerybeat.pid
|
115 |
+
|
116 |
+
# SageMath parsed files
|
117 |
+
*.sage.py
|
118 |
+
|
119 |
+
# Environments
|
120 |
+
.env
|
121 |
+
.venv
|
122 |
+
env/
|
123 |
+
venv/
|
124 |
+
ENV/
|
125 |
+
env.bak/
|
126 |
+
venv.bak/
|
127 |
+
|
128 |
+
# Spyder project settings
|
129 |
+
.spyderproject
|
130 |
+
.spyproject
|
131 |
+
|
132 |
+
# Rope project settings
|
133 |
+
.ropeproject
|
134 |
+
|
135 |
+
# mkdocs documentation
|
136 |
+
/site
|
137 |
+
|
138 |
+
# mypy
|
139 |
+
.mypy_cache/
|
140 |
+
.dmypy.json
|
141 |
+
dmypy.json
|
142 |
+
|
143 |
+
# Pyre type checker
|
144 |
+
.pyre/
|
145 |
+
|
146 |
+
# pytype static type analyzer
|
147 |
+
.pytype/
|
148 |
+
|
149 |
+
# Cython debug symbols
|
150 |
+
cython_debug/
|
151 |
+
|
152 |
+
# PyCharm
|
153 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
154 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
155 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
156 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
157 |
+
#.idea/
|
158 |
+
|
159 |
+
### venv ###
|
160 |
+
# Virtualenv
|
161 |
+
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
|
162 |
+
[Bb]in
|
163 |
+
[Ii]nclude
|
164 |
+
[Ll]ib
|
165 |
+
[Ll]ib64
|
166 |
+
[Ll]ocal
|
167 |
+
#[Ss]cripts
|
168 |
+
pyvenv.cfg
|
169 |
+
pip-selfcheck.json
|
170 |
+
|
171 |
+
### VisualStudioCode ###
|
172 |
+
.vscode/*
|
173 |
+
# !.vscode/settings.json
|
174 |
+
# !.vscode/tasks.json
|
175 |
+
# !.vscode/launch.json
|
176 |
+
# !.vscode/extensions.json
|
177 |
+
# !.vscode/*.code-snippets
|
178 |
+
|
179 |
+
# Local History for Visual Studio Code
|
180 |
+
.history/
|
181 |
+
|
182 |
+
# Built Visual Studio Code Extensions
|
183 |
+
*.vsix
|
184 |
+
|
185 |
+
### VisualStudioCode Patch ###
|
186 |
+
# Ignore all local history of files
|
187 |
+
.history
|
188 |
+
.ionide
|
189 |
+
|
190 |
+
# Support for Project snippet scope
|
191 |
+
|
192 |
+
# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,venv
|
193 |
+
|
194 |
+
*.npy
|
195 |
+
checkpoints/*
|
196 |
+
ganime_results/*
|
197 |
+
data/*
|
198 |
+
*.avi
|
199 |
+
*.out
|
200 |
+
notebooks/model/p2p_v2/*
|
201 |
+
logs/*
|
202 |
+
interesting_logs/*
|
203 |
+
notebooks/model/vq-gan/train_output/*
|
204 |
+
notebooks/model/vq-gan/validation_output/*
|
205 |
+
notebooks/model/vq-gan/test_output/*
|
206 |
+
*.zip
|
207 |
+
flagged/*
|
208 |
+
notebooks/model/vq-gan/gpt_kny_light_large_256/*
|
Dockerfile
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM tensorflow/tensorflow:2.7.0-gpu-jupyter
|
2 |
+
# Because of https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/ and https://github.com/NVIDIA/nvidia-docker/issues/1631#issuecomment-1112828208
|
3 |
+
RUN rm /etc/apt/sources.list.d/cuda.list
|
4 |
+
RUN rm /etc/apt/sources.list.d/nvidia-ml.list
|
5 |
+
RUN apt-key del 7fa2af80
|
6 |
+
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub
|
7 |
+
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu2004/x86_64/7fa2af80.pub
|
8 |
+
|
9 |
+
# Update and install ffmpeg
|
10 |
+
RUN apt-get -y update
|
11 |
+
RUN apt-get -y upgrade
|
12 |
+
RUN apt-get install -y ffmpeg
|
13 |
+
|
14 |
+
# Setup environment
|
15 |
+
WORKDIR /GANime
|
16 |
+
ENV PROJECT_DIR=/GANime
|
17 |
+
COPY requirements.txt /GANime/requirements.txt
|
18 |
+
RUN pip install -r requirements.txt
|
19 |
+
COPY . .
|
20 |
+
RUN pip install -e .
|
21 |
+
EXPOSE 8888
|
configs/colab.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
transformer_config:
|
3 |
+
#checkpoint_path: GANime/checkpoints/kny_video_full_gpt2_medium/checkpoint
|
4 |
+
remaining_frames_method: "own_embeddings"
|
5 |
+
transformer_type: "gpt2-medium"
|
6 |
+
first_stage_config:
|
7 |
+
checkpoint_path: GANime/checkpoints/kny_image_full_vgg19/checkpoint
|
8 |
+
vqvae_config:
|
9 |
+
beta: 0.25
|
10 |
+
num_embeddings: 50257
|
11 |
+
embedding_dim: 128
|
12 |
+
autoencoder_config:
|
13 |
+
z_channels: 512
|
14 |
+
channels: 32
|
15 |
+
channels_multiplier:
|
16 |
+
- 2
|
17 |
+
- 4
|
18 |
+
- 8
|
19 |
+
- 8
|
20 |
+
num_res_blocks: 1
|
21 |
+
attention_resolution:
|
22 |
+
- 16
|
23 |
+
resolution: 128
|
24 |
+
dropout: 0.0
|
25 |
+
discriminator_config:
|
26 |
+
num_layers: 3
|
27 |
+
filters: 64
|
28 |
+
|
29 |
+
loss_config:
|
30 |
+
discriminator:
|
31 |
+
loss: "hinge"
|
32 |
+
factor: 1.0
|
33 |
+
iter_start: 16200
|
34 |
+
weight: 0.3
|
35 |
+
vqvae:
|
36 |
+
codebook_weight: 1.0
|
37 |
+
perceptual_weight: 4.0
|
38 |
+
perceptual_loss: "vgg19"
|
39 |
+
|
40 |
+
train:
|
41 |
+
batch_size: 64
|
42 |
+
accumulation_size: 1
|
43 |
+
n_epochs: 2000
|
44 |
+
len_x_train: 8000
|
45 |
+
warmup_epoch_percentage: 0.15
|
46 |
+
lr_start: 1e-5
|
47 |
+
lr_max: 2.5e-4
|
48 |
+
perceptual_loss_weight: 1.0
|
49 |
+
n_frames_before: 1
|
50 |
+
stop_ground_truth_after_epoch: 50
|
configs/kny_image.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
checkpoint_path: ../../../checkpoints/kny_image_full_no_disc/checkpoint
|
3 |
+
vqvae_config:
|
4 |
+
beta: 0.25
|
5 |
+
num_embeddings: 50257
|
6 |
+
embedding_dim: 128
|
7 |
+
autoencoder_config:
|
8 |
+
z_channels: 512
|
9 |
+
channels: 32
|
10 |
+
channels_multiplier:
|
11 |
+
- 2
|
12 |
+
- 4
|
13 |
+
- 8
|
14 |
+
- 8
|
15 |
+
num_res_blocks: 1
|
16 |
+
attention_resolution:
|
17 |
+
- 16
|
18 |
+
resolution: 128
|
19 |
+
dropout: 0.0
|
20 |
+
discriminator_config:
|
21 |
+
num_layers: 3
|
22 |
+
filters: 64
|
23 |
+
|
24 |
+
loss_config:
|
25 |
+
discriminator:
|
26 |
+
loss: "hinge"
|
27 |
+
factor: 1.0
|
28 |
+
iter_start: 5000
|
29 |
+
weight: 0.8
|
30 |
+
vqvae:
|
31 |
+
codebook_weight: 1.0
|
32 |
+
perceptual_weight: 4.0
|
33 |
+
perceptual_loss: "vgg19" # "vgg16", "vgg19", "style"
|
34 |
+
|
35 |
+
trainer:
|
36 |
+
batch_size: 32
|
37 |
+
n_epochs: 10000
|
38 |
+
gen_lr: 3e-5
|
39 |
+
disc_lr: 3e-5
|
40 |
+
gen_beta_1: 0.5
|
41 |
+
gen_beta_2: 0.9
|
42 |
+
disc_beta_1: 0.5
|
43 |
+
disc_beta_2: 0.9
|
44 |
+
gen_clip_norm: 1.0
|
45 |
+
disc_clip_norm: 1.0
|
46 |
+
|
47 |
+
|
configs/kny_image_full_style.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
checkpoint_path: ../../../checkpoints/kny_image_full_style/checkpoint
|
3 |
+
vqvae_config:
|
4 |
+
beta: 0.25
|
5 |
+
num_embeddings: 50257
|
6 |
+
embedding_dim: 128
|
7 |
+
autoencoder_config:
|
8 |
+
z_channels: 512
|
9 |
+
channels: 32
|
10 |
+
channels_multiplier:
|
11 |
+
- 2
|
12 |
+
- 4
|
13 |
+
- 8
|
14 |
+
- 8
|
15 |
+
num_res_blocks: 1
|
16 |
+
attention_resolution:
|
17 |
+
- 16
|
18 |
+
resolution: 128
|
19 |
+
dropout: 0.0
|
20 |
+
discriminator_config:
|
21 |
+
num_layers: 3
|
22 |
+
filters: 64
|
23 |
+
|
24 |
+
loss_config:
|
25 |
+
discriminator:
|
26 |
+
loss: "hinge"
|
27 |
+
factor: 1.0
|
28 |
+
iter_start: 50000000
|
29 |
+
weight: 0.8
|
30 |
+
vqvae:
|
31 |
+
codebook_weight: 1.0
|
32 |
+
perceptual_weight: 4.0
|
33 |
+
perceptual_loss: "style" # "vgg16", "vgg19", "style"
|
34 |
+
|
35 |
+
trainer:
|
36 |
+
batch_size: 32
|
37 |
+
n_epochs: 10000
|
38 |
+
gen_lr: 8e-5
|
39 |
+
disc_lr: 8e-5
|
40 |
+
gen_beta_1: 0.5
|
41 |
+
gen_beta_2: 0.9
|
42 |
+
disc_beta_1: 0.5
|
43 |
+
disc_beta_2: 0.9
|
44 |
+
gen_clip_norm: 1.0
|
45 |
+
disc_clip_norm: 1.0
|
46 |
+
|
47 |
+
|
configs/kny_image_full_vgg19.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
checkpoint_path: ../../../checkpoints/kny_image_full_vgg19/checkpoint
|
3 |
+
vqvae_config:
|
4 |
+
beta: 0.25
|
5 |
+
num_embeddings: 50257
|
6 |
+
embedding_dim: 128
|
7 |
+
autoencoder_config:
|
8 |
+
z_channels: 512
|
9 |
+
channels: 32
|
10 |
+
channels_multiplier:
|
11 |
+
- 2
|
12 |
+
- 4
|
13 |
+
- 8
|
14 |
+
- 8
|
15 |
+
num_res_blocks: 1
|
16 |
+
attention_resolution:
|
17 |
+
- 16
|
18 |
+
resolution: 128
|
19 |
+
dropout: 0.0
|
20 |
+
discriminator_config:
|
21 |
+
num_layers: 3
|
22 |
+
filters: 64
|
23 |
+
|
24 |
+
loss_config:
|
25 |
+
discriminator:
|
26 |
+
loss: "hinge"
|
27 |
+
factor: 1.0
|
28 |
+
iter_start: 50000000
|
29 |
+
weight: 0.8
|
30 |
+
vqvae:
|
31 |
+
codebook_weight: 1.0
|
32 |
+
perceptual_weight: 4.0
|
33 |
+
perceptual_loss: "vgg19" # "vgg16", "vgg19", "style"
|
34 |
+
|
35 |
+
trainer:
|
36 |
+
batch_size: 64
|
37 |
+
n_epochs: 10000
|
38 |
+
gen_lr: 3e-5
|
39 |
+
disc_lr: 5e-5
|
40 |
+
gen_beta_1: 0.5
|
41 |
+
gen_beta_2: 0.9
|
42 |
+
disc_beta_1: 0.5
|
43 |
+
disc_beta_2: 0.9
|
44 |
+
gen_clip_norm: 1.0
|
45 |
+
disc_clip_norm: 1.0
|
46 |
+
|
47 |
+
|
configs/kny_transformer_light.yaml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
transformer_config:
|
3 |
+
checkpoint_path: ../../../checkpoints/kny_video_light/checkpoint
|
4 |
+
# vocab_size: 50257
|
5 |
+
# n_positions: 1024
|
6 |
+
# n_embd: 1024 #1280 #768
|
7 |
+
# n_layer: 24 #36 #12
|
8 |
+
# n_head: 16 #20 #12
|
9 |
+
# resid_pdrop: 0.1
|
10 |
+
# embd_pdrop: 0.1
|
11 |
+
# attn_pdrop: 0.1
|
12 |
+
# remaining_frames_method: "concat"
|
13 |
+
# remaining_frames_method: "token_type_ids"
|
14 |
+
remaining_frames_method: "own_embeddings"
|
15 |
+
first_stage_config:
|
16 |
+
checkpoint_path: ../../../checkpoints/kny_image_light_discriminator/checkpoint
|
17 |
+
vqvae_config:
|
18 |
+
beta: 0.25
|
19 |
+
num_embeddings: 64
|
20 |
+
embedding_dim: 256
|
21 |
+
autoencoder_config:
|
22 |
+
z_channels: 128
|
23 |
+
channels: 64
|
24 |
+
channels_multiplier:
|
25 |
+
- 1
|
26 |
+
- 1
|
27 |
+
- 2
|
28 |
+
- 2
|
29 |
+
- 4
|
30 |
+
num_res_blocks: 1
|
31 |
+
attention_resolution:
|
32 |
+
- 16
|
33 |
+
resolution: 128
|
34 |
+
dropout: 0.0
|
35 |
+
discriminator_config:
|
36 |
+
num_layers: 3
|
37 |
+
filters: 64
|
38 |
+
|
39 |
+
loss_config:
|
40 |
+
discriminator:
|
41 |
+
loss: "hinge"
|
42 |
+
factor: 1.0
|
43 |
+
iter_start: 16200
|
44 |
+
weight: 0.3
|
45 |
+
vqvae:
|
46 |
+
codebook_weight: 1.0
|
47 |
+
perceptual_weight: 4.0
|
48 |
+
perceptual_loss: "style"
|
49 |
+
|
50 |
+
train:
|
51 |
+
batch_size: 8
|
52 |
+
accumulation_size: 8
|
53 |
+
n_epochs: 2000
|
54 |
+
len_x_train: 631
|
55 |
+
warmup_epoch_percentage: 0.15
|
56 |
+
lr_start: 1e-5
|
57 |
+
lr_max: 2.5e-4
|
58 |
+
perceptual_loss_weight: 1.0
|
59 |
+
n_frames_before: 5
|
60 |
+
stop_ground_truth_after_epoch: 100
|
configs/kny_video_gpt2_large.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
transformer_config:
|
3 |
+
checkpoint_path: ../../../checkpoints/kny_video_full_gpt2_large_final/checkpoint
|
4 |
+
remaining_frames_method: "own_embeddings"
|
5 |
+
transformer_type: "gpt2-large"
|
6 |
+
first_stage_config:
|
7 |
+
checkpoint_path: ../../../checkpoints/kny_image_full_vgg19/checkpoint
|
8 |
+
vqvae_config:
|
9 |
+
beta: 0.25
|
10 |
+
num_embeddings: 50257
|
11 |
+
embedding_dim: 128
|
12 |
+
autoencoder_config:
|
13 |
+
z_channels: 512
|
14 |
+
channels: 32
|
15 |
+
channels_multiplier:
|
16 |
+
- 2
|
17 |
+
- 4
|
18 |
+
- 8
|
19 |
+
- 8
|
20 |
+
num_res_blocks: 1
|
21 |
+
attention_resolution:
|
22 |
+
- 16
|
23 |
+
resolution: 128
|
24 |
+
dropout: 0.0
|
25 |
+
discriminator_config:
|
26 |
+
num_layers: 3
|
27 |
+
filters: 64
|
28 |
+
|
29 |
+
loss_config:
|
30 |
+
discriminator:
|
31 |
+
loss: "hinge"
|
32 |
+
factor: 1.0
|
33 |
+
iter_start: 16200
|
34 |
+
weight: 0.3
|
35 |
+
vqvae:
|
36 |
+
codebook_weight: 1.0
|
37 |
+
perceptual_weight: 4.0
|
38 |
+
perceptual_loss: "vgg19"
|
39 |
+
|
40 |
+
train:
|
41 |
+
batch_size: 64
|
42 |
+
accumulation_size: 1
|
43 |
+
n_epochs: 10000
|
44 |
+
len_x_train: 28213
|
45 |
+
warmup_epoch_percentage: 0.15
|
46 |
+
lr_start: 1e-5
|
47 |
+
lr_max: 2.5e-4
|
48 |
+
perceptual_loss_weight: 1.0
|
49 |
+
n_frames_before: 1
|
50 |
+
stop_ground_truth_after_epoch: 1000
|
configs/kny_video_gpt2_large_gradio.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
transformer_config:
|
3 |
+
checkpoint_path: ./checkpoints/kny_video_full_gpt2_large_final/checkpoint
|
4 |
+
remaining_frames_method: "own_embeddings"
|
5 |
+
transformer_type: "gpt2-large"
|
6 |
+
first_stage_config:
|
7 |
+
checkpoint_path: ./checkpoints/kny_image_full_vgg19/checkpoint
|
8 |
+
vqvae_config:
|
9 |
+
beta: 0.25
|
10 |
+
num_embeddings: 50257
|
11 |
+
embedding_dim: 128
|
12 |
+
autoencoder_config:
|
13 |
+
z_channels: 512
|
14 |
+
channels: 32
|
15 |
+
channels_multiplier:
|
16 |
+
- 2
|
17 |
+
- 4
|
18 |
+
- 8
|
19 |
+
- 8
|
20 |
+
num_res_blocks: 1
|
21 |
+
attention_resolution:
|
22 |
+
- 16
|
23 |
+
resolution: 128
|
24 |
+
dropout: 0.0
|
25 |
+
discriminator_config:
|
26 |
+
num_layers: 3
|
27 |
+
filters: 64
|
28 |
+
|
29 |
+
loss_config:
|
30 |
+
discriminator:
|
31 |
+
loss: "hinge"
|
32 |
+
factor: 1.0
|
33 |
+
iter_start: 16200
|
34 |
+
weight: 0.3
|
35 |
+
vqvae:
|
36 |
+
codebook_weight: 1.0
|
37 |
+
perceptual_weight: 4.0
|
38 |
+
perceptual_loss: "vgg19"
|
39 |
+
|
40 |
+
train:
|
41 |
+
batch_size: 64
|
42 |
+
accumulation_size: 1
|
43 |
+
n_epochs: 10000
|
44 |
+
len_x_train: 28213
|
45 |
+
warmup_epoch_percentage: 0.15
|
46 |
+
lr_start: 1e-5
|
47 |
+
lr_max: 2.5e-4
|
48 |
+
perceptual_loss_weight: 1.0
|
49 |
+
n_frames_before: 1
|
50 |
+
stop_ground_truth_after_epoch: 1000
|
configs/kny_video_gpt2_medium.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
transformer_config:
|
3 |
+
checkpoint_path: ./checkpoints/kny_video_full_gpt2_medium/checkpoint
|
4 |
+
remaining_frames_method: "own_embeddings"
|
5 |
+
transformer_type: "gpt2-medium"
|
6 |
+
first_stage_config:
|
7 |
+
checkpoint_path: ./checkpoints/kny_image_full_vgg19/checkpoint
|
8 |
+
vqvae_config:
|
9 |
+
beta: 0.25
|
10 |
+
num_embeddings: 50257
|
11 |
+
embedding_dim: 128
|
12 |
+
autoencoder_config:
|
13 |
+
z_channels: 512
|
14 |
+
channels: 32
|
15 |
+
channels_multiplier:
|
16 |
+
- 2
|
17 |
+
- 4
|
18 |
+
- 8
|
19 |
+
- 8
|
20 |
+
num_res_blocks: 1
|
21 |
+
attention_resolution:
|
22 |
+
- 16
|
23 |
+
resolution: 128
|
24 |
+
dropout: 0.0
|
25 |
+
discriminator_config:
|
26 |
+
num_layers: 3
|
27 |
+
filters: 64
|
28 |
+
|
29 |
+
loss_config:
|
30 |
+
discriminator:
|
31 |
+
loss: "hinge"
|
32 |
+
factor: 1.0
|
33 |
+
iter_start: 16200
|
34 |
+
weight: 0.3
|
35 |
+
vqvae:
|
36 |
+
codebook_weight: 1.0
|
37 |
+
perceptual_weight: 4.0
|
38 |
+
perceptual_loss: "vgg19"
|
39 |
+
|
40 |
+
train:
|
41 |
+
batch_size: 64
|
42 |
+
accumulation_size: 1
|
43 |
+
n_epochs: 500
|
44 |
+
len_x_train: 28213
|
45 |
+
warmup_epoch_percentage: 0.15
|
46 |
+
lr_start: 5e-6
|
47 |
+
lr_max: 1e-4
|
48 |
+
perceptual_loss_weight: 1.0
|
49 |
+
n_frames_before: 5
|
50 |
+
stop_ground_truth_after_epoch: 200
|
configs/kny_video_gpt2_xl.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
transformer_config:
|
3 |
+
# checkpoint_path: ../../../checkpoints/kny_video_full_gpt2_xl/checkpoint
|
4 |
+
remaining_frames_method: "own_embeddings"
|
5 |
+
transformer_type: "gpt2-xl"
|
6 |
+
first_stage_config:
|
7 |
+
checkpoint_path: ../../../checkpoints/kny_image_full_vgg19/checkpoint
|
8 |
+
vqvae_config:
|
9 |
+
beta: 0.25
|
10 |
+
num_embeddings: 50257
|
11 |
+
embedding_dim: 128
|
12 |
+
autoencoder_config:
|
13 |
+
z_channels: 512
|
14 |
+
channels: 32
|
15 |
+
channels_multiplier:
|
16 |
+
- 2
|
17 |
+
- 4
|
18 |
+
- 8
|
19 |
+
- 8
|
20 |
+
num_res_blocks: 1
|
21 |
+
attention_resolution:
|
22 |
+
- 16
|
23 |
+
resolution: 128
|
24 |
+
dropout: 0.0
|
25 |
+
discriminator_config:
|
26 |
+
num_layers: 3
|
27 |
+
filters: 64
|
28 |
+
|
29 |
+
loss_config:
|
30 |
+
discriminator:
|
31 |
+
loss: "hinge"
|
32 |
+
factor: 1.0
|
33 |
+
iter_start: 16200
|
34 |
+
weight: 0.3
|
35 |
+
vqvae:
|
36 |
+
codebook_weight: 1.0
|
37 |
+
perceptual_weight: 4.0
|
38 |
+
perceptual_loss: "vgg19"
|
39 |
+
|
40 |
+
train:
|
41 |
+
batch_size: 64
|
42 |
+
accumulation_size: 1
|
43 |
+
n_epochs: 500
|
44 |
+
len_x_train: 28213
|
45 |
+
warmup_epoch_percentage: 0.15
|
46 |
+
lr_start: 5e-6
|
47 |
+
lr_max: 1e-4
|
48 |
+
perceptual_loss_weight: 1.0
|
49 |
+
n_frames_before: 1
|
50 |
+
stop_ground_truth_after_epoch: 200
|
ganime/__main__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ganime import app
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
app.run()
|
ganime/app.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import click
|
4 |
+
import omegaconf
|
5 |
+
import ray
|
6 |
+
from pyprojroot.pyprojroot import here
|
7 |
+
from ray import tune
|
8 |
+
from ray.train import Trainer
|
9 |
+
from ray.tune.schedulers import AsyncHyperBandScheduler
|
10 |
+
from ray.tune.suggest import ConcurrencyLimiter
|
11 |
+
from ray.tune.suggest.optuna import OptunaSearch
|
12 |
+
|
13 |
+
from ganime.trainer.ganime import TrainableGANime
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152
|
18 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "1, 2, 3, 4, 5, 6"
|
19 |
+
|
20 |
+
|
21 |
+
def get_metric_direction(metric: str):
|
22 |
+
if "loss" in metric:
|
23 |
+
return "min"
|
24 |
+
else:
|
25 |
+
raise ValueError(f"Unknown metric: {metric}")
|
26 |
+
|
27 |
+
|
28 |
+
def trial_name_id(trial):
|
29 |
+
return f"{trial.trainable_name}"
|
30 |
+
|
31 |
+
|
32 |
+
def trial_dirname_creator(trial):
|
33 |
+
return f"{trial.trial_id}"
|
34 |
+
|
35 |
+
|
36 |
+
def get_search_space(model):
|
37 |
+
if model == "vqgan":
|
38 |
+
return {
|
39 |
+
# "beta": tune.uniform(0.1, 1.0),
|
40 |
+
"num_embeddings": tune.choice([64, 128, 256]),
|
41 |
+
"embedding_dim": tune.choice([128, 256, 512, 1024]),
|
42 |
+
"z_channels": tune.choice([64, 128, 256]),
|
43 |
+
"channels": tune.choice([64, 128, 256]),
|
44 |
+
"channels_multiplier": tune.choice(
|
45 |
+
[
|
46 |
+
[1, 2, 4],
|
47 |
+
[1, 1, 2, 2],
|
48 |
+
[1, 2, 2, 4],
|
49 |
+
[1, 1, 2, 2, 4],
|
50 |
+
]
|
51 |
+
),
|
52 |
+
"attention_resolution": tune.choice([[16], [32], [16, 32]]),
|
53 |
+
"batch_size": tune.choice([8, 16]),
|
54 |
+
"dropout": tune.choice([0.0, 0.1, 0.2]),
|
55 |
+
"weight": tune.quniform(0.1, 1.0, 0.1),
|
56 |
+
"codebook_weight": tune.quniform(0.2, 2.0, 0.2),
|
57 |
+
"perceptual_weight": tune.quniform(0.5, 5.0, 0.5),
|
58 |
+
"gen_lr": tune.qloguniform(1e-5, 1e-3, 1e-5),
|
59 |
+
"disc_lr": tune.qloguniform(1e-5, 1e-3, 1e-5),
|
60 |
+
"gen_beta_1": tune.quniform(0.5, 0.9, 0.1),
|
61 |
+
"gen_beta_2": tune.quniform(0.9, 0.999, 0.001),
|
62 |
+
"disc_beta_1": tune.quniform(0.5, 0.9, 0.1),
|
63 |
+
"disc_beta_2": tune.quniform(0.9, 0.999, 0.001),
|
64 |
+
"gen_clip_norm": tune.choice([1.0, None]),
|
65 |
+
"disc_clip_norm": tune.choice([1.0, None]),
|
66 |
+
}
|
67 |
+
elif model == "gpt":
|
68 |
+
return {
|
69 |
+
"remaining_frames_method": tune.choice(
|
70 |
+
["concat", "token_type_ids", "own_embeddings"]
|
71 |
+
),
|
72 |
+
# "batch_size": tune.choice([8, 16]),
|
73 |
+
"lr_max": tune.qloguniform(1e-5, 1e-3, 5e-5),
|
74 |
+
"lr_start": tune.sample_from(lambda spec: spec.config.lr_max / 10),
|
75 |
+
"perceptual_loss_weight": tune.quniform(0.0, 1.0, 0.1),
|
76 |
+
"n_frames_before": tune.randint(1, 10),
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
def tune_ganime(
|
81 |
+
experiment_name: str,
|
82 |
+
dataset_name: str,
|
83 |
+
config_file: str,
|
84 |
+
model: str,
|
85 |
+
metric: str,
|
86 |
+
epochs: int,
|
87 |
+
num_samples: int,
|
88 |
+
num_cpus: int,
|
89 |
+
num_gpus: int,
|
90 |
+
max_concurrent_trials: int,
|
91 |
+
):
|
92 |
+
|
93 |
+
dataset_path = here("data")
|
94 |
+
analysis = tune.run(
|
95 |
+
TrainableGANime,
|
96 |
+
name=experiment_name,
|
97 |
+
search_alg=ConcurrencyLimiter(
|
98 |
+
OptunaSearch(), max_concurrent=max_concurrent_trials
|
99 |
+
),
|
100 |
+
scheduler=AsyncHyperBandScheduler(max_t=epochs, grace_period=5),
|
101 |
+
metric=metric,
|
102 |
+
mode=get_metric_direction(metric),
|
103 |
+
num_samples=num_samples,
|
104 |
+
stop={"training_iteration": epochs},
|
105 |
+
local_dir="./ganime_results",
|
106 |
+
config={
|
107 |
+
"dataset_name": dataset_name,
|
108 |
+
"dataset_path": dataset_path,
|
109 |
+
"model": model,
|
110 |
+
"config_file": config_file,
|
111 |
+
"hyperparameters": get_search_space(model),
|
112 |
+
},
|
113 |
+
resources_per_trial={
|
114 |
+
"cpu": num_cpus // max_concurrent_trials,
|
115 |
+
"gpu": num_gpus / max_concurrent_trials,
|
116 |
+
},
|
117 |
+
trial_name_creator=trial_name_id,
|
118 |
+
trial_dirname_creator=trial_dirname_creator,
|
119 |
+
)
|
120 |
+
best_loss = analysis.get_best_config(metric="total_loss", mode="min")
|
121 |
+
# best_accuracy = analysis.get_best_config(metric="accuracy", mode="max")
|
122 |
+
print(f"Best loss config: {best_loss}")
|
123 |
+
# print(f"Best accuracy config: {best_accuracy}")
|
124 |
+
return analysis
|
125 |
+
|
126 |
+
|
127 |
+
@click.command()
|
128 |
+
@click.option(
|
129 |
+
"--dataset",
|
130 |
+
type=click.Choice(
|
131 |
+
["moving_mnist_images", "kny_images", "kny_images_light"], case_sensitive=False
|
132 |
+
),
|
133 |
+
default="kny_images_light",
|
134 |
+
help="Dataset to use",
|
135 |
+
)
|
136 |
+
@click.option(
|
137 |
+
"--model",
|
138 |
+
type=click.Choice(["vqgan", "gpt"], case_sensitive=False),
|
139 |
+
default="vqgan",
|
140 |
+
help="Model to use",
|
141 |
+
)
|
142 |
+
@click.option(
|
143 |
+
"--epochs",
|
144 |
+
default=500,
|
145 |
+
help="Number of epochs to run",
|
146 |
+
)
|
147 |
+
@click.option(
|
148 |
+
"--num_samples",
|
149 |
+
default=100,
|
150 |
+
help="Total number of trials to run",
|
151 |
+
)
|
152 |
+
@click.option(
|
153 |
+
"--num_cpus",
|
154 |
+
default=64,
|
155 |
+
help="Number of cpus to use",
|
156 |
+
)
|
157 |
+
@click.option(
|
158 |
+
"--num_gpus",
|
159 |
+
default=6,
|
160 |
+
help="Number of gpus to use",
|
161 |
+
)
|
162 |
+
@click.option(
|
163 |
+
"--max_concurrent_trials",
|
164 |
+
default=6,
|
165 |
+
help="Maximum number of concurrent trials",
|
166 |
+
)
|
167 |
+
@click.option(
|
168 |
+
"--metric",
|
169 |
+
type=click.Choice(
|
170 |
+
["total_loss", "reconstruction_loss", "vq_loss", "disc_loss"],
|
171 |
+
case_sensitive=False,
|
172 |
+
),
|
173 |
+
default="total_loss",
|
174 |
+
help="The metric used to select the best trial",
|
175 |
+
)
|
176 |
+
@click.option(
|
177 |
+
"--experiment_name",
|
178 |
+
default="kny_images_light_v2",
|
179 |
+
help="The name of the experiment for logging in Tensorboard",
|
180 |
+
)
|
181 |
+
@click.option(
|
182 |
+
"--config_file",
|
183 |
+
default="kny_image.yaml",
|
184 |
+
help="The name of the config file located inside ./config",
|
185 |
+
)
|
186 |
+
def run(
|
187 |
+
experiment_name: str,
|
188 |
+
config_file: str,
|
189 |
+
dataset: str,
|
190 |
+
model: str,
|
191 |
+
epochs: int,
|
192 |
+
num_samples: int,
|
193 |
+
num_cpus: int,
|
194 |
+
num_gpus: int,
|
195 |
+
max_concurrent_trials: int,
|
196 |
+
metric: str,
|
197 |
+
):
|
198 |
+
config_file = here(os.path.join("configs", config_file))
|
199 |
+
|
200 |
+
ray.init(num_cpus=num_cpus, num_gpus=num_gpus)
|
201 |
+
tune_ganime(
|
202 |
+
experiment_name=experiment_name,
|
203 |
+
dataset_name=dataset,
|
204 |
+
config_file=config_file,
|
205 |
+
model=model,
|
206 |
+
epochs=epochs,
|
207 |
+
num_samples=num_samples,
|
208 |
+
num_cpus=num_cpus,
|
209 |
+
num_gpus=num_gpus,
|
210 |
+
max_concurrent_trials=max_concurrent_trials,
|
211 |
+
metric=metric,
|
212 |
+
)
|
ganime/configs/__init__.py
ADDED
File without changes
|
ganime/configs/model_configs.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List
|
3 |
+
try:
|
4 |
+
from typing import Literal
|
5 |
+
except ImportError:
|
6 |
+
from typing_extensions import Literal
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class GPTConfig:
|
11 |
+
n_layer: int
|
12 |
+
n_head: int
|
13 |
+
n_embedding: int
|
14 |
+
vocab_size: int
|
15 |
+
block_size: int
|
16 |
+
embedding_percentage_drop: float
|
17 |
+
attention_percentage_drop: float
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class VQVAEConfig:
|
22 |
+
beta: float
|
23 |
+
num_embeddings: int
|
24 |
+
embedding_dim: int
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class AutoencoderConfig:
|
29 |
+
z_channels: int
|
30 |
+
channels: int
|
31 |
+
channels_multiplier: List[int]
|
32 |
+
num_res_blocks: int
|
33 |
+
attention_resolution: List[int]
|
34 |
+
resolution: int
|
35 |
+
dropout: float
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class DiscriminatorConfig:
|
40 |
+
num_layers: int
|
41 |
+
filters: int
|
42 |
+
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class DiscriminatorLossConfig:
|
46 |
+
loss: Literal["hinge, vanilla"]
|
47 |
+
factor: float
|
48 |
+
iter_start: int
|
49 |
+
weight: float
|
50 |
+
|
51 |
+
|
52 |
+
@dataclass
|
53 |
+
class VQVAELossConfig:
|
54 |
+
codebook_weight: float
|
55 |
+
perceptual_weight: float
|
56 |
+
|
57 |
+
|
58 |
+
@dataclass
|
59 |
+
class LossConfig:
|
60 |
+
discriminator: DiscriminatorLossConfig
|
61 |
+
vqvae: VQVAELossConfig
|
62 |
+
perceptual_loss: str
|
63 |
+
|
64 |
+
|
65 |
+
@dataclass
|
66 |
+
class ModelConfig:
|
67 |
+
vqvae_config: VQVAEConfig
|
68 |
+
autoencoder_config: AutoencoderConfig
|
69 |
+
discriminator_config: DiscriminatorConfig
|
70 |
+
loss_config: LossConfig
|
ganime/data/__init__.py
ADDED
File without changes
|
ganime/data/base.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
import os
|
5 |
+
from tensorflow.keras.utils import Sequence
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
from typing import Literal
|
8 |
+
import math
|
9 |
+
from ganime.data.experimental import ImageDataset
|
10 |
+
|
11 |
+
|
12 |
+
# class SequenceDataset(Sequence):
|
13 |
+
# def __init__(
|
14 |
+
# self,
|
15 |
+
# dataset_path: str,
|
16 |
+
# batch_size: int,
|
17 |
+
# split: Literal["train", "validation", "test"] = "train",
|
18 |
+
# ):
|
19 |
+
# self.batch_size = batch_size
|
20 |
+
# self.split = split
|
21 |
+
# self.data = self.load_data(dataset_path, split)
|
22 |
+
# self.data = self.preprocess_data(self.data)
|
23 |
+
|
24 |
+
# self.indices = np.arange(self.data.shape[0])
|
25 |
+
# self.on_epoch_end()
|
26 |
+
|
27 |
+
# @abstractmethod
|
28 |
+
# def load_data(self, dataset_path: str, split: str) -> np.ndarray:
|
29 |
+
# pass
|
30 |
+
|
31 |
+
# def preprocess_data(self, data: np.ndarray) -> np.ndarray:
|
32 |
+
# return data
|
33 |
+
|
34 |
+
# def __len__(self):
|
35 |
+
# return math.ceil(len(self.data) / self.batch_size)
|
36 |
+
|
37 |
+
# def __getitem__(self, idx):
|
38 |
+
# inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size]
|
39 |
+
# batch_x = self.data[inds]
|
40 |
+
# batch_y = batch_x
|
41 |
+
|
42 |
+
# return batch_x, batch_y
|
43 |
+
|
44 |
+
# def get_fixed_batch(self, idx):
|
45 |
+
# self.fixed_indices = (
|
46 |
+
# self.fixed_indices
|
47 |
+
# if hasattr(self, "fixed_indices")
|
48 |
+
# else self.indices[
|
49 |
+
# idx * self.batch_size : (idx + 1) * self.batch_size
|
50 |
+
# ].copy()
|
51 |
+
# )
|
52 |
+
# batch_x = self.data[self.fixed_indices]
|
53 |
+
# batch_y = batch_x
|
54 |
+
|
55 |
+
# return batch_x, batch_y
|
56 |
+
|
57 |
+
# def on_epoch_end(self):
|
58 |
+
# np.random.shuffle(self.indices)
|
59 |
+
|
60 |
+
|
61 |
+
# def load_kny_images(
|
62 |
+
# dataset_path: str, batch_size: int
|
63 |
+
# ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tuple]:
|
64 |
+
# import skvideo.io
|
65 |
+
|
66 |
+
# if os.path.exists(os.path.join(dataset_path, "kny", "kny_images.npy")):
|
67 |
+
# data = np.load(os.path.join(dataset_path, "kny", "kny_images.npy"))
|
68 |
+
# else:
|
69 |
+
# data = skvideo.io.vread(os.path.join(dataset_path, "kny", "01.mp4"))
|
70 |
+
# np.random.shuffle(data)
|
71 |
+
|
72 |
+
# def _preprocess(sample):
|
73 |
+
# image = tf.cast(sample, tf.float32) / 255.0 # Scale to unit interval.
|
74 |
+
# # video = video < tf.random.uniform(tf.shape(video)) # Randomly binarize.
|
75 |
+
# image = tf.image.resize(image, [64, 64])
|
76 |
+
|
77 |
+
# return image, image
|
78 |
+
|
79 |
+
# train_dataset = (
|
80 |
+
# tf.data.Dataset.from_tensor_slices(data[:5000])
|
81 |
+
# .map(_preprocess)
|
82 |
+
# .batch(batch_size)
|
83 |
+
# .prefetch(tf.data.AUTOTUNE)
|
84 |
+
# .shuffle(int(10e3))
|
85 |
+
# )
|
86 |
+
# test_dataset = (
|
87 |
+
# tf.data.Dataset.from_tensor_slices(data[5000:6000])
|
88 |
+
# .map(_preprocess)
|
89 |
+
# .batch(batch_size)
|
90 |
+
# .prefetch(tf.data.AUTOTUNE)
|
91 |
+
# .shuffle(int(10e3))
|
92 |
+
# )
|
93 |
+
|
94 |
+
# return train_dataset, test_dataset, data.shape[1:]
|
95 |
+
|
96 |
+
|
97 |
+
# def load_moving_mnist_vae(
|
98 |
+
# dataset_path: str, batch_size: int
|
99 |
+
# ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tuple]:
|
100 |
+
# data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy"))
|
101 |
+
# data.shape
|
102 |
+
|
103 |
+
# # We can see that data is of shape (window, n_samples, width, height)
|
104 |
+
# # But we want for keras something of shape (n_samples, window, width, height)
|
105 |
+
# data = np.moveaxis(data, 0, 1)
|
106 |
+
# # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
|
107 |
+
# data = np.expand_dims(data, axis=-1)
|
108 |
+
|
109 |
+
# def _preprocess(sample):
|
110 |
+
# video = tf.cast(sample, tf.float32) / 255.0 # Scale to unit interval.
|
111 |
+
# # video = video < tf.random.uniform(tf.shape(video)) # Randomly binarize.
|
112 |
+
# return video, video
|
113 |
+
|
114 |
+
# train_dataset = (
|
115 |
+
# tf.data.Dataset.from_tensor_slices(data[:9000])
|
116 |
+
# .map(_preprocess)
|
117 |
+
# .batch(batch_size)
|
118 |
+
# .prefetch(tf.data.AUTOTUNE)
|
119 |
+
# .shuffle(int(10e3))
|
120 |
+
# )
|
121 |
+
# test_dataset = (
|
122 |
+
# tf.data.Dataset.from_tensor_slices(data[9000:])
|
123 |
+
# .map(_preprocess)
|
124 |
+
# .batch(batch_size)
|
125 |
+
# .prefetch(tf.data.AUTOTUNE)
|
126 |
+
# .shuffle(int(10e3))
|
127 |
+
# )
|
128 |
+
|
129 |
+
# return train_dataset, test_dataset, data.shape[1:]
|
130 |
+
|
131 |
+
|
132 |
+
# def load_moving_mnist(
|
133 |
+
# dataset_path: str, batch_size: int
|
134 |
+
# ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tuple]:
|
135 |
+
# data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy"))
|
136 |
+
# data.shape
|
137 |
+
|
138 |
+
# # We can see that data is of shape (window, n_samples, width, height)
|
139 |
+
# # But we want for keras something of shape (n_samples, window, width, height)
|
140 |
+
# data = np.moveaxis(data, 0, 1)
|
141 |
+
# # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
|
142 |
+
# data = np.expand_dims(data, axis=-1)
|
143 |
+
|
144 |
+
# def _preprocess(sample):
|
145 |
+
# video = tf.cast(sample, tf.float32) / 255.0 # Scale to unit interval.
|
146 |
+
# # video = video < tf.random.uniform(tf.shape(video)) # Randomly binarize.
|
147 |
+
# first_frame = video[0:1, ...]
|
148 |
+
# last_frame = video[-1:, ...]
|
149 |
+
# first_last = tf.concat([first_frame, last_frame], axis=0)
|
150 |
+
|
151 |
+
# return first_last, video
|
152 |
+
|
153 |
+
# train_dataset = (
|
154 |
+
# tf.data.Dataset.from_tensor_slices(data[:9000])
|
155 |
+
# .map(_preprocess)
|
156 |
+
# .batch(batch_size)
|
157 |
+
# .prefetch(tf.data.AUTOTUNE)
|
158 |
+
# .shuffle(int(10e3))
|
159 |
+
# )
|
160 |
+
# test_dataset = (
|
161 |
+
# tf.data.Dataset.from_tensor_slices(data[9000:])
|
162 |
+
# .map(_preprocess)
|
163 |
+
# .batch(batch_size)
|
164 |
+
# .prefetch(tf.data.AUTOTUNE)
|
165 |
+
# .shuffle(int(10e3))
|
166 |
+
# )
|
167 |
+
|
168 |
+
# return train_dataset, test_dataset, data.shape[1:]
|
169 |
+
|
170 |
+
|
171 |
+
# def load_mnist(
|
172 |
+
# dataset_path: str, batch_size: int
|
173 |
+
# ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tuple]:
|
174 |
+
# data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy"))
|
175 |
+
# data.shape
|
176 |
+
|
177 |
+
# # We can see that data is of shape (window, n_samples, width, height)
|
178 |
+
# # But we want for keras something of shape (n_samples, window, width, height)
|
179 |
+
# data = np.moveaxis(data, 0, 1)
|
180 |
+
# # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
|
181 |
+
# data = np.expand_dims(data, axis=-1)
|
182 |
+
|
183 |
+
# def _preprocess(sample):
|
184 |
+
# video = tf.cast(sample, tf.float32) / 255.0 # Scale to unit interval.
|
185 |
+
# # video = video < tf.random.uniform(tf.shape(video)) # Randomly binarize.
|
186 |
+
# first_frame = video[0, ...]
|
187 |
+
|
188 |
+
# first_frame = tf.image.grayscale_to_rgb(first_frame)
|
189 |
+
|
190 |
+
# return first_frame, first_frame
|
191 |
+
|
192 |
+
# train_dataset = (
|
193 |
+
# tf.data.Dataset.from_tensor_slices(data[:9000])
|
194 |
+
# .map(_preprocess)
|
195 |
+
# .batch(batch_size)
|
196 |
+
# .prefetch(tf.data.AUTOTUNE)
|
197 |
+
# .shuffle(int(10e3))
|
198 |
+
# )
|
199 |
+
# test_dataset = (
|
200 |
+
# tf.data.Dataset.from_tensor_slices(data[9000:])
|
201 |
+
# .map(_preprocess)
|
202 |
+
# .batch(batch_size)
|
203 |
+
# .prefetch(tf.data.AUTOTUNE)
|
204 |
+
# .shuffle(int(10e3))
|
205 |
+
# )
|
206 |
+
|
207 |
+
# return train_dataset, test_dataset, data.shape[1:]
|
208 |
+
def preprocess_image(element):
|
209 |
+
element = tf.reshape(element, (tf.shape(element)[0], tf.shape(element)[1], 3))
|
210 |
+
element = tf.cast(element, tf.float32) / 255.0
|
211 |
+
return element, element
|
212 |
+
|
213 |
+
|
214 |
+
def load_kny_images_light(dataset_path, batch_size):
|
215 |
+
dataset_length = 34045
|
216 |
+
path = os.path.join(dataset_path, "kny", "images_tfrecords_light")
|
217 |
+
dataset = ImageDataset(path).load()
|
218 |
+
dataset = dataset.shuffle(
|
219 |
+
dataset_length, reshuffle_each_iteration=True, seed=10
|
220 |
+
).map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
|
221 |
+
|
222 |
+
train_size = int(dataset_length * 0.8)
|
223 |
+
validation_size = int(dataset_length * 0.1)
|
224 |
+
|
225 |
+
train_ds = dataset.take(train_size)
|
226 |
+
validation_ds = dataset.skip(train_size).take(validation_size)
|
227 |
+
test_ds = dataset.skip(train_size + validation_size).take(validation_size)
|
228 |
+
|
229 |
+
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(
|
230 |
+
tf.data.AUTOTUNE
|
231 |
+
)
|
232 |
+
validation_ds = validation_ds.batch(batch_size, drop_remainder=True).prefetch(
|
233 |
+
tf.data.AUTOTUNE
|
234 |
+
)
|
235 |
+
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
|
236 |
+
|
237 |
+
return train_ds, validation_ds, test_ds
|
238 |
+
|
239 |
+
|
240 |
+
def load_kny_images(dataset_path, batch_size):
|
241 |
+
dataset_length = 52014
|
242 |
+
path = os.path.join(dataset_path, "kny", "images_tfrecords")
|
243 |
+
dataset = ImageDataset(path).load()
|
244 |
+
dataset = dataset.shuffle(dataset_length, reshuffle_each_iteration=True).map(
|
245 |
+
preprocess_image, num_parallel_calls=tf.data.AUTOTUNE
|
246 |
+
)
|
247 |
+
|
248 |
+
train_size = int(dataset_length * 0.8)
|
249 |
+
validation_size = int(dataset_length * 0.1)
|
250 |
+
|
251 |
+
train_ds = dataset.take(train_size)
|
252 |
+
validation_ds = dataset.skip(train_size).take(validation_size)
|
253 |
+
test_ds = dataset.skip(train_size + validation_size).take(validation_size)
|
254 |
+
|
255 |
+
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(
|
256 |
+
tf.data.AUTOTUNE
|
257 |
+
)
|
258 |
+
validation_ds = validation_ds.batch(batch_size, drop_remainder=True).prefetch(
|
259 |
+
tf.data.AUTOTUNE
|
260 |
+
)
|
261 |
+
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
|
262 |
+
|
263 |
+
return train_ds, validation_ds, test_ds
|
264 |
+
|
265 |
+
|
266 |
+
def load_dataset(
|
267 |
+
dataset_name: str, dataset_path: str, batch_size: int
|
268 |
+
) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
|
269 |
+
# if dataset_name == "moving_mnist_vae":
|
270 |
+
# return load_moving_mnist_vae(dataset_path, batch_size)
|
271 |
+
# elif dataset_name == "moving_mnist":
|
272 |
+
# return load_moving_mnist(dataset_path, batch_size)
|
273 |
+
# elif dataset_name == "mnist":
|
274 |
+
# return load_mnist(dataset_path, batch_size)
|
275 |
+
# elif dataset_name == "kny_images":
|
276 |
+
# return load_kny_images(dataset_path, batch_size)
|
277 |
+
if dataset_name == "kny_images":
|
278 |
+
return load_kny_images(dataset_path, batch_size)
|
279 |
+
if dataset_name == "kny_images_light":
|
280 |
+
return load_kny_images_light(dataset_path, batch_size)
|
281 |
+
else:
|
282 |
+
raise ValueError(f"Unknown dataset: {dataset_name}")
|
ganime/data/experimental.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractclassmethod, abstractmethod
|
2 |
+
import glob
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
from typing import Dict
|
6 |
+
from typing_extensions import dataclass_transform
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import tensorflow as tf
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
|
12 |
+
|
13 |
+
def _bytes_feature(value):
|
14 |
+
"""Returns a bytes_list from a string / byte."""
|
15 |
+
if isinstance(value, type(tf.constant(0))): # if value ist tensor
|
16 |
+
value = value.numpy() # get value of tensor
|
17 |
+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
18 |
+
|
19 |
+
|
20 |
+
def _float_feature(value):
|
21 |
+
"""Returns a floast_list from a float / double."""
|
22 |
+
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
|
23 |
+
|
24 |
+
|
25 |
+
def _int64_feature(value):
|
26 |
+
"""Returns an int64_list from a bool / enum / int / uint."""
|
27 |
+
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
|
28 |
+
|
29 |
+
|
30 |
+
def serialize_array(array):
|
31 |
+
array = tf.io.serialize_tensor(array)
|
32 |
+
return array
|
33 |
+
|
34 |
+
|
35 |
+
class Dataset(ABC):
|
36 |
+
def __init__(self, dataset_path: str):
|
37 |
+
self.dataset_path = dataset_path
|
38 |
+
|
39 |
+
@classmethod
|
40 |
+
def _parse_single_element(cls, element) -> tf.train.Example:
|
41 |
+
|
42 |
+
features = tf.train.Features(feature=cls._get_features(element))
|
43 |
+
|
44 |
+
return tf.train.Example(features=features)
|
45 |
+
|
46 |
+
@abstractclassmethod
|
47 |
+
def _get_features(cls, element) -> Dict[str, tf.train.Feature]:
|
48 |
+
pass
|
49 |
+
|
50 |
+
@abstractclassmethod
|
51 |
+
def _parse_tfr_element(cls, element):
|
52 |
+
pass
|
53 |
+
|
54 |
+
@classmethod
|
55 |
+
def write_to_tfr(cls, data: np.ndarray, out_dir: str, filename: str):
|
56 |
+
if not os.path.exists(out_dir):
|
57 |
+
os.makedirs(out_dir)
|
58 |
+
|
59 |
+
# Write all elements to a single tfrecord file
|
60 |
+
single_file_name = cls.__write_to_single_tfr(data, out_dir, filename)
|
61 |
+
|
62 |
+
# The optimal size for a single tfrecord file is around 100 MB. Get the number of files that need to be created
|
63 |
+
number_splits = cls.__get_number_splits(single_file_name)
|
64 |
+
|
65 |
+
if number_splits > 1:
|
66 |
+
os.remove(single_file_name)
|
67 |
+
cls.__write_to_multiple_tfr(data, out_dir, filename, number_splits)
|
68 |
+
|
69 |
+
@classmethod
|
70 |
+
def __write_to_multiple_tfr(
|
71 |
+
cls, data: np.array, out_dir: str, filename: str, n_splits: int
|
72 |
+
):
|
73 |
+
|
74 |
+
file_count = 0
|
75 |
+
|
76 |
+
max_files = math.ceil(data.shape[0] / n_splits)
|
77 |
+
|
78 |
+
print(f"Creating {n_splits} files with {max_files} elements each.")
|
79 |
+
|
80 |
+
for i in tqdm(range(n_splits)):
|
81 |
+
current_shard_name = os.path.join(
|
82 |
+
out_dir,
|
83 |
+
f"{filename}.tfrecords-{str(i).zfill(len(str(n_splits)))}-of-{n_splits}",
|
84 |
+
)
|
85 |
+
writer = tf.io.TFRecordWriter(current_shard_name)
|
86 |
+
|
87 |
+
current_shard_count = 0
|
88 |
+
while current_shard_count < max_files: # as long as our shard is not full
|
89 |
+
# get the index of the file that we want to parse now
|
90 |
+
index = i * max_files + current_shard_count
|
91 |
+
if index >= len(
|
92 |
+
data
|
93 |
+
): # when we have consumed the whole data, preempt generation
|
94 |
+
break
|
95 |
+
|
96 |
+
current_element = data[index]
|
97 |
+
|
98 |
+
# create the required Example representation
|
99 |
+
out = cls._parse_single_element(element=current_element)
|
100 |
+
|
101 |
+
writer.write(out.SerializeToString())
|
102 |
+
current_shard_count += 1
|
103 |
+
file_count += 1
|
104 |
+
|
105 |
+
writer.close()
|
106 |
+
print(f"\nWrote {file_count} elements to TFRecord")
|
107 |
+
return file_count
|
108 |
+
|
109 |
+
@classmethod
|
110 |
+
def __get_number_splits(cls, filename: str):
|
111 |
+
target_size = 100 * 1024 * 1024 # 100mb
|
112 |
+
|
113 |
+
single_file_size = os.path.getsize(filename)
|
114 |
+
number_splits = math.ceil(single_file_size / target_size)
|
115 |
+
return number_splits
|
116 |
+
|
117 |
+
@classmethod
|
118 |
+
def __write_to_single_tfr(cls, data: np.array, out_dir: str, filename: str):
|
119 |
+
|
120 |
+
current_path_name = os.path.join(
|
121 |
+
out_dir,
|
122 |
+
f"{filename}.tfrecords-0-of-1",
|
123 |
+
)
|
124 |
+
|
125 |
+
writer = tf.io.TFRecordWriter(current_path_name)
|
126 |
+
for element in tqdm(data):
|
127 |
+
writer.write(cls._parse_single_element(element).SerializeToString())
|
128 |
+
writer.close()
|
129 |
+
|
130 |
+
return current_path_name
|
131 |
+
|
132 |
+
def load(self) -> tf.data.TFRecordDataset:
|
133 |
+
path = self.dataset_path
|
134 |
+
dataset = None
|
135 |
+
|
136 |
+
if os.path.isdir(path):
|
137 |
+
dataset = self._load_folder(path)
|
138 |
+
elif os.path.isfile(path):
|
139 |
+
dataset = self._load_file(path)
|
140 |
+
else:
|
141 |
+
raise ValueError(f"Path {path} is not a valid file or folder.")
|
142 |
+
|
143 |
+
dataset = dataset.map(self._parse_tfr_element)
|
144 |
+
return dataset
|
145 |
+
|
146 |
+
def _load_file(self, path) -> tf.data.TFRecordDataset:
|
147 |
+
return tf.data.TFRecordDataset(path)
|
148 |
+
|
149 |
+
def _load_folder(self, path) -> tf.data.TFRecordDataset:
|
150 |
+
|
151 |
+
return tf.data.TFRecordDataset(
|
152 |
+
glob.glob(os.path.join(path, "**/*.tfrecords*"), recursive=True)
|
153 |
+
)
|
154 |
+
|
155 |
+
|
156 |
+
class VideoDataset(Dataset):
|
157 |
+
@classmethod
|
158 |
+
def _get_features(cls, element) -> Dict[str, tf.train.Feature]:
|
159 |
+
return {
|
160 |
+
"frames": _int64_feature(element.shape[0]),
|
161 |
+
"height": _int64_feature(element.shape[1]),
|
162 |
+
"width": _int64_feature(element.shape[2]),
|
163 |
+
"depth": _int64_feature(element.shape[3]),
|
164 |
+
"raw_video": _bytes_feature(serialize_array(element)),
|
165 |
+
}
|
166 |
+
|
167 |
+
@classmethod
|
168 |
+
def _parse_tfr_element(cls, element):
|
169 |
+
# use the same structure as above; it's kinda an outline of the structure we now want to create
|
170 |
+
data = {
|
171 |
+
"frames": tf.io.FixedLenFeature([], tf.int64),
|
172 |
+
"height": tf.io.FixedLenFeature([], tf.int64),
|
173 |
+
"width": tf.io.FixedLenFeature([], tf.int64),
|
174 |
+
"raw_video": tf.io.FixedLenFeature([], tf.string),
|
175 |
+
"depth": tf.io.FixedLenFeature([], tf.int64),
|
176 |
+
}
|
177 |
+
|
178 |
+
content = tf.io.parse_single_example(element, data)
|
179 |
+
|
180 |
+
frames = content["frames"]
|
181 |
+
height = content["height"]
|
182 |
+
width = content["width"]
|
183 |
+
depth = content["depth"]
|
184 |
+
raw_video = content["raw_video"]
|
185 |
+
|
186 |
+
# get our 'feature'-- our image -- and reshape it appropriately
|
187 |
+
feature = tf.io.parse_tensor(raw_video, out_type=tf.uint8)
|
188 |
+
feature = tf.reshape(feature, shape=[frames, height, width, depth])
|
189 |
+
return feature
|
190 |
+
|
191 |
+
|
192 |
+
class ImageDataset(Dataset):
|
193 |
+
@classmethod
|
194 |
+
def _get_features(cls, element) -> Dict[str, tf.train.Feature]:
|
195 |
+
return {
|
196 |
+
"height": _int64_feature(element.shape[0]),
|
197 |
+
"width": _int64_feature(element.shape[1]),
|
198 |
+
"depth": _int64_feature(element.shape[2]),
|
199 |
+
"raw_image": _bytes_feature(serialize_array(element)),
|
200 |
+
}
|
201 |
+
|
202 |
+
@classmethod
|
203 |
+
def _parse_tfr_element(cls, element):
|
204 |
+
# use the same structure as above; it's kinda an outline of the structure we now want to create
|
205 |
+
data = {
|
206 |
+
"height": tf.io.FixedLenFeature([], tf.int64),
|
207 |
+
"width": tf.io.FixedLenFeature([], tf.int64),
|
208 |
+
"raw_image": tf.io.FixedLenFeature([], tf.string),
|
209 |
+
"depth": tf.io.FixedLenFeature([], tf.int64),
|
210 |
+
}
|
211 |
+
|
212 |
+
content = tf.io.parse_single_example(element, data)
|
213 |
+
|
214 |
+
height = content["height"]
|
215 |
+
width = content["width"]
|
216 |
+
depth = content["depth"]
|
217 |
+
raw_image = content["raw_image"]
|
218 |
+
|
219 |
+
# get our 'feature'-- our image -- and reshape it appropriately
|
220 |
+
feature = tf.io.parse_tensor(raw_image, out_type=tf.uint8)
|
221 |
+
feature = tf.reshape(feature, shape=[height, width, depth])
|
222 |
+
return feature
|
ganime/data/kny.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from .base import SequenceDataset
|
6 |
+
|
7 |
+
|
8 |
+
class KNYImage(SequenceDataset):
|
9 |
+
def load_data(self, dataset_path: str, split: str) -> np.ndarray:
|
10 |
+
data = np.load(os.path.join(dataset_path, "kny", "kny_images_64x128.npy"))
|
11 |
+
if split == "train":
|
12 |
+
data = data[:-5000]
|
13 |
+
else:
|
14 |
+
data = data[-5000:]
|
15 |
+
|
16 |
+
return data
|
17 |
+
|
18 |
+
def preprocess_data(self, data: np.ndarray) -> np.ndarray:
|
19 |
+
return data / 255
|
ganime/data/mnist.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
from typing import Literal
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from .base import SequenceDataset
|
8 |
+
import math
|
9 |
+
|
10 |
+
|
11 |
+
class MovingMNISTImage(SequenceDataset):
|
12 |
+
def load_data(self, dataset_path: str, split: str) -> np.ndarray:
|
13 |
+
data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy"))
|
14 |
+
# Data is of shape (window, n_samples, width, height)
|
15 |
+
# But we want for keras something of shape (n_samples, window, width, height)
|
16 |
+
data = np.moveaxis(data, 0, 1)
|
17 |
+
# Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
|
18 |
+
data = np.expand_dims(data, axis=-1)
|
19 |
+
if split == "train":
|
20 |
+
data = data[:-1000]
|
21 |
+
else:
|
22 |
+
data = data[-1000:]
|
23 |
+
|
24 |
+
data = np.concatenate([data, data, data], axis=-1)
|
25 |
+
|
26 |
+
return data
|
27 |
+
|
28 |
+
def __getitem__(self, idx):
|
29 |
+
inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size]
|
30 |
+
batch_x = self.data[inds, 0, ...]
|
31 |
+
batch_y = self.data[inds, 1, ...]
|
32 |
+
|
33 |
+
return batch_x, batch_y
|
34 |
+
|
35 |
+
def preprocess_data(self, data: np.ndarray) -> np.ndarray:
|
36 |
+
return data / 255
|
37 |
+
|
38 |
+
|
39 |
+
class MovingMNIST(SequenceDataset):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
dataset_path: str,
|
43 |
+
batch_size: int,
|
44 |
+
split: Literal["train", "validation", "test"] = "train",
|
45 |
+
):
|
46 |
+
self.batch_size = batch_size
|
47 |
+
self.split = split
|
48 |
+
root_path = os.path.join(dataset_path, "moving_mnist", split)
|
49 |
+
self.paths = glob.glob(os.path.join(root_path, "*.npy"))
|
50 |
+
# self.data = self.preprocess_data(self.data)
|
51 |
+
|
52 |
+
self.indices = np.arange(len(self.paths))
|
53 |
+
self.on_epoch_end()
|
54 |
+
|
55 |
+
# def load_data(self, dataset_path: str, split: str) -> np.ndarray:
|
56 |
+
# data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy"))
|
57 |
+
# # Data is of shape (window, n_samples, width, height)
|
58 |
+
# # But we want for keras something of shape (n_samples, window, width, height)
|
59 |
+
# data = np.moveaxis(data, 0, 1)
|
60 |
+
# # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels)
|
61 |
+
# data = np.expand_dims(data, axis=-1)
|
62 |
+
# if split == "train":
|
63 |
+
# data = data[:100]
|
64 |
+
# else:
|
65 |
+
# data = data[100:110]
|
66 |
+
|
67 |
+
# data = np.concatenate([data, data, data], axis=-1)
|
68 |
+
|
69 |
+
# return data
|
70 |
+
|
71 |
+
def __len__(self):
|
72 |
+
return math.ceil(len(self.paths) / self.batch_size)
|
73 |
+
|
74 |
+
def __getitem__(self, idx):
|
75 |
+
inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size]
|
76 |
+
data = self.load_indices(inds)
|
77 |
+
batch_x = np.concatenate([data[:, 0:1, ...], data[:, -1:, ...]], axis=1)
|
78 |
+
batch_y = data[:, 1:, ...]
|
79 |
+
|
80 |
+
return batch_x, batch_y
|
81 |
+
|
82 |
+
def get_fixed_batch(self, idx):
|
83 |
+
self.fixed_indices = (
|
84 |
+
self.fixed_indices
|
85 |
+
if hasattr(self, "fixed_indices")
|
86 |
+
else self.indices[
|
87 |
+
idx * self.batch_size : (idx + 1) * self.batch_size
|
88 |
+
].copy()
|
89 |
+
)
|
90 |
+
data = self.load_indices(self.fixed_indices)
|
91 |
+
batch_x = np.concatenate([data[:, 0:1, ...], data[:, -1:, ...]], axis=1)
|
92 |
+
batch_y = data[:, 1:, ...]
|
93 |
+
|
94 |
+
return batch_x, batch_y
|
95 |
+
|
96 |
+
def load_indices(self, indices):
|
97 |
+
paths_to_load = [self.paths[index] for index in indices]
|
98 |
+
data = [np.load(path) for path in paths_to_load]
|
99 |
+
data = np.array(data)
|
100 |
+
return self.preprocess_data(data)
|
101 |
+
|
102 |
+
def preprocess_data(self, data: np.ndarray) -> np.ndarray:
|
103 |
+
return data / 255
|
ganime/metrics/image.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
from scipy import linalg
|
4 |
+
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
|
7 |
+
inceptionv3 = InceptionV3(include_top=False, weights="imagenet", pooling="avg")
|
8 |
+
|
9 |
+
|
10 |
+
def resize_images(images, new_shape):
|
11 |
+
images = tf.image.resize(images, new_shape)
|
12 |
+
return images
|
13 |
+
|
14 |
+
|
15 |
+
def calculate_fid(real_embeddings, generated_embeddings):
|
16 |
+
# calculate mean and covariance statistics
|
17 |
+
mu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings, rowvar=False)
|
18 |
+
mu2, sigma2 = generated_embeddings.mean(axis=0), np.cov(
|
19 |
+
generated_embeddings, rowvar=False
|
20 |
+
)
|
21 |
+
# calculate sum squared difference between means
|
22 |
+
ssdiff = np.sum((mu1 - mu2) ** 2.0)
|
23 |
+
# calculate sqrt of product between cov
|
24 |
+
covmean = linalg.sqrtm(sigma1.dot(sigma2))
|
25 |
+
# check and correct imaginary numbers from sqrt
|
26 |
+
if np.iscomplexobj(covmean):
|
27 |
+
covmean = covmean.real
|
28 |
+
# calculate score
|
29 |
+
fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
|
30 |
+
return fid
|
31 |
+
|
32 |
+
|
33 |
+
def calculate_images_metrics(dataset, model, total_length):
|
34 |
+
fake_embeddings = []
|
35 |
+
real_embeddings = []
|
36 |
+
|
37 |
+
psnrs = []
|
38 |
+
ssims = []
|
39 |
+
|
40 |
+
for sample in tqdm(dataset, total=total_length):
|
41 |
+
generated = model(sample[0], training=False)[0]
|
42 |
+
generated, real = generated, sample[0]
|
43 |
+
|
44 |
+
real_resized = resize_images(real, (299, 299))
|
45 |
+
generated_resized = resize_images(generated, (299, 299))
|
46 |
+
|
47 |
+
real_activations = inceptionv3(real_resized, training=False)
|
48 |
+
generated_activations = inceptionv3(generated_resized, training=False)
|
49 |
+
fake_embeddings.append(generated_activations)
|
50 |
+
real_embeddings.append(real_activations)
|
51 |
+
|
52 |
+
fake_scaled = tf.cast(((generated * 0.5) + 1) * 255, tf.uint8)
|
53 |
+
real_scaled = tf.cast(((real * 0.5) + 1) * 255, tf.uint8)
|
54 |
+
|
55 |
+
psnrs.append(tf.image.psnr(fake_scaled, real_scaled, 255).numpy())
|
56 |
+
ssims.append(tf.image.ssim(fake_scaled, real_scaled, 255).numpy())
|
57 |
+
|
58 |
+
fid = calculate_fid(
|
59 |
+
tf.concat(fake_embeddings, axis=0).numpy(),
|
60 |
+
tf.concat(real_embeddings, axis=0).numpy(),
|
61 |
+
)
|
62 |
+
|
63 |
+
# kid = calculate_kid(
|
64 |
+
# tf.concat(fake_embeddings, axis=0).numpy(),
|
65 |
+
# tf.concat(real_embeddings, axis=0).numpy(),
|
66 |
+
# )
|
67 |
+
|
68 |
+
psnr = np.array(psnrs).mean()
|
69 |
+
ssim = np.array(ssims).mean()
|
70 |
+
return {"fid": fid, "ssim": ssim, "psnr": psnr}
|
ganime/metrics/video.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
import tensorflow_gan as tfgan
|
4 |
+
import tensorflow_hub as hub
|
5 |
+
from sklearn.metrics.pairwise import polynomial_kernel
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
|
8 |
+
i3d = hub.KerasLayer("https://tfhub.dev/deepmind/i3d-kinetics-400/1")
|
9 |
+
|
10 |
+
|
11 |
+
def resize_videos(videos, target_resolution):
|
12 |
+
"""Runs some preprocessing on the videos for I3D model.
|
13 |
+
Args:
|
14 |
+
videos: <T>[batch_size, num_frames, height, width, depth] The videos to be
|
15 |
+
preprocessed. We don't care about the specific dtype of the videos, it can
|
16 |
+
be anything that tf.image.resize_bilinear accepts. Values are expected to
|
17 |
+
be in [-1, 1].
|
18 |
+
target_resolution: (width, height): target video resolution
|
19 |
+
Returns:
|
20 |
+
videos: <float32>[batch_size, num_frames, height, width, depth]
|
21 |
+
"""
|
22 |
+
min_frames = 9
|
23 |
+
B, T, H, W, C = videos.shape
|
24 |
+
videos = tf.transpose(videos, (1, 0, 2, 3, 4))
|
25 |
+
if T < min_frames:
|
26 |
+
videos = tf.concat([tf.zeros((min_frames - T, B, H, W, C)), videos], axis=0)
|
27 |
+
scaled_videos = tf.map_fn(lambda x: tf.image.resize(x, target_resolution), videos)
|
28 |
+
scaled_videos = tf.transpose(scaled_videos, (1, 0, 2, 3, 4))
|
29 |
+
return scaled_videos
|
30 |
+
|
31 |
+
|
32 |
+
def polynomial_mmd(X, Y):
|
33 |
+
m = X.shape[0]
|
34 |
+
n = Y.shape[0]
|
35 |
+
# compute kernels
|
36 |
+
K_XX = polynomial_kernel(X)
|
37 |
+
K_YY = polynomial_kernel(Y)
|
38 |
+
K_XY = polynomial_kernel(X, Y)
|
39 |
+
# compute mmd distance
|
40 |
+
K_XX_sum = (K_XX.sum() - np.diagonal(K_XX).sum()) / (m * (m - 1))
|
41 |
+
K_YY_sum = (K_YY.sum() - np.diagonal(K_YY).sum()) / (n * (n - 1))
|
42 |
+
K_XY_sum = K_XY.sum() / (m * n)
|
43 |
+
mmd = K_XX_sum + K_YY_sum - 2 * K_XY_sum
|
44 |
+
return mmd
|
45 |
+
|
46 |
+
|
47 |
+
def calculate_ssim_videos(fake, real):
|
48 |
+
fake = tf.cast(((fake * 0.5) + 1) * 255, tf.uint8)
|
49 |
+
real = tf.cast(((real * 0.5) + 1) * 255, tf.uint8)
|
50 |
+
ssims = []
|
51 |
+
for i in range(fake.shape[0]):
|
52 |
+
ssims.append(tf.image.ssim(fake[i], real[i], 255).numpy().mean())
|
53 |
+
|
54 |
+
return np.array(ssims).mean()
|
55 |
+
|
56 |
+
|
57 |
+
def calculate_psnr_videos(fake, real):
|
58 |
+
fake = tf.cast(((fake * 0.5) + 1) * 255, tf.uint8)
|
59 |
+
real = tf.cast(((real * 0.5) + 1) * 255, tf.uint8)
|
60 |
+
psnrs = []
|
61 |
+
for i in range(fake.shape[0]):
|
62 |
+
psnrs.append(tf.image.psnr(fake[i], real[i], 255).numpy().mean())
|
63 |
+
|
64 |
+
return np.array(psnrs).mean()
|
65 |
+
|
66 |
+
|
67 |
+
def calculate_videos_metrics(dataset, model, total_length):
|
68 |
+
fake_embeddings = []
|
69 |
+
real_embeddings = []
|
70 |
+
|
71 |
+
psnrs = []
|
72 |
+
ssims = []
|
73 |
+
|
74 |
+
for sample in tqdm(dataset, total=total_length):
|
75 |
+
generated = model(sample, training=False)
|
76 |
+
generated, real = generated[:, 1:], sample["y"][:, 1:] # ignore first frame
|
77 |
+
|
78 |
+
real_resized = resize_videos(real, (224, 224))
|
79 |
+
generated_resized = resize_videos(generated, (224, 224))
|
80 |
+
|
81 |
+
real_activations = i3d(real_resized)
|
82 |
+
generated_activations = i3d(generated_resized)
|
83 |
+
fake_embeddings.append(generated_activations)
|
84 |
+
real_embeddings.append(real_activations)
|
85 |
+
|
86 |
+
psnrs.append(calculate_psnr_videos(generated, real))
|
87 |
+
ssims.append(calculate_ssim_videos(generated, real))
|
88 |
+
|
89 |
+
# fake_concat, real_concat = tf.concat(fake_embeddings, axis=0), tf.concat(real_embeddings, axis=0)
|
90 |
+
fvd = tfgan.eval.frechet_classifier_distance_from_activations(
|
91 |
+
tf.concat(fake_embeddings, axis=0), tf.concat(real_embeddings, axis=0)
|
92 |
+
)
|
93 |
+
kvd = polynomial_mmd(
|
94 |
+
tf.concat(fake_embeddings, axis=0), tf.concat(real_embeddings, axis=0)
|
95 |
+
)
|
96 |
+
psnr = np.array(psnrs).mean()
|
97 |
+
ssim = np.array(ssims).mean()
|
98 |
+
return {"fvd": fvd, "kvd": kvd, "ssim": ssim, "psnr": psnr}
|
ganime/model/__init__.py
ADDED
File without changes
|
ganime/model/base.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from ganime.model.vqgan_clean.vqgan import VQGAN
|
3 |
+
|
4 |
+
|
5 |
+
def load_model(
|
6 |
+
model: str, config: dict, strategy: tf.distribute.Strategy
|
7 |
+
) -> tf.keras.Model:
|
8 |
+
|
9 |
+
if model == "vqgan":
|
10 |
+
with strategy.scope():
|
11 |
+
print(config["model"])
|
12 |
+
model = VQGAN(**config["model"])
|
13 |
+
|
14 |
+
gen_optimizer = tf.keras.optimizers.Adam(
|
15 |
+
learning_rate=config["trainer"]["gen_lr"],
|
16 |
+
beta_1=config["trainer"]["gen_beta_1"],
|
17 |
+
beta_2=config["trainer"]["gen_beta_2"],
|
18 |
+
clipnorm=config["trainer"]["gen_clip_norm"],
|
19 |
+
)
|
20 |
+
disc_optimizer = tf.keras.optimizers.Adam(
|
21 |
+
learning_rate=config["trainer"]["disc_lr"],
|
22 |
+
beta_1=config["trainer"]["disc_beta_1"],
|
23 |
+
beta_2=config["trainer"]["disc_beta_2"],
|
24 |
+
clipnorm=config["trainer"]["disc_clip_norm"],
|
25 |
+
)
|
26 |
+
model.compile(gen_optimizer=gen_optimizer, disc_optimizer=disc_optimizer)
|
27 |
+
return model
|
28 |
+
else:
|
29 |
+
raise ValueError(f"Unknown model: {model}")
|
30 |
+
|
31 |
+
# if model == "moving_vae":
|
32 |
+
# from ganime.model.moving_vae import MovingVAE
|
33 |
+
|
34 |
+
# with strategy.scope():
|
35 |
+
# model = MovingVAE(input_shape=input_shape)
|
36 |
+
|
37 |
+
# negloglik = lambda x, rv_x: -rv_x.log_prob(x)
|
38 |
+
# model.compile(
|
39 |
+
# optimizer=tf.optimizers.Adam(learning_rate=config["lr"]),
|
40 |
+
# loss=negloglik,
|
41 |
+
# )
|
42 |
+
# # model.build(input_shape=(None, *input_shape))
|
43 |
+
# # model.summary()
|
44 |
+
|
45 |
+
# return model
|
ganime/model/moving_vae.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tensorflow.keras import Model
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
import tensorflow_probability as tfp
|
5 |
+
|
6 |
+
|
7 |
+
class MovingVAE(Model):
|
8 |
+
def __init__(self, input_shape, encoded_size=64, base_depth=32):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.encoded_size = encoded_size
|
12 |
+
self.base_depth = base_depth
|
13 |
+
|
14 |
+
self.prior = tfp.distributions.Independent(
|
15 |
+
tfp.distributions.Normal(loc=tf.zeros(encoded_size), scale=1),
|
16 |
+
reinterpreted_batch_ndims=1,
|
17 |
+
)
|
18 |
+
|
19 |
+
self.encoder = tf.keras.Sequential(
|
20 |
+
[
|
21 |
+
tf.keras.layers.InputLayer(input_shape=input_shape),
|
22 |
+
tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
|
23 |
+
tf.keras.layers.Conv3D(
|
24 |
+
self.base_depth,
|
25 |
+
5,
|
26 |
+
strides=1,
|
27 |
+
padding="same",
|
28 |
+
activation=tf.nn.leaky_relu,
|
29 |
+
),
|
30 |
+
tf.keras.layers.Conv3D(
|
31 |
+
self.base_depth,
|
32 |
+
5,
|
33 |
+
strides=2,
|
34 |
+
padding="same",
|
35 |
+
activation=tf.nn.leaky_relu,
|
36 |
+
),
|
37 |
+
tf.keras.layers.Conv3D(
|
38 |
+
2 * self.base_depth,
|
39 |
+
5,
|
40 |
+
strides=1,
|
41 |
+
padding="same",
|
42 |
+
activation=tf.nn.leaky_relu,
|
43 |
+
),
|
44 |
+
tf.keras.layers.Conv3D(
|
45 |
+
2 * self.base_depth,
|
46 |
+
5,
|
47 |
+
strides=2,
|
48 |
+
padding="same",
|
49 |
+
activation=tf.nn.leaky_relu,
|
50 |
+
),
|
51 |
+
# tf.keras.layers.Conv3D(4 * encoded_size, 7, strides=1,
|
52 |
+
# padding='valid', activation=tf.nn.leaky_relu),
|
53 |
+
tf.keras.layers.Flatten(),
|
54 |
+
tf.keras.layers.Dense(
|
55 |
+
tfp.layers.MultivariateNormalTriL.params_size(self.encoded_size),
|
56 |
+
activation=None,
|
57 |
+
),
|
58 |
+
tfp.layers.MultivariateNormalTriL(
|
59 |
+
self.encoded_size,
|
60 |
+
activity_regularizer=tfp.layers.KLDivergenceRegularizer(self.prior),
|
61 |
+
),
|
62 |
+
]
|
63 |
+
)
|
64 |
+
|
65 |
+
self.decoder = tf.keras.Sequential(
|
66 |
+
[
|
67 |
+
tf.keras.layers.InputLayer(input_shape=[self.encoded_size]),
|
68 |
+
tf.keras.layers.Reshape([1, 1, 1, self.encoded_size]),
|
69 |
+
tf.keras.layers.Conv3DTranspose(
|
70 |
+
self.base_depth,
|
71 |
+
(5, 4, 4),
|
72 |
+
strides=1,
|
73 |
+
padding="valid",
|
74 |
+
activation=tf.nn.leaky_relu,
|
75 |
+
),
|
76 |
+
tf.keras.layers.Conv3DTranspose(
|
77 |
+
2 * self.base_depth,
|
78 |
+
(5, 4, 4),
|
79 |
+
strides=(1, 2, 2),
|
80 |
+
padding="same",
|
81 |
+
activation=tf.nn.leaky_relu,
|
82 |
+
),
|
83 |
+
tf.keras.layers.Conv3DTranspose(
|
84 |
+
2 * self.base_depth,
|
85 |
+
(5, 4, 4),
|
86 |
+
strides=2,
|
87 |
+
padding="same",
|
88 |
+
activation=tf.nn.leaky_relu,
|
89 |
+
),
|
90 |
+
tf.keras.layers.Conv3DTranspose(
|
91 |
+
self.base_depth,
|
92 |
+
(5, 4, 4),
|
93 |
+
strides=(1, 2, 2),
|
94 |
+
padding="same",
|
95 |
+
activation=tf.nn.leaky_relu,
|
96 |
+
),
|
97 |
+
tf.keras.layers.Conv3DTranspose(
|
98 |
+
self.base_depth,
|
99 |
+
(5, 4, 4),
|
100 |
+
strides=2,
|
101 |
+
padding="same",
|
102 |
+
activation=tf.nn.leaky_relu,
|
103 |
+
),
|
104 |
+
tf.keras.layers.Conv3DTranspose(
|
105 |
+
self.base_depth,
|
106 |
+
(5, 4, 4),
|
107 |
+
strides=1,
|
108 |
+
padding="same",
|
109 |
+
activation=tf.nn.leaky_relu,
|
110 |
+
),
|
111 |
+
tf.keras.layers.Conv2D(
|
112 |
+
filters=1, kernel_size=5, strides=1, padding="same", activation=None
|
113 |
+
),
|
114 |
+
tf.keras.layers.Flatten(),
|
115 |
+
tfp.layers.IndependentBernoulli(
|
116 |
+
input_shape, tfp.distributions.Bernoulli.logits
|
117 |
+
),
|
118 |
+
]
|
119 |
+
)
|
120 |
+
|
121 |
+
self.model = tf.keras.Model(
|
122 |
+
inputs=self.encoder.inputs, outputs=self.decoder(self.encoder.outputs[0])
|
123 |
+
)
|
124 |
+
|
125 |
+
def call(self, inputs):
|
126 |
+
return self.model(inputs)
|
ganime/model/p2p/__init__.py
ADDED
File without changes
|
ganime/model/p2p/p2p.py
ADDED
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from statistics import mode
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.python.keras import Model, Sequential
|
5 |
+
from tensorflow.python.keras.layers import Dense, LSTMCell, RNN, Conv2D, Conv2DTranspose
|
6 |
+
from tensorflow.keras.layers import BatchNormalization, TimeDistributed
|
7 |
+
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
|
8 |
+
from tensorflow.keras.layers import Activation
|
9 |
+
|
10 |
+
# from tensorflow_probability.python.layers.dense_variational import (
|
11 |
+
# DenseReparameterization,
|
12 |
+
# )
|
13 |
+
# import tensorflow_probability as tfp
|
14 |
+
from tensorflow.keras.losses import Loss
|
15 |
+
|
16 |
+
|
17 |
+
class KLCriterion(Loss):
|
18 |
+
def call(self, y_true, y_pred):
|
19 |
+
(mu1, logvar1), (mu2, logvar2) = y_true, y_pred
|
20 |
+
|
21 |
+
"""KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2))"""
|
22 |
+
sigma1 = tf.exp(tf.math.multiply(logvar1, 0.5))
|
23 |
+
sigma2 = tf.exp(tf.math.multiply(logvar2, 0.5))
|
24 |
+
|
25 |
+
kld = (
|
26 |
+
tf.math.log(sigma2 / sigma1)
|
27 |
+
+ (tf.exp(logvar1) + tf.square(mu1 - mu2)) / (2 * tf.exp(logvar2))
|
28 |
+
- 0.5
|
29 |
+
)
|
30 |
+
return tf.reduce_sum(kld) / 22
|
31 |
+
|
32 |
+
|
33 |
+
class Encoder(Model):
|
34 |
+
def __init__(self, dim, nc=1):
|
35 |
+
super().__init__()
|
36 |
+
self.dim = dim
|
37 |
+
self.c1 = Sequential(
|
38 |
+
[
|
39 |
+
Conv2D(64, kernel_size=4, strides=2, padding="same"),
|
40 |
+
BatchNormalization(),
|
41 |
+
LeakyReLU(alpha=0.2),
|
42 |
+
]
|
43 |
+
)
|
44 |
+
self.c2 = Sequential(
|
45 |
+
[
|
46 |
+
Conv2D(128, kernel_size=4, strides=2, padding="same"),
|
47 |
+
BatchNormalization(),
|
48 |
+
LeakyReLU(alpha=0.2),
|
49 |
+
]
|
50 |
+
)
|
51 |
+
self.c3 = Sequential(
|
52 |
+
[
|
53 |
+
Conv2D(256, kernel_size=4, strides=2, padding="same"),
|
54 |
+
BatchNormalization(),
|
55 |
+
LeakyReLU(alpha=0.2),
|
56 |
+
]
|
57 |
+
)
|
58 |
+
self.c4 = Sequential(
|
59 |
+
[
|
60 |
+
Conv2D(512, kernel_size=4, strides=2, padding="same"),
|
61 |
+
BatchNormalization(),
|
62 |
+
LeakyReLU(alpha=0.2),
|
63 |
+
]
|
64 |
+
)
|
65 |
+
self.c5 = Sequential(
|
66 |
+
[
|
67 |
+
Conv2D(self.dim, kernel_size=4, strides=1, padding="valid"),
|
68 |
+
BatchNormalization(),
|
69 |
+
Activation("tanh"),
|
70 |
+
]
|
71 |
+
)
|
72 |
+
|
73 |
+
def call(self, input):
|
74 |
+
h1 = self.c1(input)
|
75 |
+
h2 = self.c2(h1)
|
76 |
+
h3 = self.c3(h2)
|
77 |
+
h4 = self.c4(h3)
|
78 |
+
h5 = self.c5(h4)
|
79 |
+
return tf.reshape(h5, (-1, self.dim)), [h1, h2, h3, h4, h5]
|
80 |
+
|
81 |
+
|
82 |
+
class Decoder(Model):
|
83 |
+
def __init__(self, dim, nc=1):
|
84 |
+
super().__init__()
|
85 |
+
self.dim = dim
|
86 |
+
self.upc1 = Sequential(
|
87 |
+
[
|
88 |
+
Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid"),
|
89 |
+
BatchNormalization(),
|
90 |
+
LeakyReLU(alpha=0.2),
|
91 |
+
]
|
92 |
+
)
|
93 |
+
self.upc2 = Sequential(
|
94 |
+
[
|
95 |
+
Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
|
96 |
+
BatchNormalization(),
|
97 |
+
LeakyReLU(alpha=0.2),
|
98 |
+
]
|
99 |
+
)
|
100 |
+
self.upc3 = Sequential(
|
101 |
+
[
|
102 |
+
Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
|
103 |
+
BatchNormalization(),
|
104 |
+
LeakyReLU(alpha=0.2),
|
105 |
+
]
|
106 |
+
)
|
107 |
+
self.upc4 = Sequential(
|
108 |
+
[
|
109 |
+
Conv2DTranspose(64, kernel_size=4, strides=2, padding="same"),
|
110 |
+
BatchNormalization(),
|
111 |
+
LeakyReLU(alpha=0.2),
|
112 |
+
]
|
113 |
+
)
|
114 |
+
self.upc5 = Sequential(
|
115 |
+
[
|
116 |
+
Conv2DTranspose(1, kernel_size=4, strides=2, padding="same"),
|
117 |
+
Activation("sigmoid"),
|
118 |
+
]
|
119 |
+
)
|
120 |
+
|
121 |
+
def call(self, input):
|
122 |
+
vec, skip = input
|
123 |
+
d1 = self.upc1(tf.reshape(vec, (-1, 1, 1, self.dim)))
|
124 |
+
d2 = self.upc2(tf.concat([d1, skip[3]], axis=-1))
|
125 |
+
d3 = self.upc3(tf.concat([d2, skip[2]], axis=-1))
|
126 |
+
d4 = self.upc4(tf.concat([d3, skip[1]], axis=-1))
|
127 |
+
output = self.upc5(tf.concat([d4, skip[0]], axis=-1))
|
128 |
+
return output
|
129 |
+
|
130 |
+
|
131 |
+
class MyLSTM(Model):
|
132 |
+
def __init__(self, input_shape, hidden_size, output_size, n_layers):
|
133 |
+
super().__init__()
|
134 |
+
self.hidden_size = hidden_size
|
135 |
+
self.n_layers = n_layers
|
136 |
+
self.embed = Dense(hidden_size, input_dim=input_shape)
|
137 |
+
# self.lstm = Sequential(
|
138 |
+
# [LSTMCell(hidden_size) for _ in range(n_layers)], name="lstm"
|
139 |
+
# )
|
140 |
+
# self.lstm = self.create_lstm(hidden_size, n_layers)
|
141 |
+
self.lstm = LSTMCell(hidden_size)
|
142 |
+
self.out = Dense(output_size)
|
143 |
+
|
144 |
+
def init_hidden(self, batch_size):
|
145 |
+
hidden = []
|
146 |
+
for i in range(self.n_layers):
|
147 |
+
hidden.append(
|
148 |
+
(
|
149 |
+
tf.Variable(tf.zeros([batch_size, self.hidden_size])),
|
150 |
+
tf.Variable(tf.zeros([batch_size, self.hidden_size])),
|
151 |
+
)
|
152 |
+
)
|
153 |
+
self.__dict__["hidden"] = hidden
|
154 |
+
|
155 |
+
def build(self, input_shape):
|
156 |
+
self.init_hidden(input_shape[0])
|
157 |
+
|
158 |
+
def call(self, inputs):
|
159 |
+
h_in = self.embed(inputs)
|
160 |
+
for i in range(self.n_layers):
|
161 |
+
_, self.hidden[i] = self.lstm(h_in, self.hidden[i])
|
162 |
+
h_in = self.hidden[i][0]
|
163 |
+
|
164 |
+
return self.out(h_in)
|
165 |
+
|
166 |
+
|
167 |
+
class MyGaussianLSTM(Model):
|
168 |
+
def __init__(self, input_shape, hidden_size, output_size, n_layers):
|
169 |
+
super().__init__()
|
170 |
+
self.hidden_size = hidden_size
|
171 |
+
self.n_layers = n_layers
|
172 |
+
self.embed = Dense(hidden_size, input_dim=input_shape)
|
173 |
+
# self.lstm = Sequential(
|
174 |
+
# [LSTMCell(hidden_size) for _ in range(n_layers)], name="lstm"
|
175 |
+
# )
|
176 |
+
self.lstm = LSTMCell(hidden_size)
|
177 |
+
self.mu_net = Dense(output_size)
|
178 |
+
self.logvar_net = Dense(output_size)
|
179 |
+
# self.out = Sequential(
|
180 |
+
# [
|
181 |
+
# tf.keras.layers.Dense(
|
182 |
+
# tfp.layers.MultivariateNormalTriL.params_size(output_size),
|
183 |
+
# activation=None,
|
184 |
+
# ),
|
185 |
+
# tfp.layers.MultivariateNormalTriL(output_size),
|
186 |
+
# ]
|
187 |
+
# )
|
188 |
+
|
189 |
+
def reparameterize(self, mu, logvar: tf.Tensor):
|
190 |
+
logvar = tf.math.exp(logvar * 0.5)
|
191 |
+
eps = tf.random.normal(logvar.shape)
|
192 |
+
return tf.add(tf.math.multiply(eps, logvar), mu)
|
193 |
+
|
194 |
+
def init_hidden(self, batch_size):
|
195 |
+
hidden = []
|
196 |
+
for i in range(self.n_layers):
|
197 |
+
hidden.append(
|
198 |
+
(
|
199 |
+
tf.Variable(tf.zeros([batch_size, self.hidden_size])),
|
200 |
+
tf.Variable(tf.zeros([batch_size, self.hidden_size])),
|
201 |
+
)
|
202 |
+
)
|
203 |
+
self.__dict__["hidden"] = hidden
|
204 |
+
|
205 |
+
def build(self, input_shape):
|
206 |
+
self.init_hidden(input_shape[0])
|
207 |
+
|
208 |
+
def call(self, inputs):
|
209 |
+
h_in = self.embed(inputs)
|
210 |
+
for i in range(self.n_layers):
|
211 |
+
# print(h_in.shape, self.hidden[i][0].shape, self.hidden[i][0].shape)
|
212 |
+
|
213 |
+
_, self.hidden[i] = self.lstm(h_in, self.hidden[i])
|
214 |
+
h_in = self.hidden[i][0]
|
215 |
+
mu = self.mu_net(h_in)
|
216 |
+
logvar = self.logvar_net(h_in)
|
217 |
+
z = self.reparameterize(mu, logvar)
|
218 |
+
return z, mu, logvar
|
219 |
+
|
220 |
+
|
221 |
+
class P2P(Model):
|
222 |
+
def __init__(
|
223 |
+
self,
|
224 |
+
channels: int = 1,
|
225 |
+
g_dim: int = 128,
|
226 |
+
z_dim: int = 10,
|
227 |
+
rnn_size: int = 256,
|
228 |
+
prior_rnn_layers: int = 1,
|
229 |
+
posterior_rnn_layers: int = 1,
|
230 |
+
predictor_rnn_layers: float = 1,
|
231 |
+
skip_prob: float = 0.5,
|
232 |
+
n_past: int = 1,
|
233 |
+
last_frame_skip: bool = False,
|
234 |
+
beta: float = 0.0001,
|
235 |
+
weight_align: float = 0.1,
|
236 |
+
weight_cpc: float = 100,
|
237 |
+
):
|
238 |
+
super().__init__()
|
239 |
+
self.channels = channels
|
240 |
+
self.g_dim = g_dim
|
241 |
+
self.z_dim = z_dim
|
242 |
+
self.rnn_size = rnn_size
|
243 |
+
self.prior_rnn_layers = prior_rnn_layers
|
244 |
+
self.posterior_rnn_layers = posterior_rnn_layers
|
245 |
+
self.predictor_rnn_layers = predictor_rnn_layers
|
246 |
+
|
247 |
+
self.skip_prob = skip_prob
|
248 |
+
self.n_past = n_past
|
249 |
+
self.last_frame_skip = last_frame_skip
|
250 |
+
self.beta = beta
|
251 |
+
self.weight_align = weight_align
|
252 |
+
self.weight_cpc = weight_cpc
|
253 |
+
|
254 |
+
self.frame_predictor = MyLSTM(
|
255 |
+
self.g_dim + self.z_dim + 1 + 1,
|
256 |
+
self.rnn_size,
|
257 |
+
self.g_dim,
|
258 |
+
self.predictor_rnn_layers,
|
259 |
+
)
|
260 |
+
|
261 |
+
self.prior = MyGaussianLSTM(
|
262 |
+
self.g_dim + self.g_dim + 1 + 1,
|
263 |
+
self.rnn_size,
|
264 |
+
self.z_dim,
|
265 |
+
self.prior_rnn_layers,
|
266 |
+
)
|
267 |
+
|
268 |
+
self.posterior = MyGaussianLSTM(
|
269 |
+
self.g_dim + self.g_dim + 1 + 1,
|
270 |
+
self.rnn_size,
|
271 |
+
self.z_dim,
|
272 |
+
self.posterior_rnn_layers,
|
273 |
+
)
|
274 |
+
|
275 |
+
self.encoder = Encoder(self.g_dim, self.channels)
|
276 |
+
self.decoder = Decoder(self.g_dim, self.channels)
|
277 |
+
|
278 |
+
# criterions
|
279 |
+
self.mse_criterion = tf.keras.losses.MeanSquaredError()
|
280 |
+
self.kl_criterion = KLCriterion()
|
281 |
+
self.align_criterion = tf.keras.losses.MeanSquaredError()
|
282 |
+
|
283 |
+
# optimizers
|
284 |
+
self.frame_predictor_optimizer = tf.keras.optimizers.Adam(
|
285 |
+
learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8
|
286 |
+
)
|
287 |
+
self.posterior_optimizer = tf.keras.optimizers.Adam(
|
288 |
+
learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8
|
289 |
+
)
|
290 |
+
self.prior_optimizer = tf.keras.optimizers.Adam(
|
291 |
+
learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8
|
292 |
+
)
|
293 |
+
self.encoder_optimizer = tf.keras.optimizers.Adam(
|
294 |
+
learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8
|
295 |
+
)
|
296 |
+
self.decoder_optimizer = tf.keras.optimizers.Adam(
|
297 |
+
learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8
|
298 |
+
)
|
299 |
+
|
300 |
+
def get_global_descriptor(self, x, start_ix=0, cp_ix=None):
|
301 |
+
"""Get the global descriptor based on x, start_ix, cp_ix."""
|
302 |
+
if cp_ix is None:
|
303 |
+
cp_ix = x.shape[1] - 1
|
304 |
+
|
305 |
+
x_cp = x[:, cp_ix, ...]
|
306 |
+
h_cp = self.encoder(x_cp)[0] # 1 is input for skip-connection
|
307 |
+
|
308 |
+
return x_cp, h_cp
|
309 |
+
|
310 |
+
def call(self, x, start_ix=0, cp_ix=-1):
|
311 |
+
batch_size = x.shape[0]
|
312 |
+
|
313 |
+
with tf.GradientTape(persistent=True) as tape:
|
314 |
+
mse_loss = 0
|
315 |
+
kld_loss = 0
|
316 |
+
cpc_loss = 0
|
317 |
+
align_loss = 0
|
318 |
+
|
319 |
+
seq_len = x.shape[1]
|
320 |
+
start_ix = 0
|
321 |
+
cp_ix = seq_len - 1
|
322 |
+
x_cp, global_z = self.get_global_descriptor(
|
323 |
+
x, start_ix, cp_ix
|
324 |
+
) # here global_z is h_cp
|
325 |
+
|
326 |
+
skip_prob = self.skip_prob
|
327 |
+
|
328 |
+
prev_i = 0
|
329 |
+
max_skip_count = seq_len * skip_prob
|
330 |
+
skip_count = 0
|
331 |
+
probs = np.random.uniform(low=0, high=1, size=seq_len - 1)
|
332 |
+
|
333 |
+
for i in range(1, seq_len):
|
334 |
+
if (
|
335 |
+
probs[i - 1] <= skip_prob
|
336 |
+
and i >= self.n_past
|
337 |
+
and skip_count < max_skip_count
|
338 |
+
and i != 1
|
339 |
+
and i != cp_ix
|
340 |
+
):
|
341 |
+
skip_count += 1
|
342 |
+
continue
|
343 |
+
|
344 |
+
time_until_cp = tf.fill([batch_size, 1], (cp_ix - i + 1) / cp_ix)
|
345 |
+
delta_time = tf.fill([batch_size, 1], ((i - prev_i) / cp_ix))
|
346 |
+
prev_i = i
|
347 |
+
|
348 |
+
h = self.encoder(x[:, i - 1, ...])
|
349 |
+
h_target = self.encoder(x[:, i, ...])[0]
|
350 |
+
|
351 |
+
if self.last_frame_skip or i <= self.n_past:
|
352 |
+
h, skip = h
|
353 |
+
else:
|
354 |
+
h = h[0]
|
355 |
+
|
356 |
+
# Control Point Aware
|
357 |
+
h_cpaw = tf.concat([h, global_z, time_until_cp, delta_time], axis=1)
|
358 |
+
h_target_cpaw = tf.concat(
|
359 |
+
[h_target, global_z, time_until_cp, delta_time], axis=1
|
360 |
+
)
|
361 |
+
zt, mu, logvar = self.posterior(h_target_cpaw)
|
362 |
+
zt_p, mu_p, logvar_p = self.prior(h_cpaw)
|
363 |
+
|
364 |
+
concat = tf.concat([h, zt, time_until_cp, delta_time], axis=1)
|
365 |
+
h_pred = self.frame_predictor(concat)
|
366 |
+
x_pred = self.decoder([h_pred, skip])
|
367 |
+
|
368 |
+
if i == cp_ix: # the gen-cp-frame should be exactly as x_cp
|
369 |
+
h_pred_p = self.frame_predictor(
|
370 |
+
tf.concat([h, zt_p, time_until_cp, delta_time], axis=1)
|
371 |
+
)
|
372 |
+
x_pred_p = self.decoder([h_pred_p, skip])
|
373 |
+
cpc_loss = self.mse_criterion(x_pred_p, x_cp)
|
374 |
+
|
375 |
+
if i > 1:
|
376 |
+
align_loss += self.align_criterion(h[0], h_pred)
|
377 |
+
|
378 |
+
mse_loss += self.mse_criterion(x_pred, x[:, i, ...])
|
379 |
+
kld_loss += self.kl_criterion((mu, logvar), (mu_p, logvar_p))
|
380 |
+
|
381 |
+
# backward
|
382 |
+
loss = mse_loss + kld_loss * self.beta + align_loss * self.weight_align
|
383 |
+
|
384 |
+
prior_loss = kld_loss + cpc_loss * self.weight_cpc
|
385 |
+
|
386 |
+
var_list_frame_predictor = self.frame_predictor.trainable_variables
|
387 |
+
var_list_posterior = self.posterior.trainable_variables
|
388 |
+
var_list_prior = self.prior.trainable_variables
|
389 |
+
var_list_encoder = self.encoder.trainable_variables
|
390 |
+
var_list_decoder = self.decoder.trainable_variables
|
391 |
+
|
392 |
+
# mse: frame_predictor + decoder
|
393 |
+
# align: frame_predictor + encoder
|
394 |
+
# kld: posterior + prior + encoder
|
395 |
+
|
396 |
+
var_list_without_prior = (
|
397 |
+
var_list_frame_predictor
|
398 |
+
+ var_list_posterior
|
399 |
+
+ var_list_encoder
|
400 |
+
+ var_list_decoder
|
401 |
+
)
|
402 |
+
|
403 |
+
gradients_without_prior = tape.gradient(
|
404 |
+
loss,
|
405 |
+
var_list_without_prior,
|
406 |
+
)
|
407 |
+
gradients_prior = tape.gradient(
|
408 |
+
prior_loss,
|
409 |
+
var_list_prior,
|
410 |
+
)
|
411 |
+
|
412 |
+
self.update_model_without_prior(
|
413 |
+
gradients_without_prior,
|
414 |
+
var_list_without_prior,
|
415 |
+
)
|
416 |
+
self.update_prior(gradients_prior, var_list_prior)
|
417 |
+
del tape
|
418 |
+
|
419 |
+
return (
|
420 |
+
mse_loss / seq_len,
|
421 |
+
kld_loss / seq_len,
|
422 |
+
cpc_loss / seq_len,
|
423 |
+
align_loss / seq_len,
|
424 |
+
)
|
425 |
+
|
426 |
+
def p2p_generate(
|
427 |
+
self,
|
428 |
+
x,
|
429 |
+
len_output,
|
430 |
+
eval_cp_ix,
|
431 |
+
start_ix=0,
|
432 |
+
cp_ix=-1,
|
433 |
+
model_mode="full",
|
434 |
+
skip_frame=False,
|
435 |
+
init_hidden=True,
|
436 |
+
):
|
437 |
+
batch_size, num_frames, h, w, channels = x.shape
|
438 |
+
dim_shape = (h, w, channels)
|
439 |
+
|
440 |
+
gen_seq = [x[:, 0, ...]]
|
441 |
+
x_in = x[:, 0, ...]
|
442 |
+
|
443 |
+
seq_len = x.shape[1]
|
444 |
+
cp_ix = seq_len - 1
|
445 |
+
|
446 |
+
x_cp, global_z = self.get_global_descriptor(
|
447 |
+
x, cp_ix=cp_ix
|
448 |
+
) # here global_z is h_cp
|
449 |
+
|
450 |
+
skip_prob = self.skip_prob
|
451 |
+
|
452 |
+
prev_i = 0
|
453 |
+
max_skip_count = seq_len * skip_prob
|
454 |
+
skip_count = 0
|
455 |
+
probs = np.random.uniform(0, 1, len_output - 1)
|
456 |
+
|
457 |
+
for i in range(1, len_output):
|
458 |
+
if (
|
459 |
+
probs[i - 1] <= skip_prob
|
460 |
+
and i >= self.n_past
|
461 |
+
and skip_count < max_skip_count
|
462 |
+
and i != 1
|
463 |
+
and i != (len_output - 1)
|
464 |
+
and skip_frame
|
465 |
+
):
|
466 |
+
skip_count += 1
|
467 |
+
gen_seq.append(tf.zeros_like(x_in))
|
468 |
+
continue
|
469 |
+
|
470 |
+
time_until_cp = tf.fill([batch_size, 1], (eval_cp_ix - i + 1) / eval_cp_ix)
|
471 |
+
|
472 |
+
delta_time = tf.fill([batch_size, 1], ((i - prev_i) / eval_cp_ix))
|
473 |
+
|
474 |
+
prev_i = i
|
475 |
+
|
476 |
+
h = self.encoder(x_in)
|
477 |
+
|
478 |
+
if self.last_frame_skip or i == 1 or i < self.n_past:
|
479 |
+
h, skip = h
|
480 |
+
else:
|
481 |
+
h, _ = h
|
482 |
+
|
483 |
+
h_cpaw = tf.concat([h, global_z, time_until_cp, delta_time], axis=1)
|
484 |
+
|
485 |
+
if i < self.n_past:
|
486 |
+
h_target = self.encoder(x[:, i, ...])[0]
|
487 |
+
h_target_cpaw = tf.concat(
|
488 |
+
[h_target, global_z, time_until_cp, delta_time], axis=1
|
489 |
+
)
|
490 |
+
|
491 |
+
zt, _, _ = self.posterior(h_target_cpaw)
|
492 |
+
zt_p, _, _ = self.prior(h_cpaw)
|
493 |
+
|
494 |
+
if model_mode == "posterior" or model_mode == "full":
|
495 |
+
self.frame_predictor(
|
496 |
+
tf.concat([h, zt, time_until_cp, delta_time], axis=1)
|
497 |
+
)
|
498 |
+
elif model_mode == "prior":
|
499 |
+
self.frame_predictor(
|
500 |
+
tf.concat([h, zt_p, time_until_cp, delta_time], axis=1)
|
501 |
+
)
|
502 |
+
|
503 |
+
x_in = x[:, i, ...]
|
504 |
+
gen_seq.append(x_in)
|
505 |
+
else:
|
506 |
+
if i < num_frames:
|
507 |
+
h_target = self.encoder(x[:, i, ...])[0]
|
508 |
+
h_target_cpaw = tf.concat(
|
509 |
+
[h_target, global_z, time_until_cp, delta_time], axis=1
|
510 |
+
)
|
511 |
+
else:
|
512 |
+
h_target_cpaw = h_cpaw
|
513 |
+
|
514 |
+
zt, _, _ = self.posterior(h_target_cpaw)
|
515 |
+
zt_p, _, _ = self.prior(h_cpaw)
|
516 |
+
|
517 |
+
if model_mode == "posterior":
|
518 |
+
h = self.frame_predictor(
|
519 |
+
tf.concat([h, zt, time_until_cp, delta_time], axis=1)
|
520 |
+
)
|
521 |
+
elif model_mode == "prior" or model_mode == "full":
|
522 |
+
h = self.frame_predictor(
|
523 |
+
tf.concat([h, zt_p, time_until_cp, delta_time], axis=1)
|
524 |
+
)
|
525 |
+
|
526 |
+
x_in = self.decoder([h, skip])
|
527 |
+
gen_seq.append(x_in)
|
528 |
+
return tf.stack(gen_seq, axis=1)
|
529 |
+
|
530 |
+
def update_model_without_prior(self, gradients, var_list):
|
531 |
+
self.frame_predictor_optimizer.apply_gradients(zip(gradients, var_list))
|
532 |
+
self.posterior_optimizer.apply_gradients(zip(gradients, var_list))
|
533 |
+
self.encoder_optimizer.apply_gradients(zip(gradients, var_list))
|
534 |
+
self.decoder_optimizer.apply_gradients(zip(gradients, var_list))
|
535 |
+
|
536 |
+
def update_prior(self, gradients, var_list):
|
537 |
+
self.prior_optimizer.apply_gradients(zip(gradients, var_list))
|
538 |
+
|
539 |
+
# def update_model_without_prior(self):
|
540 |
+
# self.frame_predictor_optimizer.step()
|
541 |
+
# self.posterior_optimizer.step()
|
542 |
+
# self.encoder_optimizer.step()
|
543 |
+
# self.decoder_optimizer.step()
|
ganime/model/p2p/p2p_test.py
ADDED
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm.auto import tqdm
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import Model, Sequential
|
5 |
+
from tensorflow.keras.layers import (
|
6 |
+
LSTM,
|
7 |
+
LSTMCell,
|
8 |
+
Activation,
|
9 |
+
BatchNormalization,
|
10 |
+
Conv2D,
|
11 |
+
Conv2DTranspose,
|
12 |
+
Conv3D,
|
13 |
+
Conv3DTranspose,
|
14 |
+
Dense,
|
15 |
+
Flatten,
|
16 |
+
Input,
|
17 |
+
Layer,
|
18 |
+
LeakyReLU,
|
19 |
+
MaxPooling2D,
|
20 |
+
Reshape,
|
21 |
+
TimeDistributed,
|
22 |
+
UpSampling2D,
|
23 |
+
)
|
24 |
+
from tensorflow.keras.losses import Loss
|
25 |
+
from tensorflow.keras.losses import KLDivergence, MeanSquaredError
|
26 |
+
|
27 |
+
# from tensorflow_probability.python.layers.dense_variational import (
|
28 |
+
# DenseReparameterization,
|
29 |
+
# )
|
30 |
+
# import tensorflow_probability as tfp
|
31 |
+
from tensorflow.keras.losses import Loss
|
32 |
+
|
33 |
+
initializer_conv_dense = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
|
34 |
+
initializer_batch_norm = tf.keras.initializers.RandomNormal(mean=1.0, stddev=0.02)
|
35 |
+
|
36 |
+
|
37 |
+
class KLCriterion(Loss):
|
38 |
+
def call(self, y_true, y_pred):
|
39 |
+
(mu1, logvar1), (mu2, logvar2) = y_true, y_pred
|
40 |
+
|
41 |
+
"""KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2))"""
|
42 |
+
sigma1 = tf.exp(tf.math.multiply(logvar1, 0.5))
|
43 |
+
sigma2 = tf.exp(tf.math.multiply(logvar2, 0.5))
|
44 |
+
|
45 |
+
kld = (
|
46 |
+
tf.math.log(sigma2 / sigma1)
|
47 |
+
+ (tf.exp(logvar1) + tf.square(mu1 - mu2)) / (2 * tf.exp(logvar2))
|
48 |
+
- 0.5
|
49 |
+
)
|
50 |
+
return tf.reduce_sum(kld) / 100
|
51 |
+
|
52 |
+
|
53 |
+
class Encoder(Model):
|
54 |
+
def __init__(self, dim, nc=1):
|
55 |
+
super().__init__()
|
56 |
+
self.dim = dim
|
57 |
+
self.c1 = Sequential(
|
58 |
+
[
|
59 |
+
Conv2D(
|
60 |
+
64,
|
61 |
+
kernel_size=4,
|
62 |
+
strides=2,
|
63 |
+
padding="same",
|
64 |
+
kernel_initializer=initializer_conv_dense,
|
65 |
+
),
|
66 |
+
# BatchNormalization(),
|
67 |
+
LeakyReLU(alpha=0.2),
|
68 |
+
]
|
69 |
+
)
|
70 |
+
self.c2 = Sequential(
|
71 |
+
[
|
72 |
+
Conv2D(
|
73 |
+
128,
|
74 |
+
kernel_size=4,
|
75 |
+
strides=2,
|
76 |
+
padding="same",
|
77 |
+
kernel_initializer=initializer_conv_dense,
|
78 |
+
),
|
79 |
+
# BatchNormalization(),
|
80 |
+
LeakyReLU(alpha=0.2),
|
81 |
+
]
|
82 |
+
)
|
83 |
+
self.c3 = Sequential(
|
84 |
+
[
|
85 |
+
Conv2D(
|
86 |
+
256,
|
87 |
+
kernel_size=4,
|
88 |
+
strides=2,
|
89 |
+
padding="same",
|
90 |
+
kernel_initializer=initializer_conv_dense,
|
91 |
+
),
|
92 |
+
# BatchNormalization(),
|
93 |
+
LeakyReLU(alpha=0.2),
|
94 |
+
]
|
95 |
+
)
|
96 |
+
self.c4 = Sequential(
|
97 |
+
[
|
98 |
+
Conv2D(
|
99 |
+
512,
|
100 |
+
kernel_size=4,
|
101 |
+
strides=2,
|
102 |
+
padding="same",
|
103 |
+
kernel_initializer=initializer_conv_dense,
|
104 |
+
),
|
105 |
+
# BatchNormalization(),
|
106 |
+
LeakyReLU(alpha=0.2),
|
107 |
+
]
|
108 |
+
)
|
109 |
+
self.c5 = Sequential(
|
110 |
+
[
|
111 |
+
Conv2D(
|
112 |
+
self.dim,
|
113 |
+
kernel_size=4,
|
114 |
+
strides=1,
|
115 |
+
padding="valid",
|
116 |
+
kernel_initializer=initializer_conv_dense,
|
117 |
+
),
|
118 |
+
# BatchNormalization(),
|
119 |
+
Activation("tanh"),
|
120 |
+
]
|
121 |
+
)
|
122 |
+
|
123 |
+
def call(self, input):
|
124 |
+
h1 = self.c1(input)
|
125 |
+
h2 = self.c2(h1)
|
126 |
+
h3 = self.c3(h2)
|
127 |
+
h4 = self.c4(h3)
|
128 |
+
h5 = self.c5(h4)
|
129 |
+
return tf.reshape(h5, (-1, self.dim)), [h1, h2, h3, h4, h5]
|
130 |
+
|
131 |
+
|
132 |
+
class Decoder(Model):
|
133 |
+
def __init__(self, dim, nc=1):
|
134 |
+
super().__init__()
|
135 |
+
self.dim = dim
|
136 |
+
self.upc1 = Sequential(
|
137 |
+
[
|
138 |
+
Conv2DTranspose(
|
139 |
+
512,
|
140 |
+
kernel_size=4,
|
141 |
+
strides=1,
|
142 |
+
padding="valid",
|
143 |
+
kernel_initializer=initializer_conv_dense,
|
144 |
+
),
|
145 |
+
# BatchNormalization(),
|
146 |
+
LeakyReLU(alpha=0.2),
|
147 |
+
]
|
148 |
+
)
|
149 |
+
self.upc2 = Sequential(
|
150 |
+
[
|
151 |
+
Conv2DTranspose(
|
152 |
+
256,
|
153 |
+
kernel_size=4,
|
154 |
+
strides=2,
|
155 |
+
padding="same",
|
156 |
+
kernel_initializer=initializer_conv_dense,
|
157 |
+
),
|
158 |
+
# BatchNormalization(),
|
159 |
+
LeakyReLU(alpha=0.2),
|
160 |
+
]
|
161 |
+
)
|
162 |
+
self.upc3 = Sequential(
|
163 |
+
[
|
164 |
+
Conv2DTranspose(
|
165 |
+
128,
|
166 |
+
kernel_size=4,
|
167 |
+
strides=2,
|
168 |
+
padding="same",
|
169 |
+
kernel_initializer=initializer_conv_dense,
|
170 |
+
),
|
171 |
+
# BatchNormalization(),
|
172 |
+
LeakyReLU(alpha=0.2),
|
173 |
+
]
|
174 |
+
)
|
175 |
+
self.upc4 = Sequential(
|
176 |
+
[
|
177 |
+
Conv2DTranspose(
|
178 |
+
64,
|
179 |
+
kernel_size=4,
|
180 |
+
strides=2,
|
181 |
+
padding="same",
|
182 |
+
kernel_initializer=initializer_conv_dense,
|
183 |
+
),
|
184 |
+
# BatchNormalization(),
|
185 |
+
LeakyReLU(alpha=0.2),
|
186 |
+
]
|
187 |
+
)
|
188 |
+
self.upc5 = Sequential(
|
189 |
+
[
|
190 |
+
Conv2DTranspose(
|
191 |
+
1,
|
192 |
+
kernel_size=4,
|
193 |
+
strides=2,
|
194 |
+
padding="same",
|
195 |
+
kernel_initializer=initializer_conv_dense,
|
196 |
+
),
|
197 |
+
Activation("sigmoid"),
|
198 |
+
]
|
199 |
+
)
|
200 |
+
|
201 |
+
def call(self, input):
|
202 |
+
vec, skip = input
|
203 |
+
d1 = self.upc1(tf.reshape(vec, (-1, 1, 1, self.dim)))
|
204 |
+
d2 = self.upc2(tf.concat([d1, skip[3]], axis=-1))
|
205 |
+
d3 = self.upc3(tf.concat([d2, skip[2]], axis=-1))
|
206 |
+
d4 = self.upc4(tf.concat([d3, skip[1]], axis=-1))
|
207 |
+
output = self.upc5(tf.concat([d4, skip[0]], axis=-1))
|
208 |
+
return output
|
209 |
+
|
210 |
+
|
211 |
+
class MyLSTM(Model):
|
212 |
+
def __init__(self, input_shape, hidden_size, output_size, n_layers):
|
213 |
+
super().__init__()
|
214 |
+
self.hidden_size = hidden_size
|
215 |
+
self.n_layers = n_layers
|
216 |
+
self.embed = Dense(
|
217 |
+
hidden_size,
|
218 |
+
input_dim=input_shape,
|
219 |
+
kernel_initializer=initializer_conv_dense,
|
220 |
+
)
|
221 |
+
# self.lstm = Sequential(
|
222 |
+
# [LSTMCell(hidden_size) for _ in range(n_layers)], name="lstm"
|
223 |
+
# )
|
224 |
+
# self.lstm = self.create_lstm(hidden_size, n_layers)
|
225 |
+
self.lstm = [
|
226 |
+
LSTMCell(
|
227 |
+
hidden_size # , return_sequences=False if i == self.n_layers - 1 else True
|
228 |
+
)
|
229 |
+
for i in range(self.n_layers)
|
230 |
+
] # LSTMCell(hidden_size)
|
231 |
+
self.lstm_rnn = tf.keras.layers.RNN(self.lstm[0], return_state=True)
|
232 |
+
self.out = Dense(output_size, kernel_initializer=initializer_conv_dense)
|
233 |
+
|
234 |
+
def init_hidden(self, batch_size):
|
235 |
+
hidden = []
|
236 |
+
for i in range(self.n_layers):
|
237 |
+
hidden.append(
|
238 |
+
(
|
239 |
+
tf.Variable(tf.zeros([batch_size, self.hidden_size])),
|
240 |
+
tf.Variable(tf.zeros([batch_size, self.hidden_size])),
|
241 |
+
)
|
242 |
+
)
|
243 |
+
self.__dict__["hidden"] = hidden
|
244 |
+
|
245 |
+
def build(self, input_shape):
|
246 |
+
self.init_hidden(input_shape[0])
|
247 |
+
|
248 |
+
def call(self, inputs):
|
249 |
+
h_in = self.embed(inputs)
|
250 |
+
h_in = tf.reshape(h_in, (-1, 1, self.hidden_size))
|
251 |
+
h_in, *state = self.lstm_rnn(h_in)
|
252 |
+
for i in range(self.n_layers):
|
253 |
+
h_in, state = self.lstm[i](h_in, state)
|
254 |
+
return self.out(h_in)
|
255 |
+
|
256 |
+
|
257 |
+
class MyGaussianLSTM(Model):
|
258 |
+
def __init__(self, input_shape, hidden_size, output_size, n_layers):
|
259 |
+
super().__init__()
|
260 |
+
self.hidden_size = hidden_size
|
261 |
+
self.n_layers = n_layers
|
262 |
+
self.embed = Dense(
|
263 |
+
hidden_size,
|
264 |
+
input_dim=input_shape,
|
265 |
+
kernel_initializer=initializer_conv_dense,
|
266 |
+
)
|
267 |
+
# self.lstm = Sequential(
|
268 |
+
# [LSTMCell(hidden_size) for _ in range(n_layers)], name="lstm"
|
269 |
+
# )
|
270 |
+
self.lstm = [
|
271 |
+
LSTMCell(
|
272 |
+
hidden_size # , return_sequences=False if i == self.n_layers - 1 else True
|
273 |
+
)
|
274 |
+
for i in range(self.n_layers)
|
275 |
+
] # LSTMCell(hidden_size)
|
276 |
+
self.lstm_rnn = tf.keras.layers.RNN(self.lstm[0], return_state=True)
|
277 |
+
self.mu_net = Dense(output_size, kernel_initializer=initializer_conv_dense)
|
278 |
+
self.logvar_net = Dense(output_size, kernel_initializer=initializer_conv_dense)
|
279 |
+
# self.out = Sequential(
|
280 |
+
# [
|
281 |
+
# tf.keras.layers.Dense(
|
282 |
+
# tfp.layers.MultivariateNormalTriL.params_size(output_size),
|
283 |
+
# activation=None,
|
284 |
+
# ),
|
285 |
+
# tfp.layers.MultivariateNormalTriL(output_size),
|
286 |
+
# ]
|
287 |
+
# )
|
288 |
+
|
289 |
+
def reparameterize(self, mu, logvar: tf.Tensor):
|
290 |
+
logvar = tf.math.exp(logvar * 0.5)
|
291 |
+
eps = tf.random.normal(logvar.shape)
|
292 |
+
return tf.add(tf.math.multiply(eps, logvar), mu)
|
293 |
+
|
294 |
+
def init_hidden(self, batch_size):
|
295 |
+
hidden = []
|
296 |
+
for i in range(self.n_layers):
|
297 |
+
hidden.append(
|
298 |
+
(
|
299 |
+
tf.Variable(tf.zeros([batch_size, self.hidden_size])),
|
300 |
+
tf.Variable(tf.zeros([batch_size, self.hidden_size])),
|
301 |
+
)
|
302 |
+
)
|
303 |
+
self.__dict__["hidden"] = hidden
|
304 |
+
|
305 |
+
def build(self, input_shape):
|
306 |
+
self.init_hidden(input_shape[0])
|
307 |
+
|
308 |
+
def call(self, inputs):
|
309 |
+
h_in = self.embed(inputs)
|
310 |
+
# for i in range(self.n_layers):
|
311 |
+
# # print(h_in.shape, self.hidden[i][0].shape, self.hidden[i][0].shape)
|
312 |
+
|
313 |
+
# _, self.hidden[i] = self.lstm(h_in, self.hidden[i])
|
314 |
+
# h_in = self.hidden[i][0]
|
315 |
+
h_in = tf.reshape(h_in, (-1, 1, self.hidden_size))
|
316 |
+
h_in, *state = self.lstm_rnn(h_in)
|
317 |
+
for i in range(self.n_layers):
|
318 |
+
h_in, state = self.lstm[i](h_in, state)
|
319 |
+
|
320 |
+
mu = self.mu_net(h_in)
|
321 |
+
logvar = self.logvar_net(h_in)
|
322 |
+
z = self.reparameterize(mu, logvar)
|
323 |
+
return z, mu, logvar
|
324 |
+
|
325 |
+
|
326 |
+
class P2P(Model):
|
327 |
+
def __init__(
|
328 |
+
self,
|
329 |
+
channels: int = 1,
|
330 |
+
g_dim: int = 128,
|
331 |
+
z_dim: int = 10,
|
332 |
+
rnn_size: int = 256,
|
333 |
+
prior_rnn_layers: int = 1,
|
334 |
+
posterior_rnn_layers: int = 1,
|
335 |
+
predictor_rnn_layers: float = 2,
|
336 |
+
skip_prob: float = 0.5,
|
337 |
+
n_past: int = 1,
|
338 |
+
last_frame_skip: bool = False,
|
339 |
+
beta: float = 0.0001,
|
340 |
+
weight_align: float = 0.1,
|
341 |
+
weight_cpc: float = 100,
|
342 |
+
):
|
343 |
+
super().__init__()
|
344 |
+
self.channels = channels
|
345 |
+
self.g_dim = g_dim
|
346 |
+
self.z_dim = z_dim
|
347 |
+
self.rnn_size = rnn_size
|
348 |
+
self.prior_rnn_layers = prior_rnn_layers
|
349 |
+
self.posterior_rnn_layers = posterior_rnn_layers
|
350 |
+
self.predictor_rnn_layers = predictor_rnn_layers
|
351 |
+
|
352 |
+
self.skip_prob = skip_prob
|
353 |
+
self.n_past = n_past
|
354 |
+
self.last_frame_skip = last_frame_skip
|
355 |
+
self.beta = beta
|
356 |
+
self.weight_align = weight_align
|
357 |
+
self.weight_cpc = weight_cpc
|
358 |
+
|
359 |
+
self.frame_predictor = MyLSTM(
|
360 |
+
self.g_dim + self.z_dim + 1 + 1,
|
361 |
+
self.rnn_size,
|
362 |
+
self.g_dim,
|
363 |
+
self.predictor_rnn_layers,
|
364 |
+
)
|
365 |
+
|
366 |
+
self.prior = MyGaussianLSTM(
|
367 |
+
self.g_dim + self.g_dim + 1 + 1,
|
368 |
+
self.rnn_size,
|
369 |
+
self.z_dim,
|
370 |
+
self.prior_rnn_layers,
|
371 |
+
)
|
372 |
+
|
373 |
+
self.posterior = MyGaussianLSTM(
|
374 |
+
self.g_dim + self.g_dim + 1 + 1,
|
375 |
+
self.rnn_size,
|
376 |
+
self.z_dim,
|
377 |
+
self.posterior_rnn_layers,
|
378 |
+
)
|
379 |
+
|
380 |
+
self.encoder = Encoder(self.g_dim, self.channels)
|
381 |
+
self.decoder = Decoder(self.g_dim, self.channels)
|
382 |
+
|
383 |
+
# criterions
|
384 |
+
self.mse_criterion = tf.keras.losses.MeanSquaredError()
|
385 |
+
self.kl_criterion = KLCriterion()
|
386 |
+
self.align_criterion = tf.keras.losses.MeanSquaredError()
|
387 |
+
|
388 |
+
self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
|
389 |
+
self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
|
390 |
+
name="reconstruction_loss"
|
391 |
+
)
|
392 |
+
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
|
393 |
+
self.align_loss_tracker = tf.keras.metrics.Mean(name="align_loss")
|
394 |
+
self.cpc_loss_tracker = tf.keras.metrics.Mean(name="align_loss")
|
395 |
+
|
396 |
+
# optimizers
|
397 |
+
# self.frame_predictor_optimizer = tf.keras.optimizers.Adam(
|
398 |
+
# learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
|
399 |
+
# )
|
400 |
+
# self.posterior_optimizer = tf.keras.optimizers.Adam(
|
401 |
+
# learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
|
402 |
+
# )
|
403 |
+
# self.prior_optimizer = tf.keras.optimizers.Adam(
|
404 |
+
# learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
|
405 |
+
# )
|
406 |
+
# self.encoder_optimizer = tf.keras.optimizers.Adam(
|
407 |
+
# learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
|
408 |
+
# )
|
409 |
+
# self.decoder_optimizer = tf.keras.optimizers.Adam(
|
410 |
+
# learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
|
411 |
+
# )
|
412 |
+
|
413 |
+
@property
|
414 |
+
def metrics(self):
|
415 |
+
return [
|
416 |
+
self.total_loss_tracker,
|
417 |
+
self.reconstruction_loss_tracker,
|
418 |
+
self.kl_loss_tracker,
|
419 |
+
self.align_loss_tracker,
|
420 |
+
self.cpc_loss_tracker,
|
421 |
+
]
|
422 |
+
|
423 |
+
def get_global_descriptor(self, x, start_ix=0, cp_ix=None):
|
424 |
+
"""Get the global descriptor based on x, start_ix, cp_ix."""
|
425 |
+
if cp_ix is None:
|
426 |
+
cp_ix = x.shape[1] - 1
|
427 |
+
|
428 |
+
x_cp = x[:, cp_ix, ...]
|
429 |
+
h_cp = self.encoder(x_cp)[0] # 1 is input for skip-connection
|
430 |
+
|
431 |
+
return x_cp, h_cp
|
432 |
+
|
433 |
+
def compile(
|
434 |
+
self,
|
435 |
+
frame_predictor_optimizer,
|
436 |
+
prior_optimizer,
|
437 |
+
posterior_optimizer,
|
438 |
+
encoder_optimizer,
|
439 |
+
decoder_optimizer,
|
440 |
+
):
|
441 |
+
super().compile()
|
442 |
+
self.frame_predictor_optimizer = frame_predictor_optimizer
|
443 |
+
self.prior_optimizer = prior_optimizer
|
444 |
+
self.posterior_optimizer = posterior_optimizer
|
445 |
+
self.encoder_optimizer = encoder_optimizer
|
446 |
+
self.decoder_optimizer = decoder_optimizer
|
447 |
+
|
448 |
+
def train_step(self, data):
|
449 |
+
y, x = data
|
450 |
+
batch_size = 100
|
451 |
+
|
452 |
+
mse_loss = 0
|
453 |
+
kld_loss = 0
|
454 |
+
cpc_loss = 0
|
455 |
+
align_loss = 0
|
456 |
+
|
457 |
+
seq_len = x.shape[1]
|
458 |
+
start_ix = 0
|
459 |
+
cp_ix = seq_len - 1
|
460 |
+
x_cp, global_z = self.get_global_descriptor(
|
461 |
+
x, start_ix, cp_ix
|
462 |
+
) # here global_z is h_cp
|
463 |
+
|
464 |
+
skip_prob = self.skip_prob
|
465 |
+
|
466 |
+
prev_i = 0
|
467 |
+
max_skip_count = seq_len * skip_prob
|
468 |
+
skip_count = 0
|
469 |
+
probs = np.random.uniform(low=0, high=1, size=seq_len - 1)
|
470 |
+
|
471 |
+
with tf.GradientTape(persistent=True) as tape:
|
472 |
+
for i in tqdm(range(1, seq_len)):
|
473 |
+
if (
|
474 |
+
probs[i - 1] <= skip_prob
|
475 |
+
and i >= self.n_past
|
476 |
+
and skip_count < max_skip_count
|
477 |
+
and i != 1
|
478 |
+
and i != cp_ix
|
479 |
+
):
|
480 |
+
skip_count += 1
|
481 |
+
continue
|
482 |
+
|
483 |
+
if i > 1:
|
484 |
+
align_loss += self.align_criterion(h, h_pred)
|
485 |
+
|
486 |
+
time_until_cp = tf.fill(
|
487 |
+
[batch_size, 1],
|
488 |
+
(cp_ix - i + 1) / cp_ix,
|
489 |
+
)
|
490 |
+
delta_time = tf.fill([batch_size, 1], ((i - prev_i) / cp_ix))
|
491 |
+
prev_i = i
|
492 |
+
|
493 |
+
h = self.encoder(x[:, i - 1, ...])
|
494 |
+
h_target = self.encoder(x[:, i, ...])[0]
|
495 |
+
|
496 |
+
if self.last_frame_skip or i <= self.n_past:
|
497 |
+
h, skip = h
|
498 |
+
else:
|
499 |
+
h = h[0]
|
500 |
+
|
501 |
+
# Control Point Aware
|
502 |
+
h_cpaw = tf.concat([h, global_z, time_until_cp, delta_time], axis=-1)
|
503 |
+
h_target_cpaw = tf.concat(
|
504 |
+
[h_target, global_z, time_until_cp, delta_time], axis=-1
|
505 |
+
)
|
506 |
+
|
507 |
+
zt, mu, logvar = self.posterior(h_target_cpaw)
|
508 |
+
zt_p, mu_p, logvar_p = self.prior(h_cpaw)
|
509 |
+
|
510 |
+
frame_predictor_input = tf.concat(
|
511 |
+
[h, zt, time_until_cp, delta_time], axis=-1
|
512 |
+
)
|
513 |
+
h_pred = self.frame_predictor(frame_predictor_input)
|
514 |
+
x_pred = self.decoder([h_pred, skip])
|
515 |
+
|
516 |
+
if i == cp_ix: # the gen-cp-frame should be exactly as x_cp
|
517 |
+
h_pred_p = self.frame_predictor(
|
518 |
+
tf.concat([h, zt_p, time_until_cp, delta_time], axis=-1)
|
519 |
+
)
|
520 |
+
x_pred_p = self.decoder([h_pred_p, skip])
|
521 |
+
cpc_loss = self.mse_criterion(x_pred_p, x_cp)
|
522 |
+
|
523 |
+
mse_loss += self.mse_criterion(x_pred, x[:, i, ...])
|
524 |
+
kld_loss += self.kl_criterion((mu, logvar), (mu_p, logvar_p))
|
525 |
+
|
526 |
+
# backward
|
527 |
+
loss = (
|
528 |
+
mse_loss
|
529 |
+
+ kld_loss * self.beta
|
530 |
+
+ align_loss * self.weight_align
|
531 |
+
# + cpc_loss * self.weight_cpc
|
532 |
+
)
|
533 |
+
|
534 |
+
prior_loss = kld_loss + cpc_loss * self.weight_cpc
|
535 |
+
|
536 |
+
var_list_frame_predictor = self.frame_predictor.trainable_variables
|
537 |
+
var_list_posterior = self.posterior.trainable_variables
|
538 |
+
var_list_prior = self.prior.trainable_variables
|
539 |
+
var_list_encoder = self.encoder.trainable_variables
|
540 |
+
var_list_decoder = self.decoder.trainable_variables
|
541 |
+
|
542 |
+
# mse: frame_predictor + decoder
|
543 |
+
# align: frame_predictor + encoder
|
544 |
+
# kld: posterior + prior + encoder
|
545 |
+
|
546 |
+
var_list = (
|
547 |
+
var_list_frame_predictor
|
548 |
+
+ var_list_posterior
|
549 |
+
+ var_list_encoder
|
550 |
+
+ var_list_decoder
|
551 |
+
+ var_list_prior
|
552 |
+
)
|
553 |
+
|
554 |
+
gradients = tape.gradient(
|
555 |
+
loss,
|
556 |
+
var_list,
|
557 |
+
)
|
558 |
+
gradients_prior = tape.gradient(
|
559 |
+
prior_loss,
|
560 |
+
var_list_prior,
|
561 |
+
)
|
562 |
+
|
563 |
+
self.update_model(
|
564 |
+
gradients,
|
565 |
+
var_list,
|
566 |
+
)
|
567 |
+
self.update_prior(gradients_prior, var_list_prior)
|
568 |
+
del tape
|
569 |
+
|
570 |
+
self.total_loss_tracker.update_state(loss)
|
571 |
+
self.kl_loss_tracker.update_state(kld_loss)
|
572 |
+
self.align_loss_tracker.update_state(align_loss)
|
573 |
+
self.reconstruction_loss_tracker.update_state(mse_loss)
|
574 |
+
self.cpc_loss_tracker.update_state(cpc_loss)
|
575 |
+
|
576 |
+
return {
|
577 |
+
"loss": self.total_loss_tracker.result(),
|
578 |
+
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
|
579 |
+
"kl_loss": self.kl_loss_tracker.result(),
|
580 |
+
"align_loss": self.align_loss_tracker.result(),
|
581 |
+
"cpc_loss": self.cpc_loss_tracker.result(),
|
582 |
+
}
|
583 |
+
|
584 |
+
def call(
|
585 |
+
self,
|
586 |
+
inputs,
|
587 |
+
training=None,
|
588 |
+
mask=None
|
589 |
+
# len_output,
|
590 |
+
# eval_cp_ix,
|
591 |
+
# start_ix=0,
|
592 |
+
# cp_ix=-1,
|
593 |
+
# model_mode="full",
|
594 |
+
# skip_frame=False,
|
595 |
+
# init_hidden=True,
|
596 |
+
):
|
597 |
+
len_output = 20
|
598 |
+
eval_cp_ix = len_output - 1
|
599 |
+
start_ix = 0
|
600 |
+
cp_ix = -1
|
601 |
+
model_mode = "full"
|
602 |
+
skip_frame = False
|
603 |
+
init_hidden = True
|
604 |
+
|
605 |
+
batch_size, num_frames, h, w, channels = inputs.shape
|
606 |
+
dim_shape = (h, w, channels)
|
607 |
+
|
608 |
+
gen_seq = [inputs[:, 0, ...]]
|
609 |
+
x_in = inputs[:, 0, ...]
|
610 |
+
|
611 |
+
seq_len = inputs.shape[1]
|
612 |
+
cp_ix = seq_len - 1
|
613 |
+
|
614 |
+
x_cp, global_z = self.get_global_descriptor(
|
615 |
+
inputs, cp_ix=cp_ix
|
616 |
+
) # here global_z is h_cp
|
617 |
+
|
618 |
+
skip_prob = self.skip_prob
|
619 |
+
|
620 |
+
prev_i = 0
|
621 |
+
max_skip_count = seq_len * skip_prob
|
622 |
+
skip_count = 0
|
623 |
+
probs = np.random.uniform(0, 1, len_output - 1)
|
624 |
+
|
625 |
+
for i in range(1, len_output):
|
626 |
+
if (
|
627 |
+
probs[i - 1] <= skip_prob
|
628 |
+
and i >= self.n_past
|
629 |
+
and skip_count < max_skip_count
|
630 |
+
and i != 1
|
631 |
+
and i != (len_output - 1)
|
632 |
+
and skip_frame
|
633 |
+
):
|
634 |
+
skip_count += 1
|
635 |
+
gen_seq.append(tf.zeros_like(x_in))
|
636 |
+
continue
|
637 |
+
|
638 |
+
time_until_cp = tf.fill([100, 1], (eval_cp_ix - i + 1) / eval_cp_ix)
|
639 |
+
|
640 |
+
delta_time = tf.fill([100, 1], ((i - prev_i) / eval_cp_ix))
|
641 |
+
|
642 |
+
prev_i = i
|
643 |
+
|
644 |
+
h = self.encoder(x_in)
|
645 |
+
|
646 |
+
if self.last_frame_skip or i == 1 or i < self.n_past:
|
647 |
+
h, skip = h
|
648 |
+
else:
|
649 |
+
h, _ = h
|
650 |
+
|
651 |
+
h_cpaw = tf.stop_gradient(tf.concat([h, global_z, time_until_cp, delta_time], axis=-1))
|
652 |
+
|
653 |
+
if i < self.n_past:
|
654 |
+
h_target = self.encoder(inputs[:, i, ...])[0]
|
655 |
+
h_target_cpaw = tf.stop_gradient(tf.concat(
|
656 |
+
[h_target, global_z, time_until_cp, delta_time], axis=1
|
657 |
+
))
|
658 |
+
|
659 |
+
zt, _, _ = self.posterior(h_target_cpaw)
|
660 |
+
zt_p, _, _ = self.prior(h_cpaw)
|
661 |
+
|
662 |
+
if model_mode == "posterior" or model_mode == "full":
|
663 |
+
self.frame_predictor(
|
664 |
+
tf.concat([h, zt, time_until_cp, delta_time], axis=-1)
|
665 |
+
)
|
666 |
+
elif model_mode == "prior":
|
667 |
+
self.frame_predictor(
|
668 |
+
tf.concat([h, zt_p, time_until_cp, delta_time], axis=-1)
|
669 |
+
)
|
670 |
+
|
671 |
+
x_in = inputs[:, i, ...]
|
672 |
+
gen_seq.append(x_in)
|
673 |
+
else:
|
674 |
+
if i < num_frames:
|
675 |
+
h_target = self.encoder(inputs[:, i, ...])[0]
|
676 |
+
h_target_cpaw = tf.stop_gradient(tf.concat(
|
677 |
+
[h_target, global_z, time_until_cp, delta_time], axis=-1
|
678 |
+
))
|
679 |
+
else:
|
680 |
+
h_target_cpaw = h_cpaw
|
681 |
+
|
682 |
+
zt, _, _ = self.posterior(h_target_cpaw)
|
683 |
+
zt_p, _, _ = self.prior(h_cpaw)
|
684 |
+
|
685 |
+
if model_mode == "posterior":
|
686 |
+
h = self.frame_predictor(
|
687 |
+
tf.concat([h, zt, time_until_cp, delta_time], axis=-1)
|
688 |
+
)
|
689 |
+
elif model_mode == "prior" or model_mode == "full":
|
690 |
+
h = self.frame_predictor(
|
691 |
+
tf.concat([h, zt_p, time_until_cp, delta_time], axis=-1)
|
692 |
+
)
|
693 |
+
|
694 |
+
x_in = tf.stop_gradient(self.decoder([h, skip]))
|
695 |
+
gen_seq.append(x_in)
|
696 |
+
|
697 |
+
return tf.stack(gen_seq, axis=1)
|
698 |
+
|
699 |
+
def update_model(self, gradients, var_list):
|
700 |
+
self.frame_predictor_optimizer.apply_gradients(zip(gradients, var_list))
|
701 |
+
self.posterior_optimizer.apply_gradients(zip(gradients, var_list))
|
702 |
+
self.encoder_optimizer.apply_gradients(zip(gradients, var_list))
|
703 |
+
self.decoder_optimizer.apply_gradients(zip(gradients, var_list))
|
704 |
+
#self.prior_optimizer.apply_gradients(zip(gradients, var_list))
|
705 |
+
|
706 |
+
def update_prior(self, gradients, var_list):
|
707 |
+
self.prior_optimizer.apply_gradients(zip(gradients, var_list))
|
708 |
+
|
709 |
+
# def update_model_without_prior(self):
|
710 |
+
# self.frame_predictor_optimizer.step()
|
711 |
+
# self.posterior_optimizer.step()
|
712 |
+
# self.encoder_optimizer.step()
|
713 |
+
# self.decoder_optimizer.step()
|
ganime/model/p2p/p2p_v2.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
from tensorflow.keras import Model, Sequential
|
4 |
+
from tensorflow.keras.layers import (
|
5 |
+
LSTM,
|
6 |
+
Activation,
|
7 |
+
BatchNormalization,
|
8 |
+
Conv2D,
|
9 |
+
Conv2DTranspose,
|
10 |
+
Conv3D,
|
11 |
+
Conv3DTranspose,
|
12 |
+
Dense,
|
13 |
+
Flatten,
|
14 |
+
Input,
|
15 |
+
Layer,
|
16 |
+
LeakyReLU,
|
17 |
+
MaxPooling2D,
|
18 |
+
Reshape,
|
19 |
+
TimeDistributed,
|
20 |
+
UpSampling2D,
|
21 |
+
)
|
22 |
+
from tensorflow.keras.losses import Loss
|
23 |
+
from tensorflow.keras.losses import KLDivergence, MeanSquaredError
|
24 |
+
from tqdm.auto import tqdm
|
25 |
+
|
26 |
+
|
27 |
+
class KLCriterion(Loss):
|
28 |
+
def call(self, y_true, y_pred):
|
29 |
+
(mu1, logvar1), (mu2, logvar2) = y_true, y_pred
|
30 |
+
|
31 |
+
"""KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2))"""
|
32 |
+
sigma1 = tf.exp(tf.math.multiply(logvar1, 0.5))
|
33 |
+
sigma2 = tf.exp(tf.math.multiply(logvar2, 0.5))
|
34 |
+
|
35 |
+
kld = (
|
36 |
+
tf.math.log(sigma2 / sigma1)
|
37 |
+
+ (tf.exp(logvar1) + tf.square(mu1 - mu2)) / (2 * tf.exp(logvar2))
|
38 |
+
- 0.5
|
39 |
+
)
|
40 |
+
return kld
|
41 |
+
|
42 |
+
|
43 |
+
class Decoder(Model):
|
44 |
+
def __init__(self, dim, nc=1):
|
45 |
+
super().__init__()
|
46 |
+
self.dim = dim
|
47 |
+
self.upc1 = Sequential(
|
48 |
+
[
|
49 |
+
TimeDistributed(
|
50 |
+
Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid")
|
51 |
+
),
|
52 |
+
BatchNormalization(),
|
53 |
+
LeakyReLU(alpha=0.2),
|
54 |
+
]
|
55 |
+
)
|
56 |
+
self.upc2 = Sequential(
|
57 |
+
[
|
58 |
+
TimeDistributed(
|
59 |
+
Conv2DTranspose(256, kernel_size=4, strides=2, padding="same")
|
60 |
+
),
|
61 |
+
BatchNormalization(),
|
62 |
+
LeakyReLU(alpha=0.2),
|
63 |
+
]
|
64 |
+
)
|
65 |
+
self.upc3 = Sequential(
|
66 |
+
[
|
67 |
+
TimeDistributed(
|
68 |
+
Conv2DTranspose(128, kernel_size=4, strides=2, padding="same")
|
69 |
+
),
|
70 |
+
BatchNormalization(),
|
71 |
+
LeakyReLU(alpha=0.2),
|
72 |
+
]
|
73 |
+
)
|
74 |
+
self.upc4 = Sequential(
|
75 |
+
[
|
76 |
+
TimeDistributed(
|
77 |
+
Conv2DTranspose(64, kernel_size=4, strides=2, padding="same")
|
78 |
+
),
|
79 |
+
BatchNormalization(),
|
80 |
+
LeakyReLU(alpha=0.2),
|
81 |
+
]
|
82 |
+
)
|
83 |
+
self.upc5 = Sequential(
|
84 |
+
[
|
85 |
+
TimeDistributed(
|
86 |
+
Conv2DTranspose(1, kernel_size=4, strides=2, padding="same")
|
87 |
+
),
|
88 |
+
Activation("sigmoid"),
|
89 |
+
]
|
90 |
+
)
|
91 |
+
|
92 |
+
def call(self, input):
|
93 |
+
vec, skip = input
|
94 |
+
d1 = self.upc1(tf.reshape(vec, (-1, 1, 1, 1, self.dim)))
|
95 |
+
d2 = self.upc2(tf.concat([d1, skip[3]], axis=-1))
|
96 |
+
d3 = self.upc3(tf.concat([d2, skip[2]], axis=-1))
|
97 |
+
d4 = self.upc4(tf.concat([d3, skip[1]], axis=-1))
|
98 |
+
output = self.upc5(tf.concat([d4, skip[0]], axis=-1))
|
99 |
+
return output
|
100 |
+
|
101 |
+
|
102 |
+
class Sampling(Layer):
|
103 |
+
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
|
104 |
+
|
105 |
+
def call(self, inputs):
|
106 |
+
z_mean, z_log_var = inputs
|
107 |
+
batch = tf.shape(z_mean)[0]
|
108 |
+
dim = tf.shape(z_mean)[1]
|
109 |
+
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
|
110 |
+
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
|
111 |
+
|
112 |
+
def compute_output_shape(self, input_shape):
|
113 |
+
return input_shape[0]
|
114 |
+
|
115 |
+
|
116 |
+
class P2P(Model):
|
117 |
+
def __init__(
|
118 |
+
self,
|
119 |
+
channels: int = 1,
|
120 |
+
g_dim: int = 128,
|
121 |
+
z_dim: int = 10,
|
122 |
+
rnn_size: int = 256,
|
123 |
+
prior_rnn_layers: int = 1,
|
124 |
+
posterior_rnn_layers: int = 1,
|
125 |
+
predictor_rnn_layers: float = 1,
|
126 |
+
skip_prob: float = 0.1,
|
127 |
+
n_past: int = 1,
|
128 |
+
last_frame_skip: bool = False,
|
129 |
+
beta: float = 0.0001,
|
130 |
+
weight_align: float = 0.1,
|
131 |
+
weight_cpc: float = 100,
|
132 |
+
):
|
133 |
+
super().__init__()
|
134 |
+
# Models parameters
|
135 |
+
self.channels = channels
|
136 |
+
self.g_dim = g_dim
|
137 |
+
self.z_dim = z_dim
|
138 |
+
self.rnn_size = rnn_size
|
139 |
+
self.prior_rnn_layers = prior_rnn_layers
|
140 |
+
self.posterior_rnn_layers = posterior_rnn_layers
|
141 |
+
self.predictor_rnn_layers = predictor_rnn_layers
|
142 |
+
|
143 |
+
# Training parameters
|
144 |
+
self.skip_prob = skip_prob
|
145 |
+
self.n_past = n_past
|
146 |
+
self.last_frame_skip = last_frame_skip
|
147 |
+
self.beta = beta
|
148 |
+
self.weight_align = weight_align
|
149 |
+
self.weight_cpc = weight_cpc
|
150 |
+
|
151 |
+
self.frame_predictor = self.build_lstm()
|
152 |
+
self.prior = self.build_gaussian_lstm()
|
153 |
+
self.posterior = self.build_gaussian_lstm()
|
154 |
+
self.encoder = self.build_encoder()
|
155 |
+
self.decoder = self.build_decoder()
|
156 |
+
|
157 |
+
self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
|
158 |
+
self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
|
159 |
+
name="reconstruction_loss"
|
160 |
+
)
|
161 |
+
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
|
162 |
+
self.align_loss_tracker = tf.keras.metrics.Mean(name="align_loss")
|
163 |
+
self.cpc_loss_tracker = tf.keras.metrics.Mean(name="align_loss")
|
164 |
+
|
165 |
+
self.kl_loss = KLCriterion(
|
166 |
+
reduction=tf.keras.losses.Reduction.NONE
|
167 |
+
) # KLDivergence(reduction=tf.keras.losses.Reduction.NONE)
|
168 |
+
self.mse = MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
|
169 |
+
self.align_loss = MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
|
170 |
+
|
171 |
+
# self.optimizer = tf.keras.optimizers.Adam(
|
172 |
+
# learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
|
173 |
+
# )
|
174 |
+
# self.prior_optimizer = tf.keras.optimizers.Adam(
|
175 |
+
# learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8
|
176 |
+
# )
|
177 |
+
|
178 |
+
# region Model building
|
179 |
+
def build_lstm(self):
|
180 |
+
input = Input(shape=(None, self.g_dim + self.z_dim))
|
181 |
+
embed = TimeDistributed(Dense(self.rnn_size))(input)
|
182 |
+
lstm = LSTM(self.rnn_size)(embed)
|
183 |
+
output = Dense(self.g_dim)(lstm)
|
184 |
+
output = (tf.expand_dims(output, axis=1),)
|
185 |
+
|
186 |
+
return Model(inputs=input, outputs=output, name="frame_predictor")
|
187 |
+
|
188 |
+
def build_gaussian_lstm(self):
|
189 |
+
|
190 |
+
input = Input(shape=(None, self.g_dim))
|
191 |
+
embed = TimeDistributed(Dense(self.rnn_size))(input)
|
192 |
+
lstm = LSTM(self.rnn_size)(embed)
|
193 |
+
mu = Dense(self.z_dim)(lstm)
|
194 |
+
logvar = Dense(self.z_dim)(lstm)
|
195 |
+
z = Sampling()([mu, logvar])
|
196 |
+
|
197 |
+
return Model(inputs=input, outputs=[mu, logvar, z])
|
198 |
+
|
199 |
+
def build_encoder(self):
|
200 |
+
|
201 |
+
input = Input(shape=(1, 64, 64, 1))
|
202 |
+
|
203 |
+
h = TimeDistributed(Conv2D(64, kernel_size=4, strides=2, padding="same"))(input)
|
204 |
+
h = BatchNormalization()(h)
|
205 |
+
h1 = LeakyReLU(alpha=0.2)(h)
|
206 |
+
# h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
|
207 |
+
|
208 |
+
h = TimeDistributed(Conv2D(128, kernel_size=4, strides=2, padding="same"))(h1)
|
209 |
+
h = BatchNormalization()(h)
|
210 |
+
h2 = LeakyReLU(alpha=0.2)(h)
|
211 |
+
# h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
|
212 |
+
|
213 |
+
h = TimeDistributed(Conv2D(256, kernel_size=4, strides=2, padding="same"))(h2)
|
214 |
+
h = BatchNormalization()(h)
|
215 |
+
h3 = LeakyReLU(alpha=0.2)(h)
|
216 |
+
# h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
|
217 |
+
|
218 |
+
h = TimeDistributed(Conv2D(512, kernel_size=4, strides=2, padding="same"))(h3)
|
219 |
+
h = BatchNormalization()(h)
|
220 |
+
h4 = LeakyReLU(alpha=0.2)(h)
|
221 |
+
# h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
|
222 |
+
|
223 |
+
h = TimeDistributed(
|
224 |
+
Conv2D(self.g_dim, kernel_size=4, strides=1, padding="valid")
|
225 |
+
)(h4)
|
226 |
+
h = BatchNormalization()(h)
|
227 |
+
h5 = Activation("tanh")(h)
|
228 |
+
|
229 |
+
output = tf.reshape(h5, (-1, 1, self.g_dim))
|
230 |
+
# h = Flatten()(h)
|
231 |
+
# output = Dense(self.g_dim)(h)
|
232 |
+
# output = tf.expand_dims(output, axis=1)
|
233 |
+
return Model(inputs=input, outputs=[output, [h1, h2, h3, h4]], name="encoder")
|
234 |
+
|
235 |
+
def build_decoder(self):
|
236 |
+
return Decoder(self.g_dim)
|
237 |
+
|
238 |
+
# def build_decoder(self):
|
239 |
+
# latent_inputs = Input(
|
240 |
+
# shape=(
|
241 |
+
# 1,
|
242 |
+
# self.g_dim,
|
243 |
+
# )
|
244 |
+
# )
|
245 |
+
# x = Dense(1 * 1 * 1 * 128, activation="relu")(latent_inputs)
|
246 |
+
# x = Reshape((1, 1, 1, 128))(x)
|
247 |
+
# x = TimeDistributed(
|
248 |
+
# Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid")
|
249 |
+
# )(x)
|
250 |
+
# x = BatchNormalization()(x)
|
251 |
+
# x1 = LeakyReLU(alpha=0.2)(x)
|
252 |
+
|
253 |
+
# x = TimeDistributed(
|
254 |
+
# Conv2DTranspose(256, kernel_size=4, strides=2, padding="same")
|
255 |
+
# )(x1)
|
256 |
+
# x = BatchNormalization()(x)
|
257 |
+
# x2 = LeakyReLU(alpha=0.2)(x)
|
258 |
+
|
259 |
+
# x = TimeDistributed(
|
260 |
+
# Conv2DTranspose(128, kernel_size=4, strides=2, padding="same")
|
261 |
+
# )(x2)
|
262 |
+
# x = BatchNormalization()(x)
|
263 |
+
# x3 = LeakyReLU(alpha=0.2)(x)
|
264 |
+
|
265 |
+
# x = TimeDistributed(
|
266 |
+
# Conv2DTranspose(64, kernel_size=4, strides=2, padding="same")
|
267 |
+
# )(x3)
|
268 |
+
# x = BatchNormalization()(x)
|
269 |
+
# x4 = LeakyReLU(alpha=0.2)(x)
|
270 |
+
|
271 |
+
# x = TimeDistributed(
|
272 |
+
# Conv2DTranspose(1, kernel_size=4, strides=2, padding="same")
|
273 |
+
# )(x4)
|
274 |
+
# x5 = Activation("sigmoid")(x)
|
275 |
+
|
276 |
+
# return Model(inputs=latent_inputs, outputs=x5, name="decoder")
|
277 |
+
|
278 |
+
# endregion
|
279 |
+
|
280 |
+
@property
|
281 |
+
def metrics(self):
|
282 |
+
return [
|
283 |
+
self.total_loss_tracker,
|
284 |
+
self.reconstruction_loss_tracker,
|
285 |
+
self.kl_loss_tracker,
|
286 |
+
self.align_loss_tracker,
|
287 |
+
self.cpc_loss_tracker,
|
288 |
+
]
|
289 |
+
|
290 |
+
def call(self, inputs, training=None, mask=None):
|
291 |
+
first_frame = inputs[:, 0:1, ...]
|
292 |
+
last_frame = inputs[:, -1:, ...]
|
293 |
+
|
294 |
+
desired_length = 20
|
295 |
+
previous_frame = first_frame
|
296 |
+
generated = [first_frame]
|
297 |
+
|
298 |
+
z_last, _ = self.encoder(last_frame)
|
299 |
+
for i in range(1, desired_length):
|
300 |
+
|
301 |
+
z_prev = self.encoder(previous_frame)
|
302 |
+
|
303 |
+
if self.last_frame_skip or i == 1 or i < self.n_past:
|
304 |
+
z_prev, skip = z_prev
|
305 |
+
else:
|
306 |
+
z_prev = z_prev[0]
|
307 |
+
|
308 |
+
prior_input = tf.concat([z_prev, z_last], axis=1)
|
309 |
+
|
310 |
+
z_mean_prior, z_log_var_prior, z_prior = self.prior(prior_input)
|
311 |
+
|
312 |
+
predictor_input = tf.concat(
|
313 |
+
(z_prev, tf.expand_dims(z_prior, axis=1)), axis=-1
|
314 |
+
)
|
315 |
+
z_pred = self.frame_predictor(predictor_input)
|
316 |
+
|
317 |
+
current_frame = self.decoder([z_pred, skip])
|
318 |
+
generated.append(current_frame)
|
319 |
+
previous_frame = current_frame
|
320 |
+
return tf.concat(generated, axis=1)
|
321 |
+
|
322 |
+
def train_step(self, data):
|
323 |
+
global_batch_size = 100 # * 8
|
324 |
+
x, y = data
|
325 |
+
|
326 |
+
first_frame = x[:, 0:1, ...]
|
327 |
+
last_frame = x[:, -1:, ...]
|
328 |
+
desired_length = y.shape[1]
|
329 |
+
previous_frame = first_frame
|
330 |
+
|
331 |
+
reconstruction_loss = 0
|
332 |
+
kl_loss = 0
|
333 |
+
align_loss = 0
|
334 |
+
cpc_loss = 0
|
335 |
+
|
336 |
+
with tf.GradientTape(persistent=True) as tape:
|
337 |
+
z_last, _ = self.encoder(last_frame)
|
338 |
+
for i in tqdm(range(1, desired_length)):
|
339 |
+
current_frame = y[:, i : i + 1, ...]
|
340 |
+
|
341 |
+
z_prev = self.encoder(previous_frame)
|
342 |
+
|
343 |
+
if self.last_frame_skip or i <= self.n_past:
|
344 |
+
z_prev, skip = z_prev
|
345 |
+
else:
|
346 |
+
z_prev = z_prev[0]
|
347 |
+
|
348 |
+
z_curr, _ = self.encoder(current_frame)
|
349 |
+
|
350 |
+
prior_input = tf.concat([z_prev, z_last], axis=1)
|
351 |
+
posterior_input = tf.concat([z_curr, z_last], axis=1)
|
352 |
+
|
353 |
+
z_mean_prior, z_log_var_prior, z_prior = self.prior(prior_input)
|
354 |
+
z_mean_posterior, z_log_var_posterior, z_posterior = self.posterior(
|
355 |
+
posterior_input
|
356 |
+
)
|
357 |
+
|
358 |
+
# predictor_input = z_prev
|
359 |
+
predictor_input = tf.concat(
|
360 |
+
(z_prev, tf.expand_dims(z_posterior, axis=1)), axis=-1
|
361 |
+
)
|
362 |
+
|
363 |
+
z_pred = self.frame_predictor(predictor_input)
|
364 |
+
|
365 |
+
kl_loss += tf.reduce_sum(
|
366 |
+
self.kl_loss(
|
367 |
+
(z_mean_prior, z_log_var_prior),
|
368 |
+
(z_mean_posterior, z_log_var_posterior),
|
369 |
+
)
|
370 |
+
) * (1.0 / global_batch_size)
|
371 |
+
|
372 |
+
if i > 1:
|
373 |
+
align_loss += tf.reduce_sum(self.align_loss(z_pred, z_curr)) * (
|
374 |
+
1.0 / global_batch_size
|
375 |
+
)
|
376 |
+
|
377 |
+
if i == desired_length - 1:
|
378 |
+
h_pred_p = self.frame_predictor(
|
379 |
+
tf.concat([z_prev, tf.expand_dims(z_prior, axis=1)], axis=-1)
|
380 |
+
)
|
381 |
+
x_pred_p = self.decoder([h_pred_p, skip])
|
382 |
+
cpc_loss = tf.reduce_sum(self.mse(x_pred_p, current_frame)) * (
|
383 |
+
1.0 / global_batch_size
|
384 |
+
)
|
385 |
+
|
386 |
+
prediction = self.decoder([z_pred, skip])
|
387 |
+
reconstruction_loss += tf.reduce_sum(
|
388 |
+
self.mse(prediction, current_frame)
|
389 |
+
) * (1.0 / global_batch_size)
|
390 |
+
|
391 |
+
previous_frame = current_frame
|
392 |
+
|
393 |
+
loss = (
|
394 |
+
reconstruction_loss
|
395 |
+
+ kl_loss * self.beta
|
396 |
+
+ align_loss * self.weight_align
|
397 |
+
+ cpc_loss * self.weight_cpc
|
398 |
+
)
|
399 |
+
|
400 |
+
prior_loss = kl_loss + cpc_loss * self.weight_cpc
|
401 |
+
|
402 |
+
grads_without_prior = tape.gradient(
|
403 |
+
loss,
|
404 |
+
(
|
405 |
+
self.encoder.trainable_weights
|
406 |
+
+ self.decoder.trainable_weights
|
407 |
+
+ self.posterior.trainable_weights
|
408 |
+
+ self.frame_predictor.trainable_weights
|
409 |
+
),
|
410 |
+
)
|
411 |
+
self.optimizer.apply_gradients(
|
412 |
+
zip(
|
413 |
+
grads_without_prior,
|
414 |
+
(
|
415 |
+
self.encoder.trainable_weights
|
416 |
+
+ self.decoder.trainable_weights
|
417 |
+
+ self.posterior.trainable_weights
|
418 |
+
+ self.frame_predictor.trainable_weights
|
419 |
+
),
|
420 |
+
)
|
421 |
+
)
|
422 |
+
|
423 |
+
grads_prior = tape.gradient(
|
424 |
+
prior_loss,
|
425 |
+
self.prior.trainable_weights,
|
426 |
+
)
|
427 |
+
|
428 |
+
self.optimizer.apply_gradients(
|
429 |
+
zip(
|
430 |
+
grads_prior,
|
431 |
+
self.prior.trainable_weights,
|
432 |
+
)
|
433 |
+
)
|
434 |
+
del tape
|
435 |
+
|
436 |
+
self.total_loss_tracker.update_state(loss)
|
437 |
+
self.kl_loss_tracker.update_state(kl_loss)
|
438 |
+
self.align_loss_tracker.update_state(align_loss)
|
439 |
+
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
|
440 |
+
self.cpc_loss_tracker.update_state(cpc_loss)
|
441 |
+
|
442 |
+
return {
|
443 |
+
"loss": self.total_loss_tracker.result(),
|
444 |
+
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
|
445 |
+
"kl_loss": self.kl_loss_tracker.result(),
|
446 |
+
"align_loss": self.align_loss_tracker.result(),
|
447 |
+
"cpc_loss": self.cpc_loss_tracker.result(),
|
448 |
+
}
|
449 |
+
|
450 |
+
# print("KL_LOSS")
|
451 |
+
# print(kl_loss)
|
452 |
+
# print("ALIGN_LOSS")
|
453 |
+
# print(align_loss)
|
454 |
+
# print("RECONSTRUCTION_LOSS")
|
455 |
+
# print(reconstruction_loss)
|
456 |
+
|
457 |
+
# with tf.GradientTape() as tape:
|
458 |
+
# z_mean, z_log_var, z = self.encoder(x)
|
459 |
+
# reconstruction = self.decoder(z)
|
460 |
+
# reconstruction_loss = tf.reduce_mean(
|
461 |
+
# tf.reduce_sum(
|
462 |
+
# tf.keras.losses.binary_crossentropy(y, reconstruction),
|
463 |
+
# axis=(1, 2),
|
464 |
+
# )
|
465 |
+
# )
|
466 |
+
# kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
|
467 |
+
# kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
|
468 |
+
# total_loss = reconstruction_loss + self.kl_beta * kl_loss
|
469 |
+
# grads = tape.gradient(total_loss, self.trainable_weights)
|
470 |
+
# self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
|
471 |
+
# self.total_loss_tracker.update_state(total_loss)
|
472 |
+
# self.reconstruction_loss_tracker.update_state(reconstruction_loss)
|
473 |
+
# self.kl_loss_tracker.update_state(kl_loss)
|
474 |
+
# return {
|
475 |
+
# "loss": self.total_loss_tracker.result(),
|
476 |
+
# "reconstruction_loss": self.reconstruction_loss_tracker.result(),
|
477 |
+
# "kl_loss": self.kl_loss_tracker.result(),
|
478 |
+
# }
|
479 |
+
|
480 |
+
# def test_step(self, data):
|
481 |
+
# if isinstance(data, tuple):
|
482 |
+
# data = data[0]
|
483 |
+
|
484 |
+
# z_mean, z_log_var, z = self.encoder(data)
|
485 |
+
# reconstruction = self.decoder(z)
|
486 |
+
# reconstruction_loss = tf.reduce_mean(
|
487 |
+
# tf.keras.losses.binary_crossentropy(data, reconstruction)
|
488 |
+
# )
|
489 |
+
# reconstruction_loss *= 28 * 28
|
490 |
+
# kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
|
491 |
+
# kl_loss = tf.reduce_mean(kl_loss)
|
492 |
+
# kl_loss *= -0.5
|
493 |
+
# total_loss = reconstruction_loss + kl_loss
|
494 |
+
# return {
|
495 |
+
# "loss": total_loss,
|
496 |
+
# "reconstruction_loss": reconstruction_loss,
|
497 |
+
# "kl_loss": kl_loss,
|
498 |
+
# }
|
ganime/model/p2p/p2p_v3.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
from tensorflow.keras import Model, Sequential
|
4 |
+
from tensorflow.keras.layers import (
|
5 |
+
LSTM,
|
6 |
+
Activation,
|
7 |
+
BatchNormalization,
|
8 |
+
Conv2D,
|
9 |
+
Conv2DTranspose,
|
10 |
+
Conv3D,
|
11 |
+
Conv3DTranspose,
|
12 |
+
Dense,
|
13 |
+
Flatten,
|
14 |
+
Input,
|
15 |
+
Layer,
|
16 |
+
LeakyReLU,
|
17 |
+
MaxPooling2D,
|
18 |
+
Reshape,
|
19 |
+
TimeDistributed,
|
20 |
+
UpSampling2D,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
SEQ_LEN = 20
|
25 |
+
|
26 |
+
|
27 |
+
class Sampling(Layer):
|
28 |
+
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
|
29 |
+
|
30 |
+
def call(self, inputs):
|
31 |
+
z_mean, z_log_var = inputs
|
32 |
+
batch = tf.shape(z_mean)[0]
|
33 |
+
dim = tf.shape(z_mean)[1]
|
34 |
+
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
|
35 |
+
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
|
36 |
+
|
37 |
+
def compute_output_shape(self, input_shape):
|
38 |
+
return input_shape[0]
|
39 |
+
|
40 |
+
|
41 |
+
class P2P(Model):
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
channels: int = 1,
|
45 |
+
g_dim: int = 128,
|
46 |
+
z_dim: int = 10,
|
47 |
+
rnn_size: int = 256,
|
48 |
+
prior_rnn_layers: int = 1,
|
49 |
+
posterior_rnn_layers: int = 1,
|
50 |
+
predictor_rnn_layers: float = 1,
|
51 |
+
skip_prob: float = 0.1,
|
52 |
+
n_past: int = 1,
|
53 |
+
last_frame_skip: bool = False,
|
54 |
+
beta: float = 0.0001,
|
55 |
+
weight_align: float = 0.1,
|
56 |
+
weight_cpc: float = 100,
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
# Models parameters
|
60 |
+
self.channels = channels
|
61 |
+
self.g_dim = g_dim
|
62 |
+
self.z_dim = z_dim
|
63 |
+
self.rnn_size = rnn_size
|
64 |
+
self.prior_rnn_layers = prior_rnn_layers
|
65 |
+
self.posterior_rnn_layers = posterior_rnn_layers
|
66 |
+
self.predictor_rnn_layers = predictor_rnn_layers
|
67 |
+
|
68 |
+
# Training parameters
|
69 |
+
self.skip_prob = skip_prob
|
70 |
+
self.n_past = n_past
|
71 |
+
self.last_frame_skip = last_frame_skip
|
72 |
+
self.beta = beta
|
73 |
+
self.weight_align = weight_align
|
74 |
+
self.weight_cpc = weight_cpc
|
75 |
+
|
76 |
+
self.frame_predictor = self.build_lstm()
|
77 |
+
self.prior = self.build_gaussian_lstm()
|
78 |
+
self.posterior = self.build_gaussian_lstm()
|
79 |
+
self.encoder = self.build_encoder()
|
80 |
+
self.decoder = self.build_decoder()
|
81 |
+
|
82 |
+
self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
|
83 |
+
self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
|
84 |
+
name="reconstruction_loss"
|
85 |
+
)
|
86 |
+
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
|
87 |
+
|
88 |
+
# region Model building
|
89 |
+
def build_lstm(self):
|
90 |
+
input = Input(shape=(20, self.g_dim + self.z_dim + 1))
|
91 |
+
embed = TimeDistributed(Dense(self.rnn_size))(input)
|
92 |
+
lstm = LSTM(self.rnn_size, return_sequences=True)(embed)
|
93 |
+
output = TimeDistributed(Dense(self.g_dim))(lstm)
|
94 |
+
|
95 |
+
return Model(inputs=input, outputs=output, name="frame_predictor")
|
96 |
+
|
97 |
+
def build_gaussian_lstm(self):
|
98 |
+
|
99 |
+
input = Input(shape=(20, self.g_dim))
|
100 |
+
embed = TimeDistributed(Dense(self.rnn_size))(input)
|
101 |
+
lstm = LSTM(self.rnn_size, return_sequences=True)(embed)
|
102 |
+
mu = TimeDistributed(Dense(self.z_dim))(lstm)
|
103 |
+
logvar = TimeDistributed(Dense(self.z_dim))(lstm)
|
104 |
+
z = TimeDistributed(Sampling())([mu, logvar])
|
105 |
+
|
106 |
+
return Model(inputs=input, outputs=[mu, logvar, z])
|
107 |
+
|
108 |
+
def build_encoder(self):
|
109 |
+
|
110 |
+
input = Input(shape=(2, 64, 64, 1))
|
111 |
+
|
112 |
+
h = TimeDistributed(Conv2D(64, kernel_size=4, strides=2, padding="same"))(input)
|
113 |
+
h = BatchNormalization()(h)
|
114 |
+
h = LeakyReLU(alpha=0.2)(h)
|
115 |
+
# h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
|
116 |
+
|
117 |
+
h = TimeDistributed(Conv2D(128, kernel_size=4, strides=2, padding="same"))(h)
|
118 |
+
h = BatchNormalization()(h)
|
119 |
+
h = LeakyReLU(alpha=0.2)(h)
|
120 |
+
# h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
|
121 |
+
|
122 |
+
h = TimeDistributed(Conv2D(256, kernel_size=4, strides=2, padding="same"))(h)
|
123 |
+
h = BatchNormalization()(h)
|
124 |
+
h = LeakyReLU(alpha=0.2)(h)
|
125 |
+
# h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
|
126 |
+
|
127 |
+
h = TimeDistributed(Conv2D(512, kernel_size=4, strides=2, padding="same"))(h)
|
128 |
+
h = BatchNormalization()(h)
|
129 |
+
h = LeakyReLU(alpha=0.2)(h)
|
130 |
+
# h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h)
|
131 |
+
|
132 |
+
h = Flatten()(h)
|
133 |
+
# mu = Dense(self.g_dim)(h)
|
134 |
+
# logvar = Dense(self.g_dim)(h)
|
135 |
+
|
136 |
+
# z = Sampling()([mu, logvar])
|
137 |
+
lstm_input = Dense(self.g_dim * SEQ_LEN)(h)
|
138 |
+
lstm_input = Reshape((SEQ_LEN, self.g_dim))(lstm_input)
|
139 |
+
mu, logvar, z = self.posterior(lstm_input)
|
140 |
+
|
141 |
+
return Model(inputs=input, outputs=[mu, logvar, z], name="encoder")
|
142 |
+
|
143 |
+
def build_decoder(self):
|
144 |
+
latent_inputs = Input(shape=(SEQ_LEN, self.z_dim))
|
145 |
+
x = Dense(1 * 1 * 1 * 512, activation="relu")(latent_inputs)
|
146 |
+
x = Reshape((SEQ_LEN, 1, 1, 512))(x)
|
147 |
+
x = TimeDistributed(
|
148 |
+
Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid")
|
149 |
+
)(x)
|
150 |
+
x = BatchNormalization()(x)
|
151 |
+
x = LeakyReLU(alpha=0.2)(x)
|
152 |
+
|
153 |
+
x = TimeDistributed(
|
154 |
+
Conv2DTranspose(256, kernel_size=4, strides=2, padding="same")
|
155 |
+
)(x)
|
156 |
+
x = BatchNormalization()(x)
|
157 |
+
x = LeakyReLU(alpha=0.2)(x)
|
158 |
+
|
159 |
+
x = TimeDistributed(
|
160 |
+
Conv2DTranspose(128, kernel_size=4, strides=2, padding="same")
|
161 |
+
)(x)
|
162 |
+
x = BatchNormalization()(x)
|
163 |
+
x = LeakyReLU(alpha=0.2)(x)
|
164 |
+
|
165 |
+
x = TimeDistributed(
|
166 |
+
Conv2DTranspose(64, kernel_size=4, strides=2, padding="same")
|
167 |
+
)(x)
|
168 |
+
x = BatchNormalization()(x)
|
169 |
+
x = LeakyReLU(alpha=0.2)(x)
|
170 |
+
|
171 |
+
x = TimeDistributed(
|
172 |
+
Conv2DTranspose(1, kernel_size=4, strides=2, padding="same")
|
173 |
+
)(x)
|
174 |
+
x = Activation("sigmoid")(x)
|
175 |
+
|
176 |
+
return Model(inputs=latent_inputs, outputs=x, name="decoder")
|
177 |
+
|
178 |
+
# endregion
|
179 |
+
|
180 |
+
@property
|
181 |
+
def metrics(self):
|
182 |
+
return [
|
183 |
+
self.total_loss_tracker,
|
184 |
+
self.reconstruction_loss_tracker,
|
185 |
+
self.kl_loss_tracker,
|
186 |
+
]
|
187 |
+
|
188 |
+
def call(self, inputs, training=None, mask=None):
|
189 |
+
z_mean, z_log_var, z = self.encoder(inputs)
|
190 |
+
pred = self.decoder(z)
|
191 |
+
return pred
|
192 |
+
|
193 |
+
def train_step(self, data):
|
194 |
+
x, y = data
|
195 |
+
|
196 |
+
with tf.GradientTape() as tape:
|
197 |
+
z_mean, z_log_var, z = self.encoder(x)
|
198 |
+
reconstruction = self.decoder(z)
|
199 |
+
reconstruction_loss = tf.reduce_mean(
|
200 |
+
tf.reduce_sum(
|
201 |
+
tf.keras.losses.binary_crossentropy(y, reconstruction),
|
202 |
+
axis=(1, 2),
|
203 |
+
)
|
204 |
+
)
|
205 |
+
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
|
206 |
+
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
|
207 |
+
total_loss = reconstruction_loss + self.beta * kl_loss
|
208 |
+
grads = tape.gradient(total_loss, self.trainable_weights)
|
209 |
+
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
|
210 |
+
self.total_loss_tracker.update_state(total_loss)
|
211 |
+
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
|
212 |
+
self.kl_loss_tracker.update_state(kl_loss)
|
213 |
+
return {
|
214 |
+
"loss": self.total_loss_tracker.result(),
|
215 |
+
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
|
216 |
+
"kl_loss": self.kl_loss_tracker.result(),
|
217 |
+
}
|
218 |
+
|
219 |
+
def test_step(self, data):
|
220 |
+
if isinstance(data, tuple):
|
221 |
+
data = data[0]
|
222 |
+
|
223 |
+
z_mean, z_log_var, z = self.encoder(data)
|
224 |
+
reconstruction = self.decoder(z)
|
225 |
+
reconstruction_loss = tf.reduce_mean(
|
226 |
+
tf.keras.losses.binary_crossentropy(data, reconstruction)
|
227 |
+
)
|
228 |
+
reconstruction_loss *= 28 * 28
|
229 |
+
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
|
230 |
+
kl_loss = tf.reduce_mean(kl_loss)
|
231 |
+
kl_loss *= -0.5
|
232 |
+
total_loss = reconstruction_loss + kl_loss
|
233 |
+
return {
|
234 |
+
"loss": total_loss,
|
235 |
+
"reconstruction_loss": reconstruction_loss,
|
236 |
+
"kl_loss": kl_loss,
|
237 |
+
}
|
ganime/model/vae/vae.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
|
4 |
+
from tensorflow import keras
|
5 |
+
from tensorflow.keras import layers
|
6 |
+
import tensorflow as tf
|
7 |
+
|
8 |
+
input_shape = (20, 64, 64, 1)
|
9 |
+
|
10 |
+
class Sampling(keras.layers.Layer):
|
11 |
+
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
|
12 |
+
|
13 |
+
def call(self, inputs):
|
14 |
+
z_mean, z_log_var = inputs
|
15 |
+
batch = tf.shape(z_mean)[0]
|
16 |
+
dim = z_mean.shape[1:]
|
17 |
+
epsilon = tf.keras.backend.random_normal(shape=(batch, *dim))
|
18 |
+
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
|
19 |
+
|
20 |
+
def compute_output_shape(self, input_shape):
|
21 |
+
return input_shape[0]
|
22 |
+
|
23 |
+
|
24 |
+
class VAE(keras.Model):
|
25 |
+
def __init__(self, latent_dim:int=32, num_embeddings:int=128, beta:float = 0.5, **kwargs):
|
26 |
+
super().__init__(**kwargs)
|
27 |
+
self.latent_dim = latent_dim
|
28 |
+
self.num_embeddings = num_embeddings
|
29 |
+
self.beta = beta
|
30 |
+
|
31 |
+
self.encoder = self.get_encoder()
|
32 |
+
self.decoder = self.get_decoder()
|
33 |
+
|
34 |
+
self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
|
35 |
+
self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
|
36 |
+
name="reconstruction_loss"
|
37 |
+
)
|
38 |
+
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
|
39 |
+
|
40 |
+
|
41 |
+
def get_encoder(self):
|
42 |
+
encoder_inputs = keras.Input(shape=input_shape)
|
43 |
+
x = layers.TimeDistributed(layers.Conv2D(32, 3, activation="relu", strides=2, padding="same"))(
|
44 |
+
encoder_inputs
|
45 |
+
)
|
46 |
+
x = layers.TimeDistributed(layers.Conv2D(64, 3, activation="relu", strides=2, padding="same"))(x)
|
47 |
+
x = layers.TimeDistributed(layers.Conv2D(self.latent_dim, 1, padding="same"))(x)
|
48 |
+
|
49 |
+
x = layers.TimeDistributed(layers.Flatten())(x)
|
50 |
+
mu = layers.TimeDistributed(layers.Dense(self.num_embeddings))(x)
|
51 |
+
logvar = layers.TimeDistributed(layers.Dense(self.num_embeddings))(x)
|
52 |
+
z = Sampling()([mu, logvar])
|
53 |
+
|
54 |
+
return keras.Model(encoder_inputs, [mu, logvar, z], name="encoder")
|
55 |
+
|
56 |
+
|
57 |
+
def get_decoder(self):
|
58 |
+
latent_inputs = keras.Input(shape=self.encoder.output[2].shape[1:])
|
59 |
+
|
60 |
+
x = layers.TimeDistributed(layers.Dense(16 * 16 * 32, activation="relu"))(latent_inputs)
|
61 |
+
x = layers.TimeDistributed(layers.Reshape((16, 16, 32)))(x)
|
62 |
+
x = layers.TimeDistributed(layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same"))(
|
63 |
+
x
|
64 |
+
)
|
65 |
+
x = layers.TimeDistributed(layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same"))(x)
|
66 |
+
decoder_outputs = layers.TimeDistributed(layers.Conv2DTranspose(1, 3, padding="same"))(x)
|
67 |
+
return keras.Model(latent_inputs, decoder_outputs, name="decoder")
|
68 |
+
|
69 |
+
def train_step(self, data):
|
70 |
+
x, y = data
|
71 |
+
|
72 |
+
with tf.GradientTape() as tape:
|
73 |
+
mu, logvar, z = self.encoder(x)
|
74 |
+
reconstruction = self.decoder(z)
|
75 |
+
reconstruction_loss = tf.reduce_mean(
|
76 |
+
tf.reduce_sum(
|
77 |
+
tf.keras.losses.binary_crossentropy(y, reconstruction),
|
78 |
+
axis=(1, 2),
|
79 |
+
)
|
80 |
+
)
|
81 |
+
kl_loss = -0.5 * (1 + logvar - tf.square(mu) - tf.exp(logvar))
|
82 |
+
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
|
83 |
+
total_loss = reconstruction_loss + self.beta * kl_loss
|
84 |
+
grads = tape.gradient(total_loss, self.trainable_weights)
|
85 |
+
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
|
86 |
+
self.total_loss_tracker.update_state(total_loss)
|
87 |
+
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
|
88 |
+
self.kl_loss_tracker.update_state(kl_loss)
|
89 |
+
return {
|
90 |
+
"loss": self.total_loss_tracker.result(),
|
91 |
+
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
|
92 |
+
"kl_loss": self.kl_loss_tracker.result(),
|
93 |
+
}
|
94 |
+
|
95 |
+
def call(self, inputs, training=False, mask=None):
|
96 |
+
z_mean, z_log_var, z = self.encoder(inputs)
|
97 |
+
pred = self.decoder(z)
|
98 |
+
return pred
|
ganime/model/vq_vae/vq_vae.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
|
4 |
+
from tensorflow import keras
|
5 |
+
from tensorflow.keras import layers
|
6 |
+
import tensorflow as tf
|
7 |
+
|
8 |
+
input_shape = (20, 64, 64, 1)
|
9 |
+
|
10 |
+
class VectorQuantizer(layers.Layer):
|
11 |
+
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
|
12 |
+
super().__init__(**kwargs)
|
13 |
+
self.embedding_dim = embedding_dim
|
14 |
+
self.num_embeddings = num_embeddings
|
15 |
+
self.beta = (
|
16 |
+
beta # This parameter is best kept between [0.25, 2] as per the paper.
|
17 |
+
)
|
18 |
+
|
19 |
+
# Initialize the embeddings which we will quantize.
|
20 |
+
w_init = tf.random_uniform_initializer()
|
21 |
+
self.embeddings = tf.Variable(
|
22 |
+
initial_value=w_init(
|
23 |
+
shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
|
24 |
+
),
|
25 |
+
trainable=True,
|
26 |
+
name="embeddings_vqvae",
|
27 |
+
)
|
28 |
+
|
29 |
+
def call(self, x):
|
30 |
+
# Calculate the input shape of the inputs and
|
31 |
+
# then flatten the inputs keeping `embedding_dim` intact.
|
32 |
+
input_shape = tf.shape(x)
|
33 |
+
flattened = tf.reshape(x, [-1, self.embedding_dim])
|
34 |
+
|
35 |
+
# Quantization.
|
36 |
+
encoding_indices = self.get_code_indices(flattened)
|
37 |
+
encodings = tf.one_hot(encoding_indices, self.num_embeddings)
|
38 |
+
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
|
39 |
+
quantized = tf.reshape(quantized, input_shape)
|
40 |
+
|
41 |
+
# Calculate vector quantization loss and add that to the layer. You can learn more
|
42 |
+
# about adding losses to different layers here:
|
43 |
+
# https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
|
44 |
+
# the original paper to get a handle on the formulation of the loss function.
|
45 |
+
commitment_loss = self.beta * tf.reduce_mean(
|
46 |
+
(tf.stop_gradient(quantized) - x) ** 2
|
47 |
+
)
|
48 |
+
codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
|
49 |
+
self.add_loss(commitment_loss + codebook_loss)
|
50 |
+
|
51 |
+
# Straight-through estimator.
|
52 |
+
quantized = x + tf.stop_gradient(quantized - x)
|
53 |
+
return quantized
|
54 |
+
|
55 |
+
def get_code_indices(self, flattened_inputs):
|
56 |
+
# Calculate L2-normalized distance between the inputs and the codes.
|
57 |
+
similarity = tf.matmul(flattened_inputs, self.embeddings)
|
58 |
+
distances = (
|
59 |
+
tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
|
60 |
+
+ tf.reduce_sum(self.embeddings ** 2, axis=0)
|
61 |
+
- 2 * similarity
|
62 |
+
)
|
63 |
+
|
64 |
+
# Derive the indices for minimum distances.
|
65 |
+
encoding_indices = tf.argmin(distances, axis=1)
|
66 |
+
return encoding_indices
|
67 |
+
|
68 |
+
|
69 |
+
class VQVAE(keras.Model):
|
70 |
+
def __init__(self, train_variance:float, latent_dim:int=32, num_embeddings:int=128, **kwargs):
|
71 |
+
super().__init__(**kwargs)
|
72 |
+
self.train_variance = train_variance
|
73 |
+
self.latent_dim = latent_dim
|
74 |
+
self.num_embeddings = num_embeddings
|
75 |
+
|
76 |
+
self.vqvae = self.get_vqvae()
|
77 |
+
|
78 |
+
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
|
79 |
+
self.reconstruction_loss_tracker = keras.metrics.Mean(
|
80 |
+
name="reconstruction_loss"
|
81 |
+
)
|
82 |
+
self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss")
|
83 |
+
|
84 |
+
|
85 |
+
def get_encoder(self):
|
86 |
+
encoder_inputs = keras.Input(shape=input_shape)
|
87 |
+
x = layers.TimeDistributed(layers.Conv2D(32, 3, activation="relu", strides=2, padding="same"))(
|
88 |
+
encoder_inputs
|
89 |
+
)
|
90 |
+
x = layers.TimeDistributed(layers.Conv2D(64, 3, activation="relu", strides=2, padding="same"))(x)
|
91 |
+
encoder_outputs = layers.TimeDistributed(layers.Conv2D(self.latent_dim, 1, padding="same"))(x)
|
92 |
+
return keras.Model(encoder_inputs, encoder_outputs, name="encoder")
|
93 |
+
|
94 |
+
|
95 |
+
def get_decoder(self):
|
96 |
+
latent_inputs = keras.Input(shape=self.get_encoder().output.shape[1:])
|
97 |
+
x = layers.TimeDistributed(layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same"))(
|
98 |
+
latent_inputs
|
99 |
+
)
|
100 |
+
x = layers.TimeDistributed(layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same"))(x)
|
101 |
+
decoder_outputs = layers.TimeDistributed(layers.Conv2DTranspose(1, 3, padding="same"))(x)
|
102 |
+
return keras.Model(latent_inputs, decoder_outputs, name="decoder")
|
103 |
+
|
104 |
+
def get_vqvae(self):
|
105 |
+
self.vq_layer = VectorQuantizer(self.num_embeddings, self.latent_dim, name="vector_quantizer")
|
106 |
+
self.encoder = self.get_encoder()
|
107 |
+
self.decoder = self.get_decoder()
|
108 |
+
inputs = keras.Input(shape=input_shape)
|
109 |
+
encoder_outputs = self.encoder(inputs)
|
110 |
+
quantized_latents = self.vq_layer(encoder_outputs)
|
111 |
+
reconstructions = self.decoder(quantized_latents)
|
112 |
+
return keras.Model(inputs, reconstructions, name="vq_vae")
|
113 |
+
|
114 |
+
def train_step(self, data):
|
115 |
+
x, y = data
|
116 |
+
with tf.GradientTape() as tape:
|
117 |
+
# Outputs from the VQ-VAE.
|
118 |
+
reconstructions = self.vqvae(x)
|
119 |
+
|
120 |
+
# Calculate the losses.
|
121 |
+
reconstruction_loss = (
|
122 |
+
tf.reduce_mean((y - reconstructions) ** 2) / self.train_variance
|
123 |
+
)
|
124 |
+
total_loss = reconstruction_loss + sum(self.vqvae.losses)
|
125 |
+
|
126 |
+
# Backpropagation.
|
127 |
+
grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
|
128 |
+
self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))
|
129 |
+
|
130 |
+
# Loss tracking.
|
131 |
+
self.total_loss_tracker.update_state(total_loss)
|
132 |
+
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
|
133 |
+
self.vq_loss_tracker.update_state(sum(self.vqvae.losses))
|
134 |
+
|
135 |
+
# Log results.
|
136 |
+
return {
|
137 |
+
"loss": self.total_loss_tracker.result(),
|
138 |
+
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
|
139 |
+
"vqvae_loss": self.vq_loss_tracker.result(),
|
140 |
+
}
|
141 |
+
|
142 |
+
def call(self, inputs, training=False, mask=None):
|
143 |
+
return self.vqvae(inputs)
|
ganime/model/vqgan/__init__.py
ADDED
File without changes
|
ganime/model/vqgan/discriminator/__init__.py
ADDED
File without changes
|
ganime/model/vqgan/discriminator/model.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow import keras
|
5 |
+
from tensorflow.keras import Model, Sequential
|
6 |
+
from tensorflow.keras import layers
|
7 |
+
|
8 |
+
|
9 |
+
class NLayerDiscriminator(Model):
|
10 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
11 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, input_channels: int = 3, filters: int = 64, n_layers: int = 3):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
kernel_size = 4
|
18 |
+
self.sequence = [
|
19 |
+
layers.Conv2D(filters, kernel_size=kernel_size, padding="same"),
|
20 |
+
layers.LeakyReLU(alpha=0.2),
|
21 |
+
]
|
22 |
+
|
23 |
+
filters_mult = 1
|
24 |
+
for n in range(1, n_layers):
|
25 |
+
filters_mult = min(2**n, 8)
|
26 |
+
|
27 |
+
self.sequence += [
|
28 |
+
layers.AveragePooling2D(pool_size=2),
|
29 |
+
layers.Conv2D(
|
30 |
+
filters * filters_mult,
|
31 |
+
kernel_size=kernel_size,
|
32 |
+
strides=1, # 2,
|
33 |
+
padding="same",
|
34 |
+
use_bias=False,
|
35 |
+
),
|
36 |
+
layers.BatchNormalization(),
|
37 |
+
layers.LeakyReLU(alpha=0.2),
|
38 |
+
]
|
39 |
+
|
40 |
+
filters_mult = min(2**n_layers, 8)
|
41 |
+
self.sequence += [
|
42 |
+
layers.Conv2D(
|
43 |
+
filters * filters_mult,
|
44 |
+
kernel_size=kernel_size,
|
45 |
+
strides=1,
|
46 |
+
padding="same",
|
47 |
+
use_bias=False,
|
48 |
+
),
|
49 |
+
layers.BatchNormalization(),
|
50 |
+
layers.LeakyReLU(alpha=0.2),
|
51 |
+
]
|
52 |
+
|
53 |
+
self.sequence += [
|
54 |
+
layers.Conv2D(1, kernel_size=kernel_size, strides=1, padding="same")
|
55 |
+
]
|
56 |
+
|
57 |
+
# self.main = Sequential(sequence)
|
58 |
+
|
59 |
+
def call(self, inputs, training=True, mask=None):
|
60 |
+
h = inputs
|
61 |
+
for seq in self.sequence:
|
62 |
+
h = seq(h)
|
63 |
+
return h
|
64 |
+
# return self.main(inputs)
|
ganime/model/vqgan/losses/__init__.py
ADDED
File without changes
|
ganime/model/vqgan/losses/lpips.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
import torchvision.models as models
|
5 |
+
from tensorflow import keras
|
6 |
+
from tensorflow.keras import Model, Sequential
|
7 |
+
from tensorflow.keras import backend as K
|
8 |
+
from tensorflow.keras import layers
|
9 |
+
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
|
10 |
+
from tensorflow.keras.losses import Loss
|
11 |
+
from pyprojroot.pyprojroot import here
|
12 |
+
|
13 |
+
|
14 |
+
def normalize_tensor(x, eps=1e-10):
|
15 |
+
norm_factor = tf.sqrt(tf.reduce_sum(x**2, axis=-1, keepdims=True))
|
16 |
+
return x / (norm_factor + eps)
|
17 |
+
|
18 |
+
|
19 |
+
class LPIPS(Loss):
|
20 |
+
def __init__(self, use_dropout=True, **kwargs):
|
21 |
+
super().__init__(**kwargs)
|
22 |
+
|
23 |
+
self.scaling_layer = ScalingLayer() # preprocess_input
|
24 |
+
selected_layers = [
|
25 |
+
"block1_conv2",
|
26 |
+
"block2_conv2",
|
27 |
+
"block3_conv3",
|
28 |
+
"block4_conv3",
|
29 |
+
"block5_conv3",
|
30 |
+
]
|
31 |
+
|
32 |
+
# TODO here we load the same weights as pytorch, try with tensorflow weights
|
33 |
+
self.net = self.load_vgg16() # VGG16(weights="imagenet", include_top=False)
|
34 |
+
self.net.trainable = False
|
35 |
+
outputs = [self.net.get_layer(layer).output for layer in selected_layers]
|
36 |
+
|
37 |
+
self.model = Model(self.net.input, outputs)
|
38 |
+
self.lins = [NetLinLayer(use_dropout=use_dropout) for _ in selected_layers]
|
39 |
+
|
40 |
+
# TODO: here we use the pytorch weights of the linear layers, try without these layers, or without initializing the weights
|
41 |
+
self(tf.zeros((1, 16, 16, 1)), tf.zeros((1, 16, 16, 1)))
|
42 |
+
self.init_lin_layers()
|
43 |
+
|
44 |
+
def load_vgg16(self) -> Model:
|
45 |
+
"""Load a VGG16 model with the same weights as PyTorch
|
46 |
+
https://github.com/ezavarygin/vgg16_pytorch2keras
|
47 |
+
"""
|
48 |
+
pytorch_model = models.vgg16(pretrained=True)
|
49 |
+
# select weights in the conv2d layers and transpose them to keras dim ordering:
|
50 |
+
wblist_torch = list(pytorch_model.parameters())[:26]
|
51 |
+
wblist_keras = []
|
52 |
+
for i in range(len(wblist_torch)):
|
53 |
+
if wblist_torch[i].dim() == 4:
|
54 |
+
w = np.transpose(wblist_torch[i].detach().numpy(), axes=[2, 3, 1, 0])
|
55 |
+
wblist_keras.append(w)
|
56 |
+
elif wblist_torch[i].dim() == 1:
|
57 |
+
b = wblist_torch[i].detach().numpy()
|
58 |
+
wblist_keras.append(b)
|
59 |
+
else:
|
60 |
+
raise Exception("Fully connected layers are not implemented.")
|
61 |
+
|
62 |
+
keras_model = VGG16(include_top=False, weights=None)
|
63 |
+
keras_model.set_weights(wblist_keras)
|
64 |
+
return keras_model
|
65 |
+
|
66 |
+
def init_lin_layers(self):
|
67 |
+
for i in range(5):
|
68 |
+
weights = np.load(
|
69 |
+
os.path.join(here(), "models", "NetLinLayer", f"numpy_{i}.npy")
|
70 |
+
)
|
71 |
+
weights = np.moveaxis(weights, 1, 2)
|
72 |
+
self.lins[i].model.layers[1].set_weights([weights])
|
73 |
+
|
74 |
+
def call(self, y_true, y_pred):
|
75 |
+
|
76 |
+
scaled_true = self.scaling_layer(y_true)
|
77 |
+
scaled_pred = self.scaling_layer(y_pred)
|
78 |
+
|
79 |
+
outputs_true, outputs_pred = self.model(scaled_true), self.model(scaled_pred)
|
80 |
+
features_true, features_pred, diffs = {}, {}, {}
|
81 |
+
|
82 |
+
for kk in range(len(outputs_true)):
|
83 |
+
features_true[kk], features_pred[kk] = normalize_tensor(
|
84 |
+
outputs_true[kk]
|
85 |
+
), normalize_tensor(outputs_pred[kk])
|
86 |
+
|
87 |
+
diffs[kk] = (features_true[kk] - features_pred[kk]) ** 2
|
88 |
+
|
89 |
+
res = [
|
90 |
+
tf.reduce_mean(self.lins[kk](diffs[kk]), axis=(-3, -2), keepdims=True)
|
91 |
+
for kk in range(len(outputs_true))
|
92 |
+
]
|
93 |
+
|
94 |
+
return tf.reduce_sum(res)
|
95 |
+
|
96 |
+
# h1_list = self.model(self.scaling_layer(y_true))
|
97 |
+
# h2_list = self.model(self.scaling_layer(y_pred))
|
98 |
+
|
99 |
+
# rc_loss = 0.0
|
100 |
+
# for h1, h2 in zip(h1_list, h2_list):
|
101 |
+
# h1 = K.batch_flatten(h1)
|
102 |
+
# h2 = K.batch_flatten(h2)
|
103 |
+
# rc_loss += K.sum(K.square(h1 - h2), axis=-1)
|
104 |
+
|
105 |
+
# return rc_loss
|
106 |
+
|
107 |
+
|
108 |
+
class ScalingLayer(layers.Layer):
|
109 |
+
def __init__(self, **kwargs):
|
110 |
+
super().__init__(**kwargs)
|
111 |
+
self.shift = tf.Variable([-0.030, -0.088, -0.188])
|
112 |
+
self.scale = tf.Variable([0.458, 0.448, 0.450])
|
113 |
+
|
114 |
+
def call(self, inputs):
|
115 |
+
return (inputs - self.shift) / self.scale
|
116 |
+
|
117 |
+
|
118 |
+
class NetLinLayer(layers.Layer):
|
119 |
+
def __init__(self, channels_out=1, use_dropout=False):
|
120 |
+
super().__init__()
|
121 |
+
sequence = (
|
122 |
+
[
|
123 |
+
layers.Dropout(0.5),
|
124 |
+
]
|
125 |
+
if use_dropout
|
126 |
+
else []
|
127 |
+
)
|
128 |
+
sequence += [
|
129 |
+
layers.Conv2D(channels_out, 1, padding="same", use_bias=False),
|
130 |
+
]
|
131 |
+
self.model = Sequential(sequence)
|
132 |
+
|
133 |
+
def call(self, inputs):
|
134 |
+
return self.model(inputs)
|
ganime/model/vqgan/losses/vqperceptual.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Literal
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import tensorflow as tf
|
5 |
+
from tensorflow import keras
|
6 |
+
from tensorflow.keras import Model, layers
|
7 |
+
from tensorflow.keras.losses import Loss
|
8 |
+
|
9 |
+
from .lpips import LPIPS
|
10 |
+
|
11 |
+
from ..discriminator.model import NLayerDiscriminator
|
12 |
+
|
13 |
+
|
14 |
+
class VQLPIPSWithDiscriminator(Loss):
|
15 |
+
def __init__(
|
16 |
+
self, *, pixelloss_weight: float = 1.0, perceptual_weight: float = 1.0, **kwargs
|
17 |
+
):
|
18 |
+
super().__init__(**kwargs)
|
19 |
+
self.pixelloss_weight = pixelloss_weight
|
20 |
+
self.perceptual_loss = LPIPS(reduction=tf.keras.losses.Reduction.NONE)
|
21 |
+
self.perceptual_weight = perceptual_weight
|
22 |
+
|
23 |
+
def call(
|
24 |
+
self,
|
25 |
+
y_true,
|
26 |
+
y_pred,
|
27 |
+
):
|
28 |
+
reconstruction_loss = tf.abs(y_true - y_pred)
|
29 |
+
if self.perceptual_weight > 0:
|
30 |
+
perceptual_loss = self.perceptual_loss(y_true, y_pred)
|
31 |
+
reconstruction_loss += self.perceptual_weight * perceptual_loss
|
32 |
+
else:
|
33 |
+
perceptual_loss = 0.0
|
34 |
+
|
35 |
+
neg_log_likelihood = tf.reduce_mean(reconstruction_loss)
|
36 |
+
|
37 |
+
return neg_log_likelihood
|
38 |
+
|
39 |
+
# # GAN part
|
40 |
+
# if optimizer_idx == 0:
|
41 |
+
# if cond is None:
|
42 |
+
# assert not self.disc_conditional
|
43 |
+
# logits_fake = self.discriminator(y_pred)
|
44 |
+
# else:
|
45 |
+
# assert self.disc_conditional
|
46 |
+
# logits_fake = self.discriminator(tf.concat([y_pred, cond], axis=-1))
|
47 |
+
# g_loss = -tf.reduce_mean(logits_fake)
|
ganime/model/vqgan/vqgan.py
ADDED
@@ -0,0 +1,722 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Literal
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import tensorflow as tf
|
5 |
+
from .discriminator.model import NLayerDiscriminator
|
6 |
+
from .losses.vqperceptual import VQLPIPSWithDiscriminator
|
7 |
+
from tensorflow import keras
|
8 |
+
from tensorflow.keras import Model, layers, Sequential
|
9 |
+
from tensorflow.keras.optimizers import Optimizer
|
10 |
+
from tensorflow_addons.layers import GroupNormalization
|
11 |
+
|
12 |
+
INPUT_SHAPE = (64, 128, 3)
|
13 |
+
ENCODER_OUTPUT_SHAPE = (8, 8, 128)
|
14 |
+
|
15 |
+
|
16 |
+
@tf.function
|
17 |
+
def hinge_d_loss(logits_real, logits_fake):
|
18 |
+
loss_real = tf.reduce_mean(keras.activations.relu(1.0 - logits_real))
|
19 |
+
loss_fake = tf.reduce_mean(keras.activations.relu(1.0 + logits_fake))
|
20 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
21 |
+
return d_loss
|
22 |
+
|
23 |
+
|
24 |
+
@tf.function
|
25 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
26 |
+
d_loss = 0.5 * (
|
27 |
+
tf.reduce_mean(keras.activations.softplus(-logits_real))
|
28 |
+
+ tf.reduce_mean(keras.activations.softplus(logits_fake))
|
29 |
+
)
|
30 |
+
return d_loss
|
31 |
+
|
32 |
+
|
33 |
+
class VQGAN(keras.Model):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
train_variance: float,
|
37 |
+
num_embeddings: int,
|
38 |
+
embedding_dim: int,
|
39 |
+
beta: float = 0.25,
|
40 |
+
z_channels: int = 128, # 256,
|
41 |
+
codebook_weight: float = 1.0,
|
42 |
+
disc_num_layers: int = 3,
|
43 |
+
disc_factor: float = 1.0,
|
44 |
+
disc_iter_start: int = 0,
|
45 |
+
disc_conditional: bool = False,
|
46 |
+
disc_in_channels: int = 3,
|
47 |
+
disc_weight: float = 0.3,
|
48 |
+
disc_filters: int = 64,
|
49 |
+
disc_loss: Literal["hinge", "vanilla"] = "hinge",
|
50 |
+
**kwargs,
|
51 |
+
):
|
52 |
+
super().__init__(**kwargs)
|
53 |
+
self.train_variance = train_variance
|
54 |
+
self.codebook_weight = codebook_weight
|
55 |
+
|
56 |
+
self.encoder = Encoder()
|
57 |
+
self.decoder = Decoder()
|
58 |
+
self.quantize = VectorQuantizer(num_embeddings, embedding_dim, beta=beta)
|
59 |
+
|
60 |
+
self.quant_conv = layers.Conv2D(embedding_dim, kernel_size=1)
|
61 |
+
self.post_quant_conv = layers.Conv2D(z_channels, kernel_size=1)
|
62 |
+
|
63 |
+
self.vqvae = self.get_vqvae()
|
64 |
+
|
65 |
+
self.perceptual_loss = VQLPIPSWithDiscriminator(
|
66 |
+
reduction=tf.keras.losses.Reduction.NONE
|
67 |
+
)
|
68 |
+
|
69 |
+
self.discriminator = NLayerDiscriminator(
|
70 |
+
input_channels=disc_in_channels,
|
71 |
+
filters=disc_filters,
|
72 |
+
n_layers=disc_num_layers,
|
73 |
+
)
|
74 |
+
self.discriminator_iter_start = disc_iter_start
|
75 |
+
|
76 |
+
if disc_loss == "hinge":
|
77 |
+
self.disc_loss = hinge_d_loss
|
78 |
+
elif disc_loss == "vanilla":
|
79 |
+
self.disc_loss = vanilla_d_loss
|
80 |
+
else:
|
81 |
+
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
82 |
+
|
83 |
+
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
84 |
+
self.disc_factor = disc_factor
|
85 |
+
self.discriminator_weight = disc_weight
|
86 |
+
self.disc_conditional = disc_conditional
|
87 |
+
|
88 |
+
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
|
89 |
+
self.reconstruction_loss_tracker = keras.metrics.Mean(
|
90 |
+
name="reconstruction_loss"
|
91 |
+
)
|
92 |
+
self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss")
|
93 |
+
self.disc_loss_tracker = keras.metrics.Mean(name="disc_loss")
|
94 |
+
|
95 |
+
self.gen_optimizer: Optimizer = None
|
96 |
+
self.disc_optimizer: Optimizer = None
|
97 |
+
|
98 |
+
def get_vqvae(self):
|
99 |
+
inputs = keras.Input(shape=INPUT_SHAPE)
|
100 |
+
quant = self.encode(inputs)
|
101 |
+
reconstructed = self.decode(quant)
|
102 |
+
return keras.Model(inputs, reconstructed, name="vq_vae")
|
103 |
+
|
104 |
+
def encode(self, x):
|
105 |
+
h = self.encoder(x)
|
106 |
+
h = self.quant_conv(h)
|
107 |
+
return self.quantize(h)
|
108 |
+
|
109 |
+
def decode(self, quant):
|
110 |
+
quant = self.post_quant_conv(quant)
|
111 |
+
dec = self.decoder(quant)
|
112 |
+
return dec
|
113 |
+
|
114 |
+
def call(self, inputs, training=True, mask=None):
|
115 |
+
return self.vqvae(inputs)
|
116 |
+
|
117 |
+
def calculate_adaptive_weight(
|
118 |
+
self, nll_loss, g_loss, tape, trainable_vars, discriminator_weight
|
119 |
+
):
|
120 |
+
nll_grads = tape.gradient(nll_loss, trainable_vars)[0]
|
121 |
+
g_grads = tape.gradient(g_loss, trainable_vars)[0]
|
122 |
+
|
123 |
+
d_weight = tf.norm(nll_grads) / (tf.norm(g_grads) + 1e-4)
|
124 |
+
d_weight = tf.stop_gradient(tf.clip_by_value(d_weight, 0.0, 1e4))
|
125 |
+
return d_weight * discriminator_weight
|
126 |
+
|
127 |
+
@tf.function
|
128 |
+
def adopt_weight(self, weight, global_step, threshold=0, value=0.0):
|
129 |
+
if global_step < threshold:
|
130 |
+
weight = value
|
131 |
+
return weight
|
132 |
+
|
133 |
+
def get_global_step(self, optimizer):
|
134 |
+
return optimizer.iterations
|
135 |
+
|
136 |
+
def compile(
|
137 |
+
self,
|
138 |
+
gen_optimizer,
|
139 |
+
disc_optimizer,
|
140 |
+
):
|
141 |
+
super().compile()
|
142 |
+
self.gen_optimizer = gen_optimizer
|
143 |
+
self.disc_optimizer = disc_optimizer
|
144 |
+
|
145 |
+
def train_step(self, data):
|
146 |
+
x, y = data
|
147 |
+
|
148 |
+
# Autoencode
|
149 |
+
with tf.GradientTape() as tape:
|
150 |
+
with tf.GradientTape(persistent=True) as adaptive_tape:
|
151 |
+
reconstructions = self(x, training=True)
|
152 |
+
|
153 |
+
# Calculate the losses.
|
154 |
+
# reconstruction_loss = (
|
155 |
+
# tf.reduce_mean((y - reconstructions) ** 2) / self.train_variance
|
156 |
+
# )
|
157 |
+
|
158 |
+
logits_fake = self.discriminator(reconstructions, training=False)
|
159 |
+
|
160 |
+
g_loss = -tf.reduce_mean(logits_fake)
|
161 |
+
nll_loss = self.perceptual_loss(y, reconstructions)
|
162 |
+
|
163 |
+
d_weight = self.calculate_adaptive_weight(
|
164 |
+
nll_loss,
|
165 |
+
g_loss,
|
166 |
+
adaptive_tape,
|
167 |
+
self.decoder.conv_out.trainable_variables,
|
168 |
+
self.discriminator_weight,
|
169 |
+
)
|
170 |
+
del adaptive_tape
|
171 |
+
|
172 |
+
disc_factor = self.adopt_weight(
|
173 |
+
weight=self.disc_factor,
|
174 |
+
global_step=self.get_global_step(self.gen_optimizer),
|
175 |
+
threshold=self.discriminator_iter_start,
|
176 |
+
)
|
177 |
+
|
178 |
+
# total_loss = reconstruction_loss + sum(self.vqvae.losses)
|
179 |
+
total_loss = (
|
180 |
+
nll_loss
|
181 |
+
+ d_weight * disc_factor * g_loss
|
182 |
+
# + self.codebook_weight * tf.reduce_mean(self.vqvae.losses)
|
183 |
+
+ self.codebook_weight * sum(self.vqvae.losses)
|
184 |
+
)
|
185 |
+
|
186 |
+
# Backpropagation.
|
187 |
+
grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
|
188 |
+
self.gen_optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))
|
189 |
+
|
190 |
+
# Discriminator
|
191 |
+
with tf.GradientTape() as disc_tape:
|
192 |
+
logits_real = self.discriminator(y, training=True)
|
193 |
+
logits_fake = self.discriminator(reconstructions, training=True)
|
194 |
+
|
195 |
+
disc_factor = self.adopt_weight(
|
196 |
+
weight=self.disc_factor,
|
197 |
+
global_step=self.get_global_step(self.disc_optimizer),
|
198 |
+
threshold=self.discriminator_iter_start,
|
199 |
+
)
|
200 |
+
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
201 |
+
|
202 |
+
disc_grads = disc_tape.gradient(d_loss, self.discriminator.trainable_variables)
|
203 |
+
self.disc_optimizer.apply_gradients(
|
204 |
+
zip(disc_grads, self.discriminator.trainable_variables)
|
205 |
+
)
|
206 |
+
|
207 |
+
# Loss tracking.
|
208 |
+
self.total_loss_tracker.update_state(total_loss)
|
209 |
+
self.reconstruction_loss_tracker.update_state(nll_loss)
|
210 |
+
self.vq_loss_tracker.update_state(sum(self.vqvae.losses))
|
211 |
+
self.disc_loss_tracker.update_state(d_loss)
|
212 |
+
|
213 |
+
# Log results.
|
214 |
+
return {
|
215 |
+
"loss": self.total_loss_tracker.result(),
|
216 |
+
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
|
217 |
+
"vqvae_loss": self.vq_loss_tracker.result(),
|
218 |
+
"disc_loss": self.disc_loss_tracker.result(),
|
219 |
+
}
|
220 |
+
|
221 |
+
|
222 |
+
class VectorQuantizer(layers.Layer):
|
223 |
+
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
|
224 |
+
super().__init__(**kwargs)
|
225 |
+
self.embedding_dim = embedding_dim
|
226 |
+
self.num_embeddings = num_embeddings
|
227 |
+
self.beta = (
|
228 |
+
beta # This parameter is best kept between [0.25, 2] as per the paper.
|
229 |
+
)
|
230 |
+
|
231 |
+
# Initialize the embeddings which we will quantize.
|
232 |
+
w_init = tf.random_uniform_initializer()
|
233 |
+
self.embeddings = tf.Variable(
|
234 |
+
initial_value=w_init(
|
235 |
+
shape=(self.embedding_dim, self.num_embeddings) # , dtype="float32"
|
236 |
+
),
|
237 |
+
trainable=True,
|
238 |
+
name="embeddings_vqvae",
|
239 |
+
)
|
240 |
+
|
241 |
+
def call(self, x):
|
242 |
+
# Calculate the input shape of the inputs and
|
243 |
+
# then flatten the inputs keeping `embedding_dim` intact.
|
244 |
+
input_shape = tf.shape(x)
|
245 |
+
flattened = tf.reshape(x, [-1, self.embedding_dim])
|
246 |
+
|
247 |
+
# Quantization.
|
248 |
+
encoding_indices = self.get_code_indices(flattened)
|
249 |
+
encodings = tf.one_hot(encoding_indices, self.num_embeddings)
|
250 |
+
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
|
251 |
+
quantized = tf.reshape(quantized, input_shape)
|
252 |
+
|
253 |
+
# Calculate vector quantization loss and add that to the layer. You can learn more
|
254 |
+
# about adding losses to different layers here:
|
255 |
+
# https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
|
256 |
+
# the original paper to get a handle on the formulation of the loss function.
|
257 |
+
commitment_loss = self.beta * tf.reduce_mean(
|
258 |
+
(tf.stop_gradient(quantized) - x) ** 2
|
259 |
+
)
|
260 |
+
codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
|
261 |
+
self.add_loss(commitment_loss + codebook_loss)
|
262 |
+
|
263 |
+
# Straight-through estimator.
|
264 |
+
quantized = x + tf.stop_gradient(quantized - x)
|
265 |
+
return quantized
|
266 |
+
|
267 |
+
def get_code_indices(self, flattened_inputs):
|
268 |
+
# Calculate L2-normalized distance between the inputs and the codes.
|
269 |
+
similarity = tf.matmul(flattened_inputs, self.embeddings)
|
270 |
+
distances = (
|
271 |
+
tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True)
|
272 |
+
+ tf.reduce_sum(self.embeddings**2, axis=0)
|
273 |
+
- 2 * similarity
|
274 |
+
)
|
275 |
+
|
276 |
+
# Derive the indices for minimum distances.
|
277 |
+
encoding_indices = tf.argmin(distances, axis=1)
|
278 |
+
return encoding_indices
|
279 |
+
|
280 |
+
|
281 |
+
class Encoder(Model):
|
282 |
+
def __init__(
|
283 |
+
self,
|
284 |
+
*,
|
285 |
+
channels: int = 128,
|
286 |
+
output_channels: int = 3,
|
287 |
+
channels_multiplier: List[int] = [1, 1, 2, 2], # [1, 1, 2, 2, 4],
|
288 |
+
num_res_blocks: int = 1, # 2,
|
289 |
+
attention_resolution: List[int] = [16],
|
290 |
+
resolution: int = 64, # 256,
|
291 |
+
z_channels=128, # 256,
|
292 |
+
dropout=0.0,
|
293 |
+
double_z=False,
|
294 |
+
resamp_with_conv=True,
|
295 |
+
):
|
296 |
+
super().__init__()
|
297 |
+
|
298 |
+
self.channels = channels
|
299 |
+
self.timestep_embeddings_channel = 0
|
300 |
+
self.num_resolutions = len(channels_multiplier)
|
301 |
+
self.num_res_blocks = num_res_blocks
|
302 |
+
self.resolution = resolution
|
303 |
+
|
304 |
+
self.conv_in = layers.Conv2D(
|
305 |
+
self.channels, kernel_size=3, strides=1, padding="same"
|
306 |
+
)
|
307 |
+
|
308 |
+
current_resolution = resolution
|
309 |
+
|
310 |
+
in_channels_multiplier = (1,) + tuple(channels_multiplier)
|
311 |
+
|
312 |
+
self.downsampling_list = []
|
313 |
+
|
314 |
+
for i_level in range(self.num_resolutions):
|
315 |
+
block_in = channels * in_channels_multiplier[i_level]
|
316 |
+
block_out = channels * channels_multiplier[i_level]
|
317 |
+
for i_block in range(self.num_res_blocks):
|
318 |
+
self.downsampling_list.append(
|
319 |
+
ResnetBlock(
|
320 |
+
in_channels=block_in,
|
321 |
+
out_channels=block_out,
|
322 |
+
timestep_embedding_channels=self.timestep_embeddings_channel,
|
323 |
+
dropout=dropout,
|
324 |
+
)
|
325 |
+
)
|
326 |
+
block_in = block_out
|
327 |
+
|
328 |
+
if current_resolution in attention_resolution:
|
329 |
+
# attentions.append(layers.Attention())
|
330 |
+
self.downsampling_list.append(AttentionBlock(block_in))
|
331 |
+
|
332 |
+
if i_level != self.num_resolutions - 1:
|
333 |
+
self.downsampling_list.append(Downsample(block_in, resamp_with_conv))
|
334 |
+
|
335 |
+
# self.downsampling = []
|
336 |
+
|
337 |
+
# for i_level in range(self.num_resolutions):
|
338 |
+
# block = []
|
339 |
+
# attentions = []
|
340 |
+
# block_in = channels * in_channels_multiplier[i_level]
|
341 |
+
# block_out = channels * channels_multiplier[i_level]
|
342 |
+
# for i_block in range(self.num_res_blocks):
|
343 |
+
# block.append(
|
344 |
+
# ResnetBlock(
|
345 |
+
# in_channels=block_in,
|
346 |
+
# out_channels=block_out,
|
347 |
+
# timestep_embedding_channels=self.timestep_embeddings_channel,
|
348 |
+
# dropout=dropout,
|
349 |
+
# )
|
350 |
+
# )
|
351 |
+
# block_in = block_out
|
352 |
+
|
353 |
+
# if current_resolution in attention_resolution:
|
354 |
+
# # attentions.append(layers.Attention())
|
355 |
+
# attentions.append(AttentionBlock(block_in))
|
356 |
+
|
357 |
+
# down = {}
|
358 |
+
# down["block"] = block
|
359 |
+
# down["attention"] = attentions
|
360 |
+
# if i_level != self.num_resolutions - 1:
|
361 |
+
# down["downsample"] = Downsample(block_in, resamp_with_conv)
|
362 |
+
# self.downsampling.append(down)
|
363 |
+
|
364 |
+
# middle
|
365 |
+
self.mid = {}
|
366 |
+
self.mid["block_1"] = ResnetBlock(
|
367 |
+
in_channels=block_in,
|
368 |
+
out_channels=block_in,
|
369 |
+
timestep_embedding_channels=self.timestep_embeddings_channel,
|
370 |
+
dropout=dropout,
|
371 |
+
)
|
372 |
+
self.mid["attn_1"] = AttentionBlock(block_in)
|
373 |
+
self.mid["block_2"] = ResnetBlock(
|
374 |
+
in_channels=block_in,
|
375 |
+
out_channels=block_in,
|
376 |
+
timestep_embedding_channels=self.timestep_embeddings_channel,
|
377 |
+
dropout=dropout,
|
378 |
+
)
|
379 |
+
|
380 |
+
# end
|
381 |
+
self.norm_out = GroupNormalization(groups=32, epsilon=1e-6)
|
382 |
+
self.conv_out = layers.Conv2D(
|
383 |
+
2 * z_channels if double_z else z_channels,
|
384 |
+
kernel_size=3,
|
385 |
+
strides=1,
|
386 |
+
padding="same",
|
387 |
+
)
|
388 |
+
|
389 |
+
def summary(self):
|
390 |
+
x = layers.Input(shape=INPUT_SHAPE)
|
391 |
+
model = Model(inputs=[x], outputs=self.call(x))
|
392 |
+
return model.summary()
|
393 |
+
|
394 |
+
def call(self, inputs, training=True, mask=None):
|
395 |
+
h = self.conv_in(inputs)
|
396 |
+
for downsampling in self.downsampling_list:
|
397 |
+
h = downsampling(h)
|
398 |
+
# for i_level in range(self.num_resolutions):
|
399 |
+
# for i_block in range(self.num_res_blocks):
|
400 |
+
# h = self.downsampling[i_level]["block"][i_block](hs[-1])
|
401 |
+
# if len(self.downsampling[i_level]["attention"]) > 0:
|
402 |
+
# h = self.downsampling[i_level]["attention"][i_block](h)
|
403 |
+
# hs.append(h)
|
404 |
+
# if i_level != self.num_resolutions - 1:
|
405 |
+
# hs.append(self.downsampling[i_level]["downsample"](hs[-1]))
|
406 |
+
|
407 |
+
# h = hs[-1]
|
408 |
+
h = self.mid["block_1"](h)
|
409 |
+
h = self.mid["attn_1"](h)
|
410 |
+
h = self.mid["block_2"](h)
|
411 |
+
|
412 |
+
# end
|
413 |
+
h = self.norm_out(h)
|
414 |
+
h = keras.activations.swish(h)
|
415 |
+
h = self.conv_out(h)
|
416 |
+
return h
|
417 |
+
|
418 |
+
|
419 |
+
class Decoder(Model):
|
420 |
+
def __init__(
|
421 |
+
self,
|
422 |
+
*,
|
423 |
+
channels: int = 128,
|
424 |
+
output_channels: int = 3,
|
425 |
+
channels_multiplier: List[int] = [1, 1, 2, 2], # [1, 1, 2, 2, 4],
|
426 |
+
num_res_blocks: int = 1, # 2,
|
427 |
+
attention_resolution: List[int] = [16],
|
428 |
+
resolution: int = 64, # 256,
|
429 |
+
z_channels=128, # 256,
|
430 |
+
dropout=0.0,
|
431 |
+
give_pre_end=False,
|
432 |
+
resamp_with_conv=True,
|
433 |
+
):
|
434 |
+
super().__init__()
|
435 |
+
|
436 |
+
self.channels = channels
|
437 |
+
self.timestep_embeddings_channel = 0
|
438 |
+
self.num_resolutions = len(channels_multiplier)
|
439 |
+
self.num_res_blocks = num_res_blocks
|
440 |
+
self.resolution = resolution
|
441 |
+
self.give_pre_end = give_pre_end
|
442 |
+
|
443 |
+
in_channels_multiplier = (1,) + tuple(channels_multiplier)
|
444 |
+
block_in = channels * channels_multiplier[-1]
|
445 |
+
current_resolution = resolution // 2 ** (self.num_resolutions - 1)
|
446 |
+
self.z_shape = (1, z_channels, current_resolution, current_resolution)
|
447 |
+
|
448 |
+
print(
|
449 |
+
"Working with z of shape {} = {} dimensions.".format(
|
450 |
+
self.z_shape, np.prod(self.z_shape)
|
451 |
+
)
|
452 |
+
)
|
453 |
+
|
454 |
+
self.conv_in = layers.Conv2D(block_in, kernel_size=3, strides=1, padding="same")
|
455 |
+
|
456 |
+
# middle
|
457 |
+
self.mid = {}
|
458 |
+
self.mid["block_1"] = ResnetBlock(
|
459 |
+
in_channels=block_in,
|
460 |
+
out_channels=block_in,
|
461 |
+
timestep_embedding_channels=self.timestep_embeddings_channel,
|
462 |
+
dropout=dropout,
|
463 |
+
)
|
464 |
+
self.mid["attn_1"] = AttentionBlock(block_in)
|
465 |
+
self.mid["block_2"] = ResnetBlock(
|
466 |
+
in_channels=block_in,
|
467 |
+
out_channels=block_in,
|
468 |
+
timestep_embedding_channels=self.timestep_embeddings_channel,
|
469 |
+
dropout=dropout,
|
470 |
+
)
|
471 |
+
|
472 |
+
# upsampling
|
473 |
+
|
474 |
+
self.upsampling_list = []
|
475 |
+
|
476 |
+
for i_level in reversed(range(self.num_resolutions)):
|
477 |
+
block_out = channels * channels_multiplier[i_level]
|
478 |
+
for i_block in range(self.num_res_blocks + 1):
|
479 |
+
self.upsampling_list.append(
|
480 |
+
ResnetBlock(
|
481 |
+
in_channels=block_in,
|
482 |
+
out_channels=block_out,
|
483 |
+
timestep_embedding_channels=self.timestep_embeddings_channel,
|
484 |
+
dropout=dropout,
|
485 |
+
)
|
486 |
+
)
|
487 |
+
block_in = block_out
|
488 |
+
|
489 |
+
if current_resolution in attention_resolution:
|
490 |
+
# attentions.append(layers.Attention())
|
491 |
+
self.upsampling_list.append(AttentionBlock(block_in))
|
492 |
+
|
493 |
+
if i_level != 0:
|
494 |
+
self.upsampling_list.append(Upsample(block_in, resamp_with_conv))
|
495 |
+
current_resolution *= 2
|
496 |
+
# self.upsampling.insert(0, upsampling)
|
497 |
+
|
498 |
+
# self.upsampling = []
|
499 |
+
|
500 |
+
# for i_level in reversed(range(self.num_resolutions)):
|
501 |
+
# block = []
|
502 |
+
# attentions = []
|
503 |
+
# block_out = channels * channels_multiplier[i_level]
|
504 |
+
# for i_block in range(self.num_res_blocks + 1):
|
505 |
+
# block.append(
|
506 |
+
# ResnetBlock(
|
507 |
+
# in_channels=block_in,
|
508 |
+
# out_channels=block_out,
|
509 |
+
# timestep_embedding_channels=self.timestep_embeddings_channel,
|
510 |
+
# dropout=dropout,
|
511 |
+
# )
|
512 |
+
# )
|
513 |
+
# block_in = block_out
|
514 |
+
|
515 |
+
# if current_resolution in attention_resolution:
|
516 |
+
# # attentions.append(layers.Attention())
|
517 |
+
# attentions.append(AttentionBlock(block_in))
|
518 |
+
|
519 |
+
# upsampling = {}
|
520 |
+
# upsampling["block"] = block
|
521 |
+
# upsampling["attention"] = attentions
|
522 |
+
# if i_level != 0:
|
523 |
+
# upsampling["upsample"] = Upsample(block_in, resamp_with_conv)
|
524 |
+
# current_resolution *= 2
|
525 |
+
# self.upsampling.insert(0, upsampling)
|
526 |
+
|
527 |
+
# end
|
528 |
+
self.norm_out = GroupNormalization(groups=32, epsilon=1e-6)
|
529 |
+
self.conv_out = layers.Conv2D(
|
530 |
+
output_channels,
|
531 |
+
kernel_size=3,
|
532 |
+
strides=1,
|
533 |
+
activation="sigmoid",
|
534 |
+
padding="same",
|
535 |
+
)
|
536 |
+
|
537 |
+
def summary(self):
|
538 |
+
x = layers.Input(shape=ENCODER_OUTPUT_SHAPE)
|
539 |
+
model = Model(inputs=[x], outputs=self.call(x))
|
540 |
+
return model.summary()
|
541 |
+
|
542 |
+
def call(self, inputs, training=True, mask=None):
|
543 |
+
|
544 |
+
h = self.conv_in(inputs)
|
545 |
+
|
546 |
+
# middle
|
547 |
+
h = self.mid["block_1"](h)
|
548 |
+
h = self.mid["attn_1"](h)
|
549 |
+
h = self.mid["block_2"](h)
|
550 |
+
|
551 |
+
for upsampling in self.upsampling_list:
|
552 |
+
h = upsampling(h)
|
553 |
+
|
554 |
+
# for i_level in reversed(range(self.num_resolutions)):
|
555 |
+
# for i_block in range(self.num_res_blocks + 1):
|
556 |
+
# h = self.upsampling[i_level]["block"][i_block](h)
|
557 |
+
# if len(self.upsampling[i_level]["attention"]) > 0:
|
558 |
+
# h = self.upsampling[i_level]["attention"][i_block](h)
|
559 |
+
# if i_level != 0:
|
560 |
+
# h = self.upsampling[i_level]["upsample"](h)
|
561 |
+
|
562 |
+
# end
|
563 |
+
if self.give_pre_end:
|
564 |
+
return h
|
565 |
+
|
566 |
+
h = self.norm_out(h)
|
567 |
+
h = keras.activations.swish(h)
|
568 |
+
h = self.conv_out(h)
|
569 |
+
return h
|
570 |
+
|
571 |
+
|
572 |
+
class ResnetBlock(layers.Layer):
|
573 |
+
def __init__(
|
574 |
+
self,
|
575 |
+
*,
|
576 |
+
in_channels,
|
577 |
+
dropout=0.0,
|
578 |
+
out_channels=None,
|
579 |
+
conv_shortcut=False,
|
580 |
+
timestep_embedding_channels=512,
|
581 |
+
):
|
582 |
+
super().__init__()
|
583 |
+
self.in_channels = in_channels
|
584 |
+
out_channels = in_channels if out_channels is None else out_channels
|
585 |
+
self.out_channels = out_channels
|
586 |
+
self.use_conv_shortcut = conv_shortcut
|
587 |
+
|
588 |
+
self.norm1 = GroupNormalization(groups=32, epsilon=1e-6)
|
589 |
+
|
590 |
+
self.conv1 = layers.Conv2D(
|
591 |
+
out_channels, kernel_size=3, strides=1, padding="same"
|
592 |
+
)
|
593 |
+
|
594 |
+
if timestep_embedding_channels > 0:
|
595 |
+
self.timestep_embedding_projection = layers.Dense(out_channels)
|
596 |
+
|
597 |
+
self.norm2 = GroupNormalization(groups=32, epsilon=1e-6)
|
598 |
+
self.dropout = layers.Dropout(dropout)
|
599 |
+
|
600 |
+
self.conv2 = layers.Conv2D(
|
601 |
+
out_channels, kernel_size=3, strides=1, padding="same"
|
602 |
+
)
|
603 |
+
|
604 |
+
if self.in_channels != self.out_channels:
|
605 |
+
if self.use_conv_shortcut:
|
606 |
+
self.conv_shortcut = layers.Conv2D(
|
607 |
+
out_channels, kernel_size=3, strides=1, padding="same"
|
608 |
+
)
|
609 |
+
else:
|
610 |
+
self.nin_shortcut = layers.Conv2D(
|
611 |
+
out_channels, kernel_size=1, strides=1, padding="valid"
|
612 |
+
)
|
613 |
+
|
614 |
+
def call(self, x):
|
615 |
+
h = x
|
616 |
+
h = self.norm1(h)
|
617 |
+
h = keras.activations.swish(h)
|
618 |
+
h = self.conv1(h)
|
619 |
+
|
620 |
+
# if timestamp_embedding is not None:
|
621 |
+
# h = h + self.timestep_embedding_projection(keras.activations.swish(timestamp_embedding))
|
622 |
+
|
623 |
+
h = self.norm2(h)
|
624 |
+
h = keras.activations.swish(h)
|
625 |
+
h = self.dropout(h)
|
626 |
+
h = self.conv2(h)
|
627 |
+
|
628 |
+
if self.in_channels != self.out_channels:
|
629 |
+
if self.use_conv_shortcut:
|
630 |
+
x = self.conv_shortcut(x)
|
631 |
+
else:
|
632 |
+
x = self.nin_shortcut(x)
|
633 |
+
|
634 |
+
return x + h
|
635 |
+
|
636 |
+
|
637 |
+
class AttentionBlock(layers.Layer):
|
638 |
+
def __init__(self, channels):
|
639 |
+
super().__init__()
|
640 |
+
|
641 |
+
self.norm = GroupNormalization(groups=32, epsilon=1e-6)
|
642 |
+
self.q = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid")
|
643 |
+
self.k = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid")
|
644 |
+
self.v = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid")
|
645 |
+
self.proj_out = layers.Conv2D(
|
646 |
+
channels, kernel_size=1, strides=1, padding="valid"
|
647 |
+
)
|
648 |
+
|
649 |
+
def call(self, x):
|
650 |
+
h_ = x
|
651 |
+
h_ = self.norm(h_)
|
652 |
+
q = self.q(h_)
|
653 |
+
k = self.k(h_)
|
654 |
+
v = self.v(h_)
|
655 |
+
|
656 |
+
# compute attention
|
657 |
+
(
|
658 |
+
b,
|
659 |
+
h,
|
660 |
+
w,
|
661 |
+
c,
|
662 |
+
) = q.shape
|
663 |
+
if b is None:
|
664 |
+
b = -1
|
665 |
+
q = tf.reshape(q, [b, h * w, c])
|
666 |
+
k = tf.reshape(k, [b, h * w, c])
|
667 |
+
w_ = tf.matmul(
|
668 |
+
q, k, transpose_b=True
|
669 |
+
) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
670 |
+
w_ = w_ * (int(c) ** (-0.5))
|
671 |
+
w_ = keras.activations.softmax(w_)
|
672 |
+
|
673 |
+
# attend to values
|
674 |
+
v = tf.reshape(v, [b, h * w, c])
|
675 |
+
# w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
676 |
+
h_ = tf.matmul(
|
677 |
+
v, w_, transpose_a=True
|
678 |
+
) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
679 |
+
# h_ = h_.reshape(b, c, h, w)
|
680 |
+
h_ = tf.reshape(h_, [b, h, w, c])
|
681 |
+
|
682 |
+
h_ = self.proj_out(h_)
|
683 |
+
|
684 |
+
return x + h_
|
685 |
+
|
686 |
+
|
687 |
+
class Downsample(layers.Layer):
|
688 |
+
def __init__(self, channels, with_conv=True):
|
689 |
+
super().__init__()
|
690 |
+
self.with_conv = with_conv
|
691 |
+
if self.with_conv:
|
692 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
693 |
+
self.down_sample = layers.Conv2D(
|
694 |
+
channels, kernel_size=3, strides=2, padding="same"
|
695 |
+
)
|
696 |
+
else:
|
697 |
+
self.down_sample = layers.AveragePooling2D(pool_size=2, strides=2)
|
698 |
+
|
699 |
+
def call(self, x):
|
700 |
+
x = self.down_sample(x)
|
701 |
+
return x
|
702 |
+
|
703 |
+
|
704 |
+
class Upsample(layers.Layer):
|
705 |
+
def __init__(self, channels, with_conv=False):
|
706 |
+
super().__init__()
|
707 |
+
self.with_conv = with_conv
|
708 |
+
if False: # self.with_conv:
|
709 |
+
self.up_sample = layers.Conv2DTranspose(
|
710 |
+
channels, kernel_size=3, strides=2, padding="same"
|
711 |
+
)
|
712 |
+
else:
|
713 |
+
self.up_sample = Sequential(
|
714 |
+
[
|
715 |
+
layers.UpSampling2D(size=2, interpolation="nearest"),
|
716 |
+
layers.Conv2D(channels, kernel_size=3, strides=1, padding="same"),
|
717 |
+
]
|
718 |
+
)
|
719 |
+
|
720 |
+
def call(self, x):
|
721 |
+
x = self.up_sample(x)
|
722 |
+
return x
|
ganime/model/vqgan_clean/__init__.py
ADDED
File without changes
|
ganime/model/vqgan_clean/diffusion/__init__.py
ADDED
File without changes
|
ganime/model/vqgan_clean/diffusion/decoder.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import tensorflow as tf
|
5 |
+
from tensorflow import keras
|
6 |
+
from tensorflow.keras import Model, layers
|
7 |
+
from tensorflow_addons.layers import GroupNormalization
|
8 |
+
|
9 |
+
from .layers import AttentionBlock, ResnetBlock, Upsample
|
10 |
+
|
11 |
+
|
12 |
+
# @tf.keras.utils.register_keras_serializable()
|
13 |
+
class Decoder(layers.Layer):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
*,
|
17 |
+
channels: int,
|
18 |
+
output_channels: int = 3,
|
19 |
+
channels_multiplier: List[int],
|
20 |
+
num_res_blocks: int,
|
21 |
+
attention_resolution: List[int],
|
22 |
+
resolution: int,
|
23 |
+
z_channels: int,
|
24 |
+
dropout: float,
|
25 |
+
**kwargs
|
26 |
+
):
|
27 |
+
super().__init__(**kwargs)
|
28 |
+
|
29 |
+
self.channels = channels
|
30 |
+
self.output_channels = output_channels
|
31 |
+
self.channels_multiplier = channels_multiplier
|
32 |
+
self.num_resolutions = len(channels_multiplier)
|
33 |
+
self.num_res_blocks = num_res_blocks
|
34 |
+
self.attention_resolution = attention_resolution
|
35 |
+
self.resolution = resolution
|
36 |
+
self.z_channels = z_channels
|
37 |
+
self.dropout = dropout
|
38 |
+
|
39 |
+
block_in = channels * channels_multiplier[-1]
|
40 |
+
current_resolution = resolution // 2 ** (self.num_resolutions - 1)
|
41 |
+
self.z_shape = (1, z_channels, current_resolution, current_resolution)
|
42 |
+
|
43 |
+
print(
|
44 |
+
"Working with z of shape {} = {} dimensions.".format(
|
45 |
+
self.z_shape, np.prod(self.z_shape)
|
46 |
+
)
|
47 |
+
)
|
48 |
+
|
49 |
+
self.conv_in = layers.Conv2D(block_in, kernel_size=3, strides=1, padding="same")
|
50 |
+
|
51 |
+
# middle
|
52 |
+
self.mid = {}
|
53 |
+
self.mid["block_1"] = ResnetBlock(
|
54 |
+
in_channels=block_in,
|
55 |
+
out_channels=block_in,
|
56 |
+
dropout=dropout,
|
57 |
+
)
|
58 |
+
self.mid["attn_1"] = AttentionBlock(block_in)
|
59 |
+
self.mid["block_2"] = ResnetBlock(
|
60 |
+
in_channels=block_in,
|
61 |
+
out_channels=block_in,
|
62 |
+
dropout=dropout,
|
63 |
+
)
|
64 |
+
|
65 |
+
# upsampling
|
66 |
+
|
67 |
+
self.upsampling_list = []
|
68 |
+
|
69 |
+
for i_level in reversed(range(self.num_resolutions)):
|
70 |
+
block_out = channels * channels_multiplier[i_level]
|
71 |
+
for i_block in range(self.num_res_blocks + 1):
|
72 |
+
self.upsampling_list.append(
|
73 |
+
ResnetBlock(
|
74 |
+
in_channels=block_in,
|
75 |
+
out_channels=block_out,
|
76 |
+
dropout=dropout,
|
77 |
+
)
|
78 |
+
)
|
79 |
+
block_in = block_out
|
80 |
+
|
81 |
+
if current_resolution in attention_resolution:
|
82 |
+
# attentions.append(layers.Attention())
|
83 |
+
self.upsampling_list.append(AttentionBlock(block_in))
|
84 |
+
|
85 |
+
if i_level != 0:
|
86 |
+
self.upsampling_list.append(Upsample(block_in))
|
87 |
+
current_resolution *= 2
|
88 |
+
|
89 |
+
# end
|
90 |
+
self.norm_out = GroupNormalization(groups=32, epsilon=1e-6)
|
91 |
+
self.conv_out = layers.Conv2D(
|
92 |
+
output_channels,
|
93 |
+
kernel_size=3,
|
94 |
+
strides=1,
|
95 |
+
activation="tanh",
|
96 |
+
padding="same",
|
97 |
+
)
|
98 |
+
|
99 |
+
def call(self, inputs, training=True, mask=None):
|
100 |
+
|
101 |
+
h = self.conv_in(inputs)
|
102 |
+
|
103 |
+
# middle
|
104 |
+
h = self.mid["block_1"](h)
|
105 |
+
h = self.mid["attn_1"](h)
|
106 |
+
h = self.mid["block_2"](h)
|
107 |
+
|
108 |
+
for upsampling in self.upsampling_list:
|
109 |
+
h = upsampling(h)
|
110 |
+
|
111 |
+
# end
|
112 |
+
h = self.norm_out(h)
|
113 |
+
h = keras.activations.swish(h)
|
114 |
+
h = self.conv_out(h)
|
115 |
+
return h
|
ganime/model/vqgan_clean/diffusion/encoder.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import tensorflow as tf
|
3 |
+
from tensorflow import keras
|
4 |
+
from tensorflow.keras import layers, Model
|
5 |
+
from tensorflow_addons.layers import GroupNormalization
|
6 |
+
from .layers import ResnetBlock, AttentionBlock, Downsample
|
7 |
+
|
8 |
+
|
9 |
+
# @tf.keras.utils.register_keras_serializable()
|
10 |
+
class Encoder(layers.Layer):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
*,
|
14 |
+
channels: int,
|
15 |
+
channels_multiplier: List[int],
|
16 |
+
num_res_blocks: int,
|
17 |
+
attention_resolution: List[int],
|
18 |
+
resolution: int,
|
19 |
+
z_channels: int,
|
20 |
+
dropout: float,
|
21 |
+
**kwargs
|
22 |
+
):
|
23 |
+
"""Encode an image into a latent vector. The encoder will be constitued of multiple levels (lenght of `channels_multiplier`) with for each level `num_res_blocks` ResnetBlock.
|
24 |
+
Args:
|
25 |
+
channels (int, optional): The number of channel for the first layer. Defaults to 128.
|
26 |
+
channels_multiplier (List[int], optional): The channel multiplier for each level (previous level channels X multipler). Defaults to [1, 1, 2, 2].
|
27 |
+
num_res_blocks (int, optional): Number of ResnetBlock at each level. Defaults to 1.
|
28 |
+
attention_resolution (List[int], optional): Add an attention block if the current resolution is in this array. Defaults to [16].
|
29 |
+
resolution (int, optional): The starting resolution. Defaults to 64.
|
30 |
+
z_channels (int, optional): The number of channel at the end of the encoder. Defaults to 128.
|
31 |
+
dropout (float, optional): The dropout ratio for each ResnetBlock. Defaults to 0.0.
|
32 |
+
"""
|
33 |
+
super().__init__(**kwargs)
|
34 |
+
|
35 |
+
self.channels = channels
|
36 |
+
self.channels_multiplier = channels_multiplier
|
37 |
+
self.num_resolutions = len(channels_multiplier)
|
38 |
+
self.num_res_blocks = num_res_blocks
|
39 |
+
self.attention_resolution = attention_resolution
|
40 |
+
self.resolution = resolution
|
41 |
+
self.z_channels = z_channels
|
42 |
+
self.dropout = dropout
|
43 |
+
|
44 |
+
self.conv_in = layers.Conv2D(
|
45 |
+
self.channels, kernel_size=3, strides=1, padding="same"
|
46 |
+
)
|
47 |
+
|
48 |
+
current_resolution = resolution
|
49 |
+
|
50 |
+
in_channels_multiplier = (1,) + tuple(channels_multiplier)
|
51 |
+
|
52 |
+
self.downsampling_list = []
|
53 |
+
|
54 |
+
for i_level in range(self.num_resolutions):
|
55 |
+
block_in = channels * in_channels_multiplier[i_level]
|
56 |
+
block_out = channels * channels_multiplier[i_level]
|
57 |
+
for i_block in range(self.num_res_blocks):
|
58 |
+
self.downsampling_list.append(
|
59 |
+
ResnetBlock(
|
60 |
+
in_channels=block_in,
|
61 |
+
out_channels=block_out,
|
62 |
+
dropout=dropout,
|
63 |
+
)
|
64 |
+
)
|
65 |
+
block_in = block_out
|
66 |
+
|
67 |
+
if current_resolution in attention_resolution:
|
68 |
+
self.downsampling_list.append(AttentionBlock(block_in))
|
69 |
+
|
70 |
+
if i_level != self.num_resolutions - 1:
|
71 |
+
self.downsampling_list.append(Downsample(block_in))
|
72 |
+
current_resolution = current_resolution // 2
|
73 |
+
|
74 |
+
# middle
|
75 |
+
self.mid = {}
|
76 |
+
self.mid["block_1"] = ResnetBlock(
|
77 |
+
in_channels=block_in,
|
78 |
+
out_channels=block_in,
|
79 |
+
dropout=dropout,
|
80 |
+
)
|
81 |
+
self.mid["attn_1"] = AttentionBlock(block_in)
|
82 |
+
self.mid["block_2"] = ResnetBlock(
|
83 |
+
in_channels=block_in,
|
84 |
+
out_channels=block_in,
|
85 |
+
dropout=dropout,
|
86 |
+
)
|
87 |
+
|
88 |
+
# end
|
89 |
+
self.norm_out = GroupNormalization(groups=32, epsilon=1e-6)
|
90 |
+
self.conv_out = layers.Conv2D(
|
91 |
+
z_channels,
|
92 |
+
kernel_size=3,
|
93 |
+
strides=1,
|
94 |
+
padding="same",
|
95 |
+
)
|
96 |
+
|
97 |
+
# def get_config(self):
|
98 |
+
# config = super().get_config()
|
99 |
+
# config.update(
|
100 |
+
# {
|
101 |
+
# "channels": self.channels,
|
102 |
+
# "channels_multiplier": self.channels_multiplier,
|
103 |
+
# "num_res_blocks": self.num_res_blocks,
|
104 |
+
# "attention_resolution": self.attention_resolution,
|
105 |
+
# "resolution": self.resolution,
|
106 |
+
# "z_channels": self.z_channels,
|
107 |
+
# "dropout": self.dropout,
|
108 |
+
# }
|
109 |
+
# )
|
110 |
+
# return config
|
111 |
+
|
112 |
+
def call(self, inputs, training=True, mask=None):
|
113 |
+
h = self.conv_in(inputs)
|
114 |
+
for downsampling in self.downsampling_list:
|
115 |
+
h = downsampling(h)
|
116 |
+
|
117 |
+
h = self.mid["block_1"](h)
|
118 |
+
h = self.mid["attn_1"](h)
|
119 |
+
h = self.mid["block_2"](h)
|
120 |
+
|
121 |
+
# end
|
122 |
+
h = self.norm_out(h)
|
123 |
+
h = keras.activations.swish(h)
|
124 |
+
h = self.conv_out(h)
|
125 |
+
return h
|
ganime/model/vqgan_clean/diffusion/layers.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow import keras
|
3 |
+
from tensorflow.keras import layers, Sequential
|
4 |
+
from tensorflow_addons.layers import GroupNormalization
|
5 |
+
|
6 |
+
|
7 |
+
@tf.keras.utils.register_keras_serializable()
|
8 |
+
class ResnetBlock(layers.Layer):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
*,
|
12 |
+
in_channels,
|
13 |
+
dropout=0.0,
|
14 |
+
out_channels=None,
|
15 |
+
conv_shortcut=False,
|
16 |
+
**kwargs
|
17 |
+
):
|
18 |
+
super().__init__(**kwargs)
|
19 |
+
self.in_channels = in_channels
|
20 |
+
self.dropout_rate = dropout
|
21 |
+
out_channels = in_channels if out_channels is None else out_channels
|
22 |
+
self.out_channels = out_channels
|
23 |
+
self.use_conv_shortcut = conv_shortcut
|
24 |
+
|
25 |
+
self.norm1 = GroupNormalization(groups=32, epsilon=1e-6)
|
26 |
+
|
27 |
+
self.conv1 = layers.Conv2D(
|
28 |
+
out_channels, kernel_size=3, strides=1, padding="same"
|
29 |
+
)
|
30 |
+
|
31 |
+
self.norm2 = GroupNormalization(groups=32, epsilon=1e-6)
|
32 |
+
self.dropout = layers.Dropout(dropout)
|
33 |
+
|
34 |
+
self.conv2 = layers.Conv2D(
|
35 |
+
out_channels, kernel_size=3, strides=1, padding="same"
|
36 |
+
)
|
37 |
+
|
38 |
+
if self.in_channels != self.out_channels:
|
39 |
+
if self.use_conv_shortcut:
|
40 |
+
self.conv_shortcut = layers.Conv2D(
|
41 |
+
out_channels, kernel_size=3, strides=1, padding="same"
|
42 |
+
)
|
43 |
+
else:
|
44 |
+
self.nin_shortcut = layers.Conv2D(
|
45 |
+
out_channels, kernel_size=1, strides=1, padding="valid"
|
46 |
+
)
|
47 |
+
|
48 |
+
def get_config(self):
|
49 |
+
config = super().get_config()
|
50 |
+
config.update(
|
51 |
+
{
|
52 |
+
"in_channels": self.in_channels,
|
53 |
+
"dropout": self.dropout_rate,
|
54 |
+
"out_channels": self.out_channels,
|
55 |
+
"conv_shortcut": self.use_conv_shortcut,
|
56 |
+
}
|
57 |
+
)
|
58 |
+
return config
|
59 |
+
|
60 |
+
def call(self, x):
|
61 |
+
h = x
|
62 |
+
h = self.norm1(h)
|
63 |
+
h = keras.activations.swish(h)
|
64 |
+
h = self.conv1(h)
|
65 |
+
|
66 |
+
h = self.norm2(h)
|
67 |
+
h = keras.activations.swish(h)
|
68 |
+
h = self.dropout(h)
|
69 |
+
h = self.conv2(h)
|
70 |
+
|
71 |
+
if self.in_channels != self.out_channels:
|
72 |
+
if self.use_conv_shortcut:
|
73 |
+
x = self.conv_shortcut(x)
|
74 |
+
else:
|
75 |
+
x = self.nin_shortcut(x)
|
76 |
+
|
77 |
+
return x + h
|
78 |
+
|
79 |
+
|
80 |
+
@tf.keras.utils.register_keras_serializable()
|
81 |
+
class AttentionBlock(layers.Layer):
|
82 |
+
def __init__(self, channels, **kwargs):
|
83 |
+
super().__init__(**kwargs)
|
84 |
+
self.channels = channels
|
85 |
+
self.norm = GroupNormalization(groups=32, epsilon=1e-6)
|
86 |
+
self.q = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid")
|
87 |
+
self.k = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid")
|
88 |
+
self.v = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid")
|
89 |
+
self.proj_out = layers.Conv2D(
|
90 |
+
channels, kernel_size=1, strides=1, padding="valid"
|
91 |
+
)
|
92 |
+
|
93 |
+
self.attention = layers.Attention()
|
94 |
+
|
95 |
+
def get_config(self):
|
96 |
+
config = super().get_config()
|
97 |
+
config.update(
|
98 |
+
{
|
99 |
+
"channels": self.channels,
|
100 |
+
}
|
101 |
+
)
|
102 |
+
return config
|
103 |
+
|
104 |
+
def call(self, x):
|
105 |
+
h_ = x
|
106 |
+
h_ = self.norm(h_)
|
107 |
+
q = self.q(h_)
|
108 |
+
k = self.k(h_)
|
109 |
+
v = self.v(h_)
|
110 |
+
|
111 |
+
# compute attention
|
112 |
+
(b, h, w, c,) = (
|
113 |
+
tf.shape(q)[0],
|
114 |
+
tf.shape(q)[1],
|
115 |
+
tf.shape(q)[2],
|
116 |
+
tf.shape(q)[3],
|
117 |
+
)
|
118 |
+
|
119 |
+
if b is None:
|
120 |
+
b = -1
|
121 |
+
q = tf.reshape(q, [b, h * w, c])
|
122 |
+
k = tf.reshape(k, [b, h * w, c])
|
123 |
+
v = tf.reshape(v, [b, h * w, c])
|
124 |
+
|
125 |
+
h_ = self.attention([q, v, k])
|
126 |
+
|
127 |
+
h_ = tf.reshape(h_, [b, h, w, c])
|
128 |
+
|
129 |
+
h_ = self.proj_out(h_)
|
130 |
+
|
131 |
+
return x + h_
|
132 |
+
|
133 |
+
|
134 |
+
@tf.keras.utils.register_keras_serializable()
|
135 |
+
class Downsample(layers.Layer):
|
136 |
+
def __init__(self, channels, **kwargs):
|
137 |
+
super().__init__(**kwargs)
|
138 |
+
self.channels = channels
|
139 |
+
self.down_sample = self.down_sample = layers.AveragePooling2D(
|
140 |
+
pool_size=2, strides=2
|
141 |
+
)
|
142 |
+
self.conv = layers.Conv2D(channels, kernel_size=3, strides=1, padding="same")
|
143 |
+
|
144 |
+
def get_config(self):
|
145 |
+
config = super().get_config()
|
146 |
+
config.update(
|
147 |
+
{
|
148 |
+
"channels": self.channels,
|
149 |
+
}
|
150 |
+
)
|
151 |
+
return config
|
152 |
+
|
153 |
+
def call(self, x):
|
154 |
+
x = self.down_sample(x)
|
155 |
+
x = self.conv(x)
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
@tf.keras.utils.register_keras_serializable()
|
160 |
+
class Upsample(layers.Layer):
|
161 |
+
def __init__(self, channels, **kwargs):
|
162 |
+
super().__init__(**kwargs)
|
163 |
+
self.channels = channels
|
164 |
+
self.up_sample = layers.UpSampling2D(size=2, interpolation="bilinear")
|
165 |
+
self.conv = layers.Conv2D(channels, kernel_size=3, strides=1, padding="same")
|
166 |
+
|
167 |
+
def get_config(self):
|
168 |
+
config = super().get_config()
|
169 |
+
config.update(
|
170 |
+
{
|
171 |
+
"channels": self.channels,
|
172 |
+
}
|
173 |
+
)
|
174 |
+
return config
|
175 |
+
|
176 |
+
def call(self, x):
|
177 |
+
x = self.up_sample(x)
|
178 |
+
x = self.conv(x)
|
179 |
+
return x
|
ganime/model/vqgan_clean/discriminator/__init__.py
ADDED
File without changes
|
ganime/model/vqgan_clean/discriminator/model.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow import keras
|
5 |
+
from tensorflow.keras import Model, Sequential
|
6 |
+
from tensorflow.keras import layers
|
7 |
+
from tensorflow.keras.initializers import RandomNormal
|
8 |
+
|
9 |
+
|
10 |
+
class NLayerDiscriminator(Model):
|
11 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
12 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, filters: int = 64, n_layers: int = 3, **kwargs):
|
16 |
+
super().__init__(**kwargs)
|
17 |
+
|
18 |
+
init = RandomNormal(stddev=0.02)
|
19 |
+
self.filters = filters
|
20 |
+
self.n_layers = n_layers
|
21 |
+
|
22 |
+
kernel_size = 4
|
23 |
+
|
24 |
+
inp = tf.keras.layers.Input(shape=[256, 512, 3], name="input_image")
|
25 |
+
tar = tf.keras.layers.Input(shape=[256, 512, 3], name="target_image")
|
26 |
+
|
27 |
+
x = tf.keras.layers.concatenate([inp, tar])
|
28 |
+
|
29 |
+
x = layers.Conv2D(
|
30 |
+
filters,
|
31 |
+
kernel_size=kernel_size,
|
32 |
+
strides=2,
|
33 |
+
# strides=1,
|
34 |
+
padding="same",
|
35 |
+
kernel_initializer=init,
|
36 |
+
)(x)
|
37 |
+
x = layers.LeakyReLU(alpha=0.2)(x)
|
38 |
+
|
39 |
+
filters_mult = 1
|
40 |
+
for n in range(1, n_layers):
|
41 |
+
filters_mult = min(2**n, 8)
|
42 |
+
|
43 |
+
x = layers.Conv2D(
|
44 |
+
filters * filters_mult,
|
45 |
+
kernel_size=kernel_size,
|
46 |
+
# strides=1, # 2,
|
47 |
+
strides=2,
|
48 |
+
padding="same",
|
49 |
+
use_bias=False,
|
50 |
+
kernel_initializer=init,
|
51 |
+
)(x)
|
52 |
+
x = layers.BatchNormalization()(x)
|
53 |
+
x = layers.LeakyReLU(alpha=0.2)(x)
|
54 |
+
|
55 |
+
filters_mult = min(2**n_layers, 8)
|
56 |
+
x = layers.Conv2D(
|
57 |
+
filters * filters_mult,
|
58 |
+
kernel_size=kernel_size,
|
59 |
+
strides=1,
|
60 |
+
padding="same",
|
61 |
+
use_bias=False,
|
62 |
+
kernel_initializer=init,
|
63 |
+
)(x)
|
64 |
+
x = layers.BatchNormalization()(x)
|
65 |
+
x = layers.LeakyReLU(alpha=0.2)(x)
|
66 |
+
|
67 |
+
x = layers.Conv2D(
|
68 |
+
1,
|
69 |
+
kernel_size=kernel_size,
|
70 |
+
strides=1,
|
71 |
+
padding="same",
|
72 |
+
# activation="sigmoid",
|
73 |
+
kernel_initializer=init,
|
74 |
+
)(x)
|
75 |
+
self.model = tf.keras.Model(inputs=[inp, tar], outputs=x)
|
76 |
+
|
77 |
+
def call(self, inputs, training=True, mask=None):
|
78 |
+
return self.model(inputs)
|
79 |
+
|
80 |
+
def get_config(self):
|
81 |
+
config = super().get_config()
|
82 |
+
config.update(
|
83 |
+
{
|
84 |
+
"filters": self.filters,
|
85 |
+
"n_layers": self.n_layers,
|
86 |
+
}
|
87 |
+
)
|
88 |
+
return config
|
ganime/model/vqgan_clean/discriminator/model_bkp.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow import keras
|
5 |
+
from tensorflow.keras import Model, Sequential
|
6 |
+
from tensorflow.keras import layers
|
7 |
+
|
8 |
+
|
9 |
+
class NLayerDiscriminator(Model):
|
10 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
11 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, filters: int = 64, n_layers: int = 3, **kwargs):
|
15 |
+
super().__init__(**kwargs)
|
16 |
+
|
17 |
+
self.filters = filters
|
18 |
+
self.n_layers = n_layers
|
19 |
+
|
20 |
+
kernel_size = 4
|
21 |
+
self.sequence = [
|
22 |
+
layers.Conv2D(filters, kernel_size=kernel_size, strides=1, padding="same"),
|
23 |
+
layers.LeakyReLU(alpha=0.2),
|
24 |
+
]
|
25 |
+
|
26 |
+
filters_mult = 1
|
27 |
+
for n in range(1, n_layers):
|
28 |
+
filters_mult = min(2**n, 8)
|
29 |
+
|
30 |
+
self.sequence += [
|
31 |
+
layers.AveragePooling2D(pool_size=2),
|
32 |
+
layers.Conv2D(
|
33 |
+
filters * filters_mult,
|
34 |
+
kernel_size=kernel_size,
|
35 |
+
strides=1, # 2,
|
36 |
+
# strides=2,
|
37 |
+
padding="same",
|
38 |
+
use_bias=False,
|
39 |
+
),
|
40 |
+
layers.BatchNormalization(),
|
41 |
+
layers.LeakyReLU(alpha=0.2),
|
42 |
+
]
|
43 |
+
|
44 |
+
filters_mult = min(2**n_layers, 8)
|
45 |
+
self.sequence += [
|
46 |
+
layers.AveragePooling2D(pool_size=2),
|
47 |
+
layers.Conv2D(
|
48 |
+
filters * filters_mult,
|
49 |
+
kernel_size=kernel_size,
|
50 |
+
strides=1,
|
51 |
+
padding="same",
|
52 |
+
use_bias=False,
|
53 |
+
),
|
54 |
+
layers.BatchNormalization(),
|
55 |
+
layers.LeakyReLU(alpha=0.2),
|
56 |
+
]
|
57 |
+
|
58 |
+
self.sequence += [
|
59 |
+
layers.Conv2D(1, kernel_size=kernel_size, strides=1, padding="same")
|
60 |
+
]
|
61 |
+
|
62 |
+
def call(self, inputs, training=True, mask=None):
|
63 |
+
h = inputs
|
64 |
+
for seq in self.sequence:
|
65 |
+
h = seq(h)
|
66 |
+
return h
|
67 |
+
|
68 |
+
def get_config(self):
|
69 |
+
config = super().get_config()
|
70 |
+
config.update(
|
71 |
+
{
|
72 |
+
"filters": self.filters,
|
73 |
+
"n_layers": self.n_layers,
|
74 |
+
}
|
75 |
+
)
|
76 |
+
return config
|
ganime/model/vqgan_clean/experimental/gpt2_embedding.py
ADDED
@@ -0,0 +1,1127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" TF 2.0 OpenAI GPT-2 model."""
|
17 |
+
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import List, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import tensorflow as tf
|
23 |
+
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
|
24 |
+
|
25 |
+
|
26 |
+
from transformers.activations_tf import get_tf_activation
|
27 |
+
from transformers.modeling_tf_outputs import (
|
28 |
+
TFBaseModelOutputWithPastAndCrossAttentions,
|
29 |
+
TFCausalLMOutputWithCrossAttentions,
|
30 |
+
TFSequenceClassifierOutputWithPast,
|
31 |
+
)
|
32 |
+
from transformers.modeling_tf_utils import (
|
33 |
+
TFCausalLanguageModelingLoss,
|
34 |
+
TFConv1D,
|
35 |
+
TFModelInputType,
|
36 |
+
TFPreTrainedModel,
|
37 |
+
TFSequenceClassificationLoss,
|
38 |
+
TFSequenceSummary,
|
39 |
+
TFSharedEmbeddings,
|
40 |
+
get_initializer,
|
41 |
+
keras_serializable,
|
42 |
+
unpack_inputs,
|
43 |
+
)
|
44 |
+
from transformers.tf_utils import shape_list, stable_softmax
|
45 |
+
from transformers.utils import (
|
46 |
+
DUMMY_INPUTS,
|
47 |
+
ModelOutput,
|
48 |
+
add_code_sample_docstrings,
|
49 |
+
add_start_docstrings,
|
50 |
+
add_start_docstrings_to_model_forward,
|
51 |
+
logging,
|
52 |
+
replace_return_docstrings,
|
53 |
+
)
|
54 |
+
from transformers import GPT2Config
|
55 |
+
|
56 |
+
|
57 |
+
logger = logging.get_logger(__name__)
|
58 |
+
|
59 |
+
_CHECKPOINT_FOR_DOC = "gpt2"
|
60 |
+
_CONFIG_FOR_DOC = "GPT2Config"
|
61 |
+
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
62 |
+
|
63 |
+
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
64 |
+
"gpt2",
|
65 |
+
"gpt2-medium",
|
66 |
+
"gpt2-large",
|
67 |
+
"gpt2-xl",
|
68 |
+
"distilgpt2",
|
69 |
+
# See all GPT-2 models at https://huggingface.co/models?filter=gpt2
|
70 |
+
]
|
71 |
+
|
72 |
+
|
73 |
+
class TFAttention(tf.keras.layers.Layer):
|
74 |
+
def __init__(self, nx, config, scale=False, is_cross_attention=False, **kwargs):
|
75 |
+
super().__init__(**kwargs)
|
76 |
+
|
77 |
+
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
78 |
+
# [switch nx => n_state from Block to Attention to keep identical to TF implementation]
|
79 |
+
assert n_state % config.n_head == 0
|
80 |
+
self.n_head = config.n_head
|
81 |
+
self.split_size = n_state
|
82 |
+
self.scale = scale
|
83 |
+
self.output_attentions = config.output_attentions
|
84 |
+
|
85 |
+
self.is_cross_attention = is_cross_attention
|
86 |
+
|
87 |
+
if self.is_cross_attention:
|
88 |
+
self.c_attn = TFConv1D(
|
89 |
+
n_state * 2,
|
90 |
+
nx,
|
91 |
+
initializer_range=config.initializer_range,
|
92 |
+
name="c_attn",
|
93 |
+
)
|
94 |
+
self.q_attn = TFConv1D(
|
95 |
+
n_state, nx, initializer_range=config.initializer_range, name="q_attn"
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
self.c_attn = TFConv1D(
|
99 |
+
n_state * 3,
|
100 |
+
nx,
|
101 |
+
initializer_range=config.initializer_range,
|
102 |
+
name="c_attn",
|
103 |
+
)
|
104 |
+
|
105 |
+
self.c_proj = TFConv1D(
|
106 |
+
n_state, nx, initializer_range=config.initializer_range, name="c_proj"
|
107 |
+
)
|
108 |
+
self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop)
|
109 |
+
self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop)
|
110 |
+
self.pruned_heads = set()
|
111 |
+
|
112 |
+
def prune_heads(self, heads):
|
113 |
+
pass
|
114 |
+
|
115 |
+
@staticmethod
|
116 |
+
def causal_attention_mask(nd, ns, dtype):
|
117 |
+
"""
|
118 |
+
1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
|
119 |
+
-1, ns-nd), but doesn't produce garbage on TPUs.
|
120 |
+
"""
|
121 |
+
i = tf.range(nd)[:, None]
|
122 |
+
j = tf.range(ns)
|
123 |
+
m = i >= j - ns + nd
|
124 |
+
return tf.cast(m, dtype)
|
125 |
+
|
126 |
+
def _attn(
|
127 |
+
self, q, k, v, attention_mask, head_mask, output_attentions, training=False
|
128 |
+
):
|
129 |
+
# q, k, v have shape [batch, heads, sequence, features]
|
130 |
+
w = tf.matmul(q, k, transpose_b=True)
|
131 |
+
if self.scale:
|
132 |
+
dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores
|
133 |
+
w = w / tf.math.sqrt(dk)
|
134 |
+
|
135 |
+
if not self.is_cross_attention:
|
136 |
+
# if only "normal" attention layer implements causal mask
|
137 |
+
|
138 |
+
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
|
139 |
+
_, _, nd, ns = shape_list(w)
|
140 |
+
b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
|
141 |
+
b = tf.reshape(b, [1, 1, nd, ns])
|
142 |
+
w = w * b - 1e4 * (1 - b)
|
143 |
+
|
144 |
+
if attention_mask is not None:
|
145 |
+
# Apply the attention mask
|
146 |
+
attention_mask = tf.cast(attention_mask, dtype=w.dtype)
|
147 |
+
w = w + attention_mask
|
148 |
+
|
149 |
+
w = stable_softmax(w, axis=-1)
|
150 |
+
w = self.attn_dropout(w, training=training)
|
151 |
+
|
152 |
+
# Mask heads if we want to
|
153 |
+
if head_mask is not None:
|
154 |
+
w = w * head_mask
|
155 |
+
|
156 |
+
outputs = [tf.matmul(w, v)]
|
157 |
+
if output_attentions:
|
158 |
+
outputs.append(w)
|
159 |
+
return outputs
|
160 |
+
|
161 |
+
def merge_heads(self, x):
|
162 |
+
x = tf.transpose(x, [0, 2, 1, 3])
|
163 |
+
x_shape = shape_list(x)
|
164 |
+
new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
|
165 |
+
return tf.reshape(x, new_x_shape)
|
166 |
+
|
167 |
+
def split_heads(self, x):
|
168 |
+
x_shape = shape_list(x)
|
169 |
+
new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
|
170 |
+
x = tf.reshape(x, new_x_shape)
|
171 |
+
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
|
172 |
+
|
173 |
+
def call(
|
174 |
+
self,
|
175 |
+
x,
|
176 |
+
layer_past,
|
177 |
+
attention_mask,
|
178 |
+
head_mask,
|
179 |
+
encoder_hidden_states,
|
180 |
+
encoder_attention_mask,
|
181 |
+
use_cache,
|
182 |
+
output_attentions,
|
183 |
+
training=False,
|
184 |
+
):
|
185 |
+
|
186 |
+
if encoder_hidden_states is not None:
|
187 |
+
if not hasattr(self, "q_attn"):
|
188 |
+
raise ValueError(
|
189 |
+
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
190 |
+
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
191 |
+
)
|
192 |
+
|
193 |
+
query = self.q_attn(x)
|
194 |
+
kv_out = self.c_attn(encoder_hidden_states)
|
195 |
+
key, value = tf.split(kv_out, 2, axis=2)
|
196 |
+
attention_mask = encoder_attention_mask
|
197 |
+
else:
|
198 |
+
x = self.c_attn(x)
|
199 |
+
query, key, value = tf.split(x, 3, axis=2)
|
200 |
+
|
201 |
+
query = self.split_heads(query)
|
202 |
+
key = self.split_heads(key)
|
203 |
+
value = self.split_heads(value)
|
204 |
+
if layer_past is not None:
|
205 |
+
past_key, past_value = tf.unstack(layer_past, axis=0)
|
206 |
+
key = tf.concat([past_key, key], axis=-2)
|
207 |
+
value = tf.concat([past_value, value], axis=-2)
|
208 |
+
|
209 |
+
# to cope with keras serialization
|
210 |
+
if use_cache:
|
211 |
+
present = tf.stack([key, value], axis=0)
|
212 |
+
else:
|
213 |
+
present = (None,)
|
214 |
+
|
215 |
+
attn_outputs = self._attn(
|
216 |
+
query,
|
217 |
+
key,
|
218 |
+
value,
|
219 |
+
attention_mask,
|
220 |
+
head_mask,
|
221 |
+
output_attentions,
|
222 |
+
training=training,
|
223 |
+
)
|
224 |
+
a = attn_outputs[0]
|
225 |
+
|
226 |
+
a = self.merge_heads(a)
|
227 |
+
a = self.c_proj(a)
|
228 |
+
a = self.resid_dropout(a, training=training)
|
229 |
+
|
230 |
+
outputs = [a, present] + attn_outputs[1:]
|
231 |
+
return outputs # a, present, (attentions)
|
232 |
+
|
233 |
+
|
234 |
+
class TFMLP(tf.keras.layers.Layer):
|
235 |
+
def __init__(self, n_state, config, **kwargs):
|
236 |
+
super().__init__(**kwargs)
|
237 |
+
nx = config.n_embd
|
238 |
+
self.c_fc = TFConv1D(
|
239 |
+
n_state, nx, initializer_range=config.initializer_range, name="c_fc"
|
240 |
+
)
|
241 |
+
self.c_proj = TFConv1D(
|
242 |
+
nx, n_state, initializer_range=config.initializer_range, name="c_proj"
|
243 |
+
)
|
244 |
+
self.act = get_tf_activation(config.activation_function)
|
245 |
+
self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)
|
246 |
+
|
247 |
+
def call(self, x, training=False):
|
248 |
+
h = self.act(self.c_fc(x))
|
249 |
+
h2 = self.c_proj(h)
|
250 |
+
h2 = self.dropout(h2, training=training)
|
251 |
+
return h2
|
252 |
+
|
253 |
+
|
254 |
+
class TFBlock(tf.keras.layers.Layer):
|
255 |
+
def __init__(self, config, scale=False, **kwargs):
|
256 |
+
super().__init__(**kwargs)
|
257 |
+
nx = config.n_embd
|
258 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * nx
|
259 |
+
self.ln_1 = tf.keras.layers.LayerNormalization(
|
260 |
+
epsilon=config.layer_norm_epsilon, name="ln_1"
|
261 |
+
)
|
262 |
+
self.attn = TFAttention(nx, config, scale, name="attn")
|
263 |
+
self.ln_2 = tf.keras.layers.LayerNormalization(
|
264 |
+
epsilon=config.layer_norm_epsilon, name="ln_2"
|
265 |
+
)
|
266 |
+
|
267 |
+
if config.add_cross_attention:
|
268 |
+
|
269 |
+
self.crossattention = TFAttention(
|
270 |
+
nx, config, scale, name="crossattention", is_cross_attention=True
|
271 |
+
)
|
272 |
+
self.ln_cross_attn = tf.keras.layers.LayerNormalization(
|
273 |
+
epsilon=config.layer_norm_epsilon, name="ln_cross_attn"
|
274 |
+
)
|
275 |
+
|
276 |
+
self.mlp = TFMLP(inner_dim, config, name="mlp")
|
277 |
+
|
278 |
+
def call(
|
279 |
+
self,
|
280 |
+
x,
|
281 |
+
layer_past,
|
282 |
+
attention_mask,
|
283 |
+
head_mask,
|
284 |
+
encoder_hidden_states,
|
285 |
+
encoder_attention_mask,
|
286 |
+
use_cache,
|
287 |
+
output_attentions,
|
288 |
+
training=False,
|
289 |
+
):
|
290 |
+
a = self.ln_1(x)
|
291 |
+
output_attn = self.attn(
|
292 |
+
a,
|
293 |
+
layer_past=layer_past,
|
294 |
+
attention_mask=attention_mask,
|
295 |
+
head_mask=head_mask,
|
296 |
+
encoder_hidden_states=None,
|
297 |
+
encoder_attention_mask=None,
|
298 |
+
use_cache=use_cache,
|
299 |
+
output_attentions=output_attentions,
|
300 |
+
training=training,
|
301 |
+
)
|
302 |
+
a = output_attn[0] # output_attn: a, present, (attentions)
|
303 |
+
outputs = output_attn[1:]
|
304 |
+
x = x + a
|
305 |
+
|
306 |
+
# Cross-Attention Block
|
307 |
+
if encoder_hidden_states is not None:
|
308 |
+
# add one self-attention block for cross-attention
|
309 |
+
if not hasattr(self, "crossattention"):
|
310 |
+
raise ValueError(
|
311 |
+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
|
312 |
+
"cross-attention layers by setting `config.add_cross_attention=True`"
|
313 |
+
)
|
314 |
+
|
315 |
+
ca = self.ln_cross_attn(x)
|
316 |
+
output_cross_attn = self.crossattention(
|
317 |
+
ca,
|
318 |
+
layer_past=None,
|
319 |
+
attention_mask=attention_mask,
|
320 |
+
head_mask=head_mask,
|
321 |
+
encoder_hidden_states=encoder_hidden_states,
|
322 |
+
encoder_attention_mask=encoder_attention_mask,
|
323 |
+
use_cache=False,
|
324 |
+
output_attentions=output_attentions,
|
325 |
+
training=training,
|
326 |
+
)
|
327 |
+
ca = output_cross_attn[0] # output_attn: a, present, (cross_attentions)
|
328 |
+
x = x + ca
|
329 |
+
outputs = (
|
330 |
+
outputs + output_cross_attn[2:]
|
331 |
+
) # add cross attentions if we output attention weights
|
332 |
+
|
333 |
+
m = self.ln_2(x)
|
334 |
+
m = self.mlp(m, training=training)
|
335 |
+
x = x + m
|
336 |
+
|
337 |
+
outputs = [x] + outputs
|
338 |
+
return outputs # x, present, (attentions, cross_attentions)
|
339 |
+
|
340 |
+
|
341 |
+
@keras_serializable
|
342 |
+
class TFGPT2MainLayer(tf.keras.layers.Layer):
|
343 |
+
config_class = GPT2Config
|
344 |
+
|
345 |
+
def __init__(self, config, *inputs, **kwargs):
|
346 |
+
super().__init__(*inputs, **kwargs)
|
347 |
+
|
348 |
+
self.config = config
|
349 |
+
self.output_attentions = config.output_attentions
|
350 |
+
self.output_hidden_states = config.output_hidden_states
|
351 |
+
self.use_cache = config.use_cache
|
352 |
+
self.return_dict = config.use_return_dict
|
353 |
+
|
354 |
+
self.num_hidden_layers = config.n_layer
|
355 |
+
self.vocab_size = config.vocab_size
|
356 |
+
self.n_embd = config.n_embd
|
357 |
+
self.n_positions = config.n_positions
|
358 |
+
self.initializer_range = config.initializer_range
|
359 |
+
|
360 |
+
self.wte = TFSharedEmbeddings(
|
361 |
+
config.vocab_size,
|
362 |
+
config.hidden_size,
|
363 |
+
initializer_range=config.initializer_range,
|
364 |
+
name="wte",
|
365 |
+
)
|
366 |
+
|
367 |
+
self.wte_remaining_frames = TFSharedEmbeddings(
|
368 |
+
config.vocab_size,
|
369 |
+
config.hidden_size,
|
370 |
+
initializer_range=config.initializer_range,
|
371 |
+
name="wte_remaining_frames",
|
372 |
+
)
|
373 |
+
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
|
374 |
+
self.h = [
|
375 |
+
TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)
|
376 |
+
]
|
377 |
+
self.ln_f = tf.keras.layers.LayerNormalization(
|
378 |
+
epsilon=config.layer_norm_epsilon, name="ln_f"
|
379 |
+
)
|
380 |
+
|
381 |
+
def build(self, input_shape):
|
382 |
+
with tf.name_scope("wpe"):
|
383 |
+
self.wpe = self.add_weight(
|
384 |
+
name="embeddings",
|
385 |
+
shape=[self.n_positions, self.n_embd],
|
386 |
+
initializer=get_initializer(self.initializer_range),
|
387 |
+
)
|
388 |
+
self.wte_remaining_frames.build(input_shape)
|
389 |
+
|
390 |
+
super().build(input_shape)
|
391 |
+
|
392 |
+
def get_input_embeddings(self):
|
393 |
+
return self.wte
|
394 |
+
|
395 |
+
def get_remaining_frames_embeddings(self):
|
396 |
+
return self.wte_remaining_frames
|
397 |
+
|
398 |
+
def set_input_embeddings(self, value):
|
399 |
+
self.wte.weight = value
|
400 |
+
self.wte.vocab_size = shape_list(value)[0]
|
401 |
+
|
402 |
+
def set_remaining_frames_embeddings(self, value):
|
403 |
+
self.wte_remaining_frames.weight = value
|
404 |
+
self.wte_remaining_frames.vocab_size = shape_list(value)[0]
|
405 |
+
|
406 |
+
def _prune_heads(self, heads_to_prune):
|
407 |
+
"""
|
408 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
409 |
+
"""
|
410 |
+
raise NotImplementedError
|
411 |
+
|
412 |
+
@unpack_inputs
|
413 |
+
def call(
|
414 |
+
self,
|
415 |
+
input_ids: Optional[TFModelInputType] = None,
|
416 |
+
remaining_frames_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
417 |
+
past: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
418 |
+
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
419 |
+
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
420 |
+
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
421 |
+
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
422 |
+
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
423 |
+
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
424 |
+
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
425 |
+
use_cache: Optional[bool] = None,
|
426 |
+
output_attentions: Optional[bool] = None,
|
427 |
+
output_hidden_states: Optional[bool] = None,
|
428 |
+
return_dict: Optional[bool] = None,
|
429 |
+
training: Optional[bool] = False,
|
430 |
+
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
|
431 |
+
|
432 |
+
if input_ids is not None and inputs_embeds is not None:
|
433 |
+
raise ValueError(
|
434 |
+
"You cannot specify both input_ids and inputs_embeds at the same time"
|
435 |
+
)
|
436 |
+
elif input_ids is not None:
|
437 |
+
input_shape = shape_list(input_ids)
|
438 |
+
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
|
439 |
+
elif inputs_embeds is not None:
|
440 |
+
input_shape = shape_list(inputs_embeds)[:-1]
|
441 |
+
else:
|
442 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
443 |
+
|
444 |
+
if past is None:
|
445 |
+
past_length = 0
|
446 |
+
past = [None] * len(self.h)
|
447 |
+
else:
|
448 |
+
past_length = shape_list(past[0][0])[-2]
|
449 |
+
|
450 |
+
if position_ids is None:
|
451 |
+
position_ids = tf.expand_dims(
|
452 |
+
tf.range(past_length, input_shape[-1] + past_length), axis=0
|
453 |
+
)
|
454 |
+
|
455 |
+
if attention_mask is not None:
|
456 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
457 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
458 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
459 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
460 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
461 |
+
attention_mask_shape = shape_list(attention_mask)
|
462 |
+
attention_mask = tf.reshape(
|
463 |
+
attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
|
464 |
+
)
|
465 |
+
|
466 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
467 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
468 |
+
# positions we want to attend and -10000.0 for masked positions.
|
469 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
470 |
+
# effectively the same as removing these entirely.
|
471 |
+
one_cst = tf.constant(1.0)
|
472 |
+
attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype)
|
473 |
+
attention_mask = tf.multiply(
|
474 |
+
tf.subtract(one_cst, attention_mask), tf.constant(-10000.0)
|
475 |
+
)
|
476 |
+
|
477 |
+
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000
|
478 |
+
if self.config.add_cross_attention and encoder_attention_mask is not None:
|
479 |
+
# If a 2D ou 3D attention mask is provided for the cross-attention
|
480 |
+
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
481 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
482 |
+
encoder_attention_mask = tf.cast(
|
483 |
+
encoder_attention_mask, dtype=encoder_hidden_states.dtype
|
484 |
+
)
|
485 |
+
num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
|
486 |
+
if num_dims_encoder_attention_mask == 3:
|
487 |
+
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
488 |
+
if num_dims_encoder_attention_mask == 2:
|
489 |
+
encoder_extended_attention_mask = encoder_attention_mask[
|
490 |
+
:, None, None, :
|
491 |
+
]
|
492 |
+
|
493 |
+
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
494 |
+
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
|
495 |
+
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
|
496 |
+
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
|
497 |
+
|
498 |
+
encoder_extended_attention_mask = (
|
499 |
+
1.0 - encoder_extended_attention_mask
|
500 |
+
) * -10000.0
|
501 |
+
else:
|
502 |
+
encoder_extended_attention_mask = None
|
503 |
+
|
504 |
+
encoder_attention_mask = encoder_extended_attention_mask
|
505 |
+
|
506 |
+
# Prepare head mask if needed
|
507 |
+
# 1.0 in head_mask indicate we keep the head
|
508 |
+
# attention_probs has shape bsz x n_heads x N x N
|
509 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
510 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
511 |
+
if head_mask is not None:
|
512 |
+
raise NotImplementedError
|
513 |
+
else:
|
514 |
+
head_mask = [None] * self.num_hidden_layers
|
515 |
+
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
516 |
+
|
517 |
+
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
518 |
+
|
519 |
+
if inputs_embeds is None:
|
520 |
+
inputs_embeds = self.wte(input_ids, mode="embedding")
|
521 |
+
|
522 |
+
position_embeds = tf.gather(self.wpe, position_ids)
|
523 |
+
|
524 |
+
if token_type_ids is not None:
|
525 |
+
token_type_ids = tf.reshape(
|
526 |
+
token_type_ids, [-1, shape_list(token_type_ids)[-1]]
|
527 |
+
)
|
528 |
+
token_type_embeds = self.wte(token_type_ids, mode="embedding")
|
529 |
+
else:
|
530 |
+
token_type_embeds = tf.constant(0.0)
|
531 |
+
|
532 |
+
if remaining_frames_ids is not None:
|
533 |
+
remaining_frames_ids = tf.reshape(
|
534 |
+
remaining_frames_ids, [-1, shape_list(remaining_frames_ids)[-1]]
|
535 |
+
)
|
536 |
+
remaining_frames_embeds = self.wte_remaining_frames(
|
537 |
+
remaining_frames_ids, mode="embedding"
|
538 |
+
)
|
539 |
+
else:
|
540 |
+
remaining_frames_embeds = tf.constant(0.0)
|
541 |
+
|
542 |
+
position_embeds = tf.cast(position_embeds, dtype=inputs_embeds.dtype)
|
543 |
+
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype)
|
544 |
+
remaining_frames_embeds = tf.cast(
|
545 |
+
remaining_frames_embeds, dtype=inputs_embeds.dtype
|
546 |
+
)
|
547 |
+
hidden_states = (
|
548 |
+
inputs_embeds
|
549 |
+
+ position_embeds
|
550 |
+
+ token_type_embeds
|
551 |
+
+ remaining_frames_embeds
|
552 |
+
)
|
553 |
+
hidden_states = self.drop(hidden_states, training=training)
|
554 |
+
|
555 |
+
output_shape = input_shape + [shape_list(hidden_states)[-1]]
|
556 |
+
|
557 |
+
presents = () if use_cache else None
|
558 |
+
all_attentions = () if output_attentions else None
|
559 |
+
all_cross_attentions = (
|
560 |
+
() if output_attentions and self.config.add_cross_attention else None
|
561 |
+
)
|
562 |
+
all_hidden_states = () if output_hidden_states else None
|
563 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past)):
|
564 |
+
if output_hidden_states:
|
565 |
+
all_hidden_states = all_hidden_states + (
|
566 |
+
tf.reshape(hidden_states, output_shape),
|
567 |
+
)
|
568 |
+
|
569 |
+
outputs = block(
|
570 |
+
hidden_states,
|
571 |
+
layer_past,
|
572 |
+
attention_mask,
|
573 |
+
head_mask[i],
|
574 |
+
encoder_hidden_states,
|
575 |
+
encoder_attention_mask,
|
576 |
+
use_cache,
|
577 |
+
output_attentions,
|
578 |
+
training=training,
|
579 |
+
)
|
580 |
+
|
581 |
+
hidden_states, present = outputs[:2]
|
582 |
+
if use_cache:
|
583 |
+
presents = presents + (present,)
|
584 |
+
|
585 |
+
if output_attentions:
|
586 |
+
all_attentions = all_attentions + (outputs[2],)
|
587 |
+
if (
|
588 |
+
self.config.add_cross_attention
|
589 |
+
and encoder_hidden_states is not None
|
590 |
+
):
|
591 |
+
all_cross_attentions = all_cross_attentions + (outputs[3],)
|
592 |
+
|
593 |
+
hidden_states = self.ln_f(hidden_states)
|
594 |
+
|
595 |
+
hidden_states = tf.reshape(hidden_states, output_shape)
|
596 |
+
# Add last hidden state
|
597 |
+
if output_hidden_states:
|
598 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
599 |
+
|
600 |
+
if output_attentions:
|
601 |
+
# let the number of heads free (-1) so we can extract attention even after head pruning
|
602 |
+
attention_output_shape = (
|
603 |
+
input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
604 |
+
)
|
605 |
+
all_attentions = tuple(
|
606 |
+
tf.reshape(t, attention_output_shape) for t in all_attentions
|
607 |
+
)
|
608 |
+
|
609 |
+
if not return_dict:
|
610 |
+
return tuple(
|
611 |
+
v
|
612 |
+
for v in [
|
613 |
+
hidden_states,
|
614 |
+
presents,
|
615 |
+
all_hidden_states,
|
616 |
+
all_attentions,
|
617 |
+
all_cross_attentions,
|
618 |
+
]
|
619 |
+
if v is not None
|
620 |
+
)
|
621 |
+
|
622 |
+
return TFBaseModelOutputWithPastAndCrossAttentions(
|
623 |
+
last_hidden_state=hidden_states,
|
624 |
+
past_key_values=presents,
|
625 |
+
hidden_states=all_hidden_states,
|
626 |
+
attentions=all_attentions,
|
627 |
+
cross_attentions=all_cross_attentions,
|
628 |
+
)
|
629 |
+
|
630 |
+
|
631 |
+
class TFGPT2PreTrainedModel(TFPreTrainedModel):
|
632 |
+
"""
|
633 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
634 |
+
models.
|
635 |
+
"""
|
636 |
+
|
637 |
+
config_class = GPT2Config
|
638 |
+
base_model_prefix = "transformer"
|
639 |
+
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
640 |
+
_keys_to_ignore_on_load_unexpected = [
|
641 |
+
r"h.\d+.attn.bias",
|
642 |
+
r"h.\d+.crossattention.bias",
|
643 |
+
]
|
644 |
+
|
645 |
+
@property
|
646 |
+
def dummy_inputs(self):
|
647 |
+
"""
|
648 |
+
Dummy inputs to build the network.
|
649 |
+
|
650 |
+
Returns:
|
651 |
+
`Dict[str, tf.Tensor]`: The dummy inputs.
|
652 |
+
"""
|
653 |
+
dummy = {"input_ids": tf.constant(DUMMY_INPUTS)}
|
654 |
+
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
|
655 |
+
if self.config.add_cross_attention:
|
656 |
+
batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape
|
657 |
+
shape = (batch_size, seq_len) + (self.config.hidden_size,)
|
658 |
+
h = tf.random.uniform(shape=shape)
|
659 |
+
dummy["encoder_hidden_states"] = h
|
660 |
+
|
661 |
+
return dummy
|
662 |
+
|
663 |
+
@tf.function(
|
664 |
+
input_signature=[
|
665 |
+
{
|
666 |
+
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
|
667 |
+
"attention_mask": tf.TensorSpec(
|
668 |
+
(None, None), tf.int32, name="attention_mask"
|
669 |
+
),
|
670 |
+
}
|
671 |
+
]
|
672 |
+
)
|
673 |
+
def serving(self, inputs):
|
674 |
+
output = self.call(inputs)
|
675 |
+
|
676 |
+
return self.serving_output(output)
|
677 |
+
|
678 |
+
|
679 |
+
GPT2_START_DOCSTRING = r"""
|
680 |
+
|
681 |
+
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
682 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
683 |
+
etc.)
|
684 |
+
|
685 |
+
This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
686 |
+
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
687 |
+
behavior.
|
688 |
+
|
689 |
+
<Tip>
|
690 |
+
|
691 |
+
TF 2.0 models accepts two formats as inputs:
|
692 |
+
|
693 |
+
- having all inputs as keyword arguments (like PyTorch models), or
|
694 |
+
- having all inputs as a list, tuple or dict in the first positional arguments.
|
695 |
+
|
696 |
+
This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the
|
697 |
+
tensors in the first argument of the model call function: `model(inputs)`.
|
698 |
+
|
699 |
+
If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the
|
700 |
+
first positional argument :
|
701 |
+
|
702 |
+
- a single Tensor with `input_ids` only and nothing else: `model(inputs_ids)`
|
703 |
+
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
704 |
+
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
|
705 |
+
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
706 |
+
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
|
707 |
+
|
708 |
+
</Tip>
|
709 |
+
|
710 |
+
Parameters:
|
711 |
+
config ([`GPT2Config`]): Model configuration class with all the parameters of the model.
|
712 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
713 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
714 |
+
"""
|
715 |
+
|
716 |
+
GPT2_INPUTS_DOCSTRING = r"""
|
717 |
+
Args:
|
718 |
+
input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, input_ids_length)`):
|
719 |
+
`input_ids_length` = `sequence_length` if `past` is `None` else `past[0].shape[-2]` (`sequence_length` of
|
720 |
+
input past key value states). Indices of input sequence tokens in the vocabulary.
|
721 |
+
|
722 |
+
If `past` is used, only input IDs that do not have their past calculated should be passed as `input_ids`.
|
723 |
+
|
724 |
+
Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.__call__`] and
|
725 |
+
[`PreTrainedTokenizer.encode`] for details.
|
726 |
+
|
727 |
+
[What are input IDs?](../glossary#input-ids)
|
728 |
+
past (`List[tf.Tensor]` of length `config.n_layers`):
|
729 |
+
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
|
730 |
+
`past` output below). Can be used to speed up sequential decoding. The token ids which have their past
|
731 |
+
given to this model should not be passed as input ids as they have already been computed.
|
732 |
+
attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
|
733 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
734 |
+
|
735 |
+
- 1 for tokens that are **not masked**,
|
736 |
+
- 0 for tokens that are **masked**.
|
737 |
+
|
738 |
+
If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
|
739 |
+
`past_key_values`. In other words, the `attention_mask` always has to have the length:
|
740 |
+
`len(past_key_values) + len(input_ids)`
|
741 |
+
|
742 |
+
[What are attention masks?](../glossary#attention-mask)
|
743 |
+
token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
|
744 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
745 |
+
1]`:
|
746 |
+
|
747 |
+
- 0 corresponds to a *sentence A* token,
|
748 |
+
- 1 corresponds to a *sentence B* token.
|
749 |
+
|
750 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
751 |
+
position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
|
752 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
753 |
+
config.max_position_embeddings - 1]`.
|
754 |
+
|
755 |
+
[What are position IDs?](../glossary#position-ids)
|
756 |
+
head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
757 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
758 |
+
|
759 |
+
- 1 indicates the head is **not masked**,
|
760 |
+
- 0 indicates the head is **masked**.
|
761 |
+
|
762 |
+
inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
763 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
764 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
765 |
+
model's internal embedding lookup matrix.
|
766 |
+
output_attentions (`bool`, *optional*):
|
767 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
768 |
+
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
|
769 |
+
config will be used instead.
|
770 |
+
output_hidden_states (`bool`, *optional*):
|
771 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
772 |
+
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
773 |
+
used instead.
|
774 |
+
return_dict (`bool`, *optional*):
|
775 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
|
776 |
+
eager mode, in graph mode the value will always be set to True.
|
777 |
+
training (`bool`, *optional*, defaults to `False`):
|
778 |
+
Whether or not to use the model in training mode (some modules like dropout modules have different
|
779 |
+
behaviors between training and evaluation).
|
780 |
+
"""
|
781 |
+
|
782 |
+
|
783 |
+
@add_start_docstrings(
|
784 |
+
"The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
|
785 |
+
GPT2_START_DOCSTRING,
|
786 |
+
)
|
787 |
+
class TFGPT2Model(TFGPT2PreTrainedModel):
|
788 |
+
def __init__(self, config, *inputs, **kwargs):
|
789 |
+
super().__init__(config, *inputs, **kwargs)
|
790 |
+
self.transformer = TFGPT2MainLayer(config, name="transformer")
|
791 |
+
|
792 |
+
@unpack_inputs
|
793 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
794 |
+
@add_code_sample_docstrings(
|
795 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
796 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
797 |
+
output_type=TFBaseModelOutputWithPastAndCrossAttentions,
|
798 |
+
config_class=_CONFIG_FOR_DOC,
|
799 |
+
)
|
800 |
+
def call(
|
801 |
+
self,
|
802 |
+
input_ids: Optional[TFModelInputType] = None,
|
803 |
+
remaining_frames_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
804 |
+
past: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
805 |
+
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
806 |
+
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
807 |
+
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
808 |
+
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
809 |
+
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
810 |
+
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
811 |
+
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
812 |
+
use_cache: Optional[bool] = None,
|
813 |
+
output_attentions: Optional[bool] = None,
|
814 |
+
output_hidden_states: Optional[bool] = None,
|
815 |
+
return_dict: Optional[bool] = None,
|
816 |
+
training: Optional[bool] = False,
|
817 |
+
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
|
818 |
+
r"""
|
819 |
+
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
820 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
821 |
+
the model is configured as a decoder.
|
822 |
+
encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
823 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
824 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
825 |
+
|
826 |
+
- 1 for tokens that are **not masked**,
|
827 |
+
- 0 for tokens that are **masked**.
|
828 |
+
|
829 |
+
past (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
|
830 |
+
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
831 |
+
If `past` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have
|
832 |
+
their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
833 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
834 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
835 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
836 |
+
`past`). Set to `False` during training, `True` during generation
|
837 |
+
"""
|
838 |
+
|
839 |
+
outputs = self.transformer(
|
840 |
+
input_ids=input_ids,
|
841 |
+
remaining_frames_ids=remaining_frames_ids,
|
842 |
+
past=past,
|
843 |
+
attention_mask=attention_mask,
|
844 |
+
token_type_ids=token_type_ids,
|
845 |
+
position_ids=position_ids,
|
846 |
+
head_mask=head_mask,
|
847 |
+
inputs_embeds=inputs_embeds,
|
848 |
+
encoder_hidden_states=encoder_hidden_states,
|
849 |
+
encoder_attention_mask=encoder_attention_mask,
|
850 |
+
use_cache=use_cache,
|
851 |
+
output_attentions=output_attentions,
|
852 |
+
output_hidden_states=output_hidden_states,
|
853 |
+
return_dict=return_dict,
|
854 |
+
training=training,
|
855 |
+
)
|
856 |
+
|
857 |
+
return outputs
|
858 |
+
|
859 |
+
def serving_output(self, output):
|
860 |
+
pkv = (
|
861 |
+
tf.convert_to_tensor(output.past_key_values)
|
862 |
+
if self.config.use_cache
|
863 |
+
else None
|
864 |
+
)
|
865 |
+
hs = (
|
866 |
+
tf.convert_to_tensor(output.hidden_states)
|
867 |
+
if self.config.output_hidden_states
|
868 |
+
else None
|
869 |
+
)
|
870 |
+
attns = (
|
871 |
+
tf.convert_to_tensor(output.attentions)
|
872 |
+
if self.config.output_attentions
|
873 |
+
else None
|
874 |
+
)
|
875 |
+
cross_attns = (
|
876 |
+
tf.convert_to_tensor(output.cross_attentions)
|
877 |
+
if self.config.output_attentions
|
878 |
+
and self.config.add_cross_attention
|
879 |
+
and output.cross_attentions is not None
|
880 |
+
else None
|
881 |
+
)
|
882 |
+
|
883 |
+
return TFBaseModelOutputWithPastAndCrossAttentions(
|
884 |
+
last_hidden_state=output.last_hidden_state,
|
885 |
+
past_key_values=pkv,
|
886 |
+
hidden_states=hs,
|
887 |
+
attentions=attns,
|
888 |
+
cross_attentions=cross_attns,
|
889 |
+
)
|
890 |
+
|
891 |
+
|
892 |
+
@add_start_docstrings(
|
893 |
+
"""
|
894 |
+
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
895 |
+
embeddings).
|
896 |
+
""",
|
897 |
+
GPT2_START_DOCSTRING,
|
898 |
+
)
|
899 |
+
class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
900 |
+
def __init__(self, config, *inputs, **kwargs):
|
901 |
+
super().__init__(config, *inputs, **kwargs)
|
902 |
+
self.transformer = TFGPT2MainLayer(config, name="transformer")
|
903 |
+
|
904 |
+
def get_output_embeddings(self):
|
905 |
+
return self.get_input_embeddings()
|
906 |
+
|
907 |
+
def set_output_embeddings(self, value):
|
908 |
+
self.set_input_embeddings(value)
|
909 |
+
|
910 |
+
def prepare_inputs_for_generation(
|
911 |
+
self, inputs, past=None, use_cache=None, use_xla=False, **kwargs
|
912 |
+
):
|
913 |
+
# TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
|
914 |
+
# tests will need to be fixed after the change
|
915 |
+
|
916 |
+
# only last token for inputs_ids if past is defined in kwargs
|
917 |
+
if past:
|
918 |
+
inputs = tf.expand_dims(inputs[:, -1], -1)
|
919 |
+
|
920 |
+
# TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left
|
921 |
+
# for a future PR to not change too many things for now.
|
922 |
+
# All statements in this if case apply for both xla and non-xla (as they already do in PyTorch)
|
923 |
+
position_ids = None
|
924 |
+
attention_mask = None
|
925 |
+
if use_xla:
|
926 |
+
attention_mask = kwargs.get("attention_mask", None)
|
927 |
+
if past is not None and attention_mask is not None:
|
928 |
+
position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1
|
929 |
+
elif attention_mask is not None:
|
930 |
+
position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True)
|
931 |
+
|
932 |
+
return {
|
933 |
+
"input_ids": inputs,
|
934 |
+
"attention_mask": attention_mask,
|
935 |
+
"position_ids": position_ids,
|
936 |
+
"past": past,
|
937 |
+
"use_cache": use_cache,
|
938 |
+
}
|
939 |
+
|
940 |
+
def _update_model_kwargs_for_xla_generation(
|
941 |
+
self, outputs, model_kwargs, current_pos, max_length
|
942 |
+
):
|
943 |
+
# TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
|
944 |
+
# quite some duplicated code patterns it seems
|
945 |
+
# also the `attention_mask` is currently used in a somewhat hacky to
|
946 |
+
# correctly influence the `past_key_values` - not sure if this is the way to go
|
947 |
+
# Let's keep that for a future PR.
|
948 |
+
past = outputs.past_key_values
|
949 |
+
is_past_initialized = model_kwargs.pop("past", None) is not None
|
950 |
+
attention_mask = model_kwargs.pop("attention_mask")
|
951 |
+
batch_size = attention_mask.shape[0]
|
952 |
+
|
953 |
+
if not is_past_initialized:
|
954 |
+
# past[0].shape[3] is seq_length of prompt
|
955 |
+
num_padding_values = max_length - past[0].shape[3] - 1
|
956 |
+
|
957 |
+
padding_values = np.zeros((5, 2), dtype=np.int32)
|
958 |
+
padding_values[3, 1] = num_padding_values
|
959 |
+
padding_values = tf.constant(padding_values)
|
960 |
+
|
961 |
+
new_past = list(past)
|
962 |
+
for i in range(len(past)):
|
963 |
+
new_past[i] = tf.pad(past[i], padding_values)
|
964 |
+
|
965 |
+
# Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids
|
966 |
+
attention_mask = tf.concat(
|
967 |
+
[
|
968 |
+
attention_mask,
|
969 |
+
tf.zeros(
|
970 |
+
(batch_size, num_padding_values), dtype=attention_mask.dtype
|
971 |
+
),
|
972 |
+
tf.ones((batch_size, 1), dtype=attention_mask.dtype),
|
973 |
+
],
|
974 |
+
axis=1,
|
975 |
+
)
|
976 |
+
else:
|
977 |
+
new_past = [None for _ in range(len(past))]
|
978 |
+
slice_start_base = tf.constant([0, 0, 0, 1, 0])
|
979 |
+
attention_mask_update_slice = tf.ones(
|
980 |
+
(batch_size, 1), dtype=attention_mask.dtype
|
981 |
+
)
|
982 |
+
# correct 5 here
|
983 |
+
new_past_index = current_pos - 1
|
984 |
+
|
985 |
+
for i in range(len(past)):
|
986 |
+
update_slice = past[i][:, :, :, -1:]
|
987 |
+
# Write the last slice to the first open location in the padded past array
|
988 |
+
# and then truncate the last slice off the array
|
989 |
+
new_past[i] = dynamic_update_slice(
|
990 |
+
past[i][:, :, :, :-1],
|
991 |
+
update_slice,
|
992 |
+
slice_start_base * new_past_index,
|
993 |
+
)
|
994 |
+
|
995 |
+
update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
|
996 |
+
attention_mask = dynamic_update_slice(
|
997 |
+
attention_mask, attention_mask_update_slice, update_start
|
998 |
+
)
|
999 |
+
|
1000 |
+
# set `attention_mask` and `past`
|
1001 |
+
model_kwargs["attention_mask"] = attention_mask
|
1002 |
+
model_kwargs["past"] = tuple(new_past)
|
1003 |
+
|
1004 |
+
return model_kwargs
|
1005 |
+
|
1006 |
+
@unpack_inputs
|
1007 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
1008 |
+
@add_code_sample_docstrings(
|
1009 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
1010 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1011 |
+
output_type=TFCausalLMOutputWithCrossAttentions,
|
1012 |
+
config_class=_CONFIG_FOR_DOC,
|
1013 |
+
)
|
1014 |
+
def call(
|
1015 |
+
self,
|
1016 |
+
input_ids: Optional[TFModelInputType] = None,
|
1017 |
+
remaining_frames_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
1018 |
+
past: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
1019 |
+
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
1020 |
+
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
1021 |
+
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
1022 |
+
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
1023 |
+
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
1024 |
+
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
1025 |
+
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
1026 |
+
use_cache: Optional[bool] = None,
|
1027 |
+
output_attentions: Optional[bool] = None,
|
1028 |
+
output_hidden_states: Optional[bool] = None,
|
1029 |
+
return_dict: Optional[bool] = None,
|
1030 |
+
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
1031 |
+
training: Optional[bool] = False,
|
1032 |
+
) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]:
|
1033 |
+
r"""
|
1034 |
+
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
1035 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
1036 |
+
the model is configured as a decoder.
|
1037 |
+
encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1038 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
1039 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
1040 |
+
|
1041 |
+
- 1 for tokens that are **not masked**,
|
1042 |
+
- 0 for tokens that are **masked**.
|
1043 |
+
|
1044 |
+
past (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
|
1045 |
+
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
1046 |
+
If `past` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have
|
1047 |
+
their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
1048 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
1049 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
1050 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
1051 |
+
`past`). Set to `False` during training, `True` during generation
|
1052 |
+
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1053 |
+
Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
|
1054 |
+
config.vocab_size - 1]`.
|
1055 |
+
"""
|
1056 |
+
|
1057 |
+
transformer_outputs = self.transformer(
|
1058 |
+
input_ids=input_ids,
|
1059 |
+
remaining_frames_ids=remaining_frames_ids,
|
1060 |
+
past=past,
|
1061 |
+
attention_mask=attention_mask,
|
1062 |
+
token_type_ids=token_type_ids,
|
1063 |
+
position_ids=position_ids,
|
1064 |
+
head_mask=head_mask,
|
1065 |
+
inputs_embeds=inputs_embeds,
|
1066 |
+
encoder_hidden_states=encoder_hidden_states,
|
1067 |
+
encoder_attention_mask=encoder_attention_mask,
|
1068 |
+
use_cache=use_cache,
|
1069 |
+
output_attentions=output_attentions,
|
1070 |
+
output_hidden_states=output_hidden_states,
|
1071 |
+
return_dict=return_dict,
|
1072 |
+
training=training,
|
1073 |
+
)
|
1074 |
+
hidden_states = transformer_outputs[0]
|
1075 |
+
logits = self.transformer.wte(hidden_states, mode="linear")
|
1076 |
+
|
1077 |
+
loss = None
|
1078 |
+
if labels is not None:
|
1079 |
+
# shift labels to the left and cut last logit token
|
1080 |
+
shifted_logits = logits[:, :-1]
|
1081 |
+
labels = labels[:, 1:]
|
1082 |
+
loss = self.hf_compute_loss(labels, shifted_logits)
|
1083 |
+
|
1084 |
+
if not return_dict:
|
1085 |
+
output = (logits,) + transformer_outputs[1:]
|
1086 |
+
return ((loss,) + output) if loss is not None else output
|
1087 |
+
|
1088 |
+
return TFCausalLMOutputWithCrossAttentions(
|
1089 |
+
loss=loss,
|
1090 |
+
logits=logits,
|
1091 |
+
past_key_values=transformer_outputs.past_key_values,
|
1092 |
+
hidden_states=transformer_outputs.hidden_states,
|
1093 |
+
attentions=transformer_outputs.attentions,
|
1094 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
def serving_output(self, output):
|
1098 |
+
pkv = (
|
1099 |
+
tf.convert_to_tensor(output.past_key_values)
|
1100 |
+
if self.config.use_cache
|
1101 |
+
else None
|
1102 |
+
)
|
1103 |
+
hs = (
|
1104 |
+
tf.convert_to_tensor(output.hidden_states)
|
1105 |
+
if self.config.output_hidden_states
|
1106 |
+
else None
|
1107 |
+
)
|
1108 |
+
attns = (
|
1109 |
+
tf.convert_to_tensor(output.attentions)
|
1110 |
+
if self.config.output_attentions
|
1111 |
+
else None
|
1112 |
+
)
|
1113 |
+
cross_attns = (
|
1114 |
+
tf.convert_to_tensor(output.cross_attentions)
|
1115 |
+
if self.config.output_attentions
|
1116 |
+
and self.config.add_cross_attention
|
1117 |
+
and output.cross_attentions is not None
|
1118 |
+
else None
|
1119 |
+
)
|
1120 |
+
|
1121 |
+
return TFCausalLMOutputWithCrossAttentions(
|
1122 |
+
logits=output.logits,
|
1123 |
+
past_key_values=pkv,
|
1124 |
+
hidden_states=hs,
|
1125 |
+
attentions=attns,
|
1126 |
+
cross_attentions=cross_attns,
|
1127 |
+
)
|