diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..e0f8a1ad05d801589aba1c31e8d726d60e9965fe --- /dev/null +++ b/.dockerignore @@ -0,0 +1,4 @@ +.git +data +checkpoints +logs \ No newline at end of file diff --git a/.gitattributes b/.gitattributes index c7d9f3332a950355d5a77d85000f05e6f45435ea..cd967db1727d14bc6db15f653c414b44dfbdfb79 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +models/vgg19/imagenet-vgg-verydeep-19.mat filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..00208b5ac7a6f5b8296c86b3e8e2b624e4439795 --- /dev/null +++ b/.gitignore @@ -0,0 +1,208 @@ + +# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,venv +# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python,venv + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### venv ### +# Virtualenv +# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ +[Bb]in +[Ii]nclude +[Ll]ib +[Ll]ib64 +[Ll]ocal +#[Ss]cripts +pyvenv.cfg +pip-selfcheck.json + +### VisualStudioCode ### +.vscode/* +# !.vscode/settings.json +# !.vscode/tasks.json +# !.vscode/launch.json +# !.vscode/extensions.json +# !.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +# Support for Project snippet scope + +# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,venv + +*.npy +checkpoints/* +ganime_results/* +data/* +*.avi +*.out +notebooks/model/p2p_v2/* +logs/* +interesting_logs/* +notebooks/model/vq-gan/train_output/* +notebooks/model/vq-gan/validation_output/* +notebooks/model/vq-gan/test_output/* +*.zip +flagged/* +notebooks/model/vq-gan/gpt_kny_light_large_256/* \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..2f96b937d6d91bfb21d2ce0bc0e92c4fb5588899 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,21 @@ +FROM tensorflow/tensorflow:2.7.0-gpu-jupyter +# Because of https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/ and https://github.com/NVIDIA/nvidia-docker/issues/1631#issuecomment-1112828208 +RUN rm /etc/apt/sources.list.d/cuda.list +RUN rm /etc/apt/sources.list.d/nvidia-ml.list +RUN apt-key del 7fa2af80 +RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub +RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu2004/x86_64/7fa2af80.pub + +# Update and install ffmpeg +RUN apt-get -y update +RUN apt-get -y upgrade +RUN apt-get install -y ffmpeg + +# Setup environment +WORKDIR /GANime +ENV PROJECT_DIR=/GANime +COPY requirements.txt /GANime/requirements.txt +RUN pip install -r requirements.txt +COPY . . +RUN pip install -e . +EXPOSE 8888 \ No newline at end of file diff --git a/configs/colab.yaml b/configs/colab.yaml new file mode 100644 index 0000000000000000000000000000000000000000..957f93d5a87d9460e0df6a93bd2bcf07eada2eb3 --- /dev/null +++ b/configs/colab.yaml @@ -0,0 +1,50 @@ +model: + transformer_config: + #checkpoint_path: GANime/checkpoints/kny_video_full_gpt2_medium/checkpoint + remaining_frames_method: "own_embeddings" + transformer_type: "gpt2-medium" + first_stage_config: + checkpoint_path: GANime/checkpoints/kny_image_full_vgg19/checkpoint + vqvae_config: + beta: 0.25 + num_embeddings: 50257 + embedding_dim: 128 + autoencoder_config: + z_channels: 512 + channels: 32 + channels_multiplier: + - 2 + - 4 + - 8 + - 8 + num_res_blocks: 1 + attention_resolution: + - 16 + resolution: 128 + dropout: 0.0 + discriminator_config: + num_layers: 3 + filters: 64 + + loss_config: + discriminator: + loss: "hinge" + factor: 1.0 + iter_start: 16200 + weight: 0.3 + vqvae: + codebook_weight: 1.0 + perceptual_weight: 4.0 + perceptual_loss: "vgg19" + +train: + batch_size: 64 + accumulation_size: 1 + n_epochs: 2000 + len_x_train: 8000 + warmup_epoch_percentage: 0.15 + lr_start: 1e-5 + lr_max: 2.5e-4 + perceptual_loss_weight: 1.0 + n_frames_before: 1 + stop_ground_truth_after_epoch: 50 diff --git a/configs/kny_image.yaml b/configs/kny_image.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f88afafe1591f47be2c2ae4b786c4b67fe2ae06a --- /dev/null +++ b/configs/kny_image.yaml @@ -0,0 +1,47 @@ +model: + checkpoint_path: ../../../checkpoints/kny_image_full_no_disc/checkpoint + vqvae_config: + beta: 0.25 + num_embeddings: 50257 + embedding_dim: 128 + autoencoder_config: + z_channels: 512 + channels: 32 + channels_multiplier: + - 2 + - 4 + - 8 + - 8 + num_res_blocks: 1 + attention_resolution: + - 16 + resolution: 128 + dropout: 0.0 + discriminator_config: + num_layers: 3 + filters: 64 + + loss_config: + discriminator: + loss: "hinge" + factor: 1.0 + iter_start: 5000 + weight: 0.8 + vqvae: + codebook_weight: 1.0 + perceptual_weight: 4.0 + perceptual_loss: "vgg19" # "vgg16", "vgg19", "style" + +trainer: + batch_size: 32 + n_epochs: 10000 + gen_lr: 3e-5 + disc_lr: 3e-5 + gen_beta_1: 0.5 + gen_beta_2: 0.9 + disc_beta_1: 0.5 + disc_beta_2: 0.9 + gen_clip_norm: 1.0 + disc_clip_norm: 1.0 + + diff --git a/configs/kny_image_full_style.yaml b/configs/kny_image_full_style.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e56bfb225a8f846b139fc85ff802e7cb3cced385 --- /dev/null +++ b/configs/kny_image_full_style.yaml @@ -0,0 +1,47 @@ +model: + checkpoint_path: ../../../checkpoints/kny_image_full_style/checkpoint + vqvae_config: + beta: 0.25 + num_embeddings: 50257 + embedding_dim: 128 + autoencoder_config: + z_channels: 512 + channels: 32 + channels_multiplier: + - 2 + - 4 + - 8 + - 8 + num_res_blocks: 1 + attention_resolution: + - 16 + resolution: 128 + dropout: 0.0 + discriminator_config: + num_layers: 3 + filters: 64 + + loss_config: + discriminator: + loss: "hinge" + factor: 1.0 + iter_start: 50000000 + weight: 0.8 + vqvae: + codebook_weight: 1.0 + perceptual_weight: 4.0 + perceptual_loss: "style" # "vgg16", "vgg19", "style" + +trainer: + batch_size: 32 + n_epochs: 10000 + gen_lr: 8e-5 + disc_lr: 8e-5 + gen_beta_1: 0.5 + gen_beta_2: 0.9 + disc_beta_1: 0.5 + disc_beta_2: 0.9 + gen_clip_norm: 1.0 + disc_clip_norm: 1.0 + + diff --git a/configs/kny_image_full_vgg19.yaml b/configs/kny_image_full_vgg19.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ddff48a8e701cb9b0df6fa08e2324656285a1568 --- /dev/null +++ b/configs/kny_image_full_vgg19.yaml @@ -0,0 +1,47 @@ +model: + checkpoint_path: ../../../checkpoints/kny_image_full_vgg19/checkpoint + vqvae_config: + beta: 0.25 + num_embeddings: 50257 + embedding_dim: 128 + autoencoder_config: + z_channels: 512 + channels: 32 + channels_multiplier: + - 2 + - 4 + - 8 + - 8 + num_res_blocks: 1 + attention_resolution: + - 16 + resolution: 128 + dropout: 0.0 + discriminator_config: + num_layers: 3 + filters: 64 + + loss_config: + discriminator: + loss: "hinge" + factor: 1.0 + iter_start: 50000000 + weight: 0.8 + vqvae: + codebook_weight: 1.0 + perceptual_weight: 4.0 + perceptual_loss: "vgg19" # "vgg16", "vgg19", "style" + +trainer: + batch_size: 64 + n_epochs: 10000 + gen_lr: 3e-5 + disc_lr: 5e-5 + gen_beta_1: 0.5 + gen_beta_2: 0.9 + disc_beta_1: 0.5 + disc_beta_2: 0.9 + gen_clip_norm: 1.0 + disc_clip_norm: 1.0 + + diff --git a/configs/kny_transformer_light.yaml b/configs/kny_transformer_light.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a2a73efce8a5eefa7f68ae07c8b8d244d8d78110 --- /dev/null +++ b/configs/kny_transformer_light.yaml @@ -0,0 +1,60 @@ +model: + transformer_config: + checkpoint_path: ../../../checkpoints/kny_video_light/checkpoint + # vocab_size: 50257 + # n_positions: 1024 + # n_embd: 1024 #1280 #768 + # n_layer: 24 #36 #12 + # n_head: 16 #20 #12 + # resid_pdrop: 0.1 + # embd_pdrop: 0.1 + # attn_pdrop: 0.1 + # remaining_frames_method: "concat" + # remaining_frames_method: "token_type_ids" + remaining_frames_method: "own_embeddings" + first_stage_config: + checkpoint_path: ../../../checkpoints/kny_image_light_discriminator/checkpoint + vqvae_config: + beta: 0.25 + num_embeddings: 64 + embedding_dim: 256 + autoencoder_config: + z_channels: 128 + channels: 64 + channels_multiplier: + - 1 + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 1 + attention_resolution: + - 16 + resolution: 128 + dropout: 0.0 + discriminator_config: + num_layers: 3 + filters: 64 + + loss_config: + discriminator: + loss: "hinge" + factor: 1.0 + iter_start: 16200 + weight: 0.3 + vqvae: + codebook_weight: 1.0 + perceptual_weight: 4.0 + perceptual_loss: "style" + +train: + batch_size: 8 + accumulation_size: 8 + n_epochs: 2000 + len_x_train: 631 + warmup_epoch_percentage: 0.15 + lr_start: 1e-5 + lr_max: 2.5e-4 + perceptual_loss_weight: 1.0 + n_frames_before: 5 + stop_ground_truth_after_epoch: 100 diff --git a/configs/kny_video_gpt2_large.yaml b/configs/kny_video_gpt2_large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..abbcce77dc44dd6c9685c681063a34c52a7317a1 --- /dev/null +++ b/configs/kny_video_gpt2_large.yaml @@ -0,0 +1,50 @@ +model: + transformer_config: + checkpoint_path: ../../../checkpoints/kny_video_full_gpt2_large_final/checkpoint + remaining_frames_method: "own_embeddings" + transformer_type: "gpt2-large" + first_stage_config: + checkpoint_path: ../../../checkpoints/kny_image_full_vgg19/checkpoint + vqvae_config: + beta: 0.25 + num_embeddings: 50257 + embedding_dim: 128 + autoencoder_config: + z_channels: 512 + channels: 32 + channels_multiplier: + - 2 + - 4 + - 8 + - 8 + num_res_blocks: 1 + attention_resolution: + - 16 + resolution: 128 + dropout: 0.0 + discriminator_config: + num_layers: 3 + filters: 64 + + loss_config: + discriminator: + loss: "hinge" + factor: 1.0 + iter_start: 16200 + weight: 0.3 + vqvae: + codebook_weight: 1.0 + perceptual_weight: 4.0 + perceptual_loss: "vgg19" + +train: + batch_size: 64 + accumulation_size: 1 + n_epochs: 10000 + len_x_train: 28213 + warmup_epoch_percentage: 0.15 + lr_start: 1e-5 + lr_max: 2.5e-4 + perceptual_loss_weight: 1.0 + n_frames_before: 1 + stop_ground_truth_after_epoch: 1000 diff --git a/configs/kny_video_gpt2_large_gradio.yaml b/configs/kny_video_gpt2_large_gradio.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1309dc7eaa1fea4e931fef159d31941fbe3d8db0 --- /dev/null +++ b/configs/kny_video_gpt2_large_gradio.yaml @@ -0,0 +1,50 @@ +model: + transformer_config: + checkpoint_path: ./checkpoints/kny_video_full_gpt2_large_final/checkpoint + remaining_frames_method: "own_embeddings" + transformer_type: "gpt2-large" + first_stage_config: + checkpoint_path: ./checkpoints/kny_image_full_vgg19/checkpoint + vqvae_config: + beta: 0.25 + num_embeddings: 50257 + embedding_dim: 128 + autoencoder_config: + z_channels: 512 + channels: 32 + channels_multiplier: + - 2 + - 4 + - 8 + - 8 + num_res_blocks: 1 + attention_resolution: + - 16 + resolution: 128 + dropout: 0.0 + discriminator_config: + num_layers: 3 + filters: 64 + + loss_config: + discriminator: + loss: "hinge" + factor: 1.0 + iter_start: 16200 + weight: 0.3 + vqvae: + codebook_weight: 1.0 + perceptual_weight: 4.0 + perceptual_loss: "vgg19" + +train: + batch_size: 64 + accumulation_size: 1 + n_epochs: 10000 + len_x_train: 28213 + warmup_epoch_percentage: 0.15 + lr_start: 1e-5 + lr_max: 2.5e-4 + perceptual_loss_weight: 1.0 + n_frames_before: 1 + stop_ground_truth_after_epoch: 1000 diff --git a/configs/kny_video_gpt2_medium.yaml b/configs/kny_video_gpt2_medium.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e71dd7a2c15d7630330f49599bad2ed0f2a085e6 --- /dev/null +++ b/configs/kny_video_gpt2_medium.yaml @@ -0,0 +1,50 @@ +model: + transformer_config: + checkpoint_path: ./checkpoints/kny_video_full_gpt2_medium/checkpoint + remaining_frames_method: "own_embeddings" + transformer_type: "gpt2-medium" + first_stage_config: + checkpoint_path: ./checkpoints/kny_image_full_vgg19/checkpoint + vqvae_config: + beta: 0.25 + num_embeddings: 50257 + embedding_dim: 128 + autoencoder_config: + z_channels: 512 + channels: 32 + channels_multiplier: + - 2 + - 4 + - 8 + - 8 + num_res_blocks: 1 + attention_resolution: + - 16 + resolution: 128 + dropout: 0.0 + discriminator_config: + num_layers: 3 + filters: 64 + + loss_config: + discriminator: + loss: "hinge" + factor: 1.0 + iter_start: 16200 + weight: 0.3 + vqvae: + codebook_weight: 1.0 + perceptual_weight: 4.0 + perceptual_loss: "vgg19" + +train: + batch_size: 64 + accumulation_size: 1 + n_epochs: 500 + len_x_train: 28213 + warmup_epoch_percentage: 0.15 + lr_start: 5e-6 + lr_max: 1e-4 + perceptual_loss_weight: 1.0 + n_frames_before: 5 + stop_ground_truth_after_epoch: 200 diff --git a/configs/kny_video_gpt2_xl.yaml b/configs/kny_video_gpt2_xl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7c8c6276dc0e344b1f921181772e2ad11216f3ac --- /dev/null +++ b/configs/kny_video_gpt2_xl.yaml @@ -0,0 +1,50 @@ +model: + transformer_config: + # checkpoint_path: ../../../checkpoints/kny_video_full_gpt2_xl/checkpoint + remaining_frames_method: "own_embeddings" + transformer_type: "gpt2-xl" + first_stage_config: + checkpoint_path: ../../../checkpoints/kny_image_full_vgg19/checkpoint + vqvae_config: + beta: 0.25 + num_embeddings: 50257 + embedding_dim: 128 + autoencoder_config: + z_channels: 512 + channels: 32 + channels_multiplier: + - 2 + - 4 + - 8 + - 8 + num_res_blocks: 1 + attention_resolution: + - 16 + resolution: 128 + dropout: 0.0 + discriminator_config: + num_layers: 3 + filters: 64 + + loss_config: + discriminator: + loss: "hinge" + factor: 1.0 + iter_start: 16200 + weight: 0.3 + vqvae: + codebook_weight: 1.0 + perceptual_weight: 4.0 + perceptual_loss: "vgg19" + +train: + batch_size: 64 + accumulation_size: 1 + n_epochs: 500 + len_x_train: 28213 + warmup_epoch_percentage: 0.15 + lr_start: 5e-6 + lr_max: 1e-4 + perceptual_loss_weight: 1.0 + n_frames_before: 1 + stop_ground_truth_after_epoch: 200 diff --git a/ganime/__main__.py b/ganime/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..3474935a7850cb2ecb4c16e83ac410350020ba1f --- /dev/null +++ b/ganime/__main__.py @@ -0,0 +1,4 @@ +from ganime import app + +if __name__ == "__main__": + app.run() diff --git a/ganime/app.py b/ganime/app.py new file mode 100644 index 0000000000000000000000000000000000000000..4661d151aa74938e4d661645b399e2fd732f6996 --- /dev/null +++ b/ganime/app.py @@ -0,0 +1,212 @@ +import os + +import click +import omegaconf +import ray +from pyprojroot.pyprojroot import here +from ray import tune +from ray.train import Trainer +from ray.tune.schedulers import AsyncHyperBandScheduler +from ray.tune.suggest import ConcurrencyLimiter +from ray.tune.suggest.optuna import OptunaSearch + +from ganime.trainer.ganime import TrainableGANime + +import os + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 +os.environ["CUDA_VISIBLE_DEVICES"] = "1, 2, 3, 4, 5, 6" + + +def get_metric_direction(metric: str): + if "loss" in metric: + return "min" + else: + raise ValueError(f"Unknown metric: {metric}") + + +def trial_name_id(trial): + return f"{trial.trainable_name}" + + +def trial_dirname_creator(trial): + return f"{trial.trial_id}" + + +def get_search_space(model): + if model == "vqgan": + return { + # "beta": tune.uniform(0.1, 1.0), + "num_embeddings": tune.choice([64, 128, 256]), + "embedding_dim": tune.choice([128, 256, 512, 1024]), + "z_channels": tune.choice([64, 128, 256]), + "channels": tune.choice([64, 128, 256]), + "channels_multiplier": tune.choice( + [ + [1, 2, 4], + [1, 1, 2, 2], + [1, 2, 2, 4], + [1, 1, 2, 2, 4], + ] + ), + "attention_resolution": tune.choice([[16], [32], [16, 32]]), + "batch_size": tune.choice([8, 16]), + "dropout": tune.choice([0.0, 0.1, 0.2]), + "weight": tune.quniform(0.1, 1.0, 0.1), + "codebook_weight": tune.quniform(0.2, 2.0, 0.2), + "perceptual_weight": tune.quniform(0.5, 5.0, 0.5), + "gen_lr": tune.qloguniform(1e-5, 1e-3, 1e-5), + "disc_lr": tune.qloguniform(1e-5, 1e-3, 1e-5), + "gen_beta_1": tune.quniform(0.5, 0.9, 0.1), + "gen_beta_2": tune.quniform(0.9, 0.999, 0.001), + "disc_beta_1": tune.quniform(0.5, 0.9, 0.1), + "disc_beta_2": tune.quniform(0.9, 0.999, 0.001), + "gen_clip_norm": tune.choice([1.0, None]), + "disc_clip_norm": tune.choice([1.0, None]), + } + elif model == "gpt": + return { + "remaining_frames_method": tune.choice( + ["concat", "token_type_ids", "own_embeddings"] + ), + # "batch_size": tune.choice([8, 16]), + "lr_max": tune.qloguniform(1e-5, 1e-3, 5e-5), + "lr_start": tune.sample_from(lambda spec: spec.config.lr_max / 10), + "perceptual_loss_weight": tune.quniform(0.0, 1.0, 0.1), + "n_frames_before": tune.randint(1, 10), + } + + +def tune_ganime( + experiment_name: str, + dataset_name: str, + config_file: str, + model: str, + metric: str, + epochs: int, + num_samples: int, + num_cpus: int, + num_gpus: int, + max_concurrent_trials: int, +): + + dataset_path = here("data") + analysis = tune.run( + TrainableGANime, + name=experiment_name, + search_alg=ConcurrencyLimiter( + OptunaSearch(), max_concurrent=max_concurrent_trials + ), + scheduler=AsyncHyperBandScheduler(max_t=epochs, grace_period=5), + metric=metric, + mode=get_metric_direction(metric), + num_samples=num_samples, + stop={"training_iteration": epochs}, + local_dir="./ganime_results", + config={ + "dataset_name": dataset_name, + "dataset_path": dataset_path, + "model": model, + "config_file": config_file, + "hyperparameters": get_search_space(model), + }, + resources_per_trial={ + "cpu": num_cpus // max_concurrent_trials, + "gpu": num_gpus / max_concurrent_trials, + }, + trial_name_creator=trial_name_id, + trial_dirname_creator=trial_dirname_creator, + ) + best_loss = analysis.get_best_config(metric="total_loss", mode="min") + # best_accuracy = analysis.get_best_config(metric="accuracy", mode="max") + print(f"Best loss config: {best_loss}") + # print(f"Best accuracy config: {best_accuracy}") + return analysis + + +@click.command() +@click.option( + "--dataset", + type=click.Choice( + ["moving_mnist_images", "kny_images", "kny_images_light"], case_sensitive=False + ), + default="kny_images_light", + help="Dataset to use", +) +@click.option( + "--model", + type=click.Choice(["vqgan", "gpt"], case_sensitive=False), + default="vqgan", + help="Model to use", +) +@click.option( + "--epochs", + default=500, + help="Number of epochs to run", +) +@click.option( + "--num_samples", + default=100, + help="Total number of trials to run", +) +@click.option( + "--num_cpus", + default=64, + help="Number of cpus to use", +) +@click.option( + "--num_gpus", + default=6, + help="Number of gpus to use", +) +@click.option( + "--max_concurrent_trials", + default=6, + help="Maximum number of concurrent trials", +) +@click.option( + "--metric", + type=click.Choice( + ["total_loss", "reconstruction_loss", "vq_loss", "disc_loss"], + case_sensitive=False, + ), + default="total_loss", + help="The metric used to select the best trial", +) +@click.option( + "--experiment_name", + default="kny_images_light_v2", + help="The name of the experiment for logging in Tensorboard", +) +@click.option( + "--config_file", + default="kny_image.yaml", + help="The name of the config file located inside ./config", +) +def run( + experiment_name: str, + config_file: str, + dataset: str, + model: str, + epochs: int, + num_samples: int, + num_cpus: int, + num_gpus: int, + max_concurrent_trials: int, + metric: str, +): + config_file = here(os.path.join("configs", config_file)) + + ray.init(num_cpus=num_cpus, num_gpus=num_gpus) + tune_ganime( + experiment_name=experiment_name, + dataset_name=dataset, + config_file=config_file, + model=model, + epochs=epochs, + num_samples=num_samples, + num_cpus=num_cpus, + num_gpus=num_gpus, + max_concurrent_trials=max_concurrent_trials, + metric=metric, + ) diff --git a/ganime/configs/__init__.py b/ganime/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ganime/configs/model_configs.py b/ganime/configs/model_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..07479aa8844f4ec0dce18453fad034162628f360 --- /dev/null +++ b/ganime/configs/model_configs.py @@ -0,0 +1,70 @@ +from dataclasses import dataclass +from typing import List +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + + +@dataclass +class GPTConfig: + n_layer: int + n_head: int + n_embedding: int + vocab_size: int + block_size: int + embedding_percentage_drop: float + attention_percentage_drop: float + + +@dataclass +class VQVAEConfig: + beta: float + num_embeddings: int + embedding_dim: int + + +@dataclass +class AutoencoderConfig: + z_channels: int + channels: int + channels_multiplier: List[int] + num_res_blocks: int + attention_resolution: List[int] + resolution: int + dropout: float + + +@dataclass +class DiscriminatorConfig: + num_layers: int + filters: int + + +@dataclass +class DiscriminatorLossConfig: + loss: Literal["hinge, vanilla"] + factor: float + iter_start: int + weight: float + + +@dataclass +class VQVAELossConfig: + codebook_weight: float + perceptual_weight: float + + +@dataclass +class LossConfig: + discriminator: DiscriminatorLossConfig + vqvae: VQVAELossConfig + perceptual_loss: str + + +@dataclass +class ModelConfig: + vqvae_config: VQVAEConfig + autoencoder_config: AutoencoderConfig + discriminator_config: DiscriminatorConfig + loss_config: LossConfig diff --git a/ganime/data/__init__.py b/ganime/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ganime/data/base.py b/ganime/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a36556ec75ec7484c847c5ccefa324fdcd358bd3 --- /dev/null +++ b/ganime/data/base.py @@ -0,0 +1,282 @@ +from typing import Tuple +import numpy as np +import tensorflow as tf +import os +from tensorflow.keras.utils import Sequence +from abc import ABC, abstractmethod +from typing import Literal +import math +from ganime.data.experimental import ImageDataset + + +# class SequenceDataset(Sequence): +# def __init__( +# self, +# dataset_path: str, +# batch_size: int, +# split: Literal["train", "validation", "test"] = "train", +# ): +# self.batch_size = batch_size +# self.split = split +# self.data = self.load_data(dataset_path, split) +# self.data = self.preprocess_data(self.data) + +# self.indices = np.arange(self.data.shape[0]) +# self.on_epoch_end() + +# @abstractmethod +# def load_data(self, dataset_path: str, split: str) -> np.ndarray: +# pass + +# def preprocess_data(self, data: np.ndarray) -> np.ndarray: +# return data + +# def __len__(self): +# return math.ceil(len(self.data) / self.batch_size) + +# def __getitem__(self, idx): +# inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size] +# batch_x = self.data[inds] +# batch_y = batch_x + +# return batch_x, batch_y + +# def get_fixed_batch(self, idx): +# self.fixed_indices = ( +# self.fixed_indices +# if hasattr(self, "fixed_indices") +# else self.indices[ +# idx * self.batch_size : (idx + 1) * self.batch_size +# ].copy() +# ) +# batch_x = self.data[self.fixed_indices] +# batch_y = batch_x + +# return batch_x, batch_y + +# def on_epoch_end(self): +# np.random.shuffle(self.indices) + + +# def load_kny_images( +# dataset_path: str, batch_size: int +# ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tuple]: +# import skvideo.io + +# if os.path.exists(os.path.join(dataset_path, "kny", "kny_images.npy")): +# data = np.load(os.path.join(dataset_path, "kny", "kny_images.npy")) +# else: +# data = skvideo.io.vread(os.path.join(dataset_path, "kny", "01.mp4")) +# np.random.shuffle(data) + +# def _preprocess(sample): +# image = tf.cast(sample, tf.float32) / 255.0 # Scale to unit interval. +# # video = video < tf.random.uniform(tf.shape(video)) # Randomly binarize. +# image = tf.image.resize(image, [64, 64]) + +# return image, image + +# train_dataset = ( +# tf.data.Dataset.from_tensor_slices(data[:5000]) +# .map(_preprocess) +# .batch(batch_size) +# .prefetch(tf.data.AUTOTUNE) +# .shuffle(int(10e3)) +# ) +# test_dataset = ( +# tf.data.Dataset.from_tensor_slices(data[5000:6000]) +# .map(_preprocess) +# .batch(batch_size) +# .prefetch(tf.data.AUTOTUNE) +# .shuffle(int(10e3)) +# ) + +# return train_dataset, test_dataset, data.shape[1:] + + +# def load_moving_mnist_vae( +# dataset_path: str, batch_size: int +# ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tuple]: +# data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy")) +# data.shape + +# # We can see that data is of shape (window, n_samples, width, height) +# # But we want for keras something of shape (n_samples, window, width, height) +# data = np.moveaxis(data, 0, 1) +# # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels) +# data = np.expand_dims(data, axis=-1) + +# def _preprocess(sample): +# video = tf.cast(sample, tf.float32) / 255.0 # Scale to unit interval. +# # video = video < tf.random.uniform(tf.shape(video)) # Randomly binarize. +# return video, video + +# train_dataset = ( +# tf.data.Dataset.from_tensor_slices(data[:9000]) +# .map(_preprocess) +# .batch(batch_size) +# .prefetch(tf.data.AUTOTUNE) +# .shuffle(int(10e3)) +# ) +# test_dataset = ( +# tf.data.Dataset.from_tensor_slices(data[9000:]) +# .map(_preprocess) +# .batch(batch_size) +# .prefetch(tf.data.AUTOTUNE) +# .shuffle(int(10e3)) +# ) + +# return train_dataset, test_dataset, data.shape[1:] + + +# def load_moving_mnist( +# dataset_path: str, batch_size: int +# ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tuple]: +# data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy")) +# data.shape + +# # We can see that data is of shape (window, n_samples, width, height) +# # But we want for keras something of shape (n_samples, window, width, height) +# data = np.moveaxis(data, 0, 1) +# # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels) +# data = np.expand_dims(data, axis=-1) + +# def _preprocess(sample): +# video = tf.cast(sample, tf.float32) / 255.0 # Scale to unit interval. +# # video = video < tf.random.uniform(tf.shape(video)) # Randomly binarize. +# first_frame = video[0:1, ...] +# last_frame = video[-1:, ...] +# first_last = tf.concat([first_frame, last_frame], axis=0) + +# return first_last, video + +# train_dataset = ( +# tf.data.Dataset.from_tensor_slices(data[:9000]) +# .map(_preprocess) +# .batch(batch_size) +# .prefetch(tf.data.AUTOTUNE) +# .shuffle(int(10e3)) +# ) +# test_dataset = ( +# tf.data.Dataset.from_tensor_slices(data[9000:]) +# .map(_preprocess) +# .batch(batch_size) +# .prefetch(tf.data.AUTOTUNE) +# .shuffle(int(10e3)) +# ) + +# return train_dataset, test_dataset, data.shape[1:] + + +# def load_mnist( +# dataset_path: str, batch_size: int +# ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tuple]: +# data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy")) +# data.shape + +# # We can see that data is of shape (window, n_samples, width, height) +# # But we want for keras something of shape (n_samples, window, width, height) +# data = np.moveaxis(data, 0, 1) +# # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels) +# data = np.expand_dims(data, axis=-1) + +# def _preprocess(sample): +# video = tf.cast(sample, tf.float32) / 255.0 # Scale to unit interval. +# # video = video < tf.random.uniform(tf.shape(video)) # Randomly binarize. +# first_frame = video[0, ...] + +# first_frame = tf.image.grayscale_to_rgb(first_frame) + +# return first_frame, first_frame + +# train_dataset = ( +# tf.data.Dataset.from_tensor_slices(data[:9000]) +# .map(_preprocess) +# .batch(batch_size) +# .prefetch(tf.data.AUTOTUNE) +# .shuffle(int(10e3)) +# ) +# test_dataset = ( +# tf.data.Dataset.from_tensor_slices(data[9000:]) +# .map(_preprocess) +# .batch(batch_size) +# .prefetch(tf.data.AUTOTUNE) +# .shuffle(int(10e3)) +# ) + +# return train_dataset, test_dataset, data.shape[1:] +def preprocess_image(element): + element = tf.reshape(element, (tf.shape(element)[0], tf.shape(element)[1], 3)) + element = tf.cast(element, tf.float32) / 255.0 + return element, element + + +def load_kny_images_light(dataset_path, batch_size): + dataset_length = 34045 + path = os.path.join(dataset_path, "kny", "images_tfrecords_light") + dataset = ImageDataset(path).load() + dataset = dataset.shuffle( + dataset_length, reshuffle_each_iteration=True, seed=10 + ).map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE) + + train_size = int(dataset_length * 0.8) + validation_size = int(dataset_length * 0.1) + + train_ds = dataset.take(train_size) + validation_ds = dataset.skip(train_size).take(validation_size) + test_ds = dataset.skip(train_size + validation_size).take(validation_size) + + train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch( + tf.data.AUTOTUNE + ) + validation_ds = validation_ds.batch(batch_size, drop_remainder=True).prefetch( + tf.data.AUTOTUNE + ) + test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE) + + return train_ds, validation_ds, test_ds + + +def load_kny_images(dataset_path, batch_size): + dataset_length = 52014 + path = os.path.join(dataset_path, "kny", "images_tfrecords") + dataset = ImageDataset(path).load() + dataset = dataset.shuffle(dataset_length, reshuffle_each_iteration=True).map( + preprocess_image, num_parallel_calls=tf.data.AUTOTUNE + ) + + train_size = int(dataset_length * 0.8) + validation_size = int(dataset_length * 0.1) + + train_ds = dataset.take(train_size) + validation_ds = dataset.skip(train_size).take(validation_size) + test_ds = dataset.skip(train_size + validation_size).take(validation_size) + + train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch( + tf.data.AUTOTUNE + ) + validation_ds = validation_ds.batch(batch_size, drop_remainder=True).prefetch( + tf.data.AUTOTUNE + ) + test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE) + + return train_ds, validation_ds, test_ds + + +def load_dataset( + dataset_name: str, dataset_path: str, batch_size: int +) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]: + # if dataset_name == "moving_mnist_vae": + # return load_moving_mnist_vae(dataset_path, batch_size) + # elif dataset_name == "moving_mnist": + # return load_moving_mnist(dataset_path, batch_size) + # elif dataset_name == "mnist": + # return load_mnist(dataset_path, batch_size) + # elif dataset_name == "kny_images": + # return load_kny_images(dataset_path, batch_size) + if dataset_name == "kny_images": + return load_kny_images(dataset_path, batch_size) + if dataset_name == "kny_images_light": + return load_kny_images_light(dataset_path, batch_size) + else: + raise ValueError(f"Unknown dataset: {dataset_name}") diff --git a/ganime/data/experimental.py b/ganime/data/experimental.py new file mode 100644 index 0000000000000000000000000000000000000000..8eb055629a553cd99b7ea9964d59f8e8b6e9e0ca --- /dev/null +++ b/ganime/data/experimental.py @@ -0,0 +1,222 @@ +from abc import ABC, abstractclassmethod, abstractmethod +import glob +import math +import os +from typing import Dict +from typing_extensions import dataclass_transform + +import numpy as np +import tensorflow as tf +from tqdm.auto import tqdm + + +def _bytes_feature(value): + """Returns a bytes_list from a string / byte.""" + if isinstance(value, type(tf.constant(0))): # if value ist tensor + value = value.numpy() # get value of tensor + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def _float_feature(value): + """Returns a floast_list from a float / double.""" + return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) + + +def _int64_feature(value): + """Returns an int64_list from a bool / enum / int / uint.""" + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +def serialize_array(array): + array = tf.io.serialize_tensor(array) + return array + + +class Dataset(ABC): + def __init__(self, dataset_path: str): + self.dataset_path = dataset_path + + @classmethod + def _parse_single_element(cls, element) -> tf.train.Example: + + features = tf.train.Features(feature=cls._get_features(element)) + + return tf.train.Example(features=features) + + @abstractclassmethod + def _get_features(cls, element) -> Dict[str, tf.train.Feature]: + pass + + @abstractclassmethod + def _parse_tfr_element(cls, element): + pass + + @classmethod + def write_to_tfr(cls, data: np.ndarray, out_dir: str, filename: str): + if not os.path.exists(out_dir): + os.makedirs(out_dir) + + # Write all elements to a single tfrecord file + single_file_name = cls.__write_to_single_tfr(data, out_dir, filename) + + # The optimal size for a single tfrecord file is around 100 MB. Get the number of files that need to be created + number_splits = cls.__get_number_splits(single_file_name) + + if number_splits > 1: + os.remove(single_file_name) + cls.__write_to_multiple_tfr(data, out_dir, filename, number_splits) + + @classmethod + def __write_to_multiple_tfr( + cls, data: np.array, out_dir: str, filename: str, n_splits: int + ): + + file_count = 0 + + max_files = math.ceil(data.shape[0] / n_splits) + + print(f"Creating {n_splits} files with {max_files} elements each.") + + for i in tqdm(range(n_splits)): + current_shard_name = os.path.join( + out_dir, + f"{filename}.tfrecords-{str(i).zfill(len(str(n_splits)))}-of-{n_splits}", + ) + writer = tf.io.TFRecordWriter(current_shard_name) + + current_shard_count = 0 + while current_shard_count < max_files: # as long as our shard is not full + # get the index of the file that we want to parse now + index = i * max_files + current_shard_count + if index >= len( + data + ): # when we have consumed the whole data, preempt generation + break + + current_element = data[index] + + # create the required Example representation + out = cls._parse_single_element(element=current_element) + + writer.write(out.SerializeToString()) + current_shard_count += 1 + file_count += 1 + + writer.close() + print(f"\nWrote {file_count} elements to TFRecord") + return file_count + + @classmethod + def __get_number_splits(cls, filename: str): + target_size = 100 * 1024 * 1024 # 100mb + + single_file_size = os.path.getsize(filename) + number_splits = math.ceil(single_file_size / target_size) + return number_splits + + @classmethod + def __write_to_single_tfr(cls, data: np.array, out_dir: str, filename: str): + + current_path_name = os.path.join( + out_dir, + f"{filename}.tfrecords-0-of-1", + ) + + writer = tf.io.TFRecordWriter(current_path_name) + for element in tqdm(data): + writer.write(cls._parse_single_element(element).SerializeToString()) + writer.close() + + return current_path_name + + def load(self) -> tf.data.TFRecordDataset: + path = self.dataset_path + dataset = None + + if os.path.isdir(path): + dataset = self._load_folder(path) + elif os.path.isfile(path): + dataset = self._load_file(path) + else: + raise ValueError(f"Path {path} is not a valid file or folder.") + + dataset = dataset.map(self._parse_tfr_element) + return dataset + + def _load_file(self, path) -> tf.data.TFRecordDataset: + return tf.data.TFRecordDataset(path) + + def _load_folder(self, path) -> tf.data.TFRecordDataset: + + return tf.data.TFRecordDataset( + glob.glob(os.path.join(path, "**/*.tfrecords*"), recursive=True) + ) + + +class VideoDataset(Dataset): + @classmethod + def _get_features(cls, element) -> Dict[str, tf.train.Feature]: + return { + "frames": _int64_feature(element.shape[0]), + "height": _int64_feature(element.shape[1]), + "width": _int64_feature(element.shape[2]), + "depth": _int64_feature(element.shape[3]), + "raw_video": _bytes_feature(serialize_array(element)), + } + + @classmethod + def _parse_tfr_element(cls, element): + # use the same structure as above; it's kinda an outline of the structure we now want to create + data = { + "frames": tf.io.FixedLenFeature([], tf.int64), + "height": tf.io.FixedLenFeature([], tf.int64), + "width": tf.io.FixedLenFeature([], tf.int64), + "raw_video": tf.io.FixedLenFeature([], tf.string), + "depth": tf.io.FixedLenFeature([], tf.int64), + } + + content = tf.io.parse_single_example(element, data) + + frames = content["frames"] + height = content["height"] + width = content["width"] + depth = content["depth"] + raw_video = content["raw_video"] + + # get our 'feature'-- our image -- and reshape it appropriately + feature = tf.io.parse_tensor(raw_video, out_type=tf.uint8) + feature = tf.reshape(feature, shape=[frames, height, width, depth]) + return feature + + +class ImageDataset(Dataset): + @classmethod + def _get_features(cls, element) -> Dict[str, tf.train.Feature]: + return { + "height": _int64_feature(element.shape[0]), + "width": _int64_feature(element.shape[1]), + "depth": _int64_feature(element.shape[2]), + "raw_image": _bytes_feature(serialize_array(element)), + } + + @classmethod + def _parse_tfr_element(cls, element): + # use the same structure as above; it's kinda an outline of the structure we now want to create + data = { + "height": tf.io.FixedLenFeature([], tf.int64), + "width": tf.io.FixedLenFeature([], tf.int64), + "raw_image": tf.io.FixedLenFeature([], tf.string), + "depth": tf.io.FixedLenFeature([], tf.int64), + } + + content = tf.io.parse_single_example(element, data) + + height = content["height"] + width = content["width"] + depth = content["depth"] + raw_image = content["raw_image"] + + # get our 'feature'-- our image -- and reshape it appropriately + feature = tf.io.parse_tensor(raw_image, out_type=tf.uint8) + feature = tf.reshape(feature, shape=[height, width, depth]) + return feature diff --git a/ganime/data/kny.py b/ganime/data/kny.py new file mode 100644 index 0000000000000000000000000000000000000000..54223116721c5ac11d759359e8645f506fb85869 --- /dev/null +++ b/ganime/data/kny.py @@ -0,0 +1,19 @@ +import os + +import numpy as np + +from .base import SequenceDataset + + +class KNYImage(SequenceDataset): + def load_data(self, dataset_path: str, split: str) -> np.ndarray: + data = np.load(os.path.join(dataset_path, "kny", "kny_images_64x128.npy")) + if split == "train": + data = data[:-5000] + else: + data = data[-5000:] + + return data + + def preprocess_data(self, data: np.ndarray) -> np.ndarray: + return data / 255 diff --git a/ganime/data/mnist.py b/ganime/data/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4e62a6b5e844a7bc6725c85f22b12e537b3f6f --- /dev/null +++ b/ganime/data/mnist.py @@ -0,0 +1,103 @@ +import glob +import os +from typing import Literal + +import numpy as np + +from .base import SequenceDataset +import math + + +class MovingMNISTImage(SequenceDataset): + def load_data(self, dataset_path: str, split: str) -> np.ndarray: + data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy")) + # Data is of shape (window, n_samples, width, height) + # But we want for keras something of shape (n_samples, window, width, height) + data = np.moveaxis(data, 0, 1) + # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels) + data = np.expand_dims(data, axis=-1) + if split == "train": + data = data[:-1000] + else: + data = data[-1000:] + + data = np.concatenate([data, data, data], axis=-1) + + return data + + def __getitem__(self, idx): + inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size] + batch_x = self.data[inds, 0, ...] + batch_y = self.data[inds, 1, ...] + + return batch_x, batch_y + + def preprocess_data(self, data: np.ndarray) -> np.ndarray: + return data / 255 + + +class MovingMNIST(SequenceDataset): + def __init__( + self, + dataset_path: str, + batch_size: int, + split: Literal["train", "validation", "test"] = "train", + ): + self.batch_size = batch_size + self.split = split + root_path = os.path.join(dataset_path, "moving_mnist", split) + self.paths = glob.glob(os.path.join(root_path, "*.npy")) + # self.data = self.preprocess_data(self.data) + + self.indices = np.arange(len(self.paths)) + self.on_epoch_end() + + # def load_data(self, dataset_path: str, split: str) -> np.ndarray: + # data = np.load(os.path.join(dataset_path, "moving_mnist", "mnist_test_seq.npy")) + # # Data is of shape (window, n_samples, width, height) + # # But we want for keras something of shape (n_samples, window, width, height) + # data = np.moveaxis(data, 0, 1) + # # Also expand dimensions to have channels at the end (n_samples, window, width, height, channels) + # data = np.expand_dims(data, axis=-1) + # if split == "train": + # data = data[:100] + # else: + # data = data[100:110] + + # data = np.concatenate([data, data, data], axis=-1) + + # return data + + def __len__(self): + return math.ceil(len(self.paths) / self.batch_size) + + def __getitem__(self, idx): + inds = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size] + data = self.load_indices(inds) + batch_x = np.concatenate([data[:, 0:1, ...], data[:, -1:, ...]], axis=1) + batch_y = data[:, 1:, ...] + + return batch_x, batch_y + + def get_fixed_batch(self, idx): + self.fixed_indices = ( + self.fixed_indices + if hasattr(self, "fixed_indices") + else self.indices[ + idx * self.batch_size : (idx + 1) * self.batch_size + ].copy() + ) + data = self.load_indices(self.fixed_indices) + batch_x = np.concatenate([data[:, 0:1, ...], data[:, -1:, ...]], axis=1) + batch_y = data[:, 1:, ...] + + return batch_x, batch_y + + def load_indices(self, indices): + paths_to_load = [self.paths[index] for index in indices] + data = [np.load(path) for path in paths_to_load] + data = np.array(data) + return self.preprocess_data(data) + + def preprocess_data(self, data: np.ndarray) -> np.ndarray: + return data / 255 diff --git a/ganime/metrics/image.py b/ganime/metrics/image.py new file mode 100644 index 0000000000000000000000000000000000000000..ff411d2b1d59a1ebebf3a3b8ea6253e64ffe2ec2 --- /dev/null +++ b/ganime/metrics/image.py @@ -0,0 +1,70 @@ +import numpy as np +import tensorflow as tf +from scipy import linalg +from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input +from tqdm.auto import tqdm + +inceptionv3 = InceptionV3(include_top=False, weights="imagenet", pooling="avg") + + +def resize_images(images, new_shape): + images = tf.image.resize(images, new_shape) + return images + + +def calculate_fid(real_embeddings, generated_embeddings): + # calculate mean and covariance statistics + mu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings, rowvar=False) + mu2, sigma2 = generated_embeddings.mean(axis=0), np.cov( + generated_embeddings, rowvar=False + ) + # calculate sum squared difference between means + ssdiff = np.sum((mu1 - mu2) ** 2.0) + # calculate sqrt of product between cov + covmean = linalg.sqrtm(sigma1.dot(sigma2)) + # check and correct imaginary numbers from sqrt + if np.iscomplexobj(covmean): + covmean = covmean.real + # calculate score + fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) + return fid + + +def calculate_images_metrics(dataset, model, total_length): + fake_embeddings = [] + real_embeddings = [] + + psnrs = [] + ssims = [] + + for sample in tqdm(dataset, total=total_length): + generated = model(sample[0], training=False)[0] + generated, real = generated, sample[0] + + real_resized = resize_images(real, (299, 299)) + generated_resized = resize_images(generated, (299, 299)) + + real_activations = inceptionv3(real_resized, training=False) + generated_activations = inceptionv3(generated_resized, training=False) + fake_embeddings.append(generated_activations) + real_embeddings.append(real_activations) + + fake_scaled = tf.cast(((generated * 0.5) + 1) * 255, tf.uint8) + real_scaled = tf.cast(((real * 0.5) + 1) * 255, tf.uint8) + + psnrs.append(tf.image.psnr(fake_scaled, real_scaled, 255).numpy()) + ssims.append(tf.image.ssim(fake_scaled, real_scaled, 255).numpy()) + + fid = calculate_fid( + tf.concat(fake_embeddings, axis=0).numpy(), + tf.concat(real_embeddings, axis=0).numpy(), + ) + + # kid = calculate_kid( + # tf.concat(fake_embeddings, axis=0).numpy(), + # tf.concat(real_embeddings, axis=0).numpy(), + # ) + + psnr = np.array(psnrs).mean() + ssim = np.array(ssims).mean() + return {"fid": fid, "ssim": ssim, "psnr": psnr} diff --git a/ganime/metrics/video.py b/ganime/metrics/video.py new file mode 100644 index 0000000000000000000000000000000000000000..5a0b3fc101dcc9bd190baa705b62242ffaecd55b --- /dev/null +++ b/ganime/metrics/video.py @@ -0,0 +1,98 @@ +import numpy as np +import tensorflow as tf +import tensorflow_gan as tfgan +import tensorflow_hub as hub +from sklearn.metrics.pairwise import polynomial_kernel +from tqdm.auto import tqdm + +i3d = hub.KerasLayer("https://tfhub.dev/deepmind/i3d-kinetics-400/1") + + +def resize_videos(videos, target_resolution): + """Runs some preprocessing on the videos for I3D model. + Args: + videos: [batch_size, num_frames, height, width, depth] The videos to be + preprocessed. We don't care about the specific dtype of the videos, it can + be anything that tf.image.resize_bilinear accepts. Values are expected to + be in [-1, 1]. + target_resolution: (width, height): target video resolution + Returns: + videos: [batch_size, num_frames, height, width, depth] + """ + min_frames = 9 + B, T, H, W, C = videos.shape + videos = tf.transpose(videos, (1, 0, 2, 3, 4)) + if T < min_frames: + videos = tf.concat([tf.zeros((min_frames - T, B, H, W, C)), videos], axis=0) + scaled_videos = tf.map_fn(lambda x: tf.image.resize(x, target_resolution), videos) + scaled_videos = tf.transpose(scaled_videos, (1, 0, 2, 3, 4)) + return scaled_videos + + +def polynomial_mmd(X, Y): + m = X.shape[0] + n = Y.shape[0] + # compute kernels + K_XX = polynomial_kernel(X) + K_YY = polynomial_kernel(Y) + K_XY = polynomial_kernel(X, Y) + # compute mmd distance + K_XX_sum = (K_XX.sum() - np.diagonal(K_XX).sum()) / (m * (m - 1)) + K_YY_sum = (K_YY.sum() - np.diagonal(K_YY).sum()) / (n * (n - 1)) + K_XY_sum = K_XY.sum() / (m * n) + mmd = K_XX_sum + K_YY_sum - 2 * K_XY_sum + return mmd + + +def calculate_ssim_videos(fake, real): + fake = tf.cast(((fake * 0.5) + 1) * 255, tf.uint8) + real = tf.cast(((real * 0.5) + 1) * 255, tf.uint8) + ssims = [] + for i in range(fake.shape[0]): + ssims.append(tf.image.ssim(fake[i], real[i], 255).numpy().mean()) + + return np.array(ssims).mean() + + +def calculate_psnr_videos(fake, real): + fake = tf.cast(((fake * 0.5) + 1) * 255, tf.uint8) + real = tf.cast(((real * 0.5) + 1) * 255, tf.uint8) + psnrs = [] + for i in range(fake.shape[0]): + psnrs.append(tf.image.psnr(fake[i], real[i], 255).numpy().mean()) + + return np.array(psnrs).mean() + + +def calculate_videos_metrics(dataset, model, total_length): + fake_embeddings = [] + real_embeddings = [] + + psnrs = [] + ssims = [] + + for sample in tqdm(dataset, total=total_length): + generated = model(sample, training=False) + generated, real = generated[:, 1:], sample["y"][:, 1:] # ignore first frame + + real_resized = resize_videos(real, (224, 224)) + generated_resized = resize_videos(generated, (224, 224)) + + real_activations = i3d(real_resized) + generated_activations = i3d(generated_resized) + fake_embeddings.append(generated_activations) + real_embeddings.append(real_activations) + + psnrs.append(calculate_psnr_videos(generated, real)) + ssims.append(calculate_ssim_videos(generated, real)) + + # fake_concat, real_concat = tf.concat(fake_embeddings, axis=0), tf.concat(real_embeddings, axis=0) + fvd = tfgan.eval.frechet_classifier_distance_from_activations( + tf.concat(fake_embeddings, axis=0), tf.concat(real_embeddings, axis=0) + ) + kvd = polynomial_mmd( + tf.concat(fake_embeddings, axis=0), tf.concat(real_embeddings, axis=0) + ) + psnr = np.array(psnrs).mean() + ssim = np.array(ssims).mean() + return {"fvd": fvd, "kvd": kvd, "ssim": ssim, "psnr": psnr} diff --git a/ganime/model/__init__.py b/ganime/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ganime/model/base.py b/ganime/model/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ba3ba2b02fa078e67affade66ff67c3b639356a3 --- /dev/null +++ b/ganime/model/base.py @@ -0,0 +1,45 @@ +import tensorflow as tf +from ganime.model.vqgan_clean.vqgan import VQGAN + + +def load_model( + model: str, config: dict, strategy: tf.distribute.Strategy +) -> tf.keras.Model: + + if model == "vqgan": + with strategy.scope(): + print(config["model"]) + model = VQGAN(**config["model"]) + + gen_optimizer = tf.keras.optimizers.Adam( + learning_rate=config["trainer"]["gen_lr"], + beta_1=config["trainer"]["gen_beta_1"], + beta_2=config["trainer"]["gen_beta_2"], + clipnorm=config["trainer"]["gen_clip_norm"], + ) + disc_optimizer = tf.keras.optimizers.Adam( + learning_rate=config["trainer"]["disc_lr"], + beta_1=config["trainer"]["disc_beta_1"], + beta_2=config["trainer"]["disc_beta_2"], + clipnorm=config["trainer"]["disc_clip_norm"], + ) + model.compile(gen_optimizer=gen_optimizer, disc_optimizer=disc_optimizer) + return model + else: + raise ValueError(f"Unknown model: {model}") + + # if model == "moving_vae": + # from ganime.model.moving_vae import MovingVAE + + # with strategy.scope(): + # model = MovingVAE(input_shape=input_shape) + + # negloglik = lambda x, rv_x: -rv_x.log_prob(x) + # model.compile( + # optimizer=tf.optimizers.Adam(learning_rate=config["lr"]), + # loss=negloglik, + # ) + # # model.build(input_shape=(None, *input_shape)) + # # model.summary() + + # return model diff --git a/ganime/model/moving_vae.py b/ganime/model/moving_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..95b41568f5fb8d8c935c73fa4016cdb39d9bfa9a --- /dev/null +++ b/ganime/model/moving_vae.py @@ -0,0 +1,126 @@ +from tensorflow.keras import Model + +import tensorflow as tf +import tensorflow_probability as tfp + + +class MovingVAE(Model): + def __init__(self, input_shape, encoded_size=64, base_depth=32): + super().__init__() + + self.encoded_size = encoded_size + self.base_depth = base_depth + + self.prior = tfp.distributions.Independent( + tfp.distributions.Normal(loc=tf.zeros(encoded_size), scale=1), + reinterpreted_batch_ndims=1, + ) + + self.encoder = tf.keras.Sequential( + [ + tf.keras.layers.InputLayer(input_shape=input_shape), + tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5), + tf.keras.layers.Conv3D( + self.base_depth, + 5, + strides=1, + padding="same", + activation=tf.nn.leaky_relu, + ), + tf.keras.layers.Conv3D( + self.base_depth, + 5, + strides=2, + padding="same", + activation=tf.nn.leaky_relu, + ), + tf.keras.layers.Conv3D( + 2 * self.base_depth, + 5, + strides=1, + padding="same", + activation=tf.nn.leaky_relu, + ), + tf.keras.layers.Conv3D( + 2 * self.base_depth, + 5, + strides=2, + padding="same", + activation=tf.nn.leaky_relu, + ), + # tf.keras.layers.Conv3D(4 * encoded_size, 7, strides=1, + # padding='valid', activation=tf.nn.leaky_relu), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense( + tfp.layers.MultivariateNormalTriL.params_size(self.encoded_size), + activation=None, + ), + tfp.layers.MultivariateNormalTriL( + self.encoded_size, + activity_regularizer=tfp.layers.KLDivergenceRegularizer(self.prior), + ), + ] + ) + + self.decoder = tf.keras.Sequential( + [ + tf.keras.layers.InputLayer(input_shape=[self.encoded_size]), + tf.keras.layers.Reshape([1, 1, 1, self.encoded_size]), + tf.keras.layers.Conv3DTranspose( + self.base_depth, + (5, 4, 4), + strides=1, + padding="valid", + activation=tf.nn.leaky_relu, + ), + tf.keras.layers.Conv3DTranspose( + 2 * self.base_depth, + (5, 4, 4), + strides=(1, 2, 2), + padding="same", + activation=tf.nn.leaky_relu, + ), + tf.keras.layers.Conv3DTranspose( + 2 * self.base_depth, + (5, 4, 4), + strides=2, + padding="same", + activation=tf.nn.leaky_relu, + ), + tf.keras.layers.Conv3DTranspose( + self.base_depth, + (5, 4, 4), + strides=(1, 2, 2), + padding="same", + activation=tf.nn.leaky_relu, + ), + tf.keras.layers.Conv3DTranspose( + self.base_depth, + (5, 4, 4), + strides=2, + padding="same", + activation=tf.nn.leaky_relu, + ), + tf.keras.layers.Conv3DTranspose( + self.base_depth, + (5, 4, 4), + strides=1, + padding="same", + activation=tf.nn.leaky_relu, + ), + tf.keras.layers.Conv2D( + filters=1, kernel_size=5, strides=1, padding="same", activation=None + ), + tf.keras.layers.Flatten(), + tfp.layers.IndependentBernoulli( + input_shape, tfp.distributions.Bernoulli.logits + ), + ] + ) + + self.model = tf.keras.Model( + inputs=self.encoder.inputs, outputs=self.decoder(self.encoder.outputs[0]) + ) + + def call(self, inputs): + return self.model(inputs) diff --git a/ganime/model/p2p/__init__.py b/ganime/model/p2p/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ganime/model/p2p/p2p.py b/ganime/model/p2p/p2p.py new file mode 100644 index 0000000000000000000000000000000000000000..aa39b9a32b9010e5f52d8eab00c9769b63beb02c --- /dev/null +++ b/ganime/model/p2p/p2p.py @@ -0,0 +1,543 @@ +from statistics import mode +import numpy as np +import tensorflow as tf +from tensorflow.python.keras import Model, Sequential +from tensorflow.python.keras.layers import Dense, LSTMCell, RNN, Conv2D, Conv2DTranspose +from tensorflow.keras.layers import BatchNormalization, TimeDistributed +from tensorflow.python.keras.layers.advanced_activations import LeakyReLU +from tensorflow.keras.layers import Activation + +# from tensorflow_probability.python.layers.dense_variational import ( +# DenseReparameterization, +# ) +# import tensorflow_probability as tfp +from tensorflow.keras.losses import Loss + + +class KLCriterion(Loss): + def call(self, y_true, y_pred): + (mu1, logvar1), (mu2, logvar2) = y_true, y_pred + + """KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2))""" + sigma1 = tf.exp(tf.math.multiply(logvar1, 0.5)) + sigma2 = tf.exp(tf.math.multiply(logvar2, 0.5)) + + kld = ( + tf.math.log(sigma2 / sigma1) + + (tf.exp(logvar1) + tf.square(mu1 - mu2)) / (2 * tf.exp(logvar2)) + - 0.5 + ) + return tf.reduce_sum(kld) / 22 + + +class Encoder(Model): + def __init__(self, dim, nc=1): + super().__init__() + self.dim = dim + self.c1 = Sequential( + [ + Conv2D(64, kernel_size=4, strides=2, padding="same"), + BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.c2 = Sequential( + [ + Conv2D(128, kernel_size=4, strides=2, padding="same"), + BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.c3 = Sequential( + [ + Conv2D(256, kernel_size=4, strides=2, padding="same"), + BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.c4 = Sequential( + [ + Conv2D(512, kernel_size=4, strides=2, padding="same"), + BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.c5 = Sequential( + [ + Conv2D(self.dim, kernel_size=4, strides=1, padding="valid"), + BatchNormalization(), + Activation("tanh"), + ] + ) + + def call(self, input): + h1 = self.c1(input) + h2 = self.c2(h1) + h3 = self.c3(h2) + h4 = self.c4(h3) + h5 = self.c5(h4) + return tf.reshape(h5, (-1, self.dim)), [h1, h2, h3, h4, h5] + + +class Decoder(Model): + def __init__(self, dim, nc=1): + super().__init__() + self.dim = dim + self.upc1 = Sequential( + [ + Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid"), + BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.upc2 = Sequential( + [ + Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"), + BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.upc3 = Sequential( + [ + Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"), + BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.upc4 = Sequential( + [ + Conv2DTranspose(64, kernel_size=4, strides=2, padding="same"), + BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.upc5 = Sequential( + [ + Conv2DTranspose(1, kernel_size=4, strides=2, padding="same"), + Activation("sigmoid"), + ] + ) + + def call(self, input): + vec, skip = input + d1 = self.upc1(tf.reshape(vec, (-1, 1, 1, self.dim))) + d2 = self.upc2(tf.concat([d1, skip[3]], axis=-1)) + d3 = self.upc3(tf.concat([d2, skip[2]], axis=-1)) + d4 = self.upc4(tf.concat([d3, skip[1]], axis=-1)) + output = self.upc5(tf.concat([d4, skip[0]], axis=-1)) + return output + + +class MyLSTM(Model): + def __init__(self, input_shape, hidden_size, output_size, n_layers): + super().__init__() + self.hidden_size = hidden_size + self.n_layers = n_layers + self.embed = Dense(hidden_size, input_dim=input_shape) + # self.lstm = Sequential( + # [LSTMCell(hidden_size) for _ in range(n_layers)], name="lstm" + # ) + # self.lstm = self.create_lstm(hidden_size, n_layers) + self.lstm = LSTMCell(hidden_size) + self.out = Dense(output_size) + + def init_hidden(self, batch_size): + hidden = [] + for i in range(self.n_layers): + hidden.append( + ( + tf.Variable(tf.zeros([batch_size, self.hidden_size])), + tf.Variable(tf.zeros([batch_size, self.hidden_size])), + ) + ) + self.__dict__["hidden"] = hidden + + def build(self, input_shape): + self.init_hidden(input_shape[0]) + + def call(self, inputs): + h_in = self.embed(inputs) + for i in range(self.n_layers): + _, self.hidden[i] = self.lstm(h_in, self.hidden[i]) + h_in = self.hidden[i][0] + + return self.out(h_in) + + +class MyGaussianLSTM(Model): + def __init__(self, input_shape, hidden_size, output_size, n_layers): + super().__init__() + self.hidden_size = hidden_size + self.n_layers = n_layers + self.embed = Dense(hidden_size, input_dim=input_shape) + # self.lstm = Sequential( + # [LSTMCell(hidden_size) for _ in range(n_layers)], name="lstm" + # ) + self.lstm = LSTMCell(hidden_size) + self.mu_net = Dense(output_size) + self.logvar_net = Dense(output_size) + # self.out = Sequential( + # [ + # tf.keras.layers.Dense( + # tfp.layers.MultivariateNormalTriL.params_size(output_size), + # activation=None, + # ), + # tfp.layers.MultivariateNormalTriL(output_size), + # ] + # ) + + def reparameterize(self, mu, logvar: tf.Tensor): + logvar = tf.math.exp(logvar * 0.5) + eps = tf.random.normal(logvar.shape) + return tf.add(tf.math.multiply(eps, logvar), mu) + + def init_hidden(self, batch_size): + hidden = [] + for i in range(self.n_layers): + hidden.append( + ( + tf.Variable(tf.zeros([batch_size, self.hidden_size])), + tf.Variable(tf.zeros([batch_size, self.hidden_size])), + ) + ) + self.__dict__["hidden"] = hidden + + def build(self, input_shape): + self.init_hidden(input_shape[0]) + + def call(self, inputs): + h_in = self.embed(inputs) + for i in range(self.n_layers): + # print(h_in.shape, self.hidden[i][0].shape, self.hidden[i][0].shape) + + _, self.hidden[i] = self.lstm(h_in, self.hidden[i]) + h_in = self.hidden[i][0] + mu = self.mu_net(h_in) + logvar = self.logvar_net(h_in) + z = self.reparameterize(mu, logvar) + return z, mu, logvar + + +class P2P(Model): + def __init__( + self, + channels: int = 1, + g_dim: int = 128, + z_dim: int = 10, + rnn_size: int = 256, + prior_rnn_layers: int = 1, + posterior_rnn_layers: int = 1, + predictor_rnn_layers: float = 1, + skip_prob: float = 0.5, + n_past: int = 1, + last_frame_skip: bool = False, + beta: float = 0.0001, + weight_align: float = 0.1, + weight_cpc: float = 100, + ): + super().__init__() + self.channels = channels + self.g_dim = g_dim + self.z_dim = z_dim + self.rnn_size = rnn_size + self.prior_rnn_layers = prior_rnn_layers + self.posterior_rnn_layers = posterior_rnn_layers + self.predictor_rnn_layers = predictor_rnn_layers + + self.skip_prob = skip_prob + self.n_past = n_past + self.last_frame_skip = last_frame_skip + self.beta = beta + self.weight_align = weight_align + self.weight_cpc = weight_cpc + + self.frame_predictor = MyLSTM( + self.g_dim + self.z_dim + 1 + 1, + self.rnn_size, + self.g_dim, + self.predictor_rnn_layers, + ) + + self.prior = MyGaussianLSTM( + self.g_dim + self.g_dim + 1 + 1, + self.rnn_size, + self.z_dim, + self.prior_rnn_layers, + ) + + self.posterior = MyGaussianLSTM( + self.g_dim + self.g_dim + 1 + 1, + self.rnn_size, + self.z_dim, + self.posterior_rnn_layers, + ) + + self.encoder = Encoder(self.g_dim, self.channels) + self.decoder = Decoder(self.g_dim, self.channels) + + # criterions + self.mse_criterion = tf.keras.losses.MeanSquaredError() + self.kl_criterion = KLCriterion() + self.align_criterion = tf.keras.losses.MeanSquaredError() + + # optimizers + self.frame_predictor_optimizer = tf.keras.optimizers.Adam( + learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8 + ) + self.posterior_optimizer = tf.keras.optimizers.Adam( + learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8 + ) + self.prior_optimizer = tf.keras.optimizers.Adam( + learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8 + ) + self.encoder_optimizer = tf.keras.optimizers.Adam( + learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8 + ) + self.decoder_optimizer = tf.keras.optimizers.Adam( + learning_rate=0.0001 # , beta_1=0.9, beta_2=0.999, epsilon=1e-8 + ) + + def get_global_descriptor(self, x, start_ix=0, cp_ix=None): + """Get the global descriptor based on x, start_ix, cp_ix.""" + if cp_ix is None: + cp_ix = x.shape[1] - 1 + + x_cp = x[:, cp_ix, ...] + h_cp = self.encoder(x_cp)[0] # 1 is input for skip-connection + + return x_cp, h_cp + + def call(self, x, start_ix=0, cp_ix=-1): + batch_size = x.shape[0] + + with tf.GradientTape(persistent=True) as tape: + mse_loss = 0 + kld_loss = 0 + cpc_loss = 0 + align_loss = 0 + + seq_len = x.shape[1] + start_ix = 0 + cp_ix = seq_len - 1 + x_cp, global_z = self.get_global_descriptor( + x, start_ix, cp_ix + ) # here global_z is h_cp + + skip_prob = self.skip_prob + + prev_i = 0 + max_skip_count = seq_len * skip_prob + skip_count = 0 + probs = np.random.uniform(low=0, high=1, size=seq_len - 1) + + for i in range(1, seq_len): + if ( + probs[i - 1] <= skip_prob + and i >= self.n_past + and skip_count < max_skip_count + and i != 1 + and i != cp_ix + ): + skip_count += 1 + continue + + time_until_cp = tf.fill([batch_size, 1], (cp_ix - i + 1) / cp_ix) + delta_time = tf.fill([batch_size, 1], ((i - prev_i) / cp_ix)) + prev_i = i + + h = self.encoder(x[:, i - 1, ...]) + h_target = self.encoder(x[:, i, ...])[0] + + if self.last_frame_skip or i <= self.n_past: + h, skip = h + else: + h = h[0] + + # Control Point Aware + h_cpaw = tf.concat([h, global_z, time_until_cp, delta_time], axis=1) + h_target_cpaw = tf.concat( + [h_target, global_z, time_until_cp, delta_time], axis=1 + ) + zt, mu, logvar = self.posterior(h_target_cpaw) + zt_p, mu_p, logvar_p = self.prior(h_cpaw) + + concat = tf.concat([h, zt, time_until_cp, delta_time], axis=1) + h_pred = self.frame_predictor(concat) + x_pred = self.decoder([h_pred, skip]) + + if i == cp_ix: # the gen-cp-frame should be exactly as x_cp + h_pred_p = self.frame_predictor( + tf.concat([h, zt_p, time_until_cp, delta_time], axis=1) + ) + x_pred_p = self.decoder([h_pred_p, skip]) + cpc_loss = self.mse_criterion(x_pred_p, x_cp) + + if i > 1: + align_loss += self.align_criterion(h[0], h_pred) + + mse_loss += self.mse_criterion(x_pred, x[:, i, ...]) + kld_loss += self.kl_criterion((mu, logvar), (mu_p, logvar_p)) + + # backward + loss = mse_loss + kld_loss * self.beta + align_loss * self.weight_align + + prior_loss = kld_loss + cpc_loss * self.weight_cpc + + var_list_frame_predictor = self.frame_predictor.trainable_variables + var_list_posterior = self.posterior.trainable_variables + var_list_prior = self.prior.trainable_variables + var_list_encoder = self.encoder.trainable_variables + var_list_decoder = self.decoder.trainable_variables + + # mse: frame_predictor + decoder + # align: frame_predictor + encoder + # kld: posterior + prior + encoder + + var_list_without_prior = ( + var_list_frame_predictor + + var_list_posterior + + var_list_encoder + + var_list_decoder + ) + + gradients_without_prior = tape.gradient( + loss, + var_list_without_prior, + ) + gradients_prior = tape.gradient( + prior_loss, + var_list_prior, + ) + + self.update_model_without_prior( + gradients_without_prior, + var_list_without_prior, + ) + self.update_prior(gradients_prior, var_list_prior) + del tape + + return ( + mse_loss / seq_len, + kld_loss / seq_len, + cpc_loss / seq_len, + align_loss / seq_len, + ) + + def p2p_generate( + self, + x, + len_output, + eval_cp_ix, + start_ix=0, + cp_ix=-1, + model_mode="full", + skip_frame=False, + init_hidden=True, + ): + batch_size, num_frames, h, w, channels = x.shape + dim_shape = (h, w, channels) + + gen_seq = [x[:, 0, ...]] + x_in = x[:, 0, ...] + + seq_len = x.shape[1] + cp_ix = seq_len - 1 + + x_cp, global_z = self.get_global_descriptor( + x, cp_ix=cp_ix + ) # here global_z is h_cp + + skip_prob = self.skip_prob + + prev_i = 0 + max_skip_count = seq_len * skip_prob + skip_count = 0 + probs = np.random.uniform(0, 1, len_output - 1) + + for i in range(1, len_output): + if ( + probs[i - 1] <= skip_prob + and i >= self.n_past + and skip_count < max_skip_count + and i != 1 + and i != (len_output - 1) + and skip_frame + ): + skip_count += 1 + gen_seq.append(tf.zeros_like(x_in)) + continue + + time_until_cp = tf.fill([batch_size, 1], (eval_cp_ix - i + 1) / eval_cp_ix) + + delta_time = tf.fill([batch_size, 1], ((i - prev_i) / eval_cp_ix)) + + prev_i = i + + h = self.encoder(x_in) + + if self.last_frame_skip or i == 1 or i < self.n_past: + h, skip = h + else: + h, _ = h + + h_cpaw = tf.concat([h, global_z, time_until_cp, delta_time], axis=1) + + if i < self.n_past: + h_target = self.encoder(x[:, i, ...])[0] + h_target_cpaw = tf.concat( + [h_target, global_z, time_until_cp, delta_time], axis=1 + ) + + zt, _, _ = self.posterior(h_target_cpaw) + zt_p, _, _ = self.prior(h_cpaw) + + if model_mode == "posterior" or model_mode == "full": + self.frame_predictor( + tf.concat([h, zt, time_until_cp, delta_time], axis=1) + ) + elif model_mode == "prior": + self.frame_predictor( + tf.concat([h, zt_p, time_until_cp, delta_time], axis=1) + ) + + x_in = x[:, i, ...] + gen_seq.append(x_in) + else: + if i < num_frames: + h_target = self.encoder(x[:, i, ...])[0] + h_target_cpaw = tf.concat( + [h_target, global_z, time_until_cp, delta_time], axis=1 + ) + else: + h_target_cpaw = h_cpaw + + zt, _, _ = self.posterior(h_target_cpaw) + zt_p, _, _ = self.prior(h_cpaw) + + if model_mode == "posterior": + h = self.frame_predictor( + tf.concat([h, zt, time_until_cp, delta_time], axis=1) + ) + elif model_mode == "prior" or model_mode == "full": + h = self.frame_predictor( + tf.concat([h, zt_p, time_until_cp, delta_time], axis=1) + ) + + x_in = self.decoder([h, skip]) + gen_seq.append(x_in) + return tf.stack(gen_seq, axis=1) + + def update_model_without_prior(self, gradients, var_list): + self.frame_predictor_optimizer.apply_gradients(zip(gradients, var_list)) + self.posterior_optimizer.apply_gradients(zip(gradients, var_list)) + self.encoder_optimizer.apply_gradients(zip(gradients, var_list)) + self.decoder_optimizer.apply_gradients(zip(gradients, var_list)) + + def update_prior(self, gradients, var_list): + self.prior_optimizer.apply_gradients(zip(gradients, var_list)) + + # def update_model_without_prior(self): + # self.frame_predictor_optimizer.step() + # self.posterior_optimizer.step() + # self.encoder_optimizer.step() + # self.decoder_optimizer.step() diff --git a/ganime/model/p2p/p2p_test.py b/ganime/model/p2p/p2p_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b50af1a1010afe4d0bc0e4e3f2b467b57c80993c --- /dev/null +++ b/ganime/model/p2p/p2p_test.py @@ -0,0 +1,713 @@ +from tqdm.auto import tqdm +import numpy as np +import tensorflow as tf +from tensorflow.keras import Model, Sequential +from tensorflow.keras.layers import ( + LSTM, + LSTMCell, + Activation, + BatchNormalization, + Conv2D, + Conv2DTranspose, + Conv3D, + Conv3DTranspose, + Dense, + Flatten, + Input, + Layer, + LeakyReLU, + MaxPooling2D, + Reshape, + TimeDistributed, + UpSampling2D, +) +from tensorflow.keras.losses import Loss +from tensorflow.keras.losses import KLDivergence, MeanSquaredError + +# from tensorflow_probability.python.layers.dense_variational import ( +# DenseReparameterization, +# ) +# import tensorflow_probability as tfp +from tensorflow.keras.losses import Loss + +initializer_conv_dense = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02) +initializer_batch_norm = tf.keras.initializers.RandomNormal(mean=1.0, stddev=0.02) + + +class KLCriterion(Loss): + def call(self, y_true, y_pred): + (mu1, logvar1), (mu2, logvar2) = y_true, y_pred + + """KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2))""" + sigma1 = tf.exp(tf.math.multiply(logvar1, 0.5)) + sigma2 = tf.exp(tf.math.multiply(logvar2, 0.5)) + + kld = ( + tf.math.log(sigma2 / sigma1) + + (tf.exp(logvar1) + tf.square(mu1 - mu2)) / (2 * tf.exp(logvar2)) + - 0.5 + ) + return tf.reduce_sum(kld) / 100 + + +class Encoder(Model): + def __init__(self, dim, nc=1): + super().__init__() + self.dim = dim + self.c1 = Sequential( + [ + Conv2D( + 64, + kernel_size=4, + strides=2, + padding="same", + kernel_initializer=initializer_conv_dense, + ), + # BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.c2 = Sequential( + [ + Conv2D( + 128, + kernel_size=4, + strides=2, + padding="same", + kernel_initializer=initializer_conv_dense, + ), + # BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.c3 = Sequential( + [ + Conv2D( + 256, + kernel_size=4, + strides=2, + padding="same", + kernel_initializer=initializer_conv_dense, + ), + # BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.c4 = Sequential( + [ + Conv2D( + 512, + kernel_size=4, + strides=2, + padding="same", + kernel_initializer=initializer_conv_dense, + ), + # BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.c5 = Sequential( + [ + Conv2D( + self.dim, + kernel_size=4, + strides=1, + padding="valid", + kernel_initializer=initializer_conv_dense, + ), + # BatchNormalization(), + Activation("tanh"), + ] + ) + + def call(self, input): + h1 = self.c1(input) + h2 = self.c2(h1) + h3 = self.c3(h2) + h4 = self.c4(h3) + h5 = self.c5(h4) + return tf.reshape(h5, (-1, self.dim)), [h1, h2, h3, h4, h5] + + +class Decoder(Model): + def __init__(self, dim, nc=1): + super().__init__() + self.dim = dim + self.upc1 = Sequential( + [ + Conv2DTranspose( + 512, + kernel_size=4, + strides=1, + padding="valid", + kernel_initializer=initializer_conv_dense, + ), + # BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.upc2 = Sequential( + [ + Conv2DTranspose( + 256, + kernel_size=4, + strides=2, + padding="same", + kernel_initializer=initializer_conv_dense, + ), + # BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.upc3 = Sequential( + [ + Conv2DTranspose( + 128, + kernel_size=4, + strides=2, + padding="same", + kernel_initializer=initializer_conv_dense, + ), + # BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.upc4 = Sequential( + [ + Conv2DTranspose( + 64, + kernel_size=4, + strides=2, + padding="same", + kernel_initializer=initializer_conv_dense, + ), + # BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.upc5 = Sequential( + [ + Conv2DTranspose( + 1, + kernel_size=4, + strides=2, + padding="same", + kernel_initializer=initializer_conv_dense, + ), + Activation("sigmoid"), + ] + ) + + def call(self, input): + vec, skip = input + d1 = self.upc1(tf.reshape(vec, (-1, 1, 1, self.dim))) + d2 = self.upc2(tf.concat([d1, skip[3]], axis=-1)) + d3 = self.upc3(tf.concat([d2, skip[2]], axis=-1)) + d4 = self.upc4(tf.concat([d3, skip[1]], axis=-1)) + output = self.upc5(tf.concat([d4, skip[0]], axis=-1)) + return output + + +class MyLSTM(Model): + def __init__(self, input_shape, hidden_size, output_size, n_layers): + super().__init__() + self.hidden_size = hidden_size + self.n_layers = n_layers + self.embed = Dense( + hidden_size, + input_dim=input_shape, + kernel_initializer=initializer_conv_dense, + ) + # self.lstm = Sequential( + # [LSTMCell(hidden_size) for _ in range(n_layers)], name="lstm" + # ) + # self.lstm = self.create_lstm(hidden_size, n_layers) + self.lstm = [ + LSTMCell( + hidden_size # , return_sequences=False if i == self.n_layers - 1 else True + ) + for i in range(self.n_layers) + ] # LSTMCell(hidden_size) + self.lstm_rnn = tf.keras.layers.RNN(self.lstm[0], return_state=True) + self.out = Dense(output_size, kernel_initializer=initializer_conv_dense) + + def init_hidden(self, batch_size): + hidden = [] + for i in range(self.n_layers): + hidden.append( + ( + tf.Variable(tf.zeros([batch_size, self.hidden_size])), + tf.Variable(tf.zeros([batch_size, self.hidden_size])), + ) + ) + self.__dict__["hidden"] = hidden + + def build(self, input_shape): + self.init_hidden(input_shape[0]) + + def call(self, inputs): + h_in = self.embed(inputs) + h_in = tf.reshape(h_in, (-1, 1, self.hidden_size)) + h_in, *state = self.lstm_rnn(h_in) + for i in range(self.n_layers): + h_in, state = self.lstm[i](h_in, state) + return self.out(h_in) + + +class MyGaussianLSTM(Model): + def __init__(self, input_shape, hidden_size, output_size, n_layers): + super().__init__() + self.hidden_size = hidden_size + self.n_layers = n_layers + self.embed = Dense( + hidden_size, + input_dim=input_shape, + kernel_initializer=initializer_conv_dense, + ) + # self.lstm = Sequential( + # [LSTMCell(hidden_size) for _ in range(n_layers)], name="lstm" + # ) + self.lstm = [ + LSTMCell( + hidden_size # , return_sequences=False if i == self.n_layers - 1 else True + ) + for i in range(self.n_layers) + ] # LSTMCell(hidden_size) + self.lstm_rnn = tf.keras.layers.RNN(self.lstm[0], return_state=True) + self.mu_net = Dense(output_size, kernel_initializer=initializer_conv_dense) + self.logvar_net = Dense(output_size, kernel_initializer=initializer_conv_dense) + # self.out = Sequential( + # [ + # tf.keras.layers.Dense( + # tfp.layers.MultivariateNormalTriL.params_size(output_size), + # activation=None, + # ), + # tfp.layers.MultivariateNormalTriL(output_size), + # ] + # ) + + def reparameterize(self, mu, logvar: tf.Tensor): + logvar = tf.math.exp(logvar * 0.5) + eps = tf.random.normal(logvar.shape) + return tf.add(tf.math.multiply(eps, logvar), mu) + + def init_hidden(self, batch_size): + hidden = [] + for i in range(self.n_layers): + hidden.append( + ( + tf.Variable(tf.zeros([batch_size, self.hidden_size])), + tf.Variable(tf.zeros([batch_size, self.hidden_size])), + ) + ) + self.__dict__["hidden"] = hidden + + def build(self, input_shape): + self.init_hidden(input_shape[0]) + + def call(self, inputs): + h_in = self.embed(inputs) + # for i in range(self.n_layers): + # # print(h_in.shape, self.hidden[i][0].shape, self.hidden[i][0].shape) + + # _, self.hidden[i] = self.lstm(h_in, self.hidden[i]) + # h_in = self.hidden[i][0] + h_in = tf.reshape(h_in, (-1, 1, self.hidden_size)) + h_in, *state = self.lstm_rnn(h_in) + for i in range(self.n_layers): + h_in, state = self.lstm[i](h_in, state) + + mu = self.mu_net(h_in) + logvar = self.logvar_net(h_in) + z = self.reparameterize(mu, logvar) + return z, mu, logvar + + +class P2P(Model): + def __init__( + self, + channels: int = 1, + g_dim: int = 128, + z_dim: int = 10, + rnn_size: int = 256, + prior_rnn_layers: int = 1, + posterior_rnn_layers: int = 1, + predictor_rnn_layers: float = 2, + skip_prob: float = 0.5, + n_past: int = 1, + last_frame_skip: bool = False, + beta: float = 0.0001, + weight_align: float = 0.1, + weight_cpc: float = 100, + ): + super().__init__() + self.channels = channels + self.g_dim = g_dim + self.z_dim = z_dim + self.rnn_size = rnn_size + self.prior_rnn_layers = prior_rnn_layers + self.posterior_rnn_layers = posterior_rnn_layers + self.predictor_rnn_layers = predictor_rnn_layers + + self.skip_prob = skip_prob + self.n_past = n_past + self.last_frame_skip = last_frame_skip + self.beta = beta + self.weight_align = weight_align + self.weight_cpc = weight_cpc + + self.frame_predictor = MyLSTM( + self.g_dim + self.z_dim + 1 + 1, + self.rnn_size, + self.g_dim, + self.predictor_rnn_layers, + ) + + self.prior = MyGaussianLSTM( + self.g_dim + self.g_dim + 1 + 1, + self.rnn_size, + self.z_dim, + self.prior_rnn_layers, + ) + + self.posterior = MyGaussianLSTM( + self.g_dim + self.g_dim + 1 + 1, + self.rnn_size, + self.z_dim, + self.posterior_rnn_layers, + ) + + self.encoder = Encoder(self.g_dim, self.channels) + self.decoder = Decoder(self.g_dim, self.channels) + + # criterions + self.mse_criterion = tf.keras.losses.MeanSquaredError() + self.kl_criterion = KLCriterion() + self.align_criterion = tf.keras.losses.MeanSquaredError() + + self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss") + self.reconstruction_loss_tracker = tf.keras.metrics.Mean( + name="reconstruction_loss" + ) + self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss") + self.align_loss_tracker = tf.keras.metrics.Mean(name="align_loss") + self.cpc_loss_tracker = tf.keras.metrics.Mean(name="align_loss") + + # optimizers + # self.frame_predictor_optimizer = tf.keras.optimizers.Adam( + # learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8 + # ) + # self.posterior_optimizer = tf.keras.optimizers.Adam( + # learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8 + # ) + # self.prior_optimizer = tf.keras.optimizers.Adam( + # learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8 + # ) + # self.encoder_optimizer = tf.keras.optimizers.Adam( + # learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8 + # ) + # self.decoder_optimizer = tf.keras.optimizers.Adam( + # learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8 + # ) + + @property + def metrics(self): + return [ + self.total_loss_tracker, + self.reconstruction_loss_tracker, + self.kl_loss_tracker, + self.align_loss_tracker, + self.cpc_loss_tracker, + ] + + def get_global_descriptor(self, x, start_ix=0, cp_ix=None): + """Get the global descriptor based on x, start_ix, cp_ix.""" + if cp_ix is None: + cp_ix = x.shape[1] - 1 + + x_cp = x[:, cp_ix, ...] + h_cp = self.encoder(x_cp)[0] # 1 is input for skip-connection + + return x_cp, h_cp + + def compile( + self, + frame_predictor_optimizer, + prior_optimizer, + posterior_optimizer, + encoder_optimizer, + decoder_optimizer, + ): + super().compile() + self.frame_predictor_optimizer = frame_predictor_optimizer + self.prior_optimizer = prior_optimizer + self.posterior_optimizer = posterior_optimizer + self.encoder_optimizer = encoder_optimizer + self.decoder_optimizer = decoder_optimizer + + def train_step(self, data): + y, x = data + batch_size = 100 + + mse_loss = 0 + kld_loss = 0 + cpc_loss = 0 + align_loss = 0 + + seq_len = x.shape[1] + start_ix = 0 + cp_ix = seq_len - 1 + x_cp, global_z = self.get_global_descriptor( + x, start_ix, cp_ix + ) # here global_z is h_cp + + skip_prob = self.skip_prob + + prev_i = 0 + max_skip_count = seq_len * skip_prob + skip_count = 0 + probs = np.random.uniform(low=0, high=1, size=seq_len - 1) + + with tf.GradientTape(persistent=True) as tape: + for i in tqdm(range(1, seq_len)): + if ( + probs[i - 1] <= skip_prob + and i >= self.n_past + and skip_count < max_skip_count + and i != 1 + and i != cp_ix + ): + skip_count += 1 + continue + + if i > 1: + align_loss += self.align_criterion(h, h_pred) + + time_until_cp = tf.fill( + [batch_size, 1], + (cp_ix - i + 1) / cp_ix, + ) + delta_time = tf.fill([batch_size, 1], ((i - prev_i) / cp_ix)) + prev_i = i + + h = self.encoder(x[:, i - 1, ...]) + h_target = self.encoder(x[:, i, ...])[0] + + if self.last_frame_skip or i <= self.n_past: + h, skip = h + else: + h = h[0] + + # Control Point Aware + h_cpaw = tf.concat([h, global_z, time_until_cp, delta_time], axis=-1) + h_target_cpaw = tf.concat( + [h_target, global_z, time_until_cp, delta_time], axis=-1 + ) + + zt, mu, logvar = self.posterior(h_target_cpaw) + zt_p, mu_p, logvar_p = self.prior(h_cpaw) + + frame_predictor_input = tf.concat( + [h, zt, time_until_cp, delta_time], axis=-1 + ) + h_pred = self.frame_predictor(frame_predictor_input) + x_pred = self.decoder([h_pred, skip]) + + if i == cp_ix: # the gen-cp-frame should be exactly as x_cp + h_pred_p = self.frame_predictor( + tf.concat([h, zt_p, time_until_cp, delta_time], axis=-1) + ) + x_pred_p = self.decoder([h_pred_p, skip]) + cpc_loss = self.mse_criterion(x_pred_p, x_cp) + + mse_loss += self.mse_criterion(x_pred, x[:, i, ...]) + kld_loss += self.kl_criterion((mu, logvar), (mu_p, logvar_p)) + + # backward + loss = ( + mse_loss + + kld_loss * self.beta + + align_loss * self.weight_align + # + cpc_loss * self.weight_cpc + ) + + prior_loss = kld_loss + cpc_loss * self.weight_cpc + + var_list_frame_predictor = self.frame_predictor.trainable_variables + var_list_posterior = self.posterior.trainable_variables + var_list_prior = self.prior.trainable_variables + var_list_encoder = self.encoder.trainable_variables + var_list_decoder = self.decoder.trainable_variables + + # mse: frame_predictor + decoder + # align: frame_predictor + encoder + # kld: posterior + prior + encoder + + var_list = ( + var_list_frame_predictor + + var_list_posterior + + var_list_encoder + + var_list_decoder + + var_list_prior + ) + + gradients = tape.gradient( + loss, + var_list, + ) + gradients_prior = tape.gradient( + prior_loss, + var_list_prior, + ) + + self.update_model( + gradients, + var_list, + ) + self.update_prior(gradients_prior, var_list_prior) + del tape + + self.total_loss_tracker.update_state(loss) + self.kl_loss_tracker.update_state(kld_loss) + self.align_loss_tracker.update_state(align_loss) + self.reconstruction_loss_tracker.update_state(mse_loss) + self.cpc_loss_tracker.update_state(cpc_loss) + + return { + "loss": self.total_loss_tracker.result(), + "reconstruction_loss": self.reconstruction_loss_tracker.result(), + "kl_loss": self.kl_loss_tracker.result(), + "align_loss": self.align_loss_tracker.result(), + "cpc_loss": self.cpc_loss_tracker.result(), + } + + def call( + self, + inputs, + training=None, + mask=None + # len_output, + # eval_cp_ix, + # start_ix=0, + # cp_ix=-1, + # model_mode="full", + # skip_frame=False, + # init_hidden=True, + ): + len_output = 20 + eval_cp_ix = len_output - 1 + start_ix = 0 + cp_ix = -1 + model_mode = "full" + skip_frame = False + init_hidden = True + + batch_size, num_frames, h, w, channels = inputs.shape + dim_shape = (h, w, channels) + + gen_seq = [inputs[:, 0, ...]] + x_in = inputs[:, 0, ...] + + seq_len = inputs.shape[1] + cp_ix = seq_len - 1 + + x_cp, global_z = self.get_global_descriptor( + inputs, cp_ix=cp_ix + ) # here global_z is h_cp + + skip_prob = self.skip_prob + + prev_i = 0 + max_skip_count = seq_len * skip_prob + skip_count = 0 + probs = np.random.uniform(0, 1, len_output - 1) + + for i in range(1, len_output): + if ( + probs[i - 1] <= skip_prob + and i >= self.n_past + and skip_count < max_skip_count + and i != 1 + and i != (len_output - 1) + and skip_frame + ): + skip_count += 1 + gen_seq.append(tf.zeros_like(x_in)) + continue + + time_until_cp = tf.fill([100, 1], (eval_cp_ix - i + 1) / eval_cp_ix) + + delta_time = tf.fill([100, 1], ((i - prev_i) / eval_cp_ix)) + + prev_i = i + + h = self.encoder(x_in) + + if self.last_frame_skip or i == 1 or i < self.n_past: + h, skip = h + else: + h, _ = h + + h_cpaw = tf.stop_gradient(tf.concat([h, global_z, time_until_cp, delta_time], axis=-1)) + + if i < self.n_past: + h_target = self.encoder(inputs[:, i, ...])[0] + h_target_cpaw = tf.stop_gradient(tf.concat( + [h_target, global_z, time_until_cp, delta_time], axis=1 + )) + + zt, _, _ = self.posterior(h_target_cpaw) + zt_p, _, _ = self.prior(h_cpaw) + + if model_mode == "posterior" or model_mode == "full": + self.frame_predictor( + tf.concat([h, zt, time_until_cp, delta_time], axis=-1) + ) + elif model_mode == "prior": + self.frame_predictor( + tf.concat([h, zt_p, time_until_cp, delta_time], axis=-1) + ) + + x_in = inputs[:, i, ...] + gen_seq.append(x_in) + else: + if i < num_frames: + h_target = self.encoder(inputs[:, i, ...])[0] + h_target_cpaw = tf.stop_gradient(tf.concat( + [h_target, global_z, time_until_cp, delta_time], axis=-1 + )) + else: + h_target_cpaw = h_cpaw + + zt, _, _ = self.posterior(h_target_cpaw) + zt_p, _, _ = self.prior(h_cpaw) + + if model_mode == "posterior": + h = self.frame_predictor( + tf.concat([h, zt, time_until_cp, delta_time], axis=-1) + ) + elif model_mode == "prior" or model_mode == "full": + h = self.frame_predictor( + tf.concat([h, zt_p, time_until_cp, delta_time], axis=-1) + ) + + x_in = tf.stop_gradient(self.decoder([h, skip])) + gen_seq.append(x_in) + + return tf.stack(gen_seq, axis=1) + + def update_model(self, gradients, var_list): + self.frame_predictor_optimizer.apply_gradients(zip(gradients, var_list)) + self.posterior_optimizer.apply_gradients(zip(gradients, var_list)) + self.encoder_optimizer.apply_gradients(zip(gradients, var_list)) + self.decoder_optimizer.apply_gradients(zip(gradients, var_list)) + #self.prior_optimizer.apply_gradients(zip(gradients, var_list)) + + def update_prior(self, gradients, var_list): + self.prior_optimizer.apply_gradients(zip(gradients, var_list)) + + # def update_model_without_prior(self): + # self.frame_predictor_optimizer.step() + # self.posterior_optimizer.step() + # self.encoder_optimizer.step() + # self.decoder_optimizer.step() diff --git a/ganime/model/p2p/p2p_v2.py b/ganime/model/p2p/p2p_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..765bca7c22edee525171a6910275626e920de9e8 --- /dev/null +++ b/ganime/model/p2p/p2p_v2.py @@ -0,0 +1,498 @@ +import numpy as np +import tensorflow as tf +from tensorflow.keras import Model, Sequential +from tensorflow.keras.layers import ( + LSTM, + Activation, + BatchNormalization, + Conv2D, + Conv2DTranspose, + Conv3D, + Conv3DTranspose, + Dense, + Flatten, + Input, + Layer, + LeakyReLU, + MaxPooling2D, + Reshape, + TimeDistributed, + UpSampling2D, +) +from tensorflow.keras.losses import Loss +from tensorflow.keras.losses import KLDivergence, MeanSquaredError +from tqdm.auto import tqdm + + +class KLCriterion(Loss): + def call(self, y_true, y_pred): + (mu1, logvar1), (mu2, logvar2) = y_true, y_pred + + """KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2))""" + sigma1 = tf.exp(tf.math.multiply(logvar1, 0.5)) + sigma2 = tf.exp(tf.math.multiply(logvar2, 0.5)) + + kld = ( + tf.math.log(sigma2 / sigma1) + + (tf.exp(logvar1) + tf.square(mu1 - mu2)) / (2 * tf.exp(logvar2)) + - 0.5 + ) + return kld + + +class Decoder(Model): + def __init__(self, dim, nc=1): + super().__init__() + self.dim = dim + self.upc1 = Sequential( + [ + TimeDistributed( + Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid") + ), + BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.upc2 = Sequential( + [ + TimeDistributed( + Conv2DTranspose(256, kernel_size=4, strides=2, padding="same") + ), + BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.upc3 = Sequential( + [ + TimeDistributed( + Conv2DTranspose(128, kernel_size=4, strides=2, padding="same") + ), + BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.upc4 = Sequential( + [ + TimeDistributed( + Conv2DTranspose(64, kernel_size=4, strides=2, padding="same") + ), + BatchNormalization(), + LeakyReLU(alpha=0.2), + ] + ) + self.upc5 = Sequential( + [ + TimeDistributed( + Conv2DTranspose(1, kernel_size=4, strides=2, padding="same") + ), + Activation("sigmoid"), + ] + ) + + def call(self, input): + vec, skip = input + d1 = self.upc1(tf.reshape(vec, (-1, 1, 1, 1, self.dim))) + d2 = self.upc2(tf.concat([d1, skip[3]], axis=-1)) + d3 = self.upc3(tf.concat([d2, skip[2]], axis=-1)) + d4 = self.upc4(tf.concat([d3, skip[1]], axis=-1)) + output = self.upc5(tf.concat([d4, skip[0]], axis=-1)) + return output + + +class Sampling(Layer): + """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.""" + + def call(self, inputs): + z_mean, z_log_var = inputs + batch = tf.shape(z_mean)[0] + dim = tf.shape(z_mean)[1] + epsilon = tf.keras.backend.random_normal(shape=(batch, dim)) + return z_mean + tf.exp(0.5 * z_log_var) * epsilon + + def compute_output_shape(self, input_shape): + return input_shape[0] + + +class P2P(Model): + def __init__( + self, + channels: int = 1, + g_dim: int = 128, + z_dim: int = 10, + rnn_size: int = 256, + prior_rnn_layers: int = 1, + posterior_rnn_layers: int = 1, + predictor_rnn_layers: float = 1, + skip_prob: float = 0.1, + n_past: int = 1, + last_frame_skip: bool = False, + beta: float = 0.0001, + weight_align: float = 0.1, + weight_cpc: float = 100, + ): + super().__init__() + # Models parameters + self.channels = channels + self.g_dim = g_dim + self.z_dim = z_dim + self.rnn_size = rnn_size + self.prior_rnn_layers = prior_rnn_layers + self.posterior_rnn_layers = posterior_rnn_layers + self.predictor_rnn_layers = predictor_rnn_layers + + # Training parameters + self.skip_prob = skip_prob + self.n_past = n_past + self.last_frame_skip = last_frame_skip + self.beta = beta + self.weight_align = weight_align + self.weight_cpc = weight_cpc + + self.frame_predictor = self.build_lstm() + self.prior = self.build_gaussian_lstm() + self.posterior = self.build_gaussian_lstm() + self.encoder = self.build_encoder() + self.decoder = self.build_decoder() + + self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss") + self.reconstruction_loss_tracker = tf.keras.metrics.Mean( + name="reconstruction_loss" + ) + self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss") + self.align_loss_tracker = tf.keras.metrics.Mean(name="align_loss") + self.cpc_loss_tracker = tf.keras.metrics.Mean(name="align_loss") + + self.kl_loss = KLCriterion( + reduction=tf.keras.losses.Reduction.NONE + ) # KLDivergence(reduction=tf.keras.losses.Reduction.NONE) + self.mse = MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE) + self.align_loss = MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE) + + # self.optimizer = tf.keras.optimizers.Adam( + # learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8 + # ) + # self.prior_optimizer = tf.keras.optimizers.Adam( + # learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8 + # ) + + # region Model building + def build_lstm(self): + input = Input(shape=(None, self.g_dim + self.z_dim)) + embed = TimeDistributed(Dense(self.rnn_size))(input) + lstm = LSTM(self.rnn_size)(embed) + output = Dense(self.g_dim)(lstm) + output = (tf.expand_dims(output, axis=1),) + + return Model(inputs=input, outputs=output, name="frame_predictor") + + def build_gaussian_lstm(self): + + input = Input(shape=(None, self.g_dim)) + embed = TimeDistributed(Dense(self.rnn_size))(input) + lstm = LSTM(self.rnn_size)(embed) + mu = Dense(self.z_dim)(lstm) + logvar = Dense(self.z_dim)(lstm) + z = Sampling()([mu, logvar]) + + return Model(inputs=input, outputs=[mu, logvar, z]) + + def build_encoder(self): + + input = Input(shape=(1, 64, 64, 1)) + + h = TimeDistributed(Conv2D(64, kernel_size=4, strides=2, padding="same"))(input) + h = BatchNormalization()(h) + h1 = LeakyReLU(alpha=0.2)(h) + # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h) + + h = TimeDistributed(Conv2D(128, kernel_size=4, strides=2, padding="same"))(h1) + h = BatchNormalization()(h) + h2 = LeakyReLU(alpha=0.2)(h) + # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h) + + h = TimeDistributed(Conv2D(256, kernel_size=4, strides=2, padding="same"))(h2) + h = BatchNormalization()(h) + h3 = LeakyReLU(alpha=0.2)(h) + # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h) + + h = TimeDistributed(Conv2D(512, kernel_size=4, strides=2, padding="same"))(h3) + h = BatchNormalization()(h) + h4 = LeakyReLU(alpha=0.2)(h) + # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h) + + h = TimeDistributed( + Conv2D(self.g_dim, kernel_size=4, strides=1, padding="valid") + )(h4) + h = BatchNormalization()(h) + h5 = Activation("tanh")(h) + + output = tf.reshape(h5, (-1, 1, self.g_dim)) + # h = Flatten()(h) + # output = Dense(self.g_dim)(h) + # output = tf.expand_dims(output, axis=1) + return Model(inputs=input, outputs=[output, [h1, h2, h3, h4]], name="encoder") + + def build_decoder(self): + return Decoder(self.g_dim) + + # def build_decoder(self): + # latent_inputs = Input( + # shape=( + # 1, + # self.g_dim, + # ) + # ) + # x = Dense(1 * 1 * 1 * 128, activation="relu")(latent_inputs) + # x = Reshape((1, 1, 1, 128))(x) + # x = TimeDistributed( + # Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid") + # )(x) + # x = BatchNormalization()(x) + # x1 = LeakyReLU(alpha=0.2)(x) + + # x = TimeDistributed( + # Conv2DTranspose(256, kernel_size=4, strides=2, padding="same") + # )(x1) + # x = BatchNormalization()(x) + # x2 = LeakyReLU(alpha=0.2)(x) + + # x = TimeDistributed( + # Conv2DTranspose(128, kernel_size=4, strides=2, padding="same") + # )(x2) + # x = BatchNormalization()(x) + # x3 = LeakyReLU(alpha=0.2)(x) + + # x = TimeDistributed( + # Conv2DTranspose(64, kernel_size=4, strides=2, padding="same") + # )(x3) + # x = BatchNormalization()(x) + # x4 = LeakyReLU(alpha=0.2)(x) + + # x = TimeDistributed( + # Conv2DTranspose(1, kernel_size=4, strides=2, padding="same") + # )(x4) + # x5 = Activation("sigmoid")(x) + + # return Model(inputs=latent_inputs, outputs=x5, name="decoder") + + # endregion + + @property + def metrics(self): + return [ + self.total_loss_tracker, + self.reconstruction_loss_tracker, + self.kl_loss_tracker, + self.align_loss_tracker, + self.cpc_loss_tracker, + ] + + def call(self, inputs, training=None, mask=None): + first_frame = inputs[:, 0:1, ...] + last_frame = inputs[:, -1:, ...] + + desired_length = 20 + previous_frame = first_frame + generated = [first_frame] + + z_last, _ = self.encoder(last_frame) + for i in range(1, desired_length): + + z_prev = self.encoder(previous_frame) + + if self.last_frame_skip or i == 1 or i < self.n_past: + z_prev, skip = z_prev + else: + z_prev = z_prev[0] + + prior_input = tf.concat([z_prev, z_last], axis=1) + + z_mean_prior, z_log_var_prior, z_prior = self.prior(prior_input) + + predictor_input = tf.concat( + (z_prev, tf.expand_dims(z_prior, axis=1)), axis=-1 + ) + z_pred = self.frame_predictor(predictor_input) + + current_frame = self.decoder([z_pred, skip]) + generated.append(current_frame) + previous_frame = current_frame + return tf.concat(generated, axis=1) + + def train_step(self, data): + global_batch_size = 100 # * 8 + x, y = data + + first_frame = x[:, 0:1, ...] + last_frame = x[:, -1:, ...] + desired_length = y.shape[1] + previous_frame = first_frame + + reconstruction_loss = 0 + kl_loss = 0 + align_loss = 0 + cpc_loss = 0 + + with tf.GradientTape(persistent=True) as tape: + z_last, _ = self.encoder(last_frame) + for i in tqdm(range(1, desired_length)): + current_frame = y[:, i : i + 1, ...] + + z_prev = self.encoder(previous_frame) + + if self.last_frame_skip or i <= self.n_past: + z_prev, skip = z_prev + else: + z_prev = z_prev[0] + + z_curr, _ = self.encoder(current_frame) + + prior_input = tf.concat([z_prev, z_last], axis=1) + posterior_input = tf.concat([z_curr, z_last], axis=1) + + z_mean_prior, z_log_var_prior, z_prior = self.prior(prior_input) + z_mean_posterior, z_log_var_posterior, z_posterior = self.posterior( + posterior_input + ) + + # predictor_input = z_prev + predictor_input = tf.concat( + (z_prev, tf.expand_dims(z_posterior, axis=1)), axis=-1 + ) + + z_pred = self.frame_predictor(predictor_input) + + kl_loss += tf.reduce_sum( + self.kl_loss( + (z_mean_prior, z_log_var_prior), + (z_mean_posterior, z_log_var_posterior), + ) + ) * (1.0 / global_batch_size) + + if i > 1: + align_loss += tf.reduce_sum(self.align_loss(z_pred, z_curr)) * ( + 1.0 / global_batch_size + ) + + if i == desired_length - 1: + h_pred_p = self.frame_predictor( + tf.concat([z_prev, tf.expand_dims(z_prior, axis=1)], axis=-1) + ) + x_pred_p = self.decoder([h_pred_p, skip]) + cpc_loss = tf.reduce_sum(self.mse(x_pred_p, current_frame)) * ( + 1.0 / global_batch_size + ) + + prediction = self.decoder([z_pred, skip]) + reconstruction_loss += tf.reduce_sum( + self.mse(prediction, current_frame) + ) * (1.0 / global_batch_size) + + previous_frame = current_frame + + loss = ( + reconstruction_loss + + kl_loss * self.beta + + align_loss * self.weight_align + + cpc_loss * self.weight_cpc + ) + + prior_loss = kl_loss + cpc_loss * self.weight_cpc + + grads_without_prior = tape.gradient( + loss, + ( + self.encoder.trainable_weights + + self.decoder.trainable_weights + + self.posterior.trainable_weights + + self.frame_predictor.trainable_weights + ), + ) + self.optimizer.apply_gradients( + zip( + grads_without_prior, + ( + self.encoder.trainable_weights + + self.decoder.trainable_weights + + self.posterior.trainable_weights + + self.frame_predictor.trainable_weights + ), + ) + ) + + grads_prior = tape.gradient( + prior_loss, + self.prior.trainable_weights, + ) + + self.optimizer.apply_gradients( + zip( + grads_prior, + self.prior.trainable_weights, + ) + ) + del tape + + self.total_loss_tracker.update_state(loss) + self.kl_loss_tracker.update_state(kl_loss) + self.align_loss_tracker.update_state(align_loss) + self.reconstruction_loss_tracker.update_state(reconstruction_loss) + self.cpc_loss_tracker.update_state(cpc_loss) + + return { + "loss": self.total_loss_tracker.result(), + "reconstruction_loss": self.reconstruction_loss_tracker.result(), + "kl_loss": self.kl_loss_tracker.result(), + "align_loss": self.align_loss_tracker.result(), + "cpc_loss": self.cpc_loss_tracker.result(), + } + + # print("KL_LOSS") + # print(kl_loss) + # print("ALIGN_LOSS") + # print(align_loss) + # print("RECONSTRUCTION_LOSS") + # print(reconstruction_loss) + + # with tf.GradientTape() as tape: + # z_mean, z_log_var, z = self.encoder(x) + # reconstruction = self.decoder(z) + # reconstruction_loss = tf.reduce_mean( + # tf.reduce_sum( + # tf.keras.losses.binary_crossentropy(y, reconstruction), + # axis=(1, 2), + # ) + # ) + # kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) + # kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1)) + # total_loss = reconstruction_loss + self.kl_beta * kl_loss + # grads = tape.gradient(total_loss, self.trainable_weights) + # self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) + # self.total_loss_tracker.update_state(total_loss) + # self.reconstruction_loss_tracker.update_state(reconstruction_loss) + # self.kl_loss_tracker.update_state(kl_loss) + # return { + # "loss": self.total_loss_tracker.result(), + # "reconstruction_loss": self.reconstruction_loss_tracker.result(), + # "kl_loss": self.kl_loss_tracker.result(), + # } + + # def test_step(self, data): + # if isinstance(data, tuple): + # data = data[0] + + # z_mean, z_log_var, z = self.encoder(data) + # reconstruction = self.decoder(z) + # reconstruction_loss = tf.reduce_mean( + # tf.keras.losses.binary_crossentropy(data, reconstruction) + # ) + # reconstruction_loss *= 28 * 28 + # kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + # kl_loss = tf.reduce_mean(kl_loss) + # kl_loss *= -0.5 + # total_loss = reconstruction_loss + kl_loss + # return { + # "loss": total_loss, + # "reconstruction_loss": reconstruction_loss, + # "kl_loss": kl_loss, + # } diff --git a/ganime/model/p2p/p2p_v3.py b/ganime/model/p2p/p2p_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..2ffd64dd01c5daf4e7e346daf06cb912b432071f --- /dev/null +++ b/ganime/model/p2p/p2p_v3.py @@ -0,0 +1,237 @@ +import numpy as np +import tensorflow as tf +from tensorflow.keras import Model, Sequential +from tensorflow.keras.layers import ( + LSTM, + Activation, + BatchNormalization, + Conv2D, + Conv2DTranspose, + Conv3D, + Conv3DTranspose, + Dense, + Flatten, + Input, + Layer, + LeakyReLU, + MaxPooling2D, + Reshape, + TimeDistributed, + UpSampling2D, +) + + +SEQ_LEN = 20 + + +class Sampling(Layer): + """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.""" + + def call(self, inputs): + z_mean, z_log_var = inputs + batch = tf.shape(z_mean)[0] + dim = tf.shape(z_mean)[1] + epsilon = tf.keras.backend.random_normal(shape=(batch, dim)) + return z_mean + tf.exp(0.5 * z_log_var) * epsilon + + def compute_output_shape(self, input_shape): + return input_shape[0] + + +class P2P(Model): + def __init__( + self, + channels: int = 1, + g_dim: int = 128, + z_dim: int = 10, + rnn_size: int = 256, + prior_rnn_layers: int = 1, + posterior_rnn_layers: int = 1, + predictor_rnn_layers: float = 1, + skip_prob: float = 0.1, + n_past: int = 1, + last_frame_skip: bool = False, + beta: float = 0.0001, + weight_align: float = 0.1, + weight_cpc: float = 100, + ): + super().__init__() + # Models parameters + self.channels = channels + self.g_dim = g_dim + self.z_dim = z_dim + self.rnn_size = rnn_size + self.prior_rnn_layers = prior_rnn_layers + self.posterior_rnn_layers = posterior_rnn_layers + self.predictor_rnn_layers = predictor_rnn_layers + + # Training parameters + self.skip_prob = skip_prob + self.n_past = n_past + self.last_frame_skip = last_frame_skip + self.beta = beta + self.weight_align = weight_align + self.weight_cpc = weight_cpc + + self.frame_predictor = self.build_lstm() + self.prior = self.build_gaussian_lstm() + self.posterior = self.build_gaussian_lstm() + self.encoder = self.build_encoder() + self.decoder = self.build_decoder() + + self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss") + self.reconstruction_loss_tracker = tf.keras.metrics.Mean( + name="reconstruction_loss" + ) + self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss") + + # region Model building + def build_lstm(self): + input = Input(shape=(20, self.g_dim + self.z_dim + 1)) + embed = TimeDistributed(Dense(self.rnn_size))(input) + lstm = LSTM(self.rnn_size, return_sequences=True)(embed) + output = TimeDistributed(Dense(self.g_dim))(lstm) + + return Model(inputs=input, outputs=output, name="frame_predictor") + + def build_gaussian_lstm(self): + + input = Input(shape=(20, self.g_dim)) + embed = TimeDistributed(Dense(self.rnn_size))(input) + lstm = LSTM(self.rnn_size, return_sequences=True)(embed) + mu = TimeDistributed(Dense(self.z_dim))(lstm) + logvar = TimeDistributed(Dense(self.z_dim))(lstm) + z = TimeDistributed(Sampling())([mu, logvar]) + + return Model(inputs=input, outputs=[mu, logvar, z]) + + def build_encoder(self): + + input = Input(shape=(2, 64, 64, 1)) + + h = TimeDistributed(Conv2D(64, kernel_size=4, strides=2, padding="same"))(input) + h = BatchNormalization()(h) + h = LeakyReLU(alpha=0.2)(h) + # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h) + + h = TimeDistributed(Conv2D(128, kernel_size=4, strides=2, padding="same"))(h) + h = BatchNormalization()(h) + h = LeakyReLU(alpha=0.2)(h) + # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h) + + h = TimeDistributed(Conv2D(256, kernel_size=4, strides=2, padding="same"))(h) + h = BatchNormalization()(h) + h = LeakyReLU(alpha=0.2)(h) + # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h) + + h = TimeDistributed(Conv2D(512, kernel_size=4, strides=2, padding="same"))(h) + h = BatchNormalization()(h) + h = LeakyReLU(alpha=0.2)(h) + # h = TimeDistributed(MaxPooling2D(pool_size=2, strides=2, padding="same"))(h) + + h = Flatten()(h) + # mu = Dense(self.g_dim)(h) + # logvar = Dense(self.g_dim)(h) + + # z = Sampling()([mu, logvar]) + lstm_input = Dense(self.g_dim * SEQ_LEN)(h) + lstm_input = Reshape((SEQ_LEN, self.g_dim))(lstm_input) + mu, logvar, z = self.posterior(lstm_input) + + return Model(inputs=input, outputs=[mu, logvar, z], name="encoder") + + def build_decoder(self): + latent_inputs = Input(shape=(SEQ_LEN, self.z_dim)) + x = Dense(1 * 1 * 1 * 512, activation="relu")(latent_inputs) + x = Reshape((SEQ_LEN, 1, 1, 512))(x) + x = TimeDistributed( + Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid") + )(x) + x = BatchNormalization()(x) + x = LeakyReLU(alpha=0.2)(x) + + x = TimeDistributed( + Conv2DTranspose(256, kernel_size=4, strides=2, padding="same") + )(x) + x = BatchNormalization()(x) + x = LeakyReLU(alpha=0.2)(x) + + x = TimeDistributed( + Conv2DTranspose(128, kernel_size=4, strides=2, padding="same") + )(x) + x = BatchNormalization()(x) + x = LeakyReLU(alpha=0.2)(x) + + x = TimeDistributed( + Conv2DTranspose(64, kernel_size=4, strides=2, padding="same") + )(x) + x = BatchNormalization()(x) + x = LeakyReLU(alpha=0.2)(x) + + x = TimeDistributed( + Conv2DTranspose(1, kernel_size=4, strides=2, padding="same") + )(x) + x = Activation("sigmoid")(x) + + return Model(inputs=latent_inputs, outputs=x, name="decoder") + + # endregion + + @property + def metrics(self): + return [ + self.total_loss_tracker, + self.reconstruction_loss_tracker, + self.kl_loss_tracker, + ] + + def call(self, inputs, training=None, mask=None): + z_mean, z_log_var, z = self.encoder(inputs) + pred = self.decoder(z) + return pred + + def train_step(self, data): + x, y = data + + with tf.GradientTape() as tape: + z_mean, z_log_var, z = self.encoder(x) + reconstruction = self.decoder(z) + reconstruction_loss = tf.reduce_mean( + tf.reduce_sum( + tf.keras.losses.binary_crossentropy(y, reconstruction), + axis=(1, 2), + ) + ) + kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) + kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1)) + total_loss = reconstruction_loss + self.beta * kl_loss + grads = tape.gradient(total_loss, self.trainable_weights) + self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) + self.total_loss_tracker.update_state(total_loss) + self.reconstruction_loss_tracker.update_state(reconstruction_loss) + self.kl_loss_tracker.update_state(kl_loss) + return { + "loss": self.total_loss_tracker.result(), + "reconstruction_loss": self.reconstruction_loss_tracker.result(), + "kl_loss": self.kl_loss_tracker.result(), + } + + def test_step(self, data): + if isinstance(data, tuple): + data = data[0] + + z_mean, z_log_var, z = self.encoder(data) + reconstruction = self.decoder(z) + reconstruction_loss = tf.reduce_mean( + tf.keras.losses.binary_crossentropy(data, reconstruction) + ) + reconstruction_loss *= 28 * 28 + kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + kl_loss = tf.reduce_mean(kl_loss) + kl_loss *= -0.5 + total_loss = reconstruction_loss + kl_loss + return { + "loss": total_loss, + "reconstruction_loss": reconstruction_loss, + "kl_loss": kl_loss, + } diff --git a/ganime/model/vae/vae.py b/ganime/model/vae/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..edc7d1d9b7dd8b458384238a1fe0f5c06b68da3d --- /dev/null +++ b/ganime/model/vae/vae.py @@ -0,0 +1,98 @@ +import numpy as np +import matplotlib.pyplot as plt + +from tensorflow import keras +from tensorflow.keras import layers +import tensorflow as tf + +input_shape = (20, 64, 64, 1) + +class Sampling(keras.layers.Layer): + """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.""" + + def call(self, inputs): + z_mean, z_log_var = inputs + batch = tf.shape(z_mean)[0] + dim = z_mean.shape[1:] + epsilon = tf.keras.backend.random_normal(shape=(batch, *dim)) + return z_mean + tf.exp(0.5 * z_log_var) * epsilon + + def compute_output_shape(self, input_shape): + return input_shape[0] + + +class VAE(keras.Model): + def __init__(self, latent_dim:int=32, num_embeddings:int=128, beta:float = 0.5, **kwargs): + super().__init__(**kwargs) + self.latent_dim = latent_dim + self.num_embeddings = num_embeddings + self.beta = beta + + self.encoder = self.get_encoder() + self.decoder = self.get_decoder() + + self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss") + self.reconstruction_loss_tracker = tf.keras.metrics.Mean( + name="reconstruction_loss" + ) + self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss") + + + def get_encoder(self): + encoder_inputs = keras.Input(shape=input_shape) + x = layers.TimeDistributed(layers.Conv2D(32, 3, activation="relu", strides=2, padding="same"))( + encoder_inputs + ) + x = layers.TimeDistributed(layers.Conv2D(64, 3, activation="relu", strides=2, padding="same"))(x) + x = layers.TimeDistributed(layers.Conv2D(self.latent_dim, 1, padding="same"))(x) + + x = layers.TimeDistributed(layers.Flatten())(x) + mu = layers.TimeDistributed(layers.Dense(self.num_embeddings))(x) + logvar = layers.TimeDistributed(layers.Dense(self.num_embeddings))(x) + z = Sampling()([mu, logvar]) + + return keras.Model(encoder_inputs, [mu, logvar, z], name="encoder") + + + def get_decoder(self): + latent_inputs = keras.Input(shape=self.encoder.output[2].shape[1:]) + + x = layers.TimeDistributed(layers.Dense(16 * 16 * 32, activation="relu"))(latent_inputs) + x = layers.TimeDistributed(layers.Reshape((16, 16, 32)))(x) + x = layers.TimeDistributed(layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same"))( + x + ) + x = layers.TimeDistributed(layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same"))(x) + decoder_outputs = layers.TimeDistributed(layers.Conv2DTranspose(1, 3, padding="same"))(x) + return keras.Model(latent_inputs, decoder_outputs, name="decoder") + + def train_step(self, data): + x, y = data + + with tf.GradientTape() as tape: + mu, logvar, z = self.encoder(x) + reconstruction = self.decoder(z) + reconstruction_loss = tf.reduce_mean( + tf.reduce_sum( + tf.keras.losses.binary_crossentropy(y, reconstruction), + axis=(1, 2), + ) + ) + kl_loss = -0.5 * (1 + logvar - tf.square(mu) - tf.exp(logvar)) + kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1)) + total_loss = reconstruction_loss + self.beta * kl_loss + grads = tape.gradient(total_loss, self.trainable_weights) + self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) + self.total_loss_tracker.update_state(total_loss) + self.reconstruction_loss_tracker.update_state(reconstruction_loss) + self.kl_loss_tracker.update_state(kl_loss) + return { + "loss": self.total_loss_tracker.result(), + "reconstruction_loss": self.reconstruction_loss_tracker.result(), + "kl_loss": self.kl_loss_tracker.result(), + } + + def call(self, inputs, training=False, mask=None): + z_mean, z_log_var, z = self.encoder(inputs) + pred = self.decoder(z) + return pred diff --git a/ganime/model/vq_vae/vq_vae.py b/ganime/model/vq_vae/vq_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..589afce16401db6e569810d733e231c951294cc8 --- /dev/null +++ b/ganime/model/vq_vae/vq_vae.py @@ -0,0 +1,143 @@ +import numpy as np +import matplotlib.pyplot as plt + +from tensorflow import keras +from tensorflow.keras import layers +import tensorflow as tf + +input_shape = (20, 64, 64, 1) + +class VectorQuantizer(layers.Layer): + def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs): + super().__init__(**kwargs) + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.beta = ( + beta # This parameter is best kept between [0.25, 2] as per the paper. + ) + + # Initialize the embeddings which we will quantize. + w_init = tf.random_uniform_initializer() + self.embeddings = tf.Variable( + initial_value=w_init( + shape=(self.embedding_dim, self.num_embeddings), dtype="float32" + ), + trainable=True, + name="embeddings_vqvae", + ) + + def call(self, x): + # Calculate the input shape of the inputs and + # then flatten the inputs keeping `embedding_dim` intact. + input_shape = tf.shape(x) + flattened = tf.reshape(x, [-1, self.embedding_dim]) + + # Quantization. + encoding_indices = self.get_code_indices(flattened) + encodings = tf.one_hot(encoding_indices, self.num_embeddings) + quantized = tf.matmul(encodings, self.embeddings, transpose_b=True) + quantized = tf.reshape(quantized, input_shape) + + # Calculate vector quantization loss and add that to the layer. You can learn more + # about adding losses to different layers here: + # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check + # the original paper to get a handle on the formulation of the loss function. + commitment_loss = self.beta * tf.reduce_mean( + (tf.stop_gradient(quantized) - x) ** 2 + ) + codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2) + self.add_loss(commitment_loss + codebook_loss) + + # Straight-through estimator. + quantized = x + tf.stop_gradient(quantized - x) + return quantized + + def get_code_indices(self, flattened_inputs): + # Calculate L2-normalized distance between the inputs and the codes. + similarity = tf.matmul(flattened_inputs, self.embeddings) + distances = ( + tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True) + + tf.reduce_sum(self.embeddings ** 2, axis=0) + - 2 * similarity + ) + + # Derive the indices for minimum distances. + encoding_indices = tf.argmin(distances, axis=1) + return encoding_indices + + +class VQVAE(keras.Model): + def __init__(self, train_variance:float, latent_dim:int=32, num_embeddings:int=128, **kwargs): + super().__init__(**kwargs) + self.train_variance = train_variance + self.latent_dim = latent_dim + self.num_embeddings = num_embeddings + + self.vqvae = self.get_vqvae() + + self.total_loss_tracker = keras.metrics.Mean(name="total_loss") + self.reconstruction_loss_tracker = keras.metrics.Mean( + name="reconstruction_loss" + ) + self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss") + + + def get_encoder(self): + encoder_inputs = keras.Input(shape=input_shape) + x = layers.TimeDistributed(layers.Conv2D(32, 3, activation="relu", strides=2, padding="same"))( + encoder_inputs + ) + x = layers.TimeDistributed(layers.Conv2D(64, 3, activation="relu", strides=2, padding="same"))(x) + encoder_outputs = layers.TimeDistributed(layers.Conv2D(self.latent_dim, 1, padding="same"))(x) + return keras.Model(encoder_inputs, encoder_outputs, name="encoder") + + + def get_decoder(self): + latent_inputs = keras.Input(shape=self.get_encoder().output.shape[1:]) + x = layers.TimeDistributed(layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same"))( + latent_inputs + ) + x = layers.TimeDistributed(layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same"))(x) + decoder_outputs = layers.TimeDistributed(layers.Conv2DTranspose(1, 3, padding="same"))(x) + return keras.Model(latent_inputs, decoder_outputs, name="decoder") + + def get_vqvae(self): + self.vq_layer = VectorQuantizer(self.num_embeddings, self.latent_dim, name="vector_quantizer") + self.encoder = self.get_encoder() + self.decoder = self.get_decoder() + inputs = keras.Input(shape=input_shape) + encoder_outputs = self.encoder(inputs) + quantized_latents = self.vq_layer(encoder_outputs) + reconstructions = self.decoder(quantized_latents) + return keras.Model(inputs, reconstructions, name="vq_vae") + + def train_step(self, data): + x, y = data + with tf.GradientTape() as tape: + # Outputs from the VQ-VAE. + reconstructions = self.vqvae(x) + + # Calculate the losses. + reconstruction_loss = ( + tf.reduce_mean((y - reconstructions) ** 2) / self.train_variance + ) + total_loss = reconstruction_loss + sum(self.vqvae.losses) + + # Backpropagation. + grads = tape.gradient(total_loss, self.vqvae.trainable_variables) + self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables)) + + # Loss tracking. + self.total_loss_tracker.update_state(total_loss) + self.reconstruction_loss_tracker.update_state(reconstruction_loss) + self.vq_loss_tracker.update_state(sum(self.vqvae.losses)) + + # Log results. + return { + "loss": self.total_loss_tracker.result(), + "reconstruction_loss": self.reconstruction_loss_tracker.result(), + "vqvae_loss": self.vq_loss_tracker.result(), + } + + def call(self, inputs, training=False, mask=None): + return self.vqvae(inputs) diff --git a/ganime/model/vqgan/__init__.py b/ganime/model/vqgan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ganime/model/vqgan/discriminator/__init__.py b/ganime/model/vqgan/discriminator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ganime/model/vqgan/discriminator/model.py b/ganime/model/vqgan/discriminator/model.py new file mode 100644 index 0000000000000000000000000000000000000000..3e175b46b2c5ec8cebeafc107147d7f3f6c9737f --- /dev/null +++ b/ganime/model/vqgan/discriminator/model.py @@ -0,0 +1,64 @@ +from typing import List +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import Model, Sequential +from tensorflow.keras import layers + + +class NLayerDiscriminator(Model): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + + def __init__(self, input_channels: int = 3, filters: int = 64, n_layers: int = 3): + super().__init__() + + kernel_size = 4 + self.sequence = [ + layers.Conv2D(filters, kernel_size=kernel_size, padding="same"), + layers.LeakyReLU(alpha=0.2), + ] + + filters_mult = 1 + for n in range(1, n_layers): + filters_mult = min(2**n, 8) + + self.sequence += [ + layers.AveragePooling2D(pool_size=2), + layers.Conv2D( + filters * filters_mult, + kernel_size=kernel_size, + strides=1, # 2, + padding="same", + use_bias=False, + ), + layers.BatchNormalization(), + layers.LeakyReLU(alpha=0.2), + ] + + filters_mult = min(2**n_layers, 8) + self.sequence += [ + layers.Conv2D( + filters * filters_mult, + kernel_size=kernel_size, + strides=1, + padding="same", + use_bias=False, + ), + layers.BatchNormalization(), + layers.LeakyReLU(alpha=0.2), + ] + + self.sequence += [ + layers.Conv2D(1, kernel_size=kernel_size, strides=1, padding="same") + ] + + # self.main = Sequential(sequence) + + def call(self, inputs, training=True, mask=None): + h = inputs + for seq in self.sequence: + h = seq(h) + return h + # return self.main(inputs) diff --git a/ganime/model/vqgan/losses/__init__.py b/ganime/model/vqgan/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ganime/model/vqgan/losses/lpips.py b/ganime/model/vqgan/losses/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..9cbcc458c99ecd9c1bfa767f3aca0495603a4705 --- /dev/null +++ b/ganime/model/vqgan/losses/lpips.py @@ -0,0 +1,134 @@ +import os +import numpy as np +import tensorflow as tf +import torchvision.models as models +from tensorflow import keras +from tensorflow.keras import Model, Sequential +from tensorflow.keras import backend as K +from tensorflow.keras import layers +from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input +from tensorflow.keras.losses import Loss +from pyprojroot.pyprojroot import here + + +def normalize_tensor(x, eps=1e-10): + norm_factor = tf.sqrt(tf.reduce_sum(x**2, axis=-1, keepdims=True)) + return x / (norm_factor + eps) + + +class LPIPS(Loss): + def __init__(self, use_dropout=True, **kwargs): + super().__init__(**kwargs) + + self.scaling_layer = ScalingLayer() # preprocess_input + selected_layers = [ + "block1_conv2", + "block2_conv2", + "block3_conv3", + "block4_conv3", + "block5_conv3", + ] + + # TODO here we load the same weights as pytorch, try with tensorflow weights + self.net = self.load_vgg16() # VGG16(weights="imagenet", include_top=False) + self.net.trainable = False + outputs = [self.net.get_layer(layer).output for layer in selected_layers] + + self.model = Model(self.net.input, outputs) + self.lins = [NetLinLayer(use_dropout=use_dropout) for _ in selected_layers] + + # TODO: here we use the pytorch weights of the linear layers, try without these layers, or without initializing the weights + self(tf.zeros((1, 16, 16, 1)), tf.zeros((1, 16, 16, 1))) + self.init_lin_layers() + + def load_vgg16(self) -> Model: + """Load a VGG16 model with the same weights as PyTorch + https://github.com/ezavarygin/vgg16_pytorch2keras + """ + pytorch_model = models.vgg16(pretrained=True) + # select weights in the conv2d layers and transpose them to keras dim ordering: + wblist_torch = list(pytorch_model.parameters())[:26] + wblist_keras = [] + for i in range(len(wblist_torch)): + if wblist_torch[i].dim() == 4: + w = np.transpose(wblist_torch[i].detach().numpy(), axes=[2, 3, 1, 0]) + wblist_keras.append(w) + elif wblist_torch[i].dim() == 1: + b = wblist_torch[i].detach().numpy() + wblist_keras.append(b) + else: + raise Exception("Fully connected layers are not implemented.") + + keras_model = VGG16(include_top=False, weights=None) + keras_model.set_weights(wblist_keras) + return keras_model + + def init_lin_layers(self): + for i in range(5): + weights = np.load( + os.path.join(here(), "models", "NetLinLayer", f"numpy_{i}.npy") + ) + weights = np.moveaxis(weights, 1, 2) + self.lins[i].model.layers[1].set_weights([weights]) + + def call(self, y_true, y_pred): + + scaled_true = self.scaling_layer(y_true) + scaled_pred = self.scaling_layer(y_pred) + + outputs_true, outputs_pred = self.model(scaled_true), self.model(scaled_pred) + features_true, features_pred, diffs = {}, {}, {} + + for kk in range(len(outputs_true)): + features_true[kk], features_pred[kk] = normalize_tensor( + outputs_true[kk] + ), normalize_tensor(outputs_pred[kk]) + + diffs[kk] = (features_true[kk] - features_pred[kk]) ** 2 + + res = [ + tf.reduce_mean(self.lins[kk](diffs[kk]), axis=(-3, -2), keepdims=True) + for kk in range(len(outputs_true)) + ] + + return tf.reduce_sum(res) + + # h1_list = self.model(self.scaling_layer(y_true)) + # h2_list = self.model(self.scaling_layer(y_pred)) + + # rc_loss = 0.0 + # for h1, h2 in zip(h1_list, h2_list): + # h1 = K.batch_flatten(h1) + # h2 = K.batch_flatten(h2) + # rc_loss += K.sum(K.square(h1 - h2), axis=-1) + + # return rc_loss + + +class ScalingLayer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.shift = tf.Variable([-0.030, -0.088, -0.188]) + self.scale = tf.Variable([0.458, 0.448, 0.450]) + + def call(self, inputs): + return (inputs - self.shift) / self.scale + + +class NetLinLayer(layers.Layer): + def __init__(self, channels_out=1, use_dropout=False): + super().__init__() + sequence = ( + [ + layers.Dropout(0.5), + ] + if use_dropout + else [] + ) + sequence += [ + layers.Conv2D(channels_out, 1, padding="same", use_bias=False), + ] + self.model = Sequential(sequence) + + def call(self, inputs): + return self.model(inputs) diff --git a/ganime/model/vqgan/losses/vqperceptual.py b/ganime/model/vqgan/losses/vqperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ec5ad7391106d8ffffb153c2b02d746896a051 --- /dev/null +++ b/ganime/model/vqgan/losses/vqperceptual.py @@ -0,0 +1,47 @@ +from typing import List, Literal + +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import Model, layers +from tensorflow.keras.losses import Loss + +from .lpips import LPIPS + +from ..discriminator.model import NLayerDiscriminator + + +class VQLPIPSWithDiscriminator(Loss): + def __init__( + self, *, pixelloss_weight: float = 1.0, perceptual_weight: float = 1.0, **kwargs + ): + super().__init__(**kwargs) + self.pixelloss_weight = pixelloss_weight + self.perceptual_loss = LPIPS(reduction=tf.keras.losses.Reduction.NONE) + self.perceptual_weight = perceptual_weight + + def call( + self, + y_true, + y_pred, + ): + reconstruction_loss = tf.abs(y_true - y_pred) + if self.perceptual_weight > 0: + perceptual_loss = self.perceptual_loss(y_true, y_pred) + reconstruction_loss += self.perceptual_weight * perceptual_loss + else: + perceptual_loss = 0.0 + + neg_log_likelihood = tf.reduce_mean(reconstruction_loss) + + return neg_log_likelihood + + # # GAN part + # if optimizer_idx == 0: + # if cond is None: + # assert not self.disc_conditional + # logits_fake = self.discriminator(y_pred) + # else: + # assert self.disc_conditional + # logits_fake = self.discriminator(tf.concat([y_pred, cond], axis=-1)) + # g_loss = -tf.reduce_mean(logits_fake) diff --git a/ganime/model/vqgan/vqgan.py b/ganime/model/vqgan/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..3fbfcc1e7f78d3249b536d1a6705dbeda5507f46 --- /dev/null +++ b/ganime/model/vqgan/vqgan.py @@ -0,0 +1,722 @@ +from typing import List, Literal + +import numpy as np +import tensorflow as tf +from .discriminator.model import NLayerDiscriminator +from .losses.vqperceptual import VQLPIPSWithDiscriminator +from tensorflow import keras +from tensorflow.keras import Model, layers, Sequential +from tensorflow.keras.optimizers import Optimizer +from tensorflow_addons.layers import GroupNormalization + +INPUT_SHAPE = (64, 128, 3) +ENCODER_OUTPUT_SHAPE = (8, 8, 128) + + +@tf.function +def hinge_d_loss(logits_real, logits_fake): + loss_real = tf.reduce_mean(keras.activations.relu(1.0 - logits_real)) + loss_fake = tf.reduce_mean(keras.activations.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +@tf.function +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + tf.reduce_mean(keras.activations.softplus(-logits_real)) + + tf.reduce_mean(keras.activations.softplus(logits_fake)) + ) + return d_loss + + +class VQGAN(keras.Model): + def __init__( + self, + train_variance: float, + num_embeddings: int, + embedding_dim: int, + beta: float = 0.25, + z_channels: int = 128, # 256, + codebook_weight: float = 1.0, + disc_num_layers: int = 3, + disc_factor: float = 1.0, + disc_iter_start: int = 0, + disc_conditional: bool = False, + disc_in_channels: int = 3, + disc_weight: float = 0.3, + disc_filters: int = 64, + disc_loss: Literal["hinge", "vanilla"] = "hinge", + **kwargs, + ): + super().__init__(**kwargs) + self.train_variance = train_variance + self.codebook_weight = codebook_weight + + self.encoder = Encoder() + self.decoder = Decoder() + self.quantize = VectorQuantizer(num_embeddings, embedding_dim, beta=beta) + + self.quant_conv = layers.Conv2D(embedding_dim, kernel_size=1) + self.post_quant_conv = layers.Conv2D(z_channels, kernel_size=1) + + self.vqvae = self.get_vqvae() + + self.perceptual_loss = VQLPIPSWithDiscriminator( + reduction=tf.keras.losses.Reduction.NONE + ) + + self.discriminator = NLayerDiscriminator( + input_channels=disc_in_channels, + filters=disc_filters, + n_layers=disc_num_layers, + ) + self.discriminator_iter_start = disc_iter_start + + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + self.total_loss_tracker = keras.metrics.Mean(name="total_loss") + self.reconstruction_loss_tracker = keras.metrics.Mean( + name="reconstruction_loss" + ) + self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss") + self.disc_loss_tracker = keras.metrics.Mean(name="disc_loss") + + self.gen_optimizer: Optimizer = None + self.disc_optimizer: Optimizer = None + + def get_vqvae(self): + inputs = keras.Input(shape=INPUT_SHAPE) + quant = self.encode(inputs) + reconstructed = self.decode(quant) + return keras.Model(inputs, reconstructed, name="vq_vae") + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return self.quantize(h) + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def call(self, inputs, training=True, mask=None): + return self.vqvae(inputs) + + def calculate_adaptive_weight( + self, nll_loss, g_loss, tape, trainable_vars, discriminator_weight + ): + nll_grads = tape.gradient(nll_loss, trainable_vars)[0] + g_grads = tape.gradient(g_loss, trainable_vars)[0] + + d_weight = tf.norm(nll_grads) / (tf.norm(g_grads) + 1e-4) + d_weight = tf.stop_gradient(tf.clip_by_value(d_weight, 0.0, 1e4)) + return d_weight * discriminator_weight + + @tf.function + def adopt_weight(self, weight, global_step, threshold=0, value=0.0): + if global_step < threshold: + weight = value + return weight + + def get_global_step(self, optimizer): + return optimizer.iterations + + def compile( + self, + gen_optimizer, + disc_optimizer, + ): + super().compile() + self.gen_optimizer = gen_optimizer + self.disc_optimizer = disc_optimizer + + def train_step(self, data): + x, y = data + + # Autoencode + with tf.GradientTape() as tape: + with tf.GradientTape(persistent=True) as adaptive_tape: + reconstructions = self(x, training=True) + + # Calculate the losses. + # reconstruction_loss = ( + # tf.reduce_mean((y - reconstructions) ** 2) / self.train_variance + # ) + + logits_fake = self.discriminator(reconstructions, training=False) + + g_loss = -tf.reduce_mean(logits_fake) + nll_loss = self.perceptual_loss(y, reconstructions) + + d_weight = self.calculate_adaptive_weight( + nll_loss, + g_loss, + adaptive_tape, + self.decoder.conv_out.trainable_variables, + self.discriminator_weight, + ) + del adaptive_tape + + disc_factor = self.adopt_weight( + weight=self.disc_factor, + global_step=self.get_global_step(self.gen_optimizer), + threshold=self.discriminator_iter_start, + ) + + # total_loss = reconstruction_loss + sum(self.vqvae.losses) + total_loss = ( + nll_loss + + d_weight * disc_factor * g_loss + # + self.codebook_weight * tf.reduce_mean(self.vqvae.losses) + + self.codebook_weight * sum(self.vqvae.losses) + ) + + # Backpropagation. + grads = tape.gradient(total_loss, self.vqvae.trainable_variables) + self.gen_optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables)) + + # Discriminator + with tf.GradientTape() as disc_tape: + logits_real = self.discriminator(y, training=True) + logits_fake = self.discriminator(reconstructions, training=True) + + disc_factor = self.adopt_weight( + weight=self.disc_factor, + global_step=self.get_global_step(self.disc_optimizer), + threshold=self.discriminator_iter_start, + ) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + disc_grads = disc_tape.gradient(d_loss, self.discriminator.trainable_variables) + self.disc_optimizer.apply_gradients( + zip(disc_grads, self.discriminator.trainable_variables) + ) + + # Loss tracking. + self.total_loss_tracker.update_state(total_loss) + self.reconstruction_loss_tracker.update_state(nll_loss) + self.vq_loss_tracker.update_state(sum(self.vqvae.losses)) + self.disc_loss_tracker.update_state(d_loss) + + # Log results. + return { + "loss": self.total_loss_tracker.result(), + "reconstruction_loss": self.reconstruction_loss_tracker.result(), + "vqvae_loss": self.vq_loss_tracker.result(), + "disc_loss": self.disc_loss_tracker.result(), + } + + +class VectorQuantizer(layers.Layer): + def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs): + super().__init__(**kwargs) + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.beta = ( + beta # This parameter is best kept between [0.25, 2] as per the paper. + ) + + # Initialize the embeddings which we will quantize. + w_init = tf.random_uniform_initializer() + self.embeddings = tf.Variable( + initial_value=w_init( + shape=(self.embedding_dim, self.num_embeddings) # , dtype="float32" + ), + trainable=True, + name="embeddings_vqvae", + ) + + def call(self, x): + # Calculate the input shape of the inputs and + # then flatten the inputs keeping `embedding_dim` intact. + input_shape = tf.shape(x) + flattened = tf.reshape(x, [-1, self.embedding_dim]) + + # Quantization. + encoding_indices = self.get_code_indices(flattened) + encodings = tf.one_hot(encoding_indices, self.num_embeddings) + quantized = tf.matmul(encodings, self.embeddings, transpose_b=True) + quantized = tf.reshape(quantized, input_shape) + + # Calculate vector quantization loss and add that to the layer. You can learn more + # about adding losses to different layers here: + # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check + # the original paper to get a handle on the formulation of the loss function. + commitment_loss = self.beta * tf.reduce_mean( + (tf.stop_gradient(quantized) - x) ** 2 + ) + codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2) + self.add_loss(commitment_loss + codebook_loss) + + # Straight-through estimator. + quantized = x + tf.stop_gradient(quantized - x) + return quantized + + def get_code_indices(self, flattened_inputs): + # Calculate L2-normalized distance between the inputs and the codes. + similarity = tf.matmul(flattened_inputs, self.embeddings) + distances = ( + tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True) + + tf.reduce_sum(self.embeddings**2, axis=0) + - 2 * similarity + ) + + # Derive the indices for minimum distances. + encoding_indices = tf.argmin(distances, axis=1) + return encoding_indices + + +class Encoder(Model): + def __init__( + self, + *, + channels: int = 128, + output_channels: int = 3, + channels_multiplier: List[int] = [1, 1, 2, 2], # [1, 1, 2, 2, 4], + num_res_blocks: int = 1, # 2, + attention_resolution: List[int] = [16], + resolution: int = 64, # 256, + z_channels=128, # 256, + dropout=0.0, + double_z=False, + resamp_with_conv=True, + ): + super().__init__() + + self.channels = channels + self.timestep_embeddings_channel = 0 + self.num_resolutions = len(channels_multiplier) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + + self.conv_in = layers.Conv2D( + self.channels, kernel_size=3, strides=1, padding="same" + ) + + current_resolution = resolution + + in_channels_multiplier = (1,) + tuple(channels_multiplier) + + self.downsampling_list = [] + + for i_level in range(self.num_resolutions): + block_in = channels * in_channels_multiplier[i_level] + block_out = channels * channels_multiplier[i_level] + for i_block in range(self.num_res_blocks): + self.downsampling_list.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + timestep_embedding_channels=self.timestep_embeddings_channel, + dropout=dropout, + ) + ) + block_in = block_out + + if current_resolution in attention_resolution: + # attentions.append(layers.Attention()) + self.downsampling_list.append(AttentionBlock(block_in)) + + if i_level != self.num_resolutions - 1: + self.downsampling_list.append(Downsample(block_in, resamp_with_conv)) + + # self.downsampling = [] + + # for i_level in range(self.num_resolutions): + # block = [] + # attentions = [] + # block_in = channels * in_channels_multiplier[i_level] + # block_out = channels * channels_multiplier[i_level] + # for i_block in range(self.num_res_blocks): + # block.append( + # ResnetBlock( + # in_channels=block_in, + # out_channels=block_out, + # timestep_embedding_channels=self.timestep_embeddings_channel, + # dropout=dropout, + # ) + # ) + # block_in = block_out + + # if current_resolution in attention_resolution: + # # attentions.append(layers.Attention()) + # attentions.append(AttentionBlock(block_in)) + + # down = {} + # down["block"] = block + # down["attention"] = attentions + # if i_level != self.num_resolutions - 1: + # down["downsample"] = Downsample(block_in, resamp_with_conv) + # self.downsampling.append(down) + + # middle + self.mid = {} + self.mid["block_1"] = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + timestep_embedding_channels=self.timestep_embeddings_channel, + dropout=dropout, + ) + self.mid["attn_1"] = AttentionBlock(block_in) + self.mid["block_2"] = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + timestep_embedding_channels=self.timestep_embeddings_channel, + dropout=dropout, + ) + + # end + self.norm_out = GroupNormalization(groups=32, epsilon=1e-6) + self.conv_out = layers.Conv2D( + 2 * z_channels if double_z else z_channels, + kernel_size=3, + strides=1, + padding="same", + ) + + def summary(self): + x = layers.Input(shape=INPUT_SHAPE) + model = Model(inputs=[x], outputs=self.call(x)) + return model.summary() + + def call(self, inputs, training=True, mask=None): + h = self.conv_in(inputs) + for downsampling in self.downsampling_list: + h = downsampling(h) + # for i_level in range(self.num_resolutions): + # for i_block in range(self.num_res_blocks): + # h = self.downsampling[i_level]["block"][i_block](hs[-1]) + # if len(self.downsampling[i_level]["attention"]) > 0: + # h = self.downsampling[i_level]["attention"][i_block](h) + # hs.append(h) + # if i_level != self.num_resolutions - 1: + # hs.append(self.downsampling[i_level]["downsample"](hs[-1])) + + # h = hs[-1] + h = self.mid["block_1"](h) + h = self.mid["attn_1"](h) + h = self.mid["block_2"](h) + + # end + h = self.norm_out(h) + h = keras.activations.swish(h) + h = self.conv_out(h) + return h + + +class Decoder(Model): + def __init__( + self, + *, + channels: int = 128, + output_channels: int = 3, + channels_multiplier: List[int] = [1, 1, 2, 2], # [1, 1, 2, 2, 4], + num_res_blocks: int = 1, # 2, + attention_resolution: List[int] = [16], + resolution: int = 64, # 256, + z_channels=128, # 256, + dropout=0.0, + give_pre_end=False, + resamp_with_conv=True, + ): + super().__init__() + + self.channels = channels + self.timestep_embeddings_channel = 0 + self.num_resolutions = len(channels_multiplier) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.give_pre_end = give_pre_end + + in_channels_multiplier = (1,) + tuple(channels_multiplier) + block_in = channels * channels_multiplier[-1] + current_resolution = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, current_resolution, current_resolution) + + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + self.conv_in = layers.Conv2D(block_in, kernel_size=3, strides=1, padding="same") + + # middle + self.mid = {} + self.mid["block_1"] = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + timestep_embedding_channels=self.timestep_embeddings_channel, + dropout=dropout, + ) + self.mid["attn_1"] = AttentionBlock(block_in) + self.mid["block_2"] = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + timestep_embedding_channels=self.timestep_embeddings_channel, + dropout=dropout, + ) + + # upsampling + + self.upsampling_list = [] + + for i_level in reversed(range(self.num_resolutions)): + block_out = channels * channels_multiplier[i_level] + for i_block in range(self.num_res_blocks + 1): + self.upsampling_list.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + timestep_embedding_channels=self.timestep_embeddings_channel, + dropout=dropout, + ) + ) + block_in = block_out + + if current_resolution in attention_resolution: + # attentions.append(layers.Attention()) + self.upsampling_list.append(AttentionBlock(block_in)) + + if i_level != 0: + self.upsampling_list.append(Upsample(block_in, resamp_with_conv)) + current_resolution *= 2 + # self.upsampling.insert(0, upsampling) + + # self.upsampling = [] + + # for i_level in reversed(range(self.num_resolutions)): + # block = [] + # attentions = [] + # block_out = channels * channels_multiplier[i_level] + # for i_block in range(self.num_res_blocks + 1): + # block.append( + # ResnetBlock( + # in_channels=block_in, + # out_channels=block_out, + # timestep_embedding_channels=self.timestep_embeddings_channel, + # dropout=dropout, + # ) + # ) + # block_in = block_out + + # if current_resolution in attention_resolution: + # # attentions.append(layers.Attention()) + # attentions.append(AttentionBlock(block_in)) + + # upsampling = {} + # upsampling["block"] = block + # upsampling["attention"] = attentions + # if i_level != 0: + # upsampling["upsample"] = Upsample(block_in, resamp_with_conv) + # current_resolution *= 2 + # self.upsampling.insert(0, upsampling) + + # end + self.norm_out = GroupNormalization(groups=32, epsilon=1e-6) + self.conv_out = layers.Conv2D( + output_channels, + kernel_size=3, + strides=1, + activation="sigmoid", + padding="same", + ) + + def summary(self): + x = layers.Input(shape=ENCODER_OUTPUT_SHAPE) + model = Model(inputs=[x], outputs=self.call(x)) + return model.summary() + + def call(self, inputs, training=True, mask=None): + + h = self.conv_in(inputs) + + # middle + h = self.mid["block_1"](h) + h = self.mid["attn_1"](h) + h = self.mid["block_2"](h) + + for upsampling in self.upsampling_list: + h = upsampling(h) + + # for i_level in reversed(range(self.num_resolutions)): + # for i_block in range(self.num_res_blocks + 1): + # h = self.upsampling[i_level]["block"][i_block](h) + # if len(self.upsampling[i_level]["attention"]) > 0: + # h = self.upsampling[i_level]["attention"][i_block](h) + # if i_level != 0: + # h = self.upsampling[i_level]["upsample"](h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = keras.activations.swish(h) + h = self.conv_out(h) + return h + + +class ResnetBlock(layers.Layer): + def __init__( + self, + *, + in_channels, + dropout=0.0, + out_channels=None, + conv_shortcut=False, + timestep_embedding_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = GroupNormalization(groups=32, epsilon=1e-6) + + self.conv1 = layers.Conv2D( + out_channels, kernel_size=3, strides=1, padding="same" + ) + + if timestep_embedding_channels > 0: + self.timestep_embedding_projection = layers.Dense(out_channels) + + self.norm2 = GroupNormalization(groups=32, epsilon=1e-6) + self.dropout = layers.Dropout(dropout) + + self.conv2 = layers.Conv2D( + out_channels, kernel_size=3, strides=1, padding="same" + ) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = layers.Conv2D( + out_channels, kernel_size=3, strides=1, padding="same" + ) + else: + self.nin_shortcut = layers.Conv2D( + out_channels, kernel_size=1, strides=1, padding="valid" + ) + + def call(self, x): + h = x + h = self.norm1(h) + h = keras.activations.swish(h) + h = self.conv1(h) + + # if timestamp_embedding is not None: + # h = h + self.timestep_embedding_projection(keras.activations.swish(timestamp_embedding)) + + h = self.norm2(h) + h = keras.activations.swish(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttentionBlock(layers.Layer): + def __init__(self, channels): + super().__init__() + + self.norm = GroupNormalization(groups=32, epsilon=1e-6) + self.q = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid") + self.k = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid") + self.v = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid") + self.proj_out = layers.Conv2D( + channels, kernel_size=1, strides=1, padding="valid" + ) + + def call(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + ( + b, + h, + w, + c, + ) = q.shape + if b is None: + b = -1 + q = tf.reshape(q, [b, h * w, c]) + k = tf.reshape(k, [b, h * w, c]) + w_ = tf.matmul( + q, k, transpose_b=True + ) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = keras.activations.softmax(w_) + + # attend to values + v = tf.reshape(v, [b, h * w, c]) + # w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = tf.matmul( + v, w_, transpose_a=True + ) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + # h_ = h_.reshape(b, c, h, w) + h_ = tf.reshape(h_, [b, h, w, c]) + + h_ = self.proj_out(h_) + + return x + h_ + + +class Downsample(layers.Layer): + def __init__(self, channels, with_conv=True): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.down_sample = layers.Conv2D( + channels, kernel_size=3, strides=2, padding="same" + ) + else: + self.down_sample = layers.AveragePooling2D(pool_size=2, strides=2) + + def call(self, x): + x = self.down_sample(x) + return x + + +class Upsample(layers.Layer): + def __init__(self, channels, with_conv=False): + super().__init__() + self.with_conv = with_conv + if False: # self.with_conv: + self.up_sample = layers.Conv2DTranspose( + channels, kernel_size=3, strides=2, padding="same" + ) + else: + self.up_sample = Sequential( + [ + layers.UpSampling2D(size=2, interpolation="nearest"), + layers.Conv2D(channels, kernel_size=3, strides=1, padding="same"), + ] + ) + + def call(self, x): + x = self.up_sample(x) + return x diff --git a/ganime/model/vqgan_clean/__init__.py b/ganime/model/vqgan_clean/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ganime/model/vqgan_clean/diffusion/__init__.py b/ganime/model/vqgan_clean/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ganime/model/vqgan_clean/diffusion/decoder.py b/ganime/model/vqgan_clean/diffusion/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a3284f48e0a7a142f708584ceab5e1f42c829b80 --- /dev/null +++ b/ganime/model/vqgan_clean/diffusion/decoder.py @@ -0,0 +1,115 @@ +from typing import List + +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import Model, layers +from tensorflow_addons.layers import GroupNormalization + +from .layers import AttentionBlock, ResnetBlock, Upsample + + +# @tf.keras.utils.register_keras_serializable() +class Decoder(layers.Layer): + def __init__( + self, + *, + channels: int, + output_channels: int = 3, + channels_multiplier: List[int], + num_res_blocks: int, + attention_resolution: List[int], + resolution: int, + z_channels: int, + dropout: float, + **kwargs + ): + super().__init__(**kwargs) + + self.channels = channels + self.output_channels = output_channels + self.channels_multiplier = channels_multiplier + self.num_resolutions = len(channels_multiplier) + self.num_res_blocks = num_res_blocks + self.attention_resolution = attention_resolution + self.resolution = resolution + self.z_channels = z_channels + self.dropout = dropout + + block_in = channels * channels_multiplier[-1] + current_resolution = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, current_resolution, current_resolution) + + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + self.conv_in = layers.Conv2D(block_in, kernel_size=3, strides=1, padding="same") + + # middle + self.mid = {} + self.mid["block_1"] = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + self.mid["attn_1"] = AttentionBlock(block_in) + self.mid["block_2"] = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + + # upsampling + + self.upsampling_list = [] + + for i_level in reversed(range(self.num_resolutions)): + block_out = channels * channels_multiplier[i_level] + for i_block in range(self.num_res_blocks + 1): + self.upsampling_list.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + + if current_resolution in attention_resolution: + # attentions.append(layers.Attention()) + self.upsampling_list.append(AttentionBlock(block_in)) + + if i_level != 0: + self.upsampling_list.append(Upsample(block_in)) + current_resolution *= 2 + + # end + self.norm_out = GroupNormalization(groups=32, epsilon=1e-6) + self.conv_out = layers.Conv2D( + output_channels, + kernel_size=3, + strides=1, + activation="tanh", + padding="same", + ) + + def call(self, inputs, training=True, mask=None): + + h = self.conv_in(inputs) + + # middle + h = self.mid["block_1"](h) + h = self.mid["attn_1"](h) + h = self.mid["block_2"](h) + + for upsampling in self.upsampling_list: + h = upsampling(h) + + # end + h = self.norm_out(h) + h = keras.activations.swish(h) + h = self.conv_out(h) + return h diff --git a/ganime/model/vqgan_clean/diffusion/encoder.py b/ganime/model/vqgan_clean/diffusion/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2685bc1a13767ecde59f7f3356d819fec6c19f10 --- /dev/null +++ b/ganime/model/vqgan_clean/diffusion/encoder.py @@ -0,0 +1,125 @@ +from typing import List +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers, Model +from tensorflow_addons.layers import GroupNormalization +from .layers import ResnetBlock, AttentionBlock, Downsample + + +# @tf.keras.utils.register_keras_serializable() +class Encoder(layers.Layer): + def __init__( + self, + *, + channels: int, + channels_multiplier: List[int], + num_res_blocks: int, + attention_resolution: List[int], + resolution: int, + z_channels: int, + dropout: float, + **kwargs + ): + """Encode an image into a latent vector. The encoder will be constitued of multiple levels (lenght of `channels_multiplier`) with for each level `num_res_blocks` ResnetBlock. + Args: + channels (int, optional): The number of channel for the first layer. Defaults to 128. + channels_multiplier (List[int], optional): The channel multiplier for each level (previous level channels X multipler). Defaults to [1, 1, 2, 2]. + num_res_blocks (int, optional): Number of ResnetBlock at each level. Defaults to 1. + attention_resolution (List[int], optional): Add an attention block if the current resolution is in this array. Defaults to [16]. + resolution (int, optional): The starting resolution. Defaults to 64. + z_channels (int, optional): The number of channel at the end of the encoder. Defaults to 128. + dropout (float, optional): The dropout ratio for each ResnetBlock. Defaults to 0.0. + """ + super().__init__(**kwargs) + + self.channels = channels + self.channels_multiplier = channels_multiplier + self.num_resolutions = len(channels_multiplier) + self.num_res_blocks = num_res_blocks + self.attention_resolution = attention_resolution + self.resolution = resolution + self.z_channels = z_channels + self.dropout = dropout + + self.conv_in = layers.Conv2D( + self.channels, kernel_size=3, strides=1, padding="same" + ) + + current_resolution = resolution + + in_channels_multiplier = (1,) + tuple(channels_multiplier) + + self.downsampling_list = [] + + for i_level in range(self.num_resolutions): + block_in = channels * in_channels_multiplier[i_level] + block_out = channels * channels_multiplier[i_level] + for i_block in range(self.num_res_blocks): + self.downsampling_list.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + + if current_resolution in attention_resolution: + self.downsampling_list.append(AttentionBlock(block_in)) + + if i_level != self.num_resolutions - 1: + self.downsampling_list.append(Downsample(block_in)) + current_resolution = current_resolution // 2 + + # middle + self.mid = {} + self.mid["block_1"] = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + self.mid["attn_1"] = AttentionBlock(block_in) + self.mid["block_2"] = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + + # end + self.norm_out = GroupNormalization(groups=32, epsilon=1e-6) + self.conv_out = layers.Conv2D( + z_channels, + kernel_size=3, + strides=1, + padding="same", + ) + + # def get_config(self): + # config = super().get_config() + # config.update( + # { + # "channels": self.channels, + # "channels_multiplier": self.channels_multiplier, + # "num_res_blocks": self.num_res_blocks, + # "attention_resolution": self.attention_resolution, + # "resolution": self.resolution, + # "z_channels": self.z_channels, + # "dropout": self.dropout, + # } + # ) + # return config + + def call(self, inputs, training=True, mask=None): + h = self.conv_in(inputs) + for downsampling in self.downsampling_list: + h = downsampling(h) + + h = self.mid["block_1"](h) + h = self.mid["attn_1"](h) + h = self.mid["block_2"](h) + + # end + h = self.norm_out(h) + h = keras.activations.swish(h) + h = self.conv_out(h) + return h \ No newline at end of file diff --git a/ganime/model/vqgan_clean/diffusion/layers.py b/ganime/model/vqgan_clean/diffusion/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..bd8f93c6e07aee4b85156f60658a38b4c6968a5c --- /dev/null +++ b/ganime/model/vqgan_clean/diffusion/layers.py @@ -0,0 +1,179 @@ +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers, Sequential +from tensorflow_addons.layers import GroupNormalization + + +@tf.keras.utils.register_keras_serializable() +class ResnetBlock(layers.Layer): + def __init__( + self, + *, + in_channels, + dropout=0.0, + out_channels=None, + conv_shortcut=False, + **kwargs + ): + super().__init__(**kwargs) + self.in_channels = in_channels + self.dropout_rate = dropout + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = GroupNormalization(groups=32, epsilon=1e-6) + + self.conv1 = layers.Conv2D( + out_channels, kernel_size=3, strides=1, padding="same" + ) + + self.norm2 = GroupNormalization(groups=32, epsilon=1e-6) + self.dropout = layers.Dropout(dropout) + + self.conv2 = layers.Conv2D( + out_channels, kernel_size=3, strides=1, padding="same" + ) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = layers.Conv2D( + out_channels, kernel_size=3, strides=1, padding="same" + ) + else: + self.nin_shortcut = layers.Conv2D( + out_channels, kernel_size=1, strides=1, padding="valid" + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "in_channels": self.in_channels, + "dropout": self.dropout_rate, + "out_channels": self.out_channels, + "conv_shortcut": self.use_conv_shortcut, + } + ) + return config + + def call(self, x): + h = x + h = self.norm1(h) + h = keras.activations.swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = keras.activations.swish(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +@tf.keras.utils.register_keras_serializable() +class AttentionBlock(layers.Layer): + def __init__(self, channels, **kwargs): + super().__init__(**kwargs) + self.channels = channels + self.norm = GroupNormalization(groups=32, epsilon=1e-6) + self.q = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid") + self.k = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid") + self.v = layers.Conv2D(channels, kernel_size=1, strides=1, padding="valid") + self.proj_out = layers.Conv2D( + channels, kernel_size=1, strides=1, padding="valid" + ) + + self.attention = layers.Attention() + + def get_config(self): + config = super().get_config() + config.update( + { + "channels": self.channels, + } + ) + return config + + def call(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + (b, h, w, c,) = ( + tf.shape(q)[0], + tf.shape(q)[1], + tf.shape(q)[2], + tf.shape(q)[3], + ) + + if b is None: + b = -1 + q = tf.reshape(q, [b, h * w, c]) + k = tf.reshape(k, [b, h * w, c]) + v = tf.reshape(v, [b, h * w, c]) + + h_ = self.attention([q, v, k]) + + h_ = tf.reshape(h_, [b, h, w, c]) + + h_ = self.proj_out(h_) + + return x + h_ + + +@tf.keras.utils.register_keras_serializable() +class Downsample(layers.Layer): + def __init__(self, channels, **kwargs): + super().__init__(**kwargs) + self.channels = channels + self.down_sample = self.down_sample = layers.AveragePooling2D( + pool_size=2, strides=2 + ) + self.conv = layers.Conv2D(channels, kernel_size=3, strides=1, padding="same") + + def get_config(self): + config = super().get_config() + config.update( + { + "channels": self.channels, + } + ) + return config + + def call(self, x): + x = self.down_sample(x) + x = self.conv(x) + return x + + +@tf.keras.utils.register_keras_serializable() +class Upsample(layers.Layer): + def __init__(self, channels, **kwargs): + super().__init__(**kwargs) + self.channels = channels + self.up_sample = layers.UpSampling2D(size=2, interpolation="bilinear") + self.conv = layers.Conv2D(channels, kernel_size=3, strides=1, padding="same") + + def get_config(self): + config = super().get_config() + config.update( + { + "channels": self.channels, + } + ) + return config + + def call(self, x): + x = self.up_sample(x) + x = self.conv(x) + return x diff --git a/ganime/model/vqgan_clean/discriminator/__init__.py b/ganime/model/vqgan_clean/discriminator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ganime/model/vqgan_clean/discriminator/model.py b/ganime/model/vqgan_clean/discriminator/model.py new file mode 100644 index 0000000000000000000000000000000000000000..986d29dd1c5cbad6e7e7fe6c6e5564a636832557 --- /dev/null +++ b/ganime/model/vqgan_clean/discriminator/model.py @@ -0,0 +1,88 @@ +from typing import List +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import Model, Sequential +from tensorflow.keras import layers +from tensorflow.keras.initializers import RandomNormal + + +class NLayerDiscriminator(Model): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + + def __init__(self, filters: int = 64, n_layers: int = 3, **kwargs): + super().__init__(**kwargs) + + init = RandomNormal(stddev=0.02) + self.filters = filters + self.n_layers = n_layers + + kernel_size = 4 + + inp = tf.keras.layers.Input(shape=[256, 512, 3], name="input_image") + tar = tf.keras.layers.Input(shape=[256, 512, 3], name="target_image") + + x = tf.keras.layers.concatenate([inp, tar]) + + x = layers.Conv2D( + filters, + kernel_size=kernel_size, + strides=2, + # strides=1, + padding="same", + kernel_initializer=init, + )(x) + x = layers.LeakyReLU(alpha=0.2)(x) + + filters_mult = 1 + for n in range(1, n_layers): + filters_mult = min(2**n, 8) + + x = layers.Conv2D( + filters * filters_mult, + kernel_size=kernel_size, + # strides=1, # 2, + strides=2, + padding="same", + use_bias=False, + kernel_initializer=init, + )(x) + x = layers.BatchNormalization()(x) + x = layers.LeakyReLU(alpha=0.2)(x) + + filters_mult = min(2**n_layers, 8) + x = layers.Conv2D( + filters * filters_mult, + kernel_size=kernel_size, + strides=1, + padding="same", + use_bias=False, + kernel_initializer=init, + )(x) + x = layers.BatchNormalization()(x) + x = layers.LeakyReLU(alpha=0.2)(x) + + x = layers.Conv2D( + 1, + kernel_size=kernel_size, + strides=1, + padding="same", + # activation="sigmoid", + kernel_initializer=init, + )(x) + self.model = tf.keras.Model(inputs=[inp, tar], outputs=x) + + def call(self, inputs, training=True, mask=None): + return self.model(inputs) + + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "n_layers": self.n_layers, + } + ) + return config diff --git a/ganime/model/vqgan_clean/discriminator/model_bkp.py b/ganime/model/vqgan_clean/discriminator/model_bkp.py new file mode 100644 index 0000000000000000000000000000000000000000..e0599ef65a406c6c5d0afe250438018e4e8ca124 --- /dev/null +++ b/ganime/model/vqgan_clean/discriminator/model_bkp.py @@ -0,0 +1,76 @@ +from typing import List +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import Model, Sequential +from tensorflow.keras import layers + + +class NLayerDiscriminator(Model): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + + def __init__(self, filters: int = 64, n_layers: int = 3, **kwargs): + super().__init__(**kwargs) + + self.filters = filters + self.n_layers = n_layers + + kernel_size = 4 + self.sequence = [ + layers.Conv2D(filters, kernel_size=kernel_size, strides=1, padding="same"), + layers.LeakyReLU(alpha=0.2), + ] + + filters_mult = 1 + for n in range(1, n_layers): + filters_mult = min(2**n, 8) + + self.sequence += [ + layers.AveragePooling2D(pool_size=2), + layers.Conv2D( + filters * filters_mult, + kernel_size=kernel_size, + strides=1, # 2, + # strides=2, + padding="same", + use_bias=False, + ), + layers.BatchNormalization(), + layers.LeakyReLU(alpha=0.2), + ] + + filters_mult = min(2**n_layers, 8) + self.sequence += [ + layers.AveragePooling2D(pool_size=2), + layers.Conv2D( + filters * filters_mult, + kernel_size=kernel_size, + strides=1, + padding="same", + use_bias=False, + ), + layers.BatchNormalization(), + layers.LeakyReLU(alpha=0.2), + ] + + self.sequence += [ + layers.Conv2D(1, kernel_size=kernel_size, strides=1, padding="same") + ] + + def call(self, inputs, training=True, mask=None): + h = inputs + for seq in self.sequence: + h = seq(h) + return h + + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "n_layers": self.n_layers, + } + ) + return config diff --git a/ganime/model/vqgan_clean/experimental/gpt2_embedding.py b/ganime/model/vqgan_clean/experimental/gpt2_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..65261b6a4605371aa6cb430a99f159d3882f7f99 --- /dev/null +++ b/ganime/model/vqgan_clean/experimental/gpt2_embedding.py @@ -0,0 +1,1127 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" TF 2.0 OpenAI GPT-2 model.""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import tensorflow as tf +from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice + + +from transformers.activations_tf import get_tf_activation +from transformers.modeling_tf_outputs import ( + TFBaseModelOutputWithPastAndCrossAttentions, + TFCausalLMOutputWithCrossAttentions, + TFSequenceClassifierOutputWithPast, +) +from transformers.modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFConv1D, + TFModelInputType, + TFPreTrainedModel, + TFSequenceClassificationLoss, + TFSequenceSummary, + TFSharedEmbeddings, + get_initializer, + keras_serializable, + unpack_inputs, +) +from transformers.tf_utils import shape_list, stable_softmax +from transformers.utils import ( + DUMMY_INPUTS, + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers import GPT2Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "gpt2" +_CONFIG_FOR_DOC = "GPT2Config" +_TOKENIZER_FOR_DOC = "GPT2Tokenizer" + +TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "gpt2", + "gpt2-medium", + "gpt2-large", + "gpt2-xl", + "distilgpt2", + # See all GPT-2 models at https://huggingface.co/models?filter=gpt2 +] + + +class TFAttention(tf.keras.layers.Layer): + def __init__(self, nx, config, scale=False, is_cross_attention=False, **kwargs): + super().__init__(**kwargs) + + n_state = nx # in Attention: n_state=768 (nx=n_embd) + # [switch nx => n_state from Block to Attention to keep identical to TF implementation] + assert n_state % config.n_head == 0 + self.n_head = config.n_head + self.split_size = n_state + self.scale = scale + self.output_attentions = config.output_attentions + + self.is_cross_attention = is_cross_attention + + if self.is_cross_attention: + self.c_attn = TFConv1D( + n_state * 2, + nx, + initializer_range=config.initializer_range, + name="c_attn", + ) + self.q_attn = TFConv1D( + n_state, nx, initializer_range=config.initializer_range, name="q_attn" + ) + else: + self.c_attn = TFConv1D( + n_state * 3, + nx, + initializer_range=config.initializer_range, + name="c_attn", + ) + + self.c_proj = TFConv1D( + n_state, nx, initializer_range=config.initializer_range, name="c_proj" + ) + self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop) + self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop) + self.pruned_heads = set() + + def prune_heads(self, heads): + pass + + @staticmethod + def causal_attention_mask(nd, ns, dtype): + """ + 1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]), + -1, ns-nd), but doesn't produce garbage on TPUs. + """ + i = tf.range(nd)[:, None] + j = tf.range(ns) + m = i >= j - ns + nd + return tf.cast(m, dtype) + + def _attn( + self, q, k, v, attention_mask, head_mask, output_attentions, training=False + ): + # q, k, v have shape [batch, heads, sequence, features] + w = tf.matmul(q, k, transpose_b=True) + if self.scale: + dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores + w = w / tf.math.sqrt(dk) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + + # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. + _, _, nd, ns = shape_list(w) + b = self.causal_attention_mask(nd, ns, dtype=w.dtype) + b = tf.reshape(b, [1, 1, nd, ns]) + w = w * b - 1e4 * (1 - b) + + if attention_mask is not None: + # Apply the attention mask + attention_mask = tf.cast(attention_mask, dtype=w.dtype) + w = w + attention_mask + + w = stable_softmax(w, axis=-1) + w = self.attn_dropout(w, training=training) + + # Mask heads if we want to + if head_mask is not None: + w = w * head_mask + + outputs = [tf.matmul(w, v)] + if output_attentions: + outputs.append(w) + return outputs + + def merge_heads(self, x): + x = tf.transpose(x, [0, 2, 1, 3]) + x_shape = shape_list(x) + new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]] + return tf.reshape(x, new_x_shape) + + def split_heads(self, x): + x_shape = shape_list(x) + new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head] + x = tf.reshape(x, new_x_shape) + return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) + + def call( + self, + x, + layer_past, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + training=False, + ): + + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(x) + kv_out = self.c_attn(encoder_hidden_states) + key, value = tf.split(kv_out, 2, axis=2) + attention_mask = encoder_attention_mask + else: + x = self.c_attn(x) + query, key, value = tf.split(x, 3, axis=2) + + query = self.split_heads(query) + key = self.split_heads(key) + value = self.split_heads(value) + if layer_past is not None: + past_key, past_value = tf.unstack(layer_past, axis=0) + key = tf.concat([past_key, key], axis=-2) + value = tf.concat([past_value, value], axis=-2) + + # to cope with keras serialization + if use_cache: + present = tf.stack([key, value], axis=0) + else: + present = (None,) + + attn_outputs = self._attn( + query, + key, + value, + attention_mask, + head_mask, + output_attentions, + training=training, + ) + a = attn_outputs[0] + + a = self.merge_heads(a) + a = self.c_proj(a) + a = self.resid_dropout(a, training=training) + + outputs = [a, present] + attn_outputs[1:] + return outputs # a, present, (attentions) + + +class TFMLP(tf.keras.layers.Layer): + def __init__(self, n_state, config, **kwargs): + super().__init__(**kwargs) + nx = config.n_embd + self.c_fc = TFConv1D( + n_state, nx, initializer_range=config.initializer_range, name="c_fc" + ) + self.c_proj = TFConv1D( + nx, n_state, initializer_range=config.initializer_range, name="c_proj" + ) + self.act = get_tf_activation(config.activation_function) + self.dropout = tf.keras.layers.Dropout(config.resid_pdrop) + + def call(self, x, training=False): + h = self.act(self.c_fc(x)) + h2 = self.c_proj(h) + h2 = self.dropout(h2, training=training) + return h2 + + +class TFBlock(tf.keras.layers.Layer): + def __init__(self, config, scale=False, **kwargs): + super().__init__(**kwargs) + nx = config.n_embd + inner_dim = config.n_inner if config.n_inner is not None else 4 * nx + self.ln_1 = tf.keras.layers.LayerNormalization( + epsilon=config.layer_norm_epsilon, name="ln_1" + ) + self.attn = TFAttention(nx, config, scale, name="attn") + self.ln_2 = tf.keras.layers.LayerNormalization( + epsilon=config.layer_norm_epsilon, name="ln_2" + ) + + if config.add_cross_attention: + + self.crossattention = TFAttention( + nx, config, scale, name="crossattention", is_cross_attention=True + ) + self.ln_cross_attn = tf.keras.layers.LayerNormalization( + epsilon=config.layer_norm_epsilon, name="ln_cross_attn" + ) + + self.mlp = TFMLP(inner_dim, config, name="mlp") + + def call( + self, + x, + layer_past, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + training=False, + ): + a = self.ln_1(x) + output_attn = self.attn( + a, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=use_cache, + output_attentions=output_attentions, + training=training, + ) + a = output_attn[0] # output_attn: a, present, (attentions) + outputs = output_attn[1:] + x = x + a + + # Cross-Attention Block + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + + ca = self.ln_cross_attn(x) + output_cross_attn = self.crossattention( + ca, + layer_past=None, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=False, + output_attentions=output_attentions, + training=training, + ) + ca = output_cross_attn[0] # output_attn: a, present, (cross_attentions) + x = x + ca + outputs = ( + outputs + output_cross_attn[2:] + ) # add cross attentions if we output attention weights + + m = self.ln_2(x) + m = self.mlp(m, training=training) + x = x + m + + outputs = [x] + outputs + return outputs # x, present, (attentions, cross_attentions) + + +@keras_serializable +class TFGPT2MainLayer(tf.keras.layers.Layer): + config_class = GPT2Config + + def __init__(self, config, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + self.config = config + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.use_cache = config.use_cache + self.return_dict = config.use_return_dict + + self.num_hidden_layers = config.n_layer + self.vocab_size = config.vocab_size + self.n_embd = config.n_embd + self.n_positions = config.n_positions + self.initializer_range = config.initializer_range + + self.wte = TFSharedEmbeddings( + config.vocab_size, + config.hidden_size, + initializer_range=config.initializer_range, + name="wte", + ) + + self.wte_remaining_frames = TFSharedEmbeddings( + config.vocab_size, + config.hidden_size, + initializer_range=config.initializer_range, + name="wte_remaining_frames", + ) + self.drop = tf.keras.layers.Dropout(config.embd_pdrop) + self.h = [ + TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer) + ] + self.ln_f = tf.keras.layers.LayerNormalization( + epsilon=config.layer_norm_epsilon, name="ln_f" + ) + + def build(self, input_shape): + with tf.name_scope("wpe"): + self.wpe = self.add_weight( + name="embeddings", + shape=[self.n_positions, self.n_embd], + initializer=get_initializer(self.initializer_range), + ) + self.wte_remaining_frames.build(input_shape) + + super().build(input_shape) + + def get_input_embeddings(self): + return self.wte + + def get_remaining_frames_embeddings(self): + return self.wte_remaining_frames + + def set_input_embeddings(self, value): + self.wte.weight = value + self.wte.vocab_size = shape_list(value)[0] + + def set_remaining_frames_embeddings(self, value): + self.wte_remaining_frames.weight = value + self.wte_remaining_frames.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + raise NotImplementedError + + @unpack_inputs + def call( + self, + input_ids: Optional[TFModelInputType] = None, + remaining_frames_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + past: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = shape_list(input_ids) + input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if past is None: + past_length = 0 + past = [None] * len(self.h) + else: + past_length = shape_list(past[0][0])[-2] + + if position_ids is None: + position_ids = tf.expand_dims( + tf.range(past_length, input_shape[-1] + past_length), axis=0 + ) + + if attention_mask is not None: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask_shape = shape_list(attention_mask) + attention_mask = tf.reshape( + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + one_cst = tf.constant(1.0) + attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype) + attention_mask = tf.multiply( + tf.subtract(one_cst, attention_mask), tf.constant(-10000.0) + ) + + # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 + if self.config.add_cross_attention and encoder_attention_mask is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + encoder_attention_mask = tf.cast( + encoder_attention_mask, dtype=encoder_hidden_states.dtype + ) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = encoder_attention_mask[ + :, None, None, : + ] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = ( + 1.0 - encoder_extended_attention_mask + ) * -10000.0 + else: + encoder_extended_attention_mask = None + + encoder_attention_mask = encoder_extended_attention_mask + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.num_hidden_layers + # head_mask = tf.constant([0] * self.num_hidden_layers) + + position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids, mode="embedding") + + position_embeds = tf.gather(self.wpe, position_ids) + + if token_type_ids is not None: + token_type_ids = tf.reshape( + token_type_ids, [-1, shape_list(token_type_ids)[-1]] + ) + token_type_embeds = self.wte(token_type_ids, mode="embedding") + else: + token_type_embeds = tf.constant(0.0) + + if remaining_frames_ids is not None: + remaining_frames_ids = tf.reshape( + remaining_frames_ids, [-1, shape_list(remaining_frames_ids)[-1]] + ) + remaining_frames_embeds = self.wte_remaining_frames( + remaining_frames_ids, mode="embedding" + ) + else: + remaining_frames_embeds = tf.constant(0.0) + + position_embeds = tf.cast(position_embeds, dtype=inputs_embeds.dtype) + token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype) + remaining_frames_embeds = tf.cast( + remaining_frames_embeds, dtype=inputs_embeds.dtype + ) + hidden_states = ( + inputs_embeds + + position_embeds + + token_type_embeds + + remaining_frames_embeds + ) + hidden_states = self.drop(hidden_states, training=training) + + output_shape = input_shape + [shape_list(hidden_states)[-1]] + + presents = () if use_cache else None + all_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past)): + if output_hidden_states: + all_hidden_states = all_hidden_states + ( + tf.reshape(hidden_states, output_shape), + ) + + outputs = block( + hidden_states, + layer_past, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + training=training, + ) + + hidden_states, present = outputs[:2] + if use_cache: + presents = presents + (present,) + + if output_attentions: + all_attentions = all_attentions + (outputs[2],) + if ( + self.config.add_cross_attention + and encoder_hidden_states is not None + ): + all_cross_attentions = all_cross_attentions + (outputs[3],) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = tf.reshape(hidden_states, output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if output_attentions: + # let the number of heads free (-1) so we can extract attention even after head pruning + attention_output_shape = ( + input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] + ) + all_attentions = tuple( + tf.reshape(t, attention_output_shape) for t in all_attentions + ) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +class TFGPT2PreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + base_model_prefix = "transformer" + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [ + r"h.\d+.attn.bias", + r"h.\d+.crossattention.bias", + ] + + @property + def dummy_inputs(self): + """ + Dummy inputs to build the network. + + Returns: + `Dict[str, tf.Tensor]`: The dummy inputs. + """ + dummy = {"input_ids": tf.constant(DUMMY_INPUTS)} + # Add `encoder_hidden_states` to make the cross-attention layers' weights initialized + if self.config.add_cross_attention: + batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape + shape = (batch_size, seq_len) + (self.config.hidden_size,) + h = tf.random.uniform(shape=shape) + dummy["encoder_hidden_states"] = h + + return dummy + + @tf.function( + input_signature=[ + { + "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"), + "attention_mask": tf.TensorSpec( + (None, None), tf.int32, name="attention_mask" + ), + } + ] + ) + def serving(self, inputs): + output = self.call(inputs) + + return self.serving_output(output) + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TF 2.0 models accepts two formats as inputs: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional arguments. + + This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the + tensors in the first argument of the model call function: `model(inputs)`. + + If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the + first positional argument : + + - a single Tensor with `input_ids` only and nothing else: `model(inputs_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past` is `None` else `past[0].shape[-2]` (`sequence_length` of + input past key value states). Indices of input sequence tokens in the vocabulary. + + If `past` is used, only input IDs that do not have their past calculated should be passed as `input_ids`. + + Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + past (`List[tf.Tensor]` of length `config.n_layers`): + Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see + `past` output below). Can be used to speed up sequential decoding. The token ids which have their past + given to this model should not be passed as input ids as they have already been computed. + attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class TFGPT2Model(TFGPT2PreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.transformer = TFGPT2MainLayer(config, name="transformer") + + @unpack_inputs + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: Optional[TFModelInputType] = None, + remaining_frames_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + past: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have + their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past`). Set to `False` during training, `True` during generation + """ + + outputs = self.transformer( + input_ids=input_ids, + remaining_frames_ids=remaining_frames_ids, + past=past, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def serving_output(self, output): + pkv = ( + tf.convert_to_tensor(output.past_key_values) + if self.config.use_cache + else None + ) + hs = ( + tf.convert_to_tensor(output.hidden_states) + if self.config.output_hidden_states + else None + ) + attns = ( + tf.convert_to_tensor(output.attentions) + if self.config.output_attentions + else None + ) + cross_attns = ( + tf.convert_to_tensor(output.cross_attentions) + if self.config.output_attentions + and self.config.add_cross_attention + and output.cross_attentions is not None + else None + ) + + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=output.last_hidden_state, + past_key_values=pkv, + hidden_states=hs, + attentions=attns, + cross_attentions=cross_attns, + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) +class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.transformer = TFGPT2MainLayer(config, name="transformer") + + def get_output_embeddings(self): + return self.get_input_embeddings() + + def set_output_embeddings(self, value): + self.set_input_embeddings(value) + + def prepare_inputs_for_generation( + self, inputs, past=None, use_cache=None, use_xla=False, **kwargs + ): + # TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2 + # tests will need to be fixed after the change + + # only last token for inputs_ids if past is defined in kwargs + if past: + inputs = tf.expand_dims(inputs[:, -1], -1) + + # TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left + # for a future PR to not change too many things for now. + # All statements in this if case apply for both xla and non-xla (as they already do in PyTorch) + position_ids = None + attention_mask = None + if use_xla: + attention_mask = kwargs.get("attention_mask", None) + if past is not None and attention_mask is not None: + position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1 + elif attention_mask is not None: + position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True) + + return { + "input_ids": inputs, + "attention_mask": attention_mask, + "position_ids": position_ids, + "past": past, + "use_cache": use_cache, + } + + def _update_model_kwargs_for_xla_generation( + self, outputs, model_kwargs, current_pos, max_length + ): + # TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored + # quite some duplicated code patterns it seems + # also the `attention_mask` is currently used in a somewhat hacky to + # correctly influence the `past_key_values` - not sure if this is the way to go + # Let's keep that for a future PR. + past = outputs.past_key_values + is_past_initialized = model_kwargs.pop("past", None) is not None + attention_mask = model_kwargs.pop("attention_mask") + batch_size = attention_mask.shape[0] + + if not is_past_initialized: + # past[0].shape[3] is seq_length of prompt + num_padding_values = max_length - past[0].shape[3] - 1 + + padding_values = np.zeros((5, 2), dtype=np.int32) + padding_values[3, 1] = num_padding_values + padding_values = tf.constant(padding_values) + + new_past = list(past) + for i in range(len(past)): + new_past[i] = tf.pad(past[i], padding_values) + + # Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids + attention_mask = tf.concat( + [ + attention_mask, + tf.zeros( + (batch_size, num_padding_values), dtype=attention_mask.dtype + ), + tf.ones((batch_size, 1), dtype=attention_mask.dtype), + ], + axis=1, + ) + else: + new_past = [None for _ in range(len(past))] + slice_start_base = tf.constant([0, 0, 0, 1, 0]) + attention_mask_update_slice = tf.ones( + (batch_size, 1), dtype=attention_mask.dtype + ) + # correct 5 here + new_past_index = current_pos - 1 + + for i in range(len(past)): + update_slice = past[i][:, :, :, -1:] + # Write the last slice to the first open location in the padded past array + # and then truncate the last slice off the array + new_past[i] = dynamic_update_slice( + past[i][:, :, :, :-1], + update_slice, + slice_start_base * new_past_index, + ) + + update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index + attention_mask = dynamic_update_slice( + attention_mask, attention_mask_update_slice, update_start + ) + + # set `attention_mask` and `past` + model_kwargs["attention_mask"] = attention_mask + model_kwargs["past"] = tuple(new_past) + + return model_kwargs + + @unpack_inputs + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFCausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: Optional[TFModelInputType] = None, + remaining_frames_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + past: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[Union[np.ndarray, tf.Tensor]] = None, + training: Optional[bool] = False, + ) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have + their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past`). Set to `False` during training, `True` during generation + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + """ + + transformer_outputs = self.transformer( + input_ids=input_ids, + remaining_frames_ids=remaining_frames_ids, + past=past, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + hidden_states = transformer_outputs[0] + logits = self.transformer.wte(hidden_states, mode="linear") + + loss = None + if labels is not None: + # shift labels to the left and cut last logit token + shifted_logits = logits[:, :-1] + labels = labels[:, 1:] + loss = self.hf_compute_loss(labels, shifted_logits) + + if not return_dict: + output = (logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + def serving_output(self, output): + pkv = ( + tf.convert_to_tensor(output.past_key_values) + if self.config.use_cache + else None + ) + hs = ( + tf.convert_to_tensor(output.hidden_states) + if self.config.output_hidden_states + else None + ) + attns = ( + tf.convert_to_tensor(output.attentions) + if self.config.output_attentions + else None + ) + cross_attns = ( + tf.convert_to_tensor(output.cross_attentions) + if self.config.output_attentions + and self.config.add_cross_attention + and output.cross_attentions is not None + else None + ) + + return TFCausalLMOutputWithCrossAttentions( + logits=output.logits, + past_key_values=pkv, + hidden_states=hs, + attentions=attns, + cross_attentions=cross_attns, + ) diff --git a/ganime/model/vqgan_clean/experimental/net2net.py b/ganime/model/vqgan_clean/experimental/net2net.py new file mode 100644 index 0000000000000000000000000000000000000000..ea2167583bb00d482e0e2bc3ffa07313b6acb204 --- /dev/null +++ b/ganime/model/vqgan_clean/experimental/net2net.py @@ -0,0 +1,289 @@ +import numpy as np +import tensorflow as tf +import tensorflow_addons as tfa +from ganime.configs.model_configs import GPTConfig, ModelConfig +from ganime.model.vqgan_clean.vqgan import VQGAN +from ganime.trainer.warmup.cosine import WarmUpCosine +from tensorflow import keras +from tensorflow.keras import Model, layers +from transformers import TFGPT2Model, GPT2Config +from tensorflow.keras import mixed_precision + + +class Net2Net(Model): + def __init__( + self, + transformer_config: GPTConfig, + first_stage_config: ModelConfig, + trainer_config, + **kwargs, + ): + super().__init__(**kwargs) + self.first_stage_model = VQGAN(**first_stage_config) + + # configuration = GPT2Config(**transformer_config) + # self.transformer = TFGPT2Model(configuration)#.from_pretrained("gpt2", **self.transformer_config) + # configuration = GPT2Config(**transformer_config) + self.transformer = TFGPT2Model.from_pretrained( + "gpt2-medium" + ) # , **transformer_config) + if "checkpoint_path" in transformer_config: + print(f"Restoring weights from {transformer_config['checkpoint_path']}") + self.load_weights(transformer_config["checkpoint_path"]) + + self.loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction=tf.keras.losses.Reduction.NONE + ) + + self.loss_tracker = keras.metrics.Mean(name="loss") + + self.scheduled_lrs = self.create_warmup_scheduler(trainer_config) + + optimizer = tfa.optimizers.AdamW( + learning_rate=self.scheduled_lrs, weight_decay=1e-4 + ) + self.compile( + optimizer=optimizer, + loss=self.loss_fn, + # run_eagerly=True, + ) + + # Gradient accumulation + self.gradient_accumulation = [ + tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False) + for v in self.transformer.trainable_variables + ] + + def create_warmup_scheduler(self, trainer_config): + len_x_train = trainer_config["len_x_train"] + batch_size = trainer_config["batch_size"] + n_epochs = trainer_config["n_epochs"] + + total_steps = int(len_x_train / batch_size * n_epochs) + warmup_epoch_percentage = trainer_config["warmup_epoch_percentage"] + warmup_steps = int(total_steps * warmup_epoch_percentage) + + scheduled_lrs = WarmUpCosine( + lr_start=trainer_config["lr_start"], + lr_max=trainer_config["lr_max"], + warmup_steps=warmup_steps, + total_steps=total_steps, + ) + + return scheduled_lrs + + def apply_accu_gradients(self): + # apply accumulated gradients + self.optimizer.apply_gradients( + zip(self.gradient_accumulation, self.transformer.trainable_variables) + ) + + # reset + for i in range(len(self.gradient_accumulation)): + self.gradient_accumulation[i].assign( + tf.zeros_like(self.transformer.trainable_variables[i], dtype=tf.float32) + ) + + @property + def metrics(self): + # We list our `Metric` objects here so that `reset_states()` can be + # called automatically at the start of each epoch + # or at the start of `evaluate()`. + # If you don't implement this property, you have to call + # `reset_states()` yourself at the time of your choosing. + return [ + self.loss_tracker, + ] + + @tf.function( + # input_signature=[ + # tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.float32), + # ] + ) + def encode_to_z(self, x): + quant_z, indices, quantized_loss = self.first_stage_model.encode(x) + + batch_size = tf.shape(quant_z)[0] + + indices = tf.reshape(indices, shape=(batch_size, -1)) + return quant_z, indices + + def call(self, inputs, training=None, mask=None): + + first_frame = inputs["first_frame"] + last_frame = inputs["last_frame"] + n_frames = inputs["n_frames"] + + return self.generate_video(first_frame, last_frame, n_frames) + + + + @tf.function() + def predict_next_indices(self, inputs, example_indices): + logits = self.transformer(inputs) + logits = logits.last_hidden_state + logits = tf.cast(logits, dtype=tf.float32) + # Remove the conditioned part + logits = logits[ + :, tf.shape(example_indices)[1] - 1 : + ] # -1 here 'cause -1 above + # logits = tf.reshape(logits, shape=(-1, tf.shape(logits)[-1])) + return logits + + @tf.function() + def body(self, total_loss, frames, index, last_frame_indices): + + previous_frame_indices = self.encode_to_z(frames[:, index - 1, ...])[1] + cz_indices = tf.concat((last_frame_indices, previous_frame_indices), axis=1) + target_indices = self.encode_to_z(frames[:, index, ...])[1] + # target_indices = tf.reshape(target_indices, shape=(-1,)) + + with tf.GradientTape() as tape: + logits = self.predict_next_indices( + cz_indices[:, :-1], last_frame_indices + ) # don't know why -1 + + frame_loss = tf.cast( + tf.reduce_mean(self.loss_fn(target_indices, logits)), + dtype=tf.float32, + ) + + # Calculate batch gradients + gradients = tape.gradient(frame_loss, self.transformer.trainable_variables) + + # Accumulate batch gradients + for i in range(len(self.gradient_accumulation)): + self.gradient_accumulation[i].assign_add(tf.cast(gradients[i], tf.float32)) + + index = tf.add(index, 1) + total_loss = tf.add(total_loss, frame_loss) + return total_loss, frames, index, last_frame_indices + + def cond(self, total_loss, frames, index, last_frame_indices): + return tf.less(index, tf.shape(frames)[1]) + + def train_step(self, data): + first_frame = data["first_frame"] + last_frame = data["last_frame"] + frames = data["y"] + n_frames = data["n_frames"] + + last_frame_indices = self.encode_to_z(last_frame)[1] + total_loss = 0.0 + + total_loss, _, _, _ = tf.while_loop( + cond=self.cond, + body=self.body, + loop_vars=(tf.constant(0.0), frames, tf.constant(1), last_frame_indices), + ) + + self.apply_accu_gradients() + self.loss_tracker.update_state(total_loss) + return {m.name: m.result() for m in self.metrics} + + def cond_test_step(self, total_loss, frames, index, last_frame_indices): + return tf.less(index, tf.shape(frames)[1]) + + @tf.function() + def body_test_step(self, total_loss, frames, index, predicted_logits): + target_indices = self.encode_to_z(frames[:, index, ...])[1] + # target_indices = tf.reshape(target_indices, shape=(-1,)) + logits = predicted_logits[index] + + frame_loss = tf.cast( + tf.reduce_mean(self.loss_fn(target_indices, logits)), + dtype=tf.float32, + ) + + index = tf.add(index, 1) + total_loss = tf.add(total_loss, frame_loss) + return total_loss, frames, index, predicted_logits + + def test_step(self, data): + first_frame = data["first_frame"] + last_frame = data["last_frame"] + frames = data["y"] + n_frames = data["n_frames"] + + predicted_logits, _, _ = self.predict_logits(first_frame, last_frame, n_frames) + + total_loss, _, _, _ = tf.while_loop( + cond=self.cond_test_step, + body=self.body_test_step, + loop_vars=(tf.constant(0.0), frames, tf.constant(1), predicted_logits), + ) + + + self.loss_tracker.update_state(total_loss) + return {m.name: m.result() for m in self.metrics} + + @tf.function() + def convert_logits_to_indices(self, logits, shape): + probs = tf.keras.activations.softmax(logits) + _, generated_indices = tf.math.top_k(probs) + generated_indices = tf.reshape( + generated_indices, + shape, # , self.first_stage_model.quantize.num_embeddings) + ) + return generated_indices + # quant = self.first_stage_model.quantize.get_codebook_entry( + # generated_indices, shape=shape + # ) + + # return self.first_stage_model.decode(quant) + + @tf.function() + def predict_logits(self, first_frame, last_frame, n_frames): + quant_first, indices_first = self.encode_to_z(first_frame) + quant_last, indices_last = self.encode_to_z(last_frame) + + indices_previous = indices_first + + predicted_logits = tf.TensorArray( + tf.float32, size=0, dynamic_size=True, clear_after_read=False + ) + + index = tf.constant(1) + while tf.less(index, tf.reduce_max(n_frames)): + tf.autograph.experimental.set_loop_options( + shape_invariants=[(indices_previous, tf.TensorShape([None, None]))] + ) + cz_indices = tf.concat((indices_last, indices_previous), axis=1) + logits = self.predict_next_indices(cz_indices[:, :-1], indices_last) + + # generated_indices = self.convert_logits_to_indices( + # logits, tf.shape(indices_last) + # ) + predicted_logits = predicted_logits.write(index, logits) + indices_previous = self.convert_logits_to_indices( + logits, tf.shape(indices_first) + ) + index = tf.add(index, 1) + + return predicted_logits.stack(), tf.shape(quant_first), tf.shape(indices_first) + + @tf.function() + def generate_video(self, first_frame, last_frame, n_frames): + predicted_logits, quant_shape, indices_shape = self.predict_logits( + first_frame, last_frame, n_frames + ) + + generated_images = tf.TensorArray( + tf.float32, size=0, dynamic_size=True, clear_after_read=False + ) + generated_images = generated_images.write(0, first_frame) + + index = tf.constant(1) + while tf.less(index, tf.reduce_max(n_frames)): + indices = self.convert_logits_to_indices(predicted_logits[index], indices_shape) + quant = self.first_stage_model.quantize.get_codebook_entry( + indices, + shape=quant_shape, + ) + decoded = self.first_stage_model.decode(quant) + generated_images = generated_images.write(index, decoded) + index = tf.add(index, 1) + + stacked_images = generated_images.stack() + videos = tf.transpose(stacked_images, (1, 0, 2, 3, 4)) + return videos diff --git a/ganime/model/vqgan_clean/experimental/net2net_v2.py b/ganime/model/vqgan_clean/experimental/net2net_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..a69a24c0f46ab7e256790b71dcf14d650408d892 --- /dev/null +++ b/ganime/model/vqgan_clean/experimental/net2net_v2.py @@ -0,0 +1,253 @@ +import numpy as np +import tensorflow as tf +import tensorflow_addons as tfa +from ganime.configs.model_configs import GPTConfig, ModelConfig +from ganime.model.vqgan_clean.experimental.transformer import Transformer +from ganime.model.vqgan_clean.vqgan import VQGAN +from ganime.trainer.warmup.cosine import WarmUpCosine +from tensorflow import keras +from tensorflow.keras import Model, layers +from ganime.model.vqgan_clean.losses.losses import Losses +from ganime.trainer.warmup.base import create_warmup_scheduler + + +class Net2Net(Model): + def __init__( + self, + transformer_config: GPTConfig, + first_stage_config: ModelConfig, + trainer_config, + num_replicas: int = 1, + **kwargs, + ): + super().__init__(**kwargs) + self.first_stage_model = VQGAN(**first_stage_config) + self.transformer = Transformer(transformer_config) + + if "checkpoint_path" in transformer_config: + print(f"Restoring weights from {transformer_config['checkpoint_path']}") + self.load_weights(transformer_config["checkpoint_path"]) + + losses = Losses(num_replicas=num_replicas) + self.scce_loss = losses.scce_loss + + self.scheduled_lrs = create_warmup_scheduler( + trainer_config, num_devices=num_replicas + ) + + optimizer = tfa.optimizers.AdamW( + learning_rate=self.scheduled_lrs, weight_decay=1e-4 + ) + self.compile( + optimizer=optimizer, + loss=self.loss_fn, + # run_eagerly=True, + ) + + # Gradient accumulation + self.gradient_accumulation = [ + tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False) + for v in self.transformer.trainable_variables + ] + + self.loss_tracker = keras.metrics.Mean(name="loss") + + def loss_fn(self, logits_true, logits_pred): + frame_loss = self.scce_loss(logits_true, logits_pred) + return frame_loss + + def apply_accu_gradients(self): + # apply accumulated gradients + self.optimizer.apply_gradients( + zip(self.gradient_accumulation, self.transformer.trainable_variables) + ) + + # reset + for i in range(len(self.gradient_accumulation)): + self.gradient_accumulation[i].assign( + tf.zeros_like(self.transformer.trainable_variables[i], dtype=tf.float32) + ) + + @property + def metrics(self): + # We list our `Metric` objects here so that `reset_states()` can be + # called automatically at the start of each epoch + # or at the start of `evaluate()`. + # If you don't implement this property, you have to call + # `reset_states()` yourself at the time of your choosing. + return [ + self.loss_tracker, + ] + + @tf.function() + def encode_to_z(self, x): + quant_z, indices, quantized_loss = self.first_stage_model.encode(x) + + batch_size = tf.shape(quant_z)[0] + + indices = tf.reshape(indices, shape=(batch_size, -1)) + return quant_z, indices + + def call(self, inputs, training=None, mask=None): + + first_frame = inputs["first_frame"] + last_frame = inputs["last_frame"] + n_frames = inputs["n_frames"] + + return self.generate_video(first_frame, last_frame, n_frames) + + @tf.function() + def predict_next_indices(self, inputs): + logits = self.transformer(inputs) + return logits + + @tf.function() + def body(self, total_loss, frames, index, last_frame_indices): + + previous_frame_indices = self.encode_to_z(frames[:, index - 1, ...])[1] + target_indices = self.encode_to_z(frames[:, index, ...])[1] + # target_indices = tf.reshape(target_indices, shape=(-1,)) + + with tf.GradientTape() as tape: + logits = self.predict_next_indices( + (last_frame_indices, previous_frame_indices) + ) + + frame_loss = self.loss_fn(target_indices, logits) + + # Calculate batch gradients + gradients = tape.gradient(frame_loss, self.transformer.trainable_variables) + + # Accumulate batch gradients + for i in range(len(self.gradient_accumulation)): + self.gradient_accumulation[i].assign_add(tf.cast(gradients[i], tf.float32)) + + index = tf.add(index, 1) + total_loss = tf.add(total_loss, frame_loss) + return total_loss, frames, index, last_frame_indices + + def cond(self, total_loss, frames, index, last_frame_indices): + return tf.less(index, tf.shape(frames)[1]) + + def train_step(self, data): + first_frame = data["first_frame"] + last_frame = data["last_frame"] + frames = data["y"] + n_frames = data["n_frames"] + + last_frame_indices = self.encode_to_z(last_frame)[1] + total_loss = 0.0 + + total_loss, _, _, _ = tf.while_loop( + cond=self.cond, + body=self.body, + loop_vars=(tf.constant(0.0), frames, tf.constant(1), last_frame_indices), + ) + + self.apply_accu_gradients() + self.loss_tracker.update_state(total_loss) + return {m.name: m.result() for m in self.metrics} + + def cond_test_step(self, total_loss, frames, index, last_frame_indices): + return tf.less(index, tf.shape(frames)[1]) + + @tf.function() + def body_test_step(self, total_loss, frames, index, predicted_logits): + target_indices = self.encode_to_z(frames[:, index, ...])[1] + # target_indices = tf.reshape(target_indices, shape=(-1,)) + logits = predicted_logits[index] + + frame_loss = self.loss_fn(target_indices, logits) + + index = tf.add(index, 1) + total_loss = tf.add(total_loss, frame_loss) + return total_loss, frames, index, predicted_logits + + def test_step(self, data): + first_frame = data["first_frame"] + last_frame = data["last_frame"] + frames = data["y"] + n_frames = data["n_frames"] + + predicted_logits, _, _ = self.predict_logits(first_frame, last_frame, n_frames) + + total_loss, _, _, _ = tf.while_loop( + cond=self.cond_test_step, + body=self.body_test_step, + loop_vars=(tf.constant(0.0), frames, tf.constant(1), predicted_logits), + ) + + self.loss_tracker.update_state(total_loss) + return {m.name: m.result() for m in self.metrics} + + @tf.function() + def convert_logits_to_indices(self, logits, shape): + probs = tf.keras.activations.softmax(logits) + _, generated_indices = tf.math.top_k(probs) + generated_indices = tf.reshape( + generated_indices, + shape, # , self.first_stage_model.quantize.num_embeddings) + ) + return generated_indices + # quant = self.first_stage_model.quantize.get_codebook_entry( + # generated_indices, shape=shape + # ) + + # return self.first_stage_model.decode(quant) + + @tf.function() + def predict_logits(self, first_frame, last_frame, n_frames): + quant_first, indices_first = self.encode_to_z(first_frame) + quant_last, indices_last = self.encode_to_z(last_frame) + + indices_previous = indices_first + + predicted_logits = tf.TensorArray( + tf.float32, size=0, dynamic_size=True, clear_after_read=False + ) + + index = tf.constant(1) + while tf.less(index, tf.reduce_max(n_frames)): + tf.autograph.experimental.set_loop_options( + shape_invariants=[(indices_previous, tf.TensorShape([None, None]))] + ) + logits = self.predict_next_indices((indices_last, indices_previous)) + + # generated_indices = self.convert_logits_to_indices( + # logits, tf.shape(indices_last) + # ) + predicted_logits = predicted_logits.write(index, logits) + indices_previous = self.convert_logits_to_indices( + logits, tf.shape(indices_first) + ) + index = tf.add(index, 1) + + return predicted_logits.stack(), tf.shape(quant_first), tf.shape(indices_first) + + @tf.function() + def generate_video(self, first_frame, last_frame, n_frames): + predicted_logits, quant_shape, indices_shape = self.predict_logits( + first_frame, last_frame, n_frames + ) + + generated_images = tf.TensorArray( + tf.float32, size=0, dynamic_size=True, clear_after_read=False + ) + generated_images = generated_images.write(0, first_frame) + + index = tf.constant(1) + while tf.less(index, tf.reduce_max(n_frames)): + indices = self.convert_logits_to_indices( + predicted_logits[index], indices_shape + ) + quant = self.first_stage_model.quantize.get_codebook_entry( + indices, + shape=quant_shape, + ) + decoded = self.first_stage_model.decode(quant) + generated_images = generated_images.write(index, decoded) + index = tf.add(index, 1) + + stacked_images = generated_images.stack() + videos = tf.transpose(stacked_images, (1, 0, 2, 3, 4)) + return videos diff --git a/ganime/model/vqgan_clean/experimental/net2net_v3.py b/ganime/model/vqgan_clean/experimental/net2net_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..1d54340ff3d1f698907c3b10aa16577c79cf2001 --- /dev/null +++ b/ganime/model/vqgan_clean/experimental/net2net_v3.py @@ -0,0 +1,406 @@ +import numpy as np +import tensorflow as tf +import tensorflow_addons as tfa +from ganime.configs.model_configs import GPTConfig, ModelConfig +from ganime.model.vqgan_clean.experimental.transformer import Transformer +from ganime.model.vqgan_clean.vqgan import VQGAN +from ganime.trainer.warmup.cosine import WarmUpCosine +from tensorflow import keras +from tensorflow.keras import Model, layers +from ganime.model.vqgan_clean.losses.losses import Losses +from ganime.trainer.warmup.base import create_warmup_scheduler +from ganime.visualization.images import unnormalize_if_necessary + + +class Net2Net(Model): + def __init__( + self, + transformer_config: GPTConfig, + first_stage_config: ModelConfig, + trainer_config, + num_replicas: int = 1, + **kwargs, + ): + super().__init__(**kwargs) + self.first_stage_model = VQGAN(**first_stage_config) + self.transformer = Transformer(transformer_config) + + if "checkpoint_path" in transformer_config: + print(f"Restoring weights from {transformer_config['checkpoint_path']}") + self.load_weights(transformer_config["checkpoint_path"]) + + self.scheduled_lrs = create_warmup_scheduler( + trainer_config, num_devices=num_replicas + ) + + optimizer = tfa.optimizers.AdamW( + learning_rate=self.scheduled_lrs, weight_decay=1e-4 + ) + self.compile( + optimizer=optimizer, + # loss=self.loss_fn, + # run_eagerly=True, + ) + + self.n_frames_before = trainer_config["n_frames_before"] + + # Gradient accumulation + self.gradient_accumulation = [ + tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False) + for v in self.transformer.trainable_variables + ] + self.accumulation_size = trainer_config["accumulation_size"] + + # Losses + self.perceptual_loss_weight = trainer_config["perceptual_loss_weight"] + losses = Losses(num_replicas=num_replicas) + self.scce_loss = losses.scce_loss + self.perceptual_loss = losses.perceptual_loss + + self.total_loss_tracker = keras.metrics.Mean(name="total_loss") + self.scce_loss_tracker = keras.metrics.Mean(name="scce_loss") + self.perceptual_loss_tracker = keras.metrics.Mean(name="perceptual_loss") + + self.epoch = 0 + self.stop_ground_truth_after_epoch = trainer_config[ + "stop_ground_truth_after_epoch" + ] + + def apply_accu_gradients(self): + # apply accumulated gradients + self.optimizer.apply_gradients( + zip(self.gradient_accumulation, self.transformer.trainable_variables) + ) + + # reset + for i in range(len(self.gradient_accumulation)): + self.gradient_accumulation[i].assign( + tf.zeros_like(self.transformer.trainable_variables[i], dtype=tf.float32) + ) + + @property + def metrics(self): + # We list our `Metric` objects here so that `reset_states()` can be + # called automatically at the start of each epoch + # or at the start of `evaluate()`. + # If you don't implement this property, you have to call + # `reset_states()` yourself at the time of your choosing. + return [ + self.total_loss_tracker, + self.scce_loss_tracker, + self.perceptual_loss_tracker, + ] + + @tf.function() + def encode_to_z(self, x): + quant_z, indices, quantized_loss = self.first_stage_model.encode(x) + + batch_size = tf.shape(quant_z)[0] + + indices = tf.reshape(indices, shape=(batch_size, -1)) + return quant_z, indices + + def call(self, inputs, training=False, mask=None, return_losses=False): + + return self.predict_video(inputs, training, return_losses) + + def predict(self, data, sample=False, temperature=1.0): + video = self.predict_video( + data, + training=False, + return_losses=False, + sample=sample, + temperature=temperature, + ) + video = unnormalize_if_necessary(video) + return video + + def get_remaining_frames(self, inputs): + if "remaining_frames" in inputs: + remaining_frames = inputs["remaining_frames"] + else: + raise NotImplementedError + remaining_frames = tf.cast(remaining_frames, tf.int64) + return remaining_frames + + # @tf.function() + def predict_video( + self, inputs, training=False, return_losses=False, sample=False, temperature=1.0 + ): + first_frame = inputs["first_frame"] + last_frame = inputs["last_frame"] + n_frames = tf.reduce_max(inputs["n_frames"]) + remaining_frames = self.get_remaining_frames(inputs) + + try: + ground_truth = inputs["y"] + except AttributeError: + ground_truth = None + + previous_frames = tf.expand_dims(first_frame, axis=1) + + predictions = tf.TensorArray( + tf.float32, size=0, dynamic_size=True, clear_after_read=False + ) + + quant_last, indices_last = self.encode_to_z(last_frame) + + total_loss = tf.constant(0.0) + scce_loss = tf.constant(0.0) + perceptual_loss = tf.constant(0.0) + + current_frame_index = tf.constant(1) + while tf.less(current_frame_index, n_frames): + tf.autograph.experimental.set_loop_options( + shape_invariants=[ + (previous_frames, tf.TensorShape([None, None, None, None, 3])) + ], + ) + + if ground_truth is not None: + target_frame = ground_truth[:, current_frame_index] + else: + target_frame = None + + y_pred, losses = self.predict_next_frame( + remaining_frames[:, current_frame_index], + previous_frames, + last_frame, + indices_last, + quant_last, + target_frame=target_frame, + training=training, + sample=sample, + temperature=temperature, + ) + predictions = predictions.write(current_frame_index, y_pred) + + if training and self.epoch < self.stop_ground_truth_after_epoch: + start_index = tf.math.maximum( + 0, current_frame_index - self.n_frames_before + ) + previous_frames = ground_truth[ + :, start_index + 1 : current_frame_index + 1 + ] + else: + previous_frames = predictions.stack() + previous_frames = tf.transpose(previous_frames, (1, 0, 2, 3, 4)) + previous_frames = previous_frames[:, -self.n_frames_before :] + + current_frame_index = tf.add(current_frame_index, 1) + total_loss = tf.add(total_loss, losses[0]) + scce_loss = tf.add(scce_loss, losses[1]) + perceptual_loss = tf.add(perceptual_loss, losses[2]) + + predictions = predictions.stack() + predictions = tf.transpose(predictions, (1, 0, 2, 3, 4)) + + total_loss = tf.divide(total_loss, tf.cast(n_frames, tf.float32)) + scce_loss = tf.divide(scce_loss, tf.cast(n_frames, tf.float32)) + perceptual_loss = tf.divide(perceptual_loss, tf.cast(n_frames, tf.float32)) + + if return_losses: + return predictions, total_loss, scce_loss, perceptual_loss + else: + return predictions + + def predict_next_frame( + self, + remaining_frames, + previous_frames, + last_frame, + indices_last, + quant_last, + target_frame=None, + training=False, + sample=False, + temperature=1.0, + ): + # previous frames is of shape (batch_size, n_frames, height, width, 3) + previous_frames = tf.transpose(previous_frames, (1, 0, 2, 3, 4)) + # previous frames is now of shape (n_frames, batch_size, height, width, 3) + + indices_previous = tf.map_fn( + lambda x: self.encode_to_z(x)[1], + previous_frames, + fn_output_signature=tf.int64, + ) + + # indices is of shape (n_frames, batch_size, n_z) + indices_previous = tf.transpose(indices_previous, (1, 0, 2)) + # indices is now of shape (batch_size, n_frames, n_z) + batch_size, n_frames, n_z = ( + tf.shape(indices_previous)[0], + tf.shape(indices_previous)[1], + tf.shape(indices_previous)[2], + ) + indices_previous = tf.reshape( + indices_previous, shape=(batch_size, n_frames * n_z) + ) + + if target_frame is not None: + _, target_indices = self.encode_to_z(target_frame) + else: + target_indices = None + + if training: + next_frame, losses = self.train_predict_next_frame( + remaining_frames, + indices_last, + indices_previous, + target_indices=target_indices, + target_frame=target_frame, + quant_shape=tf.shape(quant_last), + indices_shape=tf.shape(indices_last), + ) + else: + next_frame, losses = self.predict_next_frame_body( + remaining_frames, + indices_last, + indices_previous, + target_indices=target_indices, + target_frame=target_frame, + quant_shape=tf.shape(quant_last), + indices_shape=tf.shape(indices_last), + sample=sample, + temperature=temperature, + ) + + return next_frame, losses + + def predict_next_frame_body( + self, + remaining_frames, + last_frame_indices, + previous_frame_indices, + quant_shape, + indices_shape, + target_indices=None, + target_frame=None, + sample=False, + temperature=1.0, + ): + logits = self.transformer( + (remaining_frames, last_frame_indices, previous_frame_indices) + ) + next_frame = self.convert_logits_to_image( + logits, + quant_shape=quant_shape, + indices_shape=indices_shape, + sample=sample, + temperature=temperature, + ) + if target_indices is not None: + scce_loss = self.scce_loss(target_indices, logits) + else: + scce_loss = 0.0 + + if target_frame is not None: + perceptual_loss = 1.0 * self.perceptual_loss(target_frame, next_frame) + else: + perceptual_loss = 0.0 + + frame_loss = scce_loss + perceptual_loss + + # self.total_loss_tracker.update_state(frame_loss) + # self.scce_loss_tracker.update_state(scce_loss) + # self.perceptual_loss_tracker.update_state(perceptual_loss) + + return next_frame, (frame_loss, scce_loss, perceptual_loss) + + def train_predict_next_frame( + self, + remaining_frames, + last_frame_indices, + previous_frame_indices, + quant_shape, + indices_shape, + target_indices, + target_frame, + ): + with tf.GradientTape() as tape: + next_frame, losses = self.predict_next_frame_body( + remaining_frames=remaining_frames, + last_frame_indices=last_frame_indices, + previous_frame_indices=previous_frame_indices, + target_indices=target_indices, + quant_shape=quant_shape, + indices_shape=indices_shape, + target_frame=target_frame, + sample=False, + ) + frame_loss = losses[0] + # Calculate batch gradients + gradients = tape.gradient(frame_loss, self.transformer.trainable_variables) + + # Accumulate batch gradients + for i in range(len(self.gradient_accumulation)): + self.gradient_accumulation[i].assign_add(tf.cast(gradients[i], tf.float32)) + + return next_frame, losses + + def convert_logits_to_image( + self, logits, quant_shape, indices_shape, sample=False, temperature=1.0 + ): + if sample: + array = [] + for i in range(logits.shape[1]): + sub_logits = logits[:, i] + sub_logits = sub_logits / temperature + # sub_logits, _ = tf.math.top_k(sub_logits, k=1) + probs = tf.keras.activations.softmax(sub_logits) + probs, probs_index = tf.math.top_k(probs, k=50) + selection_index = tf.random.categorical( + tf.math.log(probs), num_samples=1 + ) + ix = tf.gather_nd(probs_index, selection_index, batch_dims=1) + ix = tf.reshape(ix, (-1, 1)) + array.append(ix) + generated_indices = tf.concat(array, axis=-1) + else: + probs = tf.keras.activations.softmax(logits) + _, generated_indices = tf.math.top_k(probs) + + generated_indices = tf.reshape( + generated_indices, + indices_shape, + ) + quant = self.first_stage_model.quantize.get_codebook_entry( + generated_indices, shape=quant_shape + ) + + return self.first_stage_model.decode(quant) + + def train_step(self, data): + + batch_total_loss, batch_scce_loss, batch_perceptual_loss = 0.0, 0.0, 0.0 + for i in range(self.accumulation_size): + sub_data = { + key: value[ + self.accumulation_size * i : self.accumulation_size * (i + 1) + ] + for key, value in data.items() + } + _, total_loss, scce_loss, perceptual_loss = self( + sub_data, training=True, return_losses=True + ) + batch_total_loss += total_loss + batch_scce_loss += scce_loss + batch_perceptual_loss += perceptual_loss + + self.apply_accu_gradients() + self.total_loss_tracker.update_state(batch_total_loss) + self.scce_loss_tracker.update_state(batch_scce_loss) + self.perceptual_loss_tracker.update_state(batch_perceptual_loss) + self.epoch += 1 + return {m.name: m.result() for m in self.metrics} + + def test_step(self, data): + _, total_loss, scce_loss, perceptual_loss = self( + data, training=False, return_losses=True + ) + + self.total_loss_tracker.update_state(total_loss) + self.scce_loss_tracker.update_state(scce_loss) + self.perceptual_loss_tracker.update_state(perceptual_loss) + return {m.name: m.result() for m in self.metrics} diff --git a/ganime/model/vqgan_clean/experimental/transformer.py b/ganime/model/vqgan_clean/experimental/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e7af1074bb47d0ef3bbfd9f93dcc3da064b1aef1 --- /dev/null +++ b/ganime/model/vqgan_clean/experimental/transformer.py @@ -0,0 +1,109 @@ +from tensorflow.keras import layers +from tensorflow.keras import Model +import tensorflow as tf +from transformers import TFPreTrainedModel + +valid_types = ["gpt2", "gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"] + + +class Transformer(Model): + def __init__(self, config): + super().__init__() + self.config = config + self.remaining_frames_method = self.get_remaining_frames_method(config) + self.transformer_type = self.get_transformer_type(config) + self.transformer = self.load_transformer( + self.remaining_frames_method, self.transformer_type + ) + + def get_transformer_type(self, config): + if "transformer_type" in config: + transformer_type = config["transformer_type"] + if transformer_type not in valid_types: + raise ValueError( + f"transformer_type {transformer_type} is not valid. Valid types are {valid_types}" + ) + return transformer_type + else: + return valid_types[0] + + def get_remaining_frames_method(self, config) -> str: + """Get the method to use for remaining frames. + Check if the method is set inside the configuration, otherwise use concat as the default. + """ + if "remaining_frames_method" in config: + return config["remaining_frames_method"] + else: + return "concat" + + def load_transformer(self, method: str, transformer_type: str) -> TFPreTrainedModel: + print("using method ", method) + if method == "own_embeddings": + from ganime.model.vqgan_clean.experimental.gpt2_embedding import ( + TFGPT2LMHeadModel, + ) + + transformer = TFGPT2LMHeadModel.from_pretrained(transformer_type) + + else: + from transformers import TFGPT2LMHeadModel + + transformer = TFGPT2LMHeadModel.from_pretrained(transformer_type) + return transformer + + def concatenate_inputs( + self, remaining_frames, last_frame_indices, previous_frame_indices + ) -> tf.Tensor: + if self.remaining_frames_method == "concat": + return tf.concat( + [remaining_frames, last_frame_indices, previous_frame_indices], axis=1 + ) + else: + return tf.concat([last_frame_indices, previous_frame_indices], axis=1) + + def call_transformer( + self, transformer_input, remaining_frames, training, attention_mask + ): + if self.remaining_frames_method == "concat": + return self.transformer( + transformer_input, training=training, attention_mask=attention_mask + ) + elif self.remaining_frames_method == "token_type_ids": + return self.transformer( + transformer_input, + token_type_ids=remaining_frames, + training=training, + attention_mask=attention_mask, + ) + elif self.remaining_frames_method == "own_embeddings": + return self.transformer( + transformer_input, + remaining_frames_ids=remaining_frames, + training=training, + attention_mask=attention_mask, + ) + else: + raise ValueError( + f"Unknown remaining_frames_method {self.remaining_frames_method}" + ) + + def call(self, inputs, training=True, mask=None): + remaining_frames, last_frame_indices, previous_frame_indices = inputs + remaining_frames = tf.expand_dims(remaining_frames, axis=1) + shape_to_keep = tf.shape(last_frame_indices)[1] + + h = self.concatenate_inputs( + remaining_frames, last_frame_indices, previous_frame_indices + ) + + # transformer_input = h[:, :-1] + transformer_input = h + mask = tf.ones_like(transformer_input) * tf.cast( + tf.cast(remaining_frames, dtype=tf.bool), dtype=remaining_frames.dtype + ) + + h = self.call_transformer(transformer_input, remaining_frames, training, mask) + h = h.logits + # h = self.transformer.transformer.wte(h, mode="linear") + h = h[:, -shape_to_keep:] + return h diff --git a/ganime/model/vqgan_clean/losses/__init__.py b/ganime/model/vqgan_clean/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ganime/model/vqgan_clean/losses/losses.py b/ganime/model/vqgan_clean/losses/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..089d33eecc89fd78a24343619520e06e468e78ed --- /dev/null +++ b/ganime/model/vqgan_clean/losses/losses.py @@ -0,0 +1,106 @@ +import os + +import tensorflow as tf +from pyprojroot.pyprojroot import here +from tensorflow import reduce_mean +from tensorflow.keras import Model +from tensorflow.keras.applications import VGG19 +from tensorflow.keras.applications.vgg19 import preprocess_input +from tensorflow.keras.losses import ( + Loss, + MeanSquaredError, + Reduction, + SparseCategoricalCrossentropy, + BinaryCrossentropy, +) + +from . import vgg19_loss as vgg19 + + +class Losses: + def __init__(self, num_replicas: int = 1, vgg_model_file: str = None): + self.num_replicas = num_replicas + self.SCCE = SparseCategoricalCrossentropy( + from_logits=True, reduction=Reduction.NONE + ) + self.MSE = MeanSquaredError(reduction=Reduction.NONE) + self.MAE = tf.keras.losses.MeanAbsoluteError(reduction=Reduction.NONE) + self.BCE = BinaryCrossentropy(from_logits=True, reduction=Reduction.NONE) + + self.vgg = VGG.build() + self.preprocess = preprocess_input + try: + root_dir = here() + except RecursionError: + root_dir = "GANime" + + self.vgg_model_file = ( + os.path.join(root_dir, "models", "vgg19", "imagenet-vgg-verydeep-19.mat") + if vgg_model_file is None + else vgg_model_file + ) + + def bce_loss(self, real, pred): + # compute binary cross entropy loss without reduction + loss = self.BCE(real, pred) + # compute reduced mean over the entire batch + loss = reduce_mean(loss) * (1.0 / self.num_replicas) + # return reduced bce loss + return loss + + def perceptual_loss(self, real, pred): + y_true_preprocessed = self.preprocess(real) + y_pred_preprocessed = self.preprocess(pred) + y_true_scaled = y_true_preprocessed / 12.75 + y_pred_scaled = y_pred_preprocessed / 12.75 + + loss = self.mse_loss(y_true_scaled, y_pred_scaled) * 5e3 + + return loss + + def scce_loss(self, real, pred): + # compute categorical cross entropy loss without reduction + loss = self.SCCE(real, pred) + # compute reduced mean over the entire batch + loss = reduce_mean(loss) * (1.0 / self.num_replicas) + # return reduced scce loss + return loss + + def mse_loss(self, real, pred): + # compute mean squared error without reduction + loss = self.MSE(real, pred) + # compute reduced mean over the entire batch + loss = reduce_mean(loss) * (1.0 / self.num_replicas) + # return reduced mse loss + return loss + + def mae_loss(self, real, pred): + # compute mean absolute error without reduction + loss = self.MAE(real, pred) + # compute reduced mean over the entire batch + loss = reduce_mean(loss) * (1.0 / self.num_replicas) + # return reduced mae loss + return loss + + def vgg_loss(self, real, pred): + loss = vgg19.vgg_loss(pred, real, vgg_model_file=self.vgg_model_file) + return loss + + def style_loss(self, real, pred): + loss = vgg19.style_loss( + pred, + real, + vgg_model_file=self.vgg_model_file, + ) + return loss + + +class VGG: + @staticmethod + def build(): + # initialize the pre-trained VGG19 model + vgg = VGG19(input_shape=(None, None, 3), weights="imagenet", include_top=False) + # slicing the VGG19 model till layer #20 + model = Model(vgg.input, vgg.layers[20].output) + # return the sliced VGG19 model + return model diff --git a/ganime/model/vqgan_clean/losses/lpips.py b/ganime/model/vqgan_clean/losses/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..66139e028a33e2ba7db4a37f2d4633c0b48e61b5 --- /dev/null +++ b/ganime/model/vqgan_clean/losses/lpips.py @@ -0,0 +1,140 @@ +import os +import numpy as np +import tensorflow as tf +import torchvision.models as models +from tensorflow import keras +from tensorflow.keras import Model, Sequential +from tensorflow.keras import backend as K +from tensorflow.keras import layers +from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input +from tensorflow.keras.losses import Loss +from pyprojroot.pyprojroot import here + + +def normalize_tensor(x, eps=1e-10): + norm_factor = tf.sqrt(tf.reduce_sum(x**2, axis=-1, keepdims=True)) + return x / (norm_factor + eps) + + +class LPIPS(Loss): + def __init__(self, use_dropout=True, **kwargs): + super().__init__(**kwargs) + + self.use_dropout = use_dropout + self.scaling_layer = ScalingLayer() # preprocess_input + selected_layers = [ + "block1_conv2", + "block2_conv2", + "block3_conv3", + "block4_conv3", + "block5_conv3", + ] + + # TODO here we load the same weights as pytorch, try with tensorflow weights + self.net = self.load_vgg16() # VGG16(weights="imagenet", include_top=False) + self.net.trainable = False + outputs = [self.net.get_layer(layer).output for layer in selected_layers] + + self.model = Model(self.net.input, outputs) + self.lins = [ + NetLinLayer(input_shape=output.shape[1:], use_dropout=use_dropout) + for output in outputs + ] + + # TODO: here we use the pytorch weights of the linear layers, try without these layers, or without initializing the weights + self.init_lin_layers() + + def load_vgg16(self) -> Model: + """Load a VGG16 model with the same weights as PyTorch + https://github.com/ezavarygin/vgg16_pytorch2keras + """ + pytorch_model = models.vgg16(pretrained=True) + # select weights in the conv2d layers and transpose them to keras dim ordering: + wblist_torch = list(pytorch_model.parameters())[:26] + wblist_keras = [] + for i in range(len(wblist_torch)): + if wblist_torch[i].dim() == 4: + w = np.transpose(wblist_torch[i].detach().numpy(), axes=[2, 3, 1, 0]) + wblist_keras.append(w) + elif wblist_torch[i].dim() == 1: + b = wblist_torch[i].detach().numpy() + wblist_keras.append(b) + else: + raise Exception("Fully connected layers are not implemented.") + + keras_model = VGG16(include_top=False, weights=None) + keras_model.set_weights(wblist_keras) + return keras_model + + def init_lin_layers(self): + for i in range(5): + weights = np.load( + os.path.join(here(), "models", "NetLinLayer", f"numpy_{i}.npy") + ) + weights = np.moveaxis(weights, 1, 2) + self.lins[i].set_weights([weights]) + + def call(self, y_true, y_pred): + scaled_true = self.scaling_layer(y_true) + scaled_pred = self.scaling_layer(y_pred) + + outputs_true, outputs_pred = self.model(scaled_true), self.model(scaled_pred) + features_true, features_pred, diffs = {}, {}, {} + + for kk in range(len(outputs_true)): + features_true[kk], features_pred[kk] = normalize_tensor( + outputs_true[kk] + ), normalize_tensor(outputs_pred[kk]) + + diffs[kk] = (features_true[kk] - features_pred[kk]) ** 2 + + res = [ + tf.reduce_mean(self.lins[kk](diffs[kk]), axis=(-3, -2), keepdims=True) + for kk in range(len(outputs_true)) + ] + + # return tf.cast(tf.reduce_sum(res), tf.float32) + return tf.reduce_sum(res) + + +class ScalingLayer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.shift = tf.Variable([-0.030, -0.088, -0.188]) + self.scale = tf.Variable([0.458, 0.448, 0.450]) + + def call(self, inputs): + if inputs.dtype == tf.float16: + inputs = tf.cast(inputs, tf.float32) + # self.shift = tf.cast(self.shift, tf.float16) + # self.scale = tf.cast(self.scale, tf.float16) + return (inputs - self.shift) / self.scale + + +class NetLinLayer(layers.Layer): + def __init__(self, input_shape, channels_out=1, use_dropout=False): + super().__init__() + inputs = tf.keras.Input(shape=input_shape) + x = inputs + if use_dropout: + x = layers.Dropout(0.5)(x) + x = layers.Conv2D(channels_out, 1, padding="same", use_bias=False)(x) + x = layers.Activation("linear", dtype="float32")(x) + self.model = Model(inputs=inputs, outputs=x) + + # sequence = [layers.Input(input_shape)] + # sequence += ( + # [ + # layers.Dropout(0.5), + # ] + # if use_dropout + # else [] + # ) + # sequence += [ + # layers.Conv2D(channels_out, 1, padding="same", use_bias=False), + # layers.Activation("linear", dtype="float32"), + # ] + # self.model = Sequential(sequence) + + def call(self, inputs): + return self.model(inputs) diff --git a/ganime/model/vqgan_clean/losses/vgg19_loss.py b/ganime/model/vqgan_clean/losses/vgg19_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1eb428cd9b810bb0d3eece6f6f32a78df24d5291 --- /dev/null +++ b/ganime/model/vqgan_clean/losses/vgg19_loss.py @@ -0,0 +1,421 @@ +# Copyright 2022 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Feature loss based on 19 layer VGG network. + + +The network layers in the feature loss is weighted as described in +'Stereo Magnification: Learning View Synthesis using Multiplane Images', +Tinghui Zhou, Richard Tucker, Flynn, Graham Fyffe, Noah Snavely, SIGGRAPH 2018. +""" + +from typing import Any, Callable, Dict, Optional, Sequence, Tuple + +import numpy as np +import scipy.io as sio +import tensorflow.compat.v1 as tf + + +def _build_net( + layer_type: str, + input_tensor: tf.Tensor, + weight_bias: Optional[Tuple[tf.Tensor, tf.Tensor]] = None, + name: Optional[str] = None, +) -> Callable[[Any], Any]: + """Build a layer of the VGG network. + + Args: + layer_type: A string, type of this layer. + input_tensor: A tensor. + weight_bias: A tuple of weight and bias. + name: A string, name of this layer. + + Returns: + A callable function of the tensorflow layer. + + Raises: + ValueError: If layer_type is not conv or pool. + """ + + if layer_type == "conv": + return tf.nn.relu( + tf.nn.conv2d( + input_tensor, + weight_bias[0], + strides=[1, 1, 1, 1], + padding="SAME", + name=name, + ) + + weight_bias[1] + ) + elif layer_type == "pool": + return tf.nn.avg_pool( + input_tensor, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME" + ) + else: + raise ValueError("Unsupported layer %s" % layer_type) + + +def _get_weight_and_bias( + vgg_layers: np.ndarray, index: int +) -> Tuple[tf.Tensor, tf.Tensor]: + """Get the weight and bias of a specific layer from the VGG pretrained model. + + Args: + vgg_layers: An array, the VGG pretrained model. + index: An integer, index of the layer. + + Returns: + weights: A tensor. + bias: A tensor. + """ + + weights = vgg_layers[index][0][0][2][0][0] + weights = tf.constant(weights) + bias = vgg_layers[index][0][0][2][0][1] + bias = tf.constant(np.reshape(bias, (bias.size))) + + return weights, bias + + +def _build_vgg19(image: tf.Tensor, model_filepath: str) -> Dict[str, tf.Tensor]: + """Builds the VGG network given the model weights. + + The weights are loaded only for the first time this code is invoked. + + Args: + image: A tensor, input image. + model_filepath: A string, path to the VGG pretrained model. + + Returns: + net: A dict mapping a layer name to a tensor. + """ + + with tf.variable_scope("vgg", reuse=True): + net = {} + if not hasattr(_build_vgg19, "vgg_rawnet"): + with tf.io.gfile.GFile(model_filepath, "rb") as f: + _build_vgg19.vgg_rawnet = sio.loadmat(f) + vgg_layers = _build_vgg19.vgg_rawnet["layers"][0] + imagenet_mean = tf.constant([123.6800, 116.7790, 103.9390], shape=[1, 1, 1, 3]) + net["input"] = image - imagenet_mean + net["conv1_1"] = _build_net( + "conv", + net["input"], + _get_weight_and_bias(vgg_layers, 0), + name="vgg_conv1_1", + ) + net["conv1_2"] = _build_net( + "conv", + net["conv1_1"], + _get_weight_and_bias(vgg_layers, 2), + name="vgg_conv1_2", + ) + net["pool1"] = _build_net("pool", net["conv1_2"]) + net["conv2_1"] = _build_net( + "conv", + net["pool1"], + _get_weight_and_bias(vgg_layers, 5), + name="vgg_conv2_1", + ) + net["conv2_2"] = _build_net( + "conv", + net["conv2_1"], + _get_weight_and_bias(vgg_layers, 7), + name="vgg_conv2_2", + ) + net["pool2"] = _build_net("pool", net["conv2_2"]) + net["conv3_1"] = _build_net( + "conv", + net["pool2"], + _get_weight_and_bias(vgg_layers, 10), + name="vgg_conv3_1", + ) + net["conv3_2"] = _build_net( + "conv", + net["conv3_1"], + _get_weight_and_bias(vgg_layers, 12), + name="vgg_conv3_2", + ) + net["conv3_3"] = _build_net( + "conv", + net["conv3_2"], + _get_weight_and_bias(vgg_layers, 14), + name="vgg_conv3_3", + ) + net["conv3_4"] = _build_net( + "conv", + net["conv3_3"], + _get_weight_and_bias(vgg_layers, 16), + name="vgg_conv3_4", + ) + net["pool3"] = _build_net("pool", net["conv3_4"]) + net["conv4_1"] = _build_net( + "conv", + net["pool3"], + _get_weight_and_bias(vgg_layers, 19), + name="vgg_conv4_1", + ) + net["conv4_2"] = _build_net( + "conv", + net["conv4_1"], + _get_weight_and_bias(vgg_layers, 21), + name="vgg_conv4_2", + ) + net["conv4_3"] = _build_net( + "conv", + net["conv4_2"], + _get_weight_and_bias(vgg_layers, 23), + name="vgg_conv4_3", + ) + net["conv4_4"] = _build_net( + "conv", + net["conv4_3"], + _get_weight_and_bias(vgg_layers, 25), + name="vgg_conv4_4", + ) + net["pool4"] = _build_net("pool", net["conv4_4"]) + net["conv5_1"] = _build_net( + "conv", + net["pool4"], + _get_weight_and_bias(vgg_layers, 28), + name="vgg_conv5_1", + ) + net["conv5_2"] = _build_net( + "conv", + net["conv5_1"], + _get_weight_and_bias(vgg_layers, 30), + name="vgg_conv5_2", + ) + + return net + + +def _compute_error( + fake: tf.Tensor, real: tf.Tensor, mask: Optional[tf.Tensor] = None +) -> tf.Tensor: + """Computes the L1 loss and reweights by the mask.""" + if mask is None: + return tf.reduce_mean(tf.abs(fake - real)) + else: + # Resizes mask to the same size as the input. + size = (tf.shape(fake)[1], tf.shape(fake)[2]) + resized_mask = tf.image.resize( + mask, size, method=tf.image.ResizeMethod.BILINEAR + ) + return tf.reduce_mean(tf.abs(fake - real) * resized_mask) + + +# Normalized VGG loss (from +# https://github.com/CQFIO/PhotographicImageSynthesis) +def vgg_loss( + image: tf.Tensor, + reference: tf.Tensor, + vgg_model_file: str, + weights: Optional[Sequence[float]] = None, + mask: Optional[tf.Tensor] = None, +) -> tf.Tensor: + """Computes the VGG loss for an image pair. + + The VGG loss is the average feature vector difference between the two images. + + The input images must be in [0, 1] range in (B, H, W, 3) RGB format and + the recommendation seems to be to have them in gamma space. + + The pretrained weights are publicly available in + http://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-19.mat + + Args: + image: A tensor, typically the prediction from a network. + reference: A tensor, the image to compare against, i.e. the golden image. + vgg_model_file: A string, filename for the VGG 19 network weights in MATLAB + format. + weights: A list of float, optional weights for the layers. The defaults are + from Qifeng Chen and Vladlen Koltun, "Photographic image synthesis with + cascaded refinement networks," ICCV 2017. + mask: An optional image-shape and single-channel tensor, the mask values are + per-pixel weights to be applied on the losses. The mask will be resized to + the same spatial resolution with the feature maps before been applied to + the losses. When the mask value is zero, pixels near the boundary of the + mask can still influence the loss if they fall into the receptive field of + the VGG convolutional layers. + + Returns: + vgg_loss: The linear combination of losses from five VGG layers. + """ + + if not weights: + weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5] + + vgg_ref = _build_vgg19(reference * 255.0, vgg_model_file) + vgg_img = _build_vgg19(image * 255.0, vgg_model_file) + p1 = _compute_error(vgg_ref["conv1_2"], vgg_img["conv1_2"], mask) * weights[0] + p2 = _compute_error(vgg_ref["conv2_2"], vgg_img["conv2_2"], mask) * weights[1] + p3 = _compute_error(vgg_ref["conv3_2"], vgg_img["conv3_2"], mask) * weights[2] + p4 = _compute_error(vgg_ref["conv4_2"], vgg_img["conv4_2"], mask) * weights[3] + p5 = _compute_error(vgg_ref["conv5_2"], vgg_img["conv5_2"], mask) * weights[4] + + final_loss = p1 + p2 + p3 + p4 + p5 + + # Scale to range [0..1]. + final_loss /= 255.0 + + return final_loss + + +def _compute_gram_matrix(input_features: tf.Tensor, mask: tf.Tensor) -> tf.Tensor: + """Computes Gram matrix of `input_features`. + + Gram matrix described in https://en.wikipedia.org/wiki/Gramian_matrix. + + Args: + input_features: A tf.Tensor of shape (B, H, W, C) representing a feature map + obtained by a convolutional layer of a VGG network. + mask: A tf.Tensor of shape (B, H, W, 1) representing the per-pixel weights + to be applied on the `input_features`. The mask will be resized to the + same spatial resolution as the `input_featues`. When the mask value is + zero, pixels near the boundary of the mask can still influence the loss if + they fall into the receptive field of the VGG convolutional layers. + + Returns: + A tf.Tensor of shape (B, C, C) representing the gram matrix of the masked + `input_features`. + """ + # _, h, w, c = tuple( + # [ + # i if (isinstance(i, int) or i is None) else i.value + # for i in tf.shape(input_features) + # ] + # ) + _, h, w, c = ( + tf.shape(input_features)[0], + tf.shape(input_features)[1], + tf.shape(input_features)[2], + tf.shape(input_features)[3], + ) + + if mask is None: + reshaped_features = tf.reshape(input_features, (-1, h * w, c)) + else: + # Resize mask to match the shape of `input_features` + resized_mask = tf.image.resize( + mask, (h, w), method=tf.image.ResizeMethod.BILINEAR + ) + reshaped_features = tf.reshape(input_features * resized_mask, (-1, h * w, c)) + return tf.matmul(reshaped_features, reshaped_features, transpose_a=True) / tf.cast( + tf.multiply(h, w), tf.float32 + ) + + +def style_loss( + image: tf.Tensor, + reference: tf.Tensor, + vgg_model_file: str, + weights: Optional[Sequence[float]] = None, + mask: Optional[tf.Tensor] = None, +) -> tf.Tensor: + """Computes style loss as used in `A Neural Algorithm of Artistic Style`. + + Based on the work in https://github.com/cysmith/neural-style-tf. Weights are + first initilaized to the inverse of the number of elements in each VGG layer + considerd. After 1.5M iterations, they are rescaled to normalize the + contribution of the Style loss to be equal to other losses (L1/VGG). This is + based on the works of image inpainting (https://arxiv.org/abs/1804.07723) + and frame prediction (https://arxiv.org/abs/1811.00684). + + The style loss is the average gram matrix difference between `image` and + `reference`. The gram matrix is the inner product of a feature map of shape + (B, H*W, C) with itself. Results in a symmetric gram matrix shaped (B, C, C). + + The input images must be in [0, 1] range in (B, H, W, 3) RGB format and + the recommendation seems to be to have them in gamma space. + + The pretrained weights are publicly available in + http://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-19.mat + + Args: + image: A tensor, typically the prediction from a network. + reference: A tensor, the image to compare against, i.e. the golden image. + vgg_model_file: A string, filename for the VGG 19 network weights in MATLAB + format. + weights: A list of float, optional weights for the layers. The defaults are + from Qifeng Chen and Vladlen Koltun, "Photographic image synthesis with + cascaded refinement networks," ICCV 2017. + mask: An optional image-shape and single-channel tensor, the mask values are + per-pixel weights to be applied on the losses. The mask will be resized to + the same spatial resolution with the feature maps before been applied to + the losses. When the mask value is zero, pixels near the boundary of the + mask can still influence the loss if they fall into the receptive field of + the VGG convolutional layers. + + Returns: + Style loss, a linear combination of gram matrix L2 differences of from five + VGG layer features. + """ + + if not weights: + weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5] + + vgg_ref = _build_vgg19(reference * 255.0, vgg_model_file) + vgg_img = _build_vgg19(image * 255.0, vgg_model_file) + + p1 = ( + tf.reduce_mean( + tf.squared_difference( + _compute_gram_matrix(vgg_ref["conv1_2"] / 255.0, mask), + _compute_gram_matrix(vgg_img["conv1_2"] / 255.0, mask), + ) + ) + * weights[0] + ) + p2 = ( + tf.reduce_mean( + tf.squared_difference( + _compute_gram_matrix(vgg_ref["conv2_2"] / 255.0, mask), + _compute_gram_matrix(vgg_img["conv2_2"] / 255.0, mask), + ) + ) + * weights[1] + ) + p3 = ( + tf.reduce_mean( + tf.squared_difference( + _compute_gram_matrix(vgg_ref["conv3_2"] / 255.0, mask), + _compute_gram_matrix(vgg_img["conv3_2"] / 255.0, mask), + ) + ) + * weights[2] + ) + p4 = ( + tf.reduce_mean( + tf.squared_difference( + _compute_gram_matrix(vgg_ref["conv4_2"] / 255.0, mask), + _compute_gram_matrix(vgg_img["conv4_2"] / 255.0, mask), + ) + ) + * weights[3] + ) + p5 = ( + tf.reduce_mean( + tf.squared_difference( + _compute_gram_matrix(vgg_ref["conv5_2"] / 255.0, mask), + _compute_gram_matrix(vgg_img["conv5_2"] / 255.0, mask), + ) + ) + * weights[4] + ) + + final_loss = p1 + p2 + p3 + p4 + p5 + + return final_loss diff --git a/ganime/model/vqgan_clean/losses/vqperceptual.py b/ganime/model/vqgan_clean/losses/vqperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..6539d10d0fb1a000fe517e43cb0c05069caaeb8c --- /dev/null +++ b/ganime/model/vqgan_clean/losses/vqperceptual.py @@ -0,0 +1,51 @@ +from typing import List +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras.losses import Loss + +from .lpips import LPIPS + + +class PerceptualLoss(Loss): + def __init__(self, *, perceptual_weight: float = 1.0, **kwargs): + """Perceptual loss based on the LPIPS metric. + + Args: + perceptual_weight (float, optional): The weight of the perceptual loss. Defaults to 1.0. + """ + super().__init__(**kwargs) + + self.perceptual_loss = LPIPS(reduction=tf.keras.losses.Reduction.NONE) + self.perceptual_weight = perceptual_weight + + def get_config(self): + config = super().get_config() + config.update( + { + "perceptual_weight": self.perceptual_weight, + } + ) + return config + + def call( + self, + y_true, + y_pred, + ): + reconstruction_loss = tf.abs(y_true - y_pred) + if self.perceptual_weight > 0: + + perceptual_loss = self.perceptual_loss(y_true, y_pred) + reconstruction_loss += self.perceptual_weight * perceptual_loss + else: + perceptual_loss = 0.0 + + neg_log_likelihood = tf.reduce_mean(reconstruction_loss) + + return neg_log_likelihood diff --git a/ganime/model/vqgan_clean/net2net.py b/ganime/model/vqgan_clean/net2net.py new file mode 100644 index 0000000000000000000000000000000000000000..bb3e8103d79b8d452b4804355b89d5fdbf5b5344 --- /dev/null +++ b/ganime/model/vqgan_clean/net2net.py @@ -0,0 +1,305 @@ +import math + +import numpy as np +import tensorflow as tf +import tensorflow_addons as tfa +from ganime.configs.model_configs import GPTConfig, ModelConfig +from ganime.model.vqgan_clean.transformer.mingpt import GPT +from ganime.model.vqgan_clean.vqgan import VQGAN +from tensorflow import keras +from tensorflow.keras import Model, layers + + +class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule): + """A LearningRateSchedule that uses a warmup cosine decay schedule.""" + + def __init__(self, lr_start, lr_max, warmup_steps, total_steps): + """ + Args: + lr_start: The initial learning rate + lr_max: The maximum learning rate to which lr should increase to in + the warmup steps + warmup_steps: The number of steps for which the model warms up + total_steps: The total number of steps for the model training + """ + super().__init__() + self.lr_start = lr_start + self.lr_max = lr_max + self.warmup_steps = warmup_steps + self.total_steps = total_steps + self.pi = tf.constant(np.pi) + + def __call__(self, step): + # Check whether the total number of steps is larger than the warmup + # steps. If not, then throw a value error. + if self.total_steps < self.warmup_steps: + raise ValueError( + f"Total number of steps {self.total_steps} must be" + + f"larger or equal to warmup steps {self.warmup_steps}." + ) + + # `cos_annealed_lr` is a graph that increases to 1 from the initial + # step to the warmup step. After that this graph decays to -1 at the + # final step mark. + cos_annealed_lr = tf.cos( + self.pi + * (tf.cast(step, tf.float32) - self.warmup_steps) + / tf.cast(self.total_steps - self.warmup_steps, tf.float32) + ) + + # Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes + # from 0 to 2. Normalize the graph with 0.5 so that now it goes from 0 + # to 1. With the normalized graph we scale it with `lr_max` such that + # it goes from 0 to `lr_max` + learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr) + + # Check whether warmup_steps is more than 0. + if self.warmup_steps > 0: + # Check whether lr_max is larger that lr_start. If not, throw a value + # error. + if self.lr_max < self.lr_start: + raise ValueError( + f"lr_start {self.lr_start} must be smaller or" + + f"equal to lr_max {self.lr_max}." + ) + + # Calculate the slope with which the learning rate should increase + # in the warumup schedule. The formula for slope is m = ((b-a)/steps) + slope = (self.lr_max - self.lr_start) / self.warmup_steps + + # With the formula for a straight line (y = mx+c) build the warmup + # schedule + warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start + + # When the current step is lesser that warmup steps, get the line + # graph. When the current step is greater than the warmup steps, get + # the scaled cos graph. + learning_rate = tf.where( + step < self.warmup_steps, warmup_rate, learning_rate + ) + + # When the current step is more that the total steps, return 0 else return + # the calculated graph. + return tf.where( + step > self.total_steps, 0.0, learning_rate, name="learning_rate" + ) + + +LEN_X_TRAIN = 8000 +BATCH_SIZE = 16 +N_EPOCHS = 500 +TOTAL_STEPS = int(LEN_X_TRAIN / BATCH_SIZE * N_EPOCHS) +WARMUP_EPOCH_PERCENTAGE = 0.15 +WARMUP_STEPS = int(TOTAL_STEPS * WARMUP_EPOCH_PERCENTAGE) + + +class Net2Net(Model): + def __init__( + self, + transformer_config: GPTConfig, + first_stage_config: ModelConfig, + cond_stage_config: ModelConfig, + ): + super().__init__() + self.transformer = GPT(**transformer_config) + self.first_stage_model = VQGAN(**first_stage_config) + self.cond_stage_model = self.first_stage_model # VQGAN(**cond_stage_config) + + self.loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction=tf.keras.losses.Reduction.NONE + ) + + self.loss_tracker = keras.metrics.Mean(name="loss") + # self.compile( + # "adam", + # loss=self.loss_fn, + # ) + + # Calculate the number of steps for warmup. + + # Initialize the warmupcosine schedule. + self.scheduled_lrs = WarmUpCosine( + lr_start=1e-5, + lr_max=2.5e-4, + warmup_steps=WARMUP_STEPS, + total_steps=TOTAL_STEPS, + ) + + self.compile( + optimizer=tfa.optimizers.AdamW( + learning_rate=self.scheduled_lrs, weight_decay=1e-4 + ), + loss=[self.loss_fn, None], + ) + + @property + def metrics(self): + # We list our `Metric` objects here so that `reset_states()` can be + # called automatically at the start of each epoch + # or at the start of `evaluate()`. + # If you don't implement this property, you have to call + # `reset_states()` yourself at the time of your choosing. + return [ + self.loss_tracker, + ] + + def encode_to_z(self, x): + quant_z, indices, quantized_loss = self.first_stage_model.encode(x) + + batch_size = tf.shape(quant_z)[0] + + indices = tf.reshape(indices, shape=(batch_size, -1)) + return quant_z, indices + + def encode_to_c(self, c): + quant_c, indices, quantized_loss = self.cond_stage_model.encode(c) + + batch_size = tf.shape(quant_c)[0] + + indices = tf.reshape(indices, shape=(batch_size, -1)) + return quant_c, indices + + # def build(self, input_shape): + # self.first_stage_model.build(input_shape) + # self.cond_stage_model.build(input_shape) + # return super().build(input_shape) + + def call(self, inputs, training=None, mask=None): + # x, c = inputs + + # # one step to produce the logits + # _, z_indices = self.encode_to_z(x) + # _, c_indices = self.encode_to_c(c) + + # cz_indices = tf.concat((c_indices, z_indices), axis=1) + + # target = z_indices + # logits = self.transformer( + # cz_indices[:, :-1] # , training=training + # ) # don't know why -1 + + # logits = logits[:, tf.shape(c_indices)[1] - 1 :] # -1 here 'cause -1 above + + # logits = tf.reshape(logits, shape=(-1, logits.shape[-1])) + # target = tf.reshape(target, shape=(-1,)) + + # return logits, target + if isinstance(inputs, tuple) and len(inputs) == 2: + first_last_frame, y = inputs + else: + first_last_frame, y = inputs, None + + return self.process_video(first_last_frame, y) + + @tf.function() + def process_image(self, x, c, target_image=None): + + frame_loss = 0 + + # one step to produce the logits + quant_z, z_indices = self.encode_to_z(x) + _, c_indices = self.encode_to_c(c) + + cz_indices = tf.concat((c_indices, z_indices), axis=1) + + logits = self.transformer( + cz_indices[:, :-1] # , training=training + ) # don't know why -1 + + # Remove the conditioned part + logits = logits[:, tf.shape(c_indices)[1] - 1 :] # -1 here 'cause -1 above + + logits = tf.reshape(logits, shape=(-1, logits.shape[-1])) + + if target_image is not None: + _, target_indices = self.encode_to_z(target_image) + + target_indices = tf.reshape(target_indices, shape=(-1,)) + + frame_loss = tf.reduce_mean( + self.loss_fn(y_true=target_indices, y_pred=logits) + ) + + image = self.get_image(logits, tf.shape(quant_z)) + + return image, frame_loss + + # @tf.function() + def process_video(self, first_last_frame, target_video=None): + + first_frame = first_last_frame[:, 0] + last_frame = first_last_frame[:, -1] + + x = first_frame + c = last_frame + + total_loss = 0 + generated_video = [x] + + for i in range(19): # TODO change 19 to the number of frame in the video + + if target_video is not None: + + with tf.GradientTape() as tape: + target = target_video[:, i, ...] if target_video is not None else None + generated_image, frame_loss = self.process_image(x, c, target_image=target) + x = generated_image + generated_video.append(generated_image) + + grads = tape.gradient( + frame_loss, + self.transformer.trainable_variables, + ) + self.optimizer.apply_gradients(zip(grads, self.trainable_variables)) + total_loss += frame_loss + + else: + target = target_video[:, i, ...] if target_video is not None else None + generated_image, frame_loss = self.process_image(x, c, target_image=target) + x = generated_image + generated_video.append(generated_image) + + if target_video is not None: + return tf.stack(generated_video, axis=1), total_loss + else: + return tf.stack(generated_video, axis=1) + + def train_step(self, data): + + first_last_frame, y = data + + generated_video, loss = self.process_video(first_last_frame, y) + self.loss_tracker.update_state(loss) + + # Log results. + return {m.name: m.result() for m in self.metrics} + + def get_image(self, logits, shape): + probs = tf.keras.activations.softmax(logits) + _, generated_indices = tf.math.top_k(probs) + generated_indices = tf.reshape( + generated_indices, + (-1,), # , self.first_stage_model.quantize.num_embeddings) + ) + quant = self.first_stage_model.quantize.get_codebook_entry( + generated_indices, shape=shape + ) + return self.first_stage_model.decode(quant) + + def test_step(self, data): + + first_last_frame, y = data + + generated_video, loss = self.process_video(first_last_frame, y) + + self.loss_tracker.update_state(loss) + + # Log results. + return {m.name: m.result() for m in self.metrics} + + def decode_to_img(self, index, zshape): + quant_z = self.first_stage_model.quantize.get_codebook_entry( + tf.reshape(index, -1), shape=zshape + ) + x = self.first_stage_model.decode(quant_z) + return x diff --git a/ganime/model/vqgan_clean/transformer/mingpt.py b/ganime/model/vqgan_clean/transformer/mingpt.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4c0b241424a908b1c9700329264761d52ea6ec --- /dev/null +++ b/ganime/model/vqgan_clean/transformer/mingpt.py @@ -0,0 +1,167 @@ +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import Model, Sequential, layers + + +class TransformerBlock(layers.Layer): + def __init__(self, n_embedding, n_head, attention_percentage_drop): + super().__init__() + self.att = layers.MultiHeadAttention(n_head, n_embedding) + self.ffn = Sequential( + [ + layers.Dense(n_embedding * 4, activation="relu"), + layers.Dense(n_embedding), + ] + ) + self.layernorm1 = layers.LayerNormalization(epsilon=1e-6) + self.layernorm2 = layers.LayerNormalization(epsilon=1e-6) + self.dropout1 = layers.Dropout(attention_percentage_drop) + self.dropout2 = layers.Dropout(attention_percentage_drop) + + def causal_attention_mask(self, batch_size, n_dest, n_src, dtype): + """ + Mask the upper half of the dot product matrix in self attention. + This prevents flow of information from future tokens to current token. + 1's in the lower triangle, counting from the lower right corner. + """ + i = tf.range(n_dest)[:, None] + j = tf.range(n_src) + m = i >= j - n_src + n_dest + mask = tf.cast(m, dtype) + mask = tf.reshape(mask, [1, n_dest, n_src]) + mult = tf.concat( + [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0 + ) + return tf.tile(mask, mult) + + def call(self, inputs): + input_shape = tf.shape(inputs) + batch_size = input_shape[0] + seq_len = input_shape[1] + causal_mask = self.causal_attention_mask(batch_size, seq_len, seq_len, tf.bool) + attention_output = self.att(inputs, inputs, attention_mask=causal_mask) + attention_output = self.dropout1(attention_output) + out1 = self.layernorm1(inputs + attention_output) + ffn_output = self.ffn(out1) + ffn_output = self.dropout2(ffn_output) + return self.layernorm2(out1 + ffn_output) + + +class TransformerBlockV2(layers.Layer): + def __init__(self, n_embedding, n_head, attention_percentage_drop): + super().__init__() + self.att = layers.MultiHeadAttention(n_head, n_embedding) + self.mlp = Sequential( + [ + layers.Dense(n_embedding * 4), + layers.Activation("gelu"), + layers.Dense(n_embedding), + layers.Dropout(attention_percentage_drop), + ] + ) + self.layernorm1 = layers.LayerNormalization(epsilon=1e-6) + self.layernorm2 = layers.LayerNormalization(epsilon=1e-6) + + def causal_attention_mask(self, batch_size, n_dest, n_src, dtype): + """ + Mask the upper half of the dot product matrix in self attention. + This prevents flow of information from future tokens to current token. + 1's in the lower triangle, counting from the lower right corner. + """ + i = tf.range(n_dest)[:, None] + j = tf.range(n_src) + m = i >= j - n_src + n_dest + mask = tf.cast(m, dtype) + mask = tf.reshape(mask, [1, n_dest, n_src]) + mult = tf.concat( + [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0 + ) + return tf.tile(mask, mult) + + def call(self, inputs): + input_shape = tf.shape(inputs) + batch_size = input_shape[0] + seq_len = input_shape[1] + causal_mask = self.causal_attention_mask(batch_size, seq_len, seq_len, tf.bool) + + h = inputs + h = self.layernorm1(h) + h = self.att(h, h, attention_mask=causal_mask) + + h = inputs + h + h = h + self.mlp(self.layernorm2(h)) + return h + + +class TokenAndPositionEmbedding(layers.Layer): + def __init__(self, block_size, vocab_size, n_embedding, embedding_percentage_drop): + super(TokenAndPositionEmbedding, self).__init__() + self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=n_embedding) + self.pos_emb = layers.Embedding(input_dim=block_size, output_dim=n_embedding) + self.dropout = layers.Dropout(embedding_percentage_drop) + + def call(self, x, training=None, mask=None): + maxlen = tf.shape(x)[-1] + positions = tf.range(start=0, limit=maxlen, delta=1) + positions = self.pos_emb(positions) + x = self.token_emb(x) + return self.dropout(x + positions, training=training) + + +class GPT(Model): + def __init__( + self, + vocab_size, + block_size, + n_layer, + n_head, + n_embedding, + embedding_percentage_drop, + attention_percentage_drop, + ): + super().__init__() + self.block_size = block_size + self.embedding_layer = TokenAndPositionEmbedding( + block_size=block_size, + vocab_size=vocab_size, + n_embedding=n_embedding, + embedding_percentage_drop=embedding_percentage_drop, + ) + self.blocks = [ + TransformerBlock( + n_embedding=n_embedding, + n_head=n_head, + attention_percentage_drop=attention_percentage_drop, + ) + for _ in range(n_layer) + ] + self.layer_norm = layers.LayerNormalization(epsilon=1e-6) + self.outputs = layers.Dense(vocab_size) + + # loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + # self.compile( + # "adam", + # loss=loss_fn, + # ) # No loss and optimization based on word embeddings from transformer block + + # def build(self, input_shape): + # self.input_shape = input_shape + + def summary(self): + x = layers.Input(shape=self.input_shape[1:]) + model = Model(inputs=[x], outputs=self.call(x)) + return model.summary() + + def build_graph(self, raw_shape): + x = tf.keras.layers.Input(shape=(raw_shape), ragged=True) + return tf.keras.Model(inputs=[x], outputs=self.call(x)) + + def call(self, inputs, training=True, mask=None): + token_embeddings = self.embedding_layer(inputs) + + h = token_embeddings + for block in self.blocks: + h = block(h) + h = self.layer_norm(h) + logits = self.outputs(h) + return logits diff --git a/ganime/model/vqgan_clean/vqgan.py b/ganime/model/vqgan_clean/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..f5b331b3c3c613cf314ac743937bd5c04995bde3 --- /dev/null +++ b/ganime/model/vqgan_clean/vqgan.py @@ -0,0 +1,691 @@ +from typing import List, Optional, Tuple +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +import numpy as np +import tensorflow as tf + +from ganime.model.vqgan_clean.losses.losses import Losses +from .discriminator.model import NLayerDiscriminator +from .losses.vqperceptual import PerceptualLoss +from .vqvae.quantize import VectorQuantizer +from .diffusion.encoder import Encoder +from .diffusion.decoder import Decoder +from tensorflow import keras +from tensorflow.keras import layers +from tensorflow.keras.optimizers import Optimizer +from ganime.configs.model_configs import ( + VQVAEConfig, + AutoencoderConfig, + DiscriminatorConfig, + LossConfig, +) + + +@tf.function +def hinge_d_loss(logits_real, logits_fake): + loss_real = tf.reduce_mean(keras.activations.relu(1.0 - logits_real)) + loss_fake = tf.reduce_mean(keras.activations.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +@tf.function +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + tf.reduce_mean(keras.activations.softplus(-logits_real)) + + tf.reduce_mean(keras.activations.softplus(logits_fake)) + ) + return d_loss + + +class VQGAN(keras.Model): + def __init__( + self, + vqvae_config: VQVAEConfig, + autoencoder_config: AutoencoderConfig, + discriminator_config: DiscriminatorConfig, + loss_config: LossConfig, + checkpoint_path: Optional[str] = None, + num_replicas: int = 1, + **kwargs, + ): + """Create a VQ-GAN model. + Args: + vqvae (VQVAEConfig): The configuration of the VQ-VAE + autoencoder (AutoencoderConfig): The configuration of the autoencoder + discriminator (DiscriminatorConfig): The configuration of the discriminator + loss_config (LossConfig): The configuration of the loss + Raises: + ValueError: The specified loss type is not supported. + """ + super().__init__(**kwargs) + self.perceptual_weight = loss_config.vqvae.perceptual_weight + self.codebook_weight = loss_config.vqvae.codebook_weight + self.vqvae_config = vqvae_config + self.autoencoder_config = autoencoder_config + self.discriminator_config = discriminator_config + self.loss_config = loss_config + self.num_replicas = num_replicas + # self.num_embeddings = num_embeddings + # self.embedding_dim = embedding_dim + # self.codebook_weight = codebook_weight + # self.beta = beta + # self.z_channels = z_channels + # self.ae_channels = ae_channels + # self.ae_channels_multiplier = ae_channels_multiplier + # self.ae_num_res_blocks = ae_num_res_blocks + # self.ae_attention_resolution = ae_attention_resolution + # self.ae_resolution = ae_resolution + # self.ae_dropout = ae_dropout + # self.disc_num_layers = disc_num_layers + # self.disc_filters = disc_filters + # self.disc_loss_str = disc_loss + + # Create the encoder - quant_conv - vector quantizer - post quant_conv - decoder + self.encoder = Encoder(**autoencoder_config) + + self.quant_conv = layers.Conv2D( + vqvae_config.embedding_dim, kernel_size=1, name="pre_quant_conv" + ) + + self.quantize = VectorQuantizer( + vqvae_config.num_embeddings, + vqvae_config.embedding_dim, + beta=vqvae_config.beta, + ) + + self.post_quant_conv = layers.Conv2D( + autoencoder_config.z_channels, kernel_size=1, name="post_quant_conv" + ) + + self.decoder = Decoder(**autoencoder_config) + + self.perceptual_loss = self.get_perceptual_loss( + loss_config.perceptual_loss + ) # PerceptualLoss(reduction=tf.keras.losses.Reduction.NONE) + + # Setup discriminator and params + self.discriminator = NLayerDiscriminator( + filters=discriminator_config.filters, + n_layers=discriminator_config.num_layers, + ) + self.discriminator_iter_start = loss_config.discriminator.iter_start + self.disc_loss = self._get_discriminator_loss(loss_config.discriminator.loss) + self.disc_factor = loss_config.discriminator.factor + self.discriminator_weight = loss_config.discriminator.weight + # self.disc_conditional = disc_conditional + + # Setup loss trackers + self.total_loss_tracker = keras.metrics.Mean(name="total_loss") + self.reconstruction_loss_tracker = keras.metrics.Mean( + name="reconstruction_loss" + ) + self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss") + self.disc_loss_tracker = keras.metrics.Mean(name="disc_loss") + + # Setup optimizer (will be given with the compile method) + self.gen_optimizer: Optimizer = None + self.disc_optimizer: Optimizer = None + + self.checkpoint_path = checkpoint_path + + self.cross_entropy = Losses(self.num_replicas).bce_loss + self.reconstruction_loss = self.get_reconstruction_loss("mae") + + def get_perceptual_loss(self, loss_type: str): + if loss_type == "vgg16": + return PerceptualLoss(reduction=tf.keras.losses.Reduction.NONE) + elif loss_type == "vgg19": + return Losses(self.num_replicas).vgg_loss + elif loss_type == "style": + return Losses(self.num_replicas).style_loss + else: + raise ValueError(f"Unknown loss type: {loss_type}") + + def get_reconstruction_loss(self, loss_type: str): + if loss_type == "mse": + return Losses(self.num_replicas).mse_loss + elif loss_type == "mae": + return Losses(self.num_replicas).mae_loss + else: + raise ValueError(f"Unknown loss type: {loss_type}") + + def load_from_checkpoint(self, path): + self.load_weights(path) + + @property + def metrics(self): + # We list our `Metric` objects here so that `reset_states()` can be + # called automatically at the start of each epoch + # or at the start of `evaluate()`. + # If you don't implement this property, you have to call + # `reset_states()` yourself at the time of your choosing. + return [ + self.total_loss_tracker, + self.reconstruction_loss_tracker, + self.vq_loss_tracker, + self.disc_loss_tracker, + ] + + # def get_config(self): + # config = super().get_config() + # config.update( + # { + # "train_variance": self.train_variance, + # "vqvae_config": self.vqvae_config, + # # "autoencoder_config": self.autoencoder_config, + # "discriminator_config": self.discriminator_config, + # "loss_config": self.loss_config, + # } + # ) + # return config + + def _get_discriminator_loss(self, disc_loss): + if disc_loss == "hinge": + loss = hinge_d_loss + elif disc_loss == "vanilla": + loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + return loss + + def build(self, input_shape): + # Defer the shape initialization + # self.vqvae = self.get_vqvae(input_shape) + + super().build(input_shape) + self.built = True + if self.checkpoint_path is not None: + self.load_from_checkpoint(self.checkpoint_path) + + # def get_vqvae(self, input_shape): + # inputs = keras.Input(shape=input_shape[1:]) + # quant, indices, loss = self.encode(inputs) + # reconstructed = self.decode(quant) + # return keras.Model(inputs, reconstructed, name="vq_vae") + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return self.quantize(h) + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def call(self, inputs, training=True, mask=None): + quantized, encoding_indices, loss = self.encode(inputs) + reconstructed = self.decode(quantized) + return reconstructed, loss + + def predict(self, inputs): + output, loss = self(inputs) + output = (output + 1.0) * 127.5 / 255 + return output + + def calculate_adaptive_weight( + self, + nll_loss: tf.Tensor, + g_loss: tf.Tensor, + tape: tf.GradientTape, + trainable_vars: list, + discriminator_weight: float, + ) -> tf.Tensor: + """Calculate the adaptive weight for the discriminator which prevents mode collapse (https://arxiv.org/abs/2012.03149). + Args: + nll_loss (tf.Tensor): Negative log likelihood loss (the reconstruction loss). + g_loss (tf.Tensor): Generator loss (compared to the discriminator). + tape (tf.GradientTape): Gradient tape used to compute the nll_loss and g_loss + trainable_vars (list): List of trainable vars of the last layer (conv_out of the decoder) + discriminator_weight (float): Weight of the discriminator + Returns: + tf.Tensor: Discriminator weights used for the discriminator loss to benefits best the generator or discriminator and avoiding mode collapse. + """ + nll_grads = tape.gradient(nll_loss, trainable_vars)[0] + g_grads = tape.gradient(g_loss, trainable_vars)[0] + + d_weight = tf.norm(nll_grads) / (tf.norm(g_grads) + 1e-4) + d_weight = tf.stop_gradient(tf.clip_by_value(d_weight, 0.0, 1e4)) + return d_weight * discriminator_weight + + @tf.function + def adapt_weight( + self, weight: float, global_step: int, threshold: int = 0, value: float = 0.0 + ) -> float: + """Adapt the weight depending on the global step. If the global_step is lower than the threshold, the weight is set to value. Used to reduce the weight of the discriminator during the first iterations. + Args: + weight (float): The weight to adapt. + global_step (int): The global step of the optimizer + threshold (int, optional): The threshold under which the weight will be set to `value`. Defaults to 0. + value (float, optional): The value of the weight. Defaults to 0.0. + Returns: + float: The adapted weight + """ + if global_step < threshold: + weight = value + return weight + + def _get_global_step(self, optimizer: Optimizer): + """Get the global step of the optimizer.""" + return optimizer.iterations + + def compile( + self, + gen_optimizer, + disc_optimizer, + ): + super().compile() + self.gen_optimizer = gen_optimizer + self.disc_optimizer = disc_optimizer + + def get_vqvae_trainable_vars(self): + return ( + self.encoder.trainable_variables + + self.quant_conv.trainable_variables + + self.quantize.trainable_variables + + self.post_quant_conv.trainable_variables + + self.decoder.trainable_variables + ) + + # def gradient_penalty(self, real, f): + # def interpolate(a): + # beta = tf.random.uniform(shape=tf.shape(a), minval=0.0, maxval=1.0) + # _, variance = tf.nn.moments(a, list(range(a.shape.ndims))) + # b = a + 0.5 * tf.sqrt(variance) * beta + + # shape = tf.concat( + # (tf.shape(a)[0:1], tf.tile([1], [a.shape.ndims - 1])), axis=0 + # ) + # alpha = tf.random.uniform(shape=shape, minval=0.0, maxval=1.0) + # inter = a + alpha * (b - a) + # inter.set_shape(a.get_shape().as_list()) + + # return inter + + # x = interpolate(real) + # pred = f(x) + # gradients = tf.gradients(pred, x)[0] + # slopes = tf.sqrt( + # tf.reduce_sum(tf.square(gradients), axis=list(range(1, x.shape.ndims))) + # ) + # gp = tf.reduce_mean((slopes - 1.0) ** 2) + # return gp + + # def discriminator_loss(self, real_images, real_output, fake_output, discriminator): + # real_loss = self.cross_entropy( + # tf.ones_like(real_output), real_output + # ) # tf.losses.sigmoid_cross_entropy(tf.ones_like(real_output), real_output) + # fake_loss = self.cross_entropy( + # tf.zeros_like(fake_output), fake_output + # ) # tf.losses.sigmoid_cross_entropy(tf.zeros_like(fake_output), fake_output) + # gp = self.gradient_penalty(real_images, discriminator) + # total_loss = real_loss + fake_loss + 10.0 * gp + # return total_loss + + def generator_loss(self, fake_output): + return self.cross_entropy(tf.ones_like(fake_output), fake_output) + + def discriminator_loss(self, disc_real_output, disc_generated_output): + real_loss = self.cross_entropy(tf.ones_like(disc_real_output), disc_real_output) + + generated_loss = self.cross_entropy( + tf.zeros_like(disc_generated_output), disc_generated_output + ) + + total_disc_loss = real_loss + generated_loss + + return total_disc_loss + + # def train_step(self, data: Tuple[tf.Tensor, tf.Tensor]): + # x, y = data + + # # Train the generator + # with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: # Gradient tape for the final loss + # reconstructions, quantized_loss = self(x, training=True) + + # logits_fake = self.discriminator(reconstructions, training=True) + # logits_true = self.discriminator(x, training=True) + + # g_loss = self.generator_loss(logits_fake)#-tf.reduce_mean(logits_fake) + # disc_loss = self.discriminator_loss(x, logits_true, logits_fake, self.discriminator) + + # nll_loss = self.perceptual_loss(y, reconstructions) + + # disc_factor = self.adapt_weight( + # weight=self.disc_factor, + # global_step=self._get_global_step(self.gen_optimizer), + # threshold=self.discriminator_iter_start, + # ) + + # total_loss = ( + # self.perceptual_weight * nll_loss + # + disc_factor * g_loss + # + self.codebook_weight * quantized_loss + # ) + + # d_loss = disc_factor * disc_loss + + # # Backpropagation. + # gen_grads = gen_tape.gradient(total_loss, self.get_vqvae_trainable_vars()) + # self.gen_optimizer.apply_gradients(zip(gen_grads, self.get_vqvae_trainable_vars())) + + # disc_grads = disc_tape.gradient(d_loss, self.discriminator.trainable_variables) + # self.disc_optimizer.apply_gradients(zip(disc_grads, self.discriminator.trainable_variables)) + + # # Loss tracking. + # self.total_loss_tracker.update_state(total_loss) + # self.reconstruction_loss_tracker.update_state(nll_loss) + # self.vq_loss_tracker.update_state(quantized_loss) + # self.disc_loss_tracker.update_state(d_loss) + + # # Log results. + # return {m.name: m.result() for m in self.metrics} + + def train_step(self, data: Tuple[tf.Tensor, tf.Tensor]): + x, y = data + + # Train the generator + with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: # Gradient tape for the final loss + with tf.GradientTape( + persistent=True + ) as adaptive_tape: # Gradient tape for the adaptive weights + reconstructions, quantized_loss = self(x, training=True) + + disc_real_input = tf.image.resize( + x, [256, 512], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR + ) + disc_gen_input = tf.image.resize( + reconstructions, + [256, 512], + method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, + ) + logits_real = self.discriminator( + (disc_real_input, disc_real_input), + training=True, + ) + logits_fake = self.discriminator( + (disc_real_input, disc_gen_input), + training=True, + ) + + reconstruction_loss = self.reconstruction_loss(y, reconstructions) + if self.perceptual_weight > 0.0: + perceptual_loss = self.perceptual_weight * self.perceptual_loss( + y, reconstructions + ) + else: + perceptual_loss = 0.0 + + nll_loss = reconstruction_loss + perceptual_loss + + g_loss = -tf.reduce_mean(logits_fake) + # g_loss = self.generator_loss(logits_fake) + + d_weight = self.calculate_adaptive_weight( + nll_loss, + g_loss, + adaptive_tape, + self.decoder.conv_out.trainable_variables, + self.discriminator_weight, + ) + del adaptive_tape # Since persistent tape, important to delete it + # d_weight = 1.0 + + disc_factor = self.adapt_weight( + weight=self.disc_factor, + global_step=self._get_global_step(self.gen_optimizer), + threshold=self.discriminator_iter_start, + ) + + total_loss = ( + nll_loss + + d_weight * disc_factor * g_loss + + self.codebook_weight * quantized_loss + ) + + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + # d_loss = disc_factor * self.discriminator_loss( + # tf.concat((disc_real_input, disc_real_input), axis=-1), + # logits_real, + # logits_fake, + # self.discriminator, + # ) + # d_loss = disc_factor * self.discriminator_loss( + # logits_real, + # logits_fake, + # ) + + # Backpropagation. + grads = gen_tape.gradient(total_loss, self.get_vqvae_trainable_vars()) + self.gen_optimizer.apply_gradients(zip(grads, self.get_vqvae_trainable_vars())) + + # Backpropagation. + disc_grads = disc_tape.gradient(d_loss, self.discriminator.trainable_variables) + self.disc_optimizer.apply_gradients( + zip(disc_grads, self.discriminator.trainable_variables) + ) + + # Loss tracking. + self.total_loss_tracker.update_state(total_loss) + self.reconstruction_loss_tracker.update_state(nll_loss) + self.vq_loss_tracker.update_state(quantized_loss) + self.disc_loss_tracker.update_state(d_loss) + + # Log results. + return {m.name: m.result() for m in self.metrics} + + # def train_step(self, data: Tuple[tf.Tensor, tf.Tensor]): + # x, y = data + + # # Train the generator + # with tf.GradientTape() as tape: # Gradient tape for the final loss + # with tf.GradientTape( + # persistent=True + # ) as adaptive_tape: # Gradient tape for the adaptive weights + # reconstructions, quantized_loss = self(x, training=True) + + # logits_fake = self.discriminator(reconstructions, training=False) + + # g_loss = -tf.reduce_mean(logits_fake) + # nll_loss = self.perceptual_loss(y, reconstructions) + + # d_weight = self.calculate_adaptive_weight( + # nll_loss, + # g_loss, + # adaptive_tape, + # self.decoder.conv_out.trainable_variables, + # self.discriminator_weight, + # ) + # del adaptive_tape # Since persistent tape, important to delete it + + # disc_factor = self.adapt_weight( + # weight=self.disc_factor, + # global_step=self._get_global_step(self.gen_optimizer), + # threshold=self.discriminator_iter_start, + # ) + + # total_loss = ( + # self.perceptual_weight * nll_loss + # + d_weight * disc_factor * g_loss + # + self.codebook_weight * quantized_loss + # ) + + # # total_loss = ( + # # nll_loss + # # + d_weight * disc_factor * g_loss + # # # + self.codebook_weight * tf.reduce_mean(self.vqvae.losses) + # # + self.codebook_weight * sum(self.vqvae.losses) + # # ) + + # # Backpropagation. + # grads = tape.gradient(total_loss, self.get_vqvae_trainable_vars()) + # self.gen_optimizer.apply_gradients(zip(grads, self.get_vqvae_trainable_vars())) + + # # Discriminator + # with tf.GradientTape() as disc_tape: + # logits_real = self.discriminator(y, training=True) + # logits_fake = self.discriminator(reconstructions, training=True) + + # disc_factor = self.adapt_weight( + # weight=self.disc_factor, + # global_step=self._get_global_step(self.disc_optimizer), + # threshold=self.discriminator_iter_start, + # ) + # d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + # # Backpropagation. + # disc_grads = disc_tape.gradient(d_loss, self.discriminator.trainable_variables) + # self.disc_optimizer.apply_gradients( + # zip(disc_grads, self.discriminator.trainable_variables) + # ) + + # # Loss tracking. + # self.total_loss_tracker.update_state(total_loss) + # self.reconstruction_loss_tracker.update_state(nll_loss) + # self.vq_loss_tracker.update_state(quantized_loss) + # self.disc_loss_tracker.update_state(d_loss) + + # # Log results. + # return {m.name: m.result() for m in self.metrics} + + # def test_step(self, data: Tuple[tf.Tensor, tf.Tensor]): + # x, y = data + + # # Train the generator + # with tf.GradientTape( + # persistent=True + # ) as adaptive_tape: # Gradient tape for the adaptive weights + # reconstructions, quantized_loss = self(x, training=False) + + # logits_fake = self.discriminator(reconstructions, training=False) + + # g_loss = -tf.reduce_mean(logits_fake) + # nll_loss = self.perceptual_loss(y, reconstructions) + + # d_weight = self.calculate_adaptive_weight( + # nll_loss, + # g_loss, + # adaptive_tape, + # self.decoder.conv_out.trainable_variables, + # self.discriminator_weight, + # ) + # del adaptive_tape # Since persistent tape, important to delete it + + # disc_factor = self.adapt_weight( + # weight=self.disc_factor, + # global_step=self._get_global_step(self.gen_optimizer), + # threshold=self.discriminator_iter_start, + # ) + + # total_loss = ( + # nll_loss + # + d_weight * disc_factor * g_loss + # # + self.codebook_weight * tf.reduce_mean(self.vqvae.losses) + # + self.codebook_weight * quantized_loss + # ) + + # # Discriminator + # logits_real = self.discriminator(y, training=False) + # logits_fake = self.discriminator(reconstructions, training=False) + + # disc_factor = self.adapt_weight( + # weight=self.disc_factor, + # global_step=self._get_global_step(self.disc_optimizer), + # threshold=self.discriminator_iter_start, + # ) + # d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + # # Loss tracking. + # self.total_loss_tracker.update_state(total_loss) + # self.reconstruction_loss_tracker.update_state(nll_loss) + # self.vq_loss_tracker.update_state(quantized_loss) + # self.disc_loss_tracker.update_state(d_loss) + + # # Log results. + # return {m.name: m.result() for m in self.metrics} + + def test_step(self, data: Tuple[tf.Tensor, tf.Tensor]): + x, y = data + + with tf.GradientTape( + persistent=True + ) as adaptive_tape: # Gradient tape for the adaptive weights + reconstructions, quantized_loss = self(x, training=False) + + disc_real_input = tf.image.resize( + x, [256, 512], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR + ) + disc_gen_input = tf.image.resize( + reconstructions, + [256, 512], + method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, + ) + logits_real = self.discriminator( + (disc_real_input, disc_real_input), + training=False, + ) + logits_fake = self.discriminator( + (disc_real_input, disc_gen_input), + training=False, + ) + + reconstruction_loss = self.reconstruction_loss(y, reconstructions) + if self.perceptual_weight > 0.0: + perceptual_loss = self.perceptual_weight * self.perceptual_loss( + y, reconstructions + ) + else: + perceptual_loss = 0.0 + + nll_loss = reconstruction_loss + perceptual_loss + g_loss = -tf.reduce_mean(logits_fake) + # g_loss = self.generator_loss(logits_fake) + + d_weight = self.calculate_adaptive_weight( + nll_loss, + g_loss, + adaptive_tape, + self.decoder.conv_out.trainable_variables, + self.discriminator_weight, + ) + del adaptive_tape # Since persistent tape, important to delete it + # d_weight = 1.0 + + disc_factor = self.adapt_weight( + weight=self.disc_factor, + global_step=self._get_global_step(self.gen_optimizer), + threshold=self.discriminator_iter_start, + ) + + total_loss = ( + nll_loss + + d_weight * disc_factor * g_loss + + self.codebook_weight * quantized_loss + ) + + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + # d_loss = disc_factor * self.discriminator_loss( + # tf.concat((disc_real_input, disc_real_input), axis=-1), + # logits_real, + # logits_fake, + # self.discriminator, + # ) + # d_loss = disc_factor * self.discriminator_loss( + # logits_real, + # logits_fake, + # ) + + # Loss tracking. + self.total_loss_tracker.update_state(total_loss) + self.reconstruction_loss_tracker.update_state(nll_loss) + self.vq_loss_tracker.update_state(quantized_loss) + self.disc_loss_tracker.update_state(d_loss) + + # Log results. + return {m.name: m.result() for m in self.metrics} diff --git a/ganime/model/vqgan_clean/vqvae/__init__.py b/ganime/model/vqgan_clean/vqvae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ganime/model/vqgan_clean/vqvae/quantize.py b/ganime/model/vqgan_clean/vqvae/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..ecc739295910de52dd177c0b1ceab0cdb826db14 --- /dev/null +++ b/ganime/model/vqgan_clean/vqvae/quantize.py @@ -0,0 +1,79 @@ +import tensorflow as tf +from tensorflow.keras import layers + + +@tf.keras.utils.register_keras_serializable() +class VectorQuantizer(layers.Layer): + def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs): + super().__init__(**kwargs) + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + # This parameter is best kept between [0.25, 2] as per the paper. + self.beta = beta + + # Initialize the embeddings which we will quantize. + w_init = tf.random_uniform_initializer() + self.embeddings = tf.Variable( + initial_value=w_init( + shape=(self.embedding_dim, self.num_embeddings), dtype="float32" + ), + trainable=True, + name="embeddings_vqvae", + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "embedding_dim": self.embedding_dim, + "num_embeddings": self.num_embeddings, + "beta": self.beta, + } + ) + return config + + def call(self, x): + # Calculate the input shape of the inputs and + # then flatten the inputs keeping `embedding_dim` intact. + input_shape = tf.shape(x) + flattened = tf.reshape(x, [-1, self.embedding_dim]) + + # Quantization. + encoding_indices = self.get_code_indices(flattened) + encodings = tf.one_hot(encoding_indices, self.num_embeddings) + quantized = tf.matmul(encodings, self.embeddings, transpose_b=True) + quantized = tf.reshape(quantized, input_shape) + + # Calculate vector quantization loss and add that to the layer. You can learn more + # about adding losses to different layers here: + # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check + # the original paper to get a handle on the formulation of the loss function. + commitment_loss = self.beta * tf.reduce_mean( + (tf.stop_gradient(quantized) - x) ** 2 + ) + codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2) + loss = commitment_loss + codebook_loss + # self.add_loss(commitment_loss + codebook_loss) + + # Straight-through estimator. + quantized = x + tf.stop_gradient(quantized - x) + return quantized, encoding_indices, loss + + def get_code_indices(self, flattened_inputs): + # Calculate L2-normalized distance between the inputs and the codes. + similarity = tf.matmul(flattened_inputs, self.embeddings) + distances = ( + tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True) + + tf.reduce_sum(self.embeddings**2, axis=0) + - 2 * similarity + ) + + # Derive the indices for minimum distances. + encoding_indices = tf.argmin(distances, axis=1) + return encoding_indices + + def get_codebook_entry(self, indices, shape): + encodings = tf.one_hot(indices, self.num_embeddings) + quantized = tf.matmul(encodings, self.embeddings, transpose_b=True) + quantized = tf.reshape(quantized, shape) + return quantized \ No newline at end of file diff --git a/ganime/trainer/__init__.py b/ganime/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ganime/trainer/ganime.py b/ganime/trainer/ganime.py new file mode 100644 index 0000000000000000000000000000000000000000..fc66a67a5b3891cfec8b7222db153b4c011192ba --- /dev/null +++ b/ganime/trainer/ganime.py @@ -0,0 +1,86 @@ +import os +import numpy as np +from omegaconf import OmegaConf +import omegaconf +from ray.tune import Trainable +from ganime.model.base import load_model +from ganime.data.base import load_dataset +import tensorflow as tf + +from ganime.utils.callbacks import TensorboardImage + + +class TrainableGANime(Trainable): + def setup(self, config): + strategy = tf.distribute.MirroredStrategy() + + tune_config = self.load_config_file_and_replace(config) + self.batch_size = tune_config["trainer"]["batch_size"] + + self.n_devices = strategy.num_replicas_in_sync + self.global_batch_size = self.batch_size * self.n_devices + + self.train_dataset, self.validation_dataset, self.test_dataset = load_dataset( + dataset_name=config["dataset_name"], + dataset_path=config["dataset_path"], + batch_size=self.global_batch_size, + ) + + self.model = load_model(config["model"], config=tune_config, strategy=strategy) + + for data in self.train_dataset.take(1): + train_sample_data = data + for data in self.validation_dataset.take(1): + validation_sample_data = data + + tensorboard_image_callback = TensorboardImage( + self.logdir, train_sample_data, validation_sample_data + ) + checkpointing = tf.keras.callbacks.ModelCheckpoint( + os.path.join(self.logdir, "checkpoint", "checkpoint"), + monitor="total_loss", + save_best_only=True, + save_weights_only=True, + ) + self.callbacks = [tensorboard_image_callback, checkpointing] + + def load_config_file_and_replace(self, config): + cfg = OmegaConf.load(config["config_file"]) + hyperparameters = config["hyperparameters"] + + for hp_key, hp_value in hyperparameters.items(): + cfg = self.replace_item(cfg, hp_key, hp_value) + return cfg + + def replace_item(self, obj, key, replace_value): + for k, v in obj.items(): + if isinstance(v, dict) or isinstance(v, omegaconf.dictconfig.DictConfig): + obj[k] = self.replace_item(v, key, replace_value) + if key in obj: + obj[key] = replace_value + return obj + + def step(self): + + self.model.fit( + self.train_dataset, + initial_epoch=self.training_iteration, + epochs=self.training_iteration + 1, + callbacks=self.callbacks, + verbose=0, + ) + scores = self.model.evaluate(self.validation_dataset, verbose=0) + if np.nan in scores: + self.stop() + return dict(zip(self.model.metrics_names, scores)) + + def save_checkpoint(self, tmp_checkpoint_dir): + # checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") + # torch.save(self.model.state_dict(), checkpoint_path) + # return tmp_checkpoint_dir + pass + + def load_checkpoint(self, tmp_checkpoint_dir): + # checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") + # self.model.load_state_dict(torch.load(checkpoint_path)) + pass diff --git a/ganime/trainer/warmup/__init__.py b/ganime/trainer/warmup/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ganime/trainer/warmup/base.py b/ganime/trainer/warmup/base.py new file mode 100644 index 0000000000000000000000000000000000000000..afc8f4a72bb9606b340600583b0550cd7816ce2b --- /dev/null +++ b/ganime/trainer/warmup/base.py @@ -0,0 +1,20 @@ +from ganime.trainer.warmup.cosine import WarmUpCosine + + +def create_warmup_scheduler(trainer_config, num_devices): + len_x_train = trainer_config["len_x_train"] + batch_size = trainer_config["batch_size"] + n_epochs = trainer_config["n_epochs"] + + total_steps = int(len_x_train / batch_size * n_epochs / num_devices) + warmup_epoch_percentage = trainer_config["warmup_epoch_percentage"] + warmup_steps = int(total_steps * warmup_epoch_percentage) + + scheduled_lrs = WarmUpCosine( + lr_start=trainer_config["lr_start"], + lr_max=trainer_config["lr_max"], + warmup_steps=warmup_steps, + total_steps=total_steps, + ) + + return scheduled_lrs diff --git a/ganime/trainer/warmup/cosine.py b/ganime/trainer/warmup/cosine.py new file mode 100644 index 0000000000000000000000000000000000000000..45d608566cb8f2f29237bf8acee408d71a2f85dd --- /dev/null +++ b/ganime/trainer/warmup/cosine.py @@ -0,0 +1,78 @@ +import numpy as np +import tensorflow as tf +from tensorflow import keras + + +class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule): + """A LearningRateSchedule that uses a warmup cosine decay schedule.""" + + def __init__(self, lr_start, lr_max, warmup_steps, total_steps): + """ + Args: + lr_start: The initial learning rate + lr_max: The maximum learning rate to which lr should increase to in + the warmup steps + warmup_steps: The number of steps for which the model warms up + total_steps: The total number of steps for the model training + """ + super().__init__() + self.lr_start = lr_start + self.lr_max = lr_max + self.warmup_steps = warmup_steps + self.total_steps = total_steps + self.pi = tf.constant(np.pi) + + def __call__(self, step): + # Check whether the total number of steps is larger than the warmup + # steps. If not, then throw a value error. + if self.total_steps < self.warmup_steps: + raise ValueError( + f"Total number of steps {self.total_steps} must be" + + f"larger or equal to warmup steps {self.warmup_steps}." + ) + + # `cos_annealed_lr` is a graph that increases to 1 from the initial + # step to the warmup step. After that this graph decays to -1 at the + # final step mark. + cos_annealed_lr = tf.cos( + self.pi + * (tf.cast(step, tf.float32) - self.warmup_steps) + / tf.cast(self.total_steps - self.warmup_steps, tf.float32) + ) + + # Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes + # from 0 to 2. Normalize the graph with 0.5 so that now it goes from 0 + # to 1. With the normalized graph we scale it with `lr_max` such that + # it goes from 0 to `lr_max` + learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr) + + # Check whether warmup_steps is more than 0. + if self.warmup_steps > 0: + # Check whether lr_max is larger that lr_start. If not, throw a value + # error. + if self.lr_max < self.lr_start: + raise ValueError( + f"lr_start {self.lr_start} must be smaller or" + + f"equal to lr_max {self.lr_max}." + ) + + # Calculate the slope with which the learning rate should increase + # in the warumup schedule. The formula for slope is m = ((b-a)/steps) + slope = (self.lr_max - self.lr_start) / self.warmup_steps + + # With the formula for a straight line (y = mx+c) build the warmup + # schedule + warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start + + # When the current step is lesser that warmup steps, get the line + # graph. When the current step is greater than the warmup steps, get + # the scaled cos graph. + learning_rate = tf.where( + step < self.warmup_steps, warmup_rate, learning_rate + ) + + # When the current step is more that the total steps, return 0 else return + # the calculated graph. + return tf.where( + step > self.total_steps, 0.0, learning_rate, name="learning_rate" + ) diff --git a/ganime/utils/callbacks.py b/ganime/utils/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..2235207ab9d55d18ba53d62ae21170d4eb3643f0 --- /dev/null +++ b/ganime/utils/callbacks.py @@ -0,0 +1,224 @@ +import io +import os +from datetime import datetime +from typing import Optional + +import matplotlib.pyplot as plt +import numpy as np +import tensorflow as tf +from ganime.visualization.images import display_true_pred, unnormalize_if_necessary + + +def get_logdir(parent_folder: str, experiment_name: Optional[str] = None) -> str: + """Get the logdir used for logging in tensorboard. The logdir will be the parent folder with the experiment name and the current date and time. + + Args: + parent_folder (str): The parent folder of the logdir + experiment_name (str, optional): Optinal name of the experiment. Defaults to "". + + Returns: + str: The path of the logdir that can be used by Tensorboard + """ + current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + sub_folder = ( + f"{experiment_name}_{current_time}" if experiment_name else current_time + ) + logdir = os.path.join(parent_folder, sub_folder) + return logdir + + +def plot_to_image(figure): + """Converts the matplotlib plot specified by 'figure' to a PNG image and + returns it. The supplied figure is closed and inaccessible after this call.""" + # Save the plot to a PNG in memory. + buf = io.BytesIO() + plt.savefig(buf, format="png") + # Closing the figure prevents it from being displayed directly inside + # the notebook. + plt.close(figure) + buf.seek(0) + # Convert PNG buffer to TF image + image = tf.image.decode_png(buf.getvalue(), channels=4) + # Add the batch dimension + image = tf.expand_dims(image, 0) + return image + + +class TensorboardImage(tf.keras.callbacks.Callback): + def __init__( + self, + logdir: str, + train: np.array, + validation: np.array = None, + n_images: int = 8, + ): + super().__init__() + self.logdir = logdir + self.train = train + self.validation = validation + self.file_writer = tf.summary.create_file_writer(logdir) + self.n_images = n_images + + def on_epoch_end(self, epoch, logs): + train_X, train_y = self.train + train_X, train_y = self.truncate_X_y(train_X, train_y, self.n_images) + train_pred = self.model.predict(train_X) + self.write_to_tensorboard(train_y, train_pred, "Training data", epoch) + + if self.validation is not None: + validation_X, validation_y = self.validation + validation_X, validation_y = self.truncate_X_y( + validation_X, validation_y, self.n_images + ) + validation_pred = self.model.predict(validation_X) + self.write_to_tensorboard( + validation_y, validation_pred, "Validation data", epoch + ) + + def truncate_X_y(self, X, y, n_images): + """Truncate the X and y arrays to the first n_images.""" + X = X[:n_images] + y = y[:n_images] + return X, y + + def write_to_tensorboard(self, y_true, y_pred, tag, step): + with self.file_writer.as_default(): + tf.summary.image( + tag, + plot_to_image(display_true_pred(y_true, y_pred, n_cols=len(y_true))), + step=step, + ) + + +class TensorboardVideo(tf.keras.callbacks.Callback): + def __init__( + self, + logdir: str, + train: np.array, + validation: np.array = None, + n_videos: int = 3, + ): + super().__init__() + self.logdir = logdir + self.train = train + self.validation = validation + self.file_writer = tf.summary.create_file_writer(logdir) + self.n_videos = n_videos + + def on_epoch_end(self, epoch, logs): + + # train_X, train_y = self.train + # train_X, train_y = self.truncate_X_y(train_X, train_y, self.n_videos) + train_pred = self.model.predict(self.train) + self.write_to_tensorboard( + unnormalize_if_necessary(self.train["y"]), + train_pred, + "Training data", + epoch, + ) + + if self.validation is not None: + # validation_X, validation_y = self.validation + # validation_X, validation_y = self.truncate_X_y( + # validation_X, validation_y, self.n_videos + # ) + validation_pred = self.model.predict(self.validation) + self.write_to_tensorboard( + unnormalize_if_necessary(self.validation["y"]), + validation_pred, + "Validation data", + epoch, + ) + + def truncate_X_y(self, X, y, n_videos): + """Truncate the X and y arrays to the first n_videos.""" + X = X[:n_videos] + y = y[:n_videos] + return X, y + + def write_to_tensorboard(self, y_true, y_pred, tag, step): + stacked = tf.concat([y_pred, y_true], axis=2) + self.video_summary(tag, stacked, step) + self.image_summary(tag + "/images", y_true, y_pred, step) + + def image_summary(self, tag, y_true, y_pred, step): + batch, n_frames, height, width, channels = y_true.shape + images = np.empty( + (batch * 2, n_frames, height, width, channels), dtype=np.float32 + ) + + images[0::2] = y_pred + images[1::2] = y_true + images = tf.transpose(images, (0, 2, 1, 3, 4)) + images = tf.reshape(images, (height * batch * 2, width * n_frames, channels)) + + with self.file_writer.as_default(): + tf.summary.image(tag, [images], step=step) + + def add_red_border(self, image_batch): + image_batch = image_batch.copy() + dtype = image_batch.dtype + min_value = 0 + max_value = 1 if dtype in [np.float16, np.float32, np.float64] else 255 + # top + image_batch[:, 0:2, :, 0] = max_value + image_batch[:, 0:2, :, 1] = min_value + image_batch[:, 0:2, :, 2] = min_value + # bottom + image_batch[:, -2:, :, 0] = max_value + image_batch[:, -2:, :, 1] = min_value + image_batch[:, -2:, :, 2] = min_value + # left + image_batch[:, :, 0:2, 0] = max_value + image_batch[:, :, 0:2, 1] = min_value + image_batch[:, :, 0:2, 2] = min_value + # right + image_batch[:, :, -2:, 0] = max_value + image_batch[:, :, -2:, 1] = min_value + image_batch[:, :, -2:, 2] = min_value + return image_batch + + def video_summary(self, name, video, step=None, fps=10): + name = tf.constant(name).numpy().decode("utf-8") + video = np.array(video) + if video.dtype in (np.float32, np.float64): + video = np.clip(255 * video, 0, 255).astype(np.uint8) + B, T, H, W, C = video.shape + # video[:, 0] = self.add_red_border(video[:, 0]) + + with self.file_writer.as_default(): + try: + frames = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C)) + summary = tf.compat.v1.Summary() + image = tf.compat.v1.Summary.Image( + height=B * H, width=T * W, colorspace=C + ) + image.encoded_image_string = self.encode_gif(frames, fps) + summary.value.add(tag=name + "/gif", image=image) + tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step) + except (IOError, OSError) as e: + print("GIF summaries require ffmpeg in $PATH.", e) + frames = video.transpose((0, 2, 1, 3, 4)).reshape((1, B * H, T * W, C)) + tf.summary.image(name + "/grid", frames, step) + + def encode_gif(self, frames, fps): + from subprocess import PIPE, Popen + + h, w, c = frames[0].shape + pxfmt = {1: "gray", 3: "rgb24"}[c] + cmd = " ".join( + [ + f"ffmpeg -y -f rawvideo -vcodec rawvideo", + f"-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex", + f"[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse", + f"-r {fps:.02f} -f gif -", + ] + ) + proc = Popen(cmd.split(" "), stdin=PIPE, stdout=PIPE, stderr=PIPE) + for image in frames: + proc.stdin.write(image.tostring()) + out, err = proc.communicate() + if proc.returncode: + raise IOError("\n".join([" ".join(cmd), err.decode("utf8")])) + del proc + return out diff --git a/ganime/utils/recompute_grad.py b/ganime/utils/recompute_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..22ec1ed3a4acf79fef1470f00e1a10ac3373065b --- /dev/null +++ b/ganime/utils/recompute_grad.py @@ -0,0 +1,238 @@ +# coding=utf-8 +# Copyright 2022 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library for rematerialization. +Incubates a version of tf.recompute_grad that is XLA compatible. +""" +import collections +import os +import threading +from typing import Deque, List, NamedTuple, Optional, Sequence + +from absl import logging +import numpy as np +import tensorflow as tf + + +class RecomputeContext( + NamedTuple( + "RecomputeContext", + [ + ("is_recomputing", bool), + ("seed", tf.Tensor), + ("children", Deque["RecomputeContext"]), + ], + ) +): + """Context for recomputation. + Attributes: + is_recomputing: Whether we are in a recomputation phase. + seed: Scalar integer tensor that should be used with stateless random ops + for deterministic behavior and correct computation of the gradient. + children: Nested `RecomputeContext` instances. Used internally by + `recompute_grad` to track nested instances of `RecomputeContext`. + """ + + def __enter__(self): + return _context_stack.push(self) + + def __exit__(self, exc_type, exc_value, traceback): + _context_stack.pop(self) + + +# Simplified version of `_DefaultStack` in +# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/ops.py. +class _ContextStack(threading.local): + """A thread-local stack for providing implicit recompute contexts.""" + + def __init__(self): + super(_ContextStack, self).__init__() + self._stack = [] + + def top(self) -> Optional[RecomputeContext]: + return self._stack[-1] if self._stack else None + + def push(self, context: RecomputeContext): + self._stack.append(context) + return context + + def pop(self, context: RecomputeContext): + if self._stack[-1] is not context: + raise AssertionError("Nesting violated for RecomputeContext.") + self._stack.pop() + + +_context_stack = _ContextStack() + + +def get_recompute_context() -> Optional[RecomputeContext]: + """Returns the current recomputing context if it exists.""" + return _context_stack.top() + + +# Adapted from +# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/control_flow_util.py. +def _get_containing_xla_context(graph: tf.Graph) -> Optional[object]: + """Returns the first ancestor `XLAControlFlowContext` in the `graph`.""" + ctxt = graph._get_control_flow_context() # pylint: disable=protected-access + while ctxt: + if ctxt.IsXLAContext(): + return ctxt + ctxt = ctxt.outer_context + return None + + +def _in_xla_context(graph: Optional[tf.Graph] = None) -> bool: + """Detects whether we are in an XLA context.""" + if "--tf_xla_auto_jit=2" in os.environ.get("TF_XLA_FLAGS", ""): + return True + graph = tf.compat.v1.get_default_graph() if graph is None else graph + while True: + if _get_containing_xla_context(graph) is not None: + return True + try: + graph = graph.outer_graph + except AttributeError: + return False + + +def _force_data_dependency( + first_compute: Sequence[tf.Tensor], then_compute: Sequence[tf.Tensor] +) -> List[tf.Tensor]: + """Force all of `then_compute` to depend on all of `first_compute`. + Uses a dummy data dependency, which is useful when running on TPUs because + XLA ignores control dependencies. Only supports float arguments. + Args: + first_compute: Sequence of `Tensor`s to be executed before `then_compute`. + then_compute: Sequence of `Tensor`s to executed after `first_compute`. + Returns: + Sequence of `Tensor`s with same length of `then_compute`. + Raises: + ValueError: if ranks are unknown or types are not floating. + """ + + def _first_element(x): + if x.shape.ndims is None: + raise ValueError("Rank of Tensor %s must be known" % x) + ndims = x.shape.ndims + begin = tf.zeros(ndims, dtype=tf.int32) + size = tf.ones(ndims, dtype=tf.int32) + return tf.reshape(tf.slice(x, begin, size), []) + + first_compute_sum = tf.add_n( + [_first_element(x) for x in first_compute if x is not None] + ) + dtype = first_compute_sum.dtype + if not dtype.is_floating: + raise ValueError("_force_data_dependency only supports floating dtypes.") + zero = np.finfo(dtype.as_numpy_dtype).tiny * first_compute_sum + return [x + tf.cast(zero, x.dtype) if x is not None else None for x in then_compute] + + +def _make_seed_if_none(seed: Optional[tf.Tensor]) -> tf.Tensor: + """Uses the global generator to make a seed if necessary.""" + if seed is not None: + return seed + generator = tf.random.experimental.get_global_generator() + # The two seeds for stateless random ops don't have individual semantics and + # are scrambled together, so providing one seed is fine. This makes it easier + # for users to provide a local seed without worrying about integer overflow. + # See `make_seeds` in + # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/stateful_random_ops.py. + try: + return generator.uniform_full_int([], tf.int32, name="recompute_grad_seed") + except (RuntimeError, TypeError, ValueError, tf.errors.NotFoundError) as e: + # For a number of reasons, the above operation can fail like using multiple + # graphs or toggling between eager and graph modes. Reset the generator. + logging.warn("Resetting the generator. %s: %s", type(e), e) + tf.random.experimental.set_global_generator(None) + generator = tf.random.experimental.get_global_generator() + return generator.uniform_full_int([], tf.int32, name="recompute_grad_seed") + + +def recompute_grad(f, seed=None): + """An eager-compatible version of recompute_grad. + For f(*args, **kwargs), this supports gradients with respect to args, or to + gradients with respect to any variables residing in the kwarg 'variables'. + Note that for keras layer and model objects, this is handled automatically. + Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not + be able to access the member variables of that object, because `g` returns + through the wrapper function `inner`. When recomputing gradients through + objects that inherit from keras, we suggest keeping a reference to the + underlying object around for the purpose of accessing these variables. + Args: + f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs. + seed: Optional seed for random ops. `seed` should an integer scalar + `Tensor`. When compiling to XLA, `seed` must have dtype `tf.int32`. If + `seed` is not provided one will be generated. + Returns: + A function `g` that wraps `f`, but which recomputes `f` on the backwards + pass of a gradient call. + """ + + @tf.custom_gradient + def inner(*args, **kwargs): + """Inner function closure for calculating gradients.""" + # Detect when we're nested and in the backwards pass, so we don't generate + # an additional seed. + parent_context = get_recompute_context() + if parent_context is not None and parent_context.is_recomputing: + # Use the cached context in the recomputation phase. + with parent_context.children.popleft()._replace( + is_recomputing=True + ) as context: + result = f(*args, **kwargs) + else: + with RecomputeContext( + is_recomputing=False, + seed=_make_seed_if_none(seed), + children=collections.deque(), + ) as context: + result = f(*args, **kwargs) + # In the forward pass, build up a tree of recomputation contexts. + if parent_context is not None and not parent_context.is_recomputing: + parent_context.children.append(context) + + def grad(*dresult, **grad_kwargs): + """Gradient function calculation for inner function.""" + variables = grad_kwargs.pop("variables", None) + if grad_kwargs: + raise ValueError( + "Found unexpected kwargs for `grad`: ", list(grad_kwargs.keys()) + ) + inputs, seed = list(args), context.seed + if _in_xla_context(): + inputs = _force_data_dependency( + tf.nest.flatten(dresult), inputs + [seed] + ) + seed = inputs.pop() + with tf.GradientTape() as tape: + tape.watch(inputs) + if variables is not None: + tape.watch(variables) + with tf.control_dependencies(dresult): + with context._replace(is_recomputing=True, seed=seed): + result = f(*inputs, **kwargs) + kw_vars = [] + if variables is not None: + kw_vars = list(variables) + grads = tape.gradient( + result, list(inputs) + kw_vars, output_gradients=dresult + ) + return grads[: len(inputs)], grads[len(inputs) :] + + return result, grad + + return inner diff --git a/ganime/utils/statistics.py b/ganime/utils/statistics.py new file mode 100644 index 0000000000000000000000000000000000000000..903b6e85d66c7ecb9b202fe4600ab0096bf1db37 --- /dev/null +++ b/ganime/utils/statistics.py @@ -0,0 +1,18 @@ +import numpy as np +from tqdm.auto import tqdm +import tensorflow as tf +import tensorflow_datasets as tfds + + +def dataset_statistics(ds): + if isinstance(ds, tf.data.Dataset): + ds_numpy = tfds.as_numpy(ds) + elif isinstance(ds, tf.keras.utils.Sequence): + ds_numpy = ds + data = [] + + for da in tqdm(ds_numpy): + X, y = da + data.append(X) + all_data = np.concatenate(data) + return np.mean(all_data), np.var(all_data), np.std(all_data) diff --git a/ganime/visualization/__init__.py b/ganime/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ganime/visualization/images.py b/ganime/visualization/images.py new file mode 100644 index 0000000000000000000000000000000000000000..ae4bdc05d54b7e2ffe7854f1549245811b18d6c1 --- /dev/null +++ b/ganime/visualization/images.py @@ -0,0 +1,53 @@ +import matplotlib.pyplot as plt +import numpy as np +import tensorflow as tf + + +def display_images(data, n_rows=3, n_cols=3): + figure, axs = plt.subplots(n_rows, n_cols, figsize=(24, 12)) + + axs = axs.flatten() + + plt.setp(axs, xticks=[], yticks=[]) + plt.subplots_adjust(wspace=0, hspace=0) + + for img, ax in zip(data, axs): + img = unnormalize_if_necessary(img) + ax.imshow(img) + + return figure + + +def unnormalize_if_necessary(x): + if isinstance(x, np.ndarray): + if x.min() < 0: + return (x * 0.5) + 0.5 + elif isinstance(x, tf.Tensor): + if x.numpy().min() < 0: + return (x * 0.5) + 0.5 + return x + + +def display_true_pred(y_true, y_pred, n_cols=3): + + fig = plt.figure(constrained_layout=True, figsize=(24, 12)) + + y_true = unnormalize_if_necessary(y_true) + y_pred = unnormalize_if_necessary(y_pred) + + images = [y_pred, y_true] + + # create 2x1 subfigs + subfigs = fig.subfigures(nrows=2, ncols=1) + for row, subfig in enumerate(subfigs): + subfig.suptitle("Prediction" if row == 0 else "Ground truth", fontsize=24) + + # create 1xn_cols subplots per subfig + axs = subfig.subplots(nrows=1, ncols=n_cols) + for col, ax in enumerate(axs): + if row == 0: + ax.imshow(images[row][col]) + else: + ax.imshow(images[row][col]) + + return fig diff --git a/ganime/visualization/videos.py b/ganime/visualization/videos.py new file mode 100644 index 0000000000000000000000000000000000000000..069ef04257e0b7e92bcd765db9c1dcb750fb3959 --- /dev/null +++ b/ganime/visualization/videos.py @@ -0,0 +1,64 @@ +import matplotlib.pyplot as plt +import numpy as np +from IPython.display import HTML +from matplotlib import animation +from ganime.visualization.images import unnormalize_if_necessary + + +def display_videos(data, ground_truth=None, n_rows=3, n_cols=3): + + if ground_truth is not None: + data = np.concatenate((data, ground_truth), axis=2) + + fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, squeeze=False, figsize=(16, 9)) + + # remove grid and ticks + plt.setp(axs, xticks=[], yticks=[]) + plt.subplots_adjust(wspace=0, hspace=0) + + ims = [] + + for i in range(n_rows): + for j in range(n_cols): + idx = i * n_cols + j + video = data[idx] + frame = unnormalize_if_necessary(video[0]) + im = axs[i][j].imshow(frame, animated=True) + ims.append(im) + + plt.close() # this is required to not display the generated image + + def init(): + for i in range(n_rows): + for j in range(n_cols): + idx = i * n_cols + j + video = data[idx] + im = ims[idx] + frame = unnormalize_if_necessary(video[0]) + im.set_data(frame) + return ims + + def animate(frame_id): + for i in range(n_rows): + for j in range(n_cols): + idx = i * n_cols + j + video = data[idx] + frame = video[frame_id, :, :, :] + frame = unnormalize_if_necessary(frame) + # if frame_id % 2 == 0: + # d[0:2, :, 0] = 255 + # d[0:2, :, 1] = 0 + # d[0:2, :, 2] = 0 + # d[-2:, :, 0] = 255 + # d[-2:, :, 1] = 0 + # d[-2:, :, 2] = 0 + ims[idx].set_data(frame) + return ims + + anim = animation.FuncAnimation( + fig, animate, init_func=init, frames=data.shape[1], blit=True, interval=200 + ) + # FFwriter = animation.FFMpegWriter(fps=10, codec="libx264") + # anim.save("basic_animation1.mp4", writer=FFwriter) + + return HTML(anim.to_html5_video()) diff --git a/models/NetLinLayer/numpy_0.npy b/models/NetLinLayer/numpy_0.npy new file mode 100644 index 0000000000000000000000000000000000000000..0ef44d56a413f34aa961fb8a2f53080247d67a2a --- /dev/null +++ b/models/NetLinLayer/numpy_0.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee550fccfec1d637a09b856b50541d7b4e4f8989d13352d9a3fec36887037b47 +size 384 diff --git a/models/NetLinLayer/numpy_1.npy b/models/NetLinLayer/numpy_1.npy new file mode 100644 index 0000000000000000000000000000000000000000..98c2ee242b32a309b1b526e884635c85a8020160 --- /dev/null +++ b/models/NetLinLayer/numpy_1.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bd481af594f9a3f8813f97cfd5c278de2d3fe71b4e567b03bbc1543a2d5902a5 +size 640 diff --git a/models/NetLinLayer/numpy_2.npy b/models/NetLinLayer/numpy_2.npy new file mode 100644 index 0000000000000000000000000000000000000000..b4c36a9b9488c26f605b8bbd7000cc5297109bf9 --- /dev/null +++ b/models/NetLinLayer/numpy_2.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ac795997435fac6e95b83339c9e832896e316a52caed695cdb0cf2d98987463 +size 1152 diff --git a/models/NetLinLayer/numpy_3.npy b/models/NetLinLayer/numpy_3.npy new file mode 100644 index 0000000000000000000000000000000000000000..16a21eecdc4d389d3c397a984cd014d18441fb24 --- /dev/null +++ b/models/NetLinLayer/numpy_3.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:99dd32066383ac347b32c5e060a1d3287021619b51a566f4973dd3ef2e08f098 +size 2176 diff --git a/models/NetLinLayer/numpy_4.npy b/models/NetLinLayer/numpy_4.npy new file mode 100644 index 0000000000000000000000000000000000000000..50428f5735b530a78d529e014678fae71992def5 --- /dev/null +++ b/models/NetLinLayer/numpy_4.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9fd9f1b87ee177d4ec156627892ef766b9c06be569bf47cf68d33b301ea4403 +size 2176 diff --git a/models/vgg19/imagenet-vgg-verydeep-19.mat b/models/vgg19/imagenet-vgg-verydeep-19.mat new file mode 100644 index 0000000000000000000000000000000000000000..2044699a3ec35ff78ade13b5c8c58420bfb4e21d --- /dev/null +++ b/models/vgg19/imagenet-vgg-verydeep-19.mat @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:abdb57167f82a2a1fbab1e1c16ad9373411883f262a1a37ee5db2e6fb0044695 +size 534904783 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9ae622d0b54f51370585b93bc7c96ef57fd27007 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,206 @@ +# absl-py==0.15.0 +# aiohttp==3.8.3 +# aiosignal==1.2.0 +# antlr4-python3-runtime==4.9.3 +# anyio==3.6.1 +# argon2-cffi==21.1.0 +# astunparse==1.6.3 +# async-timeout==4.0.2 +# attrs==21.2.0 +# Babel==2.10.3 +# backcall==0.2.0 +# bcrypt==4.0.1 +# beautifulsoup4==4.11.1 +# bleach==4.1.0 +# cachetools==4.2.4 +# certifi==2019.11.28 +# cffi==1.15.0 +# chardet==3.0.4 +# charset-normalizer==2.1.1 +# click==8.1.3 +# cryptography==38.0.1 +# cycler==0.11.0 +# dbus-python==1.2.16 +# debugpy==1.5.1 +# decorator==5.1.0 +# defusedxml==0.7.1 +# Deprecated==1.2.13 +# entrypoints==0.3 +# fastapi==0.85.0 +# fastjsonschema==2.16.2 +# ffmpegio==0.7.0 +# ffmpegio-core==0.7.0 +# ffmpy==0.3.0 +# filelock==3.8.0 +# flatbuffers==2.0 +# fonttools==4.37.4 +# frozenlist==1.3.1 +# fsspec==2022.8.2 +# gast==0.4.0 +# google-auth==2.3.3 +# google-auth-oauthlib==0.4.6 +# google-pasta==0.2.0 +# gradio==3.4.1 +# grpcio==1.41.1 +# h11==0.12.0 +# h5py==3.5.0 +# httpcore==0.15.0 +# httpx==0.23.0 +# huggingface-hub==0.10.1 +# idna==2.8 +# imageio==2.22.1 +# importlib-metadata==5.0.0 +# importlib-resources==5.4.0 +# ipykernel==5.1.1 +# ipython==7.29.0 +# ipython-genutils==0.2.0 +# ipywidgets==7.6.5 +# jedi==0.17.2 +# Jinja2==3.0.2 +# joblib==1.2.0 +# json5==0.9.10 +# jsonschema==4.2.0 +# jupyter==1.0.0 +# jupyter-client==7.0.6 +# jupyter-console==6.4.0 +# jupyter-core==4.9.1 +# jupyter-http-over-ws==0.0.8 +# jupyter-server==1.21.0 +# jupyterlab==3.3.2 +# jupyterlab-pygments==0.1.2 +# jupyterlab-server==2.16.0 +# jupyterlab-widgets==1.0.2 +# keras==2.7.0 +# Keras-Preprocessing==1.1.2 +# kiwisolver==1.3.2 +# libclang==12.0.0 +# linkify-it-py==1.0.3 +# Markdown==3.3.4 +# markdown-it-py==2.1.0 +# MarkupSafe==2.0.1 +# matplotlib==3.5.1 +# matplotlib-inline==0.1.3 +# mdit-py-plugins==0.3.1 +# mdurl==0.1.2 +# mistune==2.0.4 +# msgpack==1.0.4 +# multidict==6.0.2 +# nbclassic==0.4.6 +# nbclient==0.5.4 +# nbconvert==7.2.1 +# nbformat==5.7.0 +# nest-asyncio==1.5.1 +# networkx==2.8.7 +# notebook==6.4.5 +# notebook-shim==0.1.0 +# numpy==1.21.3 +# oauthlib==3.1.1 +# omegaconf==2.2.3 +# opencv-python==4.6.0.66 +# opt-einsum==3.3.0 +# orjson==3.8.0 +# packaging==21.2 +# pandas==1.5.0 +# pandocfilters==1.5.0 +# paramiko==2.11.0 +# parso==0.7.1 +# pexpect==4.8.0 +# pickleshare==0.7.5 +# Pillow==8.4.0 +# pluggy==1.0.0 +# prometheus-client==0.12.0 +# prompt-toolkit==3.0.22 +# protobuf==3.19.1 +# ptyprocess==0.7.0 +# pyasn1==0.4.8 +# pyasn1-modules==0.2.8 +# pycparser==2.20 +# pycryptodome==3.15.0 +# pydantic==1.10.2 +# pydub==0.25.1 +# Pygments==2.10.0 +# PyGObject==3.36.0 +# PyNaCl==1.5.0 +# pyparsing==2.4.7 +# pyprojroot==0.2.0 +# pyrsistent==0.18.0 +# python-apt==2.0.0+ubuntu0.20.4.8 +# python-dateutil==2.8.2 +# python-multipart==0.0.5 +# pytz==2022.4 +# PyWavelets==1.4.1 +# PyYAML==6.0 +# pyzmq==22.3.0 +# qtconsole==5.1.1 +# QtPy==1.11.2 +# ray==1.11.0 +# redis==4.3.4 +# regex==2022.9.13 +# requests==2.22.0 +# requests-oauthlib==1.3.0 +# requests-unixsocket==0.2.0 +# rfc3986==1.5.0 +# rsa==4.7.2 +# scenedetect==0.5.6.1 +# scikit-image==0.19.2 +# scipy==1.9.2 +# Send2Trash==1.8.0 +# six==1.14.0 +# sk-video==1.1.10 +# sniffio==1.3.0 +# soupsieve==2.3.2.post1 +# starlette==0.20.4 +# tensorboard==2.7.0 +# tensorboard-data-server==0.6.1 +# tensorboard-plugin-wit==1.8.0 +# tensorflow==2.7.0 +# tensorflow-addons==0.18.0 +# tensorflow-estimator==2.7.0 +# tensorflow-io-gcs-filesystem==0.21.0 +# termcolor==1.1.0 +# terminado==0.12.1 +# testpath==0.5.0 +# tifffile==2022.10.10 +# tinycss2==1.1.1 +# tokenizers==0.13.1 +# torch==1.12.1 +# torch-vision==0.1.6.dev0 +# tornado==6.1 +# tqdm==4.64.1 +# traitlets==5.1.1 +# transformers==4.23.1 +# typeguard==2.13.3 +# typing-extensions==4.4.0 +# uc-micro-py==1.0.1 +# urllib3==1.25.8 +# uvicorn==0.18.3 +# wcwidth==0.2.5 +# webencodings==0.5.1 +# websocket-client==1.4.1 +# websockets==10.3 +# Werkzeug==2.0.2 +# widgetsnbextension==3.5.2 +# wrapt==1.13.3 +# yarl==1.8.1 +# zipp==3.6.0 + +tensorflow==2.7.0 +tensorflow-addons +matplotlib==3.5.1 +omegaconf +ray==1.11.0 +scikit-image==0.19.2 +scenedetect==0.5.6.1 +transformers +tqdm +jupyterlab==3.3.2 +typing-extensions>=4.1.0 +pyprojroot +torch +torch-vision +gradio +opencv-python +joblib +sk-video +ffmpegio +h11==0.12.0 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..91d0058e12f72bcbf1460c606425aa670569489d --- /dev/null +++ b/setup.py @@ -0,0 +1,18 @@ +"""Setup file to install a package with pip""" +import setuptools + + +setuptools.setup( + name="ganime", + version="0.1", + author="farid.abdalla", + author_email="farid.abdalla.13@gmail.com", + packages=setuptools.find_packages(), + license="", + description="", + long_description=open("README.md").read(), + install_requires=open("requirements.txt").readlines(), + # extras_require={ + # "dev": open("requirements_dev.txt").readlines() + # } +)