diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..6488b1112f406c87d4de154974d8144ce7055bdb --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +ucf101_stride4x4x4 +__pycache__ +*.mp4 +.ipynb_checkpoints +*.pth +UCF-101/ +results/ +vae +build/ +opensora.egg-info/ +wandb/ +.idea +*.ipynb +*.jpg +*.mp3 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b01a7d6ec35639c2167bd20b005dd09655a71a68 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 PKU-YUAN's Group (袁粒课题组-北大信工) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 57f17dabfb4e1a910ff151adbce6b7cf31e2b368..3696edf081d800e0fe0cf1f1a5b1353dac35a785 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ emoji: 🦀 colorFrom: indigo colorTo: red sdk: gradio -sdk_version: 4.25.0 +sdk_version: 3.37.0 app_file: app.py pinned: false license: mit diff --git a/docker/LICENSE b/docker/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..c4bc32c8006b48bcc0724e7b7c64731664edd7cd --- /dev/null +++ b/docker/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 SimonLee + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/docker/README.md b/docker/README.md new file mode 100644 index 0000000000000000000000000000000000000000..624d845f3e3415a49c997112c16fe843bb403479 --- /dev/null +++ b/docker/README.md @@ -0,0 +1,87 @@ +# Docker4ML + +Useful docker scripts for ML developement. +[https://github.com/SimonLeeGit/Docker4ML](https://github.com/SimonLeeGit/Docker4ML) + +## Build Docker Image + +```bash +bash docker_build.sh +``` + +![build_docker](build_docker.png) + +## Run Docker Container as Development Envirnoment + +```bash +bash docker_run.sh +``` + +![run_docker](run_docker.png) + +## Custom Docker Config + +### Config [setup_env.sh](./setup_env.sh) + +You can modify this file to custom your settings. + +```bash +TAG=ml:dev +BASE_TAG=nvcr.io/nvidia/pytorch:23.12-py3 +``` + +#### TAG + +Your built docker image tag, you can set it as what you what. + +#### BASE_TAG + +The base docker image tag for your built docker image, here we use nvidia pytorch images. +You can check it from [https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags) + +Also, you can use other docker image as base, such as: [ubuntu](https://hub.docker.com/_/ubuntu/tags) + +### USER_NAME + +Your user name used in docker container. + +### USER_PASSWD + +Your user password used in docker container. + +### Config [requriements.txt](./requirements.txt) + +You can add your default installed python libraries here. + +```txt +transformers==4.27.1 +``` + +By default, it has some libs installed, you can check it from [https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-01.html](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-01.html) + +### Config [packages.txt](./packages.txt) + +You can add your default apt-get installed packages here. + +```txt +wget +curl +git +``` + +### Config [ports.txt](./ports.txt) + +You can add some ports enabled for docker container here. + +```txt +-p 6006:6006 +-p 8080:8080 +``` + +### Config [postinstallscript.sh](./postinstallscript.sh) + +You can add your custom script to run when build docker image. + +## Q&A + +If you have any use problems, please contact to . diff --git a/docker/build_docker.png b/docker/build_docker.png new file mode 100644 index 0000000000000000000000000000000000000000..a5dbf2c1426ecad3e2a18e18f59319c0de86f766 Binary files /dev/null and b/docker/build_docker.png differ diff --git a/docker/docker_build.sh b/docker/docker_build.sh new file mode 100644 index 0000000000000000000000000000000000000000..6e82a157f448c248cd4241eea033f13b9ad34a0d --- /dev/null +++ b/docker/docker_build.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +WORK_DIR=$(dirname "$(readlink -f "$0")") +cd $WORK_DIR + +source setup_env.sh + +docker build -t $TAG --build-arg BASE_TAG=$BASE_TAG --build-arg USER_NAME=$USER_NAME --build-arg USER_PASSWD=$USER_PASSWD . -f dockerfile.base diff --git a/docker/docker_run.sh b/docker/docker_run.sh new file mode 100644 index 0000000000000000000000000000000000000000..5fa4b407d1a18bf9b05ceb107571932b25624eda --- /dev/null +++ b/docker/docker_run.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash + +WORK_DIR=$(dirname "$(readlink -f "$0")") +source $WORK_DIR/setup_env.sh + +RUNNING_IDS="$(docker ps --filter ancestor=$TAG --format "{{.ID}}")" + +if [ -n "$RUNNING_IDS" ]; then + # Initialize an array to hold the container IDs + declare -a container_ids=($RUNNING_IDS) + + # Get the first container ID using array indexing + ID=${container_ids[0]} + + # Print the first container ID + echo ' ' + echo "The running container ID is: $ID, enter it!" +else + echo ' ' + echo "Not found running containers, run it!" + + # Run a new docker container instance + ID=$(docker run \ + --rm \ + --gpus all \ + -itd \ + --ipc=host \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + -e DISPLAY=$DISPLAY \ + -v /tmp/.X11-unix/:/tmp/.X11-unix/ \ + -v $PWD:/home/$USER_NAME/workspace \ + -w /home/$USER_NAME/workspace \ + $(cat $WORK_DIR/ports.txt) \ + $TAG) +fi + +docker logs $ID + +echo ' ' +echo ' ' +echo '=========================================' +echo ' ' + +docker exec -it $ID bash diff --git a/docker/dockerfile.base b/docker/dockerfile.base new file mode 100644 index 0000000000000000000000000000000000000000..b83488d48a55df343fb97688acf91f5693d59edb --- /dev/null +++ b/docker/dockerfile.base @@ -0,0 +1,24 @@ +ARG BASE_TAG +FROM ${BASE_TAG} +ARG USER_NAME=myuser +ARG USER_PASSWD=111111 +ARG DEBIAN_FRONTEND=noninteractive + +# Pre-install packages, pip install requirements and run post install script. +COPY packages.txt . +COPY requirements.txt . +COPY postinstallscript.sh . +RUN apt-get update && apt-get install -y sudo $(cat packages.txt) +RUN pip install --no-cache-dir -r requirements.txt +RUN bash postinstallscript.sh + +# Create a new user and group using the username argument +RUN groupadd -r ${USER_NAME} && useradd -r -m -g${USER_NAME} ${USER_NAME} +RUN echo "${USER_NAME}:${USER_PASSWD}" | chpasswd +RUN usermod -aG sudo ${USER_NAME} +USER ${USER_NAME} +ENV USER=${USER_NAME} +WORKDIR /home/${USER_NAME}/workspace + +# Set the prompt to highlight the username +RUN echo "export PS1='\[\033[01;32m\]\u\[\033[00m\]@\[\033[01;34m\]\h\[\033[00m\]:\[\033[01;36m\]\w\[\033[00m\]\$'" >> /home/${USER_NAME}/.bashrc diff --git a/docker/packages.txt b/docker/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..1ec90dd5b12bc325f93ea1c3b04cef2b774e8c7b --- /dev/null +++ b/docker/packages.txt @@ -0,0 +1,3 @@ +wget +curl +git \ No newline at end of file diff --git a/docker/ports.txt b/docker/ports.txt new file mode 100644 index 0000000000000000000000000000000000000000..2fd9c4567f3b75cd223bfe3c88c9a3deaefd342d --- /dev/null +++ b/docker/ports.txt @@ -0,0 +1 @@ +-p 6006:6006 \ No newline at end of file diff --git a/docker/postinstallscript.sh b/docker/postinstallscript.sh new file mode 100644 index 0000000000000000000000000000000000000000..457f47c177f9f3865720dc4b54c0df01260d9f98 --- /dev/null +++ b/docker/postinstallscript.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +# this script will run when build docker image. + diff --git a/docker/requirements.txt b/docker/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d2c490c3b0e0f639828cfb5f7b58cc8ad1802a4f --- /dev/null +++ b/docker/requirements.txt @@ -0,0 +1,40 @@ +setuptools>=61.0 +torch==2.0.1 +torchvision==0.15.2 +transformers==4.32.0 +albumentations==1.4.0 +av==11.0.0 +decord==0.6.0 +einops==0.3.0 +fastapi==0.110.0 +accelerate==0.21.0 +gdown==5.1.0 +h5py==3.10.0 +idna==3.6 +imageio==2.34.0 +matplotlib==3.7.5 +numpy==1.24.4 +omegaconf==2.1.1 +opencv-python==4.9.0.80 +opencv-python-headless==4.9.0.80 +pandas==2.0.3 +pillow==10.2.0 +pydub==0.25.1 +pytorch-lightning==1.4.2 +pytorchvideo==0.1.5 +PyYAML==6.0.1 +regex==2023.12.25 +requests==2.31.0 +scikit-learn==1.3.2 +scipy==1.10.1 +six==1.16.0 +tensorboard==2.14.0 +test-tube==0.7.5 +timm==0.9.16 +torchdiffeq==0.2.3 +torchmetrics==0.5.0 +tqdm==4.66.2 +urllib3==2.2.1 +uvicorn==0.27.1 +diffusers==0.24.0 +scikit-video==1.1.11 diff --git a/docker/run_docker.png b/docker/run_docker.png new file mode 100644 index 0000000000000000000000000000000000000000..dbce9b9c42ffbc77d908daf4329b11c4864ba5cd Binary files /dev/null and b/docker/run_docker.png differ diff --git a/docker/setup_env.sh b/docker/setup_env.sh new file mode 100644 index 0000000000000000000000000000000000000000..e7b1ecc778cb7113b513b2879cf2321754fb5b2f --- /dev/null +++ b/docker/setup_env.sh @@ -0,0 +1,11 @@ +# Docker tag for new build image +TAG=open_sora_plan:dev + +# Base docker image tag used by docker build +BASE_TAG=nvcr.io/nvidia/pytorch:23.05-py3 + +# User name used in docker container +USER_NAME=developer + +# User password used in docker container +USER_PASSWD=666666 \ No newline at end of file diff --git a/docs/CausalVideoVAE.md b/docs/CausalVideoVAE.md new file mode 100644 index 0000000000000000000000000000000000000000..2d4b9457e7c02474ca4efee65cb2e714c65e358f --- /dev/null +++ b/docs/CausalVideoVAE.md @@ -0,0 +1,36 @@ +# CausalVideoVAE Report + +## Examples + +### Image Reconstruction + +Resconstruction in **1536×1024**. + + + + + +### Video Reconstruction + +We reconstruct two videos with **720×1280**. Since github can't put too big video, we put it here: [1](https://streamable.com/gqojal), [2](https://streamable.com/6nu3j8). + +https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/c100bb02-2420-48a3-9d7b-4608a41f14aa + +https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/8aa8f587-d9f1-4e8b-8a82-d3bf9ba91d68 + +## Model Structure + +![image](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/e3c8b35d-a217-4d96-b2e9-5c248a2859c8) + + +The Causal Video VAE architecture inherits from the [Stable-Diffusion Image VAE](https://github.com/CompVis/stable-diffusion/tree/main). To ensure that the pretrained weights of the Image VAE can be seamlessly applied to the Video VAE, the model structure has been designed as follows: + +**1. CausalConv3D**: Converting Conv2D to CausalConv3D enables joint training of image and video data. CausalConv3D applies a special treatment to the first frame, as it does not have access to subsequent frames. For more specific details, please refer to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/145 + +**2. Initialization**: There are two common [methods](https://github.com/hassony2/inflated_convnets_pytorch/blob/master/src/inflate.py#L5) to expand Conv2D to Conv3D: average initialization and center initialization. But we employ a specific initialization method (tail initialization). This initialization method ensures that without any training, the model is capable of directly reconstructing images, and even videos. + +## Training Details + +image + +We present the loss curves for two distinct initialization methods under 17×256×256. The yellow curve represents the loss using tail init, while the blue curve corresponds to the loss from center initialization. As shown in the graph, tail initialization demonstrates better performance on the loss curve. Additionally, **we found that center initialization leads to error accumulation**, causing the collapse over extended durations. diff --git a/docs/Contribution_Guidelines.md b/docs/Contribution_Guidelines.md new file mode 100644 index 0000000000000000000000000000000000000000..d7f163a3d66f62767c264ab86894c6bf74de132c --- /dev/null +++ b/docs/Contribution_Guidelines.md @@ -0,0 +1,87 @@ +# Contributing to the Open-Sora Plan Community + +The Open-Sora Plan open-source community is a collaborative initiative driven by the community, emphasizing a commitment to being free and void of exploitation. Organized spontaneously by community members, we invite you to contribute to the Open-Sora Plan open-source community and help elevate it to new heights! + +## Submitting a Pull Request (PR) + +As a contributor, before submitting your request, kindly follow these guidelines: + +1. Start by checking the [Open-Sora Plan GitHub](https://github.com/PKU-YuanGroup/Open-Sora-Plan/pulls) to see if there are any open or closed pull requests related to your intended submission. Avoid duplicating existing work. + +2. [Fork](https://github.com/PKU-YuanGroup/Open-Sora-Plan/fork) the [open-sora plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) repository and download your forked repository to your local machine. + + ```bash + git clone [your-forked-repository-url] + ``` + +3. Add the original Open-Sora Plan repository as a remote to sync with the latest updates: + + ```bash + git remote add upstream https://github.com/PKU-YuanGroup/Open-Sora-Plan + ``` + +4. Sync the code from the main repository to your local machine, and then push it back to your forked remote repository. + + ``` + # Pull the latest code from the upstream branch + git fetch upstream + + # Switch to the main branch + git checkout main + + # Merge the updates from the upstream branch into main, synchronizing the local main branch with the upstream + git merge upstream/main + + # Additionally, sync the local main branch to the remote branch of your forked repository + git push origin main + ``` + + + > Note: Sync the code from the main repository before each submission. + +5. Create a branch in your forked repository for your changes, ensuring the branch name is meaningful. + + ```bash + git checkout -b my-docs-branch main + ``` + +6. While making modifications and committing changes, adhere to our [Commit Message Format](#Commit-Message-Format). + + ```bash + git commit -m "[docs]: xxxx" + ``` + +7. Push your changes to your GitHub repository. + + ```bash + git push origin my-docs-branch + ``` + +8. Submit a pull request to `Open-Sora-Plan:main` on the GitHub repository page. + +## Commit Message Format + +Commit messages must include both `` and `` sections. + +```bash +[]: + │ │ + │ └─⫸ Briefly describe your changes, without ending with a period. + │ + └─⫸ Commit Type: |docs|feat|fix|refactor| +``` + +### Type + +* **docs**: Modify or add documents. +* **feat**: Introduce a new feature. +* **fix**: Fix a bug. +* **refactor**: Restructure code, excluding new features or bug fixes. + +### Summary + +Describe modifications in English, without ending with a period. + +> e.g., git commit -m "[docs]: add a contributing.md file" + +This guideline is borrowed by [minisora](https://github.com/mini-sora/minisora). We sincerely appreciate MiniSora authors for their awesome templates. diff --git a/docs/Data.md b/docs/Data.md new file mode 100644 index 0000000000000000000000000000000000000000..3a20c6f401b6b8b775bf9b1f3a89a51369e8f12e --- /dev/null +++ b/docs/Data.md @@ -0,0 +1,35 @@ + +**We need more dataset**, please refer to the [open-sora-Dataset](https://github.com/shaodong233/open-sora-Dataset) for details. + + +## Sky + + +This is an un-condition datasets. [Link](https://drive.google.com/open?id=1xWLiU-MBGN7MrsFHQm4_yXmfHBsMbJQo) + +``` +sky_timelapse +├── readme +├── sky_test +├── sky_train +├── test_videofolder.py +└── video_folder.py +``` + +## UCF101 + +We test the code with UCF-101 dataset. In order to download UCF-101 dataset, you can download the necessary files in [here](https://www.crcv.ucf.edu/data/UCF101.php). The code assumes a `ucf101` directory with the following structure +``` +UCF-101/ + ApplyEyeMakeup/ + v1.avi + ... + ... + YoYo/ + v1.avi + ... +``` + + +## Offline feature extraction +Coming soon... diff --git a/docs/EVAL.md b/docs/EVAL.md new file mode 100644 index 0000000000000000000000000000000000000000..3796c79093b8e354c2df1d3fbc8f5f443f325f58 --- /dev/null +++ b/docs/EVAL.md @@ -0,0 +1,110 @@ +# Evaluate the generated videos quality + +You can easily calculate the following video quality metrics, which supports the batch-wise process. +- **CLIP-SCORE**: It uses the pretrained CLIP model to measure the cosine similarity between two modalities. +- **FVD**: Frechét Video Distance +- **SSIM**: structural similarity index measure +- **LPIPS**: learned perceptual image patch similarity +- **PSNR**: peak-signal-to-noise ratio + +# Requirement +## Environment +- install Pytorch (torch>=1.7.1) +- install CLIP + ``` + pip install git+https://github.com/openai/CLIP.git + ``` +- install clip-cose from PyPi + ``` + pip install clip-score + ``` +- Other package + ``` + pip install lpips + pip install scipy (scipy==1.7.3/1.9.3, if you use 1.11.3, **you will calculate a WRONG FVD VALUE!!!**) + pip install numpy + pip install pillow + pip install torchvision>=0.8.2 + pip install ftfy + pip install regex + pip install tqdm + ``` +## Pretrain model +- FVD + Before you cacluate FVD, you should first download the FVD pre-trained model. You can manually download any of the following and put it into FVD folder. + - `i3d_torchscript.pt` from [here](https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt) + - `i3d_pretrained_400.pt` from [here](https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI) + + +## Other Notices +1. Make sure the pixel value of videos should be in [0, 1]. +2. We average SSIM when images have 3 channels, ssim is the only metric extremely sensitive to gray being compared to b/w. +3. Because the i3d model downsamples in the time dimension, `frames_num` should > 10 when calculating FVD, so FVD calculation begins from 10-th frame, like upper example. +4. For grayscale videos, we multiply to 3 channels +5. data input specifications for clip_score +> - Image Files:All images should be stored in a single directory. The image files can be in either .png or .jpg format. +> +> - Text Files: All text data should be contained in plain text files in a separate directory. These text files should have the extension .txt. +> +> Note: The number of files in the image directory should be exactly equal to the number of files in the text directory. Additionally, the files in the image directory and text directory should be paired by file name. For instance, if there is a cat.png in the image directory, there should be a corresponding cat.txt in the text directory. +> +> Directory Structure Example: +> ``` +> ├── path/to/image +> │ ├── cat.png +> │ ├── dog.png +> │ └── bird.jpg +> └── path/to/text +> ├── cat.txt +> ├── dog.txt +> └── bird.txt +> ``` + +6. data input specifications for fvd, psnr, ssim, lpips + +> Directory Structure Example: +> ``` +> ├── path/to/generated_image +> │ ├── cat.mp4 +> │ ├── dog.mp4 +> │ └── bird.mp4 +> └── path/to/real_image +> ├── cat.mp4 +> ├── dog.mp4 +> └── bird.mp4 +> ``` + + + +# Usage + +``` +# you change the file path and need to set the frame_num, resolution etc... + +# clip_score cross modality +cd opensora/eval +bash script/cal_clip_score.sh + + + +# fvd +cd opensora/eval +bash script/cal_fvd.sh + +# psnr +cd opensora/eval +bash eval/script/cal_psnr.sh + + +# ssim +cd opensora/eval +bash eval/script/cal_ssim.sh + + +# lpips +cd opensora/eval +bash eval/script/cal_lpips.sh +``` + +# Acknowledgement +The evaluation codebase refers to [clip-score](https://github.com/Taited/clip-score) and [common_metrics](https://github.com/JunyaoHu/common_metrics_on_video_quality). \ No newline at end of file diff --git a/docs/VQVAE.md b/docs/VQVAE.md new file mode 100644 index 0000000000000000000000000000000000000000..a1af36364a9b374637151f76e8600d0d1395092d --- /dev/null +++ b/docs/VQVAE.md @@ -0,0 +1,57 @@ +# VQVAE Documentation + +# Introduction + +Vector Quantized Variational AutoEncoders (VQ-VAE) is a type of autoencoder that uses a discrete latent representation. It is particularly useful for tasks that require discrete latent variables, such as text-to-speech and video generation. + +# Usage + +## Initialization + +To initialize a VQVAE model, you can use the `VideoGPTVQVAE` class. This class is a part of the `opensora.models.ae` module. + +```python +from opensora.models.ae import VideoGPTVQVAE + +vqvae = VideoGPTVQVAE() +``` + +### Training + +To train the VQVAE model, you can use the `train_videogpt.sh` script. This script will train the model using the parameters specified in the script. + +```bash +bash scripts/videogpt/train_videogpt.sh +``` + +### Loading Pretrained Models + +You can load a pretrained model using the `download_and_load_model` method. This method will download the checkpoint file and load the model. + +```python +vqvae = VideoGPTVQVAE.download_and_load_model("bair_stride4x2x2") +``` + +Alternatively, you can load a model from a checkpoint using the `load_from_checkpoint` method. + +```python +vqvae = VQVAEModel.load_from_checkpoint("results/VQVAE/checkpoint-1000") +``` + +### Encoding and Decoding + +You can encode a video using the `encode` method. This method will return the encodings and embeddings of the video. + +```python +encodings, embeddings = vqvae.encode(x_vae, include_embeddings=True) +``` + +You can reconstruct a video from its encodings using the decode method. + +```python +video_recon = vqvae.decode(encodings) +``` + +## Testing + +You can test the VQVAE model by reconstructing a video. The `examples/rec_video.py` script provides an example of how to do this. \ No newline at end of file diff --git a/examples/get_latents_std.py b/examples/get_latents_std.py new file mode 100644 index 0000000000000000000000000000000000000000..5f9937024251882c1d310dfd4ce5c03c1b3088ba --- /dev/null +++ b/examples/get_latents_std.py @@ -0,0 +1,38 @@ +import torch +from torch.utils.data import DataLoader, Subset +import sys +sys.path.append(".") +from opensora.models.ae.videobase import CausalVAEModel, CausalVAEDataset + +num_workers = 4 +batch_size = 12 + +torch.manual_seed(0) +torch.set_grad_enabled(False) + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +pretrained_model_name_or_path = 'results/causalvae/checkpoint-26000' +data_path = '/remote-home1/dataset/UCF-101' +video_num_frames = 17 +resolution = 128 +sample_rate = 10 + +vae = CausalVAEModel.load_from_checkpoint(pretrained_model_name_or_path) +vae.to(device) + +dataset = CausalVAEDataset(data_path, sequence_length=video_num_frames, resolution=resolution, sample_rate=sample_rate) +subset_indices = list(range(1000)) +subset_dataset = Subset(dataset, subset_indices) +loader = DataLoader(subset_dataset, batch_size=8, pin_memory=True) + +all_latents = [] +for video_data in loader: + video_data = video_data['video'].to(device) + latents = vae.encode(video_data).sample() + all_latents.append(video_data.cpu()) + +all_latents_tensor = torch.cat(all_latents) +std = all_latents_tensor.std().item() +normalizer = 1 / std +print(f'{normalizer = }') \ No newline at end of file diff --git a/examples/prompt_list_0.txt b/examples/prompt_list_0.txt new file mode 100644 index 0000000000000000000000000000000000000000..6c91d3c3e119f30d660bbef82061c3ccaf6e0e6a --- /dev/null +++ b/examples/prompt_list_0.txt @@ -0,0 +1,16 @@ +A quiet beach at dawn, the waves gently lapping at the shore and the sky painted in pastel hues. +A quiet beach at dawn, the waves softly lapping at the shore, pink and orange hues painting the sky, offering a moment of solitude and reflection. +The majestic beauty of a waterfall cascading down a cliff into a serene lake. +Sunset over the sea. +a cat wearing sunglasses and working as a lifeguard at pool. +Slow pan upward of blazing oak fire in an indoor fireplace. +Yellow and black tropical fish dart through the sea. +a serene winter scene in a forest. The forest is blanketed in a thick layer of snow, which has settled on the branches of the trees, creating a canopy of white. The trees, a mix of evergreens and deciduous, stand tall and silent, their forms partially obscured by the snow. The ground is a uniform white, with no visible tracks or signs of human activity. The sun is low in the sky, casting a warm glow that contrasts with the cool tones of the snow. The light filters through the trees, creating a soft, diffused illumination that highlights the texture of the snow and the contours of the trees. The overall style of the scene is naturalistic, with a focus on the tranquility and beauty of the winter landscape. +a dynamic interaction between the ocean and a large rock. The rock, with its rough texture and jagged edges, is partially submerged in the water, suggesting it is a natural feature of the coastline. The water around the rock is in motion, with white foam and waves crashing against the rock, indicating the force of the ocean's movement. The background is a vast expanse of the ocean, with small ripples and waves, suggesting a moderate sea state. The overall style of the scene is a realistic depiction of a natural landscape, with a focus on the interplay between the rock and the water. +A serene waterfall cascading down moss-covered rocks, its soothing sound creating a harmonious symphony with nature. +A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures. +The video captures the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene. The camera angle provides a bird's eye view of the waterfall, allowing viewers to appreciate the full height and grandeur of the waterfall. The video is a stunning representation of nature's power and beauty. +A vibrant scene of a snowy mountain landscape. The sky is filled with a multitude of colorful hot air balloons, each floating at different heights, creating a dynamic and lively atmosphere. The balloons are scattered across the sky, some closer to the viewer, others further away, adding depth to the scene. Below, the mountainous terrain is blanketed in a thick layer of snow, with a few patches of bare earth visible here and there. The snow-covered mountains provide a stark contrast to the colorful balloons, enhancing the visual appeal of the scene. +A serene underwater scene featuring a sea turtle swimming through a coral reef. The turtle, with its greenish-brown shell, is the main focus of the video, swimming gracefully towards the right side of the frame. The coral reef, teeming with life, is visible in the background, providing a vibrant and colorful backdrop to the turtle's journey. Several small fish, darting around the turtle, add a sense of movement and dynamism to the scene. +A snowy forest landscape with a dirt road running through it. The road is flanked by trees covered in snow, and the ground is also covered in snow. The sun is shining, creating a bright and serene atmosphere. The road appears to be empty, and there are no people or animals visible in the video. The style of the video is a natural landscape shot, with a focus on the beauty of the snowy forest and the peacefulness of the road. +The dynamic movement of tall, wispy grasses swaying in the wind. The sky above is filled with clouds, creating a dramatic backdrop. The sunlight pierces through the clouds, casting a warm glow on the scene. The grasses are a mix of green and brown, indicating a change in seasons. The overall style of the video is naturalistic, capturing the beauty of the landscape in a realistic manner. The focus is on the grasses and their movement, with the sky serving as a secondary element. The video does not contain any human or animal elements. \ No newline at end of file diff --git a/examples/rec_image.py b/examples/rec_image.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e95cb3028283ff317c0aeef9bfe9d17860b89a --- /dev/null +++ b/examples/rec_image.py @@ -0,0 +1,57 @@ +import sys +sys.path.append(".") +from PIL import Image +import torch +from torchvision.transforms import ToTensor, Compose, Resize, Normalize +from torch.nn import functional as F +from opensora.models.ae.videobase import CausalVAEModel +import argparse +import numpy as np + +def preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor: + transform = Compose( + [ + ToTensor(), + Normalize((0.5), (0.5)), + Resize(size=short_size), + ] + ) + outputs = transform(video_data) + outputs = outputs.unsqueeze(0).unsqueeze(2) + return outputs + +def main(args: argparse.Namespace): + image_path = args.image_path + resolution = args.resolution + device = args.device + + vqvae = CausalVAEModel.load_from_checkpoint(args.ckpt) + vqvae.eval() + vqvae = vqvae.to(device) + + with torch.no_grad(): + x_vae = preprocess(Image.open(image_path), resolution) + x_vae = x_vae.to(device) + latents = vqvae.encode(x_vae) + recon = vqvae.decode(latents.sample()) + x = recon[0, :, 0, :, :] + x = x.squeeze() + x = x.detach().cpu().numpy() + x = np.clip(x, -1, 1) + x = (x + 1) / 2 + x = (255*x).astype(np.uint8) + x = x.transpose(1,2,0) + image = Image.fromarray(x) + image.save(args.rec_path) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--image-path', type=str, default='') + parser.add_argument('--rec-path', type=str, default='') + parser.add_argument('--ckpt', type=str, default='') + parser.add_argument('--resolution', type=int, default=336) + parser.add_argument('--device', type=str, default='cuda') + + args = parser.parse_args() + main(args) diff --git a/examples/rec_imvi_vae.py b/examples/rec_imvi_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..7e216c253925462126637a814f33e9bbf75ecb9c --- /dev/null +++ b/examples/rec_imvi_vae.py @@ -0,0 +1,159 @@ +import math +import random +import argparse +from typing import Optional + +import cv2 +import numpy as np +import numpy.typing as npt +import torch +from PIL import Image +from decord import VideoReader, cpu +from torch.nn import functional as F +from pytorchvideo.transforms import ShortSideScale +from torchvision.transforms import Lambda, Compose + +import sys +sys.path.append(".") +from opensora.dataset.transform import CenterCropVideo, resize +from opensora.models.ae.videobase import CausalVAEModel + + +def array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None: + height, width, channels = image_array[0].shape + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height)) + + for image in image_array: + image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + video_writer.write(image_rgb) + + video_writer.release() + +def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None: + x = x.detach().cpu() + x = torch.clamp(x, -1, 1) + x = (x + 1) / 2 + x = x.permute(1, 2, 3, 0).numpy() + x = (255*x).astype(np.uint8) + array_to_video(x, fps=fps, output_file=output_file) + return + +def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor: + decord_vr = VideoReader(video_path, ctx=cpu(0)) + total_frames = len(decord_vr) + sample_frames_len = sample_rate * num_frames + + if total_frames > sample_frames_len: + s = random.randint(0, total_frames - sample_frames_len - 1) + s = 0 + e = s + sample_frames_len + num_frames = num_frames + else: + s = 0 + e = total_frames + num_frames = int(total_frames / sample_frames_len * num_frames) + print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path, + total_frames) + + + frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) + video_data = decord_vr.get_batch(frame_id_list).asnumpy() + video_data = torch.from_numpy(video_data) + video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W) + return video_data + + +class ResizeVideo: + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + self.size = size + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + _, _, h, w = clip.shape + if w < h: + new_h = int(math.floor((float(h) / w) * self.size)) + new_w = self.size + else: + new_h = self.size + new_w = int(math.floor((float(w) / h) * self.size)) + return torch.nn.functional.interpolate( + clip, size=(new_h, new_w), mode=self.interpolation_mode, align_corners=False, antialias=True + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + +def preprocess(video_data: torch.Tensor, short_size: int = 128, crop_size: Optional[int] = None) -> torch.Tensor: + + transform = Compose( + [ + Lambda(lambda x: ((x / 255.0) * 2 - 1)), + ResizeVideo(size=short_size), + CenterCropVideo(crop_size) if crop_size is not None else Lambda(lambda x: x), + ] + ) + + video_outputs = transform(video_data) + video_outputs = torch.unsqueeze(video_outputs, 0) + + return video_outputs + + +def main(args: argparse.Namespace): + video_path = args.video_path + num_frames = args.num_frames + resolution = args.resolution + crop_size = args.crop_size + sample_fps = args.sample_fps + sample_rate = args.sample_rate + device = args.device + vqvae = CausalVAEModel.from_pretrained(args.ckpt) + if args.enable_tiling: + vqvae.enable_tiling() + vqvae.tile_overlap_factor = args.tile_overlap_factor + vqvae.eval() + vqvae = vqvae.to(device) + vqvae = vqvae # .to(torch.float16) + + with torch.no_grad(): + x_vae = preprocess(read_video(video_path, num_frames, sample_rate), resolution, crop_size) + x_vae = x_vae.to(device) # b c t h w + x_vae = x_vae # .to(torch.float16) + latents = vqvae.encode(x_vae).sample() # .to(torch.float16) + video_recon = vqvae.decode(latents) + + if video_recon.shape[2] == 1: + x = video_recon[0, :, 0, :, :] + x = x.squeeze() + x = x.detach().cpu().numpy() + x = np.clip(x, -1, 1) + x = (x + 1) / 2 + x = (255 * x).astype(np.uint8) + x = x.transpose(1, 2, 0) + image = Image.fromarray(x) + image.save(args.rec_path.replace('mp4', 'jpg')) + else: + custom_to_video(video_recon[0], fps=sample_fps/sample_rate, output_file=args.rec_path) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--video-path', type=str, default='') + parser.add_argument('--rec-path', type=str, default='') + parser.add_argument('--ckpt', type=str, default='results/pretrained') + parser.add_argument('--sample-fps', type=int, default=30) + parser.add_argument('--resolution', type=int, default=336) + parser.add_argument('--crop-size', type=int, default=None) + parser.add_argument('--num-frames', type=int, default=100) + parser.add_argument('--sample-rate', type=int, default=1) + parser.add_argument('--device', type=str, default="cuda") + parser.add_argument('--tile_overlap_factor', type=float, default=0.25) + parser.add_argument('--enable_tiling', action='store_true') + + args = parser.parse_args() + main(args) diff --git a/examples/rec_video.py b/examples/rec_video.py new file mode 100644 index 0000000000000000000000000000000000000000..5acbeb6fb753a17481fe5274fd98ad3760128582 --- /dev/null +++ b/examples/rec_video.py @@ -0,0 +1,120 @@ +import random +import argparse +from typing import Optional + +import cv2 +import imageio +import numpy as np +import numpy.typing as npt +import torch +from decord import VideoReader, cpu +from torch.nn import functional as F +from pytorchvideo.transforms import ShortSideScale +from torchvision.transforms import Lambda, Compose +from torchvision.transforms._transforms_video import RandomCropVideo + +import sys +sys.path.append(".") +from opensora.models.ae import VQVAEModel + + +def array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None: + height, width, channels = image_array[0].shape + fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore + video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height)) + + for image in image_array: + image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + video_writer.write(image_rgb) + + video_writer.release() + +def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None: + x = x.detach().cpu() + x = torch.clamp(x, -0.5, 0.5) + x = (x + 0.5) + x = x.permute(1, 2, 3, 0).numpy() # (C, T, H, W) -> (T, H, W, C) + x = (255*x).astype(np.uint8) + # array_to_video(x, fps=fps, output_file=output_file) + imageio.mimwrite(output_file, x, fps=fps, quality=9) + return + +def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor: + decord_vr = VideoReader(video_path, ctx=cpu(0)) + total_frames = len(decord_vr) + sample_frames_len = sample_rate * num_frames + + if total_frames > sample_frames_len: + s = random.randint(0, total_frames - sample_frames_len - 1) + e = s + sample_frames_len + num_frames = num_frames + else: + s = 0 + e = total_frames + num_frames = int(total_frames / sample_frames_len * num_frames) + print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path, + total_frames) + + + frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) + video_data = decord_vr.get_batch(frame_id_list).asnumpy() + video_data = torch.from_numpy(video_data) + video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W) + return video_data + +def preprocess(video_data: torch.Tensor, short_size: int = 128, crop_size: Optional[int] = None) -> torch.Tensor: + + transform = Compose( + [ + # UniformTemporalSubsample(num_frames), + Lambda(lambda x: ((x / 255.0) - 0.5)), + # NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD), + ShortSideScale(size=short_size), + RandomCropVideo(size=crop_size) if crop_size is not None else Lambda(lambda x: x), + # RandomHorizontalFlipVideo(p=0.5), + ] + ) + + video_outputs = transform(video_data) + video_outputs = torch.unsqueeze(video_outputs, 0) + + return video_outputs + + +def main(args: argparse.Namespace): + video_path = args.video_path + num_frames = args.num_frames + resolution = args.resolution + crop_size = args.crop_size + sample_fps = args.sample_fps + sample_rate = args.sample_rate + device = torch.device('cuda') + if args.ckpt in ['bair_stride4x2x2', 'ucf101_stride4x4x4', 'kinetics_stride4x4x4', 'kinetics_stride2x4x4']: + vqvae = VQVAEModel.download_and_load_model(args.ckpt) + else: + vqvae = VQVAEModel.load_from_checkpoint(args.ckpt) + vqvae.eval() + vqvae = vqvae.to(device) + + with torch.no_grad(): + x_vae = preprocess(read_video(video_path, num_frames, sample_rate), resolution, crop_size) + x_vae = x_vae.to(device) + encodings, embeddings = vqvae.encode(x_vae, include_embeddings=True) + video_recon = vqvae.decode(encodings) + + # custom_to_video(x_vae[0], fps=sample_fps/sample_rate, output_file='origin_input.mp4') + custom_to_video(video_recon[0], fps=sample_fps/sample_rate, output_file=args.rec_path) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--video-path', type=str, default='') + parser.add_argument('--rec-path', type=str, default='') + parser.add_argument('--ckpt', type=str, default='ucf101_stride4x4x4') + parser.add_argument('--sample-fps', type=int, default=30) + parser.add_argument('--resolution', type=int, default=336) + parser.add_argument('--crop-size', type=int, default=None) + parser.add_argument('--num-frames', type=int, default=100) + parser.add_argument('--sample-rate', type=int, default=1) + args = parser.parse_args() + main(args) diff --git a/examples/rec_video_ae.py b/examples/rec_video_ae.py new file mode 100644 index 0000000000000000000000000000000000000000..5acbeb6fb753a17481fe5274fd98ad3760128582 --- /dev/null +++ b/examples/rec_video_ae.py @@ -0,0 +1,120 @@ +import random +import argparse +from typing import Optional + +import cv2 +import imageio +import numpy as np +import numpy.typing as npt +import torch +from decord import VideoReader, cpu +from torch.nn import functional as F +from pytorchvideo.transforms import ShortSideScale +from torchvision.transforms import Lambda, Compose +from torchvision.transforms._transforms_video import RandomCropVideo + +import sys +sys.path.append(".") +from opensora.models.ae import VQVAEModel + + +def array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None: + height, width, channels = image_array[0].shape + fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore + video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height)) + + for image in image_array: + image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + video_writer.write(image_rgb) + + video_writer.release() + +def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None: + x = x.detach().cpu() + x = torch.clamp(x, -0.5, 0.5) + x = (x + 0.5) + x = x.permute(1, 2, 3, 0).numpy() # (C, T, H, W) -> (T, H, W, C) + x = (255*x).astype(np.uint8) + # array_to_video(x, fps=fps, output_file=output_file) + imageio.mimwrite(output_file, x, fps=fps, quality=9) + return + +def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor: + decord_vr = VideoReader(video_path, ctx=cpu(0)) + total_frames = len(decord_vr) + sample_frames_len = sample_rate * num_frames + + if total_frames > sample_frames_len: + s = random.randint(0, total_frames - sample_frames_len - 1) + e = s + sample_frames_len + num_frames = num_frames + else: + s = 0 + e = total_frames + num_frames = int(total_frames / sample_frames_len * num_frames) + print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path, + total_frames) + + + frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) + video_data = decord_vr.get_batch(frame_id_list).asnumpy() + video_data = torch.from_numpy(video_data) + video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W) + return video_data + +def preprocess(video_data: torch.Tensor, short_size: int = 128, crop_size: Optional[int] = None) -> torch.Tensor: + + transform = Compose( + [ + # UniformTemporalSubsample(num_frames), + Lambda(lambda x: ((x / 255.0) - 0.5)), + # NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD), + ShortSideScale(size=short_size), + RandomCropVideo(size=crop_size) if crop_size is not None else Lambda(lambda x: x), + # RandomHorizontalFlipVideo(p=0.5), + ] + ) + + video_outputs = transform(video_data) + video_outputs = torch.unsqueeze(video_outputs, 0) + + return video_outputs + + +def main(args: argparse.Namespace): + video_path = args.video_path + num_frames = args.num_frames + resolution = args.resolution + crop_size = args.crop_size + sample_fps = args.sample_fps + sample_rate = args.sample_rate + device = torch.device('cuda') + if args.ckpt in ['bair_stride4x2x2', 'ucf101_stride4x4x4', 'kinetics_stride4x4x4', 'kinetics_stride2x4x4']: + vqvae = VQVAEModel.download_and_load_model(args.ckpt) + else: + vqvae = VQVAEModel.load_from_checkpoint(args.ckpt) + vqvae.eval() + vqvae = vqvae.to(device) + + with torch.no_grad(): + x_vae = preprocess(read_video(video_path, num_frames, sample_rate), resolution, crop_size) + x_vae = x_vae.to(device) + encodings, embeddings = vqvae.encode(x_vae, include_embeddings=True) + video_recon = vqvae.decode(encodings) + + # custom_to_video(x_vae[0], fps=sample_fps/sample_rate, output_file='origin_input.mp4') + custom_to_video(video_recon[0], fps=sample_fps/sample_rate, output_file=args.rec_path) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--video-path', type=str, default='') + parser.add_argument('--rec-path', type=str, default='') + parser.add_argument('--ckpt', type=str, default='ucf101_stride4x4x4') + parser.add_argument('--sample-fps', type=int, default=30) + parser.add_argument('--resolution', type=int, default=336) + parser.add_argument('--crop-size', type=int, default=None) + parser.add_argument('--num-frames', type=int, default=100) + parser.add_argument('--sample-rate', type=int, default=1) + args = parser.parse_args() + main(args) diff --git a/examples/rec_video_vae.py b/examples/rec_video_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..7e277017f6a913656e16bccc185caf7c77baf491 --- /dev/null +++ b/examples/rec_video_vae.py @@ -0,0 +1,274 @@ +import random +import argparse +import cv2 +from tqdm import tqdm +import numpy as np +import numpy.typing as npt +import torch +from decord import VideoReader, cpu +from torch.nn import functional as F +from pytorchvideo.transforms import ShortSideScale +from torchvision.transforms import Lambda, Compose +from torchvision.transforms._transforms_video import CenterCropVideo +import sys +from torch.utils.data import Dataset, DataLoader, Subset +import os + +sys.path.append(".") +from opensora.models.ae.videobase import CausalVAEModel +import torch.nn as nn + +def array_to_video( + image_array: npt.NDArray, fps: float = 30.0, output_file: str = "output_video.mp4" +) -> None: + height, width, channels = image_array[0].shape + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height)) + + for image in image_array: + image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + video_writer.write(image_rgb) + + video_writer.release() + + +def custom_to_video( + x: torch.Tensor, fps: float = 2.0, output_file: str = "output_video.mp4" +) -> None: + x = x.detach().cpu() + x = torch.clamp(x, -1, 1) + x = (x + 1) / 2 + x = x.permute(1, 2, 3, 0).float().numpy() + x = (255 * x).astype(np.uint8) + array_to_video(x, fps=fps, output_file=output_file) + return + + +def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor: + decord_vr = VideoReader(video_path, ctx=cpu(0), num_threads=8) + total_frames = len(decord_vr) + sample_frames_len = sample_rate * num_frames + + if total_frames > sample_frames_len: + s = 0 + e = s + sample_frames_len + num_frames = num_frames + else: + s = 0 + e = total_frames + num_frames = int(total_frames / sample_frames_len * num_frames) + print( + f"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}", + video_path, + total_frames, + ) + + frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) + video_data = decord_vr.get_batch(frame_id_list).asnumpy() + video_data = torch.from_numpy(video_data) + video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W) + return video_data + + +class RealVideoDataset(Dataset): + def __init__( + self, + real_video_dir, + num_frames, + sample_rate=1, + crop_size=None, + resolution=128, + ) -> None: + super().__init__() + self.real_video_files = self._combine_without_prefix(real_video_dir) + self.num_frames = num_frames + self.sample_rate = sample_rate + self.crop_size = crop_size + self.short_size = resolution + + def __len__(self): + return len(self.real_video_files) + + def __getitem__(self, index): + if index >= len(self): + raise IndexError + real_video_file = self.real_video_files[index] + real_video_tensor = self._load_video(real_video_file) + video_name = os.path.basename(real_video_file) + return {'video': real_video_tensor, 'file_name': video_name } + + def _load_video(self, video_path): + num_frames = self.num_frames + sample_rate = self.sample_rate + decord_vr = VideoReader(video_path, ctx=cpu(0)) + total_frames = len(decord_vr) + sample_frames_len = sample_rate * num_frames + + if total_frames > sample_frames_len: + s = 0 + e = s + sample_frames_len + num_frames = num_frames + else: + s = 0 + e = total_frames + num_frames = int(total_frames / sample_frames_len * num_frames) + print( + f"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}", + video_path, + total_frames, + ) + + frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) + video_data = decord_vr.get_batch(frame_id_list).asnumpy() + video_data = torch.from_numpy(video_data) + video_data = video_data.permute(3, 0, 1, 2) + return _preprocess( + video_data, short_size=self.short_size, crop_size=self.crop_size + ) + + def _combine_without_prefix(self, folder_path, prefix="."): + folder = [] + for name in os.listdir(folder_path): + if name[0] == prefix: + continue + folder.append(os.path.join(folder_path, name)) + folder.sort() + return folder + +def resize(x, resolution): + height, width = x.shape[-2:] + aspect_ratio = width / height + if width <= height: + new_width = resolution + new_height = int(resolution / aspect_ratio) + else: + new_height = resolution + new_width = int(resolution * aspect_ratio) + resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True) + return resized_x + +def _preprocess(video_data, short_size=128, crop_size=None): + transform = Compose( + [ + Lambda(lambda x: ((x / 255.0) * 2 - 1)), + Lambda(lambda x: resize(x, short_size)), + ( + CenterCropVideo(crop_size=crop_size) + if crop_size is not None + else Lambda(lambda x: x) + ), + ] + ) + video_outputs = transform(video_data) + video_outputs = _format_video_shape(video_outputs) + return video_outputs + + +def _format_video_shape(video, time_compress=4, spatial_compress=8): + time = video.shape[1] + height = video.shape[2] + width = video.shape[3] + new_time = ( + (time - (time - 1) % time_compress) + if (time - 1) % time_compress != 0 + else time + ) + new_height = ( + (height - (height) % spatial_compress) + if height % spatial_compress != 0 + else height + ) + new_width = ( + (width - (width) % spatial_compress) if width % spatial_compress != 0 else width + ) + return video[:, :new_time, :new_height, :new_width] + + +@torch.no_grad() +def main(args: argparse.Namespace): + real_video_dir = args.real_video_dir + generated_video_dir = args.generated_video_dir + ckpt = args.ckpt + sample_rate = args.sample_rate + resolution = args.resolution + crop_size = args.crop_size + num_frames = args.num_frames + sample_rate = args.sample_rate + device = args.device + sample_fps = args.sample_fps + batch_size = args.batch_size + num_workers = args.num_workers + subset_size = args.subset_size + + if not os.path.exists(args.generated_video_dir): + os.makedirs(args.generated_video_dir, exist_ok=True) + + data_type = torch.bfloat16 + + # ---- Load Model ---- + device = args.device + vqvae = CausalVAEModel.from_pretrained(args.ckpt) + vqvae = vqvae.to(device).to(data_type) + if args.enable_tiling: + vqvae.enable_tiling() + vqvae.tile_overlap_factor = args.tile_overlap_factor + # ---- Load Model ---- + + # ---- Prepare Dataset ---- + dataset = RealVideoDataset( + real_video_dir=real_video_dir, + num_frames=num_frames, + sample_rate=sample_rate, + crop_size=crop_size, + resolution=resolution, + ) + + if subset_size: + indices = range(subset_size) + dataset = Subset(dataset, indices=indices) + + dataloader = DataLoader( + dataset, batch_size=batch_size, pin_memory=True, num_workers=num_workers + ) + # ---- Prepare Dataset + + # ---- Inference ---- + for batch in tqdm(dataloader): + x, file_names = batch['video'], batch['file_name'] + x = x.to(device=device, dtype=data_type) # b c t h w + latents = vqvae.encode(x).sample().to(data_type) + video_recon = vqvae.decode(latents) + for idx, video in enumerate(video_recon): + output_path = os.path.join(generated_video_dir, file_names[idx]) + if args.output_origin: + os.makedirs(os.path.join(generated_video_dir, "origin/"), exist_ok=True) + origin_output_path = os.path.join(generated_video_dir, "origin/", file_names[idx]) + custom_to_video( + x[idx], fps=sample_fps / sample_rate, output_file=origin_output_path + ) + custom_to_video( + video, fps=sample_fps / sample_rate, output_file=output_path + ) + # ---- Inference ---- + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--real_video_dir", type=str, default="") + parser.add_argument("--generated_video_dir", type=str, default="") + parser.add_argument("--ckpt", type=str, default="") + parser.add_argument("--sample_fps", type=int, default=30) + parser.add_argument("--resolution", type=int, default=336) + parser.add_argument("--crop_size", type=int, default=None) + parser.add_argument("--num_frames", type=int, default=17) + parser.add_argument("--sample_rate", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--subset_size", type=int, default=None) + parser.add_argument("--tile_overlap_factor", type=float, default=0.25) + parser.add_argument('--enable_tiling', action='store_true') + parser.add_argument('--output_origin', action='store_true') + parser.add_argument("--device", type=str, default="cuda") + + args = parser.parse_args() + main(args) + diff --git a/opensora/__init__.py b/opensora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e768b56d84e02be46b30d93be818972220b8462 --- /dev/null +++ b/opensora/__init__.py @@ -0,0 +1 @@ +# \ No newline at end of file diff --git a/opensora/dataset/__init__.py b/opensora/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..363cf04eaeaea18f2e03030b7db99d2f43ce5093 --- /dev/null +++ b/opensora/dataset/__init__.py @@ -0,0 +1,99 @@ +from torchvision.transforms import Compose +from transformers import AutoTokenizer + +from .feature_datasets import T2V_Feature_dataset, T2V_T5_Feature_dataset +from torchvision import transforms +from torchvision.transforms import Lambda + +from .landscope import Landscope +from .t2v_datasets import T2V_dataset +from .transform import ToTensorVideo, TemporalRandomCrop, RandomHorizontalFlipVideo, CenterCropResizeVideo +from .ucf101 import UCF101 +from .sky_datasets import Sky + +ae_norm = { + 'CausalVAEModel_4x8x8': Lambda(lambda x: 2. * x - 1.), + 'CausalVQVAEModel_4x4x4': Lambda(lambda x: x - 0.5), + 'CausalVQVAEModel_4x8x8': Lambda(lambda x: x - 0.5), + 'VQVAEModel_4x4x4': Lambda(lambda x: x - 0.5), + 'VQVAEModel_4x8x8': Lambda(lambda x: x - 0.5), + "bair_stride4x2x2": Lambda(lambda x: x - 0.5), + "ucf101_stride4x4x4": Lambda(lambda x: x - 0.5), + "kinetics_stride4x4x4": Lambda(lambda x: x - 0.5), + "kinetics_stride2x4x4": Lambda(lambda x: x - 0.5), + 'stabilityai/sd-vae-ft-mse': transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + 'stabilityai/sd-vae-ft-ema': transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + 'vqgan_imagenet_f16_1024': Lambda(lambda x: 2. * x - 1.), + 'vqgan_imagenet_f16_16384': Lambda(lambda x: 2. * x - 1.), + 'vqgan_gumbel_f8': Lambda(lambda x: 2. * x - 1.), + +} +ae_denorm = { + 'CausalVAEModel_4x8x8': lambda x: (x + 1.) / 2., + 'CausalVQVAEModel_4x4x4': lambda x: x + 0.5, + 'CausalVQVAEModel_4x8x8': lambda x: x + 0.5, + 'VQVAEModel_4x4x4': lambda x: x + 0.5, + 'VQVAEModel_4x8x8': lambda x: x + 0.5, + "bair_stride4x2x2": lambda x: x + 0.5, + "ucf101_stride4x4x4": lambda x: x + 0.5, + "kinetics_stride4x4x4": lambda x: x + 0.5, + "kinetics_stride2x4x4": lambda x: x + 0.5, + 'stabilityai/sd-vae-ft-mse': lambda x: 0.5 * x + 0.5, + 'stabilityai/sd-vae-ft-ema': lambda x: 0.5 * x + 0.5, + 'vqgan_imagenet_f16_1024': lambda x: (x + 1.) / 2., + 'vqgan_imagenet_f16_16384': lambda x: (x + 1.) / 2., + 'vqgan_gumbel_f8': lambda x: (x + 1.) / 2., +} + +def getdataset(args): + temporal_sample = TemporalRandomCrop(args.num_frames * args.sample_rate) # 16 x + norm_fun = ae_norm[args.ae] + if args.dataset == 'ucf101': + transform = Compose( + [ + ToTensorVideo(), # TCHW + CenterCropResizeVideo(size=args.max_image_size), + RandomHorizontalFlipVideo(p=0.5), + norm_fun, + ] + ) + return UCF101(args, transform=transform, temporal_sample=temporal_sample) + if args.dataset == 'landscope': + transform = Compose( + [ + ToTensorVideo(), # TCHW + CenterCropResizeVideo(size=args.max_image_size), + RandomHorizontalFlipVideo(p=0.5), + norm_fun, + ] + ) + return Landscope(args, transform=transform, temporal_sample=temporal_sample) + elif args.dataset == 'sky': + transform = transforms.Compose([ + ToTensorVideo(), + CenterCropResizeVideo(args.max_image_size), + RandomHorizontalFlipVideo(p=0.5), + norm_fun + ]) + return Sky(args, transform=transform, temporal_sample=temporal_sample) + elif args.dataset == 't2v': + transform = transforms.Compose([ + ToTensorVideo(), + CenterCropResizeVideo(args.max_image_size), + RandomHorizontalFlipVideo(p=0.5), + norm_fun + ]) + tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name, cache_dir='./cache_dir') + return T2V_dataset(args, transform=transform, temporal_sample=temporal_sample, tokenizer=tokenizer) + elif args.dataset == 't2v_feature': + return T2V_Feature_dataset(args, temporal_sample) + elif args.dataset == 't2v_t5_feature': + transform = transforms.Compose([ + ToTensorVideo(), + CenterCropResizeVideo(args.max_image_size), + RandomHorizontalFlipVideo(p=0.5), + norm_fun + ]) + return T2V_T5_Feature_dataset(args, transform, temporal_sample) + else: + raise NotImplementedError(args.dataset) diff --git a/opensora/dataset/extract_feature_dataset.py b/opensora/dataset/extract_feature_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2fafd4ed7f347d3a1e31cb06815c1b7b5c88c63b --- /dev/null +++ b/opensora/dataset/extract_feature_dataset.py @@ -0,0 +1,64 @@ +import os +from glob import glob + +import numpy as np +import torch +import torchvision +from PIL import Image +from torch.utils.data import Dataset + +from opensora.utils.dataset_utils import DecordInit, is_image_file + + +class ExtractVideo2Feature(Dataset): + def __init__(self, args, transform): + self.data_path = args.data_path + self.transform = transform + self.v_decoder = DecordInit() + self.samples = list(glob(f'{self.data_path}')) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + video_path = self.samples[idx] + video = self.decord_read(video_path) + video = self.transform(video) # T C H W -> T C H W + return video, video_path + + def tv_read(self, path): + vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') + total_frames = len(vframes) + frame_indice = list(range(total_frames)) + video = vframes[frame_indice] + return video + + def decord_read(self, path): + decord_vr = self.v_decoder(path) + total_frames = len(decord_vr) + frame_indice = list(range(total_frames)) + video_data = decord_vr.get_batch(frame_indice).asnumpy() + video_data = torch.from_numpy(video_data) + video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) + return video_data + + + +class ExtractImage2Feature(Dataset): + def __init__(self, args, transform): + self.data_path = args.data_path + self.transform = transform + self.data_all = list(glob(f'{self.data_path}')) + + def __len__(self): + return len(self.data_all) + + def __getitem__(self, index): + path = self.data_all[index] + video_frame = torch.as_tensor(np.array(Image.open(path), dtype=np.uint8, copy=True)).unsqueeze(0) + video_frame = video_frame.permute(0, 3, 1, 2) + video_frame = self.transform(video_frame) # T C H W + # video_frame = video_frame.transpose(0, 1) # T C H W -> C T H W + + return video_frame, path + diff --git a/opensora/dataset/feature_datasets.py b/opensora/dataset/feature_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..74362528a68e8f73a9781a13fafb8854688dd29e --- /dev/null +++ b/opensora/dataset/feature_datasets.py @@ -0,0 +1,213 @@ +import json +import os +import torch +import random +import torch.utils.data as data + +import numpy as np +from glob import glob +from PIL import Image +from torch.utils.data import Dataset +from tqdm import tqdm + +from opensora.dataset.transform import center_crop, RandomCropVideo +from opensora.utils.dataset_utils import DecordInit + + +class T2V_Feature_dataset(Dataset): + def __init__(self, args, temporal_sample): + + self.video_folder = args.video_folder + self.num_frames = args.video_length + self.temporal_sample = temporal_sample + + print('Building dataset...') + if os.path.exists('samples_430k.json'): + with open('samples_430k.json', 'r') as f: + self.samples = json.load(f) + else: + self.samples = self._make_dataset() + with open('samples_430k.json', 'w') as f: + json.dump(self.samples, f, indent=2) + + self.use_image_num = args.use_image_num + self.use_img_from_vid = args.use_img_from_vid + if self.use_image_num != 0 and not self.use_img_from_vid: + self.img_cap_list = self.get_img_cap_list() + + def _make_dataset(self): + all_mp4 = list(glob(os.path.join(self.video_folder, '**', '*.mp4'), recursive=True)) + # all_mp4 = all_mp4[:1000] + samples = [] + for i in tqdm(all_mp4): + video_id = os.path.basename(i).split('.')[0] + ae = os.path.split(i)[0].replace('data_split_tt', 'lb_causalvideovae444_feature') + ae = os.path.join(ae, f'{video_id}_causalvideovae444.npy') + if not os.path.exists(ae): + continue + + t5 = os.path.split(i)[0].replace('data_split_tt', 'lb_t5_feature') + cond_list = [] + cond_llava = os.path.join(t5, f'{video_id}_t5_llava_fea.npy') + mask_llava = os.path.join(t5, f'{video_id}_t5_llava_mask.npy') + if os.path.exists(cond_llava) and os.path.exists(mask_llava): + llava = dict(cond=cond_llava, mask=mask_llava) + cond_list.append(llava) + cond_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_fea.npy') + mask_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_mask.npy') + if os.path.exists(cond_sharegpt4v) and os.path.exists(mask_sharegpt4v): + sharegpt4v = dict(cond=cond_sharegpt4v, mask=mask_sharegpt4v) + cond_list.append(sharegpt4v) + if len(cond_list) > 0: + sample = dict(ae=ae, t5=cond_list) + samples.append(sample) + return samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + # try: + sample = self.samples[idx] + ae, t5 = sample['ae'], sample['t5'] + t5 = random.choice(t5) + video_origin = np.load(ae)[0] # C T H W + _, total_frames, _, _ = video_origin.shape + # Sampling video frames + start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) + assert end_frame_ind - start_frame_ind >= self.num_frames + select_video_idx = np.linspace(start_frame_ind, end_frame_ind - 1, num=self.num_frames, dtype=int) # start, stop, num=50 + # print('select_video_idx', total_frames, select_video_idx) + video = video_origin[:, select_video_idx] # C num_frames H W + video = torch.from_numpy(video) + + cond = torch.from_numpy(np.load(t5['cond']))[0] # L + cond_mask = torch.from_numpy(np.load(t5['mask']))[0] # L D + + if self.use_image_num != 0 and self.use_img_from_vid: + select_image_idx = np.random.randint(0, total_frames, self.use_image_num) + # print('select_image_idx', total_frames, self.use_image_num, select_image_idx) + images = video_origin[:, select_image_idx] # c, num_img, h, w + images = torch.from_numpy(images) + video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w + cond = torch.stack([cond] * (1+self.use_image_num)) # 1+self.use_image_num, l + cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) # 1+self.use_image_num, l + elif self.use_image_num != 0 and not self.use_img_from_vid: + images, captions = self.img_cap_list[idx] + raise NotImplementedError + else: + pass + + return video, cond, cond_mask + # except Exception as e: + # print(f'Error with {e}, {sample}') + # return self.__getitem__(random.randint(0, self.__len__() - 1)) + + def get_img_cap_list(self): + raise NotImplementedError + + + + +class T2V_T5_Feature_dataset(Dataset): + def __init__(self, args, transform, temporal_sample): + + self.video_folder = args.video_folder + self.num_frames = args.num_frames + self.transform = transform + self.temporal_sample = temporal_sample + self.v_decoder = DecordInit() + + print('Building dataset...') + if os.path.exists('samples_430k.json'): + with open('samples_430k.json', 'r') as f: + self.samples = json.load(f) + self.samples = [dict(ae=i['ae'].replace('lb_causalvideovae444_feature', 'data_split_1024').replace('_causalvideovae444.npy', '.mp4'), t5=i['t5']) for i in self.samples] + else: + self.samples = self._make_dataset() + with open('samples_430k.json', 'w') as f: + json.dump(self.samples, f, indent=2) + + self.use_image_num = args.use_image_num + self.use_img_from_vid = args.use_img_from_vid + if self.use_image_num != 0 and not self.use_img_from_vid: + self.img_cap_list = self.get_img_cap_list() + + def _make_dataset(self): + all_mp4 = list(glob(os.path.join(self.video_folder, '**', '*.mp4'), recursive=True)) + # all_mp4 = all_mp4[:1000] + samples = [] + for i in tqdm(all_mp4): + video_id = os.path.basename(i).split('.')[0] + # ae = os.path.split(i)[0].replace('data_split', 'lb_causalvideovae444_feature') + # ae = os.path.join(ae, f'{video_id}_causalvideovae444.npy') + ae = i + if not os.path.exists(ae): + continue + + t5 = os.path.split(i)[0].replace('data_split_1024', 'lb_t5_feature') + cond_list = [] + cond_llava = os.path.join(t5, f'{video_id}_t5_llava_fea.npy') + mask_llava = os.path.join(t5, f'{video_id}_t5_llava_mask.npy') + if os.path.exists(cond_llava) and os.path.exists(mask_llava): + llava = dict(cond=cond_llava, mask=mask_llava) + cond_list.append(llava) + cond_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_fea.npy') + mask_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_mask.npy') + if os.path.exists(cond_sharegpt4v) and os.path.exists(mask_sharegpt4v): + sharegpt4v = dict(cond=cond_sharegpt4v, mask=mask_sharegpt4v) + cond_list.append(sharegpt4v) + if len(cond_list) > 0: + sample = dict(ae=ae, t5=cond_list) + samples.append(sample) + return samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + try: + sample = self.samples[idx] + ae, t5 = sample['ae'], sample['t5'] + t5 = random.choice(t5) + + video = self.decord_read(ae) + video = self.transform(video) # T C H W -> T C H W + video = video.transpose(0, 1) # T C H W -> C T H W + total_frames = video.shape[1] + cond = torch.from_numpy(np.load(t5['cond']))[0] # L + cond_mask = torch.from_numpy(np.load(t5['mask']))[0] # L D + + if self.use_image_num != 0 and self.use_img_from_vid: + select_image_idx = np.random.randint(0, total_frames, self.use_image_num) + # print('select_image_idx', total_frames, self.use_image_num, select_image_idx) + images = video.numpy()[:, select_image_idx] # c, num_img, h, w + images = torch.from_numpy(images) + video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w + cond = torch.stack([cond] * (1+self.use_image_num)) # 1+self.use_image_num, l + cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) # 1+self.use_image_num, l + elif self.use_image_num != 0 and not self.use_img_from_vid: + images, captions = self.img_cap_list[idx] + raise NotImplementedError + else: + pass + + return video, cond, cond_mask + except Exception as e: + print(f'Error with {e}, {sample}') + return self.__getitem__(random.randint(0, self.__len__() - 1)) + + def decord_read(self, path): + decord_vr = self.v_decoder(path) + total_frames = len(decord_vr) + # Sampling video frames + start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) + # assert end_frame_ind - start_frame_ind >= self.num_frames + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) + video_data = decord_vr.get_batch(frame_indice).asnumpy() + video_data = torch.from_numpy(video_data) + video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) + return video_data + + def get_img_cap_list(self): + raise NotImplementedError \ No newline at end of file diff --git a/opensora/dataset/landscope.py b/opensora/dataset/landscope.py new file mode 100644 index 0000000000000000000000000000000000000000..a1314170e3b4decc4186ee98a44be06218b570da --- /dev/null +++ b/opensora/dataset/landscope.py @@ -0,0 +1,90 @@ +import math +import os +from glob import glob + +import decord +import numpy as np +import torch +import torchvision +from decord import VideoReader, cpu +from torch.utils.data import Dataset +from torchvision.transforms import Compose, Lambda, ToTensor +from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo +from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample +from torch.nn import functional as F +import random + +from opensora.utils.dataset_utils import DecordInit + + +class Landscope(Dataset): + def __init__(self, args, transform, temporal_sample): + self.data_path = args.data_path + self.num_frames = args.num_frames + self.transform = transform + self.temporal_sample = temporal_sample + self.v_decoder = DecordInit() + + self.samples = self._make_dataset() + self.use_image_num = args.use_image_num + self.use_img_from_vid = args.use_img_from_vid + if self.use_image_num != 0 and not self.use_img_from_vid: + self.img_cap_list = self.get_img_cap_list() + + + def _make_dataset(self): + paths = list(glob(os.path.join(self.data_path, '**', '*.mp4'), recursive=True)) + + return paths + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + video_path = self.samples[idx] + try: + video = self.tv_read(video_path) + video = self.transform(video) # T C H W -> T C H W + video = video.transpose(0, 1) # T C H W -> C T H W + if self.use_image_num != 0 and self.use_img_from_vid: + select_image_idx = np.linspace(0, self.num_frames - 1, self.use_image_num, dtype=int) + assert self.num_frames >= self.use_image_num + images = video[:, select_image_idx] # c, num_img, h, w + video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w + elif self.use_image_num != 0 and not self.use_img_from_vid: + images, captions = self.img_cap_list[idx] + raise NotImplementedError + else: + pass + return video, 1 + except Exception as e: + print(f'Error with {e}, {video_path}') + return self.__getitem__(random.randint(0, self.__len__()-1)) + + def tv_read(self, path): + vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') + total_frames = len(vframes) + + # Sampling video frames + start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) + # assert end_frame_ind - start_frame_ind >= self.num_frames + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) + video = vframes[frame_indice] # (T, C, H, W) + + return video + + def decord_read(self, path): + decord_vr = self.v_decoder(path) + total_frames = len(decord_vr) + # Sampling video frames + start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) + # assert end_frame_ind - start_frame_ind >= self.num_frames + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) + + video_data = decord_vr.get_batch(frame_indice).asnumpy() + video_data = torch.from_numpy(video_data) + video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) + return video_data + + def get_img_cap_list(self): + raise NotImplementedError diff --git a/opensora/dataset/sky_datasets.py b/opensora/dataset/sky_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..23bbfd22656697f90ad547c893b01ee0faee74b9 --- /dev/null +++ b/opensora/dataset/sky_datasets.py @@ -0,0 +1,128 @@ +import os +import torch +import random +import torch.utils.data as data + +import numpy as np + +from PIL import Image + +from opensora.utils.dataset_utils import is_image_file + + +class Sky(data.Dataset): + def __init__(self, args, transform, temporal_sample=None, train=True): + + self.args = args + self.data_path = args.data_path + self.transform = transform + self.temporal_sample = temporal_sample + self.num_frames = self.args.num_frames + self.sample_rate = self.args.sample_rate + self.data_all = self.load_video_frames(self.data_path) + self.use_image_num = args.use_image_num + self.use_img_from_vid = args.use_img_from_vid + if self.use_image_num != 0 and not self.use_img_from_vid: + self.img_cap_list = self.get_img_cap_list() + + def __getitem__(self, index): + + vframes = self.data_all[index] + total_frames = len(vframes) + + # Sampling video frames + start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) + assert end_frame_ind - start_frame_ind >= self.num_frames + frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, num=self.num_frames, dtype=int) # start, stop, num=50 + + select_video_frames = vframes[frame_indice[0]: frame_indice[-1]+1: self.sample_rate] + + video_frames = [] + for path in select_video_frames: + video_frame = torch.as_tensor(np.array(Image.open(path), dtype=np.uint8, copy=True)).unsqueeze(0) + video_frames.append(video_frame) + video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) + video_clip = self.transform(video_clip) + video_clip = video_clip.transpose(0, 1) # T C H W -> C T H W + + if self.use_image_num != 0 and self.use_img_from_vid: + select_image_idx = np.linspace(0, self.num_frames - 1, self.use_image_num, dtype=int) + assert self.num_frames >= self.use_image_num + images = video_clip[:, select_image_idx] # c, num_img, h, w + video_clip = torch.cat([video_clip, images], dim=1) # c, num_frame+num_img, h, w + elif self.use_image_num != 0 and not self.use_img_from_vid: + images, captions = self.img_cap_list[index] + raise NotImplementedError + else: + pass + + return video_clip, 1 + + def __len__(self): + return self.video_num + + def load_video_frames(self, dataroot): + data_all = [] + frame_list = os.walk(dataroot) + for _, meta in enumerate(frame_list): + root = meta[0] + try: + frames = [i for i in meta[2] if is_image_file(i)] + frames = sorted(frames, key=lambda item: int(item.split('.')[0].split('_')[-1])) + except: + pass + # print(meta[0]) # root + # print(meta[2]) # files + frames = [os.path.join(root, item) for item in frames if is_image_file(item)] + if len(frames) > max(0, self.num_frames * self.sample_rate): # need all > (16 * frame-interval) videos + # if len(frames) >= max(0, self.target_video_len): # need all > 16 frames videos + data_all.append(frames) + self.video_num = len(data_all) + return data_all + + def get_img_cap_list(self): + raise NotImplementedError + +if __name__ == '__main__': + + import argparse + import torchvision + import video_transforms + import torch.utils.data as data + + from torchvision import transforms + from torchvision.utils import save_image + + + parser = argparse.ArgumentParser() + parser.add_argument("--num_frames", type=int, default=16) + parser.add_argument("--frame_interval", type=int, default=4) + parser.add_argument("--data-path", type=str, default="/path/to/datasets/sky_timelapse/sky_train/") + config = parser.parse_args() + + + target_video_len = config.num_frames + + temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval) + trans = transforms.Compose([ + video_transforms.ToTensorVideo(), + # video_transforms.CenterCropVideo(256), + video_transforms.CenterCropResizeVideo(256), + # video_transforms.RandomHorizontalFlipVideo(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + + taichi_dataset = Sky(config, transform=trans, temporal_sample=temporal_sample) + print(len(taichi_dataset)) + taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1) + + for i, video_data in enumerate(taichi_dataloader): + print(video_data['video'].shape) + + # print(video_data.dtype) + # for i in range(target_video_len): + # save_image(video_data[0][i], os.path.join('./test_data', '%04d.png' % i), normalize=True, value_range=(-1, 1)) + + # video_ = ((video_data[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1) + # torchvision.io.write_video('./test_data' + 'test.mp4', video_, fps=8) + # exit() \ No newline at end of file diff --git a/opensora/dataset/t2v_datasets.py b/opensora/dataset/t2v_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..6741d87e14f36e85d43394dd9400370b49352103 --- /dev/null +++ b/opensora/dataset/t2v_datasets.py @@ -0,0 +1,111 @@ +import json +import os, io, csv, math, random +import numpy as np +import torchvision +from einops import rearrange +from decord import VideoReader + +import torch +import torchvision.transforms as transforms +from torch.utils.data.dataset import Dataset +from tqdm import tqdm + +from opensora.utils.dataset_utils import DecordInit +from opensora.utils.utils import text_preprocessing + + + +class T2V_dataset(Dataset): + def __init__(self, args, transform, temporal_sample, tokenizer): + + # with open(args.data_path, 'r') as csvfile: + # self.samples = list(csv.DictReader(csvfile)) + self.video_folder = args.video_folder + self.num_frames = args.num_frames + self.transform = transform + self.temporal_sample = temporal_sample + self.tokenizer = tokenizer + self.model_max_length = args.model_max_length + self.v_decoder = DecordInit() + + with open(args.data_path, 'r') as f: + self.samples = json.load(f) + self.use_image_num = args.use_image_num + self.use_img_from_vid = args.use_img_from_vid + if self.use_image_num != 0 and not self.use_img_from_vid: + self.img_cap_list = self.get_img_cap_list() + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + try: + # video = torch.randn(3, 16, 128, 128) + # input_ids = torch.ones(1, 120).to(torch.long).squeeze(0) + # cond_mask = torch.cat([torch.ones(1, 60).to(torch.long), torch.ones(1, 60).to(torch.long)], dim=1).squeeze(0) + # return video, input_ids, cond_mask + video_path = self.samples[idx]['path'] + video = self.decord_read(video_path) + video = self.transform(video) # T C H W -> T C H W + video = video.transpose(0, 1) # T C H W -> C T H W + text = self.samples[idx]['cap'][0] + + text = text_preprocessing(text) + text_tokens_and_mask = self.tokenizer( + text, + max_length=self.model_max_length, + padding='max_length', + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors='pt' + ) + input_ids = text_tokens_and_mask['input_ids'].squeeze(0) + cond_mask = text_tokens_and_mask['attention_mask'].squeeze(0) + + if self.use_image_num != 0 and self.use_img_from_vid: + select_image_idx = np.linspace(0, self.num_frames-1, self.use_image_num, dtype=int) + assert self.num_frames >= self.use_image_num + images = video[:, select_image_idx] # c, num_img, h, w + video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w + input_ids = torch.stack([input_ids] * (1+self.use_image_num)) # 1+self.use_image_num, l + cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) # 1+self.use_image_num, l + elif self.use_image_num != 0 and not self.use_img_from_vid: + images, captions = self.img_cap_list[idx] + raise NotImplementedError + else: + pass + + return video, input_ids, cond_mask + except Exception as e: + print(f'Error with {e}, {self.samples[idx]}') + return self.__getitem__(random.randint(0, self.__len__() - 1)) + + def tv_read(self, path): + vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') + total_frames = len(vframes) + + # Sampling video frames + start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) + # assert end_frame_ind - start_frame_ind >= self.num_frames + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) + + video = vframes[frame_indice] # (T, C, H, W) + + return video + + def decord_read(self, path): + decord_vr = self.v_decoder(path) + total_frames = len(decord_vr) + # Sampling video frames + start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) + # assert end_frame_ind - start_frame_ind >= self.num_frames + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) + + video_data = decord_vr.get_batch(frame_indice).asnumpy() + video_data = torch.from_numpy(video_data) + video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) + return video_data + + def get_img_cap_list(self): + raise NotImplementedError \ No newline at end of file diff --git a/opensora/dataset/transform.py b/opensora/dataset/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..afcd611a6fcbfa2e8106a47a2139d7fa793b3af5 --- /dev/null +++ b/opensora/dataset/transform.py @@ -0,0 +1,489 @@ +import torch +import random +import numbers +from torchvision.transforms import RandomCrop, RandomResizedCrop + + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tensor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + +def crop(clip, i, j, h, w): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + """ + if len(clip.size()) != 4: + raise ValueError("clip should be a 4D tensor") + return clip[..., i: i + h, j: j + w] + + +def resize(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=True, antialias=True) + + +def resize_scale(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + H, W = clip.size(-2), clip.size(-1) + scale_ = target_size[0] / min(H, W) + return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=True, antialias=True) + + +def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): + """ + Do spatial cropping and resizing to the video clip + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + i (int): i in (i,j) i.e coordinates of the upper left corner. + j (int): j in (i,j) i.e coordinates of the upper left corner. + h (int): Height of the cropped region. + w (int): Width of the cropped region. + size (tuple(int, int)): height and width of resized clip + Returns: + clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + clip = crop(clip, i, j, h, w) + clip = resize(clip, size, interpolation_mode) + return clip + + +def center_crop(clip, crop_size): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + th, tw = crop_size + if h < th or w < tw: + raise ValueError("height and width must be no smaller than crop_size") + + i = int(round((h - th) / 2.0)) + j = int(round((w - tw) / 2.0)) + return crop(clip, i, j, th, tw) + + +def center_crop_using_short_edge(clip): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + if h < w: + th, tw = h, h + i = 0 + j = int(round((w - tw) / 2.0)) + else: + th, tw = w, w + i = int(round((h - th) / 2.0)) + j = 0 + return crop(clip, i, j, th, tw) + + +def random_shift_crop(clip): + ''' + Slide along the long edge, with the short edge as crop size + ''' + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + + if h <= w: + long_edge = w + short_edge = h + else: + long_edge = h + short_edge = w + + th, tw = short_edge, short_edge + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + return crop(clip, i, j, th, tw) + + +def to_tensor(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) + # return clip.float().permute(3, 0, 1, 2) / 255.0 + return clip.float() / 255.0 + + +def normalize(clip, mean, std, inplace=False): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) + mean (tuple): pixel RGB mean. Size is (3) + std (tuple): pixel standard deviation. Size is (3) + Returns: + normalized clip (torch.tensor): Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + if not inplace: + clip = clip.clone() + mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) + # print(mean) + std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + return clip + + +def hflip(clip): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) + Returns: + flipped clip (torch.tensor): Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + return clip.flip(-1) + + +class RandomCropVideo: + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: randomly cropped video clip. + size is (T, C, OH, OW) + """ + i, j, h, w = self.get_params(clip) + return crop(clip, i, j, h, w) + + def get_params(self, clip): + h, w = clip.shape[-2:] + th, tw = self.size + + if h < th or w < tw: + raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") + + if w == tw and h == th: + return 0, 0, h, w + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + + return i, j, th, tw + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +class CenterCropResizeVideo: + ''' + First use the short side for cropping length, + center crop video, then resize to the specified size + ''' + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_center_crop = center_crop_using_short_edge(clip) + clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, + interpolation_mode=self.interpolation_mode) + return clip_center_crop_resize + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class UCFCenterCropVideo: + ''' + First scale to the specified size in equal proportion to the short edge, + then center cropping + ''' + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode) + clip_center_crop = center_crop(clip_resize, self.size) + return clip_center_crop + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class KineticsRandomCropResizeVideo: + ''' + Slide along the long edge, with the short edge as crop size. And resie to the desired size. + ''' + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + clip_random_crop = random_shift_crop(clip) + clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode) + return clip_resize + + +class CenterCropVideo: + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_center_crop = center_crop(clip, self.size) + return clip_center_crop + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class NormalizeVideo: + """ + Normalize the video clip by mean subtraction and division by standard deviation + Args: + mean (3-tuple): pixel RGB mean + std (3-tuple): pixel RGB standard deviation + inplace (boolean): whether do in-place normalization + """ + + def __init__(self, mean, std, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) + """ + return normalize(clip, self.mean, self.std, self.inplace) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" + + +class ToTensorVideo: + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + """ + + def __init__(self): + pass + + def __call__(self, clip): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + return to_tensor(clip) + + def __repr__(self) -> str: + return self.__class__.__name__ + + +class RandomHorizontalFlipVideo: + """ + Flip the video clip along the horizontal direction with a given probability + Args: + p (float): probability of the clip being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Size is (T, C, H, W) + Return: + clip (torch.tensor): Size is (T, C, H, W) + """ + if random.random() < self.p: + clip = hflip(clip) + return clip + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(p={self.p})" + + +# ------------------------------------------------------------ +# --------------------- Sampling --------------------------- +# ------------------------------------------------------------ +class TemporalRandomCrop(object): + """Temporally crop the given frame indices at a random location. + + Args: + size (int): Desired length of frames will be seen in the model. + """ + + def __init__(self, size): + self.size = size + + def __call__(self, total_frames): + rand_end = max(0, total_frames - self.size - 1) + begin_index = random.randint(0, rand_end) + end_index = min(begin_index + self.size, total_frames) + return begin_index, end_index + + +if __name__ == '__main__': + from torchvision import transforms + import torchvision.io as io + import numpy as np + from torchvision.utils import save_image + import os + + vframes, aframes, info = io.read_video( + filename='./v_Archery_g01_c03.avi', + pts_unit='sec', + output_format='TCHW' + ) + + trans = transforms.Compose([ + ToTensorVideo(), + RandomHorizontalFlipVideo(), + UCFCenterCropVideo(512), + # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + + target_video_len = 32 + frame_interval = 1 + total_frames = len(vframes) + print(total_frames) + + temporal_sample = TemporalRandomCrop(target_video_len * frame_interval) + + # Sampling video frames + start_frame_ind, end_frame_ind = temporal_sample(total_frames) + # print(start_frame_ind) + # print(end_frame_ind) + assert end_frame_ind - start_frame_ind >= target_video_len + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int) + print(frame_indice) + + select_vframes = vframes[frame_indice] + print(select_vframes.shape) + print(select_vframes.dtype) + + select_vframes_trans = trans(select_vframes) + print(select_vframes_trans.shape) + print(select_vframes_trans.dtype) + + select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8) + print(select_vframes_trans_int.dtype) + print(select_vframes_trans_int.permute(0, 2, 3, 1).shape) + + io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8) + + for i in range(target_video_len): + save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, + value_range=(-1, 1)) diff --git a/opensora/dataset/ucf101.py b/opensora/dataset/ucf101.py new file mode 100644 index 0000000000000000000000000000000000000000..c368976fafec8aa7bf5450fcf3990b21ebc4e121 --- /dev/null +++ b/opensora/dataset/ucf101.py @@ -0,0 +1,80 @@ +import math +import os + +import decord +import numpy as np +import torch +import torchvision +from decord import VideoReader, cpu +from torch.utils.data import Dataset +from torchvision.transforms import Compose, Lambda, ToTensor +from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo +from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample +from torch.nn import functional as F +import random + +from opensora.utils.dataset_utils import DecordInit + + +class UCF101(Dataset): + def __init__(self, args, transform, temporal_sample): + self.data_path = args.data_path + self.num_frames = args.num_frames + self.transform = transform + self.temporal_sample = temporal_sample + self.v_decoder = DecordInit() + + self.classes = sorted(os.listdir(self.data_path)) + self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)} + self.samples = self._make_dataset() + + + def _make_dataset(self): + dataset = [] + for class_name in self.classes: + class_path = os.path.join(self.data_path, class_name) + for fname in os.listdir(class_path): + if fname.endswith('.avi'): + item = (os.path.join(class_path, fname), self.class_to_idx[class_name]) + dataset.append(item) + return dataset + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + video_path, label = self.samples[idx] + try: + video = self.tv_read(video_path) + video = self.transform(video) # T C H W -> T C H W + video = video.transpose(0, 1) # T C H W -> C T H W + return video, label + except Exception as e: + print(f'Error with {e}, {video_path}') + return self.__getitem__(random.randint(0, self.__len__()-1)) + + def tv_read(self, path): + vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') + total_frames = len(vframes) + + # Sampling video frames + start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) + # assert end_frame_ind - start_frame_ind >= self.num_frames + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) + video = vframes[frame_indice] # (T, C, H, W) + + return video + + def decord_read(self, path): + decord_vr = self.v_decoder(path) + total_frames = len(decord_vr) + # Sampling video frames + start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) + # assert end_frame_ind - start_frame_ind >= self.num_frames + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) + + video_data = decord_vr.get_batch(frame_indice).asnumpy() + video_data = torch.from_numpy(video_data) + video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) + return video_data + diff --git a/opensora/eval/cal_flolpips.py b/opensora/eval/cal_flolpips.py new file mode 100644 index 0000000000000000000000000000000000000000..eaf7ca9e17f059aad99157881de6304cc0a14d44 --- /dev/null +++ b/opensora/eval/cal_flolpips.py @@ -0,0 +1,83 @@ +import numpy as np +import torch +from tqdm import tqdm +import math +from einops import rearrange +import sys +sys.path.append(".") +from opensora.eval.flolpips.pwcnet import Network as PWCNet +from opensora.eval.flolpips.flolpips import FloLPIPS + +loss_fn = FloLPIPS(net='alex', version='0.1').eval().requires_grad_(False) +flownet = PWCNet().eval().requires_grad_(False) + +def trans(x): + return x + + +def calculate_flolpips(videos1, videos2, device): + global loss_fn, flownet + + print("calculate_flowlpips...") + loss_fn = loss_fn.to(device) + flownet = flownet.to(device) + + if videos1.shape != videos2.shape: + print("Warning: the shape of videos are not equal.") + min_frames = min(videos1.shape[1], videos2.shape[1]) + videos1 = videos1[:, :min_frames] + videos2 = videos2[:, :min_frames] + + videos1 = trans(videos1) + videos2 = trans(videos2) + + flolpips_results = [] + for video_num in tqdm(range(videos1.shape[0])): + video1 = videos1[video_num].to(device) + video2 = videos2[video_num].to(device) + frames_rec = video1[:-1] + frames_rec_next = video1[1:] + frames_gt = video2[:-1] + frames_gt_next = video2[1:] + t, c, h, w = frames_gt.shape + flow_gt = flownet(frames_gt, frames_gt_next) + flow_dis = flownet(frames_rec, frames_rec_next) + flow_diff = flow_gt - flow_dis + flolpips = loss_fn.forward(frames_gt, frames_rec, flow_diff, normalize=True) + flolpips_results.append(flolpips.cpu().numpy().tolist()) + + flolpips_results = np.array(flolpips_results) # [batch_size, num_frames] + flolpips = {} + flolpips_std = {} + + for clip_timestamp in range(flolpips_results.shape[1]): + flolpips[clip_timestamp] = np.mean(flolpips_results[:,clip_timestamp], axis=-1) + flolpips_std[clip_timestamp] = np.std(flolpips_results[:,clip_timestamp], axis=-1) + + result = { + "value": flolpips, + "value_std": flolpips_std, + "video_setting": video1.shape, + "video_setting_name": "time, channel, heigth, width", + "result": flolpips_results, + "details": flolpips_results.tolist() + } + + return result + +# test code / using example + +def main(): + NUMBER_OF_VIDEOS = 8 + VIDEO_LENGTH = 50 + CHANNEL = 3 + SIZE = 64 + videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) + videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) + + import json + result = calculate_flolpips(videos1, videos2, "cuda:0") + print(json.dumps(result, indent=4)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/opensora/eval/cal_fvd.py b/opensora/eval/cal_fvd.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1a9806eb55d067a48b222ec718460222205898 --- /dev/null +++ b/opensora/eval/cal_fvd.py @@ -0,0 +1,85 @@ +import numpy as np +import torch +from tqdm import tqdm + +def trans(x): + # if greyscale images add channel + if x.shape[-3] == 1: + x = x.repeat(1, 1, 3, 1, 1) + + # permute BTCHW -> BCTHW + x = x.permute(0, 2, 1, 3, 4) + + return x + +def calculate_fvd(videos1, videos2, device, method='styleganv'): + + if method == 'styleganv': + from fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained + elif method == 'videogpt': + from fvd.videogpt.fvd import load_i3d_pretrained + from fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats + from fvd.videogpt.fvd import frechet_distance + + print("calculate_fvd...") + + # videos [batch_size, timestamps, channel, h, w] + + assert videos1.shape == videos2.shape + + i3d = load_i3d_pretrained(device=device) + fvd_results = [] + + # support grayscale input, if grayscale -> channel*3 + # BTCHW -> BCTHW + # videos -> [batch_size, channel, timestamps, h, w] + + videos1 = trans(videos1) + videos2 = trans(videos2) + + fvd_results = {} + + # for calculate FVD, each clip_timestamp must >= 10 + for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)): + + # get a video clip + # videos_clip [batch_size, channel, timestamps[:clip], h, w] + videos_clip1 = videos1[:, :, : clip_timestamp] + videos_clip2 = videos2[:, :, : clip_timestamp] + + # get FVD features + feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device) + feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device) + + # calculate FVD when timestamps[:clip] + fvd_results[clip_timestamp] = frechet_distance(feats1, feats2) + + result = { + "value": fvd_results, + "video_setting": videos1.shape, + "video_setting_name": "batch_size, channel, time, heigth, width", + } + + return result + +# test code / using example + +def main(): + NUMBER_OF_VIDEOS = 8 + VIDEO_LENGTH = 50 + CHANNEL = 3 + SIZE = 64 + videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) + videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) + device = torch.device("cuda") + # device = torch.device("cpu") + + import json + result = calculate_fvd(videos1, videos2, device, method='videogpt') + print(json.dumps(result, indent=4)) + + result = calculate_fvd(videos1, videos2, device, method='styleganv') + print(json.dumps(result, indent=4)) + +if __name__ == "__main__": + main() diff --git a/opensora/eval/cal_lpips.py b/opensora/eval/cal_lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..10edc610b1977b859245d2307f1bebfb0f277216 --- /dev/null +++ b/opensora/eval/cal_lpips.py @@ -0,0 +1,97 @@ +import numpy as np +import torch +from tqdm import tqdm +import math + +import torch +import lpips + +spatial = True # Return a spatial map of perceptual distance. + +# Linearly calibrated models (LPIPS) +loss_fn = lpips.LPIPS(net='alex', spatial=spatial) # Can also set net = 'squeeze' or 'vgg' +# loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg' + +def trans(x): + # if greyscale images add channel + if x.shape[-3] == 1: + x = x.repeat(1, 1, 3, 1, 1) + + # value range [0, 1] -> [-1, 1] + x = x * 2 - 1 + + return x + +def calculate_lpips(videos1, videos2, device): + # image should be RGB, IMPORTANT: normalized to [-1,1] + print("calculate_lpips...") + + assert videos1.shape == videos2.shape + + # videos [batch_size, timestamps, channel, h, w] + + # support grayscale input, if grayscale -> channel*3 + # value range [0, 1] -> [-1, 1] + videos1 = trans(videos1) + videos2 = trans(videos2) + + lpips_results = [] + + for video_num in tqdm(range(videos1.shape[0])): + # get a video + # video [timestamps, channel, h, w] + video1 = videos1[video_num] + video2 = videos2[video_num] + + lpips_results_of_a_video = [] + for clip_timestamp in range(len(video1)): + # get a img + # img [timestamps[x], channel, h, w] + # img [channel, h, w] tensor + + img1 = video1[clip_timestamp].unsqueeze(0).to(device) + img2 = video2[clip_timestamp].unsqueeze(0).to(device) + + loss_fn.to(device) + + # calculate lpips of a video + lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist()) + lpips_results.append(lpips_results_of_a_video) + + lpips_results = np.array(lpips_results) + + lpips = {} + lpips_std = {} + + for clip_timestamp in range(len(video1)): + lpips[clip_timestamp] = np.mean(lpips_results[:,clip_timestamp]) + lpips_std[clip_timestamp] = np.std(lpips_results[:,clip_timestamp]) + + + result = { + "value": lpips, + "value_std": lpips_std, + "video_setting": video1.shape, + "video_setting_name": "time, channel, heigth, width", + } + + return result + +# test code / using example + +def main(): + NUMBER_OF_VIDEOS = 8 + VIDEO_LENGTH = 50 + CHANNEL = 3 + SIZE = 64 + videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) + videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) + device = torch.device("cuda") + # device = torch.device("cpu") + + import json + result = calculate_lpips(videos1, videos2, device) + print(json.dumps(result, indent=4)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/opensora/eval/cal_psnr.py b/opensora/eval/cal_psnr.py new file mode 100644 index 0000000000000000000000000000000000000000..b325106c6cc70b1194f0ba64ae5e8a11ad3a3740 --- /dev/null +++ b/opensora/eval/cal_psnr.py @@ -0,0 +1,84 @@ +import numpy as np +import torch +from tqdm import tqdm +import math + +def img_psnr(img1, img2): + # [0,1] + # compute mse + # mse = np.mean((img1-img2)**2) + mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2) + # compute psnr + if mse < 1e-10: + return 100 + psnr = 20 * math.log10(1 / math.sqrt(mse)) + return psnr + +def trans(x): + return x + +def calculate_psnr(videos1, videos2): + print("calculate_psnr...") + + # videos [batch_size, timestamps, channel, h, w] + + assert videos1.shape == videos2.shape + + videos1 = trans(videos1) + videos2 = trans(videos2) + + psnr_results = [] + + for video_num in tqdm(range(videos1.shape[0])): + # get a video + # video [timestamps, channel, h, w] + video1 = videos1[video_num] + video2 = videos2[video_num] + + psnr_results_of_a_video = [] + for clip_timestamp in range(len(video1)): + # get a img + # img [timestamps[x], channel, h, w] + # img [channel, h, w] numpy + + img1 = video1[clip_timestamp].numpy() + img2 = video2[clip_timestamp].numpy() + + # calculate psnr of a video + psnr_results_of_a_video.append(img_psnr(img1, img2)) + + psnr_results.append(psnr_results_of_a_video) + + psnr_results = np.array(psnr_results) # [batch_size, num_frames] + psnr = {} + psnr_std = {} + + for clip_timestamp in range(len(video1)): + psnr[clip_timestamp] = np.mean(psnr_results[:,clip_timestamp]) + psnr_std[clip_timestamp] = np.std(psnr_results[:,clip_timestamp]) + + result = { + "value": psnr, + "value_std": psnr_std, + "video_setting": video1.shape, + "video_setting_name": "time, channel, heigth, width", + } + + return result + +# test code / using example + +def main(): + NUMBER_OF_VIDEOS = 8 + VIDEO_LENGTH = 50 + CHANNEL = 3 + SIZE = 64 + videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) + videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) + + import json + result = calculate_psnr(videos1, videos2) + print(json.dumps(result, indent=4)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/opensora/eval/cal_ssim.py b/opensora/eval/cal_ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..b2de811b6c508667177e0c58a1b8a654063fff6b --- /dev/null +++ b/opensora/eval/cal_ssim.py @@ -0,0 +1,113 @@ +import numpy as np +import torch +from tqdm import tqdm +import cv2 + +def ssim(img1, img2): + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1 ** 2 + mu2_sq = mu2 ** 2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +def calculate_ssim_function(img1, img2): + # [0,1] + # ssim is the only metric extremely sensitive to gray being compared to b/w + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[0] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[i], img2[i])) + return np.array(ssims).mean() + elif img1.shape[0] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + +def trans(x): + return x + +def calculate_ssim(videos1, videos2): + print("calculate_ssim...") + + # videos [batch_size, timestamps, channel, h, w] + + assert videos1.shape == videos2.shape + + videos1 = trans(videos1) + videos2 = trans(videos2) + + ssim_results = [] + + for video_num in tqdm(range(videos1.shape[0])): + # get a video + # video [timestamps, channel, h, w] + video1 = videos1[video_num] + video2 = videos2[video_num] + + ssim_results_of_a_video = [] + for clip_timestamp in range(len(video1)): + # get a img + # img [timestamps[x], channel, h, w] + # img [channel, h, w] numpy + + img1 = video1[clip_timestamp].numpy() + img2 = video2[clip_timestamp].numpy() + + # calculate ssim of a video + ssim_results_of_a_video.append(calculate_ssim_function(img1, img2)) + + ssim_results.append(ssim_results_of_a_video) + + ssim_results = np.array(ssim_results) + + ssim = {} + ssim_std = {} + + for clip_timestamp in range(len(video1)): + ssim[clip_timestamp] = np.mean(ssim_results[:,clip_timestamp]) + ssim_std[clip_timestamp] = np.std(ssim_results[:,clip_timestamp]) + + result = { + "value": ssim, + "value_std": ssim_std, + "video_setting": video1.shape, + "video_setting_name": "time, channel, heigth, width", + } + + return result + +# test code / using example + +def main(): + NUMBER_OF_VIDEOS = 8 + VIDEO_LENGTH = 50 + CHANNEL = 3 + SIZE = 64 + videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) + videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) + device = torch.device("cuda") + + import json + result = calculate_ssim(videos1, videos2) + print(json.dumps(result, indent=4)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/opensora/eval/eval_clip_score.py b/opensora/eval/eval_clip_score.py new file mode 100644 index 0000000000000000000000000000000000000000..ed2f5108233a5bc65d0e945f74edc9599d1d5557 --- /dev/null +++ b/opensora/eval/eval_clip_score.py @@ -0,0 +1,225 @@ +"""Calculates the CLIP Scores + +The CLIP model is a contrasitively learned language-image model. There is +an image encoder and a text encoder. It is believed that the CLIP model could +measure the similarity of cross modalities. Please find more information from +https://github.com/openai/CLIP. + +The CLIP Score measures the Cosine Similarity between two embedded features. +This repository utilizes the pretrained CLIP Model to calculate +the mean average of cosine similarities. + +See --help to see further details. + +Code apapted from https://github.com/mseitzer/pytorch-fid and https://github.com/openai/CLIP. + +Copyright 2023 The Hong Kong Polytechnic University + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import os.path as osp +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser + +import clip +import torch +from PIL import Image +from torch.utils.data import Dataset, DataLoader + +try: + from tqdm import tqdm +except ImportError: + # If tqdm is not available, provide a mock version of it + def tqdm(x): + return x + + +IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', + 'tif', 'tiff', 'webp'} + +TEXT_EXTENSIONS = {'txt'} + + +class DummyDataset(Dataset): + + FLAGS = ['img', 'txt'] + def __init__(self, real_path, generated_path, + real_flag: str = 'img', + generated_flag: str = 'img', + transform = None, + tokenizer = None) -> None: + super().__init__() + assert real_flag in self.FLAGS and generated_flag in self.FLAGS, \ + 'CLIP Score only support modality of {}. However, get {} and {}'.format( + self.FLAGS, real_flag, generated_flag + ) + self.real_folder = self._combine_without_prefix(real_path) + self.real_flag = real_flag + self.fake_foler = self._combine_without_prefix(generated_path) + self.generated_flag = generated_flag + self.transform = transform + self.tokenizer = tokenizer + # assert self._check() + + def __len__(self): + return len(self.real_folder) + + def __getitem__(self, index): + if index >= len(self): + raise IndexError + real_path = self.real_folder[index] + generated_path = self.fake_foler[index] + real_data = self._load_modality(real_path, self.real_flag) + fake_data = self._load_modality(generated_path, self.generated_flag) + + sample = dict(real=real_data, fake=fake_data) + return sample + + def _load_modality(self, path, modality): + if modality == 'img': + data = self._load_img(path) + elif modality == 'txt': + data = self._load_txt(path) + else: + raise TypeError("Got unexpected modality: {}".format(modality)) + return data + + def _load_img(self, path): + img = Image.open(path) + if self.transform is not None: + img = self.transform(img) + return img + + def _load_txt(self, path): + with open(path, 'r') as fp: + data = fp.read() + fp.close() + if self.tokenizer is not None: + data = self.tokenizer(data).squeeze() + return data + + def _check(self): + for idx in range(len(self)): + real_name = self.real_folder[idx].split('.') + fake_name = self.fake_folder[idx].split('.') + if fake_name != real_name: + return False + return True + + def _combine_without_prefix(self, folder_path, prefix='.'): + folder = [] + for name in os.listdir(folder_path): + if name[0] == prefix: + continue + folder.append(osp.join(folder_path, name)) + folder.sort() + return folder + + +@torch.no_grad() +def calculate_clip_score(dataloader, model, real_flag, generated_flag): + score_acc = 0. + sample_num = 0. + logit_scale = model.logit_scale.exp() + for batch_data in tqdm(dataloader): + real = batch_data['real'] + real_features = forward_modality(model, real, real_flag) + fake = batch_data['fake'] + fake_features = forward_modality(model, fake, generated_flag) + + # normalize features + real_features = real_features / real_features.norm(dim=1, keepdim=True).to(torch.float32) + fake_features = fake_features / fake_features.norm(dim=1, keepdim=True).to(torch.float32) + + # calculate scores + # score = logit_scale * real_features @ fake_features.t() + # score_acc += torch.diag(score).sum() + score = logit_scale * (fake_features * real_features).sum() + score_acc += score + sample_num += real.shape[0] + + return score_acc / sample_num + + +def forward_modality(model, data, flag): + device = next(model.parameters()).device + if flag == 'img': + features = model.encode_image(data.to(device)) + elif flag == 'txt': + features = model.encode_text(data.to(device)) + else: + raise TypeError + return features + + +def main(): + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument('--batch-size', type=int, default=50, + help='Batch size to use') + parser.add_argument('--clip-model', type=str, default='ViT-B/32', + help='CLIP model to use') + parser.add_argument('--num-workers', type=int, default=8, + help=('Number of processes to use for data loading. ' + 'Defaults to `min(8, num_cpus)`')) + parser.add_argument('--device', type=str, default=None, + help='Device to use. Like cuda, cuda:0 or cpu') + parser.add_argument('--real_flag', type=str, default='img', + help=('The modality of real path. ' + 'Default to img')) + parser.add_argument('--generated_flag', type=str, default='txt', + help=('The modality of generated path. ' + 'Default to txt')) + parser.add_argument('--real_path', type=str, + help=('Paths to the real images or ' + 'to .npz statistic files')) + parser.add_argument('--generated_path', type=str, + help=('Paths to the generated images or ' + 'to .npz statistic files')) + args = parser.parse_args() + + if args.device is None: + device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') + else: + device = torch.device(args.device) + + if args.num_workers is None: + try: + num_cpus = len(os.sched_getaffinity(0)) + except AttributeError: + # os.sched_getaffinity is not available under Windows, use + # os.cpu_count instead (which may not return the *available* number + # of CPUs). + num_cpus = os.cpu_count() + + num_workers = min(num_cpus, 8) if num_cpus is not None else 0 + else: + num_workers = args.num_workers + + print('Loading CLIP model: {}'.format(args.clip_model)) + model, preprocess = clip.load(args.clip_model, device=device) + + dataset = DummyDataset(args.real_path, args.generated_path, + args.real_flag, args.generated_flag, + transform=preprocess, tokenizer=clip.tokenize) + dataloader = DataLoader(dataset, args.batch_size, + num_workers=num_workers, pin_memory=True) + + print('Calculating CLIP Score:') + clip_score = calculate_clip_score(dataloader, model, + args.real_flag, args.generated_flag) + clip_score = clip_score.cpu().item() + print('CLIP Score: ', clip_score) + + +if __name__ == '__main__': + main() diff --git a/opensora/eval/eval_common_metric.py b/opensora/eval/eval_common_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..452c03209615026e0acdbf1765df92b017ea04fe --- /dev/null +++ b/opensora/eval/eval_common_metric.py @@ -0,0 +1,224 @@ +"""Calculates the CLIP Scores + +The CLIP model is a contrasitively learned language-image model. There is +an image encoder and a text encoder. It is believed that the CLIP model could +measure the similarity of cross modalities. Please find more information from +https://github.com/openai/CLIP. + +The CLIP Score measures the Cosine Similarity between two embedded features. +This repository utilizes the pretrained CLIP Model to calculate +the mean average of cosine similarities. + +See --help to see further details. + +Code apapted from https://github.com/mseitzer/pytorch-fid and https://github.com/openai/CLIP. + +Copyright 2023 The Hong Kong Polytechnic University + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import os.path as osp +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader, Subset +from decord import VideoReader, cpu +import random +from pytorchvideo.transforms import ShortSideScale +from torchvision.io import read_video +from torchvision.transforms import Lambda, Compose +from torchvision.transforms._transforms_video import CenterCropVideo +import sys +sys.path.append(".") +from opensora.eval.cal_lpips import calculate_lpips +from opensora.eval.cal_fvd import calculate_fvd +from opensora.eval.cal_psnr import calculate_psnr +from opensora.eval.cal_flolpips import calculate_flolpips +from opensora.eval.cal_ssim import calculate_ssim + +try: + from tqdm import tqdm +except ImportError: + # If tqdm is not available, provide a mock version of it + def tqdm(x): + return x + +class VideoDataset(Dataset): + def __init__(self, + real_video_dir, + generated_video_dir, + num_frames, + sample_rate = 1, + crop_size=None, + resolution=128, + ) -> None: + super().__init__() + self.real_video_files = self._combine_without_prefix(real_video_dir) + self.generated_video_files = self._combine_without_prefix(generated_video_dir) + self.num_frames = num_frames + self.sample_rate = sample_rate + self.crop_size = crop_size + self.short_size = resolution + + + def __len__(self): + return len(self.real_video_files) + + def __getitem__(self, index): + if index >= len(self): + raise IndexError + real_video_file = self.real_video_files[index] + generated_video_file = self.generated_video_files[index] + print(real_video_file, generated_video_file) + real_video_tensor = self._load_video(real_video_file) + generated_video_tensor = self._load_video(generated_video_file) + return {'real': real_video_tensor, 'generated':generated_video_tensor } + + + def _load_video(self, video_path): + num_frames = self.num_frames + sample_rate = self.sample_rate + decord_vr = VideoReader(video_path, ctx=cpu(0)) + total_frames = len(decord_vr) + sample_frames_len = sample_rate * num_frames + + if total_frames >= sample_frames_len: + s = 0 + e = s + sample_frames_len + num_frames = num_frames + else: + s = 0 + e = total_frames + num_frames = int(total_frames / sample_frames_len * num_frames) + print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path, + total_frames) + + + frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) + video_data = decord_vr.get_batch(frame_id_list).asnumpy() + video_data = torch.from_numpy(video_data) + video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (C, T, H, W) + return _preprocess(video_data, short_size=self.short_size, crop_size = self.crop_size) + + + def _combine_without_prefix(self, folder_path, prefix='.'): + folder = [] + os.makedirs(folder_path, exist_ok=True) + for name in os.listdir(folder_path): + if name[0] == prefix: + continue + if osp.isfile(osp.join(folder_path, name)): + folder.append(osp.join(folder_path, name)) + folder.sort() + return folder + +def _preprocess(video_data, short_size=128, crop_size=None): + transform = Compose( + [ + Lambda(lambda x: x / 255.0), + ShortSideScale(size=short_size), + CenterCropVideo(crop_size=crop_size), + ] + ) + video_outputs = transform(video_data) + # video_outputs = torch.unsqueeze(video_outputs, 0) # (bz,c,t,h,w) + return video_outputs + + +def calculate_common_metric(args, dataloader, device): + + score_list = [] + for batch_data in tqdm(dataloader): # {'real': real_video_tensor, 'generated':generated_video_tensor } + real_videos = batch_data['real'] + generated_videos = batch_data['generated'] + assert real_videos.shape[2] == generated_videos.shape[2] + if args.metric == 'fvd': + tmp_list = list(calculate_fvd(real_videos, generated_videos, args.device, method=args.fvd_method)['value'].values()) + elif args.metric == 'ssim': + tmp_list = list(calculate_ssim(real_videos, generated_videos)['value'].values()) + elif args.metric == 'psnr': + tmp_list = list(calculate_psnr(real_videos, generated_videos)['value'].values()) + elif args.metric == 'flolpips': + result = calculate_flolpips(real_videos, generated_videos, args.device) + tmp_list = list(result['value'].values()) + else: + tmp_list = list(calculate_lpips(real_videos, generated_videos, args.device)['value'].values()) + score_list += tmp_list + return np.mean(score_list) + +def main(): + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument('--batch_size', type=int, default=2, + help='Batch size to use') + parser.add_argument('--real_video_dir', type=str, + help=('the path of real videos`')) + parser.add_argument('--generated_video_dir', type=str, + help=('the path of generated videos`')) + parser.add_argument('--device', type=str, default=None, + help='Device to use. Like cuda, cuda:0 or cpu') + parser.add_argument('--num_workers', type=int, default=8, + help=('Number of processes to use for data loading. ' + 'Defaults to `min(8, num_cpus)`')) + parser.add_argument('--sample_fps', type=int, default=30) + parser.add_argument('--resolution', type=int, default=336) + parser.add_argument('--crop_size', type=int, default=None) + parser.add_argument('--num_frames', type=int, default=100) + parser.add_argument('--sample_rate', type=int, default=1) + parser.add_argument('--subset_size', type=int, default=None) + parser.add_argument("--metric", type=str, default="fvd",choices=['fvd','psnr','ssim','lpips', 'flolpips']) + parser.add_argument("--fvd_method", type=str, default='styleganv',choices=['styleganv','videogpt']) + + + args = parser.parse_args() + + if args.device is None: + device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') + else: + device = torch.device(args.device) + + if args.num_workers is None: + try: + num_cpus = len(os.sched_getaffinity(0)) + except AttributeError: + # os.sched_getaffinity is not available under Windows, use + # os.cpu_count instead (which may not return the *available* number + # of CPUs). + num_cpus = os.cpu_count() + + num_workers = min(num_cpus, 8) if num_cpus is not None else 0 + else: + num_workers = args.num_workers + + + dataset = VideoDataset(args.real_video_dir, + args.generated_video_dir, + num_frames = args.num_frames, + sample_rate = args.sample_rate, + crop_size=args.crop_size, + resolution=args.resolution) + + if args.subset_size: + indices = range(args.subset_size) + dataset = Subset(dataset, indices=indices) + + dataloader = DataLoader(dataset, args.batch_size, + num_workers=num_workers, pin_memory=True) + + + metric_score = calculate_common_metric(args, dataloader,device) + print('metric: ', args.metric, " ",metric_score) + +if __name__ == '__main__': + main() diff --git a/opensora/eval/flolpips/correlation/correlation.py b/opensora/eval/flolpips/correlation/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..c2c6f4941fd202d8ce601622723670f52dfeee1a --- /dev/null +++ b/opensora/eval/flolpips/correlation/correlation.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Correlation_rearrange = ''' + extern "C" __global__ void kernel_Correlation_rearrange( + const int n, + const float* input, + float* output + ) { + int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; + + if (intIndex >= n) { + return; + } + + int intSample = blockIdx.z; + int intChannel = blockIdx.y; + + float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; + + __syncthreads(); + + int intPaddedY = (intIndex / SIZE_3(input)) + 4; + int intPaddedX = (intIndex % SIZE_3(input)) + 4; + int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; + + output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; + } +''' + +kernel_Correlation_updateOutput = ''' + extern "C" __global__ void kernel_Correlation_updateOutput( + const int n, + const float* rbot0, + const float* rbot1, + float* top + ) { + extern __shared__ char patch_data_char[]; + + float *patch_data = (float *)patch_data_char; + + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = blockIdx.x + 4; + int y1 = blockIdx.y + 4; + int item = blockIdx.z; + int ch_off = threadIdx.x; + + // Load 3D patch into shared shared memory + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = rbot0[idx1]; + } + } + } + + __syncthreads(); + + __shared__ float sum[32]; + + // Compute correlation + for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { + sum[ch_off] = 0; + + int s2o = top_channel % 9 - 4; + int s2p = top_channel / 9 - 4; + + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + int idxPatchData = ji_off + ch; + int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; + + sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; + } + } + } + + __syncthreads(); + + if (ch_off == 0) { + float total_sum = 0; + for (int idx = 0; idx < 32; idx++) { + total_sum += sum[idx]; + } + const int sumelems = SIZE_3(rbot0); + const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; + top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; + } + } + } +''' + +kernel_Correlation_updateGradFirst = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradFirst( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradFirst); // channels + int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos + int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + + // Same here: + int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) + int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) + + float sum = 0; + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + // Get rbot1 data: + int s2o = o; + int s2p = p; + int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; + float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot1tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradFirst); + const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); + gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; + } } +''' + +kernel_Correlation_updateGradSecond = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradSecond( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradSecond); // channels + int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos + int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + float sum = 0; + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + int s2o = o; + int s2p = p; + + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + + // Same here: + int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) + int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) + + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + // Get rbot0 data: + int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; + float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot0tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradSecond); + const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); + gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; + } } +''' + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + # end + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + # end + + return strKernel +# end + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.RawKernel(strKernel, strFunction) +# end + +class _FunctionCorrelation(torch.autograd.Function): + @staticmethod + def forward(self, first, second): + rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ]) + rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ]) + + self.save_for_backward(first, second, rbot0, rbot1) + + first = first.contiguous(); assert(first.is_cuda == True) + second = second.contiguous(); assert(second.is_cuda == True) + + output = first.new_zeros([ first.shape[0], 81, first.shape[2], first.shape[3] ]) + + if first.is_cuda == True: + n = first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': first, + 'output': rbot0 + }))( + grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]), + block=tuple([ 16, 1, 1 ]), + args=[ n, first.data_ptr(), rbot0.data_ptr() ] + ) + + n = second.shape[2] * second.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': second, + 'output': rbot1 + }))( + grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]), + block=tuple([ 16, 1, 1 ]), + args=[ n, second.data_ptr(), rbot1.data_ptr() ] + ) + + n = output.shape[1] * output.shape[2] * output.shape[3] + cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'top': output + }))( + grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]), + block=tuple([ 32, 1, 1 ]), + shared_mem=first.shape[1] * 4, + args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ] + ) + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return output + # end + + @staticmethod + def backward(self, gradOutput): + first, second, rbot0, rbot1 = self.saved_tensors + + gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True) + + gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None + gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None + + if first.is_cuda == True: + if gradFirst is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': gradFirst, + 'gradSecond': None + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ] + ) + # end + # end + + if gradSecond is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': None, + 'gradSecond': gradSecond + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ] + ) + # end + # end + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return gradFirst, gradSecond + # end +# end + +def FunctionCorrelation(tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) +# end + +class ModuleCorrelation(torch.nn.Module): + def __init__(self): + super(ModuleCorrelation, self).__init__() + # end + + def forward(self, tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) + # end +# end \ No newline at end of file diff --git a/opensora/eval/flolpips/flolpips.py b/opensora/eval/flolpips/flolpips.py new file mode 100644 index 0000000000000000000000000000000000000000..f237cab25bbb1e7c733d523256de88222c418511 --- /dev/null +++ b/opensora/eval/flolpips/flolpips.py @@ -0,0 +1,308 @@ + +from __future__ import absolute_import +import os +import numpy as np +import torch +import torch.nn as nn +from torch.autograd import Variable +from .pretrained_networks import vgg16, alexnet, squeezenet +import torch.nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF +import cv2 + +from .pwcnet import Network as PWCNet +from .utils import * + +def spatial_average(in_tens, keepdim=True): + return in_tens.mean([2,3],keepdim=keepdim) + +def mw_spatial_average(in_tens, flow, keepdim=True): + _,_,h,w = in_tens.shape + flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear') + flow_mag = torch.sqrt(flow[:,0:1]**2 + flow[:,1:2]**2) + flow_mag = flow_mag / torch.sum(flow_mag, dim=[1,2,3], keepdim=True) + return torch.sum(in_tens*flow_mag, dim=[2,3],keepdim=keepdim) + + +def mtw_spatial_average(in_tens, flow, texture, keepdim=True): + _,_,h,w = in_tens.shape + flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear') + texture = F.interpolate(texture, (h,w), align_corners=False, mode='bilinear') + flow_mag = torch.sqrt(flow[:,0:1]**2 + flow[:,1:2]**2) + flow_mag = (flow_mag - flow_mag.min()) / (flow_mag.max() - flow_mag.min()) + 1e-6 + texture = (texture - texture.min()) / (texture.max() - texture.min()) + 1e-6 + weight = flow_mag / texture + weight /= torch.sum(weight) + return torch.sum(in_tens*weight, dim=[2,3],keepdim=keepdim) + + + +def m2w_spatial_average(in_tens, flow, keepdim=True): + _,_,h,w = in_tens.shape + flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear') + flow_mag = flow[:,0:1]**2 + flow[:,1:2]**2 # B,1,H,W + flow_mag = flow_mag / torch.sum(flow_mag) + return torch.sum(in_tens*flow_mag, dim=[2,3],keepdim=keepdim) + +def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W + in_H, in_W = in_tens.shape[2], in_tens.shape[3] + return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens) + +# Learned perceptual metric +class LPIPS(nn.Module): + def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, + pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False): + # lpips - [True] means with linear calibration on top of base network + # pretrained - [True] means load linear weights + + super(LPIPS, self).__init__() + if(verbose): + print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'% + ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) + + self.pnet_type = net + self.pnet_tune = pnet_tune + self.pnet_rand = pnet_rand + self.spatial = spatial + self.lpips = lpips # false means baseline of just averaging all layers + self.version = version + self.scaling_layer = ScalingLayer() + + if(self.pnet_type in ['vgg','vgg16']): + net_type = vgg16 + self.chns = [64,128,256,512,512] + elif(self.pnet_type=='alex'): + net_type = alexnet + self.chns = [64,192,384,256,256] + elif(self.pnet_type=='squeeze'): + net_type = squeezenet + self.chns = [64,128,256,384,384,512,512] + self.L = len(self.chns) + + self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) + + if(lpips): + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] + if(self.pnet_type=='squeeze'): # 7 layers for squeezenet + self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) + self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) + self.lins+=[self.lin5,self.lin6] + self.lins = nn.ModuleList(self.lins) + + if(pretrained): + if(model_path is None): + import inspect + import os + model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net))) + + if(verbose): + print('Loading model from: %s'%model_path) + self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) + + if(eval_mode): + self.eval() + + def forward(self, in0, in1, retPerLayer=False, normalize=False): + if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] + in0 = 2 * in0 - 1 + in1 = 2 * in1 - 1 + + # v0.0 - original release had a bug, where input was not scaled + in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) + outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) + feats0, feats1, diffs = {}, {}, {} + + for kk in range(self.L): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk]-feats1[kk])**2 + + if(self.lpips): + if(self.spatial): + res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] + else: + res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] + else: + if(self.spatial): + res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] + else: + res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] + + # val = res[0] + # for l in range(1,self.L): + # val += res[l] + # print(val) + + # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True) + # b = torch.max(self.lins[kk](feats0[kk]**2)) + # for kk in range(self.L): + # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True) + # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2))) + # a = a/self.L + # from IPython import embed + # embed() + # return 10*torch.log10(b/a) + + # if(retPerLayer): + # return (val, res) + # else: + return torch.sum(torch.cat(res, 1), dim=(1,2,3), keepdims=False) + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) + self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + ''' A single linear layer which does a 1x1 conv ''' + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + + layers = [nn.Dropout(),] if(use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + +class Dist2LogitLayer(nn.Module): + ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' + def __init__(self, chn_mid=32, use_sigmoid=True): + super(Dist2LogitLayer, self).__init__() + + layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] + layers += [nn.LeakyReLU(0.2,True),] + layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] + layers += [nn.LeakyReLU(0.2,True),] + layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] + if(use_sigmoid): + layers += [nn.Sigmoid(),] + self.model = nn.Sequential(*layers) + + def forward(self,d0,d1,eps=0.1): + return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) + +class BCERankingLoss(nn.Module): + def __init__(self, chn_mid=32): + super(BCERankingLoss, self).__init__() + self.net = Dist2LogitLayer(chn_mid=chn_mid) + # self.parameters = list(self.net.parameters()) + self.loss = torch.nn.BCELoss() + + def forward(self, d0, d1, judge): + per = (judge+1.)/2. + self.logit = self.net.forward(d0,d1) + return self.loss(self.logit, per) + +# L2, DSSIM metrics +class FakeNet(nn.Module): + def __init__(self, use_gpu=True, colorspace='Lab'): + super(FakeNet, self).__init__() + self.use_gpu = use_gpu + self.colorspace = colorspace + +class L2(FakeNet): + def forward(self, in0, in1, retPerLayer=None): + assert(in0.size()[0]==1) # currently only supports batchSize 1 + + if(self.colorspace=='RGB'): + (N,C,X,Y) = in0.size() + value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) + return value + elif(self.colorspace=='Lab'): + value = l2(tensor2np(tensor2tensorlab(in0.data,to_norm=False)), + tensor2np(tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') + ret_var = Variable( torch.Tensor((value,) ) ) + if(self.use_gpu): + ret_var = ret_var.cuda() + return ret_var + +class DSSIM(FakeNet): + + def forward(self, in0, in1, retPerLayer=None): + assert(in0.size()[0]==1) # currently only supports batchSize 1 + + if(self.colorspace=='RGB'): + value = dssim(1.*tensor2im(in0.data), 1.*tensor2im(in1.data), range=255.).astype('float') + elif(self.colorspace=='Lab'): + value = dssim(tensor2np(tensor2tensorlab(in0.data,to_norm=False)), + tensor2np(tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') + ret_var = Variable( torch.Tensor((value,) ) ) + if(self.use_gpu): + ret_var = ret_var.cuda() + return ret_var + +def print_network(net): + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print('Network',net) + print('Total number of parameters: %d' % num_params) + + +class FloLPIPS(LPIPS): + def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False): + super(FloLPIPS, self).__init__(pretrained, net, version, lpips, spatial, pnet_rand, pnet_tune, use_dropout, model_path, eval_mode, verbose) + + def forward(self, in0, in1, flow, retPerLayer=False, normalize=False): + if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] + in0 = 2 * in0 - 1 + in1 = 2 * in1 - 1 + + in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) + outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) + feats0, feats1, diffs = {}, {}, {} + + for kk in range(self.L): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk]-feats1[kk])**2 + + res = [mw_spatial_average(self.lins[kk](diffs[kk]), flow, keepdim=True) for kk in range(self.L)] + + return torch.sum(torch.cat(res, 1), dim=(1,2,3), keepdims=False) + + + + + +class Flolpips(nn.Module): + def __init__(self): + super(Flolpips, self).__init__() + self.loss_fn = FloLPIPS(net='alex',version='0.1') + self.flownet = PWCNet() + + @torch.no_grad() + def forward(self, I0, I1, frame_dis, frame_ref): + """ + args: + I0: first frame of the triplet, shape: [B, C, H, W] + I1: third frame of the triplet, shape: [B, C, H, W] + frame_dis: prediction of the intermediate frame, shape: [B, C, H, W] + frame_ref: ground-truth of the intermediate frame, shape: [B, C, H, W] + """ + assert I0.size() == I1.size() == frame_dis.size() == frame_ref.size(), \ + "the 4 input tensors should have same size" + + flow_ref = self.flownet(frame_ref, I0) + flow_dis = self.flownet(frame_dis, I0) + flow_diff = flow_ref - flow_dis + flolpips_wrt_I0 = self.loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True) + + flow_ref = self.flownet(frame_ref, I1) + flow_dis = self.flownet(frame_dis, I1) + flow_diff = flow_ref - flow_dis + flolpips_wrt_I1 = self.loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True) + + flolpips = (flolpips_wrt_I0 + flolpips_wrt_I1) / 2 + return flolpips \ No newline at end of file diff --git a/opensora/eval/flolpips/pretrained_networks.py b/opensora/eval/flolpips/pretrained_networks.py new file mode 100644 index 0000000000000000000000000000000000000000..a70ebbeab1618da4fe2538833f049dc569f1eea1 --- /dev/null +++ b/opensora/eval/flolpips/pretrained_networks.py @@ -0,0 +1,180 @@ +from collections import namedtuple +import torch +from torchvision import models as tv + +class squeezenet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(squeezenet, self).__init__() + pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.slice6 = torch.nn.Sequential() + self.slice7 = torch.nn.Sequential() + self.N_slices = 7 + for x in range(2): + self.slice1.add_module(str(x), pretrained_features[x]) + for x in range(2,5): + self.slice2.add_module(str(x), pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), pretrained_features[x]) + for x in range(10, 11): + self.slice5.add_module(str(x), pretrained_features[x]) + for x in range(11, 12): + self.slice6.add_module(str(x), pretrained_features[x]) + for x in range(12, 13): + self.slice7.add_module(str(x), pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + h = self.slice6(h) + h_relu6 = h + h = self.slice7(h) + h_relu7 = h + vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) + out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) + + return out + + +class alexnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(alexnet, self).__init__() + alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(2): + self.slice1.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(10, 12): + self.slice5.add_module(str(x), alexnet_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) + out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + + return out + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + + return out + + + +class resnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True, num=18): + super(resnet, self).__init__() + if(num==18): + self.net = tv.resnet18(pretrained=pretrained) + elif(num==34): + self.net = tv.resnet34(pretrained=pretrained) + elif(num==50): + self.net = tv.resnet50(pretrained=pretrained) + elif(num==101): + self.net = tv.resnet101(pretrained=pretrained) + elif(num==152): + self.net = tv.resnet152(pretrained=pretrained) + self.N_slices = 5 + + self.conv1 = self.net.conv1 + self.bn1 = self.net.bn1 + self.relu = self.net.relu + self.maxpool = self.net.maxpool + self.layer1 = self.net.layer1 + self.layer2 = self.net.layer2 + self.layer3 = self.net.layer3 + self.layer4 = self.net.layer4 + + def forward(self, X): + h = self.conv1(X) + h = self.bn1(h) + h = self.relu(h) + h_relu1 = h + h = self.maxpool(h) + h = self.layer1(h) + h_conv2 = h + h = self.layer2(h) + h_conv3 = h + h = self.layer3(h) + h_conv4 = h + h = self.layer4(h) + h_conv5 = h + + outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) + out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) + + return out diff --git a/opensora/eval/flolpips/pwcnet.py b/opensora/eval/flolpips/pwcnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f1dc368c4dfb8793d8b13767c927b330e31f5c68 --- /dev/null +++ b/opensora/eval/flolpips/pwcnet.py @@ -0,0 +1,344 @@ +#!/usr/bin/env python + +import torch + +import getopt +import math +import numpy +import os +import PIL +import PIL.Image +import sys + +# try: +from .correlation import correlation # the custom cost volume layer +# except: +# sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python +# end + +########################################################## + +# assert(int(str('').join(torch.__version__.split('.')[0:2])) >= 13) # requires at least pytorch version 1.3.0 + +# torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance + +# torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance + +# ########################################################## + +# arguments_strModel = 'default' # 'default', or 'chairs-things' +# arguments_strFirst = './images/first.png' +# arguments_strSecond = './images/second.png' +# arguments_strOut = './out.flo' + +# for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]: +# if strOption == '--model' and strArgument != '': arguments_strModel = strArgument # which model to use +# if strOption == '--first' and strArgument != '': arguments_strFirst = strArgument # path to the first frame +# if strOption == '--second' and strArgument != '': arguments_strSecond = strArgument # path to the second frame +# if strOption == '--out' and strArgument != '': arguments_strOut = strArgument # path to where the output should be stored +# end + +########################################################## + + + +def backwarp(tenInput, tenFlow): + backwarp_tenGrid = {} + backwarp_tenPartial = {} + if str(tenFlow.shape) not in backwarp_tenGrid: + tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1) + tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3]) + + backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([ tenHor, tenVer ], 1).cuda() + # end + + if str(tenFlow.shape) not in backwarp_tenPartial: + backwarp_tenPartial[str(tenFlow.shape)] = tenFlow.new_ones([ tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3] ]) + # end + + tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) + tenInput = torch.cat([ tenInput, backwarp_tenPartial[str(tenFlow.shape)] ], 1) + + tenOutput = torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=False) + + tenMask = tenOutput[:, -1:, :, :]; tenMask[tenMask > 0.999] = 1.0; tenMask[tenMask < 1.0] = 0.0 + + return tenOutput[:, :-1, :, :] * tenMask +# end + +########################################################## + +class Network(torch.nn.Module): + def __init__(self): + super(Network, self).__init__() + + class Extractor(torch.nn.Module): + def __init__(self): + super(Extractor, self).__init__() + + self.netOne = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netTwo = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netThr = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netFou = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netFiv = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netSix = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + # end + + def forward(self, tenInput): + tenOne = self.netOne(tenInput) + tenTwo = self.netTwo(tenOne) + tenThr = self.netThr(tenTwo) + tenFou = self.netFou(tenThr) + tenFiv = self.netFiv(tenFou) + tenSix = self.netSix(tenFiv) + + return [ tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix ] + # end + # end + + class Decoder(torch.nn.Module): + def __init__(self, intLevel): + super(Decoder, self).__init__() + + intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 1] + intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 0] + + if intLevel < 6: self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1) + if intLevel < 6: self.netUpfeat = torch.nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1) + if intLevel < 6: self.fltBackwarp = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1] + + self.netOne = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netTwo = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netThr = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netFou = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netFiv = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netSix = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1) + ) + # end + + def forward(self, tenFirst, tenSecond, objPrevious): + tenFlow = None + tenFeat = None + + if objPrevious is None: + tenFlow = None + tenFeat = None + + tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=tenSecond), negative_slope=0.1, inplace=False) + + tenFeat = torch.cat([ tenVolume ], 1) + + elif objPrevious is not None: + tenFlow = self.netUpflow(objPrevious['tenFlow']) + tenFeat = self.netUpfeat(objPrevious['tenFeat']) + + tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=backwarp(tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp)), negative_slope=0.1, inplace=False) + + tenFeat = torch.cat([ tenVolume, tenFirst, tenFlow, tenFeat ], 1) + + # end + + tenFeat = torch.cat([ self.netOne(tenFeat), tenFeat ], 1) + tenFeat = torch.cat([ self.netTwo(tenFeat), tenFeat ], 1) + tenFeat = torch.cat([ self.netThr(tenFeat), tenFeat ], 1) + tenFeat = torch.cat([ self.netFou(tenFeat), tenFeat ], 1) + tenFeat = torch.cat([ self.netFiv(tenFeat), tenFeat ], 1) + + tenFlow = self.netSix(tenFeat) + + return { + 'tenFlow': tenFlow, + 'tenFeat': tenFeat + } + # end + # end + + class Refiner(torch.nn.Module): + def __init__(self): + super(Refiner, self).__init__() + + self.netMain = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1) + ) + # end + + def forward(self, tenInput): + return self.netMain(tenInput) + # end + # end + + self.netExtractor = Extractor() + + self.netTwo = Decoder(2) + self.netThr = Decoder(3) + self.netFou = Decoder(4) + self.netFiv = Decoder(5) + self.netSix = Decoder(6) + + self.netRefiner = Refiner() + + self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/github/pytorch-pwc/network-' + 'default' + '.pytorch').items() }) + # end + + def forward(self, tenFirst, tenSecond): + intWidth = tenFirst.shape[3] + intHeight = tenFirst.shape[2] + + intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) + intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) + + tenPreprocessedFirst = torch.nn.functional.interpolate(input=tenFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) + tenPreprocessedSecond = torch.nn.functional.interpolate(input=tenSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) + + tenFirst = self.netExtractor(tenPreprocessedFirst) + tenSecond = self.netExtractor(tenPreprocessedSecond) + + + objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None) + objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate) + objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate) + objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate) + objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate) + + tenFlow = objEstimate['tenFlow'] + self.netRefiner(objEstimate['tenFeat']) + tenFlow = 20.0 * torch.nn.functional.interpolate(input=tenFlow, size=(intHeight, intWidth), mode='bilinear', align_corners=False) + tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) + tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) + + return tenFlow + # end +# end + +netNetwork = None + +########################################################## + +def estimate(tenFirst, tenSecond): + global netNetwork + + if netNetwork is None: + netNetwork = Network().cuda().eval() + # end + + assert(tenFirst.shape[1] == tenSecond.shape[1]) + assert(tenFirst.shape[2] == tenSecond.shape[2]) + + intWidth = tenFirst.shape[2] + intHeight = tenFirst.shape[1] + + assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue + assert(intHeight == 436) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue + + tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth) + tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth) + + intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) + intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) + + tenPreprocessedFirst = torch.nn.functional.interpolate(input=tenPreprocessedFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) + tenPreprocessedSecond = torch.nn.functional.interpolate(input=tenPreprocessedSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) + + tenFlow = 20.0 * torch.nn.functional.interpolate(input=netNetwork(tenPreprocessedFirst, tenPreprocessedSecond), size=(intHeight, intWidth), mode='bilinear', align_corners=False) + + tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) + tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) + + return tenFlow[0, :, :, :].cpu() +# end + +########################################################## + +# if __name__ == '__main__': +# tenFirst = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strFirst))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0))) +# tenSecond = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strSecond))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0))) + +# tenOutput = estimate(tenFirst, tenSecond) + +# objOutput = open(arguments_strOut, 'wb') + +# numpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objOutput) +# numpy.array([ tenOutput.shape[2], tenOutput.shape[1] ], numpy.int32).tofile(objOutput) +# numpy.array(tenOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objOutput) + +# objOutput.close() +# end \ No newline at end of file diff --git a/opensora/eval/flolpips/utils.py b/opensora/eval/flolpips/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..337c48a6483c546d7493e36b3a8e59c47518fe04 --- /dev/null +++ b/opensora/eval/flolpips/utils.py @@ -0,0 +1,95 @@ +import numpy as np +import cv2 +import torch + + +def normalize_tensor(in_feat,eps=1e-10): + norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) + return in_feat/(norm_factor+eps) + +def l2(p0, p1, range=255.): + return .5*np.mean((p0 / range - p1 / range)**2) + +def dssim(p0, p1, range=255.): + from skimage.measure import compare_ssim + return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. + +def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + +def tensor2np(tensor_obj): + # change dimension of a tensor object into a numpy array + return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) + +def np2tensor(np_obj): + # change dimenion of np array into tensor array + return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + +def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): + # image tensor to lab tensor + from skimage import color + + img = tensor2im(image_tensor) + img_lab = color.rgb2lab(img) + if(mc_only): + img_lab[:,:,0] = img_lab[:,:,0]-50 + if(to_norm and not mc_only): + img_lab[:,:,0] = img_lab[:,:,0]-50 + img_lab = img_lab/100. + + return np2tensor(img_lab) + +def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt='420'): + if pix_fmt == '420': + multiplier = 1 + uv_factor = 2 + elif pix_fmt == '444': + multiplier = 2 + uv_factor = 1 + else: + print('Pixel format {} is not supported'.format(pix_fmt)) + return + + if bit_depth == 8: + datatype = np.uint8 + stream.seek(iFrame*1.5*width*height*multiplier) + Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width)) + + # read chroma samples and upsample since original is 4:2:0 sampling + U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ + reshape((height//uv_factor, width//uv_factor)) + V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ + reshape((height//uv_factor, width//uv_factor)) + + else: + datatype = np.uint16 + stream.seek(iFrame*3*width*height*multiplier) + Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width)) + + U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ + reshape((height//uv_factor, width//uv_factor)) + V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ + reshape((height//uv_factor, width//uv_factor)) + + if pix_fmt == '420': + yuv = np.empty((height*3//2, width), dtype=datatype) + yuv[0:height,:] = Y + + yuv[height:height+height//4,:] = U.reshape(-1, width) + yuv[height+height//4:,:] = V.reshape(-1, width) + + if bit_depth != 8: + yuv = (yuv/(2**bit_depth-1)*255).astype(np.uint8) + + #convert to rgb + rgb = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB_I420) + + else: + yvu = np.stack([Y,V,U],axis=2) + if bit_depth != 8: + yvu = (yvu/(2**bit_depth-1)*255).astype(np.uint8) + rgb = cv2.cvtColor(yvu, cv2.COLOR_YCrCb2RGB) + + return rgb diff --git a/opensora/eval/fvd/styleganv/fvd.py b/opensora/eval/fvd/styleganv/fvd.py new file mode 100644 index 0000000000000000000000000000000000000000..3043a2a4a4c4fc48ca97aedba074c2c2379685e4 --- /dev/null +++ b/opensora/eval/fvd/styleganv/fvd.py @@ -0,0 +1,90 @@ +import torch +import os +import math +import torch.nn.functional as F + +# https://github.com/universome/fvd-comparison + + +def load_i3d_pretrained(device=torch.device('cpu')): + i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt" + filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt') + print(filepath) + if not os.path.exists(filepath): + print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.") + os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}") + i3d = torch.jit.load(filepath).eval().to(device) + i3d = torch.nn.DataParallel(i3d) + return i3d + + +def get_feats(videos, detector, device, bs=10): + # videos : torch.tensor BCTHW [0, 1] + detector_kwargs = dict(rescale=False, resize=False, return_features=True) # Return raw features before the softmax layer. + feats = np.empty((0, 400)) + with torch.no_grad(): + for i in range((len(videos)-1)//bs + 1): + feats = np.vstack([feats, detector(torch.stack([preprocess_single(video) for video in videos[i*bs:(i+1)*bs]]).to(device), **detector_kwargs).detach().cpu().numpy()]) + return feats + + +def get_fvd_feats(videos, i3d, device, bs=10): + # videos in [0, 1] as torch tensor BCTHW + # videos = [preprocess_single(video) for video in videos] + embeddings = get_feats(videos, i3d, device, bs) + return embeddings + + +def preprocess_single(video, resolution=224, sequence_length=None): + # video: CTHW, [0, 1] + c, t, h, w = video.shape + + # temporal crop + if sequence_length is not None: + assert sequence_length <= t + video = video[:, :sequence_length] + + # scale shorter side to resolution + scale = resolution / min(h, w) + if h < w: + target_size = (resolution, math.ceil(w * scale)) + else: + target_size = (math.ceil(h * scale), resolution) + video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False) + + # center crop + c, t, h, w = video.shape + w_start = (w - resolution) // 2 + h_start = (h - resolution) // 2 + video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution] + + # [0, 1] -> [-1, 1] + video = (video - 0.5) * 2 + + return video.contiguous() + + +""" +Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py +""" +from typing import Tuple +from scipy.linalg import sqrtm +import numpy as np + + +def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + mu = feats.mean(axis=0) # [d] + sigma = np.cov(feats, rowvar=False) # [d, d] + return mu, sigma + + +def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float: + mu_gen, sigma_gen = compute_stats(feats_fake) + mu_real, sigma_real = compute_stats(feats_real) + m = np.square(mu_gen - mu_real).sum() + if feats_fake.shape[0]>1: + s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member + fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) + else: + fid = np.real(m) + return float(fid) \ No newline at end of file diff --git a/opensora/eval/fvd/styleganv/i3d_torchscript.pt b/opensora/eval/fvd/styleganv/i3d_torchscript.pt new file mode 100644 index 0000000000000000000000000000000000000000..b62542778dd3a064d69bbf7a56409053e9beea51 --- /dev/null +++ b/opensora/eval/fvd/styleganv/i3d_torchscript.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bec6519f66ea534e953026b4ae2c65553c17bf105611c746d904657e5860a5e2 +size 51235320 diff --git a/opensora/eval/fvd/videogpt/fvd.py b/opensora/eval/fvd/videogpt/fvd.py new file mode 100644 index 0000000000000000000000000000000000000000..e81c8290456a5f54de6d31bf0f946f87a12eda9c --- /dev/null +++ b/opensora/eval/fvd/videogpt/fvd.py @@ -0,0 +1,137 @@ +import torch +import os +import math +import torch.nn.functional as F +import numpy as np +import einops + +def load_i3d_pretrained(device=torch.device('cpu')): + i3D_WEIGHTS_URL = "https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI" + filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_pretrained_400.pt') + print(filepath) + if not os.path.exists(filepath): + print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.") + os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}") + from .pytorch_i3d import InceptionI3d + i3d = InceptionI3d(400, in_channels=3).eval().to(device) + i3d.load_state_dict(torch.load(filepath, map_location=device)) + i3d = torch.nn.DataParallel(i3d) + return i3d + +def preprocess_single(video, resolution, sequence_length=None): + # video: THWC, {0, ..., 255} + video = video.permute(0, 3, 1, 2).float() / 255. # TCHW + t, c, h, w = video.shape + + # temporal crop + if sequence_length is not None: + assert sequence_length <= t + video = video[:sequence_length] + + # scale shorter side to resolution + scale = resolution / min(h, w) + if h < w: + target_size = (resolution, math.ceil(w * scale)) + else: + target_size = (math.ceil(h * scale), resolution) + video = F.interpolate(video, size=target_size, mode='bilinear', + align_corners=False) + + # center crop + t, c, h, w = video.shape + w_start = (w - resolution) // 2 + h_start = (h - resolution) // 2 + video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution] + video = video.permute(1, 0, 2, 3).contiguous() # CTHW + + video -= 0.5 + + return video + +def preprocess(videos, target_resolution=224): + # we should tras videos in [0-1] [b c t h w] as th.float + # -> videos in {0, ..., 255} [b t h w c] as np.uint8 array + videos = einops.rearrange(videos, 'b c t h w -> b t h w c') + videos = (videos*255).numpy().astype(np.uint8) + + b, t, h, w, c = videos.shape + videos = torch.from_numpy(videos) + videos = torch.stack([preprocess_single(video, target_resolution) for video in videos]) + return videos * 2 # [-0.5, 0.5] -> [-1, 1] + +def get_fvd_logits(videos, i3d, device, bs=10): + videos = preprocess(videos) + embeddings = get_logits(i3d, videos, device, bs=10) + return embeddings + +# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161 +def _symmetric_matrix_square_root(mat, eps=1e-10): + u, s, v = torch.svd(mat) + si = torch.where(s < eps, s, torch.sqrt(s)) + return torch.matmul(torch.matmul(u, torch.diag(si)), v.t()) + +# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400 +def trace_sqrt_product(sigma, sigma_v): + sqrt_sigma = _symmetric_matrix_square_root(sigma) + sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma)) + return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) + +# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 +def cov(m, rowvar=False): + '''Estimate a covariance matrix given data. + + Covariance indicates the level to which two variables vary together. + If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, + then the covariance matrix element `C_{ij}` is the covariance of + `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. + + Args: + m: A 1-D or 2-D array containing multiple variables and observations. + Each row of `m` represents a variable, and each column a single + observation of all those variables. + rowvar: If `rowvar` is True, then each row represents a + variable, with observations in the columns. Otherwise, the + relationship is transposed: each column represents a variable, + while the rows contain observations. + + Returns: + The covariance matrix of the variables. + ''' + if m.dim() > 2: + raise ValueError('m has more than 2 dimensions') + if m.dim() < 2: + m = m.view(1, -1) + if not rowvar and m.size(0) != 1: + m = m.t() + + fact = 1.0 / (m.size(1) - 1) # unbiased estimate + m -= torch.mean(m, dim=1, keepdim=True) + mt = m.t() # if complex: mt = m.t().conj() + return fact * m.matmul(mt).squeeze() + + +def frechet_distance(x1, x2): + x1 = x1.flatten(start_dim=1) + x2 = x2.flatten(start_dim=1) + m, m_w = x1.mean(dim=0), x2.mean(dim=0) + sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False) + mean = torch.sum((m - m_w) ** 2) + if x1.shape[0]>1: + sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) + trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component + fd = trace + mean + else: + fd = np.real(mean) + return float(fd) + + +def get_logits(i3d, videos, device, bs=10): + # assert videos.shape[0] % 16 == 0 + with torch.no_grad(): + logits = [] + for i in range(0, videos.shape[0], bs): + batch = videos[i:i + bs].to(device) + # logits.append(i3d.module.extract_features(batch)) # wrong + logits.append(i3d(batch)) # right + logits = torch.cat(logits, dim=0) + return logits diff --git a/opensora/eval/fvd/videogpt/i3d_pretrained_400.pt b/opensora/eval/fvd/videogpt/i3d_pretrained_400.pt new file mode 100644 index 0000000000000000000000000000000000000000..c8f67ab6a50dc2e9623e64a37e9e59a31ed88441 --- /dev/null +++ b/opensora/eval/fvd/videogpt/i3d_pretrained_400.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:55095f049e706479d48e221adcdb145b2b9dc930ba28b081ed72367ffaa32343 +size 50939526 diff --git a/opensora/eval/fvd/videogpt/pytorch_i3d.py b/opensora/eval/fvd/videogpt/pytorch_i3d.py new file mode 100644 index 0000000000000000000000000000000000000000..58a16cdd797e5a0e9d711bb0ba14281486d44d03 --- /dev/null +++ b/opensora/eval/fvd/videogpt/pytorch_i3d.py @@ -0,0 +1,322 @@ +# Original code from https://github.com/piergiaj/pytorch-i3d +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +class MaxPool3dSamePadding(nn.MaxPool3d): + + def compute_pad(self, dim, s): + if s % self.stride[dim] == 0: + return max(self.kernel_size[dim] - self.stride[dim], 0) + else: + return max(self.kernel_size[dim] - (s % self.stride[dim]), 0) + + def forward(self, x): + # compute 'same' padding + (batch, channel, t, h, w) = x.size() + out_t = np.ceil(float(t) / float(self.stride[0])) + out_h = np.ceil(float(h) / float(self.stride[1])) + out_w = np.ceil(float(w) / float(self.stride[2])) + pad_t = self.compute_pad(0, t) + pad_h = self.compute_pad(1, h) + pad_w = self.compute_pad(2, w) + + pad_t_f = pad_t // 2 + pad_t_b = pad_t - pad_t_f + pad_h_f = pad_h // 2 + pad_h_b = pad_h - pad_h_f + pad_w_f = pad_w // 2 + pad_w_b = pad_w - pad_w_f + + pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) + x = F.pad(x, pad) + return super(MaxPool3dSamePadding, self).forward(x) + + +class Unit3D(nn.Module): + + def __init__(self, in_channels, + output_channels, + kernel_shape=(1, 1, 1), + stride=(1, 1, 1), + padding=0, + activation_fn=F.relu, + use_batch_norm=True, + use_bias=False, + name='unit_3d'): + + """Initializes Unit3D module.""" + super(Unit3D, self).__init__() + + self._output_channels = output_channels + self._kernel_shape = kernel_shape + self._stride = stride + self._use_batch_norm = use_batch_norm + self._activation_fn = activation_fn + self._use_bias = use_bias + self.name = name + self.padding = padding + + self.conv3d = nn.Conv3d(in_channels=in_channels, + out_channels=self._output_channels, + kernel_size=self._kernel_shape, + stride=self._stride, + padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function + bias=self._use_bias) + + if self._use_batch_norm: + self.bn = nn.BatchNorm3d(self._output_channels, eps=1e-5, momentum=0.001) + + def compute_pad(self, dim, s): + if s % self._stride[dim] == 0: + return max(self._kernel_shape[dim] - self._stride[dim], 0) + else: + return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0) + + + def forward(self, x): + # compute 'same' padding + (batch, channel, t, h, w) = x.size() + out_t = np.ceil(float(t) / float(self._stride[0])) + out_h = np.ceil(float(h) / float(self._stride[1])) + out_w = np.ceil(float(w) / float(self._stride[2])) + pad_t = self.compute_pad(0, t) + pad_h = self.compute_pad(1, h) + pad_w = self.compute_pad(2, w) + + pad_t_f = pad_t // 2 + pad_t_b = pad_t - pad_t_f + pad_h_f = pad_h // 2 + pad_h_b = pad_h - pad_h_f + pad_w_f = pad_w // 2 + pad_w_b = pad_w - pad_w_f + + pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) + x = F.pad(x, pad) + + x = self.conv3d(x) + if self._use_batch_norm: + x = self.bn(x) + if self._activation_fn is not None: + x = self._activation_fn(x) + return x + + + +class InceptionModule(nn.Module): + def __init__(self, in_channels, out_channels, name): + super(InceptionModule, self).__init__() + + self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0, + name=name+'/Branch_0/Conv3d_0a_1x1') + self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0, + name=name+'/Branch_1/Conv3d_0a_1x1') + self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3], + name=name+'/Branch_1/Conv3d_0b_3x3') + self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0, + name=name+'/Branch_2/Conv3d_0a_1x1') + self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3], + name=name+'/Branch_2/Conv3d_0b_3x3') + self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], + stride=(1, 1, 1), padding=0) + self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0, + name=name+'/Branch_3/Conv3d_0b_1x1') + self.name = name + + def forward(self, x): + b0 = self.b0(x) + b1 = self.b1b(self.b1a(x)) + b2 = self.b2b(self.b2a(x)) + b3 = self.b3b(self.b3a(x)) + return torch.cat([b0,b1,b2,b3], dim=1) + + +class InceptionI3d(nn.Module): + """Inception-v1 I3D architecture. + The model is introduced in: + Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset + Joao Carreira, Andrew Zisserman + https://arxiv.org/pdf/1705.07750v1.pdf. + See also the Inception architecture, introduced in: + Going deeper with convolutions + Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, + Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. + http://arxiv.org/pdf/1409.4842v1.pdf. + """ + + # Endpoints of the model in order. During construction, all the endpoints up + # to a designated `final_endpoint` are returned in a dictionary as the + # second return value. + VALID_ENDPOINTS = ( + 'Conv3d_1a_7x7', + 'MaxPool3d_2a_3x3', + 'Conv3d_2b_1x1', + 'Conv3d_2c_3x3', + 'MaxPool3d_3a_3x3', + 'Mixed_3b', + 'Mixed_3c', + 'MaxPool3d_4a_3x3', + 'Mixed_4b', + 'Mixed_4c', + 'Mixed_4d', + 'Mixed_4e', + 'Mixed_4f', + 'MaxPool3d_5a_2x2', + 'Mixed_5b', + 'Mixed_5c', + 'Logits', + 'Predictions', + ) + + def __init__(self, num_classes=400, spatial_squeeze=True, + final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5): + """Initializes I3D model instance. + Args: + num_classes: The number of outputs in the logit layer (default 400, which + matches the Kinetics dataset). + spatial_squeeze: Whether to squeeze the spatial dimensions for the logits + before returning (default True). + final_endpoint: The model contains many possible endpoints. + `final_endpoint` specifies the last endpoint for the model to be built + up to. In addition to the output at `final_endpoint`, all the outputs + at endpoints up to `final_endpoint` will also be returned, in a + dictionary. `final_endpoint` must be one of + InceptionI3d.VALID_ENDPOINTS (default 'Logits'). + name: A string (optional). The name of this module. + Raises: + ValueError: if `final_endpoint` is not recognized. + """ + + if final_endpoint not in self.VALID_ENDPOINTS: + raise ValueError('Unknown final endpoint %s' % final_endpoint) + + super(InceptionI3d, self).__init__() + self._num_classes = num_classes + self._spatial_squeeze = spatial_squeeze + self._final_endpoint = final_endpoint + self.logits = None + + if self._final_endpoint not in self.VALID_ENDPOINTS: + raise ValueError('Unknown final endpoint %s' % self._final_endpoint) + + self.end_points = {} + end_point = 'Conv3d_1a_7x7' + self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7], + stride=(2, 2, 2), padding=(3,3,3), name=name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'MaxPool3d_2a_3x3' + self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), + padding=0) + if self._final_endpoint == end_point: return + + end_point = 'Conv3d_2b_1x1' + self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0, + name=name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'Conv3d_2c_3x3' + self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1, + name=name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'MaxPool3d_3a_3x3' + self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), + padding=0) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_3b' + self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_3c' + self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'MaxPool3d_4a_3x3' + self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2), + padding=0) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_4b' + self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_4c' + self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_4d' + self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_4e' + self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_4f' + self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'MaxPool3d_5a_2x2' + self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2), + padding=0) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_5b' + self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'Mixed_5c' + self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point) + if self._final_endpoint == end_point: return + + end_point = 'Logits' + self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], + stride=(1, 1, 1)) + self.dropout = nn.Dropout(dropout_keep_prob) + self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes, + kernel_shape=[1, 1, 1], + padding=0, + activation_fn=None, + use_batch_norm=False, + use_bias=True, + name='logits') + + self.build() + + + def replace_logits(self, num_classes): + self._num_classes = num_classes + self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes, + kernel_shape=[1, 1, 1], + padding=0, + activation_fn=None, + use_batch_norm=False, + use_bias=True, + name='logits') + + + def build(self): + for k in self.end_points.keys(): + self.add_module(k, self.end_points[k]) + + def forward(self, x): + for end_point in self.VALID_ENDPOINTS: + if end_point in self.end_points: + x = self._modules[end_point](x) # use _modules to work with dataparallel + + x = self.logits(self.dropout(self.avg_pool(x))) + if self._spatial_squeeze: + logits = x.squeeze(3).squeeze(3) + logits = logits.mean(dim=2) + # logits is batch X time X classes, which is what we want to work with + return logits + + + def extract_features(self, x): + for end_point in self.VALID_ENDPOINTS: + if end_point in self.end_points: + x = self._modules[end_point](x) + return self.avg_pool(x) \ No newline at end of file diff --git a/opensora/eval/script/cal_clip_score.sh b/opensora/eval/script/cal_clip_score.sh new file mode 100644 index 0000000000000000000000000000000000000000..8bc241c154a72805282b511dcfb2164a5f6a282e --- /dev/null +++ b/opensora/eval/script/cal_clip_score.sh @@ -0,0 +1,23 @@ +# clip_score cross modality +python eval_clip_score.py \ + --real_path path/to/image \ + --generated_path path/to/text \ + --batch-size 50 \ + --device "cuda" + +# clip_score within the same modality +python eval_clip_score.py \ + --real_path path/to/textA \ + --generated_path path/to/textB \ + --real_flag txt \ + --generated_flag txt \ + --batch-size 50 \ + --device "cuda" + +python eval_clip_score.py \ + --real_path path/to/imageA \ + --generated_path path/to/imageB \ + --real_flag img \ + --generated_flag img \ + --batch-size 50 \ + --device "cuda" diff --git a/opensora/eval/script/cal_fvd.sh b/opensora/eval/script/cal_fvd.sh new file mode 100644 index 0000000000000000000000000000000000000000..e344242a0ec2bf4d62f390312d6a396591c049a8 --- /dev/null +++ b/opensora/eval/script/cal_fvd.sh @@ -0,0 +1,9 @@ +python eval_common_metric.py \ + --real_video_dir path/to/imageA\ + --generated_video_dir path/to/imageB \ + --batch_size 10 \ + --crop_size 64 \ + --num_frames 20 \ + --device 'cuda' \ + --metric 'fvd' \ + --fvd_method 'styleganv' diff --git a/opensora/eval/script/cal_lpips.sh b/opensora/eval/script/cal_lpips.sh new file mode 100644 index 0000000000000000000000000000000000000000..9ebaf26e856166fed4573a8164f7c5b557bfd229 --- /dev/null +++ b/opensora/eval/script/cal_lpips.sh @@ -0,0 +1,8 @@ +python eval_common_metric.py \ + --real_video_dir path/to/imageA\ + --generated_video_dir path/to/imageB \ + --batch_size 10 \ + --num_frames 20 \ + --crop_size 64 \ + --device 'cuda' \ + --metric 'lpips' \ No newline at end of file diff --git a/opensora/eval/script/cal_psnr.sh b/opensora/eval/script/cal_psnr.sh new file mode 100644 index 0000000000000000000000000000000000000000..a60a5d1c167f2dcd4a251b31af0d8a534bfee195 --- /dev/null +++ b/opensora/eval/script/cal_psnr.sh @@ -0,0 +1,9 @@ + +python eval_common_metric.py \ + --real_video_dir /data/xiaogeng_liu/data/video1 \ + --generated_video_dir /data/xiaogeng_liu/data/video2 \ + --batch_size 10 \ + --num_frames 20 \ + --crop_size 64 \ + --device 'cuda' \ + --metric 'psnr' \ No newline at end of file diff --git a/opensora/eval/script/cal_ssim.sh b/opensora/eval/script/cal_ssim.sh new file mode 100644 index 0000000000000000000000000000000000000000..404d8da9a12a28a3c173dc87652a380e7a754b25 --- /dev/null +++ b/opensora/eval/script/cal_ssim.sh @@ -0,0 +1,8 @@ +python eval_common_metric.py \ + --real_video_dir /data/xiaogeng_liu/data/video1 \ + --generated_video_dir /data/xiaogeng_liu/data/video2 \ + --batch_size 10 \ + --num_frames 20 \ + --crop_size 64 \ + --device 'cuda' \ + --metric 'ssim' \ No newline at end of file diff --git a/opensora/models/__init__.py b/opensora/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/opensora/models/ae/__init__.py b/opensora/models/ae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43a6ef0d4c13a48c6a4c303ee955333c5e8eac61 --- /dev/null +++ b/opensora/models/ae/__init__.py @@ -0,0 +1,30 @@ +from .imagebase import imagebase_ae, imagebase_ae_stride, imagebase_ae_channel +from .videobase import videobase_ae, videobase_ae_stride, videobase_ae_channel +from .videobase import ( + VQVAEConfiguration, + VQVAEModel, + VQVAETrainer, + CausalVQVAEModel, + CausalVQVAEConfiguration, + CausalVQVAETrainer +) + +ae_stride_config = {} +ae_stride_config.update(imagebase_ae_stride) +ae_stride_config.update(videobase_ae_stride) + +ae_channel_config = {} +ae_channel_config.update(imagebase_ae_channel) +ae_channel_config.update(videobase_ae_channel) + +def getae(args): + """deprecation""" + ae = imagebase_ae.get(args.ae, None) or videobase_ae.get(args.ae, None) + assert ae is not None + return ae(args.ae) + +def getae_wrapper(ae): + """deprecation""" + ae = imagebase_ae.get(ae, None) or videobase_ae.get(ae, None) + assert ae is not None + return ae \ No newline at end of file diff --git a/opensora/models/ae/imagebase/__init__.py b/opensora/models/ae/imagebase/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..12eeb327ffdec68d152adfcbc0da8191d76ca0f5 --- /dev/null +++ b/opensora/models/ae/imagebase/__init__.py @@ -0,0 +1,30 @@ +from .vae.vae import HFVAEWrapper +from .vae.vae import SDVAEWrapper +from .vqvae.vqvae import SDVQVAEWrapper + +vae = ['stabilityai/sd-vae-ft-mse', 'stabilityai/sd-vae-ft-ema'] +vqvae = ['vqgan_imagenet_f16_1024', 'vqgan_imagenet_f16_16384', 'vqgan_gumbel_f8'] + +imagebase_ae_stride = { + 'stabilityai/sd-vae-ft-mse': [1, 8, 8], + 'stabilityai/sd-vae-ft-ema': [1, 8, 8], + 'vqgan_imagenet_f16_1024': [1, 16, 16], + 'vqgan_imagenet_f16_16384': [1, 16, 16], + 'vqgan_gumbel_f8': [1, 8, 8], +} + +imagebase_ae_channel = { + 'stabilityai/sd-vae-ft-mse': 4, + 'stabilityai/sd-vae-ft-ema': 4, + 'vqgan_imagenet_f16_1024': -1, + 'vqgan_imagenet_f16_16384': -1, + 'vqgan_gumbel_f8': -1, +} + +imagebase_ae = { + 'stabilityai/sd-vae-ft-mse': HFVAEWrapper, + 'stabilityai/sd-vae-ft-ema': HFVAEWrapper, + 'vqgan_imagenet_f16_1024': SDVQVAEWrapper, + 'vqgan_imagenet_f16_16384': SDVQVAEWrapper, + 'vqgan_gumbel_f8': SDVQVAEWrapper, +} \ No newline at end of file diff --git a/opensora/models/ae/imagebase/vqvae/model.py b/opensora/models/ae/imagebase/vqvae/model.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9a757f742019649f0173235c2f4e04fe042929 --- /dev/null +++ b/opensora/models/ae/imagebase/vqvae/model.py @@ -0,0 +1,775 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x, t=None): + #assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x): + #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VUNet(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + in_channels, c_channels, + resolution, z_channels, use_timestep=False, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(c_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + self.z_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=1, + stride=1, + padding=0) + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=2*block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x, z): + #assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + z = self.z_in(z) + h = torch.cat((h,z),dim=1) + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h \ No newline at end of file diff --git a/opensora/models/ae/imagebase/vqvae/quantize.py b/opensora/models/ae/imagebase/vqvae/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..dd258efe9d06f057c020a6914058a88e234c2505 --- /dev/null +++ b/opensora/models/ae/imagebase/vqvae/quantize.py @@ -0,0 +1,447 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch import einsum +from einops import rearrange + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for + # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be + # used wherever VectorQuantizer has been used before and is additionally + # more efficient. + def __init__(self, n_e, e_dim, beta): + super(VectorQuantizer, self).__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vector that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + + ## could possible replace this here + # #\start... + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + + min_encodings = torch.zeros( + min_encoding_indices.shape[0], self.n_e).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # dtype min encodings: torch.float32 + # min_encodings shape: torch.Size([2048, 512]) + # min_encoding_indices.shape: torch.Size([2048, 1]) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + # .........\end + + # with: + # .........\start + # min_encoding_indices = torch.argmin(d, dim=1) + # z_q = self.embedding(min_encoding_indices) + # ......\end......... (TODO) + + # compute loss for embedding + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + # TODO: check for more easy handling with nn.Embedding + min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) + min_encodings.scatter_(1, indices[:, None], 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: + z_q = z_q.view(shape) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantize(nn.Module): + """ + credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) + Gumbel Softmax trick quantizer + Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 + https://arxiv.org/abs/1611.01144 + """ + + def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, + kl_weight=5e-4, temp_init=1.0, use_vqinterface=True, + remap=None, unknown_index="random"): + super().__init__() + + self.embedding_dim = embedding_dim + self.n_embed = n_embed + + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + + self.proj = nn.Conv2d(num_hiddens, n_embed, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + + self.use_vqinterface = use_vqinterface + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, return_logits=False): + # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + + logits = self.proj(z) + if self.remap is not None: + # continue only with used logits + full_zeros = torch.zeros_like(logits) + logits = logits[:, self.used, ...] + + soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) + if self.remap is not None: + # go back to all entries but unused set to zero + full_zeros[:, self.used, ...] = soft_one_hot + soft_one_hot = full_zeros + z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() + + ind = soft_one_hot.argmax(dim=1) + if self.remap is not None: + ind = self.remap_to_used(ind) + if self.use_vqinterface: + if return_logits: + return z_q, diff, (None, None, ind), logits + return z_q, diff, (None, None, ind) + return z_q, diff, ind + + def get_codebook_entry(self, indices, shape): + b, h, w, c = shape + assert b * h * w == indices.shape[0] + indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w) + if self.remap is not None: + indices = self.unmap_to_all(indices) + one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() + z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight) + return z_q + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", + sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" + assert rescale_logits == False, "Only for interface compatible with Gumbel" + assert return_logits == False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): + super().__init__() + self.decay = decay + self.eps = eps + weight = torch.randn(num_tokens, codebook_dim) + self.weight = nn.Parameter(weight, requires_grad=False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) + self.update = True + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + # normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + self.weight.data.copy_(embed_normalized) + + +class EMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, + remap=None, unknown_index="random"): + super().__init__() + self.codebook_dim = codebook_dim + self.num_tokens = num_tokens + self.beta = beta + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + # z, 'b c h w -> b h w c' + z = rearrange(z, 'b c h w -> b h w c') + z_flattened = z.reshape(-1, self.codebook_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + if self.training and self.embedding.update: + # EMA cluster size + encodings_sum = encodings.sum(0) + self.embedding.cluster_size_ema_update(encodings_sum) + # EMA embedding average + embed_sum = encodings.transpose(0, 1) @ z_flattened + self.embedding.embed_avg_ema_update(embed_sum) + # normalize embed_avg and update weight + self.embedding.weight_update(self.num_tokens) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + # z_q, 'b h w c -> b c h w' + z_q = rearrange(z_q, 'b h w c -> b c h w') + return z_q, loss, (perplexity, encodings, encoding_indices) \ No newline at end of file diff --git a/opensora/models/ae/imagebase/vqvae/vqgan.py b/opensora/models/ae/imagebase/vqvae/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..9e9125be141193d2a0f988d4fcecddfa7cfd4a39 --- /dev/null +++ b/opensora/models/ae/imagebase/vqvae/vqgan.py @@ -0,0 +1,419 @@ +import torch +import torch.nn.functional as F +import pytorch_lightning as pl +import argparse, os, sys, datetime, glob, importlib + +from .model import Encoder, Decoder +from .quantize import VectorQuantizer2 as VectorQuantizer +from .quantize import GumbelQuantize +from .quantize import EMAVectorQuantizer + + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config): + if not "target" in config: + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.image_key = image_key + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input): + quant, diff, _ = self.encode(input) + dec = self.decode(quant) + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) + return x.float() + + def training_step(self, batch, batch_idx, optimizer_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + rec_loss = log_dict_ae["val/rec_loss"] + self.log("val/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + self.log("val/aeloss", aeloss, + prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQSegmentationModel(VQModel): + def __init__(self, n_labels, *args, **kwargs): + super().__init__(*args, **kwargs) + self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1)) + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + return opt_ae + + def training_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train") + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val") + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + total_loss = log_dict_ae["val/total_loss"] + self.log("val/total_loss", total_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + return aeloss + + @torch.no_grad() + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + # convert logits to indices + xrec = torch.argmax(xrec, dim=1, keepdim=True) + xrec = F.one_hot(xrec, num_classes=x.shape[1]) + xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + return log + + +class VQNoDiscModel(VQModel): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None + ): + super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim, + ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key, + colorize_nlabels=colorize_nlabels) + + def training_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train") + output = pl.TrainResult(minimize=aeloss) + output.log("train/aeloss", aeloss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return output + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val") + rec_loss = log_dict_ae["val/rec_loss"] + output = pl.EvalResult(checkpoint_on=rec_loss) + output.log("val/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log("val/aeloss", aeloss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log_dict(log_dict_ae) + + return output + + def configure_optimizers(self): + optimizer = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=self.learning_rate, betas=(0.5, 0.9)) + return optimizer + + +class GumbelVQ(VQModel): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + temperature_scheduler_config, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + kl_weight=1e-8, + remap=None, + ): + + z_channels = ddconfig["z_channels"] + super().__init__(ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=ignore_keys, + image_key=image_key, + colorize_nlabels=colorize_nlabels, + monitor=monitor, + ) + + self.loss.n_classes = n_embed + self.vocab_size = n_embed + + self.quantize = GumbelQuantize(z_channels, embed_dim, + n_embed=n_embed, + kl_weight=kl_weight, temp_init=1.0, + remap=remap) + + self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def temperature_scheduling(self): + self.quantize.temperature = self.temperature_scheduler(self.global_step) + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode_code(self, code_b): + raise NotImplementedError + + def training_step(self, batch, batch_idx, optimizer_idx): + self.temperature_scheduling() + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + rec_loss = log_dict_ae["val/rec_loss"] + self.log("val/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log("val/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + # encode + h = self.encoder(x) + h = self.quant_conv(h) + quant, _, _ = self.quantize(h) + # decode + x_rec = self.decode(quant) + log["inputs"] = x + log["reconstructions"] = x_rec + return log + + +class EMAVQ(VQModel): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ): + super().__init__(ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=ignore_keys, + image_key=image_key, + colorize_nlabels=colorize_nlabels, + monitor=monitor, + ) + self.quantize = EMAVectorQuantizer(n_embed=n_embed, + embedding_dim=embed_dim, + beta=0.25, + remap=remap) + def configure_optimizers(self): + lr = self.learning_rate + #Remove self.quantize from parameter list since it is updated via EMA + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] \ No newline at end of file diff --git a/opensora/models/ae/imagebase/vqvae/vqvae.py b/opensora/models/ae/imagebase/vqvae/vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..c43f8c42b5d883da6069e3d92cc50e7a7d7264c3 --- /dev/null +++ b/opensora/models/ae/imagebase/vqvae/vqvae.py @@ -0,0 +1,34 @@ +from torch import nn +import yaml +import torch +from omegaconf import OmegaConf +from .vqgan import VQModel, GumbelVQ + +def load_config(config_path, display=False): + config = OmegaConf.load(config_path) + if display: + print(yaml.dump(OmegaConf.to_container(config))) + return config + + +def load_vqgan(config, ckpt_path=None, is_gumbel=False): + if is_gumbel: + model = GumbelVQ(**config.model.params) + else: + model = VQModel(**config.model.params) + if ckpt_path is not None: + sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] + missing, unexpected = model.load_state_dict(sd, strict=False) + return model.eval() + + +class SDVQVAEWrapper(nn.Module): + def __init__(self, name): + super(SDVQVAEWrapper, self).__init__() + raise NotImplementedError + + def encode(self, x): # b c h w + raise NotImplementedError + + def decode(self, x): + raise NotImplementedError diff --git a/opensora/models/ae/videobase/__init__.py b/opensora/models/ae/videobase/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6cd834cd5d8072ed808e218ddbca5130a94f6db --- /dev/null +++ b/opensora/models/ae/videobase/__init__.py @@ -0,0 +1,54 @@ +from .vqvae import ( + VQVAEConfiguration, + VQVAEModel, + VQVAETrainer, + VQVAEModelWrapper +) +from .causal_vqvae import ( + CausalVQVAEConfiguration, + CausalVQVAETrainer, + CausalVQVAEModel, CausalVQVAEModelWrapper +) +from .causal_vae import ( + CausalVAETrainer, + CausalVAEModel, CausalVAEModelWrapper +) + + +videobase_ae_stride = { + 'CausalVAEModel_4x8x8': [4, 8, 8], + 'CausalVQVAEModel_4x4x4': [4, 4, 4], + 'CausalVQVAEModel_4x8x8': [4, 8, 8], + 'VQVAEModel_4x4x4': [4, 4, 4], + 'OpenVQVAEModel_4x4x4': [4, 4, 4], + 'VQVAEModel_4x8x8': [4, 8, 8], + 'bair_stride4x2x2': [4, 2, 2], + 'ucf101_stride4x4x4': [4, 4, 4], + 'kinetics_stride4x4x4': [4, 4, 4], + 'kinetics_stride2x4x4': [2, 4, 4], +} + +videobase_ae_channel = { + 'CausalVAEModel_4x8x8': 4, + 'CausalVQVAEModel_4x4x4': 4, + 'CausalVQVAEModel_4x8x8': 4, + 'VQVAEModel_4x4x4': 4, + 'OpenVQVAEModel_4x4x4': 4, + 'VQVAEModel_4x8x8': 4, + 'bair_stride4x2x2': 256, + 'ucf101_stride4x4x4': 256, + 'kinetics_stride4x4x4': 256, + 'kinetics_stride2x4x4': 256, +} + +videobase_ae = { + 'CausalVAEModel_4x8x8': CausalVAEModelWrapper, + 'CausalVQVAEModel_4x4x4': CausalVQVAEModelWrapper, + 'CausalVQVAEModel_4x8x8': CausalVQVAEModelWrapper, + 'VQVAEModel_4x4x4': VQVAEModelWrapper, + 'VQVAEModel_4x8x8': VQVAEModelWrapper, + "bair_stride4x2x2": VQVAEModelWrapper, + "ucf101_stride4x4x4": VQVAEModelWrapper, + "kinetics_stride4x4x4": VQVAEModelWrapper, + "kinetics_stride2x4x4": VQVAEModelWrapper, +} diff --git a/opensora/models/ae/videobase/causal_vae/__init__.py b/opensora/models/ae/videobase/causal_vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..120d86b76b23021238ba42487768499c39404d8c --- /dev/null +++ b/opensora/models/ae/videobase/causal_vae/__init__.py @@ -0,0 +1,27 @@ +from .modeling_causalvae import CausalVAEModel +from .trainer_causalvae import CausalVAETrainer + +from einops import rearrange +from torch import nn + +class CausalVAEModelWrapper(nn.Module): + def __init__(self, model_path, subfolder=None, cache_dir=None): + super(CausalVAEModelWrapper, self).__init__() + # if os.path.exists(ckpt): + # self.vae = CausalVAEModel.load_from_checkpoint(ckpt) + self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir) + def encode(self, x): # b c t h w + # x = self.vae.encode(x).sample() + x = self.vae.encode(x).sample().mul_(0.18215) + return x + def decode(self, x): + # x = self.vae.decode(x) + x = self.vae.decode(x / 0.18215) + x = rearrange(x, 'b c t h w -> b t c h w').contiguous() + return x + + def dtype(self): + return self.vae.dtype + # + # def device(self): + # return self.vae.device \ No newline at end of file diff --git a/opensora/models/ae/videobase/causal_vae/modeling_causalvae.py b/opensora/models/ae/videobase/causal_vae/modeling_causalvae.py new file mode 100644 index 0000000000000000000000000000000000000000..cee91ac047450ed41fbcfe70a909ab94d314c7a6 --- /dev/null +++ b/opensora/models/ae/videobase/causal_vae/modeling_causalvae.py @@ -0,0 +1,651 @@ +from ..modeling_videobase import VideoBaseAE_PL +from ..modules import Normalize +from ..modules.ops import nonlinearity +from typing import List, Tuple +import torch.nn as nn +from ..utils.module_utils import resolve_str_to_obj, Module +from ..utils.distrib_utils import DiagonalGaussianDistribution +from ..utils.scheduler_utils import cosine_scheduler +import torch +from diffusers.configuration_utils import register_to_config + + +class Encoder(nn.Module): + def __init__( + self, + z_channels: int, + hidden_size: int, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = (16,), + conv_in: Module = "Conv2d", + conv_out: Module = "CasualConv3d", + attention: Module = "AttnBlock", + resnet_blocks: Tuple[Module] = ( + "ResnetBlock2D", + "ResnetBlock2D", + "ResnetBlock2D", + "ResnetBlock3D", + ), + spatial_downsample: Tuple[Module] = ( + "Downsample", + "Downsample", + "Downsample", + "", + ), + temporal_downsample: Tuple[Module] = ("", "", "TimeDownsampleRes2x", ""), + mid_resnet: Module = "ResnetBlock3D", + dropout: float = 0.0, + resolution: int = 256, + num_res_blocks: int = 2, + double_z: bool = True, + ) -> None: + super().__init__() + assert len(resnet_blocks) == len(hidden_size_mult), print( + hidden_size_mult, resnet_blocks + ) + # ---- Config ---- + self.num_resolutions = len(hidden_size_mult) + self.resolution = resolution + self.num_res_blocks = num_res_blocks + + # ---- In ---- + self.conv_in = resolve_str_to_obj(conv_in)( + 3, hidden_size, kernel_size=3, stride=1, padding=1 + ) + + # ---- Downsample ---- + curr_res = resolution + in_ch_mult = (1,) + tuple(hidden_size_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = hidden_size * in_ch_mult[i_level] + block_out = hidden_size * hidden_size_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + resolve_str_to_obj(resnet_blocks[i_level])( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(resolve_str_to_obj(attention)(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if spatial_downsample[i_level]: + down.downsample = resolve_str_to_obj(spatial_downsample[i_level])( + block_in, block_in + ) + curr_res = curr_res // 2 + if temporal_downsample[i_level]: + down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])( + block_in, block_in + ) + self.down.append(down) + + # ---- Mid ---- + self.mid = nn.Module() + self.mid.block_1 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) + self.mid.block_2 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + # ---- Out ---- + self.norm_out = Normalize(block_in) + self.conv_out = resolve_str_to_obj(conv_out)( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if hasattr(self.down[i_level], "downsample"): + hs.append(self.down[i_level].downsample(hs[-1])) + if hasattr(self.down[i_level], "time_downsample"): + hs_down = self.down[i_level].time_downsample(hs[-1]) + hs.append(hs_down) + + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + z_channels: int, + hidden_size: int, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = (16,), + conv_in: Module = "Conv2d", + conv_out: Module = "CasualConv3d", + attention: Module = "AttnBlock", + resnet_blocks: Tuple[Module] = ( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + spatial_upsample: Tuple[Module] = ( + "", + "SpatialUpsample2x", + "SpatialUpsample2x", + "SpatialUpsample2x", + ), + temporal_upsample: Tuple[Module] = ("", "", "", "TimeUpsampleRes2x"), + mid_resnet: Module = "ResnetBlock3D", + dropout: float = 0.0, + resolution: int = 256, + num_res_blocks: int = 2, + ): + super().__init__() + # ---- Config ---- + self.num_resolutions = len(hidden_size_mult) + self.resolution = resolution + self.num_res_blocks = num_res_blocks + + # ---- In ---- + block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.conv_in = resolve_str_to_obj(conv_in)( + z_channels, block_in, kernel_size=3, padding=1 + ) + + # ---- Mid ---- + self.mid = nn.Module() + self.mid.block_1 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) + self.mid.block_2 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + + # ---- Upsample ---- + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = hidden_size * hidden_size_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + resolve_str_to_obj(resnet_blocks[i_level])( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(resolve_str_to_obj(attention)(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if spatial_upsample[i_level]: + up.upsample = resolve_str_to_obj(spatial_upsample[i_level])( + block_in, block_in + ) + curr_res = curr_res * 2 + if temporal_upsample[i_level]: + up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])( + block_in, block_in + ) + self.up.insert(0, up) + + # ---- Out ---- + self.norm_out = Normalize(block_in) + self.conv_out = resolve_str_to_obj(conv_out)( + block_in, 3, kernel_size=3, padding=1 + ) + + def forward(self, z): + h = self.conv_in(z) + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if hasattr(self.up[i_level], "upsample"): + h = self.up[i_level].upsample(h) + if hasattr(self.up[i_level], "time_upsample"): + h = self.up[i_level].time_upsample(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class CausalVAEModel(VideoBaseAE_PL): + + @register_to_config + def __init__( + self, + lr: float = 1e-5, + hidden_size: int = 128, + z_channels: int = 4, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = [], + dropout: float = 0.0, + resolution: int = 256, + double_z: bool = True, + embed_dim: int = 4, + num_res_blocks: int = 2, + loss_type: str = "opensora.models.ae.videobase.losses.LPIPSWithDiscriminator", + loss_params: dict = { + "kl_weight": 0.000001, + "logvar_init": 0.0, + "disc_start": 2001, + "disc_weight": 0.5, + }, + q_conv: str = "CausalConv3d", + encoder_conv_in: Module = "CausalConv3d", + encoder_conv_out: Module = "CausalConv3d", + encoder_attention: Module = "AttnBlock3D", + encoder_resnet_blocks: Tuple[Module] = ( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + encoder_spatial_downsample: Tuple[Module] = ( + "SpatialDownsample2x", + "SpatialDownsample2x", + "SpatialDownsample2x", + "", + ), + encoder_temporal_downsample: Tuple[Module] = ( + "", + "TimeDownsample2x", + "TimeDownsample2x", + "", + ), + encoder_mid_resnet: Module = "ResnetBlock3D", + decoder_conv_in: Module = "CausalConv3d", + decoder_conv_out: Module = "CausalConv3d", + decoder_attention: Module = "AttnBlock3D", + decoder_resnet_blocks: Tuple[Module] = ( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + decoder_spatial_upsample: Tuple[Module] = ( + "", + "SpatialUpsample2x", + "SpatialUpsample2x", + "SpatialUpsample2x", + ), + decoder_temporal_upsample: Tuple[Module] = ("", "", "TimeUpsample2x", "TimeUpsample2x"), + decoder_mid_resnet: Module = "ResnetBlock3D", + ) -> None: + super().__init__() + self.tile_sample_min_size = 256 + self.tile_sample_min_size_t = 65 + self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1))) + # self.tile_latent_min_size_t = int((self.tile_sample_min_size_t-1) / (2 ** self.time_compress)) + 1 + self.tile_overlap_factor = 0.25 + self.use_tiling = False + + self.learning_rate = lr + self.lr_g_factor = 1.0 + + self.loss = resolve_str_to_obj(loss_type, append=False)( + **loss_params + ) + + self.encoder = Encoder( + z_channels=z_channels, + hidden_size=hidden_size, + hidden_size_mult=hidden_size_mult, + attn_resolutions=attn_resolutions, + conv_in=encoder_conv_in, + conv_out=encoder_conv_out, + attention=encoder_attention, + resnet_blocks=encoder_resnet_blocks, + spatial_downsample=encoder_spatial_downsample, + temporal_downsample=encoder_temporal_downsample, + mid_resnet=encoder_mid_resnet, + dropout=dropout, + resolution=resolution, + num_res_blocks=num_res_blocks, + double_z=double_z, + ) + + self.decoder = Decoder( + z_channels=z_channels, + hidden_size=hidden_size, + hidden_size_mult=hidden_size_mult, + attn_resolutions=attn_resolutions, + conv_in=decoder_conv_in, + conv_out=decoder_conv_out, + attention=decoder_attention, + resnet_blocks=decoder_resnet_blocks, + spatial_upsample=decoder_spatial_upsample, + temporal_upsample=decoder_temporal_upsample, + mid_resnet=decoder_mid_resnet, + dropout=dropout, + resolution=resolution, + num_res_blocks=num_res_blocks, + ) + + quant_conv_cls = resolve_str_to_obj(q_conv) + self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1) + self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1) + if hasattr(self.loss, "discriminator"): + self.automatic_optimization = False + + def encode(self, x): + if self.use_tiling and ( + x.shape[-1] > self.tile_sample_min_size + or x.shape[-2] > self.tile_sample_min_size + ): + return self.tiled_encode2d(x) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + if self.use_tiling and ( + z.shape[-1] > self.tile_latent_min_size + or z.shape[-2] > self.tile_latent_min_size + ): + return self.tiled_decode2d(z) + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx): + if hasattr(self.loss, "discriminator"): + return self._training_step_gan(batch, batch_idx=batch_idx) + else: + return self._training_step(batch, batch_idx=batch_idx) + + def _training_step(self, batch, batch_idx): + inputs = self.get_input(batch, "video") + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + split="train", + ) + self.log( + "aeloss", + aeloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict( + log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False + ) + return aeloss + + def _training_step_gan(self, batch, batch_idx): + inputs = self.get_input(batch, "video") + reconstructions, posterior = self(inputs) + opt1, opt2 = self.optimizers() + + # ---- AE Loss ---- + opt1.zero_grad() + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + self.log( + "aeloss", + aeloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.manual_backward(aeloss) + self.clip_gradients(opt1, gradient_clip_val=1, gradient_clip_algorithm="norm") + opt1.step() + # ---- GAN Loss ---- + opt2.zero_grad() + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + self.log( + "discloss", + discloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.manual_backward(discloss) + self.clip_gradients(opt2, gradient_clip_val=1, gradient_clip_algorithm="norm") + opt2.step() + self.log_dict( + {**log_dict_ae, **log_dict_disc}, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=False, + ) + + def configure_optimizers(self): + from itertools import chain + + lr = self.learning_rate + modules_to_train = [ + self.encoder.named_parameters(), + self.decoder.named_parameters(), + self.post_quant_conv.named_parameters(), + self.quant_conv.named_parameters(), + ] + params_with_time = [] + params_without_time = [] + for name, param in chain(*modules_to_train): + if "time" in name: + params_with_time.append(param) + else: + params_without_time.append(param) + optimizers = [] + opt_ae = torch.optim.Adam( + [ + {"params": params_with_time, "lr": 0.0001}, + {"params": params_without_time, "lr": 0.00001}, + ], + lr=lr, + betas=(0.5, 0.9), + ) + optimizers.append(opt_ae) + + if hasattr(self.loss, "discriminator"): + opt_disc = torch.optim.Adam( + self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) + ) + optimizers.append(opt_disc) + + return optimizers, [] + + def get_last_layer(self): + if hasattr(self.decoder.conv_out, "conv"): + return self.decoder.conv_out.conv.weight + else: + return self.decoder.conv_out.weight + + def blend_v( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( + 1 - y / blend_extent + ) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( + 1 - x / blend_extent + ) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def tiled_encode2d(self, x): + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[3], overlap_size): + row = [] + for j in range(0, x.shape[4], overlap_size): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + moments = torch.cat(result_rows, dim=3) + posterior = DiagonalGaussianDistribution(moments) + + return posterior + + def tiled_decode2d(self, z): + + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[3], overlap_size): + row = [] + for j in range(0, z.shape[4], overlap_size): + tile = z[ + :, + :, + :, + i : i + self.tile_latent_min_size, + j : j + self.tile_latent_min_size, + ] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + return dec + + def enable_tiling(self, use_tiling: bool = True): + self.use_tiling = use_tiling + + def disable_tiling(self): + self.enable_tiling(False) + + def init_from_ckpt(self, path, ignore_keys=list(), remove_loss=True): + sd = torch.load(path, map_location="cpu") + if "state_dict" in sd: + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + if remove_loss and "loss" in k: + del sd[k] + self.load_state_dict(sd, strict=False) \ No newline at end of file diff --git a/opensora/models/ae/videobase/causal_vae/trainer_causalvae.py b/opensora/models/ae/videobase/causal_vae/trainer_causalvae.py new file mode 100644 index 0000000000000000000000000000000000000000..3530085be16ad86c1829d33ceb141a28ff9e26bf --- /dev/null +++ b/opensora/models/ae/videobase/causal_vae/trainer_causalvae.py @@ -0,0 +1,20 @@ +from ..trainer_videobase import VideoBaseTrainer +from typing import Optional +import os +import torch +from transformers.utils import WEIGHTS_NAME +import json + +class CausalVAETrainer(VideoBaseTrainer): + + def compute_loss(self, model, inputs, return_outputs=False): + model = model.module + x = inputs.get("video") + reconstructions, posterior = model(x) + aeloss, _ = model.loss( + x, + reconstructions, + posterior, + split="train", + ) + return aeloss \ No newline at end of file diff --git a/opensora/models/ae/videobase/causal_vqvae/__init__.py b/opensora/models/ae/videobase/causal_vqvae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..116d8956482484564e3e7b9f1b8ccd423ca93819 --- /dev/null +++ b/opensora/models/ae/videobase/causal_vqvae/__init__.py @@ -0,0 +1,20 @@ +from .configuration_causalvqvae import CausalVQVAEConfiguration +from .modeling_causalvqvae import CausalVQVAEModel +from .trainer_causalvqvae import CausalVQVAETrainer + + +from einops import rearrange +from torch import nn + +class CausalVQVAEModelWrapper(nn.Module): + def __init__(self, ckpt): + super(CausalVQVAEModelWrapper, self).__init__() + self.vqvae = CausalVQVAEModel.load_from_checkpoint(ckpt) + def encode(self, x): # b c t h w + x = self.vqvae.pre_vq_conv(self.vqvae.encoder(x)) + return x + def decode(self, x): + vq_output = self.vqvae.codebook(x) + x = self.vqvae.decoder(self.vqvae.post_vq_conv(vq_output['embeddings'])) + x = rearrange(x, 'b c t h w -> b t c h w').contiguous() + return x \ No newline at end of file diff --git a/opensora/models/ae/videobase/causal_vqvae/configuration_causalvqvae.py b/opensora/models/ae/videobase/causal_vqvae/configuration_causalvqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..3f18be533fa298aa30c201f16e949898f73e64b5 --- /dev/null +++ b/opensora/models/ae/videobase/causal_vqvae/configuration_causalvqvae.py @@ -0,0 +1,30 @@ +from ..configuration_videobase import VideoBaseConfiguration +from typing import Union, Tuple + +class CausalVQVAEConfiguration(VideoBaseConfiguration): + def __init__( + self, + embedding_dim: int = 256, + n_codes: int = 2048, + n_hiddens: int = 240, + n_res_layers: int = 4, + resolution: int = 128, + sequence_length: int = 16, + time_downsample: int = 4, + spatial_downsample: int = 8, + no_pos_embd: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.embedding_dim = embedding_dim + self.n_codes = n_codes + self.n_hiddens = n_hiddens + self.n_res_layers = n_res_layers + self.resolution = resolution + self.sequence_length = sequence_length + self.time_downsample = time_downsample + self.spatial_downsample = spatial_downsample + self.no_pos_embd = no_pos_embd + + self.hidden_size = n_hiddens diff --git a/opensora/models/ae/videobase/causal_vqvae/modeling_causalvqvae.py b/opensora/models/ae/videobase/causal_vqvae/modeling_causalvqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..65f3f98a97f80b83782fc026070ea79fcf68004a --- /dev/null +++ b/opensora/models/ae/videobase/causal_vqvae/modeling_causalvqvae.py @@ -0,0 +1,848 @@ +from ..modeling_videobase import VideoBaseAE +import torch +from torch import nn, Tensor +import numpy as np +import torch.distributed as dist +import torch.nn.functional as F +import math +import os +import json +from typing import Tuple, Dict, Union +from .configuration_causalvqvae import CausalVQVAEConfiguration +from einops import rearrange, pack, unpack + +# Copied from https://github.com/wilson1yan/VideoGPT +def view_range(x, i, j, shape): + shape = tuple(shape) + + n_dims = len(x.shape) + if i < 0: + i = n_dims + i + + if j is None: + j = n_dims + elif j < 0: + j = n_dims + j + + assert 0 <= i < j <= n_dims + + x_shape = x.shape + target_shape = x_shape[:i] + shape + x_shape[j:] + return x.view(target_shape) + + +# Copied from https://github.com/wilson1yan/VideoGPT +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + dims = list(range(n_dims)) + del dims[src_dim] + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x + + +# Copied from https://github.com/wilson1yan/VideoGPT +def scaled_dot_product_attention(q, k, v, mask=None, attn_dropout=0.0, training=True): + # Performs scaled dot-product attention over the second to last dimension dn + + # (b, n_head, d1, ..., dn, d) + attn = torch.matmul(q, k.transpose(-1, -2)) + attn = attn / np.sqrt(q.shape[-1]) + if mask is not None: + attn = attn.masked_fill(mask == 0, float("-inf")) + attn_float = F.softmax(attn, dim=-1) + attn = attn_float.type_as(attn) # b x n_head x d1 x ... x dn x d + attn = F.dropout(attn, p=attn_dropout, training=training) + + a = torch.matmul(attn, v) # b x n_head x d1 x ... x dn x d + + return a + +def is_odd(n): + return not n % 2 == 0 + +def maybe_del_attr_(o, attr): + if hasattr(o, attr): + delattr(o, attr) + +def cast_tuple(t, length = 1): + return t if isinstance(t, tuple) else ((t,) * length) + +class SpatialDownsample2x(torch.nn.Module): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int]] = (4,4), + stride: Union[int, Tuple[int]] = (2,2) + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 2) + stride = cast_tuple(stride, 2) + + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + + self.conv = torch.nn.Conv2d(self.chan_in, self.chan_out, self.kernel_size, stride=stride) + + def forward(self, x): + x = F.pad(x, self.pad_input) + x = rearrange(x, "b c f h w -> b f c h w") + x, ps = pack([x], "* c h w") + x = self.conv(x) + x = unpack(x, ps, "* c h w")[0] + x = rearrange(x, "b f c h w -> b c f h w") + return x + +class SpatialUpsample2x(torch.nn.Module): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int]] = (3,3), + stride: Union[int, Tuple[int]] = (1,1) + ): + super().__init__() + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + self.conv = torch.nn.Conv2d(self.chan_in, self.chan_out, self.kernel_size, stride=stride, padding=tuple([(k - 1) // 2 for k in kernel_size])) + + def forward(self, x): + x = rearrange(x, "b c f h w -> b f c h w") + x, ps = pack([x], "* c h w") + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + x = unpack(x, ps, "* c h w")[0] + x = rearrange(x, "b f c h w -> b c f h w") + return x + +class TimeDownsample2x(nn.Module): + def __init__( + self, + chan_in, + chan_out, + kernel_size: int = 4, + ): + super().__init__() + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + self.conv = CausalConv3d(chan_in, chan_out, kernel_size, stride=2) + + def forward(self, x): + return self.conv(x) + +class TimeUpsample2x(nn.Module): + def __init__( + self, + chan_in, + chan_out, + kernel_size: int = 3, + ): + super().__init__() + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + self.conv = CausalConv3d(chan_in, chan_out, kernel_size, stride=1) + + def forward(self, x): + x = rearrange(x, "b c f h w -> b c h w f") + x, ps = pack([x], "b * f") + if x.size(-1) > 1: + x = torch.concat((x[:,:,:1], F.interpolate(x[:,:,1:], scale_factor=2.0, mode="linear")), dim=-1) + else: + x = x + x = unpack(x, ps, "b * f")[0] + x = rearrange(x, "b c h w f -> b c f h w") + x = self.conv(x) + return x + +class CausalConv3d(nn.Module): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int, int, int]], + **kwargs + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + self.time_kernel_size = kernel_size[0] + stride = kwargs.pop('stride', 1) + stride = (stride, 1, 1) + total_pad = tuple([k - s for k, s in zip(kernel_size[1:], stride[1:])]) + pad_input = [] + for p in total_pad[::-1]: + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + pad_input += (0, 0) + self.padding = pad_input + self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, **kwargs) + + def forward(self, x): + x = F.pad(x, self.padding) + first_frame_pad = x[:, :, :1, : ,:].repeat((1,1,self.time_kernel_size - 1,1,1)) + x = torch.concatenate((first_frame_pad, x), dim=2) + return self.conv(x) + +# Modified from https://github.com/wilson1yan/VideoGPT +class AxialBlock(nn.Module): + def __init__(self, n_hiddens, n_head): + super().__init__() + kwargs = dict( + shape=(0,) * 3, + dim_q=n_hiddens, + dim_kv=n_hiddens, + n_head=n_head, + n_layer=1, + causal=False, + attn_type="axial", + ) + self.attn_w = MultiHeadAttention(attn_kwargs=dict(axial_dim=-2), **kwargs) + self.attn_h = MultiHeadAttention(attn_kwargs=dict(axial_dim=-3), **kwargs) + kwargs['causal'] = True + self.attn_t = MultiHeadAttention(attn_kwargs=dict(axial_dim=-4), **kwargs) + + def forward(self, x): + x = shift_dim(x, 1, -1) + x = self.attn_w(x, x, x) + self.attn_h(x, x, x) + self.attn_t(x, x, x) + x = shift_dim(x, -1, 1) + return x + +# Copied from https://github.com/wilson1yan/VideoGPT +class AttentionResidualBlock(nn.Module): + def __init__(self, n_hiddens, n_heads: int = 2): + super().__init__() + self.block = nn.Sequential( + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + CausalConv3d(n_hiddens, n_hiddens // 2, 3, bias=False), + nn.BatchNorm3d(n_hiddens // 2), + nn.ReLU(), + CausalConv3d(n_hiddens // 2, n_hiddens, 1, bias=False), + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + AxialBlock(n_hiddens, n_heads), + ) + + def forward(self, x): + return x + self.block(x) + +# Copied from https://github.com/wilson1yan/VideoGPT +class Codebook(nn.Module): + def __init__(self, n_codes, embedding_dim): + super().__init__() + self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) + self.register_buffer("N", torch.zeros(n_codes)) + self.register_buffer("z_avg", self.embeddings.data.clone()) + + self.n_codes = n_codes + self.embedding_dim = embedding_dim + self._need_init = True + + def _tile(self, x): + d, ew = x.shape + if d < self.n_codes: + n_repeats = (self.n_codes + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _init_embeddings(self, z): + # z: [b, c, t, h, w] + self._need_init = False + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) + y = self._tile(flat_inputs) + + d = y.shape[0] + _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) + + def forward(self, z): + # z: [b, c, t, h, w] + if self._need_init and self.training: + self._init_embeddings(z) + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) + distances = ( + (flat_inputs**2).sum(dim=1, keepdim=True) + - 2 * flat_inputs @ self.embeddings.t() + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) + ) + + encoding_indices = torch.argmin(distances, dim=1) + encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) + encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) + + embeddings = F.embedding(encoding_indices, self.embeddings) + embeddings = shift_dim(embeddings, -1, 1) + + commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) + + # EMA codebook update + if self.training: + n_total = encode_onehot.sum(dim=0) + encode_sum = flat_inputs.t() @ encode_onehot + if dist.is_initialized(): + dist.all_reduce(n_total) + dist.all_reduce(encode_sum) + + self.N.data.mul_(0.99).add_(n_total, alpha=0.01) + self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) + + n = self.N.sum() + weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n + encode_normalized = self.z_avg / weights.unsqueeze(1) + self.embeddings.data.copy_(encode_normalized) + + y = self._tile(flat_inputs) + _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + + usage = (self.N.view(self.n_codes, 1) >= 1).float() + self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) + + embeddings_st = (embeddings - z).detach() + z + + avg_probs = torch.mean(encode_onehot, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return dict( + embeddings=embeddings_st, + encodings=encoding_indices, + commitment_loss=commitment_loss, + perplexity=perplexity, + ) + + def dictionary_lookup(self, encodings): + embeddings = F.embedding(encodings, self.embeddings) + return embeddings + +# Modified from https://github.com/wilson1yan/VideoGPT +class Encoder(nn.Module): + def __init__(self, n_hiddens, n_res_layers, time_downsample, spatial_downsample): + super().__init__() + spatial_downsample = int(math.log2(spatial_downsample)) + self.spatial_conv = nn.ModuleList() + for i in range(spatial_downsample): + in_channels = 3 if i == 0 else n_hiddens + conv = SpatialDownsample2x(in_channels, n_hiddens) + self.spatial_conv.append(conv) + self.spatial_res_stack = nn.Sequential( + *[AttentionResidualBlock(n_hiddens) for _ in range(n_res_layers)], + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + ) + time_downsample = int(math.log2(time_downsample)) + self.time_conv = nn.ModuleList() + for i in range(time_downsample): + conv = TimeDownsample2x(n_hiddens, n_hiddens) + self.time_conv.append(conv) + self.time_res_stack = nn.Sequential( + *[AttentionResidualBlock(n_hiddens) for _ in range(n_res_layers)], + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + ) + + def forward(self, x): + h = x + for conv in self.spatial_conv: + h = F.relu(conv(h)) + h = self.spatial_res_stack(h) + for conv in self.time_conv: + h = F.relu(conv(h)) + h = self.time_res_stack(h) + return h + +# Copied from https://github.com/wilson1yan/VideoGPT +class MultiHeadAttention(nn.Module): + def __init__( + self, shape, dim_q, dim_kv, n_head, n_layer, causal, attn_type, attn_kwargs + ): + super().__init__() + self.causal = causal + self.shape = shape + + self.d_k = dim_q // n_head + self.d_v = dim_kv // n_head + self.n_head = n_head + + self.w_qs = nn.Linear(dim_q, n_head * self.d_k, bias=False) # q + self.w_qs.weight.data.normal_(std=1.0 / np.sqrt(dim_q)) + + self.w_ks = nn.Linear(dim_kv, n_head * self.d_k, bias=False) # k + self.w_ks.weight.data.normal_(std=1.0 / np.sqrt(dim_kv)) + + self.w_vs = nn.Linear(dim_kv, n_head * self.d_v, bias=False) # v + self.w_vs.weight.data.normal_(std=1.0 / np.sqrt(dim_kv)) + + self.fc = nn.Linear(n_head * self.d_v, dim_q, bias=True) # c + self.fc.weight.data.normal_(std=1.0 / np.sqrt(dim_q * n_layer)) + + if attn_type == "full": + self.attn = FullAttention(shape, causal, **attn_kwargs) + elif attn_type == "axial": + self.attn = AxialAttention(len(shape), causal=causal, **attn_kwargs) + elif attn_type == "sparse": + self.attn = SparseAttention(shape, n_head, causal, **attn_kwargs) + + self.cache = None + + def forward(self, q, k, v, decode_step=None, decode_idx=None): + """Compute multi-head attention + Args + q, k, v: a [b, d1, ..., dn, c] tensor or + a [b, 1, ..., 1, c] tensor if decode_step is not None + + Returns + The output after performing attention + """ + + # compute k, q, v + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + q = view_range(self.w_qs(q), -1, None, (n_head, d_k)) + k = view_range(self.w_ks(k), -1, None, (n_head, d_k)) + v = view_range(self.w_vs(v), -1, None, (n_head, d_v)) + + # b x n_head x seq_len x d + # (b, *d_shape, n_head, d) -> (b, n_head, *d_shape, d) + q = shift_dim(q, -2, 1) + k = shift_dim(k, -2, 1) + v = shift_dim(v, -2, 1) + + # fast decoding + if decode_step is not None: + if decode_step == 0: + if self.causal: + k_shape = (q.shape[0], n_head, *self.shape, self.d_k) + v_shape = (q.shape[0], n_head, *self.shape, self.d_v) + self.cache = dict( + k=torch.zeros(k_shape, dtype=k.dtype, device=q.device), + v=torch.zeros(v_shape, dtype=v.dtype, device=q.device), + ) + else: + # cache only once in the non-causal case + self.cache = dict(k=k.clone(), v=v.clone()) + if self.causal: + idx = ( + slice(None, None), + slice(None, None), + *[slice(i, i + 1) for i in decode_idx], + ) + self.cache["k"][idx] = k + self.cache["v"][idx] = v + k, v = self.cache["k"], self.cache["v"] + + a = self.attn(q, k, v, decode_step, decode_idx) + + # (b, *d_shape, n_head, d) -> (b, *d_shape, n_head * d) + a = shift_dim(a, 1, -2).flatten(start_dim=-2) + a = self.fc(a) # (b x seq_len x embd_dim) + + return a + +# Copied from https://github.com/wilson1yan/VideoGPT +class Decoder(nn.Module): + def __init__(self, n_hiddens, n_res_layers, time_downsample, spatial_downsample): + super().__init__() + self.time_res_stack = nn.Sequential( + *[AttentionResidualBlock(n_hiddens) for _ in range(n_res_layers)], + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + ) + time_downsample = int(math.log2(time_downsample)) + self.time_conv = nn.ModuleList() + for i in range(time_downsample): + convt = TimeUpsample2x(n_hiddens, n_hiddens) + self.time_conv.append(convt) + self.spatial_res_stack = nn.Sequential( + *[AttentionResidualBlock(n_hiddens) for _ in range(n_res_layers)], + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + ) + spatial_downsample = int(math.log2(spatial_downsample)) + self.spatial_conv = nn.ModuleList() + for i in range(spatial_downsample): + out_channels = 3 if i == spatial_downsample - 1 else n_hiddens + convt = SpatialUpsample2x(n_hiddens, out_channels) + self.spatial_conv.append(convt) + + def forward(self, x): + h = self.time_res_stack(x) + for conv in self.time_conv: + h = F.relu(conv(h)) + h = self.spatial_res_stack(h) + for i, conv in enumerate(self.spatial_conv): + h = conv(h) + if i < len(self.spatial_conv) - 1: + h = F.relu(h) + return h + +# Copied from https://github.com/wilson1yan/VideoGPT +class FullAttention(nn.Module): + def __init__(self, shape, causal, attn_dropout): + super().__init__() + self.causal = causal + self.attn_dropout = attn_dropout + + seq_len = np.prod(shape) + if self.causal: + self.register_buffer("mask", torch.tril(torch.ones(seq_len, seq_len))) + + def forward(self, q, k, v, decode_step, decode_idx): + mask = self.mask if self.causal else None + if decode_step is not None and mask is not None: + mask = mask[[decode_step]] + + old_shape = q.shape[2:-1] + q = q.flatten(start_dim=2, end_dim=-2) + k = k.flatten(start_dim=2, end_dim=-2) + v = v.flatten(start_dim=2, end_dim=-2) + + out = scaled_dot_product_attention( + q, k, v, mask=mask, attn_dropout=self.attn_dropout, training=self.training + ) + + return view_range(out, 2, 3, old_shape) + +# Copied from https://github.com/wilson1yan/VideoGPT +class AxialAttention(nn.Module): + def __init__(self, n_dim, axial_dim, causal=False): + super().__init__() + if axial_dim < 0: + axial_dim = 2 + n_dim + 1 + axial_dim + else: + axial_dim += 2 # account for batch, head, dim + self.causal = causal + self.axial_dim = axial_dim + + def forward(self, q, k, v, decode_step, decode_idx): + # batch, head, frame, height, width, dim + q = shift_dim(q, self.axial_dim, -2).flatten(end_dim=-3) + k = shift_dim(k, self.axial_dim, -2).flatten(end_dim=-3) + v = shift_dim(v, self.axial_dim, -2) + + old_shape = list(v.shape) + v = v.flatten(end_dim=-3) + + if self.causal: + mask = torch.tril(torch.ones(q.shape[-2], q.shape[-2])) if self.causal else None + if decode_step is not None and mask is not None: + mask = mask[[decode_step]] + mask = mask.to(q.device) + else: + mask = None + + out = scaled_dot_product_attention(q, k, v, mask=mask, training=self.training) + out = out.view(*old_shape) + out = shift_dim(out, -2, self.axial_dim) + return out + +# Copied from https://github.com/wilson1yan/VideoGPT +class StridedSparsityConfig(object): + """ + Strided Sparse configuration specified in https://arxiv.org/abs/1904.10509 that + generalizes to arbitrary dimensions + """ + + def __init__(self, shape, n_head, causal, block, num_local_blocks): + self.n_head = n_head + self.shape = shape + self.causal = causal + self.block = block + self.num_local_blocks = num_local_blocks + + assert self.num_local_blocks >= 1, "Must have at least 1 local block" + assert self.seq_len % self.block == 0, "seq len must be divisible by block size" + + self._block_shape = self._compute_block_shape() + self._block_shape_cum = self._block_shape_cum_sizes() + + @property + def seq_len(self): + return np.prod(self.shape) + + @property + def num_blocks(self): + return self.seq_len // self.block + + def set_local_layout(self, layout): + num_blocks = self.num_blocks + for row in range(0, num_blocks): + end = min(row + self.num_local_blocks, num_blocks) + for col in range( + max(0, row - self.num_local_blocks), (row + 1 if self.causal else end) + ): + layout[:, row, col] = 1 + return layout + + def set_global_layout(self, layout): + num_blocks = self.num_blocks + n_dim = len(self._block_shape) + for row in range(num_blocks): + assert self._to_flattened_idx(self._to_unflattened_idx(row)) == row + cur_idx = self._to_unflattened_idx(row) + # no strided attention over last dim + for d in range(n_dim - 1): + end = self._block_shape[d] + for i in range(0, (cur_idx[d] + 1 if self.causal else end)): + new_idx = list(cur_idx) + new_idx[d] = i + new_idx = tuple(new_idx) + + col = self._to_flattened_idx(new_idx) + layout[:, row, col] = 1 + + return layout + + def make_layout(self): + layout = torch.zeros( + (self.n_head, self.num_blocks, self.num_blocks), dtype=torch.int64 + ) + layout = self.set_local_layout(layout) + layout = self.set_global_layout(layout) + return layout + + def make_sparse_attn_mask(self): + block_layout = self.make_layout() + assert block_layout.shape[1] == block_layout.shape[2] == self.num_blocks + + num_dense_blocks = block_layout.sum().item() + attn_mask = torch.ones(num_dense_blocks, self.block, self.block) + counter = 0 + for h in range(self.n_head): + for i in range(self.num_blocks): + for j in range(self.num_blocks): + elem = block_layout[h, i, j].item() + if elem == 1: + assert i >= j + if i == j: # need to mask within block on diagonals + attn_mask[counter] = torch.tril(attn_mask[counter]) + counter += 1 + assert counter == num_dense_blocks + + return attn_mask.unsqueeze(0) + + def get_non_block_layout_row(self, block_layout, row): + block_row = row // self.block + block_row = block_layout[:, [block_row]] # n_head x 1 x n_blocks + block_row = block_row.repeat_interleave(self.block, dim=-1) + block_row[:, :, row + 1 :] = 0.0 + return block_row + + ############# Helper functions ########################## + + def _compute_block_shape(self): + n_dim = len(self.shape) + cum_prod = 1 + for i in range(n_dim - 1, -1, -1): + cum_prod *= self.shape[i] + if cum_prod > self.block: + break + assert cum_prod % self.block == 0 + new_shape = (*self.shape[:i], cum_prod // self.block) + + assert np.prod(new_shape) == np.prod(self.shape) // self.block + + return new_shape + + def _block_shape_cum_sizes(self): + bs = np.flip(np.array(self._block_shape)) + return tuple(np.flip(np.cumprod(bs)[:-1])) + (1,) + + def _to_flattened_idx(self, idx): + assert len(idx) == len( + self._block_shape + ), f"{len(idx)} != {len(self._block_shape)}" + flat_idx = 0 + for i in range(len(self._block_shape)): + flat_idx += idx[i] * self._block_shape_cum[i] + return flat_idx + + def _to_unflattened_idx(self, flat_idx): + assert flat_idx < np.prod(self._block_shape) + idx = [] + for i in range(len(self._block_shape)): + idx.append(flat_idx // self._block_shape_cum[i]) + flat_idx %= self._block_shape_cum[i] + return tuple(idx) + +# Copied from https://github.com/wilson1yan/VideoGPT +class SparseAttention(nn.Module): + ops = dict() + attn_mask = dict() + block_layout = dict() + + def __init__( + self, shape, n_head, causal, num_local_blocks=4, block=32, attn_dropout=0.0 + ): # does not use attn_dropout + super().__init__() + self.causal = causal + self.shape = shape + + self.sparsity_config = StridedSparsityConfig( + shape=shape, + n_head=n_head, + causal=causal, + block=block, + num_local_blocks=num_local_blocks, + ) + + if self.shape not in SparseAttention.block_layout: + SparseAttention.block_layout[self.shape] = ( + self.sparsity_config.make_layout() + ) + if causal and self.shape not in SparseAttention.attn_mask: + SparseAttention.attn_mask[self.shape] = ( + self.sparsity_config.make_sparse_attn_mask() + ) + + def get_ops(self): + try: + from deepspeed.ops.sparse_attention import MatMul, Softmax + except: + raise Exception( + "Error importing deepspeed. Please install using `DS_BUILD_SPARSE_ATTN=1 pip install deepspeed`" + ) + if self.shape not in SparseAttention.ops: + sparsity_layout = self.sparsity_config.make_layout() + sparse_dot_sdd_nt = MatMul( + sparsity_layout, + self.sparsity_config.block, + "sdd", + trans_a=False, + trans_b=True, + ) + + sparse_dot_dsd_nn = MatMul( + sparsity_layout, + self.sparsity_config.block, + "dsd", + trans_a=False, + trans_b=False, + ) + + sparse_softmax = Softmax(sparsity_layout, self.sparsity_config.block) + + SparseAttention.ops[self.shape] = ( + sparse_dot_sdd_nt, + sparse_dot_dsd_nn, + sparse_softmax, + ) + return SparseAttention.ops[self.shape] + + def forward(self, q, k, v, decode_step, decode_idx): + if self.training and self.shape not in SparseAttention.ops: + self.get_ops() + + SparseAttention.block_layout[self.shape] = SparseAttention.block_layout[ + self.shape + ].to(q) + if self.causal: + SparseAttention.attn_mask[self.shape] = ( + SparseAttention.attn_mask[self.shape].to(q).type_as(q) + ) + attn_mask = SparseAttention.attn_mask[self.shape] if self.causal else None + + old_shape = q.shape[2:-1] + q = q.flatten(start_dim=2, end_dim=-2) + k = k.flatten(start_dim=2, end_dim=-2) + v = v.flatten(start_dim=2, end_dim=-2) + + if decode_step is not None: + mask = self.sparsity_config.get_non_block_layout_row( + SparseAttention.block_layout[self.shape], decode_step + ) + out = scaled_dot_product_attention( + q, k, v, mask=mask, training=self.training + ) + else: + if q.shape != k.shape or k.shape != v.shape: + raise Exception("SparseAttention only support self-attention") + sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax = self.get_ops() + scaling = float(q.shape[-1]) ** -0.5 + + attn_output_weights = sparse_dot_sdd_nt(q, k) + if attn_mask is not None: + attn_output_weights = attn_output_weights.masked_fill( + attn_mask == 0, float("-inf") + ) + attn_output_weights = sparse_softmax(attn_output_weights, scale=scaling) + + out = sparse_dot_dsd_nn(attn_output_weights, v) + + return view_range(out, 2, 3, old_shape) + +class CausalVQVAEModel(VideoBaseAE): + + def __init__(self, config: CausalVQVAEConfiguration): + super().__init__() + self.config = config + self.embedding_dim = config.embedding_dim + self.n_codes = config.n_codes + self.encoder = Encoder(config.n_hiddens, config.n_res_layers, config.time_downsample, config.spatial_downsample) + self.decoder = Decoder(config.n_hiddens, config.n_res_layers, config.time_downsample, config.spatial_downsample) + self.pre_vq_conv = CausalConv3d(config.n_hiddens, config.embedding_dim, 1) + self.post_vq_conv = CausalConv3d(config.embedding_dim, config.n_hiddens, 1) + self.codebook = Codebook(config.n_codes, config.embedding_dim) + + def forward(self, x): + z = self.pre_vq_conv(self.encoder(x)) + vq_output = self.codebook(z) + x_recon = self.decoder(self.post_vq_conv(vq_output["embeddings"])) + recon_loss = F.mse_loss(x_recon, x) / 0.06 + return recon_loss, x_recon, vq_output + + def encode(self, x: Tensor, include_embeddings: bool = False) -> Union[Tuple[Tensor, Tensor], Tensor]: + h = self.pre_vq_conv(self.encoder(x)) + vq_output: Dict[str, Tensor] = self.codebook(h) + if include_embeddings: + return vq_output["encodings"], vq_output["embeddings"] + else: + return vq_output["encodings"] + + def decode(self, encodings: Tensor) -> Tensor: + h = F.embedding(encodings, self.codebook.embeddings) + h = self.post_vq_conv(shift_dim(h, -1, 1)) + return self.decoder(h) + + @classmethod + def load_from_checkpoint(cls, model_path): + with open(os.path.join(model_path, "config.json"), "r") as file: + config = json.load(file) + state_dict = torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location="cpu") + model = cls(config=CausalVQVAEConfiguration(**config)) + model.load_state_dict(state_dict) + return model + + @classmethod + def download_and_load_model(cls, model_name, cache_dir=None): + raise NotImplementedError() diff --git a/opensora/models/ae/videobase/causal_vqvae/trainer_causalvqvae.py b/opensora/models/ae/videobase/causal_vqvae/trainer_causalvqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..f819bce363c4cfb10ec7bff3a3e754a30e09dec7 --- /dev/null +++ b/opensora/models/ae/videobase/causal_vqvae/trainer_causalvqvae.py @@ -0,0 +1,21 @@ +from ..trainer_videobase import VideoBaseTrainer +import torch.nn.functional as F +from typing import Optional +import os +import torch +from transformers.utils import WEIGHTS_NAME +import json + +class CausalVQVAETrainer(VideoBaseTrainer): + + def compute_loss(self, model, inputs, return_outputs=False): + model = model.module + x = inputs.get("video") + x = x / 2 + z = model.pre_vq_conv(model.encoder(x)) + vq_output = model.codebook(z) + x_recon = model.decoder(model.post_vq_conv(vq_output["embeddings"])) + recon_loss = F.mse_loss(x_recon, x) / 0.06 + commitment_loss = vq_output['commitment_loss'] + loss = recon_loss + commitment_loss + return loss diff --git a/opensora/models/ae/videobase/configuration_videobase.py b/opensora/models/ae/videobase/configuration_videobase.py new file mode 100644 index 0000000000000000000000000000000000000000..f25a25c5b37da8aa8f9e939112ced0b9f861a4b6 --- /dev/null +++ b/opensora/models/ae/videobase/configuration_videobase.py @@ -0,0 +1,44 @@ +import json +import yaml +from typing import TypeVar, Dict, Any +from diffusers import ConfigMixin + +T = TypeVar('T', bound='VideoBaseConfiguration') +class VideoBaseConfiguration(ConfigMixin): + config_name = "VideoBaseConfiguration" + _nested_config_fields: Dict[str, Any] = {} + + def __init__(self, **kwargs): + pass + + def to_dict(self) -> Dict[str, Any]: + d = {} + for key, value in vars(self).items(): + if isinstance(value, VideoBaseConfiguration): + d[key] = value.to_dict() # Serialize nested VideoBaseConfiguration instances + elif isinstance(value, tuple): + d[key] = list(value) + else: + d[key] = value + return d + + def to_yaml_file(self, yaml_path: str): + with open(yaml_path, 'w') as yaml_file: + yaml.dump(self.to_dict(), yaml_file, default_flow_style=False) + + @classmethod + def load_from_yaml(cls: T, yaml_path: str) -> T: + with open(yaml_path, 'r') as yaml_file: + config_dict = yaml.safe_load(yaml_file) + for field, field_type in cls._nested_config_fields.items(): + if field in config_dict: + config_dict[field] = field_type.load_from_dict(config_dict[field]) + return cls(**config_dict) + + @classmethod + def load_from_dict(cls: T, config_dict: Dict[str, Any]) -> T: + # Process nested configuration objects + for field, field_type in cls._nested_config_fields.items(): + if field in config_dict: + config_dict[field] = field_type.load_from_dict(config_dict[field]) + return cls(**config_dict) \ No newline at end of file diff --git a/opensora/models/ae/videobase/dataset_videobase.py b/opensora/models/ae/videobase/dataset_videobase.py new file mode 100644 index 0000000000000000000000000000000000000000..32f842f63a310c9ee8e2d1dde3c3e695bdaa582a --- /dev/null +++ b/opensora/models/ae/videobase/dataset_videobase.py @@ -0,0 +1,107 @@ +import os.path as osp +import random +from glob import glob + +from torchvision import transforms +import numpy as np +import torch +import torch.utils.data as data +import torch.nn.functional as F +from torchvision.transforms import Lambda + +from ....dataset.transform import ToTensorVideo, CenterCropVideo +from ....utils.dataset_utils import DecordInit + +def TemporalRandomCrop(total_frames, size): + """ + Performs a random temporal crop on a video sequence. + + This function randomly selects a continuous frame sequence of length `size` from a video sequence. + `total_frames` indicates the total number of frames in the video sequence, and `size` represents the length of the frame sequence to be cropped. + + Parameters: + - total_frames (int): The total number of frames in the video sequence. + - size (int): The length of the frame sequence to be cropped. + + Returns: + - (int, int): A tuple containing two integers. The first integer is the starting frame index of the cropped sequence, + and the second integer is the ending frame index (inclusive) of the cropped sequence. + """ + rand_end = max(0, total_frames - size - 1) + begin_index = random.randint(0, rand_end) + end_index = min(begin_index + size, total_frames) + return begin_index, end_index + +def resize(x, resolution): + height, width = x.shape[-2:] + resolution = min(2 * resolution, height, width) + aspect_ratio = width / height + if width <= height: + new_width = resolution + new_height = int(resolution / aspect_ratio) + else: + new_height = resolution + new_width = int(resolution * aspect_ratio) + resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True) + return resized_x + +class VideoDataset(data.Dataset): + """ Generic dataset for videos files stored in folders + Returns BCTHW videos in the range [-0.5, 0.5] """ + video_exts = ['avi', 'mp4', 'webm'] + def __init__(self, video_folder, sequence_length, image_folder=None, train=True, resolution=64, sample_rate=1, dynamic_sample=True): + + self.train = train + self.sequence_length = sequence_length + self.sample_rate = sample_rate + self.resolution = resolution + self.v_decoder = DecordInit() + self.video_folder = video_folder + self.dynamic_sample = dynamic_sample + + self.transform = transforms.Compose([ + ToTensorVideo(), + # Lambda(lambda x: resize(x, self.resolution)), + CenterCropVideo(self.resolution), + Lambda(lambda x: 2.0 * x - 1.0) + ]) + print('Building datasets...') + self.samples = self._make_dataset() + + def _make_dataset(self): + samples = [] + samples += sum([glob(osp.join(self.video_folder, '**', f'*.{ext}'), recursive=True) + for ext in self.video_exts], []) + return samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + video_path = self.samples[idx] + try: + video = self.decord_read(video_path) + video = self.transform(video) # T C H W -> T C H W + video = video.transpose(0, 1) # T C H W -> C T H W + return dict(video=video, label="") + except Exception as e: + print(f'Error with {e}, {video_path}') + return self.__getitem__(random.randint(0, self.__len__()-1)) + + def decord_read(self, path): + decord_vr = self.v_decoder(path) + total_frames = len(decord_vr) + # Sampling video frames + if self.dynamic_sample: + sample_rate = random.randint(1, self.sample_rate) + else: + sample_rate = self.sample_rate + size = self.sequence_length * sample_rate + start_frame_ind, end_frame_ind = TemporalRandomCrop(total_frames, size) + # assert end_frame_ind - start_frame_ind >= self.num_frames + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sequence_length, dtype=int) + + video_data = decord_vr.get_batch(frame_indice).asnumpy() + video_data = torch.from_numpy(video_data) + video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) + return video_data \ No newline at end of file diff --git a/opensora/models/ae/videobase/losses/__init__.py b/opensora/models/ae/videobase/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91e59504e7094dc422bba4f3d12f02305ce9b30f --- /dev/null +++ b/opensora/models/ae/videobase/losses/__init__.py @@ -0,0 +1 @@ +from .perceptual_loss import SimpleLPIPS, LPIPSWithDiscriminator, LPIPSWithDiscriminator3D \ No newline at end of file diff --git a/opensora/models/ae/videobase/losses/discriminator.py b/opensora/models/ae/videobase/losses/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ffb2528a5f4918b4604f1e2aaab719a758d9d8 --- /dev/null +++ b/opensora/models/ae/videobase/losses/discriminator.py @@ -0,0 +1,115 @@ +import functools +import torch.nn as nn +from ..modules.normalize import ActNorm +from einops import rearrange + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + +class NLayerDiscriminator3D(nn.Module): + """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs.""" + def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False): + """ + Construct a 3D PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input volumes + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + use_actnorm (bool) -- flag to use actnorm instead of batchnorm + """ + super(NLayerDiscriminator3D, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm3d + else: + raise NotImplementedError("Not implemented.") + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func != nn.BatchNorm3d + else: + use_bias = norm_layer != nn.BatchNorm3d + + kw = 4 + padw = 1 + sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(1,2,2), padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) diff --git a/opensora/models/ae/videobase/losses/lpips.py b/opensora/models/ae/videobase/losses/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7062cdd0e9b65e6eb268a94ab3fe139e074bb3 --- /dev/null +++ b/opensora/models/ae/videobase/losses/lpips.py @@ -0,0 +1,120 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +import torch +import torch.nn as nn +from torchvision import models +from collections import namedtuple +from .....utils.taming_download import get_ckpt_path + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") + self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name) + model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x,eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2,3],keepdim=keepdim) diff --git a/opensora/models/ae/videobase/losses/perceptual_loss.py b/opensora/models/ae/videobase/losses/perceptual_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e4042d0528b6c175581e79934c46f468dc95f0c4 --- /dev/null +++ b/opensora/models/ae/videobase/losses/perceptual_loss.py @@ -0,0 +1,414 @@ +import torch +from torch import nn +import torch.nn.functional as F +from .lpips import LPIPS +from einops import rearrange +from .discriminator import NLayerDiscriminator, weights_init, NLayerDiscriminator3D + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake)) + ) + return d_loss + + +def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): + assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] + loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3]) + loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3]) + loss_real = (weights * loss_real).sum() / weights.sum() + loss_fake = (weights * loss_fake).sum() / weights.sum() + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def adopt_weight(weight, global_step, threshold=0, value=0.0): + if global_step < threshold: + weight = value + return weight + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + + +def l1(x, y): + return torch.abs(x - y) + + +def l2(x, y): + return torch.pow((x - y), 2) + + +class LPIPSWithDiscriminator(nn.Module): + def __init__( + self, + disc_start, + logvar_init=0.0, + kl_weight=1.0, + pixelloss_weight=1.0, + perceptual_weight=1.0, + # --- Discriminator Loss --- + disc_num_layers=3, + disc_in_channels=3, + disc_factor=1.0, + disc_weight=1.0, + use_actnorm=False, + disc_conditional=False, + disc_loss="hinge", + ): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator( + input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad( + nll_loss, self.last_layer[0], retain_graph=True + )[0] + g_grads = torch.autograd.grad( + g_loss, self.last_layer[0], retain_graph=True + )[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward( + self, + inputs, + reconstructions, + posteriors, + optimizer_idx, + global_step, + split="train", + weights=None, + last_layer=None, + cond=None, + ): + inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous() + reconstructions = rearrange( + reconstructions, "b c t h w -> (b t) c h w" + ).contiguous() + rec_loss = torch.abs(inputs - reconstructions) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs, reconstructions) + rec_loss = rec_loss + self.perceptual_weight * p_loss + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights * nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # GAN Part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator( + torch.cat((reconstructions.contiguous(), cond), dim=1) + ) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight( + nll_loss, g_loss, last_layer=last_layer + ) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight( + self.disc_factor, global_step, threshold=self.discriminator_iter_start + ) + loss = ( + weighted_nll_loss + + self.kl_weight * kl_loss + + d_weight * disc_factor * g_loss + ) + log = { + "{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator( + torch.cat((inputs.contiguous().detach(), cond), dim=1) + ) + logits_fake = self.discriminator( + torch.cat((reconstructions.contiguous().detach(), cond), dim=1) + ) + + disc_factor = adopt_weight( + self.disc_factor, global_step, threshold=self.discriminator_iter_start + ) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = { + "{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean(), + } + return d_loss, log + + +class LPIPSWithDiscriminator3D(nn.Module): + def __init__( + self, + disc_start, + logvar_init=0.0, + kl_weight=1.0, + pixelloss_weight=1.0, + perceptual_weight=1.0, + # --- Discriminator Loss --- + disc_num_layers=3, + disc_in_channels=3, + disc_factor=1.0, + disc_weight=1.0, + use_actnorm=False, + disc_conditional=False, + disc_loss="hinge", + ): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator3D( + input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad( + nll_loss, self.last_layer[0], retain_graph=True + )[0] + g_grads = torch.autograd.grad( + g_loss, self.last_layer[0], retain_graph=True + )[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward( + self, + inputs, + reconstructions, + posteriors, + optimizer_idx, + global_step, + split="train", + weights=None, + last_layer=None, + cond=None, + ): + t = inputs.shape[2] + inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous() + reconstructions = rearrange( + reconstructions, "b c t h w -> (b t) c h w" + ).contiguous() + rec_loss = torch.abs(inputs - reconstructions) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs, reconstructions) + rec_loss = rec_loss + self.perceptual_weight * p_loss + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights * nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + inputs = rearrange(inputs, "(b t) c h w -> b c t h w", t=t).contiguous() + reconstructions = rearrange( + reconstructions, "(b t) c h w -> b c t h w", t=t + ).contiguous() + # GAN Part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions) + else: + assert self.disc_conditional + logits_fake = self.discriminator( + torch.cat((reconstructions, cond), dim=1) + ) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight( + nll_loss, g_loss, last_layer=last_layer + ) + except RuntimeError as e: + assert not self.training, print(e) + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight( + self.disc_factor, global_step, threshold=self.discriminator_iter_start + ) + loss = ( + weighted_nll_loss + + self.kl_weight * kl_loss + + d_weight * disc_factor * g_loss + ) + log = { + "{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator( + torch.cat((inputs.contiguous().detach(), cond), dim=1) + ) + logits_fake = self.discriminator( + torch.cat((reconstructions.contiguous().detach(), cond), dim=1) + ) + + disc_factor = adopt_weight( + self.disc_factor, global_step, threshold=self.discriminator_iter_start + ) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = { + "{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean(), + } + return d_loss, log + + +class SimpleLPIPS(nn.Module): + def __init__( + self, + logvar_init=0.0, + kl_weight=1.0, + pixelloss_weight=1.0, + perceptual_weight=1.0, + disc_loss="hinge", + ): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + def forward( + self, + inputs, + reconstructions, + posteriors, + split="train", + weights=None, + ): + inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous() + reconstructions = rearrange( + reconstructions, "b c t h w -> (b t) c h w" + ).contiguous() + rec_loss = torch.abs(inputs - reconstructions) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs, reconstructions) + rec_loss = rec_loss + self.perceptual_weight * p_loss + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights * nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + loss = weighted_nll_loss + self.kl_weight * kl_loss + log = { + "{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + } + if self.perceptual_weight > 0: + log.update({"{}/p_loss".format(split): p_loss.detach().mean()}) + return loss, log diff --git a/opensora/models/ae/videobase/modeling_videobase.py b/opensora/models/ae/videobase/modeling_videobase.py new file mode 100644 index 0000000000000000000000000000000000000000..4da5ae4c880ed0f13f6e9ab7b3b3d9e88b7ed604 --- /dev/null +++ b/opensora/models/ae/videobase/modeling_videobase.py @@ -0,0 +1,65 @@ +import torch +from diffusers import ModelMixin, ConfigMixin +from torch import nn +import os +import json +import pytorch_lightning as pl +from diffusers.configuration_utils import ConfigMixin +from diffusers.models.modeling_utils import ModelMixin + + +class VideoBaseAE(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = False + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + @classmethod + def load_from_checkpoint(cls, model_path): + with open(os.path.join(model_path, "config.json"), "r") as file: + config = json.load(file) + state_dict = torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location="cpu") + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + model = cls(config=cls.CONFIGURATION_CLS(**config)) + model.load_state_dict(state_dict) + return model + + @classmethod + def download_and_load_model(cls, model_name, cache_dir=None): + pass + + def encode(self, x: torch.Tensor, *args, **kwargs): + pass + + def decode(self, encoding: torch.Tensor, *args, **kwargs): + pass + +class VideoBaseAE_PL(pl.LightningModule, ModelMixin, ConfigMixin): + config_name = "config.json" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def encode(self, x: torch.Tensor, *args, **kwargs): + pass + + def decode(self, encoding: torch.Tensor, *args, **kwargs): + pass + + @property + def num_training_steps(self) -> int: + """Total training steps inferred from datamodule and devices.""" + if self.trainer.max_steps: + return self.trainer.max_steps + + limit_batches = self.trainer.limit_train_batches + batches = len(self.train_dataloader()) + batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches) + + num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) + if self.trainer.tpu_cores: + num_devices = max(num_devices, self.trainer.tpu_cores) + + effective_accum = self.trainer.accumulate_grad_batches * num_devices + return (batches // effective_accum) * self.trainer.max_epochs \ No newline at end of file diff --git a/opensora/models/ae/videobase/modules/__init__.py b/opensora/models/ae/videobase/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f25379ef2724d9cabe1fb479e16e4fea884582cc --- /dev/null +++ b/opensora/models/ae/videobase/modules/__init__.py @@ -0,0 +1,23 @@ +from .block import Block +from .attention import ( + AttnBlock3D, + AttnBlock, + LinAttnBlock, + LinearAttention, + TemporalAttnBlock, +) +from .conv import CausalConv3d, Conv2d +from .normalize import GroupNorm, Normalize +from .resnet_block import ResnetBlock2D, ResnetBlock3D +from .updownsample import ( + SpatialDownsample2x, + SpatialUpsample2x, + TimeDownsample2x, + TimeUpsample2x, + Upsample, + Downsample, + TimeDownsampleRes2x, + TimeUpsampleRes2x, + TimeDownsampleResAdv2x, + TimeUpsampleResAdv2x +) diff --git a/opensora/models/ae/videobase/modules/attention.py b/opensora/models/ae/videobase/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ed33ff0e902494e67f25a48f93f7b599c2a96234 --- /dev/null +++ b/opensora/models/ae/videobase/modules/attention.py @@ -0,0 +1,176 @@ +import torch.nn as nn +from .normalize import Normalize +from .conv import CausalConv3d +import torch +import numpy as np +from einops import rearrange +from .block import Block +from .ops import video_to_image + +class LinearAttention(Block): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock3D(Block): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, t, h, w = q.shape + q = q.reshape(b * t, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b * t, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b * t, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, t, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class AttnBlock(Block): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + @video_to_image + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class TemporalAttnBlock(Block): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, t, h, w = q.shape + q = rearrange(q, "b c t h w -> (b h w) t c") + k = rearrange(k, "b c t h w -> (b h w) c t") + v = rearrange(v, "b c t h w -> (b h w) c t") + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = rearrange(h_, "(b h w) c t -> b c t h w", h=h, w=w) + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + print(attn_type) + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "vanilla3D": + return AttnBlock3D(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) \ No newline at end of file diff --git a/opensora/models/ae/videobase/modules/block.py b/opensora/models/ae/videobase/modules/block.py new file mode 100644 index 0000000000000000000000000000000000000000..e93672d20b22cbaaa36b379473149c55f0f44001 --- /dev/null +++ b/opensora/models/ae/videobase/modules/block.py @@ -0,0 +1,5 @@ +import torch.nn as nn + +class Block(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) \ No newline at end of file diff --git a/opensora/models/ae/videobase/modules/conv.py b/opensora/models/ae/videobase/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..4159fdb0ba52188dc8ea0528c315600a6e43c708 --- /dev/null +++ b/opensora/models/ae/videobase/modules/conv.py @@ -0,0 +1,99 @@ +import torch.nn as nn +from typing import Union, Tuple +import torch.nn.functional as F +import torch +from .block import Block +from .ops import cast_tuple +from einops import rearrange +from .ops import video_to_image + +class Conv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]] = 3, + stride: Union[int, Tuple[int]] = 1, + padding: Union[str, int, Tuple[int]] = 0, + dilation: Union[int, Tuple[int]] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + + @video_to_image + def forward(self, x): + return super().forward(x) + + +class CausalConv3d(nn.Module): + def __init__( + self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs + ): + super().__init__() + self.kernel_size = cast_tuple(kernel_size, 3) + self.time_kernel_size = self.kernel_size[0] + self.chan_in = chan_in + self.chan_out = chan_out + stride = kwargs.pop("stride", 1) + padding = kwargs.pop("padding", 0) + padding = list(cast_tuple(padding, 3)) + padding[0] = 0 + stride = cast_tuple(stride, 3) + self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding) + self._init_weights(init_method) + + def _init_weights(self, init_method): + ks = torch.tensor(self.kernel_size) + if init_method == "avg": + assert ( + self.kernel_size[1] == 1 and self.kernel_size[2] == 1 + ), "only support temporal up/down sample" + assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out" + weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)) + + eyes = torch.concat( + [ + torch.eye(self.chan_in).unsqueeze(-1) * 1/3, + torch.eye(self.chan_in).unsqueeze(-1) * 1/3, + torch.eye(self.chan_in).unsqueeze(-1) * 1/3, + ], + dim=-1, + ) + weight[:, :, :, 0, 0] = eyes + + self.conv.weight = nn.Parameter( + weight, + requires_grad=True, + ) + elif init_method == "zero": + self.conv.weight = nn.Parameter( + torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)), + requires_grad=True, + ) + if self.conv.bias is not None: + nn.init.constant_(self.conv.bias, 0) + + def forward(self, x): + # 1 + 16 16 as video, 1 as image + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.time_kernel_size - 1, 1, 1) + ) # b c t h w + x = torch.concatenate((first_frame_pad, x), dim=2) + return self.conv(x) + \ No newline at end of file diff --git a/opensora/models/ae/videobase/modules/normalize.py b/opensora/models/ae/videobase/modules/normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..7c8c05f0fa3459214dd077e09a157c588adb637b --- /dev/null +++ b/opensora/models/ae/videobase/modules/normalize.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +from .block import Block + +class GroupNorm(Block): + def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.norm = torch.nn.GroupNorm( + num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True + ) + def forward(self, x): + return self.norm(x) + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + +class ActNorm(nn.Module): + def __init__(self, num_features, logdet=False, affine=True, + allow_reverse_init=False): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:,:,None,None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height*width*torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:,:,None,None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h diff --git a/opensora/models/ae/videobase/modules/ops.py b/opensora/models/ae/videobase/modules/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fdd262ad7145eb343838561f119767534761d4d3 --- /dev/null +++ b/opensora/models/ae/videobase/modules/ops.py @@ -0,0 +1,40 @@ +import torch +from einops import rearrange + +def video_to_image(func): + def wrapper(self, x, *args, **kwargs): + if x.dim() == 5: + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = func(self, x, *args, **kwargs) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + return wrapper + +def nonlinearity(x): + return x * torch.sigmoid(x) + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + dims = list(range(n_dims)) + del dims[src_dim] + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x \ No newline at end of file diff --git a/opensora/models/ae/videobase/modules/quant.py b/opensora/models/ae/videobase/modules/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..bb702cee5f594e73be3efe9f573f1fbb8032e70c --- /dev/null +++ b/opensora/models/ae/videobase/modules/quant.py @@ -0,0 +1,100 @@ +import torch +import torch.nn as nn +import torch.distributed as dist +import numpy as np +import torch.nn.functional as F +from .ops import shift_dim + +class Codebook(nn.Module): + def __init__(self, n_codes, embedding_dim): + super().__init__() + self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) + self.register_buffer("N", torch.zeros(n_codes)) + self.register_buffer("z_avg", self.embeddings.data.clone()) + + self.n_codes = n_codes + self.embedding_dim = embedding_dim + self._need_init = True + + def _tile(self, x): + d, ew = x.shape + if d < self.n_codes: + n_repeats = (self.n_codes + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _init_embeddings(self, z): + # z: [b, c, t, h, w] + self._need_init = False + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) + y = self._tile(flat_inputs) + + d = y.shape[0] + _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) + + def forward(self, z): + # z: [b, c, t, h, w] + if self._need_init and self.training: + self._init_embeddings(z) + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) + distances = ( + (flat_inputs**2).sum(dim=1, keepdim=True) + - 2 * flat_inputs @ self.embeddings.t() + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) + ) + + encoding_indices = torch.argmin(distances, dim=1) + encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) + encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) + + embeddings = F.embedding(encoding_indices, self.embeddings) + embeddings = shift_dim(embeddings, -1, 1) + + commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) + + # EMA codebook update + if self.training: + n_total = encode_onehot.sum(dim=0) + encode_sum = flat_inputs.t() @ encode_onehot + if dist.is_initialized(): + dist.all_reduce(n_total) + dist.all_reduce(encode_sum) + + self.N.data.mul_(0.99).add_(n_total, alpha=0.01) + self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) + + n = self.N.sum() + weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n + encode_normalized = self.z_avg / weights.unsqueeze(1) + self.embeddings.data.copy_(encode_normalized) + + y = self._tile(flat_inputs) + _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + + usage = (self.N.view(self.n_codes, 1) >= 1).float() + self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) + + embeddings_st = (embeddings - z).detach() + z + + avg_probs = torch.mean(encode_onehot, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return dict( + embeddings=embeddings_st, + encodings=encoding_indices, + commitment_loss=commitment_loss, + perplexity=perplexity, + ) + + def dictionary_lookup(self, encodings): + embeddings = F.embedding(encodings, self.embeddings) + return embeddings \ No newline at end of file diff --git a/opensora/models/ae/videobase/modules/resnet_block.py b/opensora/models/ae/videobase/modules/resnet_block.py new file mode 100644 index 0000000000000000000000000000000000000000..189766a5bfc9c1943dab284ef808496710b24994 --- /dev/null +++ b/opensora/models/ae/videobase/modules/resnet_block.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +from einops import rearrange, pack, unpack +from .normalize import Normalize +from .ops import nonlinearity, video_to_image +from .conv import CausalConv3d +from .block import Block + +class ResnetBlock2D(Block): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + @video_to_image + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + x = x + h + return x + +class ResnetBlock3D(Block): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1) + else: + self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x + h \ No newline at end of file diff --git a/opensora/models/ae/videobase/modules/updownsample.py b/opensora/models/ae/videobase/modules/updownsample.py new file mode 100644 index 0000000000000000000000000000000000000000..34e265be379234739128ee6b9167e4c8d49fcb9b --- /dev/null +++ b/opensora/models/ae/videobase/modules/updownsample.py @@ -0,0 +1,232 @@ +from typing import Union, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from .resnet_block import ResnetBlock3D +from .attention import TemporalAttnBlock +from .normalize import Normalize +from .ops import cast_tuple, video_to_image +from .conv import CausalConv3d +from einops import rearrange +from .block import Block + +class Upsample(Block): + def __init__(self, in_channels, out_channels): + super().__init__() + self.with_conv = True + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + @video_to_image + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + +class Downsample(Block): + def __init__(self, in_channels, out_channels): + super().__init__() + self.with_conv = True + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=0) + @video_to_image + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + +class SpatialDownsample2x(Block): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int]] = (3, 3), + stride: Union[int, Tuple[int]] = (2, 2), + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 2) + stride = cast_tuple(stride, 2) + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + self.conv = CausalConv3d( + self.chan_in, + self.chan_out, + (1,) + self.kernel_size, + stride=(1, ) + stride, + padding=0 + ) + + def forward(self, x): + pad = (0,1,0,1,0,0) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class SpatialUpsample2x(Block): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int]] = (3, 3), + stride: Union[int, Tuple[int]] = (1, 1), + ): + super().__init__() + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + self.conv = CausalConv3d( + self.chan_in, + self.chan_out, + (1,) + self.kernel_size, + stride=(1, ) + stride, + padding=1 + ) + + def forward(self, x): + t = x.shape[2] + x = rearrange(x, "b c t h w -> b (c t) h w") + x = F.interpolate(x, scale_factor=(2,2), mode="nearest") + x = rearrange(x, "b (c t) h w -> b c t h w", t=t) + x = self.conv(x) + return x + +class TimeDownsample2x(Block): + def __init__( + self, + chan_in, + chan_out, + kernel_size: int = 3 + ): + super().__init__() + self.kernel_size = kernel_size + self.conv = nn.AvgPool3d((kernel_size,1,1), stride=(2,1,1)) + + def forward(self, x): + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.kernel_size - 1, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x), dim=2) + return self.conv(x) + +class TimeUpsample2x(Block): + def __init__( + self, + chan_in, + chan_out + ): + super().__init__() + def forward(self, x): + if x.size(2) > 1: + x,x_= x[:,:,:1],x[:,:,1:] + x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear') + x = torch.concat([x, x_], dim=2) + return x + +class TimeDownsampleRes2x(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + mix_factor: float = 2, + ): + super().__init__() + self.kernel_size = cast_tuple(kernel_size, 3) + self.avg_pool = nn.AvgPool3d((kernel_size,1,1), stride=(2,1,1)) + self.conv = nn.Conv3d( + in_channels, out_channels, self.kernel_size, stride=(2,1,1), padding=(0,1,1) + ) + self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) + + def forward(self, x): + alpha = torch.sigmoid(self.mix_factor) + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.kernel_size[0] - 1, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x), dim=2) + return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(x) + +class TimeUpsampleRes2x(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + mix_factor: float = 2, + ): + super().__init__() + self.conv = CausalConv3d( + in_channels, out_channels, kernel_size, padding=1 + ) + self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) + + def forward(self, x): + alpha = torch.sigmoid(self.mix_factor) + if x.size(2) > 1: + x,x_= x[:,:,:1],x[:,:,1:] + x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear') + x = torch.concat([x, x_], dim=2) + return alpha * x + (1-alpha) * self.conv(x) + +class TimeDownsampleResAdv2x(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3 + ): + super().__init__() + self.kernel_size = cast_tuple(kernel_size, 3) + self.avg_pool = nn.AvgPool3d((kernel_size,1,1), stride=(2,1,1)) + self.attn = TemporalAttnBlock(in_channels) + self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0) + self.conv = nn.Conv3d( + in_channels, out_channels, self.kernel_size, stride=(2,1,1), padding=(0,1,1) + ) + self.mix_factor = 1 + + def forward(self, x): + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.kernel_size[0] - 1, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x), dim=2) + return self.mix_factor * self.avg_pool(x) + (1 - self.mix_factor) * self.conv(self.attn((self.res(x)))) + +class TimeUpsampleResAdv2x(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + ): + super().__init__() + self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0) + self.attn = TemporalAttnBlock(in_channels) + self.norm = Normalize(in_channels=in_channels) + self.conv = CausalConv3d( + in_channels, out_channels, kernel_size, padding=1 + ) + self.mix_factor = 1 + + def forward(self, x): + if x.size(2) > 1: + x,x_= x[:,:,:1],x[:,:,1:] + x_= F.interpolate(x_, scale_factor=(2,1,1), mode='trilinear') + x = torch.concat([x, x_], dim=2) + return self.mix_factor * x + (1-self.mix_factor) * self.conv(self.attn(self.res(x))) diff --git a/opensora/models/ae/videobase/trainer_videobase.py b/opensora/models/ae/videobase/trainer_videobase.py new file mode 100644 index 0000000000000000000000000000000000000000..b6c4b0c7d2a2fdec0488cfeedbfa9844416e4053 --- /dev/null +++ b/opensora/models/ae/videobase/trainer_videobase.py @@ -0,0 +1,26 @@ +from transformers import Trainer +import torch.nn.functional as F +from typing import Optional +import os +import torch +from transformers.utils import WEIGHTS_NAME +import json + +class VideoBaseTrainer(Trainer): + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + if state_dict is None: + state_dict = self.model.state_dict() + + # get model config + model_config = self.model.config.to_dict() + + # add more information + model_config['model'] = self.model.__class__.__name__ + + with open(os.path.join(output_dir, "config.json"), "w") as file: + json.dump(self.model.config.to_dict(), file) + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + torch.save(self.args, os.path.join(output_dir, "training_args.bin")) diff --git a/opensora/models/ae/videobase/utils/distrib_utils.py b/opensora/models/ae/videobase/utils/distrib_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..760c0673fe5d8afa663eb1ea5cd7683dbf5dd9f8 --- /dev/null +++ b/opensora/models/ae/videobase/utils/distrib_utils.py @@ -0,0 +1,42 @@ +import torch +import numpy as np + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean diff --git a/opensora/models/ae/videobase/utils/module_utils.py b/opensora/models/ae/videobase/utils/module_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..187888aa16a813b866ea23047d37c3a525a03823 --- /dev/null +++ b/opensora/models/ae/videobase/utils/module_utils.py @@ -0,0 +1,17 @@ +import importlib + +Module = str +MODULES_BASE = "opensora.models.ae.videobase.modules." + +def resolve_str_to_obj(str_val, append=True): + if append: + str_val = MODULES_BASE + str_val + module_name, class_name = str_val.rsplit('.', 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) + +def create_instance(module_class_str: str, **kwargs): + module_name, class_name = module_class_str.rsplit('.', 1) + module = importlib.import_module(module_name) + class_ = getattr(module, class_name) + return class_(**kwargs) \ No newline at end of file diff --git a/opensora/models/ae/videobase/utils/scheduler_utils.py b/opensora/models/ae/videobase/utils/scheduler_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0650a3f79c3bc4f1bd5bf4a556995dc538b2b23a --- /dev/null +++ b/opensora/models/ae/videobase/utils/scheduler_utils.py @@ -0,0 +1,7 @@ +import torch + +def cosine_scheduler(step, max_steps, value_base=1, value_end=0): + step = torch.tensor(step) + cosine_value = 0.5 * (1 + torch.cos(torch.pi * step / max_steps)) + value = value_end + (value_base - value_end) * cosine_value + return value \ No newline at end of file diff --git a/opensora/models/ae/videobase/vqvae/__init__.py b/opensora/models/ae/videobase/vqvae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec138ecc722fd5a8a540f363ed383dcb10e93695 --- /dev/null +++ b/opensora/models/ae/videobase/vqvae/__init__.py @@ -0,0 +1,30 @@ +from einops import rearrange +from torch import nn + +from .configuration_vqvae import VQVAEConfiguration +from .modeling_vqvae import VQVAEModel +from .trainer_vqvae import VQVAETrainer + +videovqvae = [ + "bair_stride4x2x2", + "ucf101_stride4x4x4", + "kinetics_stride4x4x4", + "kinetics_stride2x4x4", +] +videovae = [] + +class VQVAEModelWrapper(nn.Module): + def __init__(self, ckpt='kinetics_stride4x4x4'): + super(VQVAEModelWrapper, self).__init__() + if ckpt in videovqvae: + self.vqvae = VQVAEModel.download_and_load_model(ckpt) + else: + self.vqvae = VQVAEModel.load_from_checkpoint(ckpt) + def encode(self, x): # b c t h w + x = self.vqvae.pre_vq_conv(self.vqvae.encoder(x)) + return x + def decode(self, x): + vq_output = self.vqvae.codebook(x) + x = self.vqvae.decoder(self.vqvae.post_vq_conv(vq_output['embeddings'])) + x = rearrange(x, 'b c t h w -> b t c h w').contiguous() + return x diff --git a/opensora/models/ae/videobase/vqvae/configuration_vqvae.py b/opensora/models/ae/videobase/vqvae/configuration_vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..90ac29cfa5f3c899b3c78b11868cfb30e4908812 --- /dev/null +++ b/opensora/models/ae/videobase/vqvae/configuration_vqvae.py @@ -0,0 +1,33 @@ +from ..configuration_videobase import VideoBaseConfiguration +from typing import Union, Tuple + +class VQVAEConfiguration(VideoBaseConfiguration): + def __init__( + self, + embedding_dim: int = 256, + n_codes: int = 2048, + n_hiddens: int = 240, + n_res_layers: int = 4, + resolution: int = 128, + sequence_length: int = 16, + downsample: Union[Tuple[int, int, int], str] = (4, 4, 4), + no_pos_embd: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.embedding_dim = embedding_dim + self.n_codes = n_codes + self.n_hiddens = n_hiddens + self.n_res_layers = n_res_layers + self.resolution = resolution + self.sequence_length = sequence_length + + if isinstance(downsample, str): + self.downsample = tuple(map(int, downsample.split(","))) + else: + self.downsample = downsample + + self.no_pos_embd = no_pos_embd + + self.hidden_size = n_hiddens diff --git a/opensora/models/ae/videobase/vqvae/modeling_vqvae.py b/opensora/models/ae/videobase/vqvae/modeling_vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..6a51677e6c947bc52ee991d6d7b0a62813583769 --- /dev/null +++ b/opensora/models/ae/videobase/vqvae/modeling_vqvae.py @@ -0,0 +1,775 @@ +from ..modeling_videobase import VideoBaseAE +import torch +from torch import nn, Tensor +import numpy as np +import torch.distributed as dist +import torch.nn.functional as F +import math +import os +import json +from typing import Tuple, Dict, Union +from .configuration_vqvae import VQVAEConfiguration + + +# Copied from https://github.com/wilson1yan/VideoGPT +def view_range(x, i, j, shape): + shape = tuple(shape) + + n_dims = len(x.shape) + if i < 0: + i = n_dims + i + + if j is None: + j = n_dims + elif j < 0: + j = n_dims + j + + assert 0 <= i < j <= n_dims + + x_shape = x.shape + target_shape = x_shape[:i] + shape + x_shape[j:] + return x.view(target_shape) + + +# Copied from https://github.com/wilson1yan/VideoGPT +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + dims = list(range(n_dims)) + del dims[src_dim] + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x + + +# Copied from https://github.com/wilson1yan/VideoGPT +def scaled_dot_product_attention(q, k, v, mask=None, attn_dropout=0.0, training=True): + # Performs scaled dot-product attention over the second to last dimension dn + + # (b, n_head, d1, ..., dn, d) + attn = torch.matmul(q, k.transpose(-1, -2)) + attn = attn / np.sqrt(q.shape[-1]) + if mask is not None: + attn = attn.masked_fill(mask == 0, float("-inf")) + attn_float = F.softmax(attn, dim=-1) + attn = attn_float.type_as(attn) # b x n_head x d1 x ... x dn x d + attn = F.dropout(attn, p=attn_dropout, training=training) + + a = torch.matmul(attn, v) # b x n_head x d1 x ... x dn x d + + return a + + +# Copied from https://github.com/wilson1yan/VideoGPT +class AxialBlock(nn.Module): + def __init__(self, n_hiddens, n_head): + super().__init__() + kwargs = dict( + shape=(0,) * 3, + dim_q=n_hiddens, + dim_kv=n_hiddens, + n_head=n_head, + n_layer=1, + causal=False, + attn_type="axial", + ) + self.attn_w = MultiHeadAttention(attn_kwargs=dict(axial_dim=-2), **kwargs) + self.attn_h = MultiHeadAttention(attn_kwargs=dict(axial_dim=-3), **kwargs) + self.attn_t = MultiHeadAttention(attn_kwargs=dict(axial_dim=-4), **kwargs) + + def forward(self, x): + x = shift_dim(x, 1, -1) + x = self.attn_w(x, x, x) + self.attn_h(x, x, x) + self.attn_t(x, x, x) + x = shift_dim(x, -1, 1) + return x + + +# Copied from https://github.com/wilson1yan/VideoGPT +class AttentionResidualBlock(nn.Module): + def __init__(self, n_hiddens): + super().__init__() + self.block = nn.Sequential( + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + SamePadConv3d(n_hiddens, n_hiddens // 2, 3, bias=False), + nn.BatchNorm3d(n_hiddens // 2), + nn.ReLU(), + SamePadConv3d(n_hiddens // 2, n_hiddens, 1, bias=False), + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + AxialBlock(n_hiddens, 2), + ) + + def forward(self, x): + return x + self.block(x) + + +# Copied from https://github.com/wilson1yan/VideoGPT +class Codebook(nn.Module): + def __init__(self, n_codes, embedding_dim): + super().__init__() + self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) + self.register_buffer("N", torch.zeros(n_codes)) + self.register_buffer("z_avg", self.embeddings.data.clone()) + + self.n_codes = n_codes + self.embedding_dim = embedding_dim + self._need_init = True + + def _tile(self, x): + d, ew = x.shape + if d < self.n_codes: + n_repeats = (self.n_codes + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _init_embeddings(self, z): + # z: [b, c, t, h, w] + self._need_init = False + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) + y = self._tile(flat_inputs) + + d = y.shape[0] + _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) + + def forward(self, z): + # z: [b, c, t, h, w] + if self._need_init and self.training: + self._init_embeddings(z) + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) + distances = ( + (flat_inputs**2).sum(dim=1, keepdim=True) + - 2 * flat_inputs @ self.embeddings.t() + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) + ) + + encoding_indices = torch.argmin(distances, dim=1) + encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) + encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) + + embeddings = F.embedding(encoding_indices, self.embeddings) + embeddings = shift_dim(embeddings, -1, 1) + + commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) + + # EMA codebook update + if self.training: + n_total = encode_onehot.sum(dim=0) + encode_sum = flat_inputs.t() @ encode_onehot + if dist.is_initialized(): + dist.all_reduce(n_total) + dist.all_reduce(encode_sum) + + self.N.data.mul_(0.99).add_(n_total, alpha=0.01) + self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) + + n = self.N.sum() + weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n + encode_normalized = self.z_avg / weights.unsqueeze(1) + self.embeddings.data.copy_(encode_normalized) + + y = self._tile(flat_inputs) + _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + + usage = (self.N.view(self.n_codes, 1) >= 1).float() + self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) + + embeddings_st = (embeddings - z).detach() + z + + avg_probs = torch.mean(encode_onehot, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return dict( + embeddings=embeddings_st, + encodings=encoding_indices, + commitment_loss=commitment_loss, + perplexity=perplexity, + ) + + def dictionary_lookup(self, encodings): + embeddings = F.embedding(encodings, self.embeddings) + return embeddings + + +# Copied from https://github.com/wilson1yan/VideoGPT +class Encoder(nn.Module): + def __init__(self, n_hiddens, n_res_layers, downsample): + super().__init__() + n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) + self.convs = nn.ModuleList() + max_ds = n_times_downsample.max() + for i in range(max_ds): + in_channels = 3 if i == 0 else n_hiddens + stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) + conv = SamePadConv3d(in_channels, n_hiddens, 4, stride=stride) + self.convs.append(conv) + n_times_downsample -= 1 + self.conv_last = SamePadConv3d(in_channels, n_hiddens, kernel_size=3) + + self.res_stack = nn.Sequential( + *[AttentionResidualBlock(n_hiddens) for _ in range(n_res_layers)], + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + ) + + def forward(self, x): + h = x + for conv in self.convs: + h = F.relu(conv(h)) + h = self.conv_last(h) + h = self.res_stack(h) + return h + + +# Copied from https://github.com/wilson1yan/VideoGPT +class MultiHeadAttention(nn.Module): + def __init__( + self, shape, dim_q, dim_kv, n_head, n_layer, causal, attn_type, attn_kwargs + ): + super().__init__() + self.causal = causal + self.shape = shape + + self.d_k = dim_q // n_head + self.d_v = dim_kv // n_head + self.n_head = n_head + + self.w_qs = nn.Linear(dim_q, n_head * self.d_k, bias=False) # q + self.w_qs.weight.data.normal_(std=1.0 / np.sqrt(dim_q)) + + self.w_ks = nn.Linear(dim_kv, n_head * self.d_k, bias=False) # k + self.w_ks.weight.data.normal_(std=1.0 / np.sqrt(dim_kv)) + + self.w_vs = nn.Linear(dim_kv, n_head * self.d_v, bias=False) # v + self.w_vs.weight.data.normal_(std=1.0 / np.sqrt(dim_kv)) + + self.fc = nn.Linear(n_head * self.d_v, dim_q, bias=True) # c + self.fc.weight.data.normal_(std=1.0 / np.sqrt(dim_q * n_layer)) + + if attn_type == "full": + self.attn = FullAttention(shape, causal, **attn_kwargs) + elif attn_type == "axial": + assert not causal, "causal axial attention is not supported" + self.attn = AxialAttention(len(shape), **attn_kwargs) + elif attn_type == "sparse": + self.attn = SparseAttention(shape, n_head, causal, **attn_kwargs) + + self.cache = None + + def forward(self, q, k, v, decode_step=None, decode_idx=None): + """Compute multi-head attention + Args + q, k, v: a [b, d1, ..., dn, c] tensor or + a [b, 1, ..., 1, c] tensor if decode_step is not None + + Returns + The output after performing attention + """ + + # compute k, q, v + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + q = view_range(self.w_qs(q), -1, None, (n_head, d_k)) + k = view_range(self.w_ks(k), -1, None, (n_head, d_k)) + v = view_range(self.w_vs(v), -1, None, (n_head, d_v)) + + # b x n_head x seq_len x d + # (b, *d_shape, n_head, d) -> (b, n_head, *d_shape, d) + q = shift_dim(q, -2, 1) + k = shift_dim(k, -2, 1) + v = shift_dim(v, -2, 1) + + # fast decoding + if decode_step is not None: + if decode_step == 0: + if self.causal: + k_shape = (q.shape[0], n_head, *self.shape, self.d_k) + v_shape = (q.shape[0], n_head, *self.shape, self.d_v) + self.cache = dict( + k=torch.zeros(k_shape, dtype=k.dtype, device=q.device), + v=torch.zeros(v_shape, dtype=v.dtype, device=q.device), + ) + else: + # cache only once in the non-causal case + self.cache = dict(k=k.clone(), v=v.clone()) + if self.causal: + idx = ( + slice(None, None), + slice(None, None), + *[slice(i, i + 1) for i in decode_idx], + ) + self.cache["k"][idx] = k + self.cache["v"][idx] = v + k, v = self.cache["k"], self.cache["v"] + + a = self.attn(q, k, v, decode_step, decode_idx) + + # (b, *d_shape, n_head, d) -> (b, *d_shape, n_head * d) + a = shift_dim(a, 1, -2).flatten(start_dim=-2) + a = self.fc(a) # (b x seq_len x embd_dim) + + return a + + +# Copied from https://github.com/wilson1yan/VideoGPT +class Decoder(nn.Module): + def __init__(self, n_hiddens, n_res_layers, upsample): + super().__init__() + self.res_stack = nn.Sequential( + *[AttentionResidualBlock(n_hiddens) for _ in range(n_res_layers)], + nn.BatchNorm3d(n_hiddens), + nn.ReLU(), + ) + + n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) + max_us = n_times_upsample.max() + self.convts = nn.ModuleList() + for i in range(max_us): + out_channels = 3 if i == max_us - 1 else n_hiddens + us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) + convt = SamePadConvTranspose3d(n_hiddens, out_channels, 4, stride=us) + self.convts.append(convt) + n_times_upsample -= 1 + + def forward(self, x): + h = self.res_stack(x) + for i, convt in enumerate(self.convts): + h = convt(h) + if i < len(self.convts) - 1: + h = F.relu(h) + return h + + +# Copied from https://github.com/wilson1yan/VideoGPT +class SamePadConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + # assumes that the input shape is divisible by stride + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + + self.conv = nn.Conv3d( + in_channels, out_channels, kernel_size, stride=stride, padding=0, bias=bias + ) + + def forward(self, x): + return self.conv(F.pad(x, self.pad_input)) + + +# Copied from https://github.com/wilson1yan/VideoGPT +class SamePadConvTranspose3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + + self.convt = nn.ConvTranspose3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + bias=bias, + padding=tuple([k - 1 for k in kernel_size]), + ) + + def forward(self, x): + return self.convt(F.pad(x, self.pad_input)) + + +# Copied from https://github.com/wilson1yan/VideoGPT +class FullAttention(nn.Module): + def __init__(self, shape, causal, attn_dropout): + super().__init__() + self.causal = causal + self.attn_dropout = attn_dropout + + seq_len = np.prod(shape) + if self.causal: + self.register_buffer("mask", torch.tril(torch.ones(seq_len, seq_len))) + + def forward(self, q, k, v, decode_step, decode_idx): + mask = self.mask if self.causal else None + if decode_step is not None and mask is not None: + mask = mask[[decode_step]] + + old_shape = q.shape[2:-1] + q = q.flatten(start_dim=2, end_dim=-2) + k = k.flatten(start_dim=2, end_dim=-2) + v = v.flatten(start_dim=2, end_dim=-2) + + out = scaled_dot_product_attention( + q, k, v, mask=mask, attn_dropout=self.attn_dropout, training=self.training + ) + + return view_range(out, 2, 3, old_shape) + + +# Copied from https://github.com/wilson1yan/VideoGPT +class AxialAttention(nn.Module): + def __init__(self, n_dim, axial_dim): + super().__init__() + if axial_dim < 0: + axial_dim = 2 + n_dim + 1 + axial_dim + else: + axial_dim += 2 # account for batch, head, dim + self.axial_dim = axial_dim + + def forward(self, q, k, v, decode_step, decode_idx): + q = shift_dim(q, self.axial_dim, -2).flatten(end_dim=-3) + k = shift_dim(k, self.axial_dim, -2).flatten(end_dim=-3) + v = shift_dim(v, self.axial_dim, -2) + old_shape = list(v.shape) + v = v.flatten(end_dim=-3) + + out = scaled_dot_product_attention(q, k, v, training=self.training) + out = out.view(*old_shape) + out = shift_dim(out, -2, self.axial_dim) + return out + + +# Copied from https://github.com/wilson1yan/VideoGPT +class StridedSparsityConfig(object): + """ + Strided Sparse configuration specified in https://arxiv.org/abs/1904.10509 that + generalizes to arbitrary dimensions + """ + + def __init__(self, shape, n_head, causal, block, num_local_blocks): + self.n_head = n_head + self.shape = shape + self.causal = causal + self.block = block + self.num_local_blocks = num_local_blocks + + assert self.num_local_blocks >= 1, "Must have at least 1 local block" + assert self.seq_len % self.block == 0, "seq len must be divisible by block size" + + self._block_shape = self._compute_block_shape() + self._block_shape_cum = self._block_shape_cum_sizes() + + @property + def seq_len(self): + return np.prod(self.shape) + + @property + def num_blocks(self): + return self.seq_len // self.block + + def set_local_layout(self, layout): + num_blocks = self.num_blocks + for row in range(0, num_blocks): + end = min(row + self.num_local_blocks, num_blocks) + for col in range( + max(0, row - self.num_local_blocks), (row + 1 if self.causal else end) + ): + layout[:, row, col] = 1 + return layout + + def set_global_layout(self, layout): + num_blocks = self.num_blocks + n_dim = len(self._block_shape) + for row in range(num_blocks): + assert self._to_flattened_idx(self._to_unflattened_idx(row)) == row + cur_idx = self._to_unflattened_idx(row) + # no strided attention over last dim + for d in range(n_dim - 1): + end = self._block_shape[d] + for i in range(0, (cur_idx[d] + 1 if self.causal else end)): + new_idx = list(cur_idx) + new_idx[d] = i + new_idx = tuple(new_idx) + + col = self._to_flattened_idx(new_idx) + layout[:, row, col] = 1 + + return layout + + def make_layout(self): + layout = torch.zeros( + (self.n_head, self.num_blocks, self.num_blocks), dtype=torch.int64 + ) + layout = self.set_local_layout(layout) + layout = self.set_global_layout(layout) + return layout + + def make_sparse_attn_mask(self): + block_layout = self.make_layout() + assert block_layout.shape[1] == block_layout.shape[2] == self.num_blocks + + num_dense_blocks = block_layout.sum().item() + attn_mask = torch.ones(num_dense_blocks, self.block, self.block) + counter = 0 + for h in range(self.n_head): + for i in range(self.num_blocks): + for j in range(self.num_blocks): + elem = block_layout[h, i, j].item() + if elem == 1: + assert i >= j + if i == j: # need to mask within block on diagonals + attn_mask[counter] = torch.tril(attn_mask[counter]) + counter += 1 + assert counter == num_dense_blocks + + return attn_mask.unsqueeze(0) + + def get_non_block_layout_row(self, block_layout, row): + block_row = row // self.block + block_row = block_layout[:, [block_row]] # n_head x 1 x n_blocks + block_row = block_row.repeat_interleave(self.block, dim=-1) + block_row[:, :, row + 1 :] = 0.0 + return block_row + + ############# Helper functions ########################## + + def _compute_block_shape(self): + n_dim = len(self.shape) + cum_prod = 1 + for i in range(n_dim - 1, -1, -1): + cum_prod *= self.shape[i] + if cum_prod > self.block: + break + assert cum_prod % self.block == 0 + new_shape = (*self.shape[:i], cum_prod // self.block) + + assert np.prod(new_shape) == np.prod(self.shape) // self.block + + return new_shape + + def _block_shape_cum_sizes(self): + bs = np.flip(np.array(self._block_shape)) + return tuple(np.flip(np.cumprod(bs)[:-1])) + (1,) + + def _to_flattened_idx(self, idx): + assert len(idx) == len( + self._block_shape + ), f"{len(idx)} != {len(self._block_shape)}" + flat_idx = 0 + for i in range(len(self._block_shape)): + flat_idx += idx[i] * self._block_shape_cum[i] + return flat_idx + + def _to_unflattened_idx(self, flat_idx): + assert flat_idx < np.prod(self._block_shape) + idx = [] + for i in range(len(self._block_shape)): + idx.append(flat_idx // self._block_shape_cum[i]) + flat_idx %= self._block_shape_cum[i] + return tuple(idx) + + +# Copied from https://github.com/wilson1yan/VideoGPT +class SparseAttention(nn.Module): + ops = dict() + attn_mask = dict() + block_layout = dict() + + def __init__( + self, shape, n_head, causal, num_local_blocks=4, block=32, attn_dropout=0.0 + ): # does not use attn_dropout + super().__init__() + self.causal = causal + self.shape = shape + + self.sparsity_config = StridedSparsityConfig( + shape=shape, + n_head=n_head, + causal=causal, + block=block, + num_local_blocks=num_local_blocks, + ) + + if self.shape not in SparseAttention.block_layout: + SparseAttention.block_layout[self.shape] = ( + self.sparsity_config.make_layout() + ) + if causal and self.shape not in SparseAttention.attn_mask: + SparseAttention.attn_mask[self.shape] = ( + self.sparsity_config.make_sparse_attn_mask() + ) + + def get_ops(self): + try: + from deepspeed.ops.sparse_attention import MatMul, Softmax + except: + raise Exception( + "Error importing deepspeed. Please install using `DS_BUILD_SPARSE_ATTN=1 pip install deepspeed`" + ) + if self.shape not in SparseAttention.ops: + sparsity_layout = self.sparsity_config.make_layout() + sparse_dot_sdd_nt = MatMul( + sparsity_layout, + self.sparsity_config.block, + "sdd", + trans_a=False, + trans_b=True, + ) + + sparse_dot_dsd_nn = MatMul( + sparsity_layout, + self.sparsity_config.block, + "dsd", + trans_a=False, + trans_b=False, + ) + + sparse_softmax = Softmax(sparsity_layout, self.sparsity_config.block) + + SparseAttention.ops[self.shape] = ( + sparse_dot_sdd_nt, + sparse_dot_dsd_nn, + sparse_softmax, + ) + return SparseAttention.ops[self.shape] + + def forward(self, q, k, v, decode_step, decode_idx): + if self.training and self.shape not in SparseAttention.ops: + self.get_ops() + + SparseAttention.block_layout[self.shape] = SparseAttention.block_layout[ + self.shape + ].to(q) + if self.causal: + SparseAttention.attn_mask[self.shape] = ( + SparseAttention.attn_mask[self.shape].to(q).type_as(q) + ) + attn_mask = SparseAttention.attn_mask[self.shape] if self.causal else None + + old_shape = q.shape[2:-1] + q = q.flatten(start_dim=2, end_dim=-2) + k = k.flatten(start_dim=2, end_dim=-2) + v = v.flatten(start_dim=2, end_dim=-2) + + if decode_step is not None: + mask = self.sparsity_config.get_non_block_layout_row( + SparseAttention.block_layout[self.shape], decode_step + ) + out = scaled_dot_product_attention( + q, k, v, mask=mask, training=self.training + ) + else: + if q.shape != k.shape or k.shape != v.shape: + raise Exception("SparseAttention only support self-attention") + sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax = self.get_ops() + scaling = float(q.shape[-1]) ** -0.5 + + attn_output_weights = sparse_dot_sdd_nt(q, k) + if attn_mask is not None: + attn_output_weights = attn_output_weights.masked_fill( + attn_mask == 0, float("-inf") + ) + attn_output_weights = sparse_softmax(attn_output_weights, scale=scaling) + + out = sparse_dot_dsd_nn(attn_output_weights, v) + + return view_range(out, 2, 3, old_shape) + + +# Modified from https://github.com/wilson1yan/VideoGPT +class VQVAEModel(VideoBaseAE): + + DOWNLOADED_VQVAE = { + "bair_stride4x2x2": "1iIAYJ2Qqrx5Q94s5eIXQYJgAydzvT_8L", + "ucf101_stride4x4x4": "1uuB_8WzHP_bbBmfuaIV7PK_Itl3DyHY5", + "kinetics_stride4x4x4": "1DOvOZnFAIQmux6hG7pN_HkyJZy3lXbCB", + "kinetics_stride2x4x4": "1jvtjjtrtE4cy6pl7DK_zWFEPY3RZt2pB", + } + + def __init__(self, config: VQVAEConfiguration): + super().__init__() + self.config = config + self.embedding_dim = config.embedding_dim + self.n_codes = config.n_codes + self.encoder = Encoder(config.n_hiddens, config.n_res_layers, config.downsample) + self.decoder = Decoder(config.n_hiddens, config.n_res_layers, config.downsample) + self.pre_vq_conv = SamePadConv3d(config.n_hiddens, config.embedding_dim, 1) + self.post_vq_conv = SamePadConv3d(config.embedding_dim, config.n_hiddens, 1) + self.codebook = Codebook(config.n_codes, config.embedding_dim) + + def forward(self, x): + z = self.pre_vq_conv(self.encoder(x)) + vq_output = self.codebook(z) + x_recon = self.decoder(self.post_vq_conv(vq_output["embeddings"])) + recon_loss = F.mse_loss(x_recon, x) / 0.06 + return recon_loss, x_recon, vq_output + + def encode(self, x: Tensor, include_embeddings: bool = False) -> Union[Tuple[Tensor, Tensor], Tensor]: + h = self.pre_vq_conv(self.encoder(x)) + vq_output: Dict[str, Tensor] = self.codebook(h) + if include_embeddings: + return vq_output["encodings"], vq_output["embeddings"] + else: + return vq_output["encodings"] + + def decode(self, encodings: Tensor) -> Tensor: + h = F.embedding(encodings, self.codebook.embeddings) + h = self.post_vq_conv(shift_dim(h, -1, 1)) + return self.decoder(h) + + @classmethod + def load_from_checkpoint(cls, model_path): + if not os.path.isdir(model_path): + """model downloaded from internet""" + model_cpkt = torch.load(model_path) + # Compatible with old videogpt model formats. + if "hyper_parameters" in model_cpkt: + hyper_parameters = vars(model_cpkt.get("hyper_parameters").get("args")) + state_dict = model_cpkt.get("state_dict") + model = cls(config=VQVAEConfiguration(**hyper_parameters)) + model.load_state_dict(state_dict) + return model + else: + raise RuntimeError("Model checkpoint has a wrong format.") + else: + with open(os.path.join(model_path, "config.json"), "r") as file: + config = json.load(file) + state_dict = torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location="cpu") + model = cls(config=VQVAEConfiguration(**config)) + model.load_state_dict(state_dict) + return model + + @classmethod + def download_and_load_model(cls, model_name, cache_dir=None): + from .....utils.downloader import gdown_download + path = gdown_download( + cls.DOWNLOADED_VQVAE[model_name], model_name, cache_dir=cache_dir + ) + return cls.load_from_checkpoint(path) diff --git a/opensora/models/ae/videobase/vqvae/trainer_vqvae.py b/opensora/models/ae/videobase/vqvae/trainer_vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..df3f866eeeea3c50d8de372608ddfd795586efc6 --- /dev/null +++ b/opensora/models/ae/videobase/vqvae/trainer_vqvae.py @@ -0,0 +1,22 @@ +from ..trainer_videobase import VideoBaseTrainer +import torch.nn.functional as F +from typing import Optional +import os +import torch +from transformers.utils import WEIGHTS_NAME +import json + +class VQVAETrainer(VideoBaseTrainer): + + def compute_loss(self, model, inputs, return_outputs=False): + model = model.module + x = inputs.get("video") + x = x / 2 + z = model.pre_vq_conv(model.encoder(x)) + vq_output = model.codebook(z) + x_recon = model.decoder(model.post_vq_conv(vq_output["embeddings"])) + recon_loss = F.mse_loss(x_recon, x) / 0.06 + commitment_loss = vq_output['commitment_loss'] + loss = recon_loss + commitment_loss + return loss + diff --git a/opensora/models/captioner/caption_refiner/README.md b/opensora/models/captioner/caption_refiner/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cf1ae266a19d1a80a609b6b3a5e789bfa4956f55 --- /dev/null +++ b/opensora/models/captioner/caption_refiner/README.md @@ -0,0 +1,38 @@ +# Refiner for Video Caption + +Transform the short caption annotations from video datasets into the long and detailed caption annotations. + +* Add detailed description for background scene. +* Add detailed description for object attributes, including color, material, pose. +* Add detailed description for object-level spatial relationship. + +## 🛠️ Extra Requirements and Installation + +* openai == 0.28.0 +* jsonlines == 4.0.0 +* nltk == 3.8.1 +* Install the LLaMA-Accessory: + +you also need to download the weight of SPHINX to ./ckpt/ folder + +## 🗝️ Refining + +The refining instruction is in [demo_for_refiner.py](demo_for_refiner.py). + +```bash +python demo_for_refiner.py --root_path $path_to_repo$ --api_key $openai_api_key$ +``` + +### Refining Demos + +```bash +[original caption]: A red mustang parked in a showroom with american flags hanging from the ceiling. +``` + +```bash +[refine caption]: This scene depicts a red Mustang parked in a showroom with American flags hanging from the ceiling. The showroom likely serves as a space for showcasing and purchasing cars, and the Mustang is displayed prominently near the flags and ceiling. The scene also features a large window and other objects. Overall, it seems to take place in a car show or dealership. +``` + +- [ ] Add GPT-3.5-Turbo for caption summarization. ⌛ [WIP] +- [ ] Add LLAVA-1.6. ⌛ [WIP] +- [ ] More descriptions. ⌛ [WIP] \ No newline at end of file diff --git a/opensora/models/captioner/caption_refiner/caption_refiner.py b/opensora/models/captioner/caption_refiner/caption_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..23952f6d9ce504945151619e4e6295360db67d00 --- /dev/null +++ b/opensora/models/captioner/caption_refiner/caption_refiner.py @@ -0,0 +1,122 @@ +import itertools +import numpy as np +from PIL import Image +from PIL import ImageSequence +from nltk import pos_tag, word_tokenize + +from LLaMA2_Accessory.SPHINX import SPHINXModel +from gpt_combinator import caption_summary + +class CaptionRefiner(): + def __init__(self, sample_num, add_detect=True, add_pos=True, add_attr=True, + openai_api_key=None, openai_api_base=None, + ): + self.sample_num = sample_num + self.ADD_DETECTION_OBJ = add_detect + self.ADD_POS = add_pos + self.ADD_ATTR = add_attr + self.openai_api_key = openai_api_key + self.openai_api_base =openai_api_base + + def video_load_split(self, video_path=None): + frame_img_list, sampled_img_list = [], [] + + if ".gif" in video_path: + img = Image.open(video_path) + # process every frame in GIF from to + for frame in ImageSequence.Iterator(img): + frame_np = np.array(frame.copy().convert('RGB').getdata(),dtype=np.uint8).reshape(frame.size[1],frame.size[0],3) + frame_img = Image.fromarray(np.uint8(frame_np)) + frame_img_list.append(frame_img) + elif ".mp4" in video_path: + pass + + # sample frames from the mp4/gif + for i in range(0, len(frame_img_list), int(len(frame_img_list)/self.sample_num)): + sampled_img_list.append(frame_img_list[i]) + + return sampled_img_list # [, ...] + + def caption_refine(self, video_path, org_caption, model_path): + sampled_imgs = self.video_load_split(video_path) + + model = SPHINXModel.from_pretrained( + pretrained_path=model_path, + with_visual=True + ) + + existing_objects, scene_description = [], [] + text = word_tokenize(org_caption) + existing_objects = [word for word,tag in pos_tag(text) if tag in ["NN", "NNS", "NNP"]] + if self.ADD_DETECTION_OBJ: + # Detect the objects and scene in the sampled images + + qas = [["Where is this scene in the picture most likely to take place?", None]] + sc_response = model.generate_response(qas, sampled_imgs[0], max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0) + scene_description.append(sc_response) + + # # Lacking accuracy + # for img in sampled_imgs: + # qas = [["Please detect the objects in the image.", None]] + # response = model.generate_response(qas, img, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0) + # print(response) + + object_attrs = [] + if self.ADD_ATTR: + # Detailed Description for all the objects in the sampled images + for obj in existing_objects: + obj_attr = [] + for img in sampled_imgs: + qas = [["Please describe the attribute of the {}, including color, position, etc".format(obj), None]] + response = model.generate_response(qas, img, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0) + obj_attr.append(response) + object_attrs.append({obj : obj_attr}) + + space_relations = [] + if self.ADD_POS: + obj_pairs = list(itertools.combinations(existing_objects, 2)) + # Description for the relationship between each object in the sample images + for obj_pair in obj_pairs: + qas = [["What is the spatial relationship between the {} and the {}? Please describe in lease than twenty words".format(obj_pair[0], obj_pair[1]), None]] + response = model.generate_response(qas, img, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0) + space_relations.append(response) + + return dict( + org_caption = org_caption, + scene_description = scene_description, + existing_objects = existing_objects, + object_attrs = object_attrs, + space_relations = space_relations, + ) + + def gpt_summary(self, total_captions): + # combine all captions into a detailed long caption + detailed_caption = "" + + if "org_caption" in total_captions.keys(): + detailed_caption += "In summary, "+ total_captions['org_caption'] + + if "scene_description" in total_captions.keys(): + detailed_caption += "We first describe the whole scene. "+total_captions['scene_description'][-1] + + if "existing_objects" in total_captions.keys(): + tmp_sentence = "There are multiple objects in the video, including " + for obj in total_captions['existing_objects']: + tmp_sentence += obj+", " + detailed_caption += tmp_sentence + + # if "object_attrs" in total_captions.keys(): + # caption_summary( + # caption_list="", + # api_key=self.openai_api_key, + # api_base=self.openai_api_base, + # ) + + if "space_relations" in total_captions.keys(): + tmp_sentence = "As for the spatial relationship. " + for sentence in total_captions['space_relations']: tmp_sentence += sentence + detailed_caption += tmp_sentence + + detailed_caption = caption_summary(detailed_caption, self.open_api_key, self.open_api_base) + + return detailed_caption \ No newline at end of file diff --git a/opensora/models/captioner/caption_refiner/dataset/test_videos/captions.json b/opensora/models/captioner/caption_refiner/dataset/test_videos/captions.json new file mode 100644 index 0000000000000000000000000000000000000000..098a352f2e3a6eaebf4ccf7885bb7b2718d44176 --- /dev/null +++ b/opensora/models/captioner/caption_refiner/dataset/test_videos/captions.json @@ -0,0 +1 @@ +{"video1.gif": "A red mustang parked in a showroom with american flags hanging from the ceiling.", "video2.gif": "An aerial view of a city with a river running through it."} \ No newline at end of file diff --git a/opensora/models/captioner/caption_refiner/demo_for_refiner.py b/opensora/models/captioner/caption_refiner/demo_for_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c0bfc5eb5b42b7da0e3f9ce13373c69cb39bae --- /dev/null +++ b/opensora/models/captioner/caption_refiner/demo_for_refiner.py @@ -0,0 +1,28 @@ +import argparse +from caption_refiner import CaptionRefiner +from gpt_combinator import caption_summary, caption_qa + +def parse_args(): + parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3") + parser.add_argument("--root_path", required=True, help="The path to repo.") + parser.add_argument("--api_key", required=True, help="OpenAI API key.") + args = parser.parse_args() + return args + +if __name__ == "__main__": + args = parse_args() + myrefiner = CaptionRefiner( + sample_num=6, add_detect=True, add_pos=True, add_attr=True, + openai_api_key = args.api_key, + openai_api_base = "https://one-api.bltcy.top/v1", + ) + + results = myrefiner.caption_refine( + video_path="./dataset/test_videos/video1.gif", + org_caption="A red mustang parked in a showroom with american flags hanging from the ceiling.", + model_path = args.root_path + "/ckpts/SPHINX-Tiny", + ) + + final_caption = myrefiner.gpt_summary(results) + + print(final_caption) diff --git a/opensora/models/captioner/caption_refiner/gpt_combinator.py b/opensora/models/captioner/caption_refiner/gpt_combinator.py new file mode 100644 index 0000000000000000000000000000000000000000..c0a6f0dff9b2b198533c9741028a75961720fc6e --- /dev/null +++ b/opensora/models/captioner/caption_refiner/gpt_combinator.py @@ -0,0 +1,93 @@ +import openai +import ast + +def caption_qa(caption_list, api_key, api_base): + openai.api_key = api_key + openai.api_base = api_base + + question = "What is the color of a red apple" + answer = "red" + pred = "green" + try: + # Compute the correctness score + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + # model="gpt-4", + # model="gpt-4-vision-compatible", + messages=[ + { + "role": "system", + "content": + "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the meaningful match between the predicted answer and the correct answer.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Evaluate the correctness of the prediction compared to the answer." + }, + { + "role": "user", + "content": + "Please evaluate the following video-based question-answer pair:\n\n" + f"Question: {question}\n" + f"Correct Answer: {answer}\n" + f"Predicted Answer: {pred}\n\n" + "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. " + "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}." + } + ] + ) + # Convert response to a Python dictionary. + response_message = completion["choices"][0]["message"]["content"] + response_dict = ast.literal_eval(response_message) + print(response_dict) + + except Exception as e: + print(f"Error processing file : {e}") + + +def caption_summary(long_caption, api_key, api_base): + """ + apply GPT3-Turbo as the combination for original caption and the prompted captions for a video + """ + openai.api_key = api_key + openai.api_base = api_base + + try: + # Compute the correctness score + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "system", + "content": + "You are an intelligent chatbot designed for summarizing from a long sentence. " + }, + { + "role": "user", + "content": + "Please summarize the following sentences. Make it shorter than 70 words." + f"the long sentence: {long_caption}\n" + "Provide your summarization with less than 70 words. " + "DO NOT PROVIDE ANY OTHER TEXT OR EXPLANATION. Only provide the summary sentence. " + } + ] + ) + # "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." + # "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + # "For example, your response should look like this: {'summary': 'your summary sentence'}." + + # Convert response to a Python dictionary. + response_message = completion["choices"][0]["message"]["content"] + response_dict = ast.literal_eval(response_message) + + except Exception as e: + print(f"Error processing file : {e}") + + return response_dict + +if __name__ == "__main__": + caption_summary() \ No newline at end of file diff --git a/opensora/models/diffusion/__init__.py b/opensora/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf3abffa98328aa734ae824b3902765debec8587 --- /dev/null +++ b/opensora/models/diffusion/__init__.py @@ -0,0 +1,7 @@ + +from .latte.modeling_latte import Latte_models + +Diffusion_models = {} +Diffusion_models.update(Latte_models) + + \ No newline at end of file diff --git a/opensora/models/diffusion/diffusion/__init__.py b/opensora/models/diffusion/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..04b2bd3d875d0e9ea9e0059b5f7dc3cea30795dc --- /dev/null +++ b/opensora/models/diffusion/diffusion/__init__.py @@ -0,0 +1,87 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from .respace import SpacedDiffusion, space_timesteps, SpacedDiffusion_T + + +def create_diffusion( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + # learn_sigma=False, + rescale_learned_sigmas=False, + diffusion_steps=1000 +): + from . import gaussian_diffusion as gd + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type + # rescale_timesteps=rescale_timesteps, + ) + +def create_diffusion_T( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + # learn_sigma=False, + rescale_learned_sigmas=False, + diffusion_steps=1000 +): + from . import gaussian_diffusion_t2v as gd + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion_T( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type + # rescale_timesteps=rescale_timesteps, + ) diff --git a/opensora/models/diffusion/diffusion/diffusion_utils.py b/opensora/models/diffusion/diffusion/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e493a6a3ecb91e553a53cc7eadee5cc0d1753060 --- /dev/null +++ b/opensora/models/diffusion/diffusion/diffusion_utils.py @@ -0,0 +1,88 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import torch as th +import numpy as np + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/opensora/models/diffusion/diffusion/gaussian_diffusion.py b/opensora/models/diffusion/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..7fc3d43a7fb627e0738d46c0f03cf1ab29b9258f --- /dev/null +++ b/opensora/models/diffusion/diffusion/gaussian_diffusion.py @@ -0,0 +1,881 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import math + +import numpy as np +import torch as th +import enum + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type + ): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, F, C = x.shape[:3] + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) + # try: + # model_output = model_output.sample # for tav unet + # except: + # model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, F, C * 2, *x.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + # try: + # model_output = model(x_t, t, **model_kwargs).sample # for tav unet + # except: + # model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, F, C = x_t.shape[:3] + assert model_output.shape == (B, F, C * 2, *x_t.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=2) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/opensora/models/diffusion/diffusion/gaussian_diffusion_t2v.py b/opensora/models/diffusion/diffusion/gaussian_diffusion_t2v.py new file mode 100644 index 0000000000000000000000000000000000000000..f14377fed0677ed850467dd6f76451fb394b2c49 --- /dev/null +++ b/opensora/models/diffusion/diffusion/gaussian_diffusion_t2v.py @@ -0,0 +1,898 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import math + +import numpy as np +import torch as th +import enum + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion_T: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type + ): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + #B, F, C = x.shape[:3] + B, C, F = x.shape[:3] + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) + + try: + model_output.shape + except: + model_output = model_output[0] + # try: + # model_output = model_output.sample # for tav unet + # except: + # model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + #assert model_output.shape == (B, F, C * 2, *x.shape[3:]) + #model_output, model_var_values = th.split(model_output, C, dim=2) + #the output shape of uncondition or class condition latte is not the same as the latte_t2v + #BFCHW vs BCFHW + assert model_output.shape == (B, C * 2, F, *x.shape[3:]), f'model_output.shape ({model_output.shape}), != {(B, C * 2, F, *x.shape[3:])}' + model_output, model_var_values = th.split(model_output, C, dim=1) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + # try: + # model_output = model(x_t, t, **model_kwargs).sample # for tav unet + # except: + # model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + #B, F, C = x_t.shape[:3] + #assert model_output.shape == (B, F, C * 2, *x_t.shape[3:]) + #the output shape of uncondition or class condition latte is not the same as the latte_t2v + #BFCHW vs BCFHW + B, C, F = x_t.shape[:3] + assert model_output[0].shape == (B, C * 2, F, *x_t.shape[3:]) + #model_output, model_var_values = th.split(model_output, C, dim=2) + model_output, model_var_values = th.split(model_output[0], C, dim=1) + + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + #frozen_out = th.cat([model_output.detach(), model_var_values], dim=2) + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/opensora/models/diffusion/diffusion/respace.py b/opensora/models/diffusion/diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..aed6ed77f3dd6d38f15e450058ce7fc13d5dc3dc --- /dev/null +++ b/opensora/models/diffusion/diffusion/respace.py @@ -0,0 +1,198 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +import torch +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion +from .gaussian_diffusion_t2v import GaussianDiffusion_T + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + # @torch.compile + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) + +class SpacedDiffusion_T(GaussianDiffusion_T): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + # @torch.compile + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) \ No newline at end of file diff --git a/opensora/models/diffusion/diffusion/timestep_sampler.py b/opensora/models/diffusion/diffusion/timestep_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f369847677d8dbaaadb8297691b1be92cf189f --- /dev/null +++ b/opensora/models/diffusion/diffusion/timestep_sampler.py @@ -0,0 +1,150 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/opensora/models/diffusion/latte/modeling_latte.py b/opensora/models/diffusion/latte/modeling_latte.py new file mode 100644 index 0000000000000000000000000000000000000000..3993dae8121f9090889c161b9e2162a49ee2a41b --- /dev/null +++ b/opensora/models/diffusion/latte/modeling_latte.py @@ -0,0 +1,1183 @@ +import torch + +import os +import json + +from dataclasses import dataclass +from einops import rearrange, repeat +from typing import Any, Dict, Optional, Tuple +from diffusers.models import Transformer2DModel +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate +from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid, ImagePositionalEmbeddings, CaptionProjection +# from diffusers.models.embeddings import PatchEmbed, CombinedTimestepSizeEmbeddings +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear + +import torch +import torch.nn.functional as F +from torch import nn + +from opensora.models.diffusion.utils.pos_embed import get_1d_sincos_pos_embed +from .modules import PatchEmbed, BasicTransformerBlock, BasicTransformerBlock_, AdaLayerNormSingle, Transformer3DModelOutput + + +class Latte(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + patch_size_t: int = 1, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + video_length: int = 16, + attention_mode: str = 'flash' + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.video_length = video_length + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = linear_cls(in_channels, inner_dim) + else: + self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size[0] + self.width = sample_size[1] + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size[0] + self.width = sample_size[1] + + self.patch_size = patch_size + interpolation_scale = self.config.sample_size[0] // 64 # => 64 (= 512 pixart) has interpolation scale 1 + interpolation_scale = max(interpolation_scale, 1) + self.pos_embed = PatchEmbed( + height=sample_size[0], + width=sample_size[1], + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock_( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=None, ############## unconditon do not need cross attn + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=False, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + attention_mode=attention_mode, + ) + for d in range(num_layers) + ] + ) + + # Define temporal transformers blocks + self.temporal_transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock_( # one attention + inner_dim, + num_attention_heads, # num_attention_heads + attention_head_dim, # attention_head_dim 72 + dropout=dropout, + cross_attention_dim=None, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=False, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + attention_mode=attention_mode, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches and norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + elif self.is_input_patches and norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim ** 0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + # self.use_additional_conditions = self.config.sample_size[0] == 128 # False, 128 -> 1024 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + # define temporal positional embedding + # temp_pos_embed = self.get_1d_sincos_temp_embed(inner_dim, video_length) # 1152 hidden size + # self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + + interpolation_scale = self.config.video_length // 5 # => 5 (= 5 our causalvideovae) has interpolation scale 1 + interpolation_scale = max(interpolation_scale, 1) + temp_pos_embed = get_1d_sincos_pos_embed(inner_dim, video_length, interpolation_scale=interpolation_scale) # 1152 hidden size + self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_image_num: int = 0, + enable_temporal_attentions: bool = True, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + input_batch_size, c, frame, h, w = hidden_states.shape + frame = frame - use_image_num + hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w').contiguous() + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + # if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint + # encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + # encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + # encoder_attention_mask = repeat(encoder_attention_mask, 'b 1 l -> (b f) 1 l', f=frame).contiguous() + # elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint + # encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + # encoder_attention_mask_video = encoder_attention_mask[:, :1, ...] + # encoder_attention_mask_video = repeat(encoder_attention_mask_video, 'b 1 l -> b (1 f) l', + # f=frame).contiguous() + # encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...] + # encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1) + # encoder_attention_mask = rearrange(encoder_attention_mask, 'b n l -> (b n) l').contiguous().unsqueeze(1) + + # Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 1. Input + if self.is_input_patches: # here + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + num_patches = height * width + + hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # alrady add positional embeddings + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + # batch_size = hidden_states.shape[0] + batch_size = input_batch_size + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + # 2. Blocks + # if self.caption_projection is not None: + # batch_size = hidden_states.shape[0] + # encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152 + # + # if use_image_num != 0 and self.training: + # encoder_hidden_states_video = encoder_hidden_states[:, :1, ...] + # encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b 1 t d -> b (1 f) t d', + # f=frame).contiguous() + # encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...] + # encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1) + # encoder_hidden_states_spatial = rearrange(encoder_hidden_states, 'b f t d -> (b f) t d').contiguous() + # else: + # encoder_hidden_states_spatial = repeat(encoder_hidden_states, 'b t d -> (b f) t d', + # f=frame).contiguous() + + # prepare timesteps for spatial and temporal block + timestep_spatial = repeat(timestep, 'b d -> (b f) d', f=frame + use_image_num).contiguous() + timestep_temp = repeat(timestep, 'b d -> (b p) d', p=num_patches).contiguous() + + for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)): + + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + spatial_block, + hidden_states, + attention_mask, + None, # encoder_hidden_states_spatial + None, # encoder_attention_mask + timestep_spatial, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + + if enable_temporal_attentions: + hidden_states = rearrange(hidden_states, '(b f) t d -> (b t) f d', b=input_batch_size).contiguous() + + if use_image_num != 0: # image-video joitn training + hidden_states_video = hidden_states[:, :frame, ...] + hidden_states_image = hidden_states[:, frame:, ...] + + if i == 0: + hidden_states_video = hidden_states_video + self.temp_pos_embed + + hidden_states_video = torch.utils.checkpoint.checkpoint( + temp_block, + hidden_states_video, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', + b=input_batch_size).contiguous() + + else: + if i == 0: + hidden_states = hidden_states + self.temp_pos_embed + + hidden_states = torch.utils.checkpoint.checkpoint( + temp_block, + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', + b=input_batch_size).contiguous() + else: + hidden_states = spatial_block( + hidden_states, + attention_mask, + None, # encoder_hidden_states_spatial + None, # encoder_attention_mask + timestep_spatial, + cross_attention_kwargs, + class_labels, + ) + + if enable_temporal_attentions: + + hidden_states = rearrange(hidden_states, '(b f) t d -> (b t) f d', b=input_batch_size).contiguous() + + if use_image_num != 0 and self.training: + hidden_states_video = hidden_states[:, :frame, ...] + hidden_states_image = hidden_states[:, frame:, ...] + + hidden_states_video = temp_block( + hidden_states_video, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + ) + + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', + b=input_batch_size).contiguous() + + else: + if i == 0: + hidden_states = hidden_states + self.temp_pos_embed + + hidden_states = temp_block( + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + ) + + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', + b=input_batch_size).contiguous() + + if self.is_input_patches: + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame + use_image_num).contiguous() + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + output = rearrange(output, '(b f) c h w -> b c f h w', b=input_batch_size).contiguous() + + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + # def get_1d_sincos_temp_embed(self, embed_dim, length): + # pos = torch.arange(0, length).unsqueeze(1) + # return get_1d_sincos_pos_embed_from_grid(embed_dim, pos) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, **kwargs): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + model = cls.from_config(config, **kwargs) + + # model_files = [ + # os.path.join(pretrained_model_path, 'diffusion_pytorch_model.bin'), + # os.path.join(pretrained_model_path, 'diffusion_pytorch_model.safetensors') + # ] + + # model_file = None + + # for fp in model_files: + # if os.path.exists(fp): + # model_file = fp + + # if not model_file: + # raise RuntimeError(f"{model_file} does not exist") + + # if model_file.split(".")[-1] == "safetensors": + # from safetensors import safe_open + # state_dict = {} + # with safe_open(model_file, framework="pt", device="cpu") as f: + # for key in f.keys(): + # state_dict[key] = f.get_tensor(key) + # else: + # state_dict = torch.load(model_file, map_location="cpu") + + # for k, v in model.state_dict().items(): + # if 'temporal_transformer_blocks' in k: + # state_dict.update({k: v}) + + # model.load_state_dict(state_dict) + + return model + + def forward_with_cfg(self, x, timestep, class_labels=None, cfg_scale=7.0, attention_mask=None): + """ + Forward pass of Latte, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, timestep, class_labels=class_labels, attention_mask=attention_mask) + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + eps, rest = model_out[:, :, :self.in_channels], model_out[:, :, self.in_channels:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=2) + +class LatteT2V(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + patch_size_t: int = 1, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + video_length: int = 16, + attention_mode: str = 'flash' + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.video_length = video_length + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = linear_cls(in_channels, inner_dim) + else: + self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size[0] + self.width = sample_size[1] + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size[0] + self.width = sample_size[1] + + self.patch_size = patch_size + interpolation_scale = self.config.sample_size[0] // 64 # => 64 (= 512 pixart) has interpolation scale 1 + interpolation_scale = max(interpolation_scale, 1) + self.pos_embed = PatchEmbed( + height=sample_size[0], + width=sample_size[1], + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + ) + + # 3. Define transformers blocks, spatial attention + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + attention_mode=attention_mode + ) + for d in range(num_layers) + ] + ) + + # Define temporal transformers blocks + self.temporal_transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock_( # one attention + inner_dim, + num_attention_heads, # num_attention_heads + attention_head_dim, # attention_head_dim 72 + dropout=dropout, + cross_attention_dim=None, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=False, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + attention_mode=attention_mode + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches and norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + elif self.is_input_patches and norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim ** 0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + # self.use_additional_conditions = self.config.sample_size[0] == 128 # False, 128 -> 1024 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + # define temporal positional embedding + # temp_pos_embed = self.get_1d_sincos_temp_embed(inner_dim, video_length) # 1152 hidden size + + interpolation_scale = self.config.video_length // 5 # => 5 (= 5 our causalvideovae) has interpolation scale 1 + interpolation_scale = max(interpolation_scale, 1) + temp_pos_embed = get_1d_sincos_pos_embed(inner_dim, video_length, interpolation_scale=interpolation_scale) # 1152 hidden size + self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_image_num: int = 0, + enable_temporal_attentions: bool = True, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + input_batch_size, c, frame, h, w = hidden_states.shape + # print(hidden_states.shape, input_batch_size, c, frame, h, w, use_image_num) + # print(timestep) + # print(encoder_hidden_states.shape) + # print(encoder_attention_mask.shape) + frame = frame - use_image_num # 20-4=16 + hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w').contiguous() + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.to(self.dtype) + # 1 + 4, 1 -> video condition, 4 -> image condition + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + encoder_attention_mask = repeat(encoder_attention_mask, 'b 1 l -> (b f) 1 l', f=frame).contiguous() + encoder_attention_mask = encoder_attention_mask.to(self.dtype) + elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask_video = encoder_attention_mask[:, :1, ...] + encoder_attention_mask_video = repeat(encoder_attention_mask_video, 'b 1 l -> b (1 f) l', + f=frame).contiguous() + encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...] + encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1) + encoder_attention_mask = rearrange(encoder_attention_mask, 'b n l -> (b n) l').contiguous().unsqueeze(1) + encoder_attention_mask = encoder_attention_mask.to(self.dtype) + + # Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 1. Input + if self.is_input_patches: # here + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + num_patches = height * width + + hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # alrady add positional embeddings + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + # batch_size = hidden_states.shape[0] + batch_size = input_batch_size + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states.to(self.dtype)) # 3 120 1152 + + if use_image_num != 0 and self.training: + encoder_hidden_states_video = encoder_hidden_states[:, :1, ...] + encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b 1 t d -> b (1 f) t d', f=frame).contiguous() + encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...] + encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1) + encoder_hidden_states_spatial = rearrange(encoder_hidden_states, 'b f t d -> (b f) t d').contiguous() + else: + encoder_hidden_states_spatial = repeat(encoder_hidden_states, 'b t d -> (b f) t d', f=frame).contiguous() + + # prepare timesteps for spatial and temporal block + timestep_spatial = repeat(timestep, 'b d -> (b f) d', f=frame + use_image_num).contiguous() + timestep_temp = repeat(timestep, 'b d -> (b p) d', p=num_patches).contiguous() + + for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)): + + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + spatial_block, + hidden_states, + attention_mask, + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + + if enable_temporal_attentions: + hidden_states = rearrange(hidden_states, '(b f) t d -> (b t) f d', b=input_batch_size).contiguous() + + if use_image_num != 0: # image-video joitn training + hidden_states_video = hidden_states[:, :frame, ...] + hidden_states_image = hidden_states[:, frame:, ...] + + if i == 0: + hidden_states_video = hidden_states_video + self.temp_pos_embed + + hidden_states_video = torch.utils.checkpoint.checkpoint( + temp_block, + hidden_states_video, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', + b=input_batch_size).contiguous() + + else: + if i == 0: + hidden_states = hidden_states + self.temp_pos_embed + + hidden_states = torch.utils.checkpoint.checkpoint( + temp_block, + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', + b=input_batch_size).contiguous() + else: + hidden_states = spatial_block( + hidden_states, + attention_mask, + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + cross_attention_kwargs, + class_labels, + ) + + if enable_temporal_attentions: + # b c f h w, f = 16 + 4 + hidden_states = rearrange(hidden_states, '(b f) t d -> (b t) f d', b=input_batch_size).contiguous() + + if use_image_num != 0 and self.training: + hidden_states_video = hidden_states[:, :frame, ...] + hidden_states_image = hidden_states[:, frame:, ...] + + hidden_states_video = temp_block( + hidden_states_video, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + ) + + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', + b=input_batch_size).contiguous() + + else: + if i == 0: + hidden_states = hidden_states + self.temp_pos_embed + + hidden_states = temp_block( + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + ) + + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', + b=input_batch_size).contiguous() + + if self.is_input_patches: + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame + use_image_num).contiguous() + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + output = rearrange(output, '(b f) c h w -> b c f h w', b=input_batch_size).contiguous() + + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + # def get_1d_sincos_temp_embed(self, embed_dim, length): + # pos = torch.arange(0, length).unsqueeze(1) + # return get_1d_sincos_pos_embed_from_grid(embed_dim, pos) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, **kwargs): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + model = cls.from_config(config, **kwargs) + + # model_files = [ + # os.path.join(pretrained_model_path, 'diffusion_pytorch_model.bin'), + # os.path.join(pretrained_model_path, 'diffusion_pytorch_model.safetensors') + # ] + + # model_file = None + + # for fp in model_files: + # if os.path.exists(fp): + # model_file = fp + + # if not model_file: + # raise RuntimeError(f"{model_file} does not exist") + + # if model_file.split(".")[-1] == "safetensors": + # from safetensors import safe_open + # state_dict = {} + # with safe_open(model_file, framework="pt", device="cpu") as f: + # for key in f.keys(): + # state_dict[key] = f.get_tensor(key) + # else: + # state_dict = torch.load(model_file, map_location="cpu") + + # for k, v in model.state_dict().items(): + # if 'temporal_transformer_blocks' in k: + # state_dict.update({k: v}) + + # model.load_state_dict(state_dict) + + return model + +# depth = num_layers * 2 +def Latte_XL_122(**kwargs): + return Latte(num_layers=28, attention_head_dim=72, num_attention_heads=16, patch_size_t=1, patch_size=2, norm_type="ada_norm_single", **kwargs) + +def LatteClass_XL_122(**kwargs): + return Latte(num_layers=28, attention_head_dim=72, num_attention_heads=16, patch_size_t=1, patch_size=2, norm_type="ada_norm_zero", **kwargs) + +def LatteT2V_XL_122(**kwargs): + return LatteT2V(num_layers=28, attention_head_dim=72, num_attention_heads=16, patch_size_t=1, patch_size=2, + norm_type="ada_norm_single", caption_channels=4096, cross_attention_dim=1152, **kwargs) + +Latte_models = { + "Latte-XL/122": Latte_XL_122, + "LatteClass-XL/122": LatteClass_XL_122, + "LatteT2V-XL/122": LatteT2V_XL_122, +} + +if __name__ == '__main__': + a = json.load(open('./config.json', 'r')) + model = LatteT2V.from_config(a) + ckpt = torch.load(r"E:\下载\t2v.pt", map_location='cpu')['model'] + model.load_state_dict(ckpt) + print(model) \ No newline at end of file diff --git a/opensora/models/diffusion/latte/modules.py b/opensora/models/diffusion/latte/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..786045dc71e25381e24fd5cd1141789d68e8a972 --- /dev/null +++ b/opensora/models/diffusion/latte/modules.py @@ -0,0 +1,1560 @@ +from importlib import import_module + +import numpy as np +import torch + +import os +import json + +from dataclasses import dataclass +from einops import rearrange, repeat +from typing import Any, Dict, Optional, Tuple, Callable +from diffusers.models import Transformer2DModel +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_xformers_available +from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid, ImagePositionalEmbeddings, CaptionProjection, \ + PatchEmbed, CombinedTimestepSizeEmbeddings +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear + +import torch +import torch.nn.functional as F +from torch import nn +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.embeddings import SinusoidalPositionalEmbedding +from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero +from diffusers.models.attention_processor import SpatialNorm, LORA_ATTENTION_PROCESSORS, \ + CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0, \ + AttnAddedKVProcessor, AttnAddedKVProcessor2_0, SlicedAttnAddedKVProcessor, XFormersAttnAddedKVProcessor, \ + LoRAAttnAddedKVProcessor, LoRAXFormersAttnProcessor, XFormersAttnProcessor, LoRAAttnProcessor2_0, LoRAAttnProcessor, \ + AttnProcessor, SlicedAttnProcessor, logger +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU + +from dataclasses import dataclass + +from torch import nn + +from opensora.models.diffusion.utils.pos_embed import get_2d_sincos_pos_embed + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=1, + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size = patch_size + # See: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 + self.height, self.width = height // patch_size, width // patch_size + self.base_size = height // patch_size + self.interpolation_scale = interpolation_scale + pos_embed = get_2d_sincos_pos_embed( + embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale + ) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + def forward(self, latent): + height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size + + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + + # Interpolate positional embeddings if needed. + # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) + if self.height != height or self.width != width: + pos_embed = get_2d_sincos_pos_embed( + embed_dim=self.pos_embed.shape[-1], + grid_size=(height, width), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed) + pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) + else: + pos_embed = self.pos_embed + + return (latent + pos_embed).to(latent.dtype) + + +@maybe_allow_in_graph +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + attention_mode: str = 'xformers', + ): + super().__init__() + self.inner_dim = dim_head * heads + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + if USE_PEFT_BACKEND: + linear_cls = nn.Linear + else: + linear_cls = LoRACompatibleLinear + + self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0(attention_mode) if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ) -> None: + r""" + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ + is_lora = hasattr(self, "processor") and isinstance( + self.processor, + LORA_ATTENTION_PROCESSORS, + ) + is_custom_diffusion = hasattr(self, "processor") and isinstance( + self.processor, + (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), + ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + LoRAAttnAddedKVProcessor, + ), + ) + + if use_memory_efficient_attention_xformers: + if is_added_kv_processor and (is_lora or is_custom_diffusion): + raise NotImplementedError( + f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}" + ) + if not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + + if is_lora: + # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers + # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? + processor = LoRAXFormersAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + elif is_custom_diffusion: + processor = CustomDiffusionXFormersAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + # throw warning + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + if is_lora: + attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + processor = attn_processor_class( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + ) + processor.load_state_dict(self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.device) + elif is_custom_diffusion: + attn_processor_class = ( + CustomDiffusionAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else CustomDiffusionAttnProcessor + ) + processor = attn_processor_class( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + + self.set_processor(processor) + + def set_attention_slice(self, slice_size: int) -> None: + r""" + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + _remove_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to remove LoRA layers from the model. + """ + if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None: + deprecate( + "set_processor to offload LoRA", + "0.26.0", + "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.", + ) + # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete + # We need to remove all LoRA layers + # Don't forget to remove ALL `_remove_lora` from the codebase + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible + # serialization format for LoRA Attention Processors. It should be deleted once the integration + # with PEFT is completed. + is_lora_activated = { + name: module.lora_layer is not None + for name, module in self.named_modules() + if hasattr(module, "lora_layer") + } + + # 1. if no layer has a LoRA activated we can return the processor as usual + if not any(is_lora_activated.values()): + return self.processor + + # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` + is_lora_activated.pop("add_k_proj", None) + is_lora_activated.pop("add_v_proj", None) + # 2. else it is not posssible that only some layers have LoRA activated + if not all(is_lora_activated.values()): + raise ValueError( + f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" + ) + + # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor + non_lora_processor_cls_name = self.processor.__class__.__name__ + lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name) + + hidden_size = self.inner_dim + + # now create a LoRA attention processor from the LoRA layers + if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]: + kwargs = { + "cross_attention_dim": self.cross_attention_dim, + "rank": self.to_q.lora_layer.rank, + "network_alpha": self.to_q.lora_layer.network_alpha, + "q_rank": self.to_q.lora_layer.rank, + "q_hidden_size": self.to_q.lora_layer.out_features, + "k_rank": self.to_k.lora_layer.rank, + "k_hidden_size": self.to_k.lora_layer.out_features, + "v_rank": self.to_v.lora_layer.rank, + "v_hidden_size": self.to_v.lora_layer.out_features, + "out_rank": self.to_out[0].lora_layer.rank, + "out_hidden_size": self.to_out[0].lora_layer.out_features, + } + + if hasattr(self.processor, "attention_op"): + kwargs["attention_op"] = self.processor.attention_op + + lora_processor = lora_processor_cls(hidden_size, **kwargs) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + elif lora_processor_cls == LoRAAttnAddedKVProcessor: + lora_processor = lora_processor_cls( + hidden_size, + cross_attention_dim=self.add_k_proj.weight.shape[0], + rank=self.to_q.lora_layer.rank, + network_alpha=self.to_q.lora_layer.network_alpha, + ) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + + # only save if used + if self.add_k_proj.lora_layer is not None: + lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict()) + lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict()) + else: + lora_processor.add_k_proj_lora = None + lora_processor.add_v_proj_lora = None + else: + raise ValueError(f"{lora_processor_cls} does not exist.") + + return lora_processor + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + + return tensor + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self, attention_mode='xformers'): + self.attention_mode = attention_mode + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + residual = hidden_states + + args = () if USE_PEFT_BACKEND else (scale,) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + args = () if USE_PEFT_BACKEND else (scale,) + query = attn.to_q(hidden_states, *args) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + if self.attention_mode == 'flash': + assert attention_mask is None or torch.all(attention_mask.bool()), 'flash-attn do not support attention_mask' + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False): + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) + elif self.attention_mode == 'xformers': + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True): + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + elif self.attention_mode == 'math': + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + else: + raise NotImplementedError(f'Found attention_mode: {self.attention_mode}') + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +@maybe_allow_in_graph +class GatedSelfAttentionDense(nn.Module): + r""" + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock_(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + attention_mode: str = "xformers", + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + attention_mode=attention_mode + ) + + # # 2. Cross-Attn + # if cross_attention_dim is not None or double_self_attention: + # # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # # the second cross attention block. + # self.norm2 = ( + # AdaLayerNorm(dim, num_embeds_ada_norm) + # if self.use_ada_layer_norm + # else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + # ) + # self.attn2 = Attention( + # query_dim=dim, + # cross_attention_dim=cross_attention_dim if not double_self_attention else None, + # heads=num_attention_heads, + # dim_head=attention_head_dim, + # dropout=dropout, + # bias=attention_bias, + # upcast_attention=upcast_attention, + # ) # is self-attn if encoder_hidden_states is none + # else: + # self.norm2 = None + # self.attn2 = None + + # 3. Feed-forward + # if not self.use_ada_layer_norm_single: + # self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if self.use_ada_layer_norm_single: + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim ** 0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.use_layer_norm: + norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_single: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.use_ada_layer_norm_single: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # # 3. Cross-Attention + # if self.attn2 is not None: + # if self.use_ada_layer_norm: + # norm_hidden_states = self.norm2(hidden_states, timestep) + # elif self.use_ada_layer_norm_zero or self.use_layer_norm: + # norm_hidden_states = self.norm2(hidden_states) + # elif self.use_ada_layer_norm_single: + # # For PixArt norm2 isn't applied here: + # # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + # norm_hidden_states = hidden_states + # else: + # raise ValueError("Incorrect norm") + + # if self.pos_embed is not None and self.use_ada_layer_norm_single is False: + # norm_hidden_states = self.pos_embed(norm_hidden_states) + + # attn_output = self.attn2( + # norm_hidden_states, + # encoder_hidden_states=encoder_hidden_states, + # attention_mask=encoder_attention_mask, + # **cross_attention_kwargs, + # ) + # hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # if not self.use_ada_layer_norm_single: + # norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.use_ada_layer_norm_single: + # norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = self.norm3(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [ + self.ff(hid_slice, scale=lora_scale) + for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) + ], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.use_ada_layer_norm_single: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + attention_mode: str = "xformers" + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + attention_mode=attention_mode + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + attention_mode='xformers', # only xformers support attention_mask + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if not self.use_ada_layer_norm_single: + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if self.use_ada_layer_norm_single: + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.use_layer_norm: + norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_single: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.use_ada_layer_norm_single: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.use_ada_layer_norm_zero or self.use_layer_norm: + norm_hidden_states = self.norm2(hidden_states) + elif self.use_ada_layer_norm_single: + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.use_ada_layer_norm_single is False: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + if not self.use_ada_layer_norm_single: + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.use_ada_layer_norm_single: + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward( + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.use_ada_layer_norm_single: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.emb = CombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + batch_size: int = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + embedded_timestep = self.emb(timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, + aspect_ratio=None) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor diff --git a/opensora/models/diffusion/latte/pos.py b/opensora/models/diffusion/latte/pos.py new file mode 100644 index 0000000000000000000000000000000000000000..6a735a6acabaf5fa7b97a718653676c4f37a99ac --- /dev/null +++ b/opensora/models/diffusion/latte/pos.py @@ -0,0 +1,142 @@ +# import numpy as np +# import torch +# import torch.nn as nn +# from math import pi +# from einops import rearrange, repeat +# +# ################################################################################# +# # Sine/Cosine Positional Embedding Functions # +# ################################################################################# +# # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# +# def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): +# """ +# grid_size: int of the grid height and width +# return: +# pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) +# """ +# grid_h = np.arange(grid_size, dtype=np.float32) +# grid_w = np.arange(grid_size, dtype=np.float32) +# grid = np.meshgrid(grid_w, grid_h) # here w goes first +# grid = np.stack(grid, axis=0) +# +# grid = grid.reshape([2, 1, grid_size, grid_size]) +# pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) +# if cls_token and extra_tokens > 0: +# pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) +# return pos_embed +# +# +# def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): +# assert embed_dim % 2 == 0 +# +# # use half of dimensions to encode grid_h +# emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) +# emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) +# +# emb = np.concatenate([emb_h, emb_w], axis=1) +# return emb +# +# +# def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): +# """ +# embed_dim: output dimension for each position +# pos: a list of positions to be encoded: size (M,) +# out: (M, D) +# """ +# assert embed_dim % 2 == 0 +# omega = np.arange(embed_dim // 2, dtype=np.float64) +# omega /= embed_dim / 2. +# omega = 1. / 10000**omega +# +# pos = pos.reshape(-1) +# out = np.einsum('m,d->md', pos, omega) +# +# emb_sin = np.sin(out) +# emb_cos = np.cos(out) +# +# emb = np.concatenate([emb_sin, emb_cos], axis=1) +# return emb +# +# def broadcat(tensors, dim=-1): +# num_tensors = len(tensors) +# shape_lens = set(list(map(lambda t: len(t.shape), tensors))) +# assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' +# shape_len = list(shape_lens)[0] +# dim = (dim + shape_len) if dim < 0 else dim +# dims = list(zip(*map(lambda t: list(t.shape), tensors))) +# expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] +# assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' +# max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) +# expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) +# expanded_dims.insert(dim, (dim, dims[dim])) +# expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) +# tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) +# return torch.cat(tensors, dim=dim) +# +# +# def rotate_half(x): +# x = rearrange(x, '... (d r) -> ... d r', r=2) +# x1, x2 = x.unbind(dim=-1) +# x = torch.stack((-x2, x1), dim=-1) +# return rearrange(x, '... d r -> ... (d r)') +# +# ################################################################################# +# # VisionRotary # +# ################################################################################# +# # References: +# # EVA: https://github.com/baaivision/EVA +# # Transformer升级之路:2、博采众长的旋转式位置编码: https://spaces.ac.cn/archives/8265 +# # Transformer升级之路:4、二维位置的旋转式位置编码: https://spaces.ac.cn/archives/8397 +# +# class VisionRotaryEmbeddingFast(nn.Module): +# def __init__( +# self, +# dim, +# pt_hw=(int, int), # (H, W) +# ft_hw=None, +# custom_freqs = None, +# freqs_for = 'lang', +# theta = 10000, +# max_freq = 10, +# num_freqs = 1, +# ): +# super().__init__() +# # Unlike a 1d RoPE, a 2d RoPE requires splitting the dimension into four parts +# # References: https://spaces.ac.cn/archives/8397 +# +# if custom_freqs: +# freqs = custom_freqs +# elif freqs_for == 'lang': +# freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) +# elif freqs_for == 'pixel': +# freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi +# elif freqs_for == 'constant': +# freqs = torch.ones(num_freqs).float() +# else: +# raise ValueError(f'unknown modality {freqs_for}') +# +# if ft_hw is None: ft_hw = pt_hw +# h_t = torch.arange(ft_hw[0]) / ft_hw[0] * pt_hw[0] +# w_t = torch.arange(ft_hw[1]) / ft_hw[1] * pt_hw[1] +# +# h_freqs = torch.einsum('..., f -> ... f', h_t, freqs) +# w_freqs = torch.einsum('..., f -> ... f', w_t, freqs) +# +# h_freqs = repeat(h_freqs, '... n -> ... (n r)', r=2) +# w_freqs = repeat(w_freqs, '... n -> ... (n r)', r=2) +# +# freqs = broadcat((h_freqs[:, None, :], w_freqs[None, :, :]), dim=-1) +# freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) +# freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) +# +# self.register_buffer("freqs_cos", freqs_cos) +# self.register_buffer("freqs_sin", freqs_sin) +# +# def forward(self, t): +# # 2d RoPE: [[cos(h*theta), -sin(h*theta), 0, 0 ], +# # [sin(h*theta), cos(h*theta), 0, 0 ], +# # [0, 0, cos(w*theta), -sin(w*theta)], +# # [0, 0, sin(w*theta), cos(w*theta) ],] +# +# return t * self.freqs_cos + rotate_half(t) * self.freqs_sin \ No newline at end of file diff --git a/opensora/models/diffusion/transport/__init__.py b/opensora/models/diffusion/transport/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db68edd20c9716e74ef1c853e968227efe45be29 --- /dev/null +++ b/opensora/models/diffusion/transport/__init__.py @@ -0,0 +1,63 @@ +from .transport import Transport, ModelType, WeightType, PathType, Sampler + +def create_transport( + path_type='Linear', + prediction="velocity", + loss_weight=None, + train_eps=None, + sample_eps=None, +): + """function for creating Transport object + **Note**: model prediction defaults to velocity + Args: + - path_type: type of path to use; default to linear + - learn_score: set model prediction to score + - learn_noise: set model prediction to noise + - velocity_weighted: weight loss by velocity weight + - likelihood_weighted: weight loss by likelihood weight + - train_eps: small epsilon for avoiding instability during training + - sample_eps: small epsilon for avoiding instability during sampling + """ + + if prediction == "noise": + model_type = ModelType.NOISE + elif prediction == "score": + model_type = ModelType.SCORE + else: + model_type = ModelType.VELOCITY + + if loss_weight == "velocity": + loss_type = WeightType.VELOCITY + elif loss_weight == "likelihood": + loss_type = WeightType.LIKELIHOOD + else: + loss_type = WeightType.NONE + + path_choice = { + "Linear": PathType.LINEAR, + "GVP": PathType.GVP, + "VP": PathType.VP, + } + + path_type = path_choice[path_type] + + if (path_type in [PathType.VP]): + train_eps = 1e-5 if train_eps is None else train_eps + sample_eps = 1e-3 if train_eps is None else sample_eps + elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY): + train_eps = 1e-3 if train_eps is None else train_eps + sample_eps = 1e-3 if train_eps is None else sample_eps + else: # velocity & [GVP, LINEAR] is stable everywhere + train_eps = 0 + sample_eps = 0 + + # create flow state + state = Transport( + model_type=model_type, + path_type=path_type, + loss_type=loss_type, + train_eps=train_eps, + sample_eps=sample_eps, + ) + + return state \ No newline at end of file diff --git a/opensora/models/diffusion/transport/integrators.py b/opensora/models/diffusion/transport/integrators.py new file mode 100644 index 0000000000000000000000000000000000000000..adf7c7b4c50b6ff6c63973e0ddaa65b9759274c0 --- /dev/null +++ b/opensora/models/diffusion/transport/integrators.py @@ -0,0 +1,117 @@ +import numpy as np +import torch as th +import torch.nn as nn +from torchdiffeq import odeint +from functools import partial +from tqdm import tqdm + +class sde: + """SDE solver class""" + def __init__( + self, + drift, + diffusion, + *, + t0, + t1, + num_steps, + sampler_type, + ): + assert t0 < t1, "SDE sampler has to be in forward time" + + self.num_timesteps = num_steps + self.t = th.linspace(t0, t1, num_steps) + self.dt = self.t[1] - self.t[0] + self.drift = drift + self.diffusion = diffusion + self.sampler_type = sampler_type + + def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs): + w_cur = th.randn(x.size()).to(x) + t = th.ones(x.size(0)).to(x) * t + dw = w_cur * th.sqrt(self.dt) + drift = self.drift(x, t, model, **model_kwargs) + diffusion = self.diffusion(x, t) + mean_x = x + drift * self.dt + x = mean_x + th.sqrt(2 * diffusion) * dw + return x, mean_x + + def __Heun_step(self, x, _, t, model, **model_kwargs): + w_cur = th.randn(x.size()).to(x) + dw = w_cur * th.sqrt(self.dt) + t_cur = th.ones(x.size(0)).to(x) * t + diffusion = self.diffusion(x, t_cur) + xhat = x + th.sqrt(2 * diffusion) * dw + K1 = self.drift(xhat, t_cur, model, **model_kwargs) + xp = xhat + self.dt * K1 + K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs) + return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step + + def __forward_fn(self): + """TODO: generalize here by adding all private functions ending with steps to it""" + sampler_dict = { + "Euler": self.__Euler_Maruyama_step, + "Heun": self.__Heun_step, + } + + try: + sampler = sampler_dict[self.sampler_type] + except: + raise NotImplementedError("Smapler type not implemented.") + + return sampler + + def sample(self, init, model, **model_kwargs): + """forward loop of sde""" + x = init + mean_x = init + samples = [] + sampler = self.__forward_fn() + for ti in self.t[:-1]: + with th.no_grad(): + x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs) + samples.append(x) + + return samples + +class ode: + """ODE solver class""" + def __init__( + self, + drift, + *, + t0, + t1, + sampler_type, + num_steps, + atol, + rtol, + ): + assert t0 < t1, "ODE sampler has to be in forward time" + + self.drift = drift + self.t = th.linspace(t0, t1, num_steps) + self.atol = atol + self.rtol = rtol + self.sampler_type = sampler_type + + def sample(self, x, model, **model_kwargs): + + device = x[0].device if isinstance(x, tuple) else x.device + def _fn(t, x): + t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t + model_output = self.drift(x, t, model, **model_kwargs) + return model_output + + t = self.t.to(device) + atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol] + rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol] + samples = odeint( + _fn, + x, + t, + method=self.sampler_type, + atol=atol, + rtol=rtol + ) + return samples \ No newline at end of file diff --git a/opensora/models/diffusion/transport/path.py b/opensora/models/diffusion/transport/path.py new file mode 100644 index 0000000000000000000000000000000000000000..156a7b0dea03497a85306ebbeedfe4fbedf87c27 --- /dev/null +++ b/opensora/models/diffusion/transport/path.py @@ -0,0 +1,192 @@ +import torch as th +import numpy as np +from functools import partial + +def expand_t_like_x(t, x): + """Function to reshape time t to broadcastable dimension of x + Args: + t: [batch_dim,], time vector + x: [batch_dim,...], data point + """ + dims = [1] * (len(x.size()) - 1) + t = t.view(t.size(0), *dims) + return t + + +#################### Coupling Plans #################### + +class ICPlan: + """Linear Coupling Plan""" + def __init__(self, sigma=0.0): + self.sigma = sigma + + def compute_alpha_t(self, t): + """Compute the data coefficient along the path""" + return t, 1 + + def compute_sigma_t(self, t): + """Compute the noise coefficient along the path""" + return 1 - t, -1 + + def compute_d_alpha_alpha_ratio_t(self, t): + """Compute the ratio between d_alpha and alpha""" + return 1 / t + + def compute_drift(self, x, t): + """We always output sde according to score parametrization; """ + t = expand_t_like_x(t, x) + alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + drift = alpha_ratio * x + diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t + + return -drift, diffusion + + def compute_diffusion(self, x, t, form="constant", norm=1.0): + """Compute the diffusion term of the SDE + Args: + x: [batch_dim, ...], data point + t: [batch_dim,], time vector + form: str, form of the diffusion term + norm: float, norm of the diffusion term + """ + t = expand_t_like_x(t, x) + choices = { + "constant": norm, + "SBDM": norm * self.compute_drift(x, t)[1], + "sigma": norm * self.compute_sigma_t(t)[0], + "linear": norm * (1 - t), + "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2, + "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2, + } + + try: + diffusion = choices[form] + except KeyError: + raise NotImplementedError(f"Diffusion form {form} not implemented") + + return diffusion + + def get_score_from_velocity(self, velocity, x, t): + """Wrapper function: transfrom velocity prediction model to score + Args: + velocity: [batch_dim, ...] shaped tensor; velocity model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + alpha_t, d_alpha_t = self.compute_alpha_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + mean = x + reverse_alpha_ratio = alpha_t / d_alpha_t + var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t + score = (reverse_alpha_ratio * velocity - mean) / var + return score + + def get_noise_from_velocity(self, velocity, x, t): + """Wrapper function: transfrom velocity prediction model to denoiser + Args: + velocity: [batch_dim, ...] shaped tensor; velocity model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + alpha_t, d_alpha_t = self.compute_alpha_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + mean = x + reverse_alpha_ratio = alpha_t / d_alpha_t + var = reverse_alpha_ratio * d_sigma_t - sigma_t + noise = (reverse_alpha_ratio * velocity - mean) / var + return noise + + def get_velocity_from_score(self, score, x, t): + """Wrapper function: transfrom score prediction model to velocity + Args: + score: [batch_dim, ...] shaped tensor; score model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + drift, var = self.compute_drift(x, t) + velocity = var * score - drift + return velocity + + def compute_mu_t(self, t, x0, x1): + """Compute the mean of time-dependent density p_t""" + t = expand_t_like_x(t, x1) + alpha_t, _ = self.compute_alpha_t(t) + sigma_t, _ = self.compute_sigma_t(t) + return alpha_t * x1 + sigma_t * x0 + + def compute_xt(self, t, x0, x1): + """Sample xt from time-dependent density p_t; rng is required""" + xt = self.compute_mu_t(t, x0, x1) + return xt + + def compute_ut(self, t, x0, x1, xt): + """Compute the vector field corresponding to p_t""" + t = expand_t_like_x(t, x1) + _, d_alpha_t = self.compute_alpha_t(t) + _, d_sigma_t = self.compute_sigma_t(t) + return d_alpha_t * x1 + d_sigma_t * x0 + + def plan(self, t, x0, x1): + xt = self.compute_xt(t, x0, x1) + ut = self.compute_ut(t, x0, x1, xt) + return t, xt, ut + + +class VPCPlan(ICPlan): + """class for VP path flow matching""" + + def __init__(self, sigma_min=0.1, sigma_max=20.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min + self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min + + + def compute_alpha_t(self, t): + """Compute coefficient of x1""" + alpha_t = self.log_mean_coeff(t) + alpha_t = th.exp(alpha_t) + d_alpha_t = alpha_t * self.d_log_mean_coeff(t) + return alpha_t, d_alpha_t + + def compute_sigma_t(self, t): + """Compute coefficient of x0""" + p_sigma_t = 2 * self.log_mean_coeff(t) + sigma_t = th.sqrt(1 - th.exp(p_sigma_t)) + d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t) + return sigma_t, d_sigma_t + + def compute_d_alpha_alpha_ratio_t(self, t): + """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" + return self.d_log_mean_coeff(t) + + def compute_drift(self, x, t): + """Compute the drift term of the SDE""" + t = expand_t_like_x(t, x) + beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min) + return -0.5 * beta_t * x, beta_t / 2 + + +class GVPCPlan(ICPlan): + def __init__(self, sigma=0.0): + super().__init__(sigma) + + def compute_alpha_t(self, t): + """Compute coefficient of x1""" + alpha_t = th.sin(t * np.pi / 2) + d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2) + return alpha_t, d_alpha_t + + def compute_sigma_t(self, t): + """Compute coefficient of x0""" + sigma_t = th.cos(t * np.pi / 2) + d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2) + return sigma_t, d_sigma_t + + def compute_d_alpha_alpha_ratio_t(self, t): + """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" + return np.pi / (2 * th.tan(t * np.pi / 2)) \ No newline at end of file diff --git a/opensora/models/diffusion/transport/transport.py b/opensora/models/diffusion/transport/transport.py new file mode 100644 index 0000000000000000000000000000000000000000..396c516cfc64516a39212d95ff895c98135eef17 --- /dev/null +++ b/opensora/models/diffusion/transport/transport.py @@ -0,0 +1,443 @@ +import torch as th +import numpy as np +import logging + +import enum + +from . import path +from .utils import EasyDict, log_state, mean_flat +from .integrators import ode, sde + +class ModelType(enum.Enum): + """ + Which type of output the model predicts. + """ + + NOISE = enum.auto() # the model predicts epsilon + SCORE = enum.auto() # the model predicts \nabla \log p(x) + VELOCITY = enum.auto() # the model predicts v(x) + +class PathType(enum.Enum): + """ + Which type of path to use. + """ + + LINEAR = enum.auto() + GVP = enum.auto() + VP = enum.auto() + +class WeightType(enum.Enum): + """ + Which type of weighting to use. + """ + + NONE = enum.auto() + VELOCITY = enum.auto() + LIKELIHOOD = enum.auto() + + +class Transport: + + def __init__( + self, + *, + model_type, + path_type, + loss_type, + train_eps, + sample_eps, + ): + path_options = { + PathType.LINEAR: path.ICPlan, + PathType.GVP: path.GVPCPlan, + PathType.VP: path.VPCPlan, + } + + self.loss_type = loss_type + self.model_type = model_type + self.path_sampler = path_options[path_type]() + self.train_eps = train_eps + self.sample_eps = sample_eps + + def prior_logp(self, z): + ''' + Standard multivariate normal prior + Assume z is batched + ''' + shape = th.tensor(z.size()) + N = th.prod(shape[1:]) + _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2. + return th.vmap(_fn)(z) + + + def check_interval( + self, + train_eps, + sample_eps, + *, + diffusion_form="SBDM", + sde=False, + reverse=False, + eval=False, + last_step_size=0.0, + ): + t0 = 0 + t1 = 1 + eps = train_eps if not eval else sample_eps + if (type(self.path_sampler) in [path.VPCPlan]): + + t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size + + elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \ + and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step + + t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0 + t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size + + if reverse: + t0, t1 = 1 - t0, 1 - t1 + + return t0, t1 + + + def sample(self, x1): + """Sampling x0 & t based on shape of x1 (if needed) + Args: + x1 - data point; [batch, *dim] + """ + + x0 = th.randn_like(x1) + t0, t1 = self.check_interval(self.train_eps, self.sample_eps) + t = th.rand((x1.shape[0],)) * (t1 - t0) + t0 + t = t.to(x1) + return t, x0, x1 + + + def training_losses( + self, + model, + x1, + model_kwargs=None + ): + """Loss for training the score model + Args: + - model: backbone model; could be score, noise, or velocity + - x1: datapoint + - model_kwargs: additional arguments for the model + """ + if model_kwargs == None: + model_kwargs = {} + + t, x0, x1 = self.sample(x1) + t, xt, ut = self.path_sampler.plan(t, x0, x1) + model_output = model(xt, t, **model_kwargs) + B, *_, C = xt.shape + assert model_output.size() == (B, *xt.size()[1:-1], C) + + terms = {} + terms['pred'] = model_output + if self.model_type == ModelType.VELOCITY: + terms['loss'] = mean_flat(((model_output - ut) ** 2)) + else: + _, drift_var = self.path_sampler.compute_drift(xt, t) + sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt)) + if self.loss_type in [WeightType.VELOCITY]: + weight = (drift_var / sigma_t) ** 2 + elif self.loss_type in [WeightType.LIKELIHOOD]: + weight = drift_var / (sigma_t ** 2) + elif self.loss_type in [WeightType.NONE]: + weight = 1 + else: + raise NotImplementedError() + + if self.model_type == ModelType.NOISE: + terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2)) + else: + terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2)) + + return terms + + + def get_drift( + self + ): + """member function for obtaining the drift of the probability flow ODE""" + def score_ode(x, t, model, **model_kwargs): + drift_mean, drift_var = self.path_sampler.compute_drift(x, t) + model_output = model(x, t, **model_kwargs) + return (-drift_mean + drift_var * model_output) # by change of variable + + def noise_ode(x, t, model, **model_kwargs): + drift_mean, drift_var = self.path_sampler.compute_drift(x, t) + sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x)) + model_output = model(x, t, **model_kwargs) + score = model_output / -sigma_t + return (-drift_mean + drift_var * score) + + def velocity_ode(x, t, model, **model_kwargs): + model_output = model(x, t, **model_kwargs) + return model_output + + if self.model_type == ModelType.NOISE: + drift_fn = noise_ode + elif self.model_type == ModelType.SCORE: + drift_fn = score_ode + else: + drift_fn = velocity_ode + + def body_fn(x, t, model, **model_kwargs): + model_output = drift_fn(x, t, model, **model_kwargs) + assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape" + return model_output + + return body_fn + + + def get_score( + self, + ): + """member function for obtaining score of + x_t = alpha_t * x + sigma_t * eps""" + if self.model_type == ModelType.NOISE: + score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0] + elif self.model_type == ModelType.SCORE: + score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs) + elif self.model_type == ModelType.VELOCITY: + score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t) + else: + raise NotImplementedError() + + return score_fn + + +class Sampler: + """Sampler class for the transport model""" + def __init__( + self, + transport, + ): + """Constructor for a general sampler; supporting different sampling methods + Args: + - transport: an tranport object specify model prediction & interpolant type + """ + + self.transport = transport + self.drift = self.transport.get_drift() + self.score = self.transport.get_score() + + def __get_sde_diffusion_and_drift( + self, + *, + diffusion_form="SBDM", + diffusion_norm=1.0, + ): + + def diffusion_fn(x, t): + diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm) + return diffusion + + sde_drift = \ + lambda x, t, model, **kwargs: \ + self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs) + + sde_diffusion = diffusion_fn + + return sde_drift, sde_diffusion + + def __get_last_step( + self, + sde_drift, + *, + last_step, + last_step_size, + ): + """Get the last step function of the SDE solver""" + + if last_step is None: + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x + elif last_step == "Mean": + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x + sde_drift(x, t, model, **model_kwargs) * last_step_size + elif last_step == "Tweedie": + alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long + sigma = self.transport.path_sampler.compute_sigma_t + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs) + elif last_step == "Euler": + last_step_fn = \ + lambda x, t, model, **model_kwargs: \ + x + self.drift(x, t, model, **model_kwargs) * last_step_size + else: + raise NotImplementedError() + + return last_step_fn + + def sample_sde( + self, + *, + sampling_method="Euler", + diffusion_form="SBDM", + diffusion_norm=1.0, + last_step="Mean", + last_step_size=0.04, + num_steps=250, + ): + """returns a sampling function with given SDE settings + Args: + - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama + - diffusion_form: function form of diffusion coefficient; default to be matching SBDM + - diffusion_norm: function magnitude of diffusion coefficient; default to 1 + - last_step: type of the last step; default to identity + - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1] + - num_steps: total integration step of SDE + """ + + if last_step is None: + last_step_size = 0.0 + + sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift( + diffusion_form=diffusion_form, + diffusion_norm=diffusion_norm, + ) + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + diffusion_form=diffusion_form, + sde=True, + eval=True, + reverse=False, + last_step_size=last_step_size, + ) + + _sde = sde( + sde_drift, + sde_diffusion, + t0=t0, + t1=t1, + num_steps=num_steps, + sampler_type=sampling_method + ) + + last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size) + + + def _sample(init, model, **model_kwargs): + xs = _sde.sample(init, model, **model_kwargs) + ts = th.ones(init.size(0), device=init.device) * t1 + x = last_step_fn(xs[-1], ts, model, **model_kwargs) + xs.append(x) + + assert len(xs) == num_steps, "Samples does not match the number of steps" + + return xs + + return _sample + + def sample_ode( + self, + *, + sampling_method="dopri5", + num_steps=50, + atol=1e-6, + rtol=1e-3, + reverse=False, + ): + """returns a sampling function with given ODE settings + Args: + - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 + - num_steps: + - fixed solver (Euler, Heun): the actual number of integration steps performed + - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation + - atol: absolute error tolerance for the solver + - rtol: relative error tolerance for the solver + - reverse: whether solving the ODE in reverse (data to noise); default to False + """ + if reverse: + drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs) + else: + drift = self.drift + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + sde=False, + eval=True, + reverse=reverse, + last_step_size=0.0, + ) + + _ode = ode( + drift=drift, + t0=t0, + t1=t1, + sampler_type=sampling_method, + num_steps=num_steps, + atol=atol, + rtol=rtol, + ) + + return _ode.sample + + def sample_ode_likelihood( + self, + *, + sampling_method="dopri5", + num_steps=50, + atol=1e-6, + rtol=1e-3, + ): + + """returns a sampling function for calculating likelihood with given ODE settings + Args: + - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 + - num_steps: + - fixed solver (Euler, Heun): the actual number of integration steps performed + - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation + - atol: absolute error tolerance for the solver + - rtol: relative error tolerance for the solver + """ + def _likelihood_drift(x, t, model, **model_kwargs): + x, _ = x + eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1 + t = th.ones_like(t) * (1 - t) + with th.enable_grad(): + x.requires_grad = True + grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0] + logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size())))) + drift = self.drift(x, t, model, **model_kwargs) + return (-drift, logp_grad) + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + sde=False, + eval=True, + reverse=False, + last_step_size=0.0, + ) + + _ode = ode( + drift=_likelihood_drift, + t0=t0, + t1=t1, + sampler_type=sampling_method, + num_steps=num_steps, + atol=atol, + rtol=rtol, + ) + + def _sample_fn(x, model, **model_kwargs): + init_logp = th.zeros(x.size(0)).to(x) + input = (x, init_logp) + drift, delta_logp = _ode.sample(input, model, **model_kwargs) + drift, delta_logp = drift[-1], delta_logp[-1] + prior_logp = self.transport.prior_logp(drift) + logp = prior_logp - delta_logp + return logp, drift + + return _sample_fn \ No newline at end of file diff --git a/opensora/models/diffusion/transport/utils.py b/opensora/models/diffusion/transport/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..44646035531326b81883727f973900edb4eac494 --- /dev/null +++ b/opensora/models/diffusion/transport/utils.py @@ -0,0 +1,29 @@ +import torch as th + +class EasyDict: + + def __init__(self, sub_dict): + for k, v in sub_dict.items(): + setattr(self, k, v) + + def __getitem__(self, key): + return getattr(self, key) + +def mean_flat(x): + """ + Take the mean over all non-batch dimensions. + """ + return th.mean(x, dim=list(range(1, len(x.size())))) + +def log_state(state): + result = [] + + sorted_state = dict(sorted(state.items())) + for key, value in sorted_state.items(): + # Check if the value is an instance of a class + if " + +// forward declaration +void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ); + +void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd ) +{ + const int B = tokens.size(0); + const int N = tokens.size(1); + const int H = tokens.size(2); + const int D = tokens.size(3) / 4; + + auto tok = tokens.accessor(); + auto pos = positions.accessor(); + + for (int b = 0; b < B; b++) { + for (int x = 0; x < 2; x++) { // y and then x (2d) + for (int n = 0; n < N; n++) { + + // grab the token position + const int p = pos[b][n][x]; + + for (int h = 0; h < H; h++) { + for (int d = 0; d < D; d++) { + // grab the two values + float u = tok[b][n][h][d+0+x*2*D]; + float v = tok[b][n][h][d+D+x*2*D]; + + // grab the cos,sin + const float inv_freq = fwd * p / powf(base, d/float(D)); + float c = cosf(inv_freq); + float s = sinf(inv_freq); + + // write the result + tok[b][n][h][d+0+x*2*D] = u*c - v*s; + tok[b][n][h][d+D+x*2*D] = v*c + u*s; + } + } + } + } + } +} + +void rope_2d( torch::Tensor tokens, // B,N,H,D + const torch::Tensor positions, // B,N,2 + const float base, + const float fwd ) +{ + TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions"); + TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions"); + TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions"); + TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions"); + TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2"); + TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" ); + + if (tokens.is_cuda()) + rope_2d_cuda( tokens, positions, base, fwd ); + else + rope_2d_cpu( tokens, positions, base, fwd ); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward"); +} diff --git a/opensora/models/diffusion/utils/curope/curope2d.py b/opensora/models/diffusion/utils/curope/curope2d.py new file mode 100644 index 0000000000000000000000000000000000000000..a49c12f8c529e9a889b5ac20c5767158f238e17d --- /dev/null +++ b/opensora/models/diffusion/utils/curope/curope2d.py @@ -0,0 +1,40 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import torch + +try: + import curope as _kernels # run `python setup.py install` +except ModuleNotFoundError: + from . import curope as _kernels # run `python setup.py build_ext --inplace` + + +class cuRoPE2D_func (torch.autograd.Function): + + @staticmethod + def forward(ctx, tokens, positions, base, F0=1): + ctx.save_for_backward(positions) + ctx.saved_base = base + ctx.saved_F0 = F0 + # tokens = tokens.clone() # uncomment this if inplace doesn't work + _kernels.rope_2d( tokens, positions, base, F0 ) + ctx.mark_dirty(tokens) + return tokens + + @staticmethod + def backward(ctx, grad_res): + positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 + _kernels.rope_2d( grad_res, positions, base, -F0 ) + ctx.mark_dirty(grad_res) + return grad_res, None, None, None + + +class cuRoPE2D(torch.nn.Module): + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + + def forward(self, tokens, positions): + cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 ) + return tokens \ No newline at end of file diff --git a/opensora/models/diffusion/utils/curope/kernels.cu b/opensora/models/diffusion/utils/curope/kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..7156cd1bb935cb1f0be45e58add53f9c21505c20 --- /dev/null +++ b/opensora/models/diffusion/utils/curope/kernels.cu @@ -0,0 +1,108 @@ +/* + Copyright (C) 2022-present Naver Corporation. All rights reserved. + Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +*/ + +#include +#include +#include +#include + +#define CHECK_CUDA(tensor) {\ + TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \ + TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); } +void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));} + + +template < typename scalar_t > +__global__ void rope_2d_cuda_kernel( + //scalar_t* __restrict__ tokens, + torch::PackedTensorAccessor32 tokens, + const int64_t* __restrict__ pos, + const float base, + const float fwd ) + // const int N, const int H, const int D ) +{ + // tokens shape = (B, N, H, D) + const int N = tokens.size(1); + const int H = tokens.size(2); + const int D = tokens.size(3); + + // each block update a single token, for all heads + // each thread takes care of a single output + extern __shared__ float shared[]; + float* shared_inv_freq = shared + D; + + const int b = blockIdx.x / N; + const int n = blockIdx.x % N; + + const int Q = D / 4; + // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D] + // u_Y v_Y u_X v_X + + // shared memory: first, compute inv_freq + if (threadIdx.x < Q) + shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q)); + __syncthreads(); + + // start of X or Y part + const int X = threadIdx.x < D/2 ? 0 : 1; + const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X + + // grab the cos,sin appropriate for me + const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q]; + const float cos = cosf(freq); + const float sin = sinf(freq); + /* + float* shared_cos_sin = shared + D + D/4; + if ((threadIdx.x % (D/2)) < Q) + shared_cos_sin[m+0] = cosf(freq); + else + shared_cos_sin[m+Q] = sinf(freq); + __syncthreads(); + const float cos = shared_cos_sin[m+0]; + const float sin = shared_cos_sin[m+Q]; + */ + + for (int h = 0; h < H; h++) + { + // then, load all the token for this head in shared memory + shared[threadIdx.x] = tokens[b][n][h][threadIdx.x]; + __syncthreads(); + + const float u = shared[m]; + const float v = shared[m+Q]; + + // write output + if ((threadIdx.x % (D/2)) < Q) + tokens[b][n][h][threadIdx.x] = u*cos - v*sin; + else + tokens[b][n][h][threadIdx.x] = v*cos + u*sin; + } +} + +void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ) +{ + const int B = tokens.size(0); // batch size + const int N = tokens.size(1); // sequence length + const int H = tokens.size(2); // number of heads + const int D = tokens.size(3); // dimension per head + + TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous"); + TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous"); + TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape"); + TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4"); + + // one block for each layer, one thread per local-max + const int THREADS_PER_BLOCK = D; + const int N_BLOCKS = B * N; // each block takes care of H*D values + const int SHARED_MEM = sizeof(float) * (D + D/4); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] { + rope_2d_cuda_kernel <<>> ( + //tokens.data_ptr(), + tokens.packed_accessor32(), + pos.data_ptr(), + base, fwd); //, N, H, D ); + })); +} diff --git a/opensora/models/diffusion/utils/curope/setup.py b/opensora/models/diffusion/utils/curope/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..230632ed05e309200e8f93a3a852072333975009 --- /dev/null +++ b/opensora/models/diffusion/utils/curope/setup.py @@ -0,0 +1,34 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +from setuptools import setup +from torch import cuda +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +# compile for all possible CUDA architectures +all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split() +# alternatively, you can list cuda archs that you want, eg: +# all_cuda_archs = [ + # '-gencode', 'arch=compute_70,code=sm_70', + # '-gencode', 'arch=compute_75,code=sm_75', + # '-gencode', 'arch=compute_80,code=sm_80', + # '-gencode', 'arch=compute_86,code=sm_86' +# ] + +setup( + name = 'curope', + ext_modules = [ + CUDAExtension( + name='curope', + sources=[ + "curope.cpp", + "kernels.cu", + ], + extra_compile_args = dict( + nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs, + cxx=['-O3']) + ) + ], + cmdclass = { + 'build_ext': BuildExtension + }) diff --git a/opensora/models/diffusion/utils/pos_embed.py b/opensora/models/diffusion/utils/pos_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..b7dac72746517005da9131bb83ebcbbfc43f40db --- /dev/null +++ b/opensora/models/diffusion/utils/pos_embed.py @@ -0,0 +1,135 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +# croco: https://github.com/naver/croco +# diffusers: https://github.com/huggingface/diffusers +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np +import torch + + +def get_2d_sincos_pos_embed( + embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 +): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale + grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed( + embed_dim, length, interpolation_scale=1.0, base_size=16 +): + pos = torch.arange(0, length).unsqueeze(1) / interpolation_scale + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, pos) + return pos_embed + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# ---------------------------------------------------------- +# RoPE2D: RoPE implementation in 2D +# ---------------------------------------------------------- + +try: + from .curope import cuRoPE2D + + RoPE2D = cuRoPE2D +except ImportError: + print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead') + + + class RoPE2D(torch.nn.Module): + + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D, seq_len, device, dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D, seq_len, device, dtype] = (cos, sin) + return self.cache[D, seq_len, device, dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 2 (y and x position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + assert tokens.size(3) % 2 == 0, "number of dimensions should be a multiple of two" + D = tokens.size(3) // 2 + assert positions.ndim == 3 and positions.shape[-1] == 2 # Batch, Seq, 2 + cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype) + # split features into two along the feature dimension, and apply rope1d on each half + y, x = tokens.chunk(2, dim=-1) + y = self.apply_rope1d(y, positions[:, :, 0], cos, sin) + x = self.apply_rope1d(x, positions[:, :, 1], cos, sin) + tokens = torch.cat((y, x), dim=-1) + return tokens diff --git a/opensora/models/frame_interpolation/cfgs/AMT-G.yaml b/opensora/models/frame_interpolation/cfgs/AMT-G.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d259d4fe97336d2a3e9c9e34a0067ca5ddfae1f0 --- /dev/null +++ b/opensora/models/frame_interpolation/cfgs/AMT-G.yaml @@ -0,0 +1,9 @@ + +seed: 2023 + +network: + name: networks.AMT-G.Model + params: + corr_radius: 3 + corr_lvls: 4 + num_flows: 5 \ No newline at end of file diff --git a/opensora/models/frame_interpolation/interpolation.py b/opensora/models/frame_interpolation/interpolation.py new file mode 100644 index 0000000000000000000000000000000000000000..0c60bed8a4dc7747448156a18d93efacaaa837e7 --- /dev/null +++ b/opensora/models/frame_interpolation/interpolation.py @@ -0,0 +1,197 @@ +# this script is modified from https://github.com/MCG-NKU/AMT/blob/main/demos/demo_2x.py +from json import load +import os +import cv2 +import sys +import glob +import torch +import argparse +import numpy as np +import os.path as osp +from warnings import warn +from omegaconf import OmegaConf +from torchvision.utils import make_grid +sys.path.append('.') +from utils.utils import ( + read, write, + img2tensor, tensor2img, + check_dim_and_resize + ) +from utils.build_utils import build_from_cfg +from utils.utils import InputPadder + + +AMT_G = { + 'name': 'networks.AMT-G.Model', + 'params':{ + 'corr_radius': 3, + 'corr_lvls': 4, + 'num_flows': 5, + } +} + + + +def init(device="cuda"): + + ''' + initialize the device and the anchor resolution. + ''' + + if device == 'cuda': + anchor_resolution = 1024 * 512 + anchor_memory = 1500 * 1024**2 + anchor_memory_bias = 2500 * 1024**2 + vram_avail = torch.cuda.get_device_properties(device).total_memory + print("VRAM available: {:.1f} MB".format(vram_avail / 1024 ** 2)) + else: + # Do not resize in cpu mode + anchor_resolution = 8192*8192 + anchor_memory = 1 + anchor_memory_bias = 0 + vram_avail = 1 + + return anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail + +def get_input_video_from_path(input_path, device="cuda"): + + ''' + Get the input video from the input_path. + + params: + input_path: str, the path of the input video. + devices: str, the device to run the model. + returns: + inputs: list, the list of the input frames. + scale: float, the scale of the input frames. + padder: InputPadder, the padder to pad the input frames. + ''' + + anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail = init(device) + + if osp.splitext(input_path)[-1] in ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', + '.webm', '.MP4', '.AVI', '.MOV', '.MKV', '.FLV', + '.WMV', '.WEBM']: + + vcap = cv2.VideoCapture(input_path) + + inputs = [] + w = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory) + scale = 1 if scale > 1 else scale + scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16 + if scale < 1: + print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}") + padding = int(16 / scale) + padder = InputPadder((h, w), padding) + while True: + ret, frame = vcap.read() + if ret is False: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame_t = img2tensor(frame).to(device) + frame_t = padder.pad(frame_t) + inputs.append(frame_t) + print(f'Loading the [video] from {input_path}, the number of frames [{len(inputs)}]') + else: + raise TypeError("Input should be a video.") + + return inputs, scale, padder + + +def load_model(ckpt_path, device="cuda"): + + ''' + load the frame interpolation model. + ''' + network_cfg = AMT_G + network_name = network_cfg['name'] + print(f'Loading [{network_name}] from [{ckpt_path}]...') + model = build_from_cfg(network_cfg) + ckpt = torch.load(ckpt_path) + model.load_state_dict(ckpt['state_dict']) + model = model.to(device) + model.eval() + return model + +def interpolater(model, inputs, scale, padder, iters=1): + + ''' + interpolating with the interpolation model. + + params: + model: nn.Module, the frame interpolation model. + inputs: list, the list of the input frames. + scale: float, the scale of the input frames. + iters: int, the number of iterations of interpolation. The final frames model generating is 2 ** iters * (m - 1) + 1 and m is input frames. + returns: + outputs: list, the list of the output frames. + ''' + + print(f'Start frame interpolation:') + embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) + + for i in range(iters): + print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}') + outputs = [inputs[0]] + for in_0, in_1 in zip(inputs[:-1], inputs[1:]): + in_0 = in_0.to(device) + in_1 = in_1.to(device) + with torch.no_grad(): + imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred'] + outputs += [imgt_pred.cpu(), in_1.cpu()] + inputs = outputs + + outputs = padder.unpad(*outputs) + + return outputs + +def write(outputs, input_path, output_path, frame_rate=30): + ''' + write results to the output_path. + ''' + + if osp.exists(output_path) is False: + os.makedirs(output_path) + + + size = outputs[0].shape[2:][::-1] + + _, file_name_with_extension = os.path.split(input_path) + file_name, _ = os.path.splitext(file_name_with_extension) + + save_video_path = f'{output_path}/output_{file_name}.mp4' + writer = cv2.VideoWriter(save_video_path, cv2.VideoWriter_fourcc(*"mp4v"), + frame_rate, size) + + for i, imgt_pred in enumerate(outputs): + imgt_pred = tensor2img(imgt_pred) + imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR) + writer.write(imgt_pred) + print(f"Demo video is saved to [{save_video_path}]") + + writer.release() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--ckpt', type=str, default='amt-g.pth', help="The pretrained model.") + parser.add_argument('--niters', type=int, default=1, help="Iter of Interpolation. The number of frames will be double after per iter.") + parser.add_argument('--input', default="test.mp4", help="Input video.") + parser.add_argument('--output_path', type=str, default='results', help="Output path.") + parser.add_argument('--frame_rate', type=int, default=30, help="Frames rate of the output video.") + + args = parser.parse_args() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + ckpt_path = args.ckpt + input_path = args.input + output_path = args.output_path + iters = int(args.niters) + frame_rate = int(args.frame_rate) + + inputs, scale, padder = get_input_video_from_path(input_path, device) + model = load_model(ckpt_path, device) + outputs = interpolater(model, inputs, scale, padder, iters) + write(outputs, input_path, output_path, frame_rate) diff --git a/opensora/models/frame_interpolation/networks/AMT-G.py b/opensora/models/frame_interpolation/networks/AMT-G.py new file mode 100644 index 0000000000000000000000000000000000000000..a24cb1a3704984418788bb1f8f0e9946c87886e3 --- /dev/null +++ b/opensora/models/frame_interpolation/networks/AMT-G.py @@ -0,0 +1,172 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from networks.blocks.raft import ( + coords_grid, + BasicUpdateBlock, BidirCorrBlock +) +from networks.blocks.feat_enc import ( + LargeEncoder +) +from networks.blocks.ifrnet import ( + resize, + Encoder, + InitDecoder, + IntermediateDecoder +) +from networks.blocks.multi_flow import ( + multi_flow_combine, + MultiFlowDecoder +) + + +class Model(nn.Module): + def __init__(self, + corr_radius=3, + corr_lvls=4, + num_flows=5, + channels=[84, 96, 112, 128], + skip_channels=84): + super(Model, self).__init__() + self.radius = corr_radius + self.corr_levels = corr_lvls + self.num_flows = num_flows + + self.feat_encoder = LargeEncoder(output_dim=128, norm_fn='instance', dropout=0.) + self.encoder = Encoder(channels, large=True) + self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) + self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) + self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels) + self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows) + + self.update4 = self._get_updateblock(112, None) + self.update3_low = self._get_updateblock(96, 2.0) + self.update2_low = self._get_updateblock(84, 4.0) + + self.update3_high = self._get_updateblock(96, None) + self.update2_high = self._get_updateblock(84, None) + + self.comb_block = nn.Sequential( + nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3), + nn.PReLU(6*self.num_flows), + nn.Conv2d(6*self.num_flows, 3, 7, 1, 3), + ) + + def _get_updateblock(self, cdim, scale_factor=None): + return BasicUpdateBlock(cdim=cdim, hidden_dim=192, flow_dim=64, + corr_dim=256, corr_dim2=192, fc_dim=188, + scale_factor=scale_factor, corr_levels=self.corr_levels, + radius=self.radius) + + def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): + # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 + # based on linear assumption + t1_scale = 1. / embt + t0_scale = 1. / (1. - embt) + if downsample != 1: + inv = 1 / downsample + flow0 = inv * resize(flow0, scale_factor=inv) + flow1 = inv * resize(flow1, scale_factor=inv) + + corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) + corr = torch.cat([corr0, corr1], dim=1) + flow = torch.cat([flow0, flow1], dim=1) + return corr, flow + + def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): + mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + img0 = img0 - mean_ + img1 = img1 - mean_ + img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 + img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 + b, _, h, w = img0_.shape + coord = coords_grid(b, h // 8, w // 8, img0.device) + + fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] + corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) + + # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] + # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] + f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) + f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) + + ######################################### the 4th decoder ######################################### + up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt) + corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, + up_flow0_4, up_flow1_4, + embt, downsample=1) + + # residue update with lookup corr + delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) + delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) + up_flow0_4 = up_flow0_4 + delta_flow0_4 + up_flow1_4 = up_flow1_4 + delta_flow1_4 + ft_3_ = ft_3_ + delta_ft_3_ + + ######################################### the 3rd decoder ######################################### + up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) + corr_3, flow_3 = self._corr_scale_lookup(corr_fn, + coord, up_flow0_3, up_flow1_3, + embt, downsample=2) + + # residue update with lookup corr + delta_ft_2_, delta_flow_3 = self.update3_low(ft_2_, flow_3, corr_3) + delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1) + up_flow0_3 = up_flow0_3 + delta_flow0_3 + up_flow1_3 = up_flow1_3 + delta_flow1_3 + ft_2_ = ft_2_ + delta_ft_2_ + + # residue update with lookup corr (hr) + corr_3 = resize(corr_3, scale_factor=2.0) + up_flow_3 = torch.cat([up_flow0_3, up_flow1_3], dim=1) + delta_ft_2_, delta_up_flow_3 = self.update3_high(ft_2_, up_flow_3, corr_3) + ft_2_ += delta_ft_2_ + up_flow0_3 += delta_up_flow_3[:, 0:2] + up_flow1_3 += delta_up_flow_3[:, 2:4] + + ######################################### the 2nd decoder ######################################### + up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) + corr_2, flow_2 = self._corr_scale_lookup(corr_fn, + coord, up_flow0_2, up_flow1_2, + embt, downsample=4) + + # residue update with lookup corr + delta_ft_1_, delta_flow_2 = self.update2_low(ft_1_, flow_2, corr_2) + delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) + up_flow0_2 = up_flow0_2 + delta_flow0_2 + up_flow1_2 = up_flow1_2 + delta_flow1_2 + ft_1_ = ft_1_ + delta_ft_1_ + + # residue update with lookup corr (hr) + corr_2 = resize(corr_2, scale_factor=4.0) + up_flow_2 = torch.cat([up_flow0_2, up_flow1_2], dim=1) + delta_ft_1_, delta_up_flow_2 = self.update2_high(ft_1_, up_flow_2, corr_2) + ft_1_ += delta_ft_1_ + up_flow0_2 += delta_up_flow_2[:, 0:2] + up_flow1_2 += delta_up_flow_2[:, 2:4] + + ######################################### the 1st decoder ######################################### + up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) + + if scale_factor != 1.0: + up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) + up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) + mask = resize(mask, scale_factor=(1.0/scale_factor)) + img_res = resize(img_res, scale_factor=(1.0/scale_factor)) + + # Merge multiple predictions + imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, + mask, img_res, mean_) + imgt_pred = torch.clamp(imgt_pred, 0, 1) + + if eval: + return { 'imgt_pred': imgt_pred, } + else: + up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w) + up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w) + return { + 'imgt_pred': imgt_pred, + 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], + 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], + 'ft_pred': [ft_1_, ft_2_, ft_3_], + } diff --git a/opensora/models/frame_interpolation/networks/__init__.py b/opensora/models/frame_interpolation/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/opensora/models/frame_interpolation/networks/blocks/__init__.py b/opensora/models/frame_interpolation/networks/blocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/opensora/models/frame_interpolation/networks/blocks/feat_enc.py b/opensora/models/frame_interpolation/networks/blocks/feat_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..3805bd315422703c19bf6a4d0962ee75002d92aa --- /dev/null +++ b/opensora/models/frame_interpolation/networks/blocks/feat_enc.py @@ -0,0 +1,343 @@ +import torch +import torch.nn as nn + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(72, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class LargeEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(LargeEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(112, stride=2) + self.layer3 = self._make_layer(160, stride=2) + self.layer3_2 = self._make_layer(160, stride=1) + + # output convolution + self.conv2 = nn.Conv2d(self.in_planes, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer3_2(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/opensora/models/frame_interpolation/networks/blocks/ifrnet.py b/opensora/models/frame_interpolation/networks/blocks/ifrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..586ae61036191a52337a791f3e7442899fdf5fc9 --- /dev/null +++ b/opensora/models/frame_interpolation/networks/blocks/ifrnet.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from utils.flow_utils import warp + + +def resize(x, scale_factor): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + +def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), + nn.PReLU(out_channels) + ) + +class ResBlock(nn.Module): + def __init__(self, in_channels, side_channels, bias=True): + super(ResBlock, self).__init__() + self.side_channels = side_channels + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv3 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv4 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) + self.prelu = nn.PReLU(in_channels) + + def forward(self, x): + out = self.conv1(x) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv2(side_feat) + out = self.conv3(torch.cat([res_feat, side_feat], 1)) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv4(side_feat) + out = self.conv5(torch.cat([res_feat, side_feat], 1)) + + out = self.prelu(x + out) + return out + +class Encoder(nn.Module): + def __init__(self, channels, large=False): + super(Encoder, self).__init__() + self.channels = channels + prev_ch = 3 + for idx, ch in enumerate(channels, 1): + k = 7 if large and idx == 1 else 3 + p = 3 if k ==7 else 1 + self.register_module(f'pyramid{idx}', + nn.Sequential( + convrelu(prev_ch, ch, k, 2, p), + convrelu(ch, ch, 3, 1, 1) + )) + prev_ch = ch + + def forward(self, in_x): + fs = [] + for idx in range(len(self.channels)): + out_x = getattr(self, f'pyramid{idx+1}')(in_x) + fs.append(out_x) + in_x = out_x + return fs + +class InitDecoder(nn.Module): + def __init__(self, in_ch, out_ch, skip_ch) -> None: + super().__init__() + self.convblock = nn.Sequential( + convrelu(in_ch*2+1, in_ch*2), + ResBlock(in_ch*2, skip_ch), + nn.ConvTranspose2d(in_ch*2, out_ch+4, 4, 2, 1, bias=True) + ) + def forward(self, f0, f1, embt): + h, w = f0.shape[2:] + embt = embt.repeat(1, 1, h, w) + out = self.convblock(torch.cat([f0, f1, embt], 1)) + flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) + ft_ = out[:, 4:, ...] + return flow0, flow1, ft_ + +class IntermediateDecoder(nn.Module): + def __init__(self, in_ch, out_ch, skip_ch) -> None: + super().__init__() + self.convblock = nn.Sequential( + convrelu(in_ch*3+4, in_ch*3), + ResBlock(in_ch*3, skip_ch), + nn.ConvTranspose2d(in_ch*3, out_ch+4, 4, 2, 1, bias=True) + ) + def forward(self, ft_, f0, f1, flow0_in, flow1_in): + f0_warp = warp(f0, flow0_in) + f1_warp = warp(f1, flow1_in) + f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1) + out = self.convblock(f_in) + flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) + ft_ = out[:, 4:, ...] + flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0) + flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0) + return flow0, flow1, ft_ \ No newline at end of file diff --git a/opensora/models/frame_interpolation/networks/blocks/multi_flow.py b/opensora/models/frame_interpolation/networks/blocks/multi_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..4563c3262b980ec4489ac96177dea522caa84f21 --- /dev/null +++ b/opensora/models/frame_interpolation/networks/blocks/multi_flow.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from utils.flow_utils import warp +from networks.blocks.ifrnet import ( + convrelu, resize, + ResBlock, +) + + +def multi_flow_combine(comb_block, img0, img1, flow0, flow1, + mask=None, img_res=None, mean=None): + ''' + A parallel implementation of multiple flow field warping + comb_block: An nn.Seqential object. + img shape: [b, c, h, w] + flow shape: [b, 2*num_flows, h, w] + mask (opt): + If 'mask' is None, the function conduct a simple average. + img_res (opt): + If 'img_res' is None, the function adds zero instead. + mean (opt): + If 'mean' is None, the function adds zero instead. + ''' + b, c, h, w = flow0.shape + num_flows = c // 2 + flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + + mask = mask.reshape(b, num_flows, 1, h, w + ).reshape(-1, 1, h, w) if mask is not None else None + img_res = img_res.reshape(b, num_flows, 3, h, w + ).reshape(-1, 3, h, w) if img_res is not None else 0 + img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) + img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) + mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1 + ) if mean is not None else 0 + + img0_warp = warp(img0, flow0) + img1_warp = warp(img1, flow1) + img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res + img_warps = img_warps.reshape(b, num_flows, 3, h, w) + imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) + return imgt_pred + + +class MultiFlowDecoder(nn.Module): + def __init__(self, in_ch, skip_ch, num_flows=3): + super(MultiFlowDecoder, self).__init__() + self.num_flows = num_flows + self.convblock = nn.Sequential( + convrelu(in_ch*3+4, in_ch*3), + ResBlock(in_ch*3, skip_ch), + nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True) + ) + + def forward(self, ft_, f0, f1, flow0, flow1): + n = self.num_flows + f0_warp = warp(f0, flow0) + f1_warp = warp(f1, flow1) + out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)) + delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1) + mask = torch.sigmoid(mask) + + flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0 + ).repeat(1, self.num_flows, 1, 1) + flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0 + ).repeat(1, self.num_flows, 1, 1) + + return flow0, flow1, mask, img_res \ No newline at end of file diff --git a/opensora/models/frame_interpolation/networks/blocks/raft.py b/opensora/models/frame_interpolation/networks/blocks/raft.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb85ad6556a28f5b80034c595be539fd700ad48 --- /dev/null +++ b/opensora/models/frame_interpolation/networks/blocks/raft.py @@ -0,0 +1,207 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def resize(x, scale_factor): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + + +def bilinear_sampler(img, coords, mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid(torch.arange(ht, device=device), + torch.arange(wd, device=device), + indexing='ij') + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +class SmallUpdateBlock(nn.Module): + def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim, + corr_levels=4, radius=3, scale_factor=None): + super(SmallUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) **2 + self.scale_factor = scale_factor + + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = resize(net, 1 / self.scale_factor + ) if self.scale_factor is not None else net + cor = self.lrelu(self.convc1(corr)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) + + return delta_net, delta_flow + + +class BasicUpdateBlock(nn.Module): + def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2, + fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1): + super(BasicUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) **2 + + self.scale_factor = scale_factor + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1) + self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = resize(net, 1 / self.scale_factor + ) if self.scale_factor is not None else net + cor = self.lrelu(self.convc1(corr)) + cor = self.lrelu(self.convc2(cor)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) + return delta_net, delta_flow + + +class BidirCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + self.corr_pyramid_T = [] + + corr = BidirCorrBlock.corr(fmap1, fmap2) + batch, h1, w1, dim, h2, w2 = corr.shape + corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2) + + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + corr_T = corr_T.reshape(batch*h2*w2, dim, h1, w1) + + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + for _ in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + corr_T = F.avg_pool2d(corr_T, 2, stride=2) + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + def __call__(self, coords0, coords1): + r = self.radius + coords0 = coords0.permute(0, 2, 3, 1) + coords1 = coords1.permute(0, 2, 3, 1) + assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]" + batch, h1, w1, _ = coords0.shape + + out_pyramid = [] + out_pyramid_T = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + corr_T = self.corr_pyramid_T[i] + + dx = torch.linspace(-r, r, 2*r+1, device=coords0.device) + dy = torch.linspace(-r, r, 2*r+1, device=coords0.device) + delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1) + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + + centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i + centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i + coords_lvl_0 = centroid_lvl_0 + delta_lvl + coords_lvl_1 = centroid_lvl_1 + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl_0) + corr_T = bilinear_sampler(corr_T, coords_lvl_1) + corr = corr.view(batch, h1, w1, -1) + corr_T = corr_T.view(batch, h1, w1, -1) + out_pyramid.append(corr) + out_pyramid_T.append(corr_T) + + out = torch.cat(out_pyramid, dim=-1) + out_T = torch.cat(out_pyramid_T, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + corr = torch.matmul(fmap1.transpose(1,2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) \ No newline at end of file diff --git a/opensora/models/frame_interpolation/readme.md b/opensora/models/frame_interpolation/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..1f35e92a600f33d52858c6ab9d0b4a7d650908cd --- /dev/null +++ b/opensora/models/frame_interpolation/readme.md @@ -0,0 +1,17 @@ +#### Frame Interpolation + +We use AMT as our frame interpolation model. (Thanks [AMT](https://github.com/MCG-NKU/AMT)) After sampling, you can use frame interpolation model to interpolate your video smoothly. + +1. Download the pretrained weights from [AMT](https://github.com/MCG-NKU/AMT), we recommend using the largest model AMT-G to achieve the best performance. +2. Run the script of frame interpolation. +``` +python opensora/models/frame_interpolation/interpolation.py --ckpt /path/to/ckpt --niters 1 --input /path/to/input/video.mp4 --output_path /path/to/output/floder --frame_rate 30 +``` +3. The output video will be stored at output_path and its duration time is equal `the total number of frames after frame interpolation / the frame rate` +##### Frame Interpolation Specific Settings + +* `--ckpt`: Pretrained model of [AMT](https://github.com/MCG-NKU/AMT). We use AMT-G as our frame interpolation model. +* `--niter`: Iterations of interpolation. With $m$ input frames, `[N_ITER]` $=n$ corresponds to $2^n\times (m-1)+1$ output frames. +* `--input`: Path of the input video. +* `--output_path`: Folder Path of the output video. +* `--frame_rate"`: Frame rate of the output video. diff --git a/opensora/models/frame_interpolation/utils/__init__.py b/opensora/models/frame_interpolation/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/opensora/models/frame_interpolation/utils/build_utils.py b/opensora/models/frame_interpolation/utils/build_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2d34c3b8d45d97961a175784b1c0a362bed3a508 --- /dev/null +++ b/opensora/models/frame_interpolation/utils/build_utils.py @@ -0,0 +1,12 @@ +import importlib + + +def base_build_fn(module, cls, params): + return getattr(importlib.import_module( + module, package=None), cls)(**params) + + +def build_from_cfg(config): + module, cls = config['name'].rsplit(".", 1) + params = config.get('params', {}) + return base_build_fn(module, cls, params) diff --git a/opensora/models/frame_interpolation/utils/dist_utils.py b/opensora/models/frame_interpolation/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6337f9991fc510cfb6cbc7da18574eb443ec1dac --- /dev/null +++ b/opensora/models/frame_interpolation/utils/dist_utils.py @@ -0,0 +1,48 @@ +import os +import torch + + +def get_world_size(): + """Find OMPI world size without calling mpi functions + :rtype: int + """ + if os.environ.get('PMI_SIZE') is not None: + return int(os.environ.get('PMI_SIZE') or 1) + elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) + else: + return torch.cuda.device_count() + + +def get_global_rank(): + """Find OMPI world rank without calling mpi functions + :rtype: int + """ + if os.environ.get('PMI_RANK') is not None: + return int(os.environ.get('PMI_RANK') or 0) + elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) + else: + return 0 + + +def get_local_rank(): + """Find OMPI local rank without calling mpi functions + :rtype: int + """ + if os.environ.get('MPI_LOCALRANKID') is not None: + return int(os.environ.get('MPI_LOCALRANKID') or 0) + elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) + else: + return 0 + + +def get_master_ip(): + if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: + return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] + elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: + return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') + else: + return "127.0.0.1" + diff --git a/opensora/models/frame_interpolation/utils/flow_utils.py b/opensora/models/frame_interpolation/utils/flow_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..84fca2049783b22175e0d1e024a19a5f9a79906e --- /dev/null +++ b/opensora/models/frame_interpolation/utils/flow_utils.py @@ -0,0 +1,122 @@ +import numpy as np +import torch +from PIL import ImageFile +import torch.nn.functional as F +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +def warp(img, flow): + B, _, H, W = flow.shape + xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1) + yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W) + grid = torch.cat([xx, yy], 1).to(img) + flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1) + grid_ = (grid + flow_).permute(0, 2, 3, 1) + output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True) + return output + + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/opensora/models/frame_interpolation/utils/utils.py b/opensora/models/frame_interpolation/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0473226d4eaf98e41e7ae3ee81b722308765e96c --- /dev/null +++ b/opensora/models/frame_interpolation/utils/utils.py @@ -0,0 +1,297 @@ +import re +import sys +import torch +import random +import numpy as np +from PIL import ImageFile +import torch.nn.functional as F +from imageio import imread, imwrite +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +class AverageMeter(): + def __init__(self): + self.reset() + + def reset(self): + self.val = 0. + self.avg = 0. + self.sum = 0. + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class AverageMeterGroups: + def __init__(self) -> None: + self.meter_dict = dict() + + def update(self, dict, n=1): + for name, val in dict.items(): + if self.meter_dict.get(name) is None: + self.meter_dict[name] = AverageMeter() + self.meter_dict[name].update(val, n) + + def reset(self, name=None): + if name is None: + for v in self.meter_dict.values(): + v.reset() + else: + meter = self.meter_dict.get(name) + if meter is not None: + meter.reset() + + def avg(self, name): + meter = self.meter_dict.get(name) + if meter is not None: + return meter.avg + + +class InputPadder: + """ Pads images such that dimensions are divisible by divisor """ + def __init__(self, dims, divisor=16): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor + pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + + def pad(self, *inputs): + if len(inputs) == 1: + return F.pad(inputs[0], self._pad, mode='replicate') + else: + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self, *inputs): + if len(inputs) == 1: + return self._unpad(inputs[0]) + else: + return [self._unpad(x) for x in inputs] + + def _unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + + +def img2tensor(img): + if img.shape[-1] > 3: + img = img[:,:,:3] + return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0 + + +def tensor2img(img_t): + return (img_t * 255.).detach( + ).squeeze(0).permute(1, 2, 0).cpu().numpy( + ).clip(0, 255).astype(np.uint8) + +def seed_all(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def read(file): + if file.endswith('.float3'): return readFloat(file) + elif file.endswith('.flo'): return readFlow(file) + elif file.endswith('.ppm'): return readImage(file) + elif file.endswith('.pgm'): return readImage(file) + elif file.endswith('.png'): return readImage(file) + elif file.endswith('.jpg'): return readImage(file) + elif file.endswith('.pfm'): return readPFM(file)[0] + else: raise Exception('don\'t know how to read %s' % file) + + +def write(file, data): + if file.endswith('.float3'): return writeFloat(file, data) + elif file.endswith('.flo'): return writeFlow(file, data) + elif file.endswith('.ppm'): return writeImage(file, data) + elif file.endswith('.pgm'): return writeImage(file, data) + elif file.endswith('.png'): return writeImage(file, data) + elif file.endswith('.jpg'): return writeImage(file, data) + elif file.endswith('.pfm'): return writePFM(file, data) + else: raise Exception('don\'t know how to write %s' % file) + + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == 'PF': + color = True + elif header.decode("ascii") == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + endian = '<' + scale = -scale + else: + endian = '>' + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data, scale + + +def writePFM(file, image, scale=1): + file = open(file, 'wb') + + color = None + + if image.dtype.name != 'float32': + raise Exception('Image dtype must be float32.') + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: + color = True + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: + color = False + else: + raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') + + file.write('PF\n' if color else 'Pf\n'.encode()) + file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == '<' or endian == '=' and sys.byteorder == 'little': + scale = -scale + + file.write('%f\n'.encode() % scale) + + image.tofile(file) + + +def readFlow(name): + if name.endswith('.pfm') or name.endswith('.PFM'): + return readPFM(name)[0][:,:,0:2] + + f = open(name, 'rb') + + header = f.read(4) + if header.decode("utf-8") != 'PIEH': + raise Exception('Flow file header does not contain PIEH') + + width = np.fromfile(f, np.int32, 1).squeeze() + height = np.fromfile(f, np.int32, 1).squeeze() + + flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2)) + + return flow.astype(np.float32) + + +def readImage(name): + if name.endswith('.pfm') or name.endswith('.PFM'): + data = readPFM(name)[0] + if len(data.shape)==3: + return data[:,:,0:3] + else: + return data + return imread(name) + + +def writeImage(name, data): + if name.endswith('.pfm') or name.endswith('.PFM'): + return writePFM(name, data, 1) + return imwrite(name, data) + + +def writeFlow(name, flow): + f = open(name, 'wb') + f.write('PIEH'.encode('utf-8')) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + flow = flow.astype(np.float32) + flow.tofile(f) + + +def readFloat(name): + f = open(name, 'rb') + + if(f.readline().decode("utf-8")) != 'float\n': + raise Exception('float file %s did not contain keyword' % name) + + dim = int(f.readline()) + + dims = [] + count = 1 + for i in range(0, dim): + d = int(f.readline()) + dims.append(d) + count *= d + + dims = list(reversed(dims)) + + data = np.fromfile(f, np.float32, count).reshape(dims) + if dim > 2: + data = np.transpose(data, (2, 1, 0)) + data = np.transpose(data, (1, 0, 2)) + + return data + + +def writeFloat(name, data): + f = open(name, 'wb') + + dim=len(data.shape) + if dim>3: + raise Exception('bad float file dimension: %d' % dim) + + f.write(('float\n').encode('ascii')) + f.write(('%d\n' % dim).encode('ascii')) + + if dim == 1: + f.write(('%d\n' % data.shape[0]).encode('ascii')) + else: + f.write(('%d\n' % data.shape[1]).encode('ascii')) + f.write(('%d\n' % data.shape[0]).encode('ascii')) + for i in range(2, dim): + f.write(('%d\n' % data.shape[i]).encode('ascii')) + + data = data.astype(np.float32) + if dim==2: + data.tofile(f) + + else: + np.transpose(data, (2, 0, 1)).tofile(f) + + +def check_dim_and_resize(tensor_list): + shape_list = [] + for t in tensor_list: + shape_list.append(t.shape[2:]) + + if len(set(shape_list)) > 1: + desired_shape = shape_list[0] + print(f'Inconsistent size of input video frames. All frames will be resized to {desired_shape}') + + resize_tensor_list = [] + for t in tensor_list: + resize_tensor_list.append(torch.nn.functional.interpolate(t, size=tuple(desired_shape), mode='bilinear')) + + tensor_list = resize_tensor_list + + return tensor_list + diff --git a/opensora/models/super_resolution/README.md b/opensora/models/super_resolution/README.md new file mode 100644 index 0000000000000000000000000000000000000000..869cee78d438905cce91d852d20b8a858a2cf25b --- /dev/null +++ b/opensora/models/super_resolution/README.md @@ -0,0 +1,25 @@ + +## Environment Preparation + +For video super resolution, please prepare your own python envirment from [RGT](https://github.com/zhengchen1999/RGT) and down the [ckpt](https://drive.google.com/drive/folders/1zxrr31Kp2D_N9a-OUAPaJEn_yTaSXTfZ) into the folder like +```bash +./experiments/pretrained_models/RGT_x2.pth +``` + +## Video Super Resolution +The inferencing instruction is in [run.py](run.py). +```bash +python run.py --SR x4 --root_path /path_to_root --input_dir /path_to_input_dir --output_dir /path_to_video_output +``` +You can configure some more detailed parameters in [run.py](run.py) such as . +```bash +--mul_numwork 16 --use_chop False +``` +We recommend using `` --use_chop = False `` when memory allows. +Note that in our tests. + +A single frame of 256x256 requires about 3G RAM-Usage, and a single 4090 card can process about one frame per second. + +A single frame of 512x512 takes about 19G RAM-Usage, and a single 4090 takes about 5 seconds to process a frame. + + diff --git a/opensora/models/super_resolution/basicsr/__init__.py b/opensora/models/super_resolution/basicsr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3360401f72406374b62ec8da2625d2f293e37687 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/__init__.py @@ -0,0 +1,6 @@ +from .archs import * +from .data import * +from .metrics import * +from .models import * +from .test import * +from .utils import * diff --git a/opensora/models/super_resolution/basicsr/archs/__init__.py b/opensora/models/super_resolution/basicsr/archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb1e4d7bb221c429082bd389d9140e5b1cc07b0 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/archs/__init__.py @@ -0,0 +1,25 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.registry import ARCH_REGISTRY + +__all__ = ['build_network'] + +# automatically scan and import arch modules for registry +# scan all the files under the 'archs' folder and collect files ending with +# '_arch.py' +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] + + +def build_network(opt): + opt = deepcopy(opt) + network_type = opt.pop('type') + net = ARCH_REGISTRY.get(network_type)(**opt) + logger = get_root_logger() + logger.info(f'Network [{net.__class__.__name__}] is created.') + return net diff --git a/opensora/models/super_resolution/basicsr/archs/arch_util.py b/opensora/models/super_resolution/basicsr/archs/arch_util.py new file mode 100644 index 0000000000000000000000000000000000000000..1719e8e9fe66cd0adc667a76646bcd9dfe588d5e --- /dev/null +++ b/opensora/models/super_resolution/basicsr/archs/arch_util.py @@ -0,0 +1,318 @@ +import collections.abc +import math +import torch +import torchvision +import warnings +from distutils.version import LooseVersion +from itertools import repeat +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + +# from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv +from basicsr.utils import get_root_logger + + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == 'ratio': + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == 'shape': + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.') + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +# class DCNv2Pack(ModulatedDeformConvPack): +# """Modulated deformable conv for deformable alignment. +# +# Different from the official DCNv2Pack, which generates offsets and masks +# from the preceding features, this DCNv2Pack takes another different +# features to generate offsets and masks. +# +# Ref: +# Delving Deep into Deformable Alignment in Video Super-Resolution. +# """ +# +# def forward(self, x, feat): +# out = self.conv_offset(feat) +# o1, o2, mask = torch.chunk(out, 3, dim=1) +# offset = torch.cat((o1, o2), dim=1) +# mask = torch.sigmoid(mask) +# +# offset_absmean = torch.mean(torch.abs(offset)) +# if offset_absmean > 50: +# logger = get_root_logger() +# logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.') +# +# if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'): +# return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, +# self.dilation, mask) +# else: +# return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, +# self.dilation, self.groups, self.deformable_groups) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + low = norm_cdf((a - mean) / std) + up = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [low, up], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * low - 1, 2 * up - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +# From PyTorch +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple diff --git a/opensora/models/super_resolution/basicsr/archs/rgt_arch.py b/opensora/models/super_resolution/basicsr/archs/rgt_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..810fd765a8a59bd1ddf8ff010441b0cd7ee6126e --- /dev/null +++ b/opensora/models/super_resolution/basicsr/archs/rgt_arch.py @@ -0,0 +1,757 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from torch import Tensor +from torch.nn import functional as F + +from timm.models.layers import DropPath, trunc_normal_ +from einops.layers.torch import Rearrange +from einops import rearrange, repeat + +import math +import numpy as np + +import random + +from basicsr.utils.registry import ARCH_REGISTRY + + +def img2windows(img, H_sp, W_sp): + """ + Input: Image (B, C, H, W) + Output: Window Partition (B', N, C) + """ + B, C, H, W = img.shape + img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) + img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp* W_sp, C) + return img_perm + + +def windows2img(img_splits_hw, H_sp, W_sp, H, W): + """ + Input: Window Partition (B', N, C) + Output: Image (B, H, W, C) + """ + B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp)) + + img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) + img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return img + + +class Gate(nn.Module): + def __init__(self, dim): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.conv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) # DW Conv + + def forward(self, x, H, W): + # Split + x1, x2 = x.chunk(2, dim = -1) + B, N, C = x.shape + x2 = self.conv(self.norm(x2).transpose(1, 2).contiguous().view(B, C//2, H, W)).flatten(2).transpose(-1, -2).contiguous() + + return x1 * x2 + + +class MLP(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.sg = Gate(hidden_features//2) + self.fc2 = nn.Linear(hidden_features//2, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + """ + Input: x: (B, H*W, C), H, W + Output: x: (B, H*W, C) + """ + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + + x = self.sg(x, H, W) + x = self.drop(x) + + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py + """ Dynamic Relative Position Bias. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + residual (bool): If True, use residual strage to connect conv. + """ + def __init__(self, dim, num_heads, residual): + super().__init__() + self.residual = residual + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim) + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads) + ) + def forward(self, biases): + if self.residual: + pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads + pos = pos + self.pos1(pos) + pos = pos + self.pos2(pos) + pos = self.pos3(pos) + else: + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + +class WindowAttention(nn.Module): + def __init__(self, dim, idx, split_size=[8,8], dim_out=None, num_heads=6, attn_drop=0., proj_drop=0., qk_scale=None, position_bias=True): + super().__init__() + self.dim = dim + self.dim_out = dim_out or dim + self.split_size = split_size + self.num_heads = num_heads + self.idx = idx + self.position_bias = position_bias + + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + if idx == 0: + H_sp, W_sp = self.split_size[0], self.split_size[1] + elif idx == 1: + W_sp, H_sp = self.split_size[0], self.split_size[1] + else: + print ("ERROR MODE", idx) + exit(0) + self.H_sp = H_sp + self.W_sp = W_sp + + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False) + # generate mother-set + position_bias_h = torch.arange(1 - self.H_sp, self.H_sp) + position_bias_w = torch.arange(1 - self.W_sp, self.W_sp) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) + biases = biases.flatten(1).transpose(0, 1).contiguous().float() + self.register_buffer('rpe_biases', biases) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.H_sp) + coords_w = torch.arange(self.W_sp) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.H_sp - 1 + relative_coords[:, :, 1] += self.W_sp - 1 + relative_coords[:, :, 0] *= 2 * self.W_sp - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer('relative_position_index', relative_position_index) + + self.attn_drop = nn.Dropout(attn_drop) + + def im2win(self, x, H, W): + B, N, C = x.shape + x = x.transpose(-2,-1).contiguous().view(B, C, H, W) + x = img2windows(x, self.H_sp, self.W_sp) + x = x.reshape(-1, self.H_sp* self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous() + return x + + def forward(self, qkv, H, W, mask=None): + """ + Input: qkv: (B, 3*L, C), H, W, mask: (B, N, N), N is the window size + Output: x (B, H, W, C) + """ + q,k,v = qkv[0], qkv[1], qkv[2] + + B, L, C = q.shape + assert L == H * W, "flatten img_tokens has wrong size" + + # partition the q,k,v, image to window + q = self.im2win(q, H, W) + k = self.im2win(k, H, W) + v = self.im2win(v, H, W) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) # B head N C @ B head C N --> B head N N + + # calculate drpe + if self.position_bias: + pos = self.pos(self.rpe_biases) + # select position bias + relative_position_bias = pos[self.relative_position_index.view(-1)].view( + self.H_sp * self.W_sp, self.H_sp * self.W_sp, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + + N = attn.shape[3] + + # use mask for shift window + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + + attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype) + attn = self.attn_drop(attn) + + x = (attn @ v) + x = x.transpose(1, 2).reshape(-1, self.H_sp* self.W_sp, C) # B head N N @ B head N C + + # merge the window, window to image + x = windows2img(x, self.H_sp, self.W_sp, H, W) # B H' W' C + + return x + + +class L_SA(nn.Module): + # The implementation builds on CAT code https://github.com/zhengchen1999/CAT/blob/main/basicsr/archs/cat_arch.py + def __init__(self, dim, num_heads, + split_size=[2,4], shift_size=[1,2], qkv_bias=False, qk_scale=None, + drop=0., attn_drop=0., idx=0, reso=64, rs_id=0): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.split_size = split_size + self.shift_size = shift_size + self.idx = idx + self.rs_id = rs_id + self.patches_resolution = reso + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + + assert 0 <= self.shift_size[0] < self.split_size[0], "shift_size must in 0-split_size0" + assert 0 <= self.shift_size[1] < self.split_size[1], "shift_size must in 0-split_size1" + + self.branch_num = 2 + + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(drop) + + self.attns = nn.ModuleList([ + WindowAttention( + dim//2, idx = i, + split_size=split_size, num_heads=num_heads//2, dim_out=dim//2, + qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, position_bias=True) + for i in range(self.branch_num)]) + + if (self.rs_id % 2 == 0 and self.idx > 0 and (self.idx - 2) % 4 == 0) or (self.rs_id % 2 != 0 and self.idx % 4 == 0): + attn_mask = self.calculate_mask(self.patches_resolution, self.patches_resolution) + + self.register_buffer("attn_mask_0", attn_mask[0]) + self.register_buffer("attn_mask_1", attn_mask[1]) + else: + attn_mask = None + + self.register_buffer("attn_mask_0", None) + self.register_buffer("attn_mask_1", None) + + self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim) # DW Conv + + def calculate_mask(self, H, W): + # The implementation builds on Swin Transformer code https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for Rwin + img_mask_0 = torch.zeros((1, H, W, 1)) # 1 H W 1 idx=0 + img_mask_1 = torch.zeros((1, H, W, 1)) # 1 H W 1 idx=1 + h_slices_0 = (slice(0, -self.split_size[0]), + slice(-self.split_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)) + w_slices_0 = (slice(0, -self.split_size[1]), + slice(-self.split_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)) + + h_slices_1 = (slice(0, -self.split_size[1]), + slice(-self.split_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)) + w_slices_1 = (slice(0, -self.split_size[0]), + slice(-self.split_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)) + cnt = 0 + for h in h_slices_0: + for w in w_slices_0: + img_mask_0[:, h, w, :] = cnt + cnt += 1 + cnt = 0 + for h in h_slices_1: + for w in w_slices_1: + img_mask_1[:, h, w, :] = cnt + cnt += 1 + + # calculate mask for H-Shift + img_mask_0 = img_mask_0.view(1, H // self.split_size[0], self.split_size[0], W // self.split_size[1], self.split_size[1], 1) + img_mask_0 = img_mask_0.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.split_size[0], self.split_size[1], 1) # nW, sw[0], sw[1], 1 + mask_windows_0 = img_mask_0.view(-1, self.split_size[0] * self.split_size[1]) + attn_mask_0 = mask_windows_0.unsqueeze(1) - mask_windows_0.unsqueeze(2) + attn_mask_0 = attn_mask_0.masked_fill(attn_mask_0 != 0, float(-100.0)).masked_fill(attn_mask_0 == 0, float(0.0)) + + # calculate mask for V-Shift + img_mask_1 = img_mask_1.view(1, H // self.split_size[1], self.split_size[1], W // self.split_size[0], self.split_size[0], 1) + img_mask_1 = img_mask_1.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.split_size[1], self.split_size[0], 1) # nW, sw[1], sw[0], 1 + mask_windows_1 = img_mask_1.view(-1, self.split_size[1] * self.split_size[0]) + attn_mask_1 = mask_windows_1.unsqueeze(1) - mask_windows_1.unsqueeze(2) + attn_mask_1 = attn_mask_1.masked_fill(attn_mask_1 != 0, float(-100.0)).masked_fill(attn_mask_1 == 0, float(0.0)) + + return attn_mask_0, attn_mask_1 + + def forward(self, x, H, W): + """ + Input: x: (B, H*W, C), x_size: (H, W) + Output: x: (B, H*W, C) + """ + + B, L, C = x.shape + assert L == H * W, "flatten img_tokens has wrong size" + + qkv = self.qkv(x).reshape(B, -1, 3, C).permute(2, 0, 1, 3) # 3, B, HW, C + # v without partition + v = qkv[2].transpose(-2,-1).contiguous().view(B, C, H, W) + + + max_split_size = max(self.split_size[0], self.split_size[1]) + pad_l = pad_t = 0 + pad_r = (max_split_size - W % max_split_size) % max_split_size + pad_b = (max_split_size - H % max_split_size) % max_split_size + + qkv = qkv.reshape(3*B, H, W, C).permute(0, 3, 1, 2) # 3B C H W + qkv = F.pad(qkv, (pad_l, pad_r, pad_t, pad_b)).reshape(3, B, C, -1).transpose(-2, -1) # l r t b + _H = pad_b + H + _W = pad_r + W + _L = _H * _W + + if (self.rs_id % 2 == 0 and self.idx > 0 and (self.idx - 2) % 4 == 0) or (self.rs_id % 2 != 0 and self.idx % 4 == 0): + qkv = qkv.view(3, B, _H, _W, C) + # H-Shift + qkv_0 = torch.roll(qkv[:,:,:,:,:C//2], shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(2, 3)) + qkv_0 = qkv_0.view(3, B, _L, C//2) + # V-Shift + qkv_1 = torch.roll(qkv[:,:,:,:,C//2:], shifts=(-self.shift_size[1], -self.shift_size[0]), dims=(2, 3)) + qkv_1 = qkv_1.view(3, B, _L, C//2) + + if self.patches_resolution != _H or self.patches_resolution != _W: + mask_tmp = self.calculate_mask(_H, _W) + # H-Rwin + x1_shift = self.attns[0](qkv_0, _H, _W, mask=mask_tmp[0].to(x.device)) + # V-Rwin + x2_shift = self.attns[1](qkv_1, _H, _W, mask=mask_tmp[1].to(x.device)) + + else: + # H-Rwin + x1_shift = self.attns[0](qkv_0, _H, _W, mask=self.attn_mask_0) + # V-Rwin + x2_shift = self.attns[1](qkv_1, _H, _W, mask=self.attn_mask_1) + + x1 = torch.roll(x1_shift, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2)) + x2 = torch.roll(x2_shift, shifts=(self.shift_size[1], self.shift_size[0]), dims=(1, 2)) + x1 = x1[:, :H, :W, :].reshape(B, L, C//2) + x2 = x2[:, :H, :W, :].reshape(B, L, C//2) + # Concat + attened_x = torch.cat([x1,x2], dim=2) + else: + # V-Rwin + x1 = self.attns[0](qkv[:,:,:,:C//2], _H, _W)[:, :H, :W, :].reshape(B, L, C//2) + # H-Rwin + x2 = self.attns[1](qkv[:,:,:,C//2:], _H, _W)[:, :H, :W, :].reshape(B, L, C//2) + # Concat + attened_x = torch.cat([x1,x2], dim=2) + + # mix + lcm = self.get_v(v) + lcm = lcm.permute(0, 2, 3, 1).contiguous().view(B, L, C) + + x = attened_x + lcm + + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class RG_SA(nn.Module): + """ + Recursive-Generalization Self-Attention (RG-SA). + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + c_ratio (float): channel adjustment factor. + """ + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., c_ratio=0.5): + super(RG_SA, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + self.num_heads = num_heads + head_dim = dim // num_heads + + self.cr = int(dim * c_ratio) # scaled channel dimension + + # self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or (head_dim * c_ratio) ** -0.5 + + # RGM + self.reduction1 = nn.Conv2d(dim, dim, kernel_size=4, stride=4, groups=dim) + self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim) + self.conv = nn.Conv2d(dim, self.cr, kernel_size=1, stride=1) + self.norm_act = nn.Sequential( + nn.LayerNorm(self.cr), + nn.GELU()) + # CA + self.q = nn.Linear(dim, self.cr, bias=qkv_bias) + self.k = nn.Linear(self.cr, self.cr, bias=qkv_bias) + self.v = nn.Linear(self.cr, dim, bias=qkv_bias) + + # CPE + self.cpe = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) + + self.proj = nn.Linear(dim, dim) + self.attn_drop = nn.Dropout(attn_drop) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, H, W): + B, N, C = x.shape + + _scale = 1 + + # reduction + _x = x.permute(0, 2, 1).reshape(B, C, H, W).contiguous() + + if self.training: + _time = max(int(math.log(H//4, 4)), int(math.log(W//4, 4))) + else: + _time = max(int(math.log(H//16, 4)), int(math.log(W//16, 4))) + if _time < 2: _time = 2 # testing _time must equal or larger than training _time (2) + + _scale = 4 ** _time + + # Recursion xT + for _ in range(_time): + _x = self.reduction1(_x) + + _x = self.conv(self.dwconv(_x)).reshape(B, self.cr, -1).permute(0, 2, 1).contiguous() # shape=(B, N', C') + _x = self.norm_act(_x) + + # q, k, v, where q_shape=(B, N, C'), k_shape=(B, N', C'), v_shape=(B, N', C) + q = self.q(x).reshape(B, N, self.num_heads, int(self.cr / self.num_heads)).permute(0, 2, 1, 3) + k = self.k(_x).reshape(B, -1, self.num_heads, int(self.cr / self.num_heads)).permute(0, 2, 1, 3) + v = self.v(_x).reshape(B, -1, self.num_heads, int(C / self.num_heads)).permute(0, 2, 1, 3) + + # corss-attention + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # CPE + # v_shape=(B, H, N', C//H) + v = v + self.cpe(v.transpose(1, 2).reshape(B, -1, C).transpose(1, 2).contiguous().view(B, C, H // _scale, W // _scale)).view(B, C, -1).view(B, self.num_heads, int(C / self.num_heads), -1).transpose(-1, -2) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., + attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, idx=0, + rs_id=0, split_size=[2,4], shift_size=[1,2], reso=64, c_ratio=0.5, layerscale_value=1e-4): + super().__init__() + self.norm1 = norm_layer(dim) + if idx % 2 == 0: + self.attn = L_SA( + dim, split_size=split_size, shift_size=shift_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, + drop=drop, idx=idx, reso=reso, rs_id=rs_id + ) + else: + self.attn = RG_SA( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, + proj_drop=drop, c_ratio=c_ratio + ) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer) + self.norm2 = norm_layer(dim) + + # HAI + self.gamma = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) + + def forward(self, x, x_size): + H , W = x_size + + res = x + + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + + # HAI + x = x + (res * self.gamma) + + return x + + +class ResidualGroup(nn.Module): + + def __init__( self, + dim, + reso, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_paths=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + depth=2, + use_chk=False, + resi_connection='1conv', + rs_id=0, + split_size=[8,8], + c_ratio = 0.5): + super().__init__() + self.use_chk = use_chk + self.reso = reso + + self.blocks = nn.ModuleList([ + Block( + dim=dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_paths[i], + act_layer=act_layer, + norm_layer=norm_layer, + idx = i, + rs_id = rs_id, + split_size = split_size, + shift_size = [split_size[0]//2, split_size[1]//2], + c_ratio = c_ratio, + )for i in range(depth)]) + + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + self.conv = nn.Sequential( + nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + def forward(self, x, x_size): + """ + Input: x: (B, H*W, C), x_size: (H, W) + Output: x: (B, H*W, C) + """ + H, W = x_size + res = x + for blk in self.blocks: + if self.use_chk: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W) + x = self.conv(x) + x = rearrange(x, "b c h w -> b (h w) c") + x = res + x + + return x + + +class Upsample(nn.Sequential): + """Upsample module. + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +@ARCH_REGISTRY.register() +class RGT(nn.Module): + + def __init__(self, + img_size=64, + in_chans=3, + embed_dim=180, + depth=[2,2,2,2], + num_heads=[2,2,2,2], + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_chk=False, + upscale=2, + img_range=1., + resi_connection='1conv', + split_size=[8,8], + c_ratio=0.5, + **kwargs): + super().__init__() + + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + + # ------------------------- 1, Shallow Feature Extraction ------------------------- # + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + # ------------------------- 2, Deep Feature Extraction ------------------------- # + self.num_layers = len(depth) + self.use_chk = use_chk + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + heads=num_heads + + self.before_RG = nn.Sequential( + Rearrange('b c h w -> b (h w) c'), + nn.LayerNorm(embed_dim) + ) + + curr_dim = embed_dim + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, np.sum(depth))] # stochastic depth decay rule + + self.layers = nn.ModuleList() + for i in range(self.num_layers): + layer = ResidualGroup( + dim=embed_dim, + num_heads=heads[i], + reso=img_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_paths=dpr[sum(depth[:i]):sum(depth[:i + 1])], + act_layer=act_layer, + norm_layer=norm_layer, + depth=depth[i], + use_chk=use_chk, + resi_connection=resi_connection, + rs_id=i, + split_size = split_size, + c_ratio = c_ratio + ) + self.layers.append(layer) + + self.norm = norm_layer(curr_dim) + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + # ------------------------- 3, Reconstruction ------------------------- # + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm, nn.InstanceNorm2d)): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + _, _, H, W = x.shape + x_size = [H, W] + x = self.before_RG(x) + for layer in self.layers: + x = layer(x, x_size) + x = self.norm(x) + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W) + + return x + + def forward(self, x): + """ + Input: x: (B, C, H, W) + """ + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + + x = x / self.img_range + self.mean + return x + + +if __name__ == '__main__': + upscale = 1 + height = 62 + width = 66 + model = RGT( + upscale=2, + in_chans=3, + img_size=64, + img_range=1., + depth=[6,6,6,6,6,6], + embed_dim=180, + num_heads=[6,6,6,6,6,6], + mlp_ratio=2, + resi_connection='1conv', + split_size=[8, 8], + upsampler='pixelshuffle').cuda() + # print(model) + print(height, width) + + x = torch.randn((1, 3, height, width)).cuda() + x = model(x) + print(x.shape) \ No newline at end of file diff --git a/opensora/models/super_resolution/basicsr/archs/vgg_arch.py b/opensora/models/super_resolution/basicsr/archs/vgg_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..1e44c0a11c355fde5847ed9b42dae2ad7e42ea9d --- /dev/null +++ b/opensora/models/super_resolution/basicsr/archs/vgg_arch.py @@ -0,0 +1,161 @@ +import os +import torch +from collections import OrderedDict +from torch import nn as nn +from torchvision.models import vgg as vgg + +from basicsr.utils.registry import ARCH_REGISTRY + +VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' +NAMES = { + 'vgg11': [ + 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', + 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', + 'pool5' + ], + 'vgg13': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', + 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' + ], + 'vgg16': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', + 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', + 'pool5' + ], + 'vgg19': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', + 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', + 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' + ] +} + + +def insert_bn(names): + """Insert bn layer after each conv. + + Args: + names (list): The list of layer names. + + Returns: + list: The list of layer names with bn layers. + """ + names_bn = [] + for name in names: + names_bn.append(name) + if 'conv' in name: + position = name.replace('conv', '') + names_bn.append('bn' + position) + return names_bn + + +@ARCH_REGISTRY.register() +class VGGFeatureExtractor(nn.Module): + """VGG network for feature extraction. + + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + + Args: + layer_name_list (list[str]): Forward function returns the corresponding + features according to the layer_name_list. + Example: {'relu1_1', 'relu2_1', 'relu3_1'}. + vgg_type (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. Importantly, + the input feature must in the range [0, 1]. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + requires_grad (bool): If true, the parameters of VGG network will be + optimized. Default: False. + remove_pooling (bool): If true, the max pooling operations in VGG net + will be removed. Default: False. + pooling_stride (int): The stride of max pooling operation. Default: 2. + """ + + def __init__(self, + layer_name_list, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + requires_grad=False, + remove_pooling=False, + pooling_stride=2): + super(VGGFeatureExtractor, self).__init__() + + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + self.names = NAMES[vgg_type.replace('_bn', '')] + if 'bn' in vgg_type: + self.names = insert_bn(self.names) + + # only borrow layers that will be used to avoid unused params + max_idx = 0 + for v in layer_name_list: + idx = self.names.index(v) + if idx > max_idx: + max_idx = idx + + if os.path.exists(VGG_PRETRAIN_PATH): + vgg_net = getattr(vgg, vgg_type)(pretrained=False) + state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) + vgg_net.load_state_dict(state_dict) + else: + vgg_net = getattr(vgg, vgg_type)(pretrained=True) + + features = vgg_net.features[:max_idx + 1] + + modified_net = OrderedDict() + for k, v in zip(self.names, features): + if 'pool' in k: + # if remove_pooling is true, pooling operation will be removed + if remove_pooling: + continue + else: + # in some cases, we may want to change the default stride + modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) + else: + modified_net[k] = v + + self.vgg_net = nn.Sequential(modified_net) + + if not requires_grad: + self.vgg_net.eval() + for param in self.parameters(): + param.requires_grad = False + else: + self.vgg_net.train() + for param in self.parameters(): + param.requires_grad = True + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + if self.range_norm: + x = (x + 1) / 2 + if self.use_input_norm: + x = (x - self.mean) / self.std + + output = {} + for key, layer in self.vgg_net._modules.items(): + x = layer(x) + if key in self.layer_name_list: + output[key] = x.clone() + + return output \ No newline at end of file diff --git a/opensora/models/super_resolution/basicsr/data/__init__.py b/opensora/models/super_resolution/basicsr/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..12a7ea0538bb3ff33755a7440bb7aa963c67efdf --- /dev/null +++ b/opensora/models/super_resolution/basicsr/data/__init__.py @@ -0,0 +1,101 @@ +import importlib +import numpy as np +import random +import torch +import torch.utils.data +from copy import deepcopy +from functools import partial +from os import path as osp + +from basicsr.data.prefetch_dataloader import PrefetchDataLoader +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.dist_util import get_dist_info +from basicsr.utils.registry import DATASET_REGISTRY + +__all__ = ['build_dataset', 'build_dataloader'] + +# automatically scan and import dataset modules for registry +# scan all the files under the data folder with '_dataset' in file names +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] + + +def build_dataset(dataset_opt): + """Build dataset from options. + + Args: + dataset_opt (dict): Configuration for dataset. It must contain: + name (str): Dataset name. + type (str): Dataset type. + """ + dataset_opt = deepcopy(dataset_opt) + dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) + # logger = get_root_logger() + # logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') + return dataset + + +def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): + """Build dataloader. + + Args: + dataset (torch.utils.data.Dataset): Dataset. + dataset_opt (dict): Dataset options. It contains the following keys: + phase (str): 'train' or 'val'. + num_worker_per_gpu (int): Number of workers for each GPU. + batch_size_per_gpu (int): Training batch size for each GPU. + num_gpu (int): Number of GPUs. Used only in the train phase. + Default: 1. + dist (bool): Whether in distributed training. Used only in the train + phase. Default: False. + sampler (torch.utils.data.sampler): Data sampler. Default: None. + seed (int | None): Seed. Default: None + """ + phase = dataset_opt['phase'] + rank, _ = get_dist_info() + if phase == 'train': + if dist: # distributed training + batch_size = dataset_opt['batch_size_per_gpu'] + num_workers = dataset_opt['num_worker_per_gpu'] + else: # non-distributed training + multiplier = 1 if num_gpu == 0 else num_gpu + batch_size = dataset_opt['batch_size_per_gpu'] * multiplier + num_workers = dataset_opt['num_worker_per_gpu'] * multiplier + dataloader_args = dict( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=sampler, + drop_last=True) + if sampler is None: + dataloader_args['shuffle'] = True + dataloader_args['worker_init_fn'] = partial( + worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None + elif phase in ['val', 'test']: # validation + dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) + else: + raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") + + dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) + dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) + + prefetch_mode = dataset_opt.get('prefetch_mode') + if prefetch_mode == 'cpu': # CPUPrefetcher + num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) + logger = get_root_logger() + logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') + return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) + else: + # prefetch_mode=None: Normal dataloader + # prefetch_mode='cuda': dataloader for CUDAPrefetcher + return torch.utils.data.DataLoader(**dataloader_args) + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # Set the worker seed to num_workers * rank + worker_id + seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/opensora/models/super_resolution/basicsr/data/data_sampler.py b/opensora/models/super_resolution/basicsr/data/data_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..575452d9f844a928f7f42296c81635cfbadec7c2 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/data/data_sampler.py @@ -0,0 +1,48 @@ +import math +import torch +from torch.utils.data.sampler import Sampler + + +class EnlargedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + Modified from torch.utils.data.distributed.DistributedSampler + Support enlarging the dataset for iteration-based training, for saving + time when restart the dataloader after each epoch + + Args: + dataset (torch.utils.data.Dataset): Dataset used for sampling. + num_replicas (int | None): Number of processes participating in + the training. It is usually the world_size. + rank (int | None): Rank of the current process within num_replicas. + ratio (int): Enlarging ratio. Default: 1. + """ + + def __init__(self, dataset, num_replicas, rank, ratio=1): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() + + dataset_size = len(self.dataset) + indices = [v % dataset_size for v in indices] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/opensora/models/super_resolution/basicsr/data/data_util.py b/opensora/models/super_resolution/basicsr/data/data_util.py new file mode 100644 index 0000000000000000000000000000000000000000..de4811dad4ad295be6c992c997412c46920d6583 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/data/data_util.py @@ -0,0 +1,280 @@ +import cv2 +import numpy as np +import torch +from os import path as osp +from torch.nn import functional as F + +from basicsr.utils import img2tensor, scandir + + +def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'): + """Generate an index list for reading `num_frames` frames from a sequence + of images. + + Args: + crt_idx (int): Current center index. + max_frame_num (int): Max number of the sequence of images (from 1). + num_frames (int): Reading num_frames frames. + padding (str): Padding mode, one of + 'replicate' | 'reflection' | 'reflection_circle' | 'circle' + Examples: current_idx = 0, num_frames = 5 + The generated frame indices under different padding mode: + replicate: [0, 0, 0, 1, 2] + reflection: [2, 1, 0, 1, 2] + reflection_circle: [4, 3, 0, 1, 2] + circle: [3, 4, 0, 1, 2] + + Returns: + list[int]: A list of indices. + """ + assert num_frames % 2 == 1, 'num_frames should be an odd number.' + assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.' + + max_frame_num = max_frame_num - 1 # start from 0 + num_pad = num_frames // 2 + + indices = [] + for i in range(crt_idx - num_pad, crt_idx + num_pad + 1): + if i < 0: + if padding == 'replicate': + pad_idx = 0 + elif padding == 'reflection': + pad_idx = -i + elif padding == 'reflection_circle': + pad_idx = crt_idx + num_pad - i + else: + pad_idx = num_frames + i + elif i > max_frame_num: + if padding == 'replicate': + pad_idx = max_frame_num + elif padding == 'reflection': + pad_idx = max_frame_num * 2 - i + elif padding == 'reflection_circle': + pad_idx = (crt_idx - num_pad) - (i - max_frame_num) + else: + pad_idx = i - num_frames + else: + pad_idx = i + indices.append(pad_idx) + return indices + + +def paired_paths_from_lmdb(folders, keys): + """Generate paired paths from lmdb files. + + Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is: + + lq.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records + 1)image name (with extension), + 2)image shape, + 3)compression level, separated by a white space. + Example: `baboon.png (120,125,3) 1` + + We use the image name without extension as the lmdb key. + Note that we use the same key for the corresponding lq and gt images. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + Note that this key is different from lmdb keys. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}' + input_folder, gt_folder = folders + input_key, gt_key = keys + + if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')): + raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb ' + f'formats. But received {input_key}: {input_folder}; ' + f'{gt_key}: {gt_folder}') + # ensure that the two meta_info files are the same + with open(osp.join(input_folder, 'meta_info.txt')) as fin: + input_lmdb_keys = [line.split('.')[0] for line in fin] + with open(osp.join(gt_folder, 'meta_info.txt')) as fin: + gt_lmdb_keys = [line.split('.')[0] for line in fin] + if set(input_lmdb_keys) != set(gt_lmdb_keys): + raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.') + else: + paths = [] + for lmdb_key in sorted(input_lmdb_keys): + paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)])) + return paths + + +def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl): + """Generate paired paths from an meta information file. + + Each line in the meta information file contains the image names and + image shape (usually for gt), separated by a white space. + + Example of an meta information file: + ``` + 0001_s001.png (480,480,3) + 0001_s002.png (480,480,3) + ``` + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + meta_info_file (str): Path to the meta information file. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}' + input_folder, gt_folder = folders + input_key, gt_key = keys + + with open(meta_info_file, 'r') as fin: + gt_names = [line.strip().split(' ')[0] for line in fin] + + paths = [] + for gt_name in gt_names: + basename, ext = osp.splitext(osp.basename(gt_name)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + gt_path = osp.join(gt_folder, gt_name) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + + +def paired_paths_from_folder(folders, keys, filename_tmpl): + """Generate paired paths from folders. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}' + input_folder, gt_folder = folders + input_key, gt_key = keys + + input_paths = list(scandir(input_folder)) + gt_paths = list(scandir(gt_folder)) + assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: ' + f'{len(input_paths)}, {len(gt_paths)}.') + paths = [] + for gt_path in gt_paths: + basename, ext = osp.splitext(osp.basename(gt_path)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.' + gt_path = osp.join(gt_folder, gt_path) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + + +def paths_from_folder(folder): + """Generate paths from folder. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + + paths = list(scandir(folder)) + paths = [osp.join(folder, path) for path in paths] + return paths + + +def paths_from_lmdb(folder): + """Generate paths from lmdb. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + if not folder.endswith('.lmdb'): + raise ValueError(f'Folder {folder}folder should in lmdb format.') + with open(osp.join(folder, 'meta_info.txt')) as fin: + paths = [line.split('.')[0] for line in fin] + return paths + + +def generate_gaussian_kernel(kernel_size=13, sigma=1.6): + """Generate Gaussian kernel used in `duf_downsample`. + + Args: + kernel_size (int): Kernel size. Default: 13. + sigma (float): Sigma of the Gaussian kernel. Default: 1.6. + + Returns: + np.array: The Gaussian kernel. + """ + from scipy.ndimage import filters as filters + kernel = np.zeros((kernel_size, kernel_size)) + # set element at the middle to one, a dirac delta + kernel[kernel_size // 2, kernel_size // 2] = 1 + # gaussian-smooth the dirac, resulting in a gaussian filter + return filters.gaussian_filter(kernel, sigma) + + +def duf_downsample(x, kernel_size=13, scale=4): + """Downsamping with Gaussian kernel used in the DUF official code. + + Args: + x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w). + kernel_size (int): Kernel size. Default: 13. + scale (int): Downsampling factor. Supported scale: (2, 3, 4). + Default: 4. + + Returns: + Tensor: DUF downsampled frames. + """ + assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.' + + squeeze_flag = False + if x.ndim == 4: + squeeze_flag = True + x = x.unsqueeze(0) + b, t, c, h, w = x.size() + x = x.view(-1, 1, h, w) + pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2 + x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect') + + gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale) + gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0) + x = F.conv2d(x, gaussian_filter, stride=scale) + x = x[:, :, 2:-2, 2:-2] + x = x.view(b, t, c, x.size(2), x.size(3)) + if squeeze_flag: + x = x.squeeze(0) + return x diff --git a/opensora/models/super_resolution/basicsr/data/paired_image_dataset.py b/opensora/models/super_resolution/basicsr/data/paired_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..23749f8d5d314b90482903cd6430658356b6f1fa --- /dev/null +++ b/opensora/models/super_resolution/basicsr/data/paired_image_dataset.py @@ -0,0 +1,113 @@ +from torch.utils import data as data +from torchvision.transforms.functional import normalize + +from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file +from basicsr.data.transforms import augment, paired_random_crop +from basicsr.utils import FileClient, imfrombytes, img2tensor +from basicsr.utils.matlab_functions import bgr2ycbcr +from basicsr.utils.registry import DATASET_REGISTRY + +import numpy as np + +@DATASET_REGISTRY.register() +class PairedImageDataset(data.Dataset): + """Paired image dataset for image restoration. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. + + There are three modes: + 1. 'lmdb': Use lmdb files. + If opt['io_backend'] == lmdb. + 2. 'meta_info_file': Use meta information file to generate paths. + If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. + 3. 'folder': Scan folders to generate paths. + The rest. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + meta_info_file (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. + Default: '{}'. + gt_size (int): Cropped patched size for gt patches. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + + scale (bool): Scale, which will be added automatically. + phase (str): 'train' or 'val'. + """ + + def __init__(self, opt): + super(PairedImageDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + + self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] + if 'filename_tmpl' in opt: + self.filename_tmpl = opt['filename_tmpl'] + else: + self.filename_tmpl = '{}' + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) + elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: + self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'], + self.opt['meta_info_file'], self.filename_tmpl) + else: + self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + + # image range: [0, 1], float32., H W 3 + gt_path = self.paths[index]['gt_path'] + img_bytes = self.file_client.get(gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + lq_path = self.paths[index]['lq_path'] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + # augmentation for training + if self.opt['phase'] == 'train': + gt_size = self.opt['gt_size'] + # random crop + img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) + + # color space transform + if 'color' in self.opt and self.opt['color'] == 'y': + img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None] + img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None] + + # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets + # TODO: It is better to update the datasets, rather than force to crop + if self.opt['phase'] != 'train': + img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + # print(img_lq.shape,img_gt.shape,img_lq.min(),img_gt.min(),img_lq.max(),img_gt.max(),lq_path,gt_path) + + return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/opensora/models/super_resolution/basicsr/data/prefetch_dataloader.py b/opensora/models/super_resolution/basicsr/data/prefetch_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..5088425050d4cc98114a9b93eb50ea60273f35a0 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/data/prefetch_dataloader.py @@ -0,0 +1,125 @@ +import queue as Queue +import threading +import torch +from torch.utils.data import DataLoader + + +class PrefetchGenerator(threading.Thread): + """A general prefetch generator. + + Ref: + https://stackoverflow.com/questions/7323664/python-generator-pre-fetch + + Args: + generator: Python generator. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, generator, num_prefetch_queue): + threading.Thread.__init__(self) + self.queue = Queue.Queue(num_prefetch_queue) + self.generator = generator + self.daemon = True + self.start() + + def run(self): + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def __next__(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class PrefetchDataLoader(DataLoader): + """Prefetch version of dataloader. + + Ref: + https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# + + TODO: + Need to test on single gpu and ddp (multi-gpu). There is a known issue in + ddp. + + Args: + num_prefetch_queue (int): Number of prefetch queue. + kwargs (dict): Other arguments for dataloader. + """ + + def __init__(self, num_prefetch_queue, **kwargs): + self.num_prefetch_queue = num_prefetch_queue + super(PrefetchDataLoader, self).__init__(**kwargs) + + def __iter__(self): + return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) + + +class CPUPrefetcher(): + """CPU prefetcher. + + Args: + loader: Dataloader. + """ + + def __init__(self, loader): + self.ori_loader = loader + self.loader = iter(loader) + + def next(self): + try: + return next(self.loader) + except StopIteration: + return None + + def reset(self): + self.loader = iter(self.ori_loader) + + +class CUDAPrefetcher(): + """CUDA prefetcher. + + Ref: + https://github.com/NVIDIA/apex/issues/304# + + It may consums more GPU memory. + + Args: + loader: Dataloader. + opt (dict): Options. + """ + + def __init__(self, loader, opt): + self.ori_loader = loader + self.loader = iter(loader) + self.opt = opt + self.stream = torch.cuda.Stream() + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.preload() + + def preload(self): + try: + self.batch = next(self.loader) # self.batch is a dict + except StopIteration: + self.batch = None + return None + # put tensors to gpu + with torch.cuda.stream(self.stream): + for k, v in self.batch.items(): + if torch.is_tensor(v): + self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + self.preload() + return batch + + def reset(self): + self.loader = iter(self.ori_loader) + self.preload() diff --git a/opensora/models/super_resolution/basicsr/data/single_image_dataset.py b/opensora/models/super_resolution/basicsr/data/single_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..795803a10f02c649834c1daed7a87804a8426305 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/data/single_image_dataset.py @@ -0,0 +1,69 @@ +from os import path as osp +from torch.utils import data as data +from torchvision.transforms.functional import normalize + +from basicsr.data.data_util import paths_from_lmdb +from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir +from basicsr.utils.matlab_functions import rgb2ycbcr +from basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register() +class SingleImageDataset(data.Dataset): + """Read only lq images in the test phase. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). + + There are two modes: + 1. 'meta_info_file': Use meta information file to generate paths. + 2. 'folder': Scan folders to generate paths. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_lq (str): Data root path for lq. + meta_info_file (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + """ + + def __init__(self, opt): + super(SingleImageDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + self.lq_folder = opt['dataroot_lq'] + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder] + self.io_backend_opt['client_keys'] = ['lq'] + self.paths = paths_from_lmdb(self.lq_folder) + elif 'meta_info_file' in self.opt: + with open(self.opt['meta_info_file'], 'r') as fin: + self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin] + else: + self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load lq image + lq_path = self.paths[index] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + # color space transform + if 'color' in self.opt and self.opt['color'] == 'y': + img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] + + # BGR to RGB, HWC to CHW, numpy to tensor + img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + return {'lq': img_lq, 'lq_path': lq_path} + + def __len__(self): + return len(self.paths) diff --git a/opensora/models/super_resolution/basicsr/data/transforms.py b/opensora/models/super_resolution/basicsr/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d9bbb5fb7daef5edfb425fafb4d67d471b3001e6 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/data/transforms.py @@ -0,0 +1,179 @@ +import cv2 +import random +import torch + + +def mod_crop(img, scale): + """Mod crop images, used during testing. + + Args: + img (ndarray): Input image. + scale (int): Scale factor. + + Returns: + ndarray: Result image. + """ + img = img.copy() + if img.ndim in (2, 3): + h, w = img.shape[0], img.shape[1] + h_remainder, w_remainder = h % scale, w % scale + img = img[:h - h_remainder, :w - w_remainder, ...] + else: + raise ValueError(f'Wrong img ndim: {img.ndim}.') + return img + + +def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None): + """Paired random crop. Support Numpy array and Tensor inputs. + + It crops lists of lq and gt images with corresponding locations. + + Args: + img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + img_lqs (list[ndarray] | ndarray): LQ images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + gt_patch_size (int): GT patch size. + scale (int): Scale factor. + gt_path (str): Path to ground-truth. Default: None. + + Returns: + list[ndarray] | ndarray: GT images and LQ images. If returned results + only have one element, just return ndarray. + """ + + if not isinstance(img_gts, list): + img_gts = [img_gts] + if not isinstance(img_lqs, list): + img_lqs = [img_lqs] + + # determine input type: Numpy array or Tensor + input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy' + + if input_type == 'Tensor': + h_lq, w_lq = img_lqs[0].size()[-2:] + h_gt, w_gt = img_gts[0].size()[-2:] + else: + h_lq, w_lq = img_lqs[0].shape[0:2] + h_gt, w_gt = img_gts[0].shape[0:2] + lq_patch_size = gt_patch_size // scale + + if h_gt != h_lq * scale or w_gt != w_lq * scale: + raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', + f'multiplication of LQ ({h_lq}, {w_lq}).') + if h_lq < lq_patch_size or w_lq < lq_patch_size: + raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' + f'({lq_patch_size}, {lq_patch_size}). ' + f'Please remove {gt_path}.') + + # randomly choose top and left coordinates for lq patch + top = random.randint(0, h_lq - lq_patch_size) + left = random.randint(0, w_lq - lq_patch_size) + + # crop lq patch + if input_type == 'Tensor': + img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs] + else: + img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] + + # crop corresponding gt patch + top_gt, left_gt = int(top * scale), int(left * scale) + if input_type == 'Tensor': + img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts] + else: + img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] + if len(img_gts) == 1: + img_gts = img_gts[0] + if len(img_lqs) == 1: + img_lqs = img_lqs[0] + return img_gts, img_lqs + + +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + + We use vertical flip and transpose for rotation implementation. + All the images in the list use the same augmentation. + + Args: + imgs (list[ndarray] | ndarray): Images to be augmented. If the input + is an ndarray, it will be transformed to a list. + hflip (bool): Horizontal flip. Default: True. + rotation (bool): Ratotation. Default: True. + flows (list[ndarray]: Flows to be augmented. If the input is an + ndarray, it will be transformed to a list. + Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. + + Returns: + list[ndarray] | ndarray: Augmented images and flows. If returned + results only have one element, just return ndarray. + + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: # horizontal + cv2.flip(flow, 1, flow) + flow[:, :, 0] *= -1 + if vflip: # vertical + cv2.flip(flow, 0, flow) + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + if flows is not None: + if not isinstance(flows, list): + flows = [flows] + flows = [_augment_flow(flow) for flow in flows] + if len(flows) == 1: + flows = flows[0] + return imgs, flows + else: + if return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs + + +def img_rotate(img, angle, center=None, scale=1.0): + """Rotate image. + + Args: + img (ndarray): Image to be rotated. + angle (float): Rotation angle in degrees. Positive values mean + counter-clockwise rotation. + center (tuple[int]): Rotation center. If the center is None, + initialize it as the center of the image. Default: None. + scale (float): Isotropic scale factor. Default: 1.0. + """ + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + rotated_img = cv2.warpAffine(img, matrix, (w, h)) + return rotated_img diff --git a/opensora/models/super_resolution/basicsr/losses/__init__.py b/opensora/models/super_resolution/basicsr/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..14942900c1abc657d6ea649446ec13c2fbb39387 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/losses/__init__.py @@ -0,0 +1,26 @@ +from copy import deepcopy + +from basicsr.utils import get_root_logger +from basicsr.utils.registry import LOSS_REGISTRY +from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, WeightedTVLoss, g_path_regularize, + gradient_penalty_loss, r1_penalty) + +__all__ = [ + 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'GANLoss', 'gradient_penalty_loss', + 'r1_penalty', 'g_path_regularize' +] + + +def build_loss(opt): + """Build loss from options. + + Args: + opt (dict): Configuration. It must contain: + type (str): Model type. + """ + opt = deepcopy(opt) + loss_type = opt.pop('type') + loss = LOSS_REGISTRY.get(loss_type)(**opt) + logger = get_root_logger() + logger.info(f'Loss [{loss.__class__.__name__}] is created.') + return loss diff --git a/opensora/models/super_resolution/basicsr/losses/loss_util.py b/opensora/models/super_resolution/basicsr/losses/loss_util.py new file mode 100644 index 0000000000000000000000000000000000000000..744eeb46d1f3b5a7b4553ca23237ddd9c899a698 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/losses/loss_util.py @@ -0,0 +1,95 @@ +import functools +from torch.nn import functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are 'none', 'mean' and 'sum'. + + Returns: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + else: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean'): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. Default: None. + reduction (str): Same as built-in losses of PyTorch. Options are + 'none', 'mean' and 'sum'. Default: 'mean'. + + Returns: + Tensor: Loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if weight is not specified or reduction is sum, just reduce the loss + if weight is None or reduction == 'sum': + loss = reduce_loss(loss, reduction) + # if reduction is mean, then compute mean over weight region + elif reduction == 'mean': + if weight.size(1) > 1: + weight = weight.sum() + else: + weight = weight.sum() * loss.size(1) + loss = loss.sum() / weight + + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.5000) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, reduction='sum') + tensor(3.) + """ + + @functools.wraps(loss_func) + def wrapper(pred, target, weight=None, reduction='mean', **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction) + return loss + + return wrapper diff --git a/opensora/models/super_resolution/basicsr/losses/losses.py b/opensora/models/super_resolution/basicsr/losses/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..55436902eb67b456a39742c7eda75b73471e5f5b --- /dev/null +++ b/opensora/models/super_resolution/basicsr/losses/losses.py @@ -0,0 +1,492 @@ +import math +import torch +from torch import autograd as autograd +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.archs.vgg_arch import VGGFeatureExtractor +from basicsr.utils.registry import LOSS_REGISTRY +from .loss_util import weighted_loss + +_reduction_modes = ['none', 'mean', 'sum'] + + +@weighted_loss +def l1_loss(pred, target): + return F.l1_loss(pred, target, reduction='none') + + +@weighted_loss +def mse_loss(pred, target): + return F.mse_loss(pred, target, reduction='none') + + +@weighted_loss +def charbonnier_loss(pred, target, eps=1e-12): + return torch.sqrt((pred - target)**2 + eps) + + +@LOSS_REGISTRY.register() +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class MSELoss(nn.Module): + """MSE (L2) loss. + + Args: + loss_weight (float): Loss weight for MSE loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(MSELoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class CharbonnierLoss(nn.Module): + """Charbonnier loss (one variant of Robust L1Loss, a differentiable + variant of L1Loss). + + Described in "Deep Laplacian Pyramid Networks for Fast and Accurate + Super-Resolution". + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + eps (float): A value used to control the curvature near zero. Default: 1e-12. + """ + + def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12): + super(CharbonnierLoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + self.eps = eps + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class WeightedTVLoss(L1Loss): + """Weighted TV loss. + + Args: + loss_weight (float): Loss weight. Default: 1.0. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + if reduction not in ['mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum') + super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction) + + def forward(self, pred, weight=None): + if weight is None: + y_weight = None + x_weight = None + else: + y_weight = weight[:, :, :-1, :] + x_weight = weight[:, :, :, :-1] + + y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight) + x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight) + + loss = x_diff + y_diff + + return loss + + +@LOSS_REGISTRY.register() +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculating losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + + def __init__(self, + layer_weights, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0., + criterion='l1'): + super(PerceptualLoss, self).__init__() + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.vgg = VGGFeatureExtractor( + layer_name_list=list(layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + range_norm=range_norm) + + self.criterion_type = criterion + if self.criterion_type == 'l1': + self.criterion = torch.nn.L1Loss() + elif self.criterion_type == 'l2': + self.criterion = torch.nn.L2loss() + elif self.criterion_type == 'fro': + self.criterion = None + else: + raise NotImplementedError(f'{criterion} criterion has not been supported.') + + def forward(self, x, gt): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] + else: + percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # calculate style loss + if self.style_weight > 0: + style_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + style_loss += torch.norm( + self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] + else: + style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat( + gt_features[k])) * self.layer_weights[k] + style_loss *= self.style_weight + else: + style_loss = None + + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram + + +@LOSS_REGISTRY.register() +class GANLoss(nn.Module): + """Define GAN loss. + + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan': + self.loss = self._wgan_loss + elif self.gan_type == 'wgan_softplus': + self.loss = self._wgan_softplus_loss + elif self.gan_type == 'hinge': + self.loss = nn.ReLU() + else: + raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') + + def _wgan_loss(self, input, target): + """wgan loss. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return -input.mean() if target else input.mean() + + def _wgan_softplus_loss(self, input, target): + """wgan loss with soft plus. softplus is a smooth approximation to the + ReLU function. + + In StyleGAN2, it is called: + Logistic loss for discriminator; + Non-saturating loss for generator. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return F.softplus(-input).mean() if target else F.softplus(input).mean() + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + if self.gan_type in ['wgan', 'wgan_softplus']: + return target_is_real + target_val = (self.real_label_val if target_is_real else self.fake_label_val) + return input.new_ones(input.size()) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + target_label = self.get_target_label(input, target_is_real) + if self.gan_type == 'hinge': + if is_disc: # for discriminators in hinge-gan + input = -input if target_is_real else input + loss = self.loss(1 + input).mean() + else: # for generators in hinge-gan + loss = -input.mean() + else: # other gan types + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight + + +@LOSS_REGISTRY.register() +class MultiScaleGANLoss(GANLoss): + """ + MultiScaleGANLoss accepts a list of predictions + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight) + + def forward(self, input, target_is_real, is_disc=False): + """ + The input is a list of tensors, or a list of (a list of tensors) + """ + if isinstance(input, list): + loss = 0 + for pred_i in input: + if isinstance(pred_i, list): + # Only compute GAN loss for the last layer + # in case of multiscale feature matching + pred_i = pred_i[-1] + # Safe operation: 0-dim tensor calling self.mean() does nothing + loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean() + loss += loss_tensor + return loss / len(input) + else: + return super().forward(input, target_is_real, is_disc) + + +def r1_penalty(real_pred, real_img): + """R1 regularization for discriminator. The core idea is to + penalize the gradient on real data alone: when the + generator distribution produces the true data distribution + and the discriminator is equal to 0 on the data manifold, the + gradient penalty ensures that the discriminator cannot create + a non-zero gradient orthogonal to the data manifold without + suffering a loss in the GAN game. + + Ref: + Eq. 9 in Which training methods for GANs do actually converge. + """ + grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() + return grad_penalty + + +def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): + noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) + grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] + path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_lengths.detach().mean(), path_mean.detach() + + +def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): + """Calculate gradient penalty for wgan-gp. + + Args: + discriminator (nn.Module): Network for the discriminator. + real_data (Tensor): Real input data. + fake_data (Tensor): Fake input data. + weight (Tensor): Weight tensor. Default: None. + + Returns: + Tensor: A tensor for gradient penalty. + """ + + batch_size = real_data.size(0) + alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) + + # interpolate between real_data and fake_data + interpolates = alpha * real_data + (1. - alpha) * fake_data + interpolates = autograd.Variable(interpolates, requires_grad=True) + + disc_interpolates = discriminator(interpolates) + gradients = autograd.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones_like(disc_interpolates), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + + if weight is not None: + gradients = gradients * weight + + gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() + if weight is not None: + gradients_penalty /= torch.mean(weight) + + return gradients_penalty + + +@LOSS_REGISTRY.register() +class GANFeatLoss(nn.Module): + """Define feature matching loss for gans + + Args: + criterion (str): Support 'l1', 'l2', 'charbonnier'. + loss_weight (float): Loss weight. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, criterion='l1', loss_weight=1.0, reduction='mean'): + super(GANFeatLoss, self).__init__() + if criterion == 'l1': + self.loss_op = L1Loss(loss_weight, reduction) + elif criterion == 'l2': + self.loss_op = MSELoss(loss_weight, reduction) + elif criterion == 'charbonnier': + self.loss_op = CharbonnierLoss(loss_weight, reduction) + else: + raise ValueError(f'Unsupported loss mode: {criterion}. Supported ones are: l1|l2|charbonnier') + + self.loss_weight = loss_weight + + def forward(self, pred_fake, pred_real): + num_d = len(pred_fake) + loss = 0 + for i in range(num_d): # for each discriminator + # last output is the final prediction, exclude it + num_intermediate_outputs = len(pred_fake[i]) - 1 + for j in range(num_intermediate_outputs): # for each layer output + unweighted_loss = self.loss_op(pred_fake[i][j], pred_real[i][j].detach()) + loss += unweighted_loss / num_d + return loss * self.loss_weight diff --git a/opensora/models/super_resolution/basicsr/metrics/__init__.py b/opensora/models/super_resolution/basicsr/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c65580aec1f868da07a9b2c0237214e4899a2736 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/metrics/__init__.py @@ -0,0 +1,19 @@ +from copy import deepcopy + +from basicsr.utils.registry import METRIC_REGISTRY +from .psnr_ssim import calculate_psnr, calculate_ssim + +__all__ = ['calculate_psnr', 'calculate_ssim'] + + +def calculate_metric(data, opt): + """Calculate metric from data and options. + + Args: + opt (dict): Configuration. It must contain: + type (str): Model type. + """ + opt = deepcopy(opt) + metric_type = opt.pop('type') + metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) + return metric diff --git a/opensora/models/super_resolution/basicsr/metrics/metric_util.py b/opensora/models/super_resolution/basicsr/metrics/metric_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0b21777874f18a7e87c67153ee92dec4d7b599e8 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/metrics/metric_util.py @@ -0,0 +1,45 @@ +import numpy as np + +from basicsr.utils.matlab_functions import bgr2ycbcr + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255. + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. diff --git a/opensora/models/super_resolution/basicsr/metrics/psnr_ssim.py b/opensora/models/super_resolution/basicsr/metrics/psnr_ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..cb00426b91e200f9458eb9863ed7594d430002ae --- /dev/null +++ b/opensora/models/super_resolution/basicsr/metrics/psnr_ssim.py @@ -0,0 +1,128 @@ +import cv2 +import numpy as np + +from basicsr.metrics.metric_util import reorder_image, to_y_channel +from basicsr.utils.registry import METRIC_REGISTRY + + +@METRIC_REGISTRY.register() +def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: psnr result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + mse = np.mean((img - img2)**2) + if mse == 0: + return float('inf') + return 20. * np.log10(255. / np.sqrt(mse)) + + +def _ssim(img, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: ssim result. + """ + + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) + return ssim_map.mean() + + +@METRIC_REGISTRY.register() +def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the SSIM calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: ssim result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + ssims = [] + for i in range(img.shape[2]): + ssims.append(_ssim(img[..., i], img2[..., i])) + return np.array(ssims).mean() diff --git a/opensora/models/super_resolution/basicsr/models/__init__.py b/opensora/models/super_resolution/basicsr/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..285ce3ef90550f5cd6cb61467388f8ae4b73f14a --- /dev/null +++ b/opensora/models/super_resolution/basicsr/models/__init__.py @@ -0,0 +1,30 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.registry import MODEL_REGISTRY + +__all__ = ['build_model'] + +# automatically scan and import model modules for registry +# scan all the files under the 'models' folder and collect files ending with +# '_model.py' +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] + + +def build_model(opt): + """Build model from options. + + Args: + opt (dict): Configuration. It must contain: + model_type (str): Model type. + """ + opt = deepcopy(opt) + model = MODEL_REGISTRY.get(opt['model_type'])(opt) + logger = get_root_logger() + logger.info(f'Model [{model.__class__.__name__}] is created.') + return model diff --git a/opensora/models/super_resolution/basicsr/models/base_model.py b/opensora/models/super_resolution/basicsr/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f06f9ca2ca213f1a7c400355e9c66eaa12b1b1c4 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/models/base_model.py @@ -0,0 +1,380 @@ +import os +import time +import torch +from collections import OrderedDict +from copy import deepcopy +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from basicsr.models import lr_scheduler as lr_scheduler +from basicsr.utils import get_root_logger +from basicsr.utils.dist_util import master_only + + +class BaseModel(): + """Base model.""" + + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def save(self, epoch, current_iter): + """Save networks and training state.""" + pass + + def validation(self, dataloader, current_iter, tb_logger, save_img=False): + """Validation function. + + Args: + dataloader (torch.utils.data.DataLoader): Validation dataloader. + current_iter (int): Current iteration. + tb_logger (tensorboard logger): Tensorboard logger. + save_img (bool): Whether to save images. Default: False. + """ + if self.opt['dist']: + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + else: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def _initialize_best_metric_results(self, dataset_name): + """Initialize the best metric results dict for recording the best metric value and iteration.""" + if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results: + return + elif not hasattr(self, 'best_metric_results'): + self.best_metric_results = dict() + + # add a dataset record + record = dict() + for metric, content in self.opt['val']['metrics'].items(): + better = content.get('better', 'higher') + init_val = float('-inf') if better == 'higher' else float('inf') + record[metric] = dict(better=better, val=init_val, iter=-1) + self.best_metric_results[dataset_name] = record + + def _update_best_metric_result(self, dataset_name, metric, val, current_iter): + if self.best_metric_results[dataset_name][metric]['better'] == 'higher': + if val >= self.best_metric_results[dataset_name][metric]['val']: + self.best_metric_results[dataset_name][metric]['val'] = val + self.best_metric_results[dataset_name][metric]['iter'] = current_iter + else: + if val <= self.best_metric_results[dataset_name][metric]['val']: + self.best_metric_results[dataset_name][metric]['val'] = val + self.best_metric_results[dataset_name][metric]['iter'] = current_iter + + def model_ema(self, decay=0.999): + net_g = self.get_bare_model(self.net_g) + + net_g_params = dict(net_g.named_parameters()) + net_g_ema_params = dict(self.net_g_ema.named_parameters()) + + for k in net_g_ema_params.keys(): + net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay) + + def get_current_log(self): + return self.log_dict + + def model_to_device(self, net): + """Model to device. It also warps models with DistributedDataParallel + or DataParallel. + + Args: + net (nn.Module) + """ + net = net.to(self.device) + if self.opt['dist']: + find_unused_parameters = self.opt.get('find_unused_parameters', False) + net = DistributedDataParallel( + net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters) + elif self.opt['num_gpu'] > 1: + net = DataParallel(net) + return net + + def get_optimizer(self, optim_type, params, lr, **kwargs): + if optim_type == 'Adam': + optimizer = torch.optim.Adam(params, lr, **kwargs) + else: + raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.') + return optimizer + + def setup_schedulers(self): + """Set up schedulers.""" + train_opt = self.opt['train'] + scheduler_type = train_opt['scheduler'].pop('type') + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'CosineAnnealingRestartLR': + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler'])) + else: + raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.') + + def get_bare_model(self, net): + """Get bare model, especially under wrapping with + DistributedDataParallel or DataParallel. + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + return net + + @master_only + def print_network(self, net): + """Print the str and parameter number of a network. + + Args: + net (nn.Module) + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}' + else: + net_cls_str = f'{net.__class__.__name__}' + + net = self.get_bare_model(net) + net_str = str(net) + net_params = sum(map(lambda x: x.numel(), net.parameters())) + + logger = get_root_logger() + logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}') + logger.info(net_str) + + def _set_lr(self, lr_groups_l): + """Set learning rate for warmup. + + Args: + lr_groups_l (list): List for lr_groups, each for an optimizer. + """ + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + """Get the initial lr, which is set by the scheduler. + """ + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, current_iter, warmup_iter=-1): + """Update learning rate. + + Args: + current_iter (int): Current iteration. + warmup_iter (int): Warmup iter numbers. -1 for no warmup. + Default: -1. + """ + if current_iter > 1: + for scheduler in self.schedulers: + scheduler.step() + # set up warm-up learning rate + if current_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + # currently only support linearly warm up + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + return [param_group['lr'] for param_group in self.optimizers[0].param_groups] + + @master_only + def save_network(self, net, net_label, current_iter, param_key='params'): + """Save networks. + + Args: + net (nn.Module | list[nn.Module]): Network(s) to be saved. + net_label (str): Network label. + current_iter (int): Current iter number. + param_key (str | list[str]): The parameter key(s) to save network. + Default: 'params'. + """ + if current_iter == -1: + current_iter = 'latest' + save_filename = f'{net_label}_{current_iter}.pth' + save_path = os.path.join(self.opt['path']['models'], save_filename) + + net = net if isinstance(net, list) else [net] + param_key = param_key if isinstance(param_key, list) else [param_key] + assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.' + + save_dict = {} + for net_, param_key_ in zip(net, param_key): + net_ = self.get_bare_model(net_) + state_dict = net_.state_dict() + for key, param in state_dict.items(): + if key.startswith('module.'): # remove unnecessary 'module.' + key = key[7:] + state_dict[key] = param.cpu() + save_dict[param_key_] = state_dict + + # avoid occasional writing errors + retry = 3 + while retry > 0: + try: + torch.save(save_dict, save_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}') + time.sleep(1) + else: + break + finally: + retry -= 1 + if retry == 0: + logger.warning(f'Still cannot save {save_path}. Just ignore it.') + # raise IOError(f'Cannot save {save_path}.') + + def _print_different_keys_loading(self, crt_net, load_net, strict=True): + """Print keys with different name or different size when loading models. + + 1. Print keys with different names. + 2. If strict=False, print the same key but with different tensor size. + It also ignore these keys with different sizes (not load). + + Args: + crt_net (torch model): Current network. + load_net (dict): Loaded network. + strict (bool): Whether strictly loaded. Default: True. + """ + crt_net = self.get_bare_model(crt_net) + crt_net = crt_net.state_dict() + crt_net_keys = set(crt_net.keys()) + load_net_keys = set(load_net.keys()) + + logger = get_root_logger() + if crt_net_keys != load_net_keys: + logger.warning('Current net - loaded net:') + for v in sorted(list(crt_net_keys - load_net_keys)): + logger.warning(f' {v}') + logger.warning('Loaded net - current net:') + for v in sorted(list(load_net_keys - crt_net_keys)): + logger.warning(f' {v}') + + # check the size for the same keys + if not strict: + common_keys = crt_net_keys & load_net_keys + for k in common_keys: + if crt_net[k].size() != load_net[k].size(): + logger.warning(f'Size different, ignore [{k}]: crt_net: ' + f'{crt_net[k].shape}; load_net: {load_net[k].shape}') + load_net[k + '.ignore'] = load_net.pop(k) + + def load_network(self, net, load_path, strict=True, param_key='params'): + """Load network. + + Args: + load_path (str): The path of networks to be loaded. + net (nn.Module): Network. + strict (bool): Whether strictly loaded. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. + Default: 'params'. + """ + logger = get_root_logger() + net = self.get_bare_model(net) + load_net = torch.load(load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info('Loading: params_ema does not exist, use params.') + load_net = load_net[param_key] + logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].') + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + self._print_different_keys_loading(net, load_net, strict) + net.load_state_dict(load_net, strict=strict) + + @master_only + def save_training_state(self, epoch, current_iter): + """Save training states during training, which will be used for + resuming. + + Args: + epoch (int): Current epoch. + current_iter (int): Current iteration. + """ + if current_iter != -1: + state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []} + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + save_filename = f'{current_iter}.state' + save_path = os.path.join(self.opt['path']['training_states'], save_filename) + + # avoid occasional writing errors + retry = 3 + while retry > 0: + try: + torch.save(state, save_path) + except Exception as e: + logger = get_root_logger() + logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}') + time.sleep(1) + else: + break + finally: + retry -= 1 + if retry == 0: + logger.warning(f'Still cannot save {save_path}. Just ignore it.') + # raise IOError(f'Cannot save {save_path}.') + + def resume_training(self, resume_state): + """Reload the optimizers and schedulers for resumed training. + + Args: + resume_state (dict): Resume state. + """ + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) + + def reduce_loss_dict(self, loss_dict): + """reduce loss dict. + + In distributed training, it averages the losses among different GPUs . + + Args: + loss_dict (OrderedDict): Loss dict. + """ + with torch.no_grad(): + if self.opt['dist']: + keys = [] + losses = [] + for name, value in loss_dict.items(): + keys.append(name) + losses.append(value) + losses = torch.stack(losses, 0) + torch.distributed.reduce(losses, dst=0) + if self.opt['rank'] == 0: + losses /= self.opt['world_size'] + loss_dict = {key: loss for key, loss in zip(keys, losses)} + + log_dict = OrderedDict() + for name, value in loss_dict.items(): + log_dict[name] = value.mean().item() + + return log_dict diff --git a/opensora/models/super_resolution/basicsr/models/lr_scheduler.py b/opensora/models/super_resolution/basicsr/models/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..11e1c6c7a74f5233accda52370f92681d3d3cecf --- /dev/null +++ b/opensora/models/super_resolution/basicsr/models/lr_scheduler.py @@ -0,0 +1,96 @@ +import math +from collections import Counter +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepRestartLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.restarts = restarts + self.restart_weights = restart_weights + assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] + + +def get_position_from_periods(iteration, cumulative_period): + """Get the position from a period list. + + It will return the index of the right-closest number in the period list. + For example, the cumulative_period = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 2. + + Args: + iteration (int): Current iteration. + cumulative_period (list[int]): Cumulative period list. + + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_period): + if iteration <= period: + return i + + +class CosineAnnealingRestartLR(_LRScheduler): + """ Cosine annealing with restarts learning rate scheme. + + An example of config: + periods = [10, 10, 10, 10] + restart_weights = [1, 0.5, 0.5, 0.5] + eta_min=1e-7 + + It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the + scheduler will restart with the weights in restart_weights. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + periods (list): Period for each cosine anneling cycle. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + eta_min (float): The minimum lr. Default: 0. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): + self.periods = periods + self.restart_weights = restart_weights + self.eta_min = eta_min + assert (len(self.periods) == len( + self.restart_weights)), 'periods and restart_weights should have the same length.' + self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] + super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + idx = get_position_from_periods(self.last_epoch, self.cumulative_period) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] + current_period = self.periods[idx] + + return [ + self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * + (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) + for base_lr in self.base_lrs + ] diff --git a/opensora/models/super_resolution/basicsr/models/rgt_model.py b/opensora/models/super_resolution/basicsr/models/rgt_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f63b2351380c679b42cf21ffb0092566c5bd28cd --- /dev/null +++ b/opensora/models/super_resolution/basicsr/models/rgt_model.py @@ -0,0 +1,127 @@ +import torch +from torch.nn import functional as F + +from basicsr.utils.registry import MODEL_REGISTRY +from basicsr.models.sr_model import SRModel + + +@MODEL_REGISTRY.register() +class RGTModel(SRModel): + + def test(self): + self.use_chop = self.opt['val']['use_chop'] if 'use_chop' in self.opt['val'] else False + if not self.use_chop: + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(self.lq) + else: + self.net_g.eval() + with torch.no_grad(): + self.output = self.net_g(self.lq) + self.net_g.train() + + # test by partitioning + else: + _, C, h, w = self.lq.size() + split_token_h = h // 200 + 1 # number of horizontal cut sections + split_token_w = w // 200 + 1 # number of vertical cut sections + + patch_size_tmp_h = split_token_h + patch_size_tmp_w = split_token_w + + # padding + mod_pad_h, mod_pad_w = 0, 0 + if h % patch_size_tmp_h != 0: + mod_pad_h = patch_size_tmp_h - h % patch_size_tmp_h + if w % patch_size_tmp_w != 0: + mod_pad_w = patch_size_tmp_w - w % patch_size_tmp_w + + img = self.lq + img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, :h+mod_pad_h, :] + img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, :w+mod_pad_w] + + _, _, H, W = img.size() + split_h = H // split_token_h # height of each partition + split_w = W // split_token_w # width of each partition + + # overlapping + shave_h = 16 + shave_w = 16 + scale = self.opt.get('scale', 1) + ral = H // split_h + row = W // split_w + slices = [] # list of partition borders + for i in range(ral): + for j in range(row): + if i == 0 and i == ral - 1: + top = slice(i * split_h, (i + 1) * split_h) + elif i == 0: + top = slice(i*split_h, (i+1)*split_h+shave_h) + elif i == ral - 1: + top = slice(i*split_h-shave_h, (i+1)*split_h) + else: + top = slice(i*split_h-shave_h, (i+1)*split_h+shave_h) + if j == 0 and j == row - 1: + left = slice(j*split_w, (j+1)*split_w) + elif j == 0: + left = slice(j*split_w, (j+1)*split_w+shave_w) + elif j == row - 1: + left = slice(j*split_w-shave_w, (j+1)*split_w) + else: + left = slice(j*split_w-shave_w, (j+1)*split_w+shave_w) + temp = (top, left) + slices.append(temp) + img_chops = [] # list of partitions + for temp in slices: + top, left = temp + img_chops.append(img[..., top, left]) + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + outputs = [] + for chop in img_chops: + out = self.net_g_ema(chop) # image processing of each partition + outputs.append(out) + _img = torch.zeros(1, C, H * scale, W * scale) + # merge + for i in range(ral): + for j in range(row): + top = slice(i * split_h * scale, (i + 1) * split_h * scale) + left = slice(j * split_w * scale, (j + 1) * split_w * scale) + if i == 0: + _top = slice(0, split_h * scale) + else: + _top = slice(shave_h*scale, (shave_h+split_h)*scale) + if j == 0: + _left = slice(0, split_w*scale) + else: + _left = slice(shave_w*scale, (shave_w+split_w)*scale) + _img[..., top, left] = outputs[i * row + j][..., _top, _left] + self.output = _img + else: + self.net_g.eval() + with torch.no_grad(): + outputs = [] + for chop in img_chops: + out = self.net_g(chop) # image processing of each partition + outputs.append(out) + _img = torch.zeros(1, C, H * scale, W * scale) + # merge + for i in range(ral): + for j in range(row): + top = slice(i * split_h * scale, (i + 1) * split_h * scale) + left = slice(j * split_w * scale, (j + 1) * split_w * scale) + if i == 0: + _top = slice(0, split_h * scale) + else: + _top = slice(shave_h * scale, (shave_h + split_h) * scale) + if j == 0: + _left = slice(0, split_w * scale) + else: + _left = slice(shave_w * scale, (shave_w + split_w) * scale) + _img[..., top, left] = outputs[i * row + j][..., _top, _left] + self.output = _img + self.net_g.train() + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale] diff --git a/opensora/models/super_resolution/basicsr/models/sr_model.py b/opensora/models/super_resolution/basicsr/models/sr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e076553c25cb21ffc9cb0786cb89a5e9348576ff --- /dev/null +++ b/opensora/models/super_resolution/basicsr/models/sr_model.py @@ -0,0 +1,235 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from basicsr.archs import build_network +from basicsr.losses import build_loss +from basicsr.metrics import calculate_metric +from basicsr.utils import get_root_logger, imwrite, tensor2img +from basicsr.utils.registry import MODEL_REGISTRY +from .base_model import BaseModel + + +@MODEL_REGISTRY.register() +class SRModel(BaseModel): + """Base SR model for single image super-resolution.""" + + def __init__(self, opt): + super(SRModel, self).__init__(opt) + + # define network + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + self.net_g.train() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if self.cri_pix is None and self.cri_perceptual is None: + raise ValueError('Both pixel and perceptual losses are None.') + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + optim_params = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def feed_data(self, data): + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + + def optimize_parameters(self, current_iter): + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_total = 0 + loss_dict = OrderedDict() + # pixel loss + if self.cri_pix: + l_pix = self.cri_pix(self.output, self.gt) + l_total += l_pix + loss_dict['l_pix'] = l_pix + # perceptual loss + if self.cri_perceptual: + l_percep, l_style = self.cri_perceptual(self.output, self.gt) + if l_percep is not None: + l_total += l_percep + loss_dict['l_percep'] = l_percep + if l_style is not None: + l_total += l_style + loss_dict['l_style'] = l_style + + l_total.backward() + self.optimizer_g.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def test(self): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(self.lq) + else: + self.net_g.eval() + with torch.no_grad(): + self.output = self.net_g(self.lq) + self.net_g.train() + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + use_pbar = self.opt['val'].get('pbar', False) + # print(with_metrics,use_pbar) + + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + # initialize the best metric results for each dataset_name (supporting multiple validation datasets) + self._initialize_best_metric_results(dataset_name) + # zero self.metric_results + if with_metrics: + self.metric_results = {metric: 0 for metric in self.metric_results} + + metric_data = dict() + if use_pbar: + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + # this is img data + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + + metric_data['img'] = sr_img + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + metric_data['img2'] = gt_img + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + # save img + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + self.metric_results[name] += calculate_metric(metric_data, opt_) + if use_pbar: + pbar.update(1) + pbar.set_description(f'Test {img_name}') + if use_pbar: + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + # update the best metric result + self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}' + if hasattr(self, 'best_metric_results'): + log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' + f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') + log_str += '\n' + + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter) + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['lq'] = self.lq.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + if hasattr(self, 'gt'): + out_dict['gt'] = self.gt.detach().cpu() + return out_dict + + def save(self, epoch, current_iter): + if hasattr(self, 'net_g_ema'): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/opensora/models/super_resolution/basicsr/test_img.py b/opensora/models/super_resolution/basicsr/test_img.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac2930178acd25aec4b69127f5a3ce3ebac1345 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/test_img.py @@ -0,0 +1,48 @@ +import logging +import torch +from os import path as osp +from basicsr.data import build_dataloader, build_dataset +from basicsr.models import build_model +from basicsr.utils import get_root_logger, get_time_str, make_exp_dirs +from basicsr.utils.options import dict2str, parse_options + + +def image_sr(args): + # parse options, set distributed setting, set ramdom seed + opt, _ = parse_options(args.root_path, is_train=False) + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # create test dataset and dataloader + test_loaders = [] + for _, dataset_opt in sorted(opt['datasets'].items()): + dataset_opt['dataroot_lq'] = osp.join(args.output_dir, f'temp_LR') + if args.SR == 'x4': + opt['upscale'] = opt['network_g']['upscale'] = 4 + opt['val']['suffix'] = 'x4' + opt['path']['pretrain_network_g'] = osp.join(args.root_path, f'experiments/pretrained_models/RGT_x4.pth') + if args.SR == 'x2': + opt['upscale'] = opt['network_g']['upscale'] = 2 + opt['val']['suffix'] = 'x2' + + test_set = build_dataset(dataset_opt) + test_loader = build_dataloader( + test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) + test_loaders.append(test_loader) + + opt['path']['pretrain_network_g'] = args.ckpt_path + opt['val']['use_chop'] = args.use_chop + opt['path']['visualization'] = osp.join(args.output_dir, f'temp_results') + opt['path']['results_root'] = osp.join(args.output_dir, f'temp_results') + + # create model + model = build_model(opt) + for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) + + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + # print(root_path) + # image_sr(root_path) diff --git a/opensora/models/super_resolution/basicsr/utils/__init__.py b/opensora/models/super_resolution/basicsr/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4860d526119edc892ea348ae212ad4ed65cd0019 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/utils/__init__.py @@ -0,0 +1,30 @@ +from .file_client import FileClient +from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img +from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger +from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt + +__all__ = [ + # file_client.py + 'FileClient', + # img_util.py + 'img2tensor', + 'tensor2img', + 'imfrombytes', + 'imwrite', + 'crop_border', + # logger.py + 'MessageLogger', + 'AvgTimer', + 'init_tb_logger', + 'init_wandb_logger', + 'get_root_logger', + 'get_env_info', + # misc.py + 'set_random_seed', + 'get_time_str', + 'mkdir_and_rename', + 'make_exp_dirs', + 'scandir', + 'check_resume', + 'sizeof_fmt', +] diff --git a/opensora/models/super_resolution/basicsr/utils/dist_util.py b/opensora/models/super_resolution/basicsr/utils/dist_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0fab887b2cb1ce8533d2e8fdee72ae0c24f68fd0 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/utils/dist_util.py @@ -0,0 +1,82 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/opensora/models/super_resolution/basicsr/utils/file_client.py b/opensora/models/super_resolution/basicsr/utils/file_client.py new file mode 100644 index 0000000000000000000000000000000000000000..89d83ab9e0d4314f8cdf2393908a561c6d1dca92 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/utils/file_client.py @@ -0,0 +1,167 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError('Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing different lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError(f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/opensora/models/super_resolution/basicsr/utils/img_util.py b/opensora/models/super_resolution/basicsr/utils/img_util.py new file mode 100644 index 0000000000000000000000000000000000000000..3a5f1da0911d9b12f9c6164df6c6e14e3c1aef88 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/utils/img_util.py @@ -0,0 +1,172 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): + """This implementation is slightly faster than tensor2img. + It now only supports torch tensor with shape (1, c, h, w). + + Args: + tensor (Tensor): Now only support torch tensor with (1, c, h, w). + rgb2bgr (bool): Whether to change rgb to bgr. Default: True. + min_max (tuple[int]): min and max values for clamp. + """ + output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) + output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 + output = output.type(torch.uint8).cpu().numpy() + if rgb2bgr: + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + ok = cv2.imwrite(file_path, img, params) + if not ok: + raise IOError('Failed in writing images.') + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] diff --git a/opensora/models/super_resolution/basicsr/utils/logger.py b/opensora/models/super_resolution/basicsr/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..73553dc664781a061737e94880ea1c6788c09043 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/utils/logger.py @@ -0,0 +1,213 @@ +import datetime +import logging +import time + +from .dist_util import get_dist_info, master_only + +initialized_logger = {} + + +class AvgTimer(): + + def __init__(self, window=200): + self.window = window # average window + self.current_time = 0 + self.total_time = 0 + self.count = 0 + self.avg_time = 0 + self.start() + + def start(self): + self.start_time = self.tic = time.time() + + def record(self): + self.count += 1 + self.toc = time.time() + self.current_time = self.toc - self.tic + self.total_time += self.current_time + # calculate average time + self.avg_time = self.total_time / self.count + + # reset + if self.count > self.window: + self.count = 0 + self.total_time = 0 + + self.tic = time.time() + + def get_current_time(self): + return self.current_time + + def get_avg_time(self): + return self.avg_time + + +class MessageLogger(): + """Message logger for printing. + + Args: + opt (dict): Config. It contains the following keys: + name (str): Exp name. + logger (dict): Contains 'print_freq' (str) for logger interval. + train (dict): Contains 'total_iter' (int) for total iters. + use_tb_logger (bool): Use tensorboard logger. + start_iter (int): Start iter. Default: 1. + tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. + """ + + def __init__(self, opt, start_iter=1, tb_logger=None): + self.exp_name = opt['name'] + self.interval = opt['logger']['print_freq'] + self.start_iter = start_iter + self.max_iters = opt['train']['total_iter'] + self.use_tb_logger = opt['logger']['use_tb_logger'] + self.tb_logger = tb_logger + self.start_time = time.time() + self.logger = get_root_logger() + + def reset_start_time(self): + self.start_time = time.time() + + @master_only + def __call__(self, log_vars): + """Format logging message. + + Args: + log_vars (dict): It contains the following keys: + epoch (int): Epoch number. + iter (int): Current iter. + lrs (list): List for learning rates. + + time (float): Iter time. + data_time (float): Data time for each iter. + """ + # epoch, iter, learning rates + epoch = log_vars.pop('epoch') + current_iter = log_vars.pop('iter') + lrs = log_vars.pop('lrs') + + message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(') + for v in lrs: + message += f'{v:.3e},' + message += ')] ' + + # time and estimated time + if 'time' in log_vars.keys(): + iter_time = log_vars.pop('time') + data_time = log_vars.pop('data_time') + + total_time = time.time() - self.start_time + time_sec_avg = total_time / (current_iter - self.start_iter + 1) + eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + message += f'[eta: {eta_str}, ' + message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' + + # other items, especially losses + for k, v in log_vars.items(): + message += f'{k}: {v:.4e} ' + # tensorboard logger + if self.use_tb_logger and 'debug' not in self.exp_name: + if k.startswith('l_'): + self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) + else: + self.tb_logger.add_scalar(k, v, current_iter) + self.logger.info(message) + + +@master_only +def init_tb_logger(log_dir): + from torch.utils.tensorboard import SummaryWriter + tb_logger = SummaryWriter(log_dir=log_dir) + return tb_logger + + +@master_only +def init_wandb_logger(opt): + """We now only use wandb to sync tensorboard log.""" + import wandb + logger = get_root_logger() + + project = opt['logger']['wandb']['project'] + resume_id = opt['logger']['wandb'].get('resume_id') + if resume_id: + wandb_id = resume_id + resume = 'allow' + logger.warning(f'Resume wandb logger with id={wandb_id}.') + else: + wandb_id = wandb.util.generate_id() + resume = 'never' + + wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) + + logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') + + +def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger_name in initialized_logger: + return logger + + format_str = '%(asctime)s %(levelname)s: %(message)s' + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False + rank, _ = get_dist_info() + if rank != 0: + logger.setLevel('ERROR') + elif log_file is not None: + logger.setLevel(log_level) + # add file handler + file_handler = logging.FileHandler(log_file, 'w') + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + initialized_logger[logger_name] = True + return logger + + +def get_env_info(): + """Get environment information. + + Currently, only log the software version. + """ + import torch + import torchvision + + from basicsr.version import __version__ + msg = r""" + ____ _ _____ ____ + / __ ) ____ _ _____ (_)_____/ ___/ / __ \ + / __ |/ __ `// ___// // ___/\__ \ / /_/ / + / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ + /_____/ \__,_//____//_/ \___//____//_/ |_| + ______ __ __ __ __ + / ____/____ ____ ____/ / / / __ __ _____ / /__ / / + / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / + / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ + \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) + """ + msg += ('\nVersion Information: ' + f'\n\tBasicSR: {__version__}' + f'\n\tPyTorch: {torch.__version__}' + f'\n\tTorchVision: {torchvision.__version__}') + return msg diff --git a/opensora/models/super_resolution/basicsr/utils/matlab_functions.py b/opensora/models/super_resolution/basicsr/utils/matlab_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f1a83bc8beee468dd7c9ca734966e926fd9fde --- /dev/null +++ b/opensora/models/super_resolution/basicsr/utils/matlab_functions.py @@ -0,0 +1,359 @@ +import math +import numpy as np +import torch + + +def cubic(x): + """cubic function used for calculate_weights_indices.""" + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) * + (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + """Calculate weights and indices, used for imresize function. + + Args: + in_length (int): Input length. + out_length (int): Output length. + scale (float): Scale factor. + kernel_width (int): Kernel width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + """ + + if (scale < 1) and antialiasing: + # Use a modified kernel (larger kernel width) to simultaneously + # interpolate and antialias + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5 + scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + p = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( + out_length, p) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices + + # apply cubic kernel + if (scale < 1) and antialiasing: + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, p) + + # If a column in weights is all zero, get rid of it. only consider the + # first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, p - 2) + weights = weights.narrow(1, 1, p - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, p - 2) + weights = weights.narrow(1, 0, p - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +@torch.no_grad() +def imresize(img, scale, antialiasing=True): + """imresize function same as MATLAB. + + It now only supports bicubic. + The same scale applies for both height and width. + + Args: + img (Tensor | Numpy array): + Tensor: Input image with shape (c, h, w), [0, 1] range. + Numpy: Input image with shape (h, w, c), [0, 1] range. + scale (float): Scale factor. The same scale applies for both height + and width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + Default: True. + + Returns: + Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. + """ + squeeze_flag = False + if type(img).__module__ == np.__name__: # numpy type + numpy_type = True + if img.ndim == 2: + img = img[:, :, None] + squeeze_flag = True + img = torch.from_numpy(img.transpose(2, 0, 1)).float() + else: + numpy_type = False + if img.ndim == 2: + img = img.unsqueeze(0) + squeeze_flag = True + + in_c, in_h, in_w = img.size() + out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) + kernel_width = 4 + kernel = 'cubic' + + # get weights and indices + weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width, + antialiasing) + weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width, + antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) + img_aug.narrow(1, sym_len_hs, in_h).copy_(img) + + sym_patch = img[:, :sym_len_hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_he:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_c, out_h, in_w) + kernel_width = weights_h.size(1) + for i in range(out_h): + idx = int(indices_h[i][0]) + for j in range(in_c): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) + out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_we:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_c, out_h, out_w) + kernel_width = weights_w.size(1) + for i in range(out_w): + idx = int(indices_w[i][0]) + for j in range(in_c): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) + + if squeeze_flag: + out_2 = out_2.squeeze(0) + if numpy_type: + out_2 = out_2.numpy() + if not squeeze_flag: + out_2 = out_2.transpose(1, 2, 0) + + return out_2 + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + conversion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace conversion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) diff --git a/opensora/models/super_resolution/basicsr/utils/misc.py b/opensora/models/super_resolution/basicsr/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..728fef857d0071875c82ffcbc8c74b6fbe029e22 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/utils/misc.py @@ -0,0 +1,141 @@ +import numpy as np +import os +import random +import time +import torch +from os import path as osp + +from .dist_util import master_only + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def mkdir_and_rename(path): + """mkdirs. If path exists, rename it with timestamp and create a new one. + + Args: + path (str): Folder path. + """ + if osp.exists(path): + new_name = path + '_archived_' + get_time_str() + print(f'Path already exists. Rename it to {new_name}', flush=True) + os.rename(path, new_name) + os.makedirs(path, exist_ok=True) + + +@master_only +def make_exp_dirs(opt): + """Make dirs for experiments.""" + path_opt = opt['path'].copy() + if opt['is_train']: + mkdir_and_rename(path_opt.pop('experiments_root')) + else: + mkdir_and_rename(path_opt.pop('results_root')) + for key, path in path_opt.items(): + if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): + continue + else: + os.makedirs(path, exist_ok=True) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative paths. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_network paths. + + Args: + opt (dict): Options. + resume_iter (int): Resume iteration. + """ + if opt['path']['resume_state']: + # get all the networks + networks = [key for key in opt.keys() if key.startswith('network_')] + flag_pretrain = False + for network in networks: + if opt['path'].get(f'pretrain_{network}') is not None: + flag_pretrain = True + if flag_pretrain: + print('pretrain_network path will be ignored during resuming.') + # set pretrained model paths + for network in networks: + name = f'pretrain_{network}' + basename = network.replace('network_', '') + if opt['path'].get('ignore_resume_networks') is None or (network + not in opt['path']['ignore_resume_networks']): + opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') + print(f"Set {name} to {opt['path'][name]}") + + # change param_key to params in resume + param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] + for param_key in param_keys: + if opt['path'][param_key] == 'params_ema': + opt['path'][param_key] = 'params' + print(f'Set {param_key} to params') + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formatted file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' diff --git a/opensora/models/super_resolution/basicsr/utils/options.py b/opensora/models/super_resolution/basicsr/utils/options.py new file mode 100644 index 0000000000000000000000000000000000000000..1925644dd62fc7b6b1e47bf2641f9d11251f3142 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/utils/options.py @@ -0,0 +1,200 @@ +import argparse +import random +import torch +import yaml +from collections import OrderedDict +from os import path as osp + +from basicsr.utils import set_random_seed +from basicsr.utils.dist_util import get_dist_info, init_dist, master_only + + +def ordered_yaml(): + """Support OrderedDict for yaml. + + Returns: + yaml Loader and Dumper. + """ + try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader + except ImportError: + from yaml import Dumper, Loader + + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +def dict2str(opt, indent_level=1): + """dict to string for printing options. + + Args: + opt (dict): Option dict. + indent_level (int): Indent level. Default: 1. + + Return: + (str): Option string for printing. + """ + msg = '\n' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_level * 2) + k + ':[' + msg += dict2str(v, indent_level + 1) + msg += ' ' * (indent_level * 2) + ']\n' + else: + msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' + return msg + + +def _postprocess_yml_value(value): + # None + if value == '~' or value.lower() == 'none': + return None + # bool + if value.lower() == 'true': + return True + elif value.lower() == 'false': + return False + # !!float number + if value.startswith('!!float'): + return float(value.replace('!!float', '')) + # number + if value.isdigit(): + return int(value) + elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: + return float(value) + # list + if value.startswith('['): + return eval(value) + # str + return value + + +def parse_options(root_path, SR, is_train=True): + parser = argparse.ArgumentParser() + # parser.add_argument('-opt', type=str, default = 'options/test/test_RGT_S_x2.yml',required=True, help='Path to option YAML file.') + if SR == 'x4': + file_path = osp.join(root_path,'options/test/test_RGT_x4.yml') + if SR == 'x2': + file_path = osp.join(root_path,'options/test/test_RGT_x2.yml') + parser.add_argument('-opt', type=str, default = file_path, help='Path to option YAML file.') + parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') + parser.add_argument('--auto_resume', action='store_true') + parser.add_argument('--debug', action='store_true') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') + args = parser.parse_args() + + # parse yml to dict + with open(args.opt, mode='r') as f: + # print(args.opt) + opt = yaml.load(f, Loader=ordered_yaml()[0]) + + # distributed settings + if args.launcher == 'none': + opt['dist'] = False + print('Disable distributed.', flush=True) + else: + opt['dist'] = True + if args.launcher == 'slurm' and 'dist_params' in opt: + init_dist(args.launcher, **opt['dist_params']) + else: + init_dist(args.launcher) + opt['rank'], opt['world_size'] = get_dist_info() + + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) + + # force to update yml options + if args.force_yml is not None: + for entry in args.force_yml: + # now do not support creating new keys + keys, value = entry.split('=') + keys, value = keys.strip(), value.strip() + value = _postprocess_yml_value(value) + eval_str = 'opt' + for key in keys.split(':'): + eval_str += f'["{key}"]' + eval_str += '=value' + # using exec function + exec(eval_str) + + opt['auto_resume'] = args.auto_resume + opt['is_train'] = is_train + + # debug setting + if args.debug and not opt['name'].startswith('debug'): + opt['name'] = 'debug_' + opt['name'] + + if opt['num_gpu'] == 'auto': + opt['num_gpu'] = torch.cuda.device_count() + + # datasets + for phase, dataset in opt['datasets'].items(): + # for multiple datasets, e.g., val_1, val_2; test_1, test_2 + phase = phase.split('_')[0] + dataset['phase'] = phase + if 'scale' in opt: + dataset['scale'] = opt['scale'] + if dataset.get('dataroot_gt') is not None: + dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) + if dataset.get('dataroot_lq') is not None: + dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) + + # paths + for key, val in opt['path'].items(): + if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): + opt['path'][key] = osp.expanduser(val) + + if is_train: + experiments_root = osp.join(root_path, 'experiments', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_states'] = osp.join(experiments_root, 'training_states') + opt['path']['log'] = experiments_root + opt['path']['visualization'] = osp.join(experiments_root, 'visualization') + + # change some options for debug mode + if 'debug' in opt['name']: + if 'val' in opt: + opt['val']['val_freq'] = 8 + opt['logger']['print_freq'] = 1 + opt['logger']['save_checkpoint_freq'] = 8 + else: # test + results_root = osp.join(root_path, 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + opt['path']['visualization'] = osp.join(results_root, 'visualization') + + return opt, args + + +@master_only +def copy_opt_file(opt_file, experiments_root): + # copy the yml file to the experiment root + import sys + import time + from shutil import copyfile + cmd = ' '.join(sys.argv) + filename = osp.join(experiments_root, osp.basename(opt_file)) + copyfile(opt_file, filename) + + with open(filename, 'r+') as f: + lines = f.readlines() + lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n') + f.seek(0) + f.writelines(lines) diff --git a/opensora/models/super_resolution/basicsr/utils/registry.py b/opensora/models/super_resolution/basicsr/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..655753b3b9cbd0cfe73fe93a77cf1fcc3db6d827 --- /dev/null +++ b/opensora/models/super_resolution/basicsr/utils/registry.py @@ -0,0 +1,82 @@ +# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 + + +class Registry(): + """ + The registry that provides name -> object mapping, to support third-party + users' custom modules. + + To create a registry (e.g. a backbone registry): + + .. code-block:: python + + BACKBONE_REGISTRY = Registry('BACKBONE') + + To register an object: + + .. code-block:: python + + @BACKBONE_REGISTRY.register() + class MyBackbone(): + ... + + Or: + + .. code-block:: python + + BACKBONE_REGISTRY.register(MyBackbone) + """ + + def __init__(self, name): + """ + Args: + name (str): the name of this registry + """ + self._name = name + self._obj_map = {} + + def _do_register(self, name, obj): + assert (name not in self._obj_map), (f"An object named '{name}' was already registered " + f"in '{self._name}' registry!") + self._obj_map[name] = obj + + def register(self, obj=None): + """ + Register the given object under the the name `obj.__name__`. + Can be used as either a decorator or not. + See docstring of this class for usage. + """ + if obj is None: + # used as a decorator + def deco(func_or_class): + name = func_or_class.__name__ + self._do_register(name, func_or_class) + return func_or_class + + return deco + + # used as a function call + name = obj.__name__ + self._do_register(name, obj) + + def get(self, name): + ret = self._obj_map.get(name) + if ret is None: + raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") + return ret + + def __contains__(self, name): + return name in self._obj_map + + def __iter__(self): + return iter(self._obj_map.items()) + + def keys(self): + return self._obj_map.keys() + + +DATASET_REGISTRY = Registry('dataset') +ARCH_REGISTRY = Registry('arch') +MODEL_REGISTRY = Registry('model') +LOSS_REGISTRY = Registry('loss') +METRIC_REGISTRY = Registry('metric') diff --git a/opensora/models/super_resolution/options/test/test_RGT_x2.yml b/opensora/models/super_resolution/options/test/test_RGT_x2.yml new file mode 100644 index 0000000000000000000000000000000000000000..4f88a04d696328ce21641fc733900cd2bb4bf262 --- /dev/null +++ b/opensora/models/super_resolution/options/test/test_RGT_x2.yml @@ -0,0 +1,94 @@ +# general settings +name: test_RGT_x2 +model_type: RGTModel +scale: 2 +num_gpu: 1 +manual_seed: 10 + +datasets: + test_1: # the 1st test dataset + task: SR + name: Set5 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set5/HR + dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2 + filename_tmpl: '{}x2' + io_backend: + type: disk + + # test_2: # the 2st test dataset + # task: SR + # name: Set14 + # type: PairedImageDataset + # dataroot_gt: datasets/benchmark/Set14/HR + # dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X2 + # filename_tmpl: '{}x2' + # io_backend: + # type: disk + + # test_3: # the 3st test dataset + # task: SR + # name: B100 + # type: PairedImageDataset + # dataroot_gt: datasets/benchmark/B100/HR + # dataroot_lq: datasets/benchmark/B100/LR_bicubic/X2 + # filename_tmpl: '{}x2' + # io_backend: + # type: disk + + # test_4: # the 4st test dataset + # task: SR + # name: Urban100 + # type: PairedImageDataset + # dataroot_gt: datasets/benchmark/Urban100/HR + # dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X2 + # filename_tmpl: '{}x2' + # io_backend: + # type: disk + + # test_5: # the 5st test dataset + # task: SR + # name: Manga109 + # type: PairedImageDataset + # dataroot_gt: datasets/benchmark/Manga109/HR + # dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X2 + # filename_tmpl: '{}_LRBI_x2' + # io_backend: + # type: disk + + +# network structures +network_g: + type: RGT + upscale: 2 + in_chans: 3 + img_size: 64 + img_range: 1. + depth: [6,6,6,6,6,6,6,6] + embed_dim: 180 + num_heads: [6,6,6,6,6,6,6,6] + mlp_ratio: 2 + resi_connection: '1conv' + split_size: [8,32] + c_ratio: 0.5 + +# path +path: + pretrain_network_g: /remote-home/lzy/RGT/experiments/pretrained_models/RGT_x2.pth + strict_load_g: True + +# validation settings +val: + save_img: True + suffix: ~ # add suffix to saved images, if None, use exp name + use_chop: False # True to save memory, if img too large + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 2 + test_y_channel: True + ssim: + type: calculate_ssim + crop_border: 2 + test_y_channel: True \ No newline at end of file diff --git a/opensora/models/super_resolution/options/test/test_RGT_x4.yml b/opensora/models/super_resolution/options/test/test_RGT_x4.yml new file mode 100644 index 0000000000000000000000000000000000000000..a776fa5f82b1b5f8c57241c060219c1489a14a2f --- /dev/null +++ b/opensora/models/super_resolution/options/test/test_RGT_x4.yml @@ -0,0 +1,94 @@ +# general settings +name: test_RGT_x4 +model_type: RGTModel +scale: 4 +num_gpu: 1 +manual_seed: 10 + +datasets: + test_1: # the 1st test dataset + task: SR + name: Set5 + type: PairedImageDataset + dataroot_gt: datasets/benchmark/Set5/HR + dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 + filename_tmpl: '{}x4' + io_backend: + type: disk + + # test_2: # the 2st test dataset + # task: SR + # name: Set14 + # type: PairedImageDataset + # dataroot_gt: datasets/benchmark/Set14/HR + # dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4 + # filename_tmpl: '{}x4' + # io_backend: + # type: disk + + # test_3: # the 3st test dataset + # task: SR + # name: B100 + # type: PairedImageDataset + # dataroot_gt: datasets/benchmark/B100/HR + # dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4 + # filename_tmpl: '{}x4' + # io_backend: + # type: disk + + # test_4: # the 4st test dataset + # task: SR + # name: Urban100 + # type: PairedImageDataset + # dataroot_gt: datasets/benchmark/Urban100/HR + # dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4 + # filename_tmpl: '{}x4' + # io_backend: + # type: disk + + # test_5: # the 5st test dataset + # task: SR + # name: Manga109 + # type: PairedImageDataset + # dataroot_gt: datasets/benchmark/Manga109/HR + # dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4 + # filename_tmpl: '{}_LRBI_x4' + # io_backend: + # type: disk + + +# network structures +network_g: + type: RGT + upscale: 4 + in_chans: 3 + img_size: 64 + img_range: 1. + depth: [6,6,6,6,6,6,6,6] + embed_dim: 180 + num_heads: [6,6,6,6,6,6,6,6] + mlp_ratio: 2 + resi_connection: '1conv' + split_size: [8,32] + c_ratio: 0.5 + +# path +path: + pretrain_network_g: /remote-home/lzy/RGT/experiments/pretrained_models/RGT_x4.pth + strict_load_g: True + +# validation settings +val: + save_img: True + suffix: ~ # add suffix to saved images, if None, use exp name + use_chop: False # True to save memory, if img too large + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 4 + test_y_channel: True + ssim: + type: calculate_ssim + crop_border: 4 + test_y_channel: True \ No newline at end of file diff --git a/opensora/models/super_resolution/options/test/test_single_config.yml b/opensora/models/super_resolution/options/test/test_single_config.yml new file mode 100644 index 0000000000000000000000000000000000000000..58a0753db579a1ebc6b5e6382c6e091e098880eb --- /dev/null +++ b/opensora/models/super_resolution/options/test/test_single_config.yml @@ -0,0 +1,41 @@ +# general settings +name: test_single +model_type: RGTModel +scale: 2 +num_gpu: 1 +manual_seed: 10 + +datasets: + test_1: # the 1st test dataset + name: Single + type: SingleImageDataset + dataroot_lq: /test + io_backend: + type: disk + + +# network structures +network_g: + type: RGT + upscale: 2 + in_chans: 3 + img_size: 64 + img_range: 1. + depth: [6,6,6,6,6,6,6,6] + embed_dim: 180 + num_heads: [6,6,6,6,6,6,6,6] + mlp_ratio: 2 + resi_connection: '1conv' + split_size: [8,32] + c_ratio: 0.5 + +# path +path: + pretrain_network_g: /test + strict_load_g: True + +# validation settings +val: + save_img: True + suffix: ~ # add suffix to saved images, if None, use exp name + use_chop: False # True to save memory, if img too large \ No newline at end of file diff --git a/opensora/models/super_resolution/run.py b/opensora/models/super_resolution/run.py new file mode 100644 index 0000000000000000000000000000000000000000..ac8e2bfcb38196c2262a96d270282d547f95f488 --- /dev/null +++ b/opensora/models/super_resolution/run.py @@ -0,0 +1,138 @@ +import cv2 +import argparse +from basicsr.test_img import image_sr +from os import path as osp +import os +import shutil +from PIL import Image +import re +import imageio.v2 as imageio +import threading +from concurrent.futures import ThreadPoolExecutor +import time + +def replace_filename(original_path, suffix): + + directory = os.path.dirname(original_path) + old_filename = os.path.basename(original_path) + name_part, file_extension = os.path.splitext(old_filename) + new_filename = f"{name_part}{suffix}{file_extension}" + new_path = os.path.join(directory, new_filename) + + return new_path + +def create_temp_folder(folder_path): + + if os.path.exists(folder_path): + shutil.rmtree(folder_path) + os.makedirs(folder_path) + +def delete_temp_folder(folder_path): + shutil.rmtree(folder_path) + +def extract_number(filename): + s = re.findall(r'\d+', filename) + return int(s[0]) if s else -1 + +def bicubic_upsample_opencv(input_image_path, output_image_path, scale_factor): + + img = cv2.imread(input_image_path) + + original_height, original_width = img.shape[:2] + + new_width = int(original_width * scale_factor) + new_height = int(original_height * scale_factor) + + upsampled_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_CUBIC) + cv2.imwrite(output_image_path, upsampled_img) + + +def process_frame(frame_count, frame, temp_LR_folder_path, temp_HR_folder_path, SR): + frame_path = os.path.join(temp_LR_folder_path, f"frame_{frame_count}{SR}.png") + cv2.imwrite(frame_path, frame) + HR_frame_path = os.path.join(temp_HR_folder_path, f"frame_{frame_count}.png") + + if SR == 'x4': + bicubic_upsample_opencv(frame_path, HR_frame_path, 4) + elif SR == 'x2': + bicubic_upsample_opencv(frame_path, HR_frame_path, 2) + +def video_sr(args): + file_name = os.path.basename(args.input_dir) + video_output_path = os.path.join(args.output_dir,file_name) + + if args.SR == 'x4': + temp_LR_folder_path = os.path.join(args.output_dir, f'temp_LR/X4') + video_output_path = replace_filename(video_output_path, '_x4') + result_temp = osp.join(args.root_path, f'results/test_RGT_x4/visualization/Set5') + if args.SR == 'x2': + temp_LR_folder_path = os.path.join(args.output_dir, f'temp_LR/X2') + video_output_path = replace_filename(video_output_path, '_x2') + result_temp = osp.join(args.root_path, f'results/test_RGT_x2/visualization/Set5') + + temp_HR_folder_path = os.path.join(args.output_dir, f'temp_HR') + + # create_temp_folder(result_temp) + create_temp_folder(temp_LR_folder_path) + create_temp_folder(temp_HR_folder_path) + + cap = cv2.VideoCapture(args.input_dir) + if not cap.isOpened(): + print("Error opening video file.") + return + + t1 = time.time() + frame_count = 0 + frames_to_process = [] + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frames_to_process.append((frame_count, frame)) + frame_count += 1 + + with ThreadPoolExecutor(max_workers = args.mul_numwork) as executor: + for frame_count, frame in frames_to_process: + executor.submit(process_frame, frame_count, frame, temp_LR_folder_path, temp_HR_folder_path, args.SR) + + print("total frames:",frame_count) + print("fps :",cap.get(cv2.CAP_PROP_FPS)) + + t2 = time.time() + print('mul threads: ',t2 - t1,'s') + # progress all frames in video + image_sr(args) + + t3 = time.time() + print('image super resolution: ',t3 - t2,'s') + # recover video form all frames + frame_files = sorted(os.listdir(result_temp), key=extract_number) + video_frames = [imageio.imread(os.path.join(result_temp, frame_file)) for frame_file in frame_files] + fps = cap.get(cv2.CAP_PROP_FPS) + imageio.mimwrite(video_output_path, video_frames, fps=fps, quality=9) + + t4 = time.time() + print('tranformer frames to video: ',t4 - t3,'s') + # release all resources + cap.release() + delete_temp_folder(os.path.dirname(temp_LR_folder_path)) + delete_temp_folder(temp_HR_folder_path) + delete_temp_folder(os.path.join(args.root_path, f'results')) + + t5 = time.time() + print('delete time: ',t5 - t4,'s') + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="RGT for Video Super-Resolution") + # make sure you SR is match with the ckpt_path + parser.add_argument("--SR", type=str, choices=['x2', 'x4'], default='x4', help='image resolution') + parser.add_argument("--ckpt_path", type=str, default = "/remote-home/lzy/RGT/experiments/pretrained_models/RGT_x4.pth") + + parser.add_argument("--root_path", type=str, default = "/remote-home/lzy/RGT") + parser.add_argument("--input_dir", type=str, default= "/remote-home/lzy/RGT/datasets/video/video_test1.mp4") + parser.add_argument("--output_dir", type=str, default= "/remote-home/lzy/RGT/datasets/video_output") + + parser.add_argument("--mul_numwork", type=int, default = 16, help ='max_workers to execute Multi') + parser.add_argument("--use_chop", type= bool, default = True, help ='use_chop: True # True to save memory, if img too large') + args = parser.parse_args() + video_sr(args) diff --git a/opensora/models/text_encoder/__init__.py b/opensora/models/text_encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e3f47c6e8ac86e5cc4faaf52393f92f48b78d0c --- /dev/null +++ b/opensora/models/text_encoder/__init__.py @@ -0,0 +1,44 @@ +import torch +from torch import nn +from transformers import T5EncoderModel, CLIPModel, CLIPProcessor + +from opensora.utils.utils import get_precision + + +class T5Wrapper(nn.Module): + def __init__(self, args): + super(T5Wrapper, self).__init__() + self.model_name = args.text_encoder_name + dtype = get_precision(args) + t5_model_kwargs = {'cache_dir': './cache_dir', 'low_cpu_mem_usage': True, 'torch_dtype': dtype} + self.text_enc = T5EncoderModel.from_pretrained(self.model_name, **t5_model_kwargs).eval() + + def forward(self, input_ids, attention_mask): + text_encoder_embs = self.text_enc(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state'] + return text_encoder_embs.detach() + +class CLIPWrapper(nn.Module): + def __init__(self, args): + super(CLIPWrapper, self).__init__() + self.model_name = args.text_encoder_name + dtype = get_precision(args) + model_kwargs = {'cache_dir': './cache_dir', 'low_cpu_mem_usage': True, 'torch_dtype': dtype} + self.text_enc = CLIPModel.from_pretrained(self.model_name, **model_kwargs).eval() + + def forward(self, input_ids, attention_mask): + text_encoder_embs = self.text_enc.get_text_features(input_ids=input_ids, attention_mask=attention_mask) + return text_encoder_embs.detach() + + + +text_encoder = { + 'DeepFloyd/t5-v1_1-xxl': T5Wrapper, + 'openai/clip-vit-large-patch14': CLIPWrapper +} + + +def get_text_enc(args): + """deprecation""" + text_enc = text_encoder.get(args.text_encoder_name, None) + assert text_enc is not None + return text_enc(args) diff --git a/opensora/models/text_encoder/clip.py b/opensora/models/text_encoder/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..11bf4041e217858eef767b913921d1b173e28bd4 --- /dev/null +++ b/opensora/models/text_encoder/clip.py @@ -0,0 +1,222 @@ +# -*- coding: utf-8 -*- +import os +import re +import ftfy +import torch +import html +from PIL import Image +from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer, CLIPTextModel + +class CLIPEmbedder: + """ + A class for embedding texts and images using a pretrained CLIP model. + """ + + def __init__(self, device='cuda', model_name='openai/clip-vit-base-patch32', cache_dir='./cache_dir', use_text_preprocessing=True, max_length=77): + """ + Initializes the CLIPEmbedder with specified model and configurations. + """ + self.device = torch.device(device) + self.model_name = model_name + self.cache_dir = cache_dir + self.use_text_preprocessing = use_text_preprocessing + self.max_length = max_length + + os.makedirs(self.cache_dir, exist_ok=True) + + self.processor = CLIPProcessor.from_pretrained(model_name, cache_dir=self.cache_dir) + self.model = CLIPModel.from_pretrained(model_name, cache_dir=self.cache_dir).to(self.device).eval() + self.tokenizer = CLIPTokenizer.from_pretrained(model_name) + self.text_model = CLIPTextModel.from_pretrained(model_name, cache_dir=self.cache_dir).to(self.device).eval() + + for param in self.text_model.parameters(): + param.requires_grad = False + + def get_text_embeddings(self, texts): + """ + Generates embeddings for a list of text prompts. + """ + self._validate_input_list(texts, str) + + if self.use_text_preprocessing: + texts = [self._clean_text(text) for text in texts] + + inputs = self.processor(text=texts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length).to(self.device) + + with torch.no_grad(): + embeddings = self.model.get_text_features(**inputs) + + return embeddings + + def encode_text(self, texts): + """ + Encodes texts into embeddings and returns the last hidden state and pooled output. + """ + self._validate_input_list(texts, str) + + batch_encoding = self.tokenizer(texts, return_tensors="pt", truncation=True, max_length=self.max_length, padding="max_length").to(self.device) + + with torch.no_grad(): + outputs = self.text_model(**batch_encoding) + + return outputs.last_hidden_state, outputs.pooler_output + + def get_image_embeddings(self, image_paths): + """ + Generates embeddings for a list of image file paths. + """ + self._validate_input_list(image_paths, str) + images = [self._load_image(path) for path in image_paths] + + inputs = self.processor(images=images, return_tensors="pt").to(self.device) + + with torch.no_grad(): + embeddings = self.model.get_image_features(**inputs) + + return embeddings + + def _validate_input_list(self, input_list, expected_type): + """ + Validates that the input is a list of expected type. + """ + if not isinstance(input_list, list) or not all(isinstance(item, expected_type) for item in input_list): + raise ValueError(f"Input must be a list of {expected_type.__name__}.") + + def _clean_text(self, text): + """ + Applies basic cleaning and formatting to a text string. + """ + text = ftfy.fix_text(text) + text = html.unescape(text) + return text.strip() + + def _load_image(self, image_path): + """ + Loads and preprocesses an image from a file path. + """ + try: + image = Image.open(image_path).convert("RGB") + except FileNotFoundError: + raise FileNotFoundError(f"Image file not found: {image_path}") + except Exception as e: + raise Exception(f"Error loading image {image_path}: {e}") + return image + + def clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub('', 'person', caption) + # urls: + caption = re.sub( + r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + '', caption) # regex for urls + caption = re.sub( + r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + '', caption) # regex for urls + + caption = BeautifulSoup(caption, features='html.parser').text + + + caption = re.sub(r'@[\w\d]+\b', '', caption) + + caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) + caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) + caption = re.sub(r'[\u3200-\u32ff]+', '', caption) + caption = re.sub(r'[\u3300-\u33ff]+', '', caption) + caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) + caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) + caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) + + caption = re.sub( + r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa + '-', caption) + + + caption = re.sub(r'[`´«»“”¨]', '"', caption) + caption = re.sub(r'[‘’]', "'", caption) + + + caption = re.sub(r'"?', '', caption) + + caption = re.sub(r'&', '', caption) + + + caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) + + + caption = re.sub(r'\d:\d\d\s+$', '', caption) + + + caption = re.sub(r'\\n', ' ', caption) + + + caption = re.sub(r'#\d{1,3}\b', '', caption) + + caption = re.sub(r'#\d{5,}\b', '', caption) + caption = re.sub(r'\b\d{6,}\b', '', caption) + caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) + caption = re.sub(r'[\"\']{2,}', r'"', caption) + caption = re.sub(r'[\.]{2,}', r' ', caption) + + caption = re.sub(self.bad_punct_regex, r' ', caption) + caption = re.sub(r'\s+\.\s+', r' ', caption) + regex2 = re.compile(r'(?:\-|\_)') + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, ' ', caption) + caption = self.basic_clean(caption) + caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640 + caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc + caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231 + + caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) + caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) + caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) + caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) + caption = re.sub(r'\bpage\s+\d+\b', '', caption) + + caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a... + + caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) + + caption = re.sub(r'\b\s+\:\s+', r': ', caption) + caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) + caption = re.sub(r'\s+', ' ', caption) + + caption.strip() + + caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) + caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) + caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) + caption = re.sub(r'^\.\S+$', '', caption) + + return caption.strip() + + @staticmethod + def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + +if __name__ == '__main__': + + clip_embedder = CLIPEmbedder() + + # Example + text_prompts = [ + 'A photo of a cute puppy playing with a ball.', + 'An image of a beautiful sunset over the ocean.', + 'A scene depicting a busy city street.' + ] + text_embeddings = clip_embedder.get_text_embeddings(text_prompts) + print(f"Text embeddings shape: {text_embeddings.shape}") + + image_paths = ['image1.jpg', 'image2.png'] + try: + image_embeddings = clip_embedder.get_image_embeddings(image_paths) + print(f"Image embeddings shape: {image_embeddings.shape}") + except FileNotFoundError as e: + print(e) + except Exception as e: + print(f"An error occurred: {e}") + diff --git a/opensora/models/text_encoder/t5.py b/opensora/models/text_encoder/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..cbc782072e58d9fd3041d36a2a1ded089b61329d --- /dev/null +++ b/opensora/models/text_encoder/t5.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +import os +import re +import html +import urllib.parse as ul + +import ftfy +import torch +from bs4 import BeautifulSoup +from transformers import T5EncoderModel, AutoTokenizer +from huggingface_hub import hf_hub_download + +class T5Embedder: + + available_models = ['t5-v1_1-xxl'] + bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa + + def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir='./cache_dir', hf_token=None, use_text_preprocessing=True, + t5_model_kwargs=None, torch_dtype=None, model_max_length=120): + self.device = torch.device(device) + self.torch_dtype = torch_dtype or torch.bfloat16 + if t5_model_kwargs is None: + t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype} + t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device} + + self.use_text_preprocessing = use_text_preprocessing + self.hf_token = hf_token + self.cache_dir = cache_dir + self.dir_or_name = dir_or_name + cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl') + for filename in ['config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json', + 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin', 'pytorch_model.bin.index.json']: + hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir, + force_filename=filename, token=self.hf_token) + + print(cache_dir) + self.tokenizer = AutoTokenizer.from_pretrained(cache_dir) + self.model = T5EncoderModel.from_pretrained(cache_dir, **t5_model_kwargs).eval() + self.model_max_length = model_max_length + + def get_text_embeddings(self, texts): + texts = [self.text_preprocessing(text) for text in texts] + + text_tokens_and_mask = self.tokenizer( + texts, + max_length=self.model_max_length, + padding='max_length', + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors='pt' + ) + + text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids'] + text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask'] + + with torch.no_grad(): + text_encoder_embs = self.model( + input_ids=text_tokens_and_mask['input_ids'].to(self.device), + attention_mask=text_tokens_and_mask['attention_mask'].to(self.device), + )['last_hidden_state'].detach() + return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device) + + def text_preprocessing(self, text): + if self.use_text_preprocessing: + # The exact text cleaning as was in the training stage: + text = self.clean_caption(text) + text = self.clean_caption(text) + return text + else: + return text.lower().strip() + + @staticmethod + def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + def clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub('', 'person', caption) + # urls: + caption = re.sub( + r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + '', caption) # regex for urls + caption = re.sub( + r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + '', caption) # regex for urls + # html: + caption = BeautifulSoup(caption, features='html.parser').text + + # @ + caption = re.sub(r'@[\w\d]+\b', '', caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) + caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) + caption = re.sub(r'[\u3200-\u32ff]+', '', caption) + caption = re.sub(r'[\u3300-\u33ff]+', '', caption) + caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) + caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) + caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa + '-', caption) + + # кавычки к одному стандарту + caption = re.sub(r'[`´«»“”¨]', '"', caption) + caption = re.sub(r'[‘’]', "'", caption) + + # " + caption = re.sub(r'"?', '', caption) + # & + caption = re.sub(r'&', '', caption) + + # ip adresses: + caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) + + # article ids: + caption = re.sub(r'\d:\d\d\s+$', '', caption) + + # \n + caption = re.sub(r'\\n', ' ', caption) + + # "#123" + caption = re.sub(r'#\d{1,3}\b', '', caption) + # "#12345.." + caption = re.sub(r'#\d{5,}\b', '', caption) + # "123456.." + caption = re.sub(r'\b\d{6,}\b', '', caption) + # filenames: + caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) + + # + caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r'\s+\.\s+', r' ', caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r'(?:\-|\_)') + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, ' ', caption) + + caption = self.basic_clean(caption) + + caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640 + caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc + caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231 + + caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) + caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) + caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) + caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) + caption = re.sub(r'\bpage\s+\d+\b', '', caption) + + caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a... + + caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) + + caption = re.sub(r'\b\s+\:\s+', r': ', caption) + caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) + caption = re.sub(r'\s+', ' ', caption) + + caption.strip() + + caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) + caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) + caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) + caption = re.sub(r'^\.\S+$', '', caption) + + return caption.strip() + +if __name__ == '__main__': + t5 = T5Embedder(device="cuda", cache_dir='./cache_dir', torch_dtype=torch.float) + device = t5.device + prompts = ['I am a test caption', 'Test twice'] + with torch.no_grad(): + caption_embs, emb_masks = t5.get_text_embeddings(prompts) + emb_dict = { + 'caption_feature': caption_embs.float().cpu().data.numpy(), + 'attention_mask': emb_masks.cpu().data.numpy(), + } + import ipdb;ipdb.set_trace() + print() \ No newline at end of file diff --git a/opensora/sample/pipeline_videogen.py b/opensora/sample/pipeline_videogen.py new file mode 100644 index 0000000000000000000000000000000000000000..303414b8c61b8938d600cd81923f63e52e71dc5b --- /dev/null +++ b/opensora/sample/pipeline_videogen.py @@ -0,0 +1,759 @@ +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from typing import Callable, List, Optional, Tuple, Union + +import torch +import einops +from einops import rearrange +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKL, Transformer2DModel +from diffusers.schedulers import DPMSolverMultistepScheduler +from diffusers.utils import ( + BACKENDS_MAPPING, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput +from dataclasses import dataclass + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import PixArtAlphaPipeline + + >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too. + >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A small cactus with a happy face in the Sahara desert." + >>> image = pipe(prompt).images[0] + ``` +""" + + +@dataclass +class VideoPipelineOutput(BaseOutput): + video: torch.Tensor + + +class VideoGenPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using PixArt-Alpha. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. PixArt-Alpha uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`Transformer2DModel`]): + A text conditioned `Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: Transformer2DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + # self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py + def mask_text_embeddings(self, emb, mask): + if emb.shape[0] == 1: + keep_index = mask.sum().item() + return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096 + else: + masked_feature = emb * mask[:, None, :, None] # 1 120 4096 + return masked_feature, emb.shape[2] + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + mask_feature: bool = True, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + mask_feature: (bool, defaults to `True`): + If `True`, the function will mask the text embeddings. + """ + embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None + + if device is None: + device = self.text_encoder.device or self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # See Section 3.1. of the paper. + max_length = 300 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1]) + logger.warning( + "The following part of your input was truncated because the model can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds_attention_mask = attention_mask + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds_attention_mask = torch.ones_like(prompt_embeds) + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1) + prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + # print(prompt_embeds.shape) # 1 120 4096 + # print(negative_prompt_embeds.shape) # 1 120 4096 + + # Perform additional masking. + if mask_feature and not embeds_initially_provided: + prompt_embeds = prompt_embeds.unsqueeze(1) + masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask) + masked_prompt_embeds = masked_prompt_embeds.squeeze(1) + masked_negative_prompt_embeds = ( + negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None + ) + + # import torch.nn.functional as F + + # padding = (0, 0, 0, 113) # (左, 右, 下, 上) + # masked_prompt_embeds_ = F.pad(masked_prompt_embeds, padding, "constant", 0) + # masked_negative_prompt_embeds_ = F.pad(masked_negative_prompt_embeds, padding, "constant", 0) + + # print(masked_prompt_embeds == masked_prompt_embeds_[:, :masked_negative_prompt_embeds.shape[1], ...]) + + return masked_prompt_embeds, masked_negative_prompt_embeds + # return masked_prompt_embeds_, masked_negative_prompt_embeds_ + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", + # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", + # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", + # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, + latents=None): + shape = ( + batch_size, num_channels_latents, video_length, self.vae.latent_size[0], self.vae.latent_size[1]) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + video_length: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + mask_feature: bool = True, + enable_temporal_attentions: bool = True, + ) -> Union[VideoPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # 1. Check inputs. Raise error if not correct + # height = height or self.transformer.config.sample_size * self.vae_scale_factor + # width = width or self.transformer.config.sample_size * self.vae_scale_factor + self.check_inputs( + prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self.text_encoder.device or self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + mask_feature=mask_feature, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + video_length, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + # if self.transformer.config.sample_size == 128: + # resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) + # aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) + # resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) + # aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) + # added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + enable_temporal_attentions=enable_temporal_attentions, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == 'latents': + video = self.decode_latents(latents) + else: + video = latents + return VideoPipelineOutput(video=video) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return VideoPipelineOutput(video=video) + + def decode_latents(self, latents): + video = self.vae.decode(latents) + # video = self.vae.decode(latents / 0.18215) + # video = rearrange(video, 'b c t h w -> b t c h w').contiguous() + video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous() + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + return video diff --git a/opensora/sample/sample.py b/opensora/sample/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..84c7466aa6b0cc9fd6a22c3566ddee43cbbf46fa --- /dev/null +++ b/opensora/sample/sample.py @@ -0,0 +1,128 @@ +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Sample new images from a pre-trained Latte. +""" +import os +import sys + +from accelerate import Accelerator +from tqdm import tqdm + +from opensora.dataset import ae_denorm +from opensora.models.ae import ae_channel_config, getae, ae_stride_config +from opensora.models.ae.videobase import CausalVQVAEModelWrapper +from opensora.models.diffusion import Diffusion_models +from opensora.models.diffusion.diffusion import create_diffusion_T as create_diffusion +from opensora.models.diffusion.latte.modeling_latte import Latte +from opensora.utils.utils import find_model + +import torch +import argparse + +from einops import rearrange +import imageio + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + +def main(args): + # Setup PyTorch: + # torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + assert torch.cuda.is_available(), "Training currently requires at least one GPU." + + # Setup accelerator: + accelerator = Accelerator(mixed_precision=args.mixed_precision) + device = accelerator.device + + using_cfg = args.cfg_scale > 1.0 + + # Load model: + latent_size = (args.image_size // ae_stride_config[args.ae][1], args.image_size // ae_stride_config[args.ae][2]) + args.latent_size = latent_size + model = Latte.from_pretrained(args.ckpt, subfolder="model").to(device) + + model.eval() # important! + + model = accelerator.prepare(model) + + diffusion = create_diffusion(str(args.num_sampling_steps)) + ae = getae(args).to(device) + if isinstance(ae, CausalVQVAEModelWrapper): + video_length = args.num_frames // ae_stride_config[args.ae][0] + 1 + else: + video_length = args.num_frames // ae_stride_config[args.ae][0] + bar = tqdm(range(args.num_sample)) + for i in bar: + # Create sampling noise: + z = torch.randn(1, model.module.in_channels, video_length, latent_size[0], latent_size[1], device=device) + + # Setup classifier-free guidance: + if using_cfg and args.train_classcondition: + z = torch.cat([z, z], 0) + y = torch.randint(0, args.num_classes, (1,), device=device) + cls_id = str(int(y.detach().cpu())) + y_null = torch.tensor([args.num_classes] * 1, device=device) + y = torch.cat([y, y_null], dim=0) + model_kwargs = dict(class_labels=y, cfg_scale=args.cfg_scale) + sample_fn = model.module.forward_with_cfg + else: + if args.train_classcondition: + sample_fn = model.forward + y = torch.randint(0, args.num_classes, (1,), device=device) + cls_id = str(int(y.detach().cpu())) + model_kwargs = dict(class_labels=y) + else: + sample_fn = model.forward + model_kwargs = dict(class_labels=None) + + # Sample images: + if args.sample_method == 'ddim': + samples = diffusion.ddim_sample_loop( + sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device + ) + elif args.sample_method == 'ddpm': + samples = diffusion.p_sample_loop( + sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device + ) + + with torch.no_grad(): + samples = ae.decode(samples) + # Save and display images: + + if not os.path.exists(args.save_video_path): + os.makedirs(args.save_video_path) + + video_ = (ae_denorm[args.ae](samples[0]) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous() + if args.train_classcondition: + video_save_path = os.path.join(args.save_video_path, f"sample_{i:03d}_cls" + str(cls_id) + '.mp4') + else: + video_save_path = os.path.join(args.save_video_path, f"sample_{i:03d}" + '.mp4') + print(video_save_path) + imageio.mimwrite(video_save_path, video_, fps=args.fps, quality=9) + print('save path {}'.format(args.save_video_path)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt", type=str, default="") + parser.add_argument("--model", type=str, default='Latte-XL/122') + parser.add_argument("--ae", type=str, default='stabilityai/sd-vae-ft-mse') + parser.add_argument("--save_video_path", type=str, default="./sample_videos/") + parser.add_argument("--fps", type=int, default=10) + parser.add_argument("--num_classes", type=int, default=101) + parser.add_argument("--num_frames", type=int, default=16) + parser.add_argument("--image_size", type=int, default=256) + parser.add_argument("--train_classcondition", action="store_true") + parser.add_argument("--num_sampling_steps", type=int, default=250) + parser.add_argument("--num_sample", type=int, default=1) + parser.add_argument("--cfg_scale", type=float, default=1.0) + parser.add_argument("--sample_method", type=str, default='ddpm') + parser.add_argument("--mixed_precision", type=str, default=None, choices=[None, "fp16", "bf16"]) + parser.add_argument("--attention_mode", type=str, choices=['xformers', 'math', 'flash'], default="math") + args = parser.parse_args() + main(args) diff --git a/opensora/sample/sample_t2v.py b/opensora/sample/sample_t2v.py new file mode 100644 index 0000000000000000000000000000000000000000..045e4809719fde136f95f187c478c1a9246f3744 --- /dev/null +++ b/opensora/sample/sample_t2v.py @@ -0,0 +1,161 @@ +import math +import os +import torch +import argparse +import torchvision + +from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler, + EulerDiscreteScheduler, DPMSolverMultistepScheduler, + HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, + DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler) +from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler +from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder +from omegaconf import OmegaConf +from torchvision.utils import save_image +from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer + +import os, sys + +from opensora.models.ae import ae_stride_config, getae, getae_wrapper +from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper +from opensora.models.diffusion.latte.modeling_latte import LatteT2V +from opensora.models.text_encoder import get_text_enc +from opensora.utils.utils import save_video_grid + +sys.path.append(os.path.split(sys.path[0])[0]) +from pipeline_videogen import VideoGenPipeline + +import imageio + + +def main(args): + # torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + + vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir').to(device, dtype=torch.float16) + if args.enable_tiling: + vae.vae.enable_tiling() + vae.vae.tile_overlap_factor = args.tile_overlap_factor + + # Load model: + transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, cache_dir="cache_dir", torch_dtype=torch.float16).to(device) + transformer_model.force_images = args.force_images + tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir="cache_dir") + text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir", torch_dtype=torch.float16).to(device) + + video_length, image_size = transformer_model.config.video_length, int(args.version.split('x')[1]) + latent_size = (image_size // ae_stride_config[args.ae][1], image_size // ae_stride_config[args.ae][2]) + vae.latent_size = latent_size + if args.force_images: + video_length = 1 + ext = 'jpg' + else: + ext = 'mp4' + + # set eval mode + transformer_model.eval() + vae.eval() + text_encoder.eval() + + if args.sample_method == 'DDIM': ######### + scheduler = DDIMScheduler() + elif args.sample_method == 'EulerDiscrete': + scheduler = EulerDiscreteScheduler() + elif args.sample_method == 'DDPM': ############# + scheduler = DDPMScheduler() + elif args.sample_method == 'DPMSolverMultistep': + scheduler = DPMSolverMultistepScheduler() + elif args.sample_method == 'DPMSolverSinglestep': + scheduler = DPMSolverSinglestepScheduler() + elif args.sample_method == 'PNDM': + scheduler = PNDMScheduler() + elif args.sample_method == 'HeunDiscrete': ######## + scheduler = HeunDiscreteScheduler() + elif args.sample_method == 'EulerAncestralDiscrete': + scheduler = EulerAncestralDiscreteScheduler() + elif args.sample_method == 'DEISMultistep': + scheduler = DEISMultistepScheduler() + elif args.sample_method == 'KDPM2AncestralDiscrete': ######### + scheduler = KDPM2AncestralDiscreteScheduler() + print('videogen_pipeline', device) + videogen_pipeline = VideoGenPipeline(vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer_model).to(device=device) + # videogen_pipeline.enable_xformers_memory_efficient_attention() + + if not os.path.exists(args.save_img_path): + os.makedirs(args.save_img_path) + + video_grids = [] + if not isinstance(args.text_prompt, list): + args.text_prompt = [args.text_prompt] + if len(args.text_prompt) == 1 and args.text_prompt[0].endswith('txt'): + text_prompt = open(args.text_prompt[0], 'r').readlines() + args.text_prompt = [i.strip() for i in text_prompt] + for prompt in args.text_prompt: + print('Processing the ({}) prompt'.format(prompt)) + videos = videogen_pipeline(prompt, + video_length=video_length, + height=image_size, + width=image_size, + num_inference_steps=args.num_sampling_steps, + guidance_scale=args.guidance_scale, + enable_temporal_attentions=not args.force_images, + num_images_per_prompt=1, + mask_feature=True, + ).video + try: + if args.force_images: + videos = videos[:, 0].permute(0, 3, 1, 2) # b t h w c -> b c h w + save_image(videos / 255.0, os.path.join(args.save_img_path, + prompt.replace(' ', '_')[:100] + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), + nrow=1, normalize=True, value_range=(0, 1)) # t c h w + + else: + imageio.mimwrite( + os.path.join( + args.save_img_path, + prompt.replace(' ', '_')[:100] + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}' + ), videos[0], + fps=args.fps, quality=9) # highest quality is 10, lowest is 0 + except: + print('Error when saving {}'.format(prompt)) + video_grids.append(videos) + video_grids = torch.cat(video_grids, dim=0) + + + # torchvision.io.write_video(args.save_img_path + '_%04d' % args.run_time + '-.mp4', video_grids, fps=6) + if args.force_images: + save_image(video_grids / 255.0, os.path.join(args.save_img_path, f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), + nrow=math.ceil(math.sqrt(len(video_grids))), normalize=True, value_range=(0, 1)) + else: + video_grids = save_video_grid(video_grids) + imageio.mimwrite(os.path.join(args.save_img_path, f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), video_grids, fps=args.fps, quality=9) + + print('save path {}'.format(args.save_img_path)) + + # save_videos_grid(video, f"./{prompt}.gif") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default='LanguageBind/Open-Sora-Plan-v1.0.0') + parser.add_argument("--version", type=str, default='65x512x512', choices=['65x512x512', '65x256x256', '17x256x256']) + parser.add_argument("--ae", type=str, default='CausalVAEModel_4x8x8') + parser.add_argument("--text_encoder_name", type=str, default='DeepFloyd/t5-v1_1-xxl') + parser.add_argument("--save_img_path", type=str, default="./sample_videos/t2v") + parser.add_argument("--guidance_scale", type=float, default=7.5) + parser.add_argument("--sample_method", type=str, default="PNDM") + parser.add_argument("--num_sampling_steps", type=int, default=50) + parser.add_argument("--fps", type=int, default=24) + parser.add_argument("--run_time", type=int, default=0) + parser.add_argument("--text_prompt", nargs='+') + parser.add_argument('--force_images', action='store_true') + parser.add_argument('--tile_overlap_factor', type=float, default=0.25) + parser.add_argument('--enable_tiling', action='store_true') + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/opensora/sample/transport_sample.py b/opensora/sample/transport_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..b88b552061e2c6a023ece46e0d586c0248470099 --- /dev/null +++ b/opensora/sample/transport_sample.py @@ -0,0 +1,203 @@ +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Sample new images from a pre-trained SiT. +""" +import os +import sys + +from opensora.dataset import ae_denorm +from opensora.models.ae import ae_channel_config, getae, ae_stride_config +from opensora.models.diffusion import Diffusion_models +from opensora.models.diffusion.transport import create_transport, Sampler +from opensora.utils.utils import find_model + +import torch +import argparse + +from einops import rearrange +import imageio + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + + + +def main(mode, args): + # Setup PyTorch: + # torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + + using_cfg = args.cfg_scale > 1.0 + + # Load model: + latent_size = args.image_size // ae_stride_config[args.ae][1] + args.latent_size = latent_size + model = Diffusion_models[args.model]( + input_size=latent_size, + num_classes=args.num_classes, + in_channels=ae_channel_config[args.ae], + extras=args.extras + ).to(device) + + if args.use_compile: + model = torch.compile(model) + + # a pre-trained model or load a custom Latte checkpoint from train.py: + ckpt_path = args.ckpt + state_dict = find_model(ckpt_path) + model.load_state_dict(state_dict) + + model.eval() # important! + transport = create_transport( + args.path_type, + args.prediction, + args.loss_weight, + args.train_eps, + args.sample_eps + ) + sampler = Sampler(transport) + if mode == "ODE": + if args.likelihood: + assert args.cfg_scale == 1, "Likelihood is incompatible with guidance" + sample_fn = sampler.sample_ode_likelihood( + sampling_method=args.sampling_method, + num_steps=args.num_sampling_steps, + atol=args.atol, + rtol=args.rtol, + ) + else: + sample_fn = sampler.sample_ode( + sampling_method=args.sampling_method, + num_steps=args.num_sampling_steps, + atol=args.atol, + rtol=args.rtol, + reverse=args.reverse + ) + elif mode == "SDE": + sample_fn = sampler.sample_sde( + sampling_method=args.sampling_method, + diffusion_form=args.diffusion_form, + diffusion_norm=args.diffusion_norm, + last_step=args.last_step, + last_step_size=args.last_step_size, + num_steps=args.num_sampling_steps, + ) + + ae = getae(args).to(device) + + if args.use_fp16: + print('WARNING: using half percision for inferencing!') + ae.to(dtype=torch.float16) + model.to(dtype=torch.float16) + + # Labels to condition the model with (feel free to change): + + # Create sampling noise: + if args.use_fp16: + z = torch.randn(1, args.num_frames // ae_stride_config[args.ae][0], model.in_channels, latent_size, latent_size, dtype=torch.float16, device=device) # b c f h w + else: + z = torch.randn(1, args.num_frames // ae_stride_config[args.ae][0], model.in_channels, latent_size, latent_size, device=device) + + # Setup classifier-free guidance: + if using_cfg: + z = torch.cat([z, z], 0) + y = torch.randint(0, args.num_classes, (1,), device=device) + y_null = torch.tensor([args.num_classes] * 1, device=device) + y = torch.cat([y, y_null], dim=0) + model_kwargs = dict(y=y, cfg_scale=args.cfg_scale, use_fp16=args.use_fp16) + forward_fn = model.forward_with_cfg + else: + forward_fn = model.forward + model_kwargs = dict(y=None, use_fp16=args.use_fp16) + + # Sample images: + samples = sample_fn(z, forward_fn, **model_kwargs)[-1] + + if args.use_fp16: + samples = samples.to(dtype=torch.float16) + samples = ae.decode(samples) + + # Save and display images: + if not os.path.exists(args.save_video_path): + os.makedirs(args.save_video_path) + + + video_ = (ae_denorm[args.ae](samples[0]) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1).contiguous() + video_save_path = os.path.join(args.save_video_path, 'sample' + '.mp4') + print(video_save_path) + imageio.mimwrite(video_save_path, video_, fps=args.fps, quality=9) + print('save path {}'.format(args.save_video_path)) + + +def none_or_str(value): + if value == 'None': + return None + return value + +def parse_transport_args(parser): + group = parser.add_argument_group("Transport arguments") + group.add_argument("--path-type", type=str, default="Linear", choices=["Linear", "GVP", "VP"]) + group.add_argument("--prediction", type=str, default="velocity", choices=["velocity", "score", "noise"]) + group.add_argument("--loss-weight", type=none_or_str, default=None, choices=[None, "velocity", "likelihood"]) + group.add_argument("--sample-eps", type=float) + group.add_argument("--train-eps", type=float) + +def parse_ode_args(parser): + group = parser.add_argument_group("ODE arguments") + group.add_argument("--sampling-method", type=str, default="dopri5", help="blackbox ODE solver methods; for full list check https://github.com/rtqichen/torchdiffeq") + group.add_argument("--atol", type=float, default=1e-6, help="Absolute tolerance") + group.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance") + group.add_argument("--reverse", action="store_true") + group.add_argument("--likelihood", action="store_true") + +def parse_sde_args(parser): + group = parser.add_argument_group("SDE arguments") + group.add_argument("--sampling-method", type=str, default="Euler", choices=["Euler", "Heun"]) + group.add_argument("--diffusion-form", type=str, default="sigma", \ + choices=["constant", "SBDM", "sigma", "linear", "decreasing", "increasing-decreasing"],\ + help="form of diffusion coefficient in the SDE") + group.add_argument("--diffusion-norm", type=float, default=1.0) + group.add_argument("--last-step", type=none_or_str, default="Mean", choices=[None, "Mean", "Tweedie", "Euler"],\ + help="form of last step taken in the SDE") + group.add_argument("--last-step-size", type=float, default=0.04, \ + help="size of the last step taken") + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: program.py [options]") + sys.exit(1) + + mode = sys.argv[1] + + assert mode[:2] != "--", "Usage: program.py [options]" + assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'" + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt", type=str, default="") + parser.add_argument("--model", type=str, default='Latte-XL/122') + parser.add_argument("--ae", type=str, default='stabilityai/sd-vae-ft-mse') + parser.add_argument("--save-video-path", type=str, default="./sample_videos/") + parser.add_argument("--fps", type=int, default=10) + parser.add_argument("--num-classes", type=int, default=101) + parser.add_argument("--num-frames", type=int, default=16) + parser.add_argument("--image-size", type=int, default=256, choices=[256, 512]) + parser.add_argument("--extras", type=int, default=1) + parser.add_argument("--num-sampling-steps", type=int, default=250) + parser.add_argument("--cfg-scale", type=float, default=1.0) + parser.add_argument("--use-fp16", action="store_true") + parser.add_argument("--use-compile", action="store_true") + parser.add_argument("--sample-method", type=str, default='ddpm') + + parse_transport_args(parser) + if mode == "ODE": + parse_ode_args(parser) + # Further processing for ODE + elif mode == "SDE": + parse_sde_args(parser) + # Further processing for SDE + + args = parser.parse_known_args()[0] + main(mode, args) diff --git a/opensora/serve/gradio_utils.py b/opensora/serve/gradio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5b6a898770970dcc3c0b7a2e419a905d946fce5f --- /dev/null +++ b/opensora/serve/gradio_utils.py @@ -0,0 +1,65 @@ +import random + +import torch + + +def set_env(seed=0): + torch.manual_seed(seed) + torch.set_grad_enabled(False) + + +def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, 203279) + return seed + +title_markdown = (""" +
+ +
+""") +DESCRIPTION = """ +# Open-Sora-Plan v1.0.0 +## If Open-Sora-Plan is helpful, please help to ✨ the [Github Repo](https://github.com/PKU-YuanGroup/Open-Sora-Plan) and recommend it to your friends 😊' +#### [Open-Sora-Plan v1.0.0](https://github.com/PKU-YuanGroup/Open-Sora-Plan) is a transformer-based text-to-video diffusion system trained on text embeddings from T5. +#### This demo is only trained on 40k videos, when creating videos, please be aware that it has the potential to generate harmful videos. For more details read our [report](). +#### Image generation is typically 50 steps, video generation maybe 250 steps will yield good results, but this may take 2-3 minutes. +#### Feel free to enjoy the examples. +#### English prompts ONLY; 提示词仅限英文 +#### +""" + +#
+#
+#
+# +# +#
+#
+# """) + +block_css = """ +#buttons button { + min-width: min(120px,100%); +} +""" + + +examples = [ + ["A quiet beach at dawn, the waves gently lapping at the shore and the sky painted in pastel hues.", 50, 10.0], + ["A quiet beach at dawn, the waves softly lapping at the shore, pink and orange hues painting the sky, offering a moment of solitude and reflection.", 50, 10.0], + ["The majestic beauty of a waterfall cascading down a cliff into a serene lake.", 50, 10.0], + ["Sunset over the sea.", 50, 10.0], + ["a cat wearing sunglasses and working as a lifeguard at pool.", 50, 10.0], + ["Slow pan upward of blazing oak fire in an indoor fireplace.", 50, 10.0], + ["Yellow and black tropical fish dart through the sea.", 50, 10.0], + ["a serene winter scene in a forest. The forest is blanketed in a thick layer of snow, which has settled on the branches of the trees, creating a canopy of white. The trees, a mix of evergreens and deciduous, stand tall and silent, their forms partially obscured by the snow. The ground is a uniform white, with no visible tracks or signs of human activity. The sun is low in the sky, casting a warm glow that contrasts with the cool tones of the snow. The light filters through the trees, creating a soft, diffused illumination that highlights the texture of the snow and the contours of the trees. The overall style of the scene is naturalistic, with a focus on the tranquility and beauty of the winter landscape.", 50, 10.0], + ["a dynamic interaction between the ocean and a large rock. The rock, with its rough texture and jagged edges, is partially submerged in the water, suggesting it is a natural feature of the coastline. The water around the rock is in motion, with white foam and waves crashing against the rock, indicating the force of the ocean's movement. The background is a vast expanse of the ocean, with small ripples and waves, suggesting a moderate sea state. The overall style of the scene is a realistic depiction of a natural landscape, with a focus on the interplay between the rock and the water.", 50, 10.0], + ["A serene waterfall cascading down moss-covered rocks, its soothing sound creating a harmonious symphony with nature.", 50, 10.0], + ["A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures.", 50, 10.0], + ["The video captures the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene. The camera angle provides a bird's eye view of the waterfall, allowing viewers to appreciate the full height and grandeur of the waterfall. The video is a stunning representation of nature's power and beauty.", 50, 10.0], + ["A vibrant scene of a snowy mountain landscape. The sky is filled with a multitude of colorful hot air balloons, each floating at different heights, creating a dynamic and lively atmosphere. The balloons are scattered across the sky, some closer to the viewer, others further away, adding depth to the scene. Below, the mountainous terrain is blanketed in a thick layer of snow, with a few patches of bare earth visible here and there. The snow-covered mountains provide a stark contrast to the colorful balloons, enhancing the visual appeal of the scene.", 50, 10.0], + ["A serene underwater scene featuring a sea turtle swimming through a coral reef. The turtle, with its greenish-brown shell, is the main focus of the video, swimming gracefully towards the right side of the frame. The coral reef, teeming with life, is visible in the background, providing a vibrant and colorful backdrop to the turtle's journey. Several small fish, darting around the turtle, add a sense of movement and dynamism to the scene.", 50, 10.0], + ["A snowy forest landscape with a dirt road running through it. The road is flanked by trees covered in snow, and the ground is also covered in snow. The sun is shining, creating a bright and serene atmosphere. The road appears to be empty, and there are no people or animals visible in the video. The style of the video is a natural landscape shot, with a focus on the beauty of the snowy forest and the peacefulness of the road.", 50, 10.0], + ["The dynamic movement of tall, wispy grasses swaying in the wind. The sky above is filled with clouds, creating a dramatic backdrop. The sunlight pierces through the clouds, casting a warm glow on the scene. The grasses are a mix of green and brown, indicating a change in seasons. The overall style of the video is naturalistic, capturing the beauty of the landscape in a realistic manner. The focus is on the grasses and their movement, with the sky serving as a secondary element. The video does not contain any human or animal elements.", 50, 10.0], +] \ No newline at end of file diff --git a/opensora/serve/gradio_web_server.py b/opensora/serve/gradio_web_server.py new file mode 100644 index 0000000000000000000000000000000000000000..479b1140b4cf129f74af918528191d68c1d3246e --- /dev/null +++ b/opensora/serve/gradio_web_server.py @@ -0,0 +1,123 @@ + + +import argparse +import sys +import os +import random + +import imageio +import torch +from diffusers import PNDMScheduler +from huggingface_hub import hf_hub_download +from torchvision.utils import save_image +from diffusers.models import AutoencoderKL +from datetime import datetime +from typing import List, Union +import gradio as gr +import numpy as np +from gradio.components import Textbox, Video, Image +from transformers import T5Tokenizer, T5EncoderModel + +from opensora.models.ae import ae_stride_config, getae, getae_wrapper +from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper +from opensora.models.diffusion.latte.modeling_latte import LatteT2V +from opensora.sample.pipeline_videogen import VideoGenPipeline +from opensora.serve.gradio_utils import block_css, title_markdown, randomize_seed_fn, set_env, examples, DESCRIPTION + + +@torch.inference_mode() +def generate_img(prompt, sample_steps, scale, seed=0, randomize_seed=False, force_images=False): + seed = int(randomize_seed_fn(seed, randomize_seed)) + set_env(seed) + video_length = transformer_model.config.video_length if not force_images else 1 + height, width = int(args.version.split('x')[1]), int(args.version.split('x')[2]) + num_frames = 1 if video_length == 1 else int(args.version.split('x')[0]) + videos = videogen_pipeline(prompt, + video_length=video_length, + height=height, + width=width, + num_inference_steps=sample_steps, + guidance_scale=scale, + enable_temporal_attentions=not force_images, + num_images_per_prompt=1, + mask_feature=True, + ).video + + torch.cuda.empty_cache() + videos = videos[0] + tmp_save_path = 'tmp.mp4' + imageio.mimwrite(tmp_save_path, videos, fps=24, quality=9) # highest quality is 10, lowest is 0 + display_model_info = f"Video size: {num_frames}×{height}×{width}, \nSampling Step: {sample_steps}, \nGuidance Scale: {scale}" + return tmp_save_path, prompt, display_model_info, seed + +if __name__ == '__main__': + args = type('args', (), { + 'ae': 'CausalVAEModel_4x8x8', + 'force_images': False, + 'model_path': 'LanguageBind/Open-Sora-Plan-v1.0.0', + 'text_encoder_name': 'DeepFloyd/t5-v1_1-xxl', + 'version': '65x512x512' + }) + device = torch.device('cuda:0') + + # Load model: + transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, torch_dtype=torch.float16, cache_dir='cache_dir').to(device) + + vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir').to(device, dtype=torch.float16) + vae.vae.enable_tiling() + image_size = int(args.version.split('x')[1]) + latent_size = (image_size // ae_stride_config[args.ae][1], image_size // ae_stride_config[args.ae][2]) + vae.latent_size = latent_size + transformer_model.force_images = args.force_images + tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir="cache_dir") + text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir", + torch_dtype=torch.float16).to(device) + + # set eval mode + transformer_model.eval() + vae.eval() + text_encoder.eval() + scheduler = PNDMScheduler() + videogen_pipeline = VideoGenPipeline(vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer_model).to(device=device) + + + demo = gr.Interface( + fn=generate_img, + inputs=[Textbox(label="", + placeholder="Please enter your prompt. \n"), + gr.Slider( + label='Sample Steps', + minimum=1, + maximum=500, + value=50, + step=10 + ), + gr.Slider( + label='Guidance Scale', + minimum=0.1, + maximum=30.0, + value=10.0, + step=0.1 + ), + gr.Slider( + label="Seed", + minimum=0, + maximum=203279, + step=1, + value=0, + ), + gr.Checkbox(label="Randomize seed", value=True), + gr.Checkbox(label="Generate image (1 frame video)", value=False), + ], + outputs=[Video(label="Vid", width=512, height=512), + Textbox(label="input prompt"), + Textbox(label="model info"), + gr.Slider(label='seed')], + title=title_markdown, description=DESCRIPTION, theme=gr.themes.Default(), css=block_css, + examples=examples, + ) + demo.launch() \ No newline at end of file diff --git a/opensora/train/train.py b/opensora/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..70bfb3faafafbc5498969075b6afd3f8878b293a --- /dev/null +++ b/opensora/train/train.py @@ -0,0 +1,775 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +A minimal training script for DiT using PyTorch DDP. +""" +import argparse +import logging +import math +import os +import shutil +from pathlib import Path +from typing import Optional + +import numpy as np +from einops import rearrange +from tqdm import tqdm +from dataclasses import field, dataclass +from torch.utils.data import DataLoader +from copy import deepcopy + +import accelerate +import torch +from torch.nn import functional as F +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo +from packaging import version +from tqdm.auto import tqdm +from transformers import HfArgumentParser, TrainingArguments + +import diffusers +from diffusers import DDPMScheduler, PNDMScheduler +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel, compute_snr +from diffusers.utils import check_min_version, is_wandb_available + +from opensora.dataset import getdataset, ae_denorm +from opensora.models.ae import getae, getae_wrapper +from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper +from opensora.models.diffusion.diffusion import create_diffusion_T as create_diffusion +from opensora.models.diffusion.latte.modeling_latte import Latte +from opensora.utils.dataset_utils import Collate +from opensora.models.ae import ae_stride_config, ae_channel_config +from opensora.models.diffusion import Diffusion_models + + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.24.0") +logger = get_logger(__name__) + + + + +def generate_timestep_weights(args, num_timesteps): + weights = torch.ones(num_timesteps) + + # Determine the indices to bias + num_to_bias = int(args.timestep_bias_portion * num_timesteps) + + if args.timestep_bias_strategy == "later": + bias_indices = slice(-num_to_bias, None) + elif args.timestep_bias_strategy == "earlier": + bias_indices = slice(0, num_to_bias) + elif args.timestep_bias_strategy == "range": + # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500. + range_begin = args.timestep_bias_begin + range_end = args.timestep_bias_end + if range_begin < 0: + raise ValueError( + "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero." + ) + if range_end > num_timesteps: + raise ValueError( + "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps." + ) + bias_indices = slice(range_begin, range_end) + else: # 'none' or any other string + return weights + if args.timestep_bias_multiplier <= 0: + return ValueError( + "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps." + " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead." + " A timestep bias multiplier less than or equal to 0 is not allowed." + ) + + # Apply the bias + weights[bias_indices] *= args.timestep_bias_multiplier + + # Normalize + weights /= weights.sum() + + return weights + +################################################################################# +# Training Loop # +################################################################################# + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # if args.push_to_hub: + # repo_id = create_repo( + # repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + # ).repo_id + + + # Create model: + + diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule + ae = getae_wrapper(args.ae)(args.ae_path).eval() + if args.enable_tiling: + ae.vae.enable_tiling() + ae.vae.tile_overlap_factor = args.tile_overlap_factor + + ae_stride_t, ae_stride_h, ae_stride_w = ae_stride_config[args.ae] + args.ae_stride_t, args.ae_stride_h, args.ae_stride_w = ae_stride_t, ae_stride_h, ae_stride_w + args.ae_stride = args.ae_stride_h + patch_size = args.model[-3:] + patch_size_t, patch_size_h, patch_size_w = int(patch_size[0]), int(patch_size[1]), int(patch_size[2]) + args.patch_size = patch_size_h + args.patch_size_t, args.patch_size_h, args.patch_size_w = patch_size_t, patch_size_h, patch_size_w + assert ae_stride_h == ae_stride_w, f"Support only ae_stride_h == ae_stride_w now, but found ae_stride_h ({ae_stride_h}), ae_stride_w ({ae_stride_w})" + assert patch_size_h == patch_size_w, f"Support only patch_size_h == patch_size_w now, but found patch_size_h ({patch_size_h}), patch_size_w ({patch_size_w})" + # assert args.num_frames % ae_stride_t == 0, f"Num_frames must be divisible by ae_stride_t, but found num_frames ({args.num_frames}), ae_stride_t ({ae_stride_t})." + assert args.max_image_size % ae_stride_h == 0, f"Image size must be divisible by ae_stride_h, but found max_image_size ({args.max_image_size}), ae_stride_h ({ae_stride_h})." + + latent_size = (args.max_image_size // ae_stride_h, args.max_image_size // ae_stride_w) + + if getae_wrapper(args.ae) == CausalVQVAEModelWrapper or getae_wrapper(args.ae) == CausalVAEModelWrapper: + args.video_length = video_length = args.num_frames // ae_stride_t + 1 + else: + args.video_length = video_length = args.num_frames // ae_stride_t + model = Diffusion_models[args.model]( + in_channels=ae_channel_config[args.ae], + out_channels=ae_channel_config[args.ae] * 2, + caption_channels=None, # unconditon + cross_attention_dim=None, # unconditon + attention_bias=True, + sample_size=latent_size, + num_vector_embeds=None, + activation_fn="gelu-approximate", + num_embeds_ada_norm=args.num_classes if args.train_classcondition else 1000, + use_linear_projection=False, + only_cross_attention=False, + double_self_attention=False, + upcast_attention=False, + # norm_type="ada_norm_single", + norm_elementwise_affine=False, + norm_eps=1e-6, + attention_type='default', + video_length=video_length, + attention_mode=args.attention_mode, + # compress_kv=args.compress_kv + ) + model.gradient_checkpointing = args.gradient_checkpointing + + # # use pretrained model? + if args.pretrained: + checkpoint = torch.load(args.pretrained, map_location='cpu')['model'] + model_state_dict = model.state_dict() + missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) + logger.info(f'missing_keys {len(missing_keys)}, unexpected_keys {len(unexpected_keys)}') + logger.info(f'Successfully load {len(model.state_dict()) - len(missing_keys)}/{len(model_state_dict)} keys from {args.pretrained}!') + ''' + pretrained_state_dict = torch.load(args.pretrained, map_location='cpu')['model'] + model_state_dict = model.state_dict() + load_state_dict = {k: v for k, v in model_state_dict.items() if pretrained_state_dict.get(k, None) is not None and v.numel() == pretrained_state_dict[k].numel()} + missing_keys, unexpected_keys = model.load_state_dict(load_state_dict, strict=False) + logger.info(f'missing_keys {missing_keys}, unexpected_keys {unexpected_keys}') + logger.info(f'Successfully load {len(model_state_dict) - len(missing_keys)}/{len(model_state_dict)} keys from {args.pretrained}!') + ''' + + # Freeze vae and text encoders. + ae.requires_grad_(False) + # Set model as trainable. + model.train() + + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + ae.to(accelerator.device, dtype=torch.float32) + + # Create EMA for the unet. + if args.use_ema: + ema_model = deepcopy(model) + ema_model = EMAModel(ema_model.parameters(), model_cls=Latte, model_config=ema_model.config) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_model.save_pretrained(os.path.join(output_dir, "model_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "model")) + + if weights: + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "model_ema"), Latte) + ema_model.load_state_dict(load_model.state_dict()) + ema_model.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = Latte.from_pretrained(input_dir, subfolder="model") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = model.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Setup data: + train_dataset = getdataset(args) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + # collate_fn=Collate(args), # TODO: do not enable dynamic mask in this point + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers(args.output_dir, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + for step, (x, y) in enumerate(train_dataloader): + with accelerator.accumulate(model): + # Sample noise that we'll add to the latents + x = x.to(accelerator.device) # B C T H W + if args.train_classcondition: + y = y.to(accelerator.device) + # attn_mask = attn_mask.to(device) # B T H W + # assert torch.all(attn_mask.bool()), 'do not enable dynamic input' + attn_mask = None + + model_kwargs = dict(class_labels=y if args.train_classcondition else None, + attention_mask=attn_mask, + use_image_num=args.use_image_num) + with torch.no_grad(): + # Map input images to latent space + normalize latents + if args.use_image_num == 0: + x = ae.encode(x) # B C T H W + else: + videos, images = x[:, :, :-args.use_image_num], x[:, :, -args.use_image_num:] + videos = ae.encode(videos) # B C T H W + images = rearrange(images, 'b c t h w -> (b t) c 1 h w') + images = ae.encode(images) + images = rearrange(images, '(b t) c 1 h w -> b c t h w', t=args.use_image_num) + x = torch.cat([videos, images], dim=2) + t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=accelerator.device) + loss_dict = diffusion.training_losses(model, x, t, model_kwargs) + loss = loss_dict["loss"].mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = model.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if args.use_deepspeed or accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + logger.info(f"Running validation... \n" + f"Generating {args.num_validation_videos} videos") + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_model.store(model.parameters()) + ema_model.copy_to(model.parameters()) + if args.enable_tracker: + with torch.no_grad(): + # create pipeline + ae_ = getae_wrapper(args.ae)(args.ae_path).to(accelerator.device).eval() + model_ = Latte.from_pretrained(save_path, subfolder="model").to(accelerator.device).eval() + diffusion_ = create_diffusion(str(500)) + videos = [] + ys = [] + for _ in range(args.num_validation_videos): + with torch.autocast(device_type='cuda', dtype=weight_dtype): + z = torch.randn(1, model_.in_channels, video_length, + latent_size[0], latent_size[1], device=accelerator.device) + if args.train_classcondition: + y = torch.randint(0, args.num_classes, (1,), device=accelerator.device) + ys.append(y.detach().cpu.items()) + sample_fn = model_.forward + model_kwargs = dict(class_labels=y if args.train_classcondition else None, attention_mask=None) + # Sample images: + samples = diffusion_.p_sample_loop( + sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, + device=accelerator.device + ) + samples = ae_.decode(samples) + # Save and display images: + video = (ae_denorm[args.ae](samples[0]) * 255).add_(0.5).clamp_(0, 255).to( + dtype=torch.uint8).cpu().contiguous() # t c h w + videos.append(video) + + videos = torch.stack(videos).numpy() + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_videos = np.stack([np.asarray(vid) for vid in videos]) + tracker.writer.add_video("validation", np_videos, global_step, fps=10) + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Video(video, caption=f"{i}: {str(ys[i])}" if args.train_classcondition else f"{i}", fps=10) + for i, video in enumerate(videos) + ] + } + ) + + del ae_, model_, diffusion_ + torch.cuda.empty_cache() + + accelerator.wait_for_everyone() + accelerator.end_training() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, required=True) + parser.add_argument("--data_path", type=str, required=True) + parser.add_argument("--model", type=str, choices=list(Diffusion_models.keys()), default="DiT-XL/122") + parser.add_argument("--num_classes", type=int, default=1000) + parser.add_argument("--ae", type=str, default="stabilityai/sd-vae-ft-mse") + parser.add_argument("--sample_rate", type=int, default=4) + parser.add_argument("--num_frames", type=int, default=16) + parser.add_argument("--max_image_size", type=int, default=128) + parser.add_argument("--dynamic_frames", action="store_true") + parser.add_argument("--compress_kv", action="store_true") + parser.add_argument("--attention_mode", type=str, choices=['xformers', 'math', 'flash'], default="math") + + + parser.add_argument('--tile_overlap_factor', type=float, default=0.25) + parser.add_argument('--enable_tiling', action='store_true') + + parser.add_argument("--pretrained", type=str, default=None) + parser.add_argument("--train_classcondition", action="store_true") + + parser.add_argument("--enable_tracker", action="store_true") + parser.add_argument("--use_image_num", type=int, default=0) + parser.add_argument("--use_img_from_vid", action="store_true") + parser.add_argument("--use_deepspeed", action="store_true") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--num_validation_videos", + type=int, + default=2, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--timestep_bias_strategy", + type=str, + default="none", + choices=["earlier", "later", "range", "none"], + help=( + "The timestep bias strategy, which may help direct the model toward learning low or high frequency details." + " Choices: ['earlier', 'later', 'range', 'none']." + " The default is 'none', which means no bias is applied, and training proceeds normally." + " The value of 'later' will increase the frequency of the model's final training timesteps." + ), + ) + parser.add_argument( + "--timestep_bias_multiplier", + type=float, + default=1.0, + help=( + "The multiplier for the bias. Defaults to 1.0, which means no bias is applied." + " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it." + ), + ) + parser.add_argument( + "--timestep_bias_begin", + type=int, + default=0, + help=( + "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias." + " Defaults to zero, which equates to having no specific bias." + ), + ) + parser.add_argument( + "--timestep_bias_end", + type=int, + default=1000, + help=( + "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias." + " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on." + ), + ) + parser.add_argument( + "--timestep_bias_portion", + type=float, + default=0.25, + help=( + "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased." + " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines" + " whether the biased portions are in the earlier or later timesteps." + ), + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=10, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/opensora/train/train_causalvae.py b/opensora/train/train_causalvae.py new file mode 100644 index 0000000000000000000000000000000000000000..731e2cfa23f050cc3cd410acdb824cb41056931b --- /dev/null +++ b/opensora/train/train_causalvae.py @@ -0,0 +1,104 @@ +import sys +sys.path.append(".") +import torch +import random +import numpy as np +from opensora.models.ae.videobase import ( + CausalVAEModel, +) +from torch.utils.data import DataLoader +from opensora.models.ae.videobase.dataset_videobase import VideoDataset +import argparse +from transformers import HfArgumentParser +from dataclasses import dataclass, field, asdict +import torch.distributed as dist +import os +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor + +@dataclass +class TrainingArguments: + exp_name: str = field(default="causalvae") + batch_size: int = field(default=1) + precision: str = field(default="bf16") + max_steps: int = field(default=100000) + save_steps: int = field(default=2000) + output_dir: str = field(default="results/causalvae") + video_path: str = field(default="/remote-home1/dataset/data_split_tt") + video_num_frames: int = field(default=17) + sample_rate: int = field(default=1) + dynamic_sample: bool = field(default=False) + model_config: str = field(default="scripts/causalvae/288.yaml") + n_nodes: int = field(default=1) + devices: int = field(default=8) + resolution: int = field(default=64) + num_workers: int = field(default=8) + resume_from_checkpoint: str = field(default=None) + +def set_seed(seed=1006): + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + +def load_callbacks_and_logger(args): + checkpoint_callback = ModelCheckpoint( + dirpath=args.output_dir, + filename="model-{epoch:02d}-{step}", + every_n_train_steps=args.save_steps, + save_top_k=-1, + save_on_train_epoch_end=False, + ) + lr_monitor = LearningRateMonitor(logging_interval="step") + logger = WandbLogger(name=args.exp_name, log_model=False) + return [checkpoint_callback, lr_monitor], logger + +def train(args): + set_seed() + # Load Config + model = CausalVAEModel() + if args.resume_from_checkpoint is not None: + model = CausalVAEModel.from_pretrained(args.resume_from_checkpoint) + else: + model = CausalVAEModel.from_config(args.model_config) + + if (dist.is_initialized() and dist.get_rank() == 0) or not dist.is_initialized(): + print(model) + + # Load Dataset + dataset = VideoDataset(args.video_path, sequence_length=args.video_num_frames, resolution=args.resolution, sample_rate=args.sample_rate, dynamic_sample=args.dynamic_sample) + train_loader = DataLoader( + dataset, + shuffle=True, + num_workers=args.num_workers, + batch_size=args.batch_size, + pin_memory=True, + ) + # Load Callbacks and Logger + callbacks, logger = load_callbacks_and_logger(args) + # Load Trainer + trainer = pl.Trainer( + accelerator="cuda", + devices=args.devices, + num_nodes=args.n_nodes, + callbacks=callbacks, + logger=logger, + log_every_n_steps=5, + precision=args.precision, + max_steps=args.max_steps, + strategy="ddp_find_unused_parameters_true" + ) + trainer_kwargs = {} + if args.resume_from_checkpoint: + trainer_kwargs['ckpt_path'] = args.resume_from_checkpoint + + trainer.fit( + model, + train_loader, + **trainer_kwargs + ) + +if __name__ == "__main__": + parser = HfArgumentParser(TrainingArguments) + args = parser.parse_args_into_dataclasses() + train(args[0]) diff --git a/opensora/train/train_t2v.py b/opensora/train/train_t2v.py new file mode 100644 index 0000000000000000000000000000000000000000..29fbfd7b2d0c13ae8e44c23e5b40cb0ca8abdde3 --- /dev/null +++ b/opensora/train/train_t2v.py @@ -0,0 +1,803 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +A minimal training script for DiT using PyTorch DDP. +""" +import argparse +import logging +import math +import os +import shutil +from pathlib import Path +from typing import Optional + +import numpy as np +from einops import rearrange +from tqdm import tqdm +from dataclasses import field, dataclass +from torch.utils.data import DataLoader +from copy import deepcopy + +import accelerate +import torch +from torch.nn import functional as F +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo +from packaging import version +from tqdm.auto import tqdm +from transformers import HfArgumentParser, TrainingArguments, AutoTokenizer + +import diffusers +from diffusers import DDPMScheduler, PNDMScheduler +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel, compute_snr +from diffusers.utils import check_min_version, is_wandb_available + +from opensora.dataset import getdataset, ae_denorm +from opensora.models.ae import getae, getae_wrapper +from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper +from opensora.models.diffusion.diffusion import create_diffusion_T as create_diffusion +from opensora.models.diffusion.latte.modeling_latte import LatteT2V +from opensora.models.text_encoder import get_text_enc +from opensora.utils.dataset_utils import Collate +from opensora.models.ae import ae_stride_config, ae_channel_config +from opensora.models.diffusion import Diffusion_models + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.24.0") +logger = get_logger(__name__) + + +def generate_timestep_weights(args, num_timesteps): + weights = torch.ones(num_timesteps) + + # Determine the indices to bias + num_to_bias = int(args.timestep_bias_portion * num_timesteps) + + if args.timestep_bias_strategy == "later": + bias_indices = slice(-num_to_bias, None) + elif args.timestep_bias_strategy == "earlier": + bias_indices = slice(0, num_to_bias) + elif args.timestep_bias_strategy == "range": + # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500. + range_begin = args.timestep_bias_begin + range_end = args.timestep_bias_end + if range_begin < 0: + raise ValueError( + "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero." + ) + if range_end > num_timesteps: + raise ValueError( + "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps." + ) + bias_indices = slice(range_begin, range_end) + else: # 'none' or any other string + return weights + if args.timestep_bias_multiplier <= 0: + return ValueError( + "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps." + " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead." + " A timestep bias multiplier less than or equal to 0 is not allowed." + ) + + # Apply the bias + weights[bias_indices] *= args.timestep_bias_multiplier + + # Normalize + weights /= weights.sum() + + return weights + + +################################################################################# +# Training Loop # +################################################################################# + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # if args.push_to_hub: + # repo_id = create_repo( + # repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + # ).repo_id + + # Create model: + + diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule + ae = getae_wrapper(args.ae)(args.ae_path).eval() + if args.enable_tiling: + ae.vae.enable_tiling() + ae.vae.tile_overlap_factor = args.tile_overlap_factor + text_enc = get_text_enc(args).eval() + + ae_stride_t, ae_stride_h, ae_stride_w = ae_stride_config[args.ae] + args.ae_stride_t, args.ae_stride_h, args.ae_stride_w = ae_stride_t, ae_stride_h, ae_stride_w + args.ae_stride = args.ae_stride_h + patch_size = args.model[-3:] + patch_size_t, patch_size_h, patch_size_w = int(patch_size[0]), int(patch_size[1]), int(patch_size[2]) + args.patch_size = patch_size_h + args.patch_size_t, args.patch_size_h, args.patch_size_w = patch_size_t, patch_size_h, patch_size_w + assert ae_stride_h == ae_stride_w, f"Support only ae_stride_h == ae_stride_w now, but found ae_stride_h ({ae_stride_h}), ae_stride_w ({ae_stride_w})" + assert patch_size_h == patch_size_w, f"Support only patch_size_h == patch_size_w now, but found patch_size_h ({patch_size_h}), patch_size_w ({patch_size_w})" + # assert args.num_frames % ae_stride_t == 0, f"Num_frames must be divisible by ae_stride_t, but found num_frames ({args.num_frames}), ae_stride_t ({ae_stride_t})." + assert args.max_image_size % ae_stride_h == 0, f"Image size must be divisible by ae_stride_h, but found max_image_size ({args.max_image_size}), ae_stride_h ({ae_stride_h})." + + latent_size = (args.max_image_size // ae_stride_h, args.max_image_size // ae_stride_w) + + if getae_wrapper(args.ae) == CausalVQVAEModelWrapper or getae_wrapper(args.ae) == CausalVAEModelWrapper: + args.video_length = video_length = args.num_frames // ae_stride_t + 1 + else: + video_length = args.num_frames // ae_stride_t + model = Diffusion_models[args.model]( + in_channels=ae_channel_config[args.ae], + out_channels=ae_channel_config[args.ae] * 2, + # caption_channels=4096, + # cross_attention_dim=1152, + attention_bias=True, + sample_size=latent_size, + num_vector_embeds=None, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + use_linear_projection=False, + only_cross_attention=False, + double_self_attention=False, + upcast_attention=False, + # norm_type="ada_norm_single", + norm_elementwise_affine=False, + norm_eps=1e-6, + attention_type='default', + video_length=video_length, + attention_mode=args.attention_mode, + # compress_kv=args.compress_kv + ) + model.gradient_checkpointing = args.gradient_checkpointing + + # # use pretrained model? + if args.pretrained: + checkpoint = torch.load(args.pretrained, map_location='cpu')['model'] + model_state_dict = model.state_dict() + missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) + logger.info(f'missing_keys {len(missing_keys)}, unexpected_keys {len(unexpected_keys)}') + logger.info(f'Successfully load {len(model.state_dict()) - len(missing_keys)}/{len(model_state_dict)} keys from {args.pretrained}!') + # load from pixart-alpha + # pixelart_alpha = torch.load(args.pretrained, map_location='cpu')['state_dict'] + # checkpoint = {} + # for k, v in pixelart_alpha.items(): + # if 'x_embedder' in k or 't_embedder' in k or 'y_embedder' in k: + # checkpoint[k] = v + # if k.startswith('blocks'): + # k_spilt = k.split('.') + # blk_id = str(int(k_spilt[1]) * 2) + # k_spilt[1] = blk_id + # new_k = '.'.join(k_spilt) + # checkpoint[new_k] = v + # missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) + # logger.info(f'Successfully load {len(model.state_dict()) - len(missing_keys)} keys from {args.pretrained}!') + + # Freeze vae and text encoders. + ae.requires_grad_(False) + text_enc.requires_grad_(False) + # Set model as trainable. + model.train() + + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + ae.to(accelerator.device, dtype=torch.float32) + text_enc.to(accelerator.device, dtype=weight_dtype) + + # Create EMA for the unet. + if args.use_ema: + ema_model = deepcopy(model) + ema_model = EMAModel(ema_model.parameters(), model_cls=LatteT2V, model_config=ema_model.config) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_model.save_pretrained(os.path.join(output_dir, "model_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "model")) + if weights: # Don't pop if empty + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "model_ema"), LatteT2V) + ema_model.load_state_dict(load_model.state_dict()) + ema_model.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = LatteT2V.from_pretrained(input_dir, subfolder="model") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = model.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Setup data: + train_dataset = getdataset(args) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + # collate_fn=Collate(args), # TODO: do not enable dynamic mask in this point + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers(args.output_dir, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + for step, (x, input_ids, cond_mask) in enumerate(train_dataloader): + with accelerator.accumulate(model): + # Sample noise that we'll add to the latents + x = x.to(accelerator.device) # B C T+num_images H W, 16 + 4 + # attn_mask = attn_mask.to(device) # B T H W + # assert torch.all(attn_mask.bool()), 'do not enable dynamic input' + attn_mask = None + input_ids = input_ids.to(accelerator.device) # B L or B 1+num_images L + cond_mask = cond_mask.to(accelerator.device) # B L or B 1+num_images L + + with torch.no_grad(): + # Map input images to latent space + normalize latents + if args.use_image_num == 0: + x = ae.encode(x) # B C T H W + cond = text_enc(input_ids, cond_mask) # B L -> B L D + else: + videos, images = x[:, :, :-args.use_image_num], x[:, :, -args.use_image_num:] + videos = ae.encode(videos) # B C T H W + images = rearrange(images, 'b c t h w -> (b t) c 1 h w') + images = ae.encode(images) + images = rearrange(images, '(b t) c 1 h w -> b c t h w', t=args.use_image_num) + x = torch.cat([videos, images], dim=2) # b c 16+4, h, w + + # use for loop to avoid OOM, because T5 is too huge... + B, _, _ = input_ids.shape # B T+num_images L b 1+4, L + cond = torch.stack([text_enc(input_ids[i], cond_mask[i]) for i in range(B)]) # B 1+num_images L D + + model_kwargs = dict(encoder_hidden_states=cond, attention_mask=attn_mask, + encoder_attention_mask=cond_mask, use_image_num=args.use_image_num) + t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=accelerator.device) + loss_dict = diffusion.training_losses(model, x, t, model_kwargs) + loss = loss_dict["loss"].mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = model.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if args.use_deepspeed or accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + validation_prompt = "The majestic beauty of a waterfall cascading down a cliff into a serene lake. The camera angle provides a bird's eye view of the waterfall." + if global_step % args.checkpointing_steps == 0: + logger.info(f"Running validation... \n" + f"Generating {args.num_validation_videos} videos with prompt: {validation_prompt}") + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_model.store(model.parameters()) + ema_model.copy_to(model.parameters()) + + if args.enable_tracker: + with torch.no_grad(): + # create pipeline + ae_ = getae_wrapper(args.ae)(args.ae_path).to(accelerator.device).eval() + if args.enable_tiling: + ae_.vae.enable_tiling() + ae_.vae.tile_overlap_factor = args.tile_overlap_factor + # text_enc_ = get_text_enc(args).to(accelerator.device).eval() + model_ = LatteT2V.from_pretrained(save_path, subfolder="model").to(accelerator.device).eval() + diffusion_ = create_diffusion(str(250)) + tokenizer_ = AutoTokenizer.from_pretrained(args.text_encoder_name, cache_dir='./cache_dir') + videos = [] + for idx in range(args.num_validation_videos): + with torch.autocast(device_type='cuda', dtype=weight_dtype): + z = torch.randn(1, model_.in_channels, video_length, + latent_size[0], latent_size[1], device=accelerator.device) + text_tokens_and_mask = tokenizer_( + validation_prompt, + max_length=args.model_max_length, + padding='max_length', + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors='pt' + ) + input_ids = text_tokens_and_mask['input_ids'].to(accelerator.device) + cond_mask = text_tokens_and_mask['attention_mask'].to(accelerator.device) + # cond = text_enc_(input_ids, cond_mask) # B L D + cond = text_enc(input_ids, cond_mask) # B L D + model_kwargs = dict(encoder_hidden_states=cond, attention_mask=None, encoder_attention_mask=cond_mask) + sample_fn = model_.forward + # Sample images: + samples = diffusion_.p_sample_loop( + sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, + device=accelerator.device + ) + samples = ae_.decode(samples) + # Save and display images: + video = (ae_denorm[args.ae](samples[0]) * 255).add_(0.5).clamp_(0, 255).to( + dtype=torch.uint8).cpu().contiguous() # t c h w + videos.append(video) + + videos = torch.stack(videos).numpy() + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_videos = np.stack([np.asarray(vid) for vid in videos]) + tracker.writer.add_video("validation", np_videos, global_step, fps=10) + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Video(video, caption=f"{i}: {validation_prompt}", fps=10) + for i, video in enumerate(videos) + ] + } + ) + + # del ae_, text_enc_, model_, diffusion_, tokenizer_ + del ae_, model_, diffusion_, tokenizer_ + torch.cuda.empty_cache() + + accelerator.wait_for_everyone() + accelerator.end_training() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, required=True) + parser.add_argument("--data_path", type=str, required=True) + parser.add_argument("--model", type=str, choices=list(Diffusion_models.keys()), default="DiT-XL/122") + parser.add_argument("--num_classes", type=int, default=1000) + parser.add_argument("--ae", type=str, default="stabilityai/sd-vae-ft-mse") + parser.add_argument("--ae_path", type=str, default="stabilityai/sd-vae-ft-mse") + parser.add_argument("--sample_rate", type=int, default=4) + parser.add_argument("--num_frames", type=int, default=16) + parser.add_argument("--max_image_size", type=int, default=128) + parser.add_argument("--dynamic_frames", action="store_true") + parser.add_argument("--compress_kv", action="store_true") + parser.add_argument("--attention_mode", type=str, choices=['xformers', 'math', 'flash'], default="math") + parser.add_argument("--pretrained", type=str, default=None) + + parser.add_argument('--tile_overlap_factor', type=float, default=0.25) + parser.add_argument('--enable_tiling', action='store_true') + + parser.add_argument("--video_folder", type=str, default='') + parser.add_argument("--text_encoder_name", type=str, default='DeepFloyd/t5-v1_1-xxl') + parser.add_argument("--model_max_length", type=int, default=120) + + parser.add_argument("--enable_tracker", action="store_true") + parser.add_argument("--use_image_num", type=int, default=0) + parser.add_argument("--use_img_from_vid", action="store_true") + parser.add_argument("--use_deepspeed", action="store_true") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--num_validation_videos", + type=int, + default=2, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--timestep_bias_strategy", + type=str, + default="none", + choices=["earlier", "later", "range", "none"], + help=( + "The timestep bias strategy, which may help direct the model toward learning low or high frequency details." + " Choices: ['earlier', 'later', 'range', 'none']." + " The default is 'none', which means no bias is applied, and training proceeds normally." + " The value of 'later' will increase the frequency of the model's final training timesteps." + ), + ) + parser.add_argument( + "--timestep_bias_multiplier", + type=float, + default=1.0, + help=( + "The multiplier for the bias. Defaults to 1.0, which means no bias is applied." + " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it." + ), + ) + parser.add_argument( + "--timestep_bias_begin", + type=int, + default=0, + help=( + "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias." + " Defaults to zero, which equates to having no specific bias." + ), + ) + parser.add_argument( + "--timestep_bias_end", + type=int, + default=1000, + help=( + "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias." + " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on." + ), + ) + parser.add_argument( + "--timestep_bias_portion", + type=float, + default=0.25, + help=( + "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased." + " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines" + " whether the biased portions are in the earlier or later timesteps." + ), + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=10, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/opensora/train/train_t2v_feature.py b/opensora/train/train_t2v_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..8d46d00092caf49d3c45445de63b823e7c914104 --- /dev/null +++ b/opensora/train/train_t2v_feature.py @@ -0,0 +1,783 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +A minimal training script for DiT using PyTorch DDP. +""" +import argparse +import logging +import math +import os +import shutil +from pathlib import Path +from typing import Optional + +import numpy as np +from einops import rearrange +from tqdm import tqdm +from dataclasses import field, dataclass +from torch.utils.data import DataLoader +from copy import deepcopy + +import accelerate +import torch +from torch.nn import functional as F +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo +from packaging import version +from tqdm.auto import tqdm +from transformers import HfArgumentParser, TrainingArguments, AutoTokenizer + +import diffusers +from diffusers import DDPMScheduler, PNDMScheduler +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel, compute_snr +from diffusers.utils import check_min_version, is_wandb_available + +from opensora.dataset import getdataset, ae_denorm +from opensora.models.ae import getae, getae_wrapper +from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper +from opensora.models.diffusion.diffusion import create_diffusion_T as create_diffusion +from opensora.models.diffusion.latte.modeling_latte import LatteT2V +from opensora.models.text_encoder import get_text_enc +from opensora.utils.dataset_utils import Collate +from opensora.models.ae import ae_stride_config, ae_channel_config +from opensora.models.diffusion import Diffusion_models + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.24.0") +logger = get_logger(__name__) + + +def generate_timestep_weights(args, num_timesteps): + weights = torch.ones(num_timesteps) + + # Determine the indices to bias + num_to_bias = int(args.timestep_bias_portion * num_timesteps) + + if args.timestep_bias_strategy == "later": + bias_indices = slice(-num_to_bias, None) + elif args.timestep_bias_strategy == "earlier": + bias_indices = slice(0, num_to_bias) + elif args.timestep_bias_strategy == "range": + # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500. + range_begin = args.timestep_bias_begin + range_end = args.timestep_bias_end + if range_begin < 0: + raise ValueError( + "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero." + ) + if range_end > num_timesteps: + raise ValueError( + "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps." + ) + bias_indices = slice(range_begin, range_end) + else: # 'none' or any other string + return weights + if args.timestep_bias_multiplier <= 0: + return ValueError( + "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps." + " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead." + " A timestep bias multiplier less than or equal to 0 is not allowed." + ) + + # Apply the bias + weights[bias_indices] *= args.timestep_bias_multiplier + + # Normalize + weights /= weights.sum() + + return weights + + +################################################################################# +# Training Loop # +################################################################################# + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # if args.push_to_hub: + # repo_id = create_repo( + # repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + # ).repo_id + + # Create model: + + diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule + # ae = getae(args).eval() + # text_enc = get_text_enc(args).eval() + + ae_stride_t, ae_stride_h, ae_stride_w = ae_stride_config[args.ae] + args.ae_stride_t, args.ae_stride_h, args.ae_stride_w = ae_stride_t, ae_stride_h, ae_stride_w + args.ae_stride = args.ae_stride_h + patch_size = args.model[-3:] + patch_size_t, patch_size_h, patch_size_w = int(patch_size[0]), int(patch_size[1]), int(patch_size[2]) + args.patch_size = patch_size_h + args.patch_size_t, args.patch_size_h, args.patch_size_w = patch_size_t, patch_size_h, patch_size_w + assert ae_stride_h == ae_stride_w, f"Support only ae_stride_h == ae_stride_w now, but found ae_stride_h ({ae_stride_h}), ae_stride_w ({ae_stride_w})" + assert patch_size_h == patch_size_w, f"Support only patch_size_h == patch_size_w now, but found patch_size_h ({patch_size_h}), patch_size_w ({patch_size_w})" + # assert args.num_frames % ae_stride_t == 0, f"Num_frames must be divisible by ae_stride_t, but found num_frames ({args.num_frames}), ae_stride_t ({ae_stride_t})." + assert args.max_image_size % ae_stride_h == 0, f"Image size must be divisible by ae_stride_h, but found max_image_size ({args.max_image_size}), ae_stride_h ({ae_stride_h})." + + latent_size = (args.max_image_size // ae_stride_h, args.max_image_size // ae_stride_w) + + if getae_wrapper(args.ae) == CausalVQVAEModelWrapper or getae_wrapper(args.ae) == CausalVAEModelWrapper: + args.video_length = video_length = args.num_frames // ae_stride_t + 1 + else: + args.video_length = video_length = args.num_frames // ae_stride_t + model = Diffusion_models[args.model]( + in_channels=ae_channel_config[args.ae], + out_channels=ae_channel_config[args.ae] * 2, + # caption_channels=4096, + # cross_attention_dim=1152, + attention_bias=True, + sample_size=latent_size, + num_vector_embeds=None, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + use_linear_projection=False, + only_cross_attention=False, + double_self_attention=False, + upcast_attention=False, + # norm_type="ada_norm_single", + norm_elementwise_affine=False, + norm_eps=1e-6, + attention_type='default', + video_length=video_length, + attention_mode=args.attention_mode, + # compress_kv=args.compress_kv + ) + model.gradient_checkpointing = args.gradient_checkpointing + + # # use pretrained model? + if args.pretrained: + checkpoint = torch.load(args.pretrained, map_location='cpu')['model'] + model_state_dict = model.state_dict() + missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) + logger.info(f'missing_keys {len(missing_keys)} {missing_keys}, unexpected_keys {len(unexpected_keys)}') + logger.info(f'Successfully load {len(model.state_dict()) - len(missing_keys)}/{len(model_state_dict)} keys from {args.pretrained}!') + # load from pixart-alpha + # pixelart_alpha = torch.load(args.pretrained, map_location='cpu')['state_dict'] + # checkpoint = {} + # for k, v in pixelart_alpha.items(): + # if 'x_embedder' in k or 't_embedder' in k or 'y_embedder' in k: + # checkpoint[k] = v + # if k.startswith('blocks'): + # k_spilt = k.split('.') + # blk_id = str(int(k_spilt[1]) * 2) + # k_spilt[1] = blk_id + # new_k = '.'.join(k_spilt) + # checkpoint[new_k] = v + # missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) + # logger.info(f'Successfully load {len(model.state_dict()) - len(missing_keys)} keys from {args.pretrained}!') + + # Freeze vae and text encoders. + # ae.requires_grad_(False) + # text_enc.requires_grad_(False) + # Set model as trainable. + model.train() + + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + # ae.to(accelerator.device, dtype=torch.float32) + # text_enc.to(accelerator.device, dtype=weight_dtype) + + # Create EMA for the unet. + if args.use_ema: + ema_model = deepcopy(model) + ema_model = EMAModel(ema_model.parameters(), model_cls=LatteT2V, model_config=ema_model.config) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_model.save_pretrained(os.path.join(output_dir, "model_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "model")) + if weights: # Don't pop if empty + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "model_ema"), LatteT2V) + ema_model.load_state_dict(load_model.state_dict()) + ema_model.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = LatteT2V.from_pretrained(input_dir, subfolder="model") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = model.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Setup data: + train_dataset = getdataset(args) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + # collate_fn=Collate(args), # TODO: do not enable dynamic mask in this point + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers(args.output_dir, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + for step, (x, cond, cond_mask) in enumerate(train_dataloader): + with accelerator.accumulate(model): + # Sample noise that we'll add to the latents + x = x.to(accelerator.device) # B C T H W + # attn_mask = attn_mask.to(device) # B T H W + # assert torch.all(attn_mask.bool()), 'do not enable dynamic input' + attn_mask = None + cond = cond.to(accelerator.device) # B L or B 1+num_images L + cond_mask = cond_mask.to(accelerator.device) # B L or B 1+num_images L + # print(args.use_image_num, x.shape, cond.shape, cond_mask.shape, cond_mask) + model_kwargs = dict(encoder_hidden_states=cond, attention_mask=attn_mask, + encoder_attention_mask=cond_mask, use_image_num=args.use_image_num) + t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=accelerator.device) + loss_dict = diffusion.training_losses(model, x, t, model_kwargs) + loss = loss_dict["loss"].mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = model.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if args.use_deepspeed or accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + validation_prompt = "The majestic beauty of a waterfall cascading down a cliff into a serene lake. The camera angle provides a bird's eye view of the waterfall." + if global_step % args.checkpointing_steps == 0: + logger.info(f"Running validation... \n" + f"Generating {args.num_validation_videos} videos with prompt: {validation_prompt}") + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_model.store(model.parameters()) + ema_model.copy_to(model.parameters()) + + if args.enable_tracker: + with torch.no_grad(): + # create pipeline + ae_ = getae_wrapper(args.ae)(args.ae_path).to(accelerator.device).eval() + if args.enable_tiling: + ae_.vae.enable_tiling() + ae_.vae.tile_overlap_factor = args.tile_overlap_factor + text_enc_ = get_text_enc(args).to(accelerator.device).eval() + model_ = LatteT2V.from_pretrained(save_path, subfolder="model").to(accelerator.device).eval() + diffusion_ = create_diffusion(str(250)) + tokenizer_ = AutoTokenizer.from_pretrained(args.text_encoder_name, cache_dir='./cache_dir') + videos = [] + for idx in range(args.num_validation_videos): + with torch.autocast(device_type='cuda', dtype=weight_dtype): + z = torch.randn(1, model_.in_channels, video_length, + latent_size[0], latent_size[1], device=accelerator.device) + text_tokens_and_mask = tokenizer_( + validation_prompt, + max_length=args.model_max_length, + padding='max_length', + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors='pt' + ) + input_ids = text_tokens_and_mask['input_ids'].to(accelerator.device) + cond_mask = text_tokens_and_mask['attention_mask'].to(accelerator.device) + cond = text_enc_(input_ids, cond_mask) # B L D + # cond = text_enc(input_ids, cond_mask) # B L D + model_kwargs = dict(encoder_hidden_states=cond, attention_mask=None, encoder_attention_mask=cond_mask) + sample_fn = model_.forward + # Sample images: + samples = diffusion_.p_sample_loop( + sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, + device=accelerator.device + ) + samples = ae_.decode(samples) + # Save and display images: + video = (ae_denorm[args.ae](samples[0]) * 255).add_(0.5).clamp_(0, 255).to( + dtype=torch.uint8).cpu().contiguous() # t c h w + videos.append(video) + + videos = torch.stack(videos).numpy() + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_videos = np.stack([np.asarray(vid) for vid in videos]) + tracker.writer.add_video("validation", np_videos, global_step, fps=10) + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Video(video, caption=f"{i}: {validation_prompt}", fps=10) + for i, video in enumerate(videos) + ] + } + ) + + del ae_, text_enc_, model_, diffusion_, tokenizer_ + # del ae_, model_, diffusion_, tokenizer_ + torch.cuda.empty_cache() + + accelerator.wait_for_everyone() + accelerator.end_training() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, required=True) + parser.add_argument("--data_path", type=str, required=True) + parser.add_argument("--model", type=str, choices=list(Diffusion_models.keys()), default="DiT-XL/122") + parser.add_argument("--num_classes", type=int, default=1000) + parser.add_argument("--ae", type=str, default="stabilityai/sd-vae-ft-mse") + parser.add_argument("--ae_path", type=str, default="stabilityai/sd-vae-ft-mse") + parser.add_argument("--sample_rate", type=int, default=4) + parser.add_argument("--num_frames", type=int, default=16) + parser.add_argument("--max_image_size", type=int, default=128) + parser.add_argument("--dynamic_frames", action="store_true") + parser.add_argument("--compress_kv", action="store_true") + parser.add_argument("--attention_mode", type=str, choices=['xformers', 'math', 'flash'], default="math") + parser.add_argument("--pretrained", type=str, default=None) + + parser.add_argument('--tile_overlap_factor', type=float, default=0.25) + parser.add_argument('--enable_tiling', action='store_true') + + parser.add_argument("--video_folder", type=str, default='') + parser.add_argument("--text_encoder_name", type=str, default='DeepFloyd/t5-v1_1-xxl') + parser.add_argument("--model_max_length", type=int, default=120) + + parser.add_argument("--enable_tracker", action="store_true") + parser.add_argument("--use_image_num", type=int, default=0) + parser.add_argument("--use_img_from_vid", action="store_true") + parser.add_argument("--use_deepspeed", action="store_true") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--num_validation_videos", + type=int, + default=2, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--timestep_bias_strategy", + type=str, + default="none", + choices=["earlier", "later", "range", "none"], + help=( + "The timestep bias strategy, which may help direct the model toward learning low or high frequency details." + " Choices: ['earlier', 'later', 'range', 'none']." + " The default is 'none', which means no bias is applied, and training proceeds normally." + " The value of 'later' will increase the frequency of the model's final training timesteps." + ), + ) + parser.add_argument( + "--timestep_bias_multiplier", + type=float, + default=1.0, + help=( + "The multiplier for the bias. Defaults to 1.0, which means no bias is applied." + " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it." + ), + ) + parser.add_argument( + "--timestep_bias_begin", + type=int, + default=0, + help=( + "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias." + " Defaults to zero, which equates to having no specific bias." + ), + ) + parser.add_argument( + "--timestep_bias_end", + type=int, + default=1000, + help=( + "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias." + " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on." + ), + ) + parser.add_argument( + "--timestep_bias_portion", + type=float, + default=0.25, + help=( + "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased." + " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines" + " whether the biased portions are in the earlier or later timesteps." + ), + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=10, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/opensora/train/train_t2v_t5_feature.py b/opensora/train/train_t2v_t5_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..a8f3543e08eb4dec0038cb64a96dbf4dd7168a94 --- /dev/null +++ b/opensora/train/train_t2v_t5_feature.py @@ -0,0 +1,825 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +A minimal training script for DiT using PyTorch DDP. +""" +import argparse +import logging +import math +import os +import shutil +import sys +from pathlib import Path +from typing import Optional + +import numpy as np +from PIL import Image +from einops import rearrange +from tqdm import tqdm +from dataclasses import field, dataclass +from torch.utils.data import DataLoader +from copy import deepcopy + +import accelerate +import torch +from torch.nn import functional as F +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo +from packaging import version +from tqdm.auto import tqdm +from transformers import HfArgumentParser, TrainingArguments, AutoTokenizer + +import diffusers +from diffusers import DDPMScheduler, PNDMScheduler +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel, compute_snr +from diffusers.utils import check_min_version, is_wandb_available + +from examples.rec_imvi_vae import custom_to_video +from opensora.dataset import getdataset, ae_denorm +from opensora.models.ae import getae, getae_wrapper +from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper +from opensora.models.diffusion.diffusion import create_diffusion_T as create_diffusion +from opensora.models.diffusion.latte.modeling_latte import LatteT2V +from opensora.models.text_encoder import get_text_enc +from opensora.utils.dataset_utils import Collate +from opensora.models.ae import ae_stride_config, ae_channel_config +from opensora.models.diffusion import Diffusion_models + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.24.0") +logger = get_logger(__name__) + + +def generate_timestep_weights(args, num_timesteps): + weights = torch.ones(num_timesteps) + + # Determine the indices to bias + num_to_bias = int(args.timestep_bias_portion * num_timesteps) + + if args.timestep_bias_strategy == "later": + bias_indices = slice(-num_to_bias, None) + elif args.timestep_bias_strategy == "earlier": + bias_indices = slice(0, num_to_bias) + elif args.timestep_bias_strategy == "range": + # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500. + range_begin = args.timestep_bias_begin + range_end = args.timestep_bias_end + if range_begin < 0: + raise ValueError( + "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero." + ) + if range_end > num_timesteps: + raise ValueError( + "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps." + ) + bias_indices = slice(range_begin, range_end) + else: # 'none' or any other string + return weights + if args.timestep_bias_multiplier <= 0: + return ValueError( + "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps." + " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead." + " A timestep bias multiplier less than or equal to 0 is not allowed." + ) + + # Apply the bias + weights[bias_indices] *= args.timestep_bias_multiplier + + # Normalize + weights /= weights.sum() + + return weights + + +################################################################################# +# Training Loop # +################################################################################# + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # if args.push_to_hub: + # repo_id = create_repo( + # repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + # ).repo_id + + # Create model: + + diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule + ae = getae_wrapper(args.ae)(args.ae_path).eval() + if args.enable_tiling: + ae.vae.enable_tiling() + ae.vae.tile_overlap_factor = args.tile_overlap_factor + # text_enc = get_text_enc(args).eval() + + ae_stride_t, ae_stride_h, ae_stride_w = ae_stride_config[args.ae] + args.ae_stride_t, args.ae_stride_h, args.ae_stride_w = ae_stride_t, ae_stride_h, ae_stride_w + args.ae_stride = args.ae_stride_h + patch_size = args.model[-3:] + patch_size_t, patch_size_h, patch_size_w = int(patch_size[0]), int(patch_size[1]), int(patch_size[2]) + args.patch_size = patch_size_h + args.patch_size_t, args.patch_size_h, args.patch_size_w = patch_size_t, patch_size_h, patch_size_w + assert ae_stride_h == ae_stride_w, f"Support only ae_stride_h == ae_stride_w now, but found ae_stride_h ({ae_stride_h}), ae_stride_w ({ae_stride_w})" + assert patch_size_h == patch_size_w, f"Support only patch_size_h == patch_size_w now, but found patch_size_h ({patch_size_h}), patch_size_w ({patch_size_w})" + # assert args.num_frames % ae_stride_t == 0, f"Num_frames must be divisible by ae_stride_t, but found num_frames ({args.num_frames}), ae_stride_t ({ae_stride_t})." + assert args.max_image_size % ae_stride_h == 0, f"Image size must be divisible by ae_stride_h, but found max_image_size ({args.max_image_size}), ae_stride_h ({ae_stride_h})." + + latent_size = (args.max_image_size // ae_stride_h, args.max_image_size // ae_stride_w) + + if getae_wrapper(args.ae) == CausalVQVAEModelWrapper or getae_wrapper(args.ae) == CausalVAEModelWrapper: + args.video_length = video_length = args.num_frames // ae_stride_t + 1 + else: + args.video_length = video_length = args.num_frames // ae_stride_t + model = Diffusion_models[args.model]( + in_channels=ae_channel_config[args.ae], + out_channels=ae_channel_config[args.ae] * 2, + # caption_channels=4096, + # cross_attention_dim=1152, + attention_bias=True, + sample_size=latent_size, + num_vector_embeds=None, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + use_linear_projection=False, + only_cross_attention=False, + double_self_attention=False, + upcast_attention=False, + # norm_type="ada_norm_single", + norm_elementwise_affine=False, + norm_eps=1e-6, + attention_type='default', + video_length=video_length, + attention_mode=args.attention_mode, + # compress_kv=args.compress_kv + ) + model.gradient_checkpointing = args.gradient_checkpointing + + # # use pretrained model? + if args.pretrained: + if 'safetensors' in args.pretrained: + from safetensors.torch import load_file as safe_load + checkpoint = safe_load(args.pretrained, device="cpu") + else: + checkpoint = torch.load(args.pretrained, map_location='cpu')['model'] + model_state_dict = model.state_dict() + missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) + logger.info(f'missing_keys {len(missing_keys)} {missing_keys}, unexpected_keys {len(unexpected_keys)}') + logger.info(f'Successfully load {len(model.state_dict()) - len(missing_keys)}/{len(model_state_dict)} keys from {args.pretrained}!') + # load from pixart-alpha + # pixelart_alpha = torch.load(args.pretrained, map_location='cpu')['state_dict'] + # checkpoint = {} + # for k, v in pixelart_alpha.items(): + # if 'x_embedder' in k or 't_embedder' in k or 'y_embedder' in k: + # checkpoint[k] = v + # if k.startswith('blocks'): + # k_spilt = k.split('.') + # blk_id = str(int(k_spilt[1]) * 2) + # k_spilt[1] = blk_id + # new_k = '.'.join(k_spilt) + # checkpoint[new_k] = v + # missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) + # logger.info(f'Successfully load {len(model.state_dict()) - len(missing_keys)} keys from {args.pretrained}!') + + # Freeze vae and text encoders. + ae.requires_grad_(False) + # text_enc.requires_grad_(False) + # Set model as trainable. + model.train() + + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + # ae.to(accelerator.device, dtype=torch.float32) + ae.to(accelerator.device, dtype=weight_dtype) + model.to(accelerator.device, dtype=weight_dtype) + # text_enc.to(accelerator.device, dtype=weight_dtype) + + # Create EMA for the unet. + if args.use_ema: + ema_model = deepcopy(model) + ema_model = EMAModel(ema_model.parameters(), model_cls=LatteT2V, model_config=ema_model.config) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_model.save_pretrained(os.path.join(output_dir, "model_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "model")) + if weights: # Don't pop if empty + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "model_ema"), LatteT2V) + ema_model.load_state_dict(load_model.state_dict()) + ema_model.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = LatteT2V.from_pretrained(input_dir, subfolder="model") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = model.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Setup data: + train_dataset = getdataset(args) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + # collate_fn=Collate(args), # TODO: do not enable dynamic mask in this point + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers(args.output_dir, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + for step, (x, cond, cond_mask) in enumerate(train_dataloader): + with accelerator.accumulate(model): + # Sample noise that we'll add to the latents + x = x.to(accelerator.device) # B C T H W + # print(x.dtype) + # attn_mask = attn_mask.to(device) # B T H W + # assert torch.all(attn_mask.bool()), 'do not enable dynamic input' + attn_mask = None + cond = cond.to(accelerator.device, dtype=weight_dtype) # B L or B 1+num_images L + cond_mask = cond_mask.to(accelerator.device) # B L or B 1+num_images L + + with torch.no_grad(): + # Map input images to latent space + normalize latents + if args.use_image_num == 0: + x = ae.encode(x.to(dtype=weight_dtype)) # B C T H W + else: + videos, images = x[:, :, :-args.use_image_num], x[:, :, -args.use_image_num:] + videos = ae.encode(videos.to(dtype=weight_dtype)) # B C T H W + + # videos = ae.decode(videos.to(dtype=weight_dtype))[0] + # videos = videos.transpose(0, 1) + # custom_to_video(videos.to(torch.float32), fps=24, output_file='tmp.mp4') + + images = rearrange(images, 'b c t h w -> (b t) c 1 h w') + images = ae.encode(images.to(dtype=weight_dtype)) + + # images = ae.decode(images.to(dtype=weight_dtype)) + # x = images[0, 0, :, :, :].to(torch.float32) + # x = x.squeeze() + # x = x.detach().cpu().numpy() + # x = np.clip(x, -1, 1) + # x = (x + 1) / 2 + # x = (255 * x).astype(np.uint8) + # x = x.transpose(1, 2, 0) + # image = Image.fromarray(x) + # image.save('tmp.jpg') + # sys.exit() + + images = rearrange(images, '(b t) c 1 h w -> b c t h w', t=args.use_image_num) + x = torch.cat([videos, images], dim=2) + + # print(args.use_image_num, x.shape, cond.shape, cond_mask.shape, cond_mask) + model_kwargs = dict(encoder_hidden_states=cond, attention_mask=attn_mask, + encoder_attention_mask=cond_mask, use_image_num=args.use_image_num) + t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=accelerator.device) + loss_dict = diffusion.training_losses(model, x, t, model_kwargs) + loss = loss_dict["loss"].mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = model.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if args.use_deepspeed or accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + validation_prompt = "The majestic beauty of a waterfall cascading down a cliff into a serene lake. The camera angle provides a bird's eye view of the waterfall." + if global_step % args.checkpointing_steps == 0: + logger.info(f"Running validation... \n" + f"Generating {args.num_validation_videos} videos with prompt: {validation_prompt}") + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_model.store(model.parameters()) + ema_model.copy_to(model.parameters()) + if args.enable_tracker: + with torch.no_grad(): + # create pipeline + ae_ = getae_wrapper(args.ae)(args.ae_path).to(accelerator.device).eval() + if args.enable_tiling: + ae_.vae.enable_tiling() + ae_.vae.tile_overlap_factor = args.tile_overlap_factor + text_enc_ = get_text_enc(args).to(accelerator.device).eval() + model_ = LatteT2V.from_pretrained(save_path, subfolder="model").to(accelerator.device).eval() + diffusion_ = create_diffusion(str(500)) + tokenizer_ = AutoTokenizer.from_pretrained(args.text_encoder_name, cache_dir='./cache_dir') + videos = [] + for idx in range(args.num_validation_videos): + with torch.autocast(device_type='cuda', dtype=weight_dtype): + z = torch.randn(1, model_.in_channels, video_length, + latent_size[0], latent_size[1], device=accelerator.device) + text_tokens_and_mask = tokenizer_( + validation_prompt, + max_length=args.model_max_length, + padding='max_length', + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors='pt' + ) + input_ids = text_tokens_and_mask['input_ids'].to(accelerator.device) + cond_mask = text_tokens_and_mask['attention_mask'].to(accelerator.device) + cond = text_enc_(input_ids, cond_mask) # B L D + # cond = text_enc(input_ids, cond_mask) # B L D + model_kwargs = dict(encoder_hidden_states=cond, attention_mask=None, encoder_attention_mask=cond_mask) + sample_fn = model_.forward + # Sample images: + samples = diffusion_.p_sample_loop( + sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, + device=accelerator.device + ) + samples = ae_.decode(samples) + # Save and display images: + video = (ae_denorm[args.ae](samples[0]) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().contiguous() # t c h w + videos.append(video) + + videos = torch.stack(videos).numpy() + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_videos = np.stack([np.asarray(vid) for vid in videos]) + tracker.writer.add_video("validation", np_videos, global_step, fps=24) + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Video(video, caption=f"{i}: {validation_prompt}", fps=24) + for i, video in enumerate(videos) + ] + } + ) + + del ae_, text_enc_, model_, diffusion_, tokenizer_ + # del ae_, model_, diffusion_, tokenizer_ + torch.cuda.empty_cache() + + accelerator.wait_for_everyone() + accelerator.end_training() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, required=True) + parser.add_argument("--data_path", type=str, required=True) + parser.add_argument("--model", type=str, choices=list(Diffusion_models.keys()), default="DiT-XL/122") + parser.add_argument("--num_classes", type=int, default=1000) + parser.add_argument("--ae", type=str, default="stabilityai/sd-vae-ft-mse") + parser.add_argument("--ae_path", type=str, default="stabilityai/sd-vae-ft-mse") + parser.add_argument("--sample_rate", type=int, default=4) + parser.add_argument("--num_frames", type=int, default=16) + parser.add_argument("--max_image_size", type=int, default=128) + parser.add_argument("--dynamic_frames", action="store_true") + parser.add_argument("--compress_kv", action="store_true") + parser.add_argument("--attention_mode", type=str, choices=['xformers', 'math', 'flash'], default="math") + parser.add_argument("--pretrained", type=str, default=None) + + parser.add_argument('--tile_overlap_factor', type=float, default=0.25) + parser.add_argument('--enable_tiling', action='store_true') + + parser.add_argument("--video_folder", type=str, default='') + parser.add_argument("--text_encoder_name", type=str, default='DeepFloyd/t5-v1_1-xxl') + parser.add_argument("--model_max_length", type=int, default=120) + + parser.add_argument("--use_image_num", type=int, default=0) + parser.add_argument("--use_img_from_vid", action="store_true") + parser.add_argument("--enable_tracker", action="store_true") + parser.add_argument("--use_deepspeed", action="store_true") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--num_validation_videos", + type=int, + default=2, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--timestep_bias_strategy", + type=str, + default="none", + choices=["earlier", "later", "range", "none"], + help=( + "The timestep bias strategy, which may help direct the model toward learning low or high frequency details." + " Choices: ['earlier', 'later', 'range', 'none']." + " The default is 'none', which means no bias is applied, and training proceeds normally." + " The value of 'later' will increase the frequency of the model's final training timesteps." + ), + ) + parser.add_argument( + "--timestep_bias_multiplier", + type=float, + default=1.0, + help=( + "The multiplier for the bias. Defaults to 1.0, which means no bias is applied." + " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it." + ), + ) + parser.add_argument( + "--timestep_bias_begin", + type=int, + default=0, + help=( + "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias." + " Defaults to zero, which equates to having no specific bias." + ), + ) + parser.add_argument( + "--timestep_bias_end", + type=int, + default=1000, + help=( + "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias." + " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on." + ), + ) + parser.add_argument( + "--timestep_bias_portion", + type=float, + default=0.25, + help=( + "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased." + " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines" + " whether the biased portions are in the earlier or later timesteps." + ), + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=10, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/opensora/train/train_videogpt.py b/opensora/train/train_videogpt.py new file mode 100644 index 0000000000000000000000000000000000000000..0cf3579131fa1ece702661a8c5b78cb49e104dc2 --- /dev/null +++ b/opensora/train/train_videogpt.py @@ -0,0 +1,53 @@ +import sys +sys.path.append(".") + +from opensora.models.ae.videobase.dataset_videobase import VideoDataset +from opensora.models.ae.videobase import ( + VQVAEModel, + VQVAEConfiguration, + VQVAETrainer, +) +import argparse +from typing import Optional +from accelerate.utils import set_seed +from transformers import HfArgumentParser, TrainingArguments +from dataclasses import dataclass, field, asdict + + +@dataclass +class VQVAEArgument: + embedding_dim: int = field(default=256), + n_codes: int = field(default=2048), + n_hiddens: int = field(default=240), + n_res_layers: int = field(default=4), + resolution: int = field(default=128), + sequence_length: int = field(default=16), + downsample: str = field(default="4,4,4"), + no_pos_embd: bool = True, + data_path: str = field(default=None, metadata={"help": "data path"}) + +@dataclass +class VQVAETrainingArgument(TrainingArguments): + remove_unused_columns: Optional[bool] = field( + default=False, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."} + ) + +def train(args, vqvae_args, training_args): + # Load Config + config = VQVAEConfiguration(**asdict(vqvae_args)) + # Load Model + model = VQVAEModel(config) + # Load Dataset + dataset = VideoDataset(args.data_path, sequence_length=args.sequence_length, resolution=config.resolution) + # Load Trainer + trainer = VQVAETrainer(model, training_args, train_dataset=dataset) + trainer.train() + + +if __name__ == "__main__": + parser = HfArgumentParser((VQVAEArgument, VQVAETrainingArgument)) + vqvae_args, training_args = parser.parse_args_into_dataclasses() + args = argparse.Namespace(**vars(vqvae_args), **vars(training_args)) + set_seed(args.seed) + + train(args, vqvae_args, training_args) diff --git a/opensora/utils/dataset_utils.py b/opensora/utils/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8923adc7e5b9e3647277711f6721f69bf817cfba --- /dev/null +++ b/opensora/utils/dataset_utils.py @@ -0,0 +1,106 @@ +import math + +import decord +from torch.nn import functional as F +import torch + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG'] + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +class DecordInit(object): + """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader.""" + + def __init__(self, num_threads=1): + self.num_threads = num_threads + self.ctx = decord.cpu(0) + + def __call__(self, filename): + """Perform the Decord initialization. + Args: + results (dict): The resulting dict to be modified and passed + to the next transform in pipeline. + """ + reader = decord.VideoReader(filename, + ctx=self.ctx, + num_threads=self.num_threads) + return reader + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'sr={self.sr},' + f'num_threads={self.num_threads})') + return repr_str + +def pad_to_multiple(number, ds_stride): + remainder = number % ds_stride + if remainder == 0: + return number + else: + padding = ds_stride - remainder + return number + padding + +class Collate: + def __init__(self, args): + self.max_image_size = args.max_image_size + self.ae_stride = args.ae_stride + self.ae_stride_t = args.ae_stride_t + self.patch_size = args.patch_size + self.patch_size_t = args.patch_size_t + self.num_frames = args.num_frames + + def __call__(self, batch): + unzip = tuple(zip(*batch)) + if len(unzip) == 2: + batch_tubes, labels = unzip + labels = torch.as_tensor(labels).to(torch.long) + elif len(unzip) == 3: + batch_tubes, input_ids, cond_mask = unzip + input_ids = torch.stack(input_ids).squeeze(1) + cond_mask = torch.stack(cond_mask).squeeze(1) + else: + raise NotImplementedError + ds_stride = self.ae_stride * self.patch_size + t_ds_stride = self.ae_stride_t * self.patch_size_t + + # pad to max multiple of ds_stride + batch_input_size = [i.shape for i in batch_tubes] + max_t, max_h, max_w = self.num_frames, \ + self.max_image_size, \ + self.max_image_size + pad_max_t, pad_max_h, pad_max_w = pad_to_multiple(max_t, t_ds_stride), \ + pad_to_multiple(max_h, ds_stride), \ + pad_to_multiple(max_w, ds_stride) + each_pad_t_h_w = [[pad_max_t - i.shape[1], + pad_max_h - i.shape[2], + pad_max_w - i.shape[3]] for i in batch_tubes] + pad_batch_tubes = [F.pad(im, + (0, pad_w, + 0, pad_h, + 0, pad_t), value=0) for (pad_t, pad_h, pad_w), im in zip(each_pad_t_h_w, batch_tubes)] + pad_batch_tubes = torch.stack(pad_batch_tubes, dim=0) + + # make attention_mask + max_tube_size = [pad_max_t, pad_max_h, pad_max_w] + max_latent_size = [max_tube_size[0] // self.ae_stride_t, + max_tube_size[1] // self.ae_stride, + max_tube_size[2] // self.ae_stride] + max_patchify_latent_size = [max_latent_size[0] // self.patch_size_t, + max_latent_size[1] // self.patch_size, + max_latent_size[2] // self.patch_size] + valid_patchify_latent_size = [[int(math.ceil(i[1] / t_ds_stride)), + int(math.ceil(i[2] / ds_stride)), + int(math.ceil(i[3] / ds_stride))] for i in batch_input_size] + attention_mask = [F.pad(torch.ones(i), + (0, max_patchify_latent_size[2] - i[2], + 0, max_patchify_latent_size[1] - i[1], + 0, max_patchify_latent_size[0] - i[0]), value=0) for i in valid_patchify_latent_size] + attention_mask = torch.stack(attention_mask) + + if len(unzip) == 2: + return pad_batch_tubes, labels, attention_mask + elif len(unzip) == 3: + return pad_batch_tubes, attention_mask, input_ids, cond_mask + diff --git a/opensora/utils/downloader.py b/opensora/utils/downloader.py new file mode 100644 index 0000000000000000000000000000000000000000..f2ac4017b10033f67b5affce0906c0defadb38cf --- /dev/null +++ b/opensora/utils/downloader.py @@ -0,0 +1,18 @@ +import gdown +import os + +opensora_cache_home = os.path.expanduser( + os.getenv("OPENSORA_HOME", os.path.join("~/.cache", "opensora")) +) + + +def gdown_download(id, fname, cache_dir=None): + cache_dir = opensora_cache_home if not cache_dir else cache_dir + + os.makedirs(cache_dir, exist_ok=True) + destination = os.path.join(cache_dir, fname) + if os.path.exists(destination): + return destination + + gdown.download(id=id, output=destination, quiet=False) + return destination diff --git a/opensora/utils/taming_download.py b/opensora/utils/taming_download.py new file mode 100644 index 0000000000000000000000000000000000000000..5a62be7781ae7eb166d5a80383f30a83af093895 --- /dev/null +++ b/opensora/utils/taming_download.py @@ -0,0 +1,145 @@ +"""Modified from https://github.com/CompVis/taming-transformers.git""" + +import os, hashlib +import requests +from tqdm import tqdm + +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class KeyNotFoundError(Exception): + def __init__(self, cause, keys=None, visited=None): + self.cause = cause + self.keys = keys + self.visited = visited + messages = list() + if keys is not None: + messages.append("Key not found: {}".format(keys)) + if visited is not None: + messages.append("Visited: {}".format(visited)) + messages.append("Cause:\n{}".format(cause)) + message = "\n".join(messages) + super().__init__(message) + + +def retrieve( + list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False +): + """Given a nested list or dict return the desired value at key expanding + callable nodes if necessary and :attr:`expand` is ``True``. The expansion + is done in-place. + + Parameters + ---------- + list_or_dict : list or dict + Possibly nested list or dictionary. + key : str + key/to/value, path like string describing all keys necessary to + consider to get to the desired value. List indices can also be + passed here. + splitval : str + String that defines the delimiter between keys of the + different depth levels in `key`. + default : obj + Value returned if :attr:`key` is not found. + expand : bool + Whether to expand callable nodes on the path or not. + + Returns + ------- + The desired value or if :attr:`default` is not ``None`` and the + :attr:`key` is not found returns ``default``. + + Raises + ------ + Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is + ``None``. + """ + + keys = key.split(splitval) + + success = True + try: + visited = [] + parent = None + last_key = None + for key in keys: + if callable(list_or_dict): + if not expand: + raise KeyNotFoundError( + ValueError( + "Trying to get past callable node with expand=False." + ), + keys=keys, + visited=visited, + ) + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + + last_key = key + parent = list_or_dict + + try: + if isinstance(list_or_dict, dict): + list_or_dict = list_or_dict[key] + else: + list_or_dict = list_or_dict[int(key)] + except (KeyError, IndexError, ValueError) as e: + raise KeyNotFoundError(e, keys=keys, visited=visited) + + visited += [key] + # final expansion of retrieved value + if expand and callable(list_or_dict): + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + except KeyNotFoundError as e: + if default is None: + raise e + else: + list_or_dict = default + success = False + + if not pass_success: + return list_or_dict + else: + return list_or_dict, success + diff --git a/opensora/utils/utils.py b/opensora/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ba94b6aa12d2560c4a074645a864000bda0f94e5 --- /dev/null +++ b/opensora/utils/utils.py @@ -0,0 +1,459 @@ +import os + +import torch + +import os +import math +import torch +import logging +import random +import subprocess +import numpy as np +import torch.distributed as dist + +# from torch._six import inf +from torch import inf +from PIL import Image +from typing import Union, Iterable +from collections import OrderedDict +from torch.utils.tensorboard import SummaryWriter + +from diffusers.utils import is_bs4_available, is_ftfy_available + +import html +import re +import urllib.parse as ul + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] + +def find_model(model_name): + """ + Finds a pre-trained Latte model, downloading it if necessary. Alternatively, loads a model from a local path. + """ + assert os.path.isfile(model_name), f'Could not find Latte checkpoint at {model_name}' + checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) + + # if "ema" in checkpoint: # supports checkpoints from train.py + # print('Using Ema!') + # checkpoint = checkpoint["ema"] + # else: + print('Using model!') + checkpoint = checkpoint['model'] + return checkpoint + +################################################################################# +# Training Clip Gradients # +################################################################################# + +def get_grad_norm( + parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor: + r""" + Copy from torch.nn.utils.clip_grad_norm_ + + Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.) + device = grads[0].device + if norm_type == inf: + norms = [g.detach().abs().max().to(device) for g in grads] + total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) + else: + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) + return total_norm + + +def clip_grad_norm_( + parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, + error_if_nonfinite: bool = False, clip_grad=True) -> torch.Tensor: + r""" + Copy from torch.nn.utils.clip_grad_norm_ + + Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.) + device = grads[0].device + if norm_type == inf: + norms = [g.detach().abs().max().to(device) for g in grads] + total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) + else: + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) + + if clip_grad: + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f'The total norm of order {norm_type} for gradients from ' + '`parameters` is non-finite, so it cannot be clipped. To disable ' + 'this error and scale the gradients by the non-finite norm anyway, ' + 'set `error_if_nonfinite=False`') + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for g in grads: + g.detach().mul_(clip_coef_clamped.to(g.device)) + # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) + # print(gradient_cliped) + return total_norm + + +def get_experiment_dir(root_dir, args): + # if args.pretrained is not None and 'Latte-XL-2-256x256.pt' not in args.pretrained: + # root_dir += '-WOPRE' + if args.use_compile: + root_dir += '-Compile' # speedup by torch compile + if args.attention_mode: + root_dir += f'-{args.attention_mode.upper()}' + # if args.enable_xformers_memory_efficient_attention: + # root_dir += '-Xfor' + if args.gradient_checkpointing: + root_dir += '-Gc' + if args.mixed_precision: + root_dir += f'-{args.mixed_precision.upper()}' + root_dir += f'-{args.max_image_size}' + return root_dir + +def get_precision(args): + if args.mixed_precision == "bf16": + dtype = torch.bfloat16 + elif args.mixed_precision == "fp16": + dtype = torch.float16 + else: + dtype = torch.float32 + return dtype + +################################################################################# +# Training Logger # +################################################################################# + +def create_logger(logging_dir): + """ + Create a logger that writes to a log file and stdout. + """ + if dist.get_rank() == 0: # real logger + logging.basicConfig( + level=logging.INFO, + # format='[\033[34m%(asctime)s\033[0m] %(message)s', + format='[%(asctime)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] + ) + logger = logging.getLogger(__name__) + + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + + +def create_tensorboard(tensorboard_dir): + """ + Create a tensorboard that saves losses. + """ + if dist.get_rank() == 0: # real tensorboard + # tensorboard + writer = SummaryWriter(tensorboard_dir) + + return writer + + +def write_tensorboard(writer, *args): + ''' + write the loss information to a tensorboard file. + Only for pytorch DDP mode. + ''' + if dist.get_rank() == 0: # real tensorboard + writer.add_scalar(args[0], args[1], args[2]) + + +################################################################################# +# EMA Update/ DDP Training Utils # +################################################################################# + +@torch.no_grad() +def update_ema(ema_model, model, decay=0.9999): + """ + Step the EMA model towards the current model. + """ + ema_params = OrderedDict(ema_model.named_parameters()) + model_params = OrderedDict(model.named_parameters()) + + for name, param in model_params.items(): + # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed + ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) + + +def requires_grad(model, flag=True): + """ + Set requires_grad flag for all parameters in a model. + """ + for p in model.parameters(): + p.requires_grad = flag + + +def cleanup(): + """ + End DDP training. + """ + dist.destroy_process_group() + + +def setup_distributed(backend="nccl", port=None): + """Initialize distributed training environment. + support both slurm and torch.distributed.launch + see torch.distributed.init_process_group() for more details + """ + num_gpus = torch.cuda.device_count() + + if "SLURM_JOB_ID" in os.environ: + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NTASKS"]) + node_list = os.environ["SLURM_NODELIST"] + addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") + # specify master port + if port is not None: + os.environ["MASTER_PORT"] = str(port) + elif "MASTER_PORT" not in os.environ: + # os.environ["MASTER_PORT"] = "29566" + os.environ["MASTER_PORT"] = str(29567 + num_gpus) + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = addr + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_RANK"] = str(rank % num_gpus) + os.environ["RANK"] = str(rank) + else: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # torch.cuda.set_device(rank % num_gpus) + + dist.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + ) + + +################################################################################# +# Testing Utils # +################################################################################# + +def save_video_grid(video, nrow=None): + b, t, h, w, c = video.shape + + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = torch.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype=torch.uint8) + + print(video_grid.shape) + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + + return video_grid + + +################################################################################# +# MMCV Utils # +################################################################################# + + +def collect_env(): + # Copyright (c) OpenMMLab. All rights reserved. + from mmcv.utils import collect_env as collect_base_env + from mmcv.utils import get_git_hash + """Collect the information of the running environments.""" + + env_info = collect_base_env() + env_info['MMClassification'] = get_git_hash()[:7] + + for name, val in env_info.items(): + print(f'{name}: {val}') + + print(torch.cuda.get_arch_list()) + print(torch.version.cuda) + + +################################################################################# +# Pixart-alpha Utils # +################################################################################# + + +bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa + +def text_preprocessing(text): + # The exact text cleaning as was in the training stage: + text = clean_caption(text) + text = clean_caption(text) + return text + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + +def clean_caption(caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub('', 'person', caption) + # urls: + caption = re.sub( + r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + '', caption) # regex for urls + caption = re.sub( + r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + '', caption) # regex for urls + # html: + caption = BeautifulSoup(caption, features='html.parser').text + + # @ + caption = re.sub(r'@[\w\d]+\b', '', caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) + caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) + caption = re.sub(r'[\u3200-\u32ff]+', '', caption) + caption = re.sub(r'[\u3300-\u33ff]+', '', caption) + caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) + caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) + caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa + '-', caption) + + # кавычки к одному стандарту + caption = re.sub(r'[`´«»“”¨]', '"', caption) + caption = re.sub(r'[‘’]', "'", caption) + + # " + caption = re.sub(r'"?', '', caption) + # & + caption = re.sub(r'&', '', caption) + + # ip adresses: + caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) + + # article ids: + caption = re.sub(r'\d:\d\d\s+$', '', caption) + + # \n + caption = re.sub(r'\\n', ' ', caption) + + # "#123" + caption = re.sub(r'#\d{1,3}\b', '', caption) + # "#12345.." + caption = re.sub(r'#\d{5,}\b', '', caption) + # "123456.." + caption = re.sub(r'\b\d{6,}\b', '', caption) + # filenames: + caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) + + # + caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT""" + + caption = re.sub(bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r'\s+\.\s+', r' ', caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r'(?:\-|\_)') + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, ' ', caption) + + caption = basic_clean(caption) + + caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640 + caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc + caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231 + + caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) + caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) + caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) + caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) + caption = re.sub(r'\bpage\s+\d+\b', '', caption) + + caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a... + + caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) + + caption = re.sub(r'\b\s+\:\s+', r': ', caption) + caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) + caption = re.sub(r'\s+', ' ', caption) + + caption.strip() + + caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) + caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) + caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) + caption = re.sub(r'^\.\S+$', '', caption) + + return caption.strip() + + + + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..6c42594ff7c6b0aadc2e53fc7219036cabe15e69 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,49 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "opensora" +version = "1.0.0" +description = "Reproduce OpenAI's Sora." +readme = "README.md" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", +] +dependencies = [ + "torch==2.1.2", "torchvision==0.16.2", + "transformers==4.39.1", "accelerate==0.28.0", + "albumentations==1.4.0", "av==11.0.0", "decord==0.6.0", "einops==0.7.0", "fastapi==0.110.0", + "gdown==5.1.0", "h5py==3.10.0", "idna==3.6", 'imageio==2.34.0', "matplotlib==3.7.5", "numpy==1.24.4", + "omegaconf==2.1.1", "opencv-python==4.9.0.80", "opencv-python-headless==4.9.0.80", "pandas==2.0.3", "pillow==10.2.0", + "pydub==0.25.1", "pytorch-lightning==1.4.2", "pytorchvideo==0.1.5", "PyYAML==6.0.1", "regex==2023.12.25", + "requests==2.31.0", "scikit-learn==1.3.2", "scipy==1.10.1", "six==1.16.0", "test-tube==0.7.5", + "timm==0.9.16", "torchdiffeq==0.2.3", "torchmetrics==0.5.0", "tqdm==4.66.2", "urllib3==2.2.1", "uvicorn==0.27.1", + "diffusers==0.24.0", "scikit-video==1.1.11", "imageio-ffmpeg==0.4.9", "sentencepiece==0.1.99", "beautifulsoup4==4.12.3", + "ftfy==6.1.3", "moviepy==1.0.3", "wandb==0.16.3", "tensorboard==2.14.0" +] + +[project.optional-dependencies] +train = ["deepspeed==0.9.5", "pydantic==1.10.13"] +dev = ["mypy==1.8.0"] + + +[project.urls] +"Homepage" = "https://github.com/PKU-YuanGroup/Open-Sora-Plan" +"Bug Tracker" = "https://github.com/PKU-YuanGroup/Open-Sora-Plan/issues" + +[tool.setuptools.packages.find] +exclude = ["assets*", "docker*", "docs", "scripts*"] + +[tool.wheel] +exclude = ["assets*", "docker*", "docs", "scripts*"] + +[tool.mypy] +warn_return_any = true +warn_unused_configs = true +ignore_missing_imports = true +disallow_untyped_calls = true +check_untyped_defs = true +no_implicit_optional = true diff --git a/scripts/accelerate_configs/ddp_config.yaml b/scripts/accelerate_configs/ddp_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa6cfbdd6b16df740418da3d21de91ea4252282f --- /dev/null +++ b/scripts/accelerate_configs/ddp_config.yaml @@ -0,0 +1,11 @@ +compute_environment: LOCAL_MACHINE +distributed_type: MULTI_GPU +fsdp_config: {} +machine_rank: 0 +main_process_ip: null +main_process_port: 29501 +main_training_function: main +num_machines: 1 +num_processes: 8 +gpu_ids: 0,1,2,3,4,5,6,7 +use_cpu: false \ No newline at end of file diff --git a/scripts/accelerate_configs/deepspeed_zero2_config.yaml b/scripts/accelerate_configs/deepspeed_zero2_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..43ec1e80cc1ad01a012da0354e50e089ac2b865d --- /dev/null +++ b/scripts/accelerate_configs/deepspeed_zero2_config.yaml @@ -0,0 +1,13 @@ +compute_environment: LOCAL_MACHINE +distributed_type: DEEPSPEED +deepspeed_config: + deepspeed_config_file: scripts/accelerate_configs/zero2.json +fsdp_config: {} +machine_rank: 0 +main_process_ip: null +main_process_port: 29501 +main_training_function: main +num_machines: 1 +num_processes: 8 +gpu_ids: 0,1,2,3,4,5,6,7 +use_cpu: false \ No newline at end of file diff --git a/scripts/accelerate_configs/deepspeed_zero2_offload_config.yaml b/scripts/accelerate_configs/deepspeed_zero2_offload_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9c96bae6751d461100bfd6ba0f2c22f9450d726b --- /dev/null +++ b/scripts/accelerate_configs/deepspeed_zero2_offload_config.yaml @@ -0,0 +1,13 @@ +compute_environment: LOCAL_MACHINE +distributed_type: DEEPSPEED +deepspeed_config: + deepspeed_config_file: scripts/accelerate_configs/zero2_offload.json +fsdp_config: {} +machine_rank: 0 +main_process_ip: null +main_process_port: 29501 +main_training_function: main +num_machines: 1 +num_processes: 8 +gpu_ids: 0,1,2,3,4,5,6,7 +use_cpu: false \ No newline at end of file diff --git a/scripts/accelerate_configs/deepspeed_zero3_config.yaml b/scripts/accelerate_configs/deepspeed_zero3_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4a4a0935e048766b150062a75090355e0aa0d302 --- /dev/null +++ b/scripts/accelerate_configs/deepspeed_zero3_config.yaml @@ -0,0 +1,13 @@ +compute_environment: LOCAL_MACHINE +distributed_type: DEEPSPEED +deepspeed_config: + deepspeed_config_file: scripts/accelerate_configs/zero3.json +fsdp_config: {} +machine_rank: 0 +main_process_ip: null +main_process_port: 29501 +main_training_function: main +num_machines: 1 +num_processes: 8 +gpu_ids: 0,1,2,3,4,5,6,7 +use_cpu: false \ No newline at end of file diff --git a/scripts/accelerate_configs/deepspeed_zero3_offload_config.yaml b/scripts/accelerate_configs/deepspeed_zero3_offload_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b01631caf6a2bae2bc30659c89931917e769a083 --- /dev/null +++ b/scripts/accelerate_configs/deepspeed_zero3_offload_config.yaml @@ -0,0 +1,13 @@ +compute_environment: LOCAL_MACHINE +distributed_type: DEEPSPEED +deepspeed_config: + deepspeed_config_file: scripts/accelerate_configs/zero3_offload.json +fsdp_config: {} +machine_rank: 0 +main_process_ip: null +main_process_port: 29501 +main_training_function: main +num_machines: 1 +num_processes: 8 +gpu_ids: 0,1,2,3,4,5,6,7 +use_cpu: false \ No newline at end of file diff --git a/scripts/accelerate_configs/default_config.yaml b/scripts/accelerate_configs/default_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ac3bb99b568c237d6d2da8db473ef20f94520996 --- /dev/null +++ b/scripts/accelerate_configs/default_config.yaml @@ -0,0 +1,12 @@ +compute_environment: LOCAL_MACHINE +distributed_type: MULTI_GPU +fsdp_config: {} +machine_rank: 0 +main_process_ip: null +main_process_port: 29501 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +gpu_ids: 0,1,2,3,4,5,6,7 +use_cpu: false \ No newline at end of file diff --git a/scripts/accelerate_configs/hostfile b/scripts/accelerate_configs/hostfile new file mode 100644 index 0000000000000000000000000000000000000000..a226930999c83d9c35595ec64d0ee66f7c774349 --- /dev/null +++ b/scripts/accelerate_configs/hostfile @@ -0,0 +1,2 @@ +gpu55 slots=8 # your server name and GPU in total +gpu117 slots=8 diff --git a/scripts/accelerate_configs/multi_node_example.yaml b/scripts/accelerate_configs/multi_node_example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4e3cac79c6fc17d84e1c58dd1f2d321ad714ddd4 --- /dev/null +++ b/scripts/accelerate_configs/multi_node_example.yaml @@ -0,0 +1,18 @@ +compute_environment: LOCAL_MACHINE +distributed_type: DEEPSPEED +deepspeed_config: + deepspeed_config_file: scripts/accelerate_configs/zero2.json + deepspeed_hostfile: /remote-home1/yeyang/Open-Sora-Plan/scripts/accelerate_configs/hostfile +fsdp_config: {} +machine_rank: 0 +main_process_ip: 10.10.10.55 +main_process_port: 29501 +main_training_function: main +num_machines: 2 +num_processes: 16 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/scripts/accelerate_configs/zero2.json b/scripts/accelerate_configs/zero2.json new file mode 100644 index 0000000000000000000000000000000000000000..bc26a0a5fb44b2bac1659ba7b38d32099738232d --- /dev/null +++ b/scripts/accelerate_configs/zero2.json @@ -0,0 +1,23 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "train_micro_batch_size_per_gpu": "auto", + "train_batch_size": "auto", + "gradient_accumulation_steps": "auto", + "zero_optimization": { + "stage": 2, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": 5e8 + } +} \ No newline at end of file diff --git a/scripts/accelerate_configs/zero2_offload.json b/scripts/accelerate_configs/zero2_offload.json new file mode 100644 index 0000000000000000000000000000000000000000..c315ba5fb10a3247104a606c6331a9643813bb99 --- /dev/null +++ b/scripts/accelerate_configs/zero2_offload.json @@ -0,0 +1,26 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "train_micro_batch_size_per_gpu": "auto", + "train_batch_size": "auto", + "gradient_accumulation_steps": "auto", + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu" + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": 5e8 + } +} \ No newline at end of file diff --git a/scripts/accelerate_configs/zero3.json b/scripts/accelerate_configs/zero3.json new file mode 100644 index 0000000000000000000000000000000000000000..89dc78e0b020634183cbe51a584f55b0e828efae --- /dev/null +++ b/scripts/accelerate_configs/zero3.json @@ -0,0 +1,28 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "train_micro_batch_size_per_gpu": "auto", + "train_batch_size": "auto", + "gradient_accumulation_steps": "auto", + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": 5e8, + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + } +} \ No newline at end of file diff --git a/scripts/accelerate_configs/zero3_offload.json b/scripts/accelerate_configs/zero3_offload.json new file mode 100644 index 0000000000000000000000000000000000000000..c15bcf91b1ea0ffbade5e07b1d4d8b7455729e9e --- /dev/null +++ b/scripts/accelerate_configs/zero3_offload.json @@ -0,0 +1,39 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": 5e8, + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "gather_16bit_weights_on_model_save": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "steps_per_print": 1e5, + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/scripts/causalvae/eval.sh b/scripts/causalvae/eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..ed8bd3eda34357499b6c475ae1115f5e4604c1ee --- /dev/null +++ b/scripts/causalvae/eval.sh @@ -0,0 +1,12 @@ +python opensora/eval/eval_common_metric.py \ + --batch_size 2 \ + --real_video_dir ..//test_eval/release/origin \ + --generated_video_dir ../test_eval/release \ + --device cuda \ + --sample_fps 10 \ + --crop_size 256 \ + --resolution 256 \ + --num_frames 17 \ + --sample_rate 1 \ + --subset_size 100 \ + --metric ssim \ No newline at end of file diff --git a/scripts/causalvae/gen_video.sh b/scripts/causalvae/gen_video.sh new file mode 100644 index 0000000000000000000000000000000000000000..2f340d62e17f0a19d9a0785b1455a003940d0126 --- /dev/null +++ b/scripts/causalvae/gen_video.sh @@ -0,0 +1,13 @@ +python examples/rec_video_vae.py \ + --batch_size 1 \ + --real_video_dir ../test_eval/eyes_test \ + --generated_video_dir ../test_eval/eyes_gen \ + --device cuda \ + --sample_fps 10 \ + --sample_rate 1 \ + --num_frames 17 \ + --resolution 512 \ + --crop_size 512 \ + --num_workers 8 \ + --ckpt results/pretrained_488 \ + --enable_tiling \ No newline at end of file diff --git a/scripts/causalvae/release.json b/scripts/causalvae/release.json new file mode 100644 index 0000000000000000000000000000000000000000..62211d2849a9bc34e2137d7b2399559be94ade70 --- /dev/null +++ b/scripts/causalvae/release.json @@ -0,0 +1,71 @@ +{ + "_class_name": "CausalVAEModel", + "_diffusers_version": "0.27.2", + "attn_resolutions": [], + "decoder_attention": "AttnBlock3D", + "decoder_conv_in": "CausalConv3d", + "decoder_conv_out": "CausalConv3d", + "decoder_mid_resnet": "ResnetBlock3D", + "decoder_resnet_blocks": [ + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D" + ], + "decoder_spatial_upsample": [ + "", + "SpatialUpsample2x", + "SpatialUpsample2x", + "SpatialUpsample2x" + ], + "decoder_temporal_upsample": [ + "", + "", + "TimeUpsample2x", + "TimeUpsample2x" + ], + "double_z": true, + "dropout": 0.0, + "embed_dim": 4, + "encoder_attention": "AttnBlock3D", + "encoder_conv_in": "CausalConv3d", + "encoder_conv_out": "CausalConv3d", + "encoder_mid_resnet": "ResnetBlock3D", + "encoder_resnet_blocks": [ + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D" + ], + "encoder_spatial_downsample": [ + "SpatialDownsample2x", + "SpatialDownsample2x", + "SpatialDownsample2x", + "" + ], + "encoder_temporal_downsample": [ + "", + "TimeDownsample2x", + "TimeDownsample2x", + "" + ], + "hidden_size": 128, + "hidden_size_mult": [ + 1, + 2, + 4, + 4 + ], + "loss_params": { + "disc_start": 2001, + "disc_weight": 0.5, + "kl_weight": 1e-06, + "logvar_init": 0.0 + }, + "loss_type": "opensora.models.ae.videobase.losses.LPIPSWithDiscriminator", + "lr": 1e-05, + "num_res_blocks": 2, + "q_conv": "CausalConv3d", + "resolution": 256, + "z_channels": 4 +} diff --git a/scripts/causalvae/train.sh b/scripts/causalvae/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..31826d77a8c6cd1219354ab1c60137c8688b1045 --- /dev/null +++ b/scripts/causalvae/train.sh @@ -0,0 +1,15 @@ +python opensora/train/train_causalvae.py \ + --exp_name "exp_name" \ + --batch_size 1 \ + --precision bf16 \ + --max_steps 40000 \ + --save_steps 2000 \ + --output_dir results/causalvae \ + --video_path /remote-home1/dataset/data_split_tt \ + --video_num_frames 5 \ + --resolution 32 \ + --sample_rate 1 \ + --n_nodes 1 \ + --devices 1 \ + --num_workers 8 \ + --model_config scripts/causalvae/release.json \ No newline at end of file diff --git a/scripts/class_condition/sample.sh b/scripts/class_condition/sample.sh new file mode 100644 index 0000000000000000000000000000000000000000..eedee9bdce9f9d0b96a07285452468266da6db54 --- /dev/null +++ b/scripts/class_condition/sample.sh @@ -0,0 +1,15 @@ +accelerate launch \ + --num_processes 1 \ + --main_process_port 29502 \ + opensora/sample/sample.py \ + --model Latte-XL/122 \ + --ae stabilityai/sd-vae-ft-mse \ + --ckpt ucf101-f16s3-128-imgvae188-bf16-ckpt-flash/checkpoint-98500 \ + --train_classcondition \ + --num_classes 101 \ + --fps 10 \ + --num_frames 16 \ + --image_size 128 \ + --num_sampling_steps 500 \ + --attention_mode flash \ + --mixed_precision bf16 diff --git a/scripts/class_condition/train_imgae.sh b/scripts/class_condition/train_imgae.sh new file mode 100644 index 0000000000000000000000000000000000000000..d48cfe6af130f0552a881ba623770d1e09c3d05d --- /dev/null +++ b/scripts/class_condition/train_imgae.sh @@ -0,0 +1,26 @@ +export WANDB_KEY="" +export ENTITY="" +export PROJECT="ucf101-f16s3-128-imgvae188-bf16-ckpt-flash" +accelerate launch \ + --config_file scripts/accelerate_configs/ddp_config.yaml \ + opensora/train/train.py \ + --model Latte-XL/122 \ + --dataset ucf101 \ + --ae stabilityai/sd-vae-ft-mse \ + --data_path /remote-home/yeyang/UCF-101 \ + --train_classcondition \ + --num_classes 101 \ + --sample_rate 3 \ + --num_frames 16 \ + --max_image_size 128 \ + --gradient_checkpointing \ + --attention_mode flash \ + --train_batch_size=8 --dataloader_num_workers 10 \ + --gradient_accumulation_steps=1 \ + --max_train_steps=1000000 \ + --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ + --mixed_precision="bf16" \ + --report_to="wandb" \ + --checkpointing_steps=500 \ + --output_dir="ucf101-f16s3-128-imgvae188-bf16-ckpt-flash" \ + --allow_tf32 diff --git a/scripts/class_condition/train_vidae.sh b/scripts/class_condition/train_vidae.sh new file mode 100644 index 0000000000000000000000000000000000000000..cedfcd4ab9a8a5f88a1e14119d6b0b44a8f9e852 --- /dev/null +++ b/scripts/class_condition/train_vidae.sh @@ -0,0 +1,26 @@ +export WANDB_KEY="" +export ENTITY="" +export PROJECT="ucf101-f16s3-128-causalvideovae444-bf16-ckpt-flash" +accelerate launch \ + --config_file scripts/accelerate_configs/ddp_config.yaml \ + opensora/train/train.py \ + --model Latte-XL/122 \ + --dataset ucf101 \ + --ae CausalVQVAEModel_4x4x4 \ + --data_path /remote-home/yeyang/UCF-101 \ + --train_classcondition \ + --num_classes 101 \ + --sample_rate 3 \ + --num_frames 16 \ + --max_image_size 128 \ + --gradient_checkpointing \ + --attention_mode flash \ + --train_batch_size=8 --dataloader_num_workers 10 \ + --gradient_accumulation_steps=1 \ + --max_train_steps=1000000 \ + --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ + --mixed_precision="bf16" \ + --report_to="wandb" \ + --checkpointing_steps=500 \ + --output_dir="ucf101-f16s3-128-causalvideovae444-bf16-ckpt-flash" \ + --allow_tf32 diff --git a/scripts/slurm/placeholder b/scripts/slurm/placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/text_condition/sample_image.sh b/scripts/text_condition/sample_image.sh new file mode 100644 index 0000000000000000000000000000000000000000..303dac6d717dc8325f4753a49d02fe7beda8e855 --- /dev/null +++ b/scripts/text_condition/sample_image.sh @@ -0,0 +1,12 @@ +CUDA_VISIBLE_DEVICES=0 python opensora/sample/sample_t2v.py \ + --model_path LanguageBind/Open-Sora-Plan-v1.0.0 \ + --text_encoder_name DeepFloyd/t5-v1_1-xxl \ + --text_prompt examples/prompt_list_0.txt \ + --ae CausalVAEModel_4x8x8 \ + --version 65x512x512 \ + --save_img_path "./sample_images/prompt_list_0" \ + --fps 24 \ + --guidance_scale 7.5 \ + --num_sampling_steps 250 \ + --enable_tiling \ + --force_images diff --git a/scripts/text_condition/sample_video.sh b/scripts/text_condition/sample_video.sh new file mode 100644 index 0000000000000000000000000000000000000000..a7f8f20d5d4faf57b210801ab663280e46c3dc94 --- /dev/null +++ b/scripts/text_condition/sample_video.sh @@ -0,0 +1,11 @@ +CUDA_VISIBLE_DEVICES=0 python opensora/sample/sample_t2v.py \ + --model_path LanguageBind/Open-Sora-Plan-v1.0.0 \ + --text_encoder_name DeepFloyd/t5-v1_1-xxl \ + --text_prompt examples/prompt_list_0.txt \ + --ae CausalVAEModel_4x8x8 \ + --version 65x512x512 \ + --save_img_path "./sample_videos/prompt_0" \ + --fps 24 \ + --guidance_scale 10.0 \ + --num_sampling_steps 100 \ + --enable_tiling diff --git a/scripts/text_condition/train_imageae.sh b/scripts/text_condition/train_imageae.sh new file mode 100644 index 0000000000000000000000000000000000000000..fff9c8afe7a8d2406bee6c435ba22e09f0ef5915 --- /dev/null +++ b/scripts/text_condition/train_imageae.sh @@ -0,0 +1,34 @@ +export WANDB_KEY="" +export ENTITY="" +export PROJECT="t2v-f16s3-img4-128-imgvae188-bf16-gc-xformers" +accelerate launch \ + --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ + opensora/train/train_t2v.py \ + --model LatteT2V-XL/122 \ + --text_encoder_name DeepFloyd/t5-v1_1-xxl \ + --dataset t2v \ + --ae stabilityai/sd-vae-ft-mse \ + --data_path /remote-home1/dataset/sharegpt4v_path_cap_.json \ + --video_folder /remote-home1/dataset/data_split \ + --sample_rate 1 \ + --num_frames 17 \ + --max_image_size 256 \ + --gradient_checkpointing \ + --attention_mode xformers \ + --train_batch_size=4 \ + --dataloader_num_workers 10 \ + --gradient_accumulation_steps=1 \ + --max_train_steps=1000000 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --mixed_precision="bf16" \ + --report_to="wandb" \ + --checkpointing_steps=500 \ + --output_dir="t2v-f17-256-img4-imagevae488-bf16-ckpt-xformers-bs4-lr2e-5-t5" \ + --allow_tf32 \ + --pretrained t2v.pt \ + --use_deepspeed \ + --model_max_length 300 \ + --use_image_num 4 \ + --use_img_from_vid diff --git a/scripts/text_condition/train_videoae_17x256x256.sh b/scripts/text_condition/train_videoae_17x256x256.sh new file mode 100644 index 0000000000000000000000000000000000000000..f7cfb1729a6040a64ec2e51d0812b9e1f8002118 --- /dev/null +++ b/scripts/text_condition/train_videoae_17x256x256.sh @@ -0,0 +1,35 @@ +export WANDB_KEY="" +export ENTITY="" +export PROJECT="t2v-f17-256-img4-videovae488-bf16-ckpt-xformers-bs4-lr2e-5-t5" +accelerate launch \ + --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ + opensora/train/train_t2v.py \ + --model LatteT2V-XL/122 \ + --text_encoder_name DeepFloyd/t5-v1_1-xxl \ + --dataset t2v \ + --ae CausalVAEModel_4x8x8 \ + --ae_path CausalVAEModel_4x8x8 \ + --data_path /remote-home1/dataset/sharegpt4v_path_cap_64x512x512.json \ + --video_folder /remote-home1/dataset/data_split_tt \ + --sample_rate 1 \ + --num_frames 17 \ + --max_image_size 256 \ + --gradient_checkpointing \ + --attention_mode xformers \ + --train_batch_size=4 \ + --dataloader_num_workers 10 \ + --gradient_accumulation_steps=1 \ + --max_train_steps=1000000 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --mixed_precision="bf16" \ + --report_to="wandb" \ + --checkpointing_steps=500 \ + --output_dir="t2v-f17-256-img4-videovae488-bf16-ckpt-xformers-bs4-lr2e-5-t5" \ + --allow_tf32 \ + --pretrained t2v.pt \ + --use_deepspeed \ + --model_max_length 300 \ + --use_image_num 4 \ + --use_img_from_vid diff --git a/scripts/text_condition/train_videoae_65x256x256.sh b/scripts/text_condition/train_videoae_65x256x256.sh new file mode 100644 index 0000000000000000000000000000000000000000..45d497afa9c2ab9bd95812335be7191067ea02c1 --- /dev/null +++ b/scripts/text_condition/train_videoae_65x256x256.sh @@ -0,0 +1,35 @@ +export WANDB_KEY="" +export ENTITY="" +export PROJECT="t2v-f65-256-img4-videovae488-bf16-ckpt-xformers-bs4-lr2e-5-t5" +accelerate launch \ + --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ + opensora/train/train_t2v.py \ + --model LatteT2V-XL/122 \ + --text_encoder_name DeepFloyd/t5-v1_1-xxl \ + --dataset t2v \ + --ae CausalVAEModel_4x8x8 \ + --ae_path CausalVAEModel_4x8x8 \ + --data_path /remote-home1/dataset/sharegpt4v_path_cap_.json \ + --video_folder /remote-home1/dataset/data_split_tt \ + --sample_rate 1 \ + --num_frames 65 \ + --max_image_size 256 \ + --gradient_checkpointing \ + --attention_mode xformers \ + --train_batch_size=4 \ + --dataloader_num_workers 10 \ + --gradient_accumulation_steps=1 \ + --max_train_steps=1000000 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --mixed_precision="bf16" \ + --report_to="wandb" \ + --checkpointing_steps=500 \ + --output_dir="t2v-f65-256-img4-videovae488-bf16-ckpt-xformers-bs4-lr2e-5-t5" \ + --allow_tf32 \ + --pretrained t2v.pt \ + --use_deepspeed \ + --model_max_length 300 \ + --use_image_num 4 \ + --use_img_from_vid diff --git a/scripts/text_condition/train_videoae_65x512x512.sh b/scripts/text_condition/train_videoae_65x512x512.sh new file mode 100644 index 0000000000000000000000000000000000000000..c8823823ee2d6ce07db9c92377ec9922cde61408 --- /dev/null +++ b/scripts/text_condition/train_videoae_65x512x512.sh @@ -0,0 +1,36 @@ +export WANDB_KEY="" +export ENTITY="" +export PROJECT="t2v-f65-256-img16-videovae488-bf16-ckpt-xformers-bs4-lr2e-5-t5" +accelerate launch \ + --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ + opensora/train/train_t2v.py \ + --model LatteT2V-XL/122 \ + --text_encoder_name DeepFloyd/t5-v1_1-xxl \ + --dataset t2v \ + --ae CausalVAEModel_4x8x8 \ + --ae_path CausalVAEModel_4x8x8 \ + --data_path /remote-home1/dataset/sharegpt4v_path_cap_.json \ + --video_folder /remote-home1/dataset/data_split_tt \ + --sample_rate 1 \ + --num_frames 65 \ + --max_image_size 512 \ + --gradient_checkpointing \ + --attention_mode xformers \ + --train_batch_size=2 \ + --dataloader_num_workers 10 \ + --gradient_accumulation_steps=1 \ + --max_train_steps=1000000 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --mixed_precision="bf16" \ + --report_to="wandb" \ + --checkpointing_steps=500 \ + --output_dir="t2v-f65-512-img16-videovae488-bf16-ckpt-xformers-bs4-lr2e-5-t5" \ + --allow_tf32 \ + --pretrained t2v.pt \ + --use_deepspeed \ + --model_max_length 300 \ + --use_image_num 16 \ + --use_img_from_vid \ + --enable_tiling diff --git a/scripts/un_condition/sample.sh b/scripts/un_condition/sample.sh new file mode 100644 index 0000000000000000000000000000000000000000..5241d2cf5ffb4d23b6b331ebce5b10aa4e2b38e8 --- /dev/null +++ b/scripts/un_condition/sample.sh @@ -0,0 +1,15 @@ + +accelerate launch \ + --num_processes 1 \ + --main_process_port 29502 \ + opensora/sample/sample.py \ + --model Latte-XL/122 \ + --ae CausalVQVAEModel \ + --ckpt sky-f17s3-128-causalvideovae488-bf16-ckpt-flash-log/checkpoint-45500 \ + --fps 10 \ + --num_frames 17 \ + --image_size 128 \ + --num_sampling_steps 250 \ + --attention_mode flash \ + --mixed_precision bf16 \ + --num_sample 10 \ No newline at end of file diff --git a/scripts/un_condition/train_imgae.sh b/scripts/un_condition/train_imgae.sh new file mode 100644 index 0000000000000000000000000000000000000000..1442865f6b66a19cd9794cd5e7cc00878b15e367 --- /dev/null +++ b/scripts/un_condition/train_imgae.sh @@ -0,0 +1,25 @@ +export WANDB_KEY="" +export ENTITY="" +export PROJECT="sky-f16s3-128-imgae188-bf16-ckpt-flash-log" +HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 accelerate launch \ + --config_file scripts/accelerate_configs/ddp_config.yaml \ + opensora/train/train.py \ + --model Latte-XL/122 \ + --dataset sky \ + --ae stabilityai/sd-vae-ft-mse \ + --data_path /remote-home/yeyang/sky_timelapse/sky_train/ \ + --sample_rate 3 \ + --num_frames 16 \ + --max_image_size 128 \ + --gradient_checkpointing \ + --attention_mode flash \ + --train_batch_size=8 --dataloader_num_workers 10 \ + --gradient_accumulation_steps=1 \ + --max_train_steps=1000000 \ + --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ + --mixed_precision="bf16" \ + --report_to="wandb" \ + --checkpointing_steps=500 \ + --output_dir="sky-f16s3-128-imgae188-bf16-ckpt-flash-log" \ + --allow_tf32 + diff --git a/scripts/un_condition/train_vidae.sh b/scripts/un_condition/train_vidae.sh new file mode 100644 index 0000000000000000000000000000000000000000..b071dd1e699339785de83a0200aca8a6e8b81679 --- /dev/null +++ b/scripts/un_condition/train_vidae.sh @@ -0,0 +1,25 @@ +export WANDB_KEY="" +export ENTITY="" +export PROJECT="sky-f17s3-128-causalvideovae444-bf16-ckpt-flash-log" +HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 accelerate launch \ + --config_file scripts/accelerate_configs/ddp_config.yaml \ + opensora/train/train.py \ + --model Latte-XL/122 \ + --dataset sky \ + --ae CausalVQVAEModel_4x4x4 \ + --data_path /remote-home/yeyang/sky_timelapse/sky_train/ \ + --sample_rate 3 \ + --num_frames 17 \ + --max_image_size 128 \ + --gradient_checkpointing \ + --attention_mode flash \ + --train_batch_size=8 --dataloader_num_workers 10 \ + --gradient_accumulation_steps=1 \ + --max_train_steps=1000000 \ + --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ + --mixed_precision="bf16" \ + --report_to="wandb" \ + --checkpointing_steps=500 \ + --output_dir="sky-f17s3-128-causalvideovae444-bf16-ckpt-flash-log" \ + --allow_tf32 + diff --git a/scripts/videogpt/train_videogpt.sh b/scripts/videogpt/train_videogpt.sh new file mode 100644 index 0000000000000000000000000000000000000000..6c06e0a64104669d133bcfdad73adc7ef7d1b6c4 --- /dev/null +++ b/scripts/videogpt/train_videogpt.sh @@ -0,0 +1,30 @@ + +accelerate launch \ + --config_file scripts/accelerate_configs/ddp_config.yaml \ + opensora/train/train_videogpt.py \ + --do_train \ + --seed 1234 \ + --data_path "/remote-home/yeyang/UCF-101/" \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --learning_rate 7e-4 \ + --weight_decay 0. \ + --max_steps 20000 \ + --lr_scheduler_type cosine \ + --max_grad_norm 1.0 \ + --save_strategy steps \ + --save_total_limit 5 \ + --logging_steps 5 \ + --save_steps 1000 \ + --n_codes 2048 \ + --n_hiddens 240 \ + --embedding_dim 4 \ + --n_res_layers 4 \ + --downsample "4,4,4" \ + --resolution 240 \ + --sequence_length 16 \ + --output_dir results/videogpt_488_256_16 \ + --bf16 True \ + --fp16 False \ + --report_to tensorboard \ + --dataloader_num_workers 10 diff --git a/scripts/videogpt/train_videogpt_dsz2.sh b/scripts/videogpt/train_videogpt_dsz2.sh new file mode 100644 index 0000000000000000000000000000000000000000..7875333eb436f6f50e951d858f064d1ca83af95c --- /dev/null +++ b/scripts/videogpt/train_videogpt_dsz2.sh @@ -0,0 +1,30 @@ +export ACCELERATE_GRADIENT_ACCUMULATION_STEPS=1 + +accelerate launch \ + --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ + opensora/train/train_videogpt.py \ + --do_train \ + --seed 1234 \ + --data_path "datasets/UCF-101/" \ + --per_device_train_batch_size 32 \ + --gradient_accumulation_steps $ACCELERATE_GRADIENT_ACCUMULATION_STEPS \ + --learning_rate 7e-4 \ + --weight_decay 0. \ + --num_train_epochs 2 \ + --lr_scheduler_type cosine \ + --max_grad_norm 1.0 \ + --save_strategy steps \ + --save_total_limit 5 \ + --logging_steps 5 \ + --save_steps 10000 \ + --n_codes 1024 \ + --n_hiddens 240 \ + --embedding_dim 4 \ + --n_res_layers 4 \ + --downsample "4,4,4" \ + --resolution 128 \ + --sequence_length 16 \ + --output_dir results/videogpt_444_128 \ + --bf16 True \ + --fp16 False \ + --report_to tensorboard diff --git a/scripts/videogpt/train_videogpt_dsz3.sh b/scripts/videogpt/train_videogpt_dsz3.sh new file mode 100644 index 0000000000000000000000000000000000000000..4840842362f94e93d2b0acc2cd7f6c9bfe15c464 --- /dev/null +++ b/scripts/videogpt/train_videogpt_dsz3.sh @@ -0,0 +1,30 @@ +export ACCELERATE_GRADIENT_ACCUMULATION_STEPS=1 + +accelerate launch \ + --config_file scripts/accelerate_configs/deepspeed_zero3_config.yaml \ + opensora/train/train_videogpt.py \ + --do_train \ + --seed 1234 \ + --data_path "datasets/UCF-101/" \ + --per_device_train_batch_size 32 \ + --gradient_accumulation_steps $ACCELERATE_GRADIENT_ACCUMULATION_STEPS \ + --learning_rate 7e-4 \ + --weight_decay 0. \ + --num_train_epochs 2 \ + --lr_scheduler_type cosine \ + --max_grad_norm 1.0 \ + --save_strategy steps \ + --save_total_limit 5 \ + --logging_steps 5 \ + --save_steps 10000 \ + --n_codes 1024 \ + --n_hiddens 240 \ + --embedding_dim 4 \ + --n_res_layers 4 \ + --downsample "4,4,4" \ + --resolution 128 \ + --sequence_length 16 \ + --output_dir results/videogpt_444_128 \ + --bf16 True \ + --fp16 False \ + --report_to tensorboard