LinB203 commited on
Commit
a220803
·
1 Parent(s): bb863bd
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +15 -0
  2. LICENSE +21 -0
  3. README.md +1 -1
  4. docker/LICENSE +21 -0
  5. docker/README.md +87 -0
  6. docker/build_docker.png +0 -0
  7. docker/docker_build.sh +8 -0
  8. docker/docker_run.sh +45 -0
  9. docker/dockerfile.base +24 -0
  10. docker/packages.txt +3 -0
  11. docker/ports.txt +1 -0
  12. docker/postinstallscript.sh +3 -0
  13. docker/requirements.txt +40 -0
  14. docker/run_docker.png +0 -0
  15. docker/setup_env.sh +11 -0
  16. docs/CausalVideoVAE.md +36 -0
  17. docs/Contribution_Guidelines.md +87 -0
  18. docs/Data.md +35 -0
  19. docs/EVAL.md +110 -0
  20. docs/VQVAE.md +57 -0
  21. examples/get_latents_std.py +38 -0
  22. examples/prompt_list_0.txt +16 -0
  23. examples/rec_image.py +57 -0
  24. examples/rec_imvi_vae.py +159 -0
  25. examples/rec_video.py +120 -0
  26. examples/rec_video_ae.py +120 -0
  27. examples/rec_video_vae.py +274 -0
  28. opensora/__init__.py +1 -0
  29. opensora/dataset/__init__.py +99 -0
  30. opensora/dataset/extract_feature_dataset.py +64 -0
  31. opensora/dataset/feature_datasets.py +213 -0
  32. opensora/dataset/landscope.py +90 -0
  33. opensora/dataset/sky_datasets.py +128 -0
  34. opensora/dataset/t2v_datasets.py +111 -0
  35. opensora/dataset/transform.py +489 -0
  36. opensora/dataset/ucf101.py +80 -0
  37. opensora/eval/cal_flolpips.py +83 -0
  38. opensora/eval/cal_fvd.py +85 -0
  39. opensora/eval/cal_lpips.py +97 -0
  40. opensora/eval/cal_psnr.py +84 -0
  41. opensora/eval/cal_ssim.py +113 -0
  42. opensora/eval/eval_clip_score.py +225 -0
  43. opensora/eval/eval_common_metric.py +224 -0
  44. opensora/eval/flolpips/correlation/correlation.py +397 -0
  45. opensora/eval/flolpips/flolpips.py +308 -0
  46. opensora/eval/flolpips/pretrained_networks.py +180 -0
  47. opensora/eval/flolpips/pwcnet.py +344 -0
  48. opensora/eval/flolpips/utils.py +95 -0
  49. opensora/eval/fvd/styleganv/fvd.py +90 -0
  50. opensora/eval/fvd/styleganv/i3d_torchscript.pt +3 -0
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ucf101_stride4x4x4
2
+ __pycache__
3
+ *.mp4
4
+ .ipynb_checkpoints
5
+ *.pth
6
+ UCF-101/
7
+ results/
8
+ vae
9
+ build/
10
+ opensora.egg-info/
11
+ wandb/
12
+ .idea
13
+ *.ipynb
14
+ *.jpg
15
+ *.mp3
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 PKU-YUAN's Group (袁粒课题组-北大信工)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🦀
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.25.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.37.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
docker/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 SimonLee
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
docker/README.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Docker4ML
2
+
3
+ Useful docker scripts for ML developement.
4
+ [https://github.com/SimonLeeGit/Docker4ML](https://github.com/SimonLeeGit/Docker4ML)
5
+
6
+ ## Build Docker Image
7
+
8
+ ```bash
9
+ bash docker_build.sh
10
+ ```
11
+
12
+ ![build_docker](build_docker.png)
13
+
14
+ ## Run Docker Container as Development Envirnoment
15
+
16
+ ```bash
17
+ bash docker_run.sh
18
+ ```
19
+
20
+ ![run_docker](run_docker.png)
21
+
22
+ ## Custom Docker Config
23
+
24
+ ### Config [setup_env.sh](./setup_env.sh)
25
+
26
+ You can modify this file to custom your settings.
27
+
28
+ ```bash
29
+ TAG=ml:dev
30
+ BASE_TAG=nvcr.io/nvidia/pytorch:23.12-py3
31
+ ```
32
+
33
+ #### TAG
34
+
35
+ Your built docker image tag, you can set it as what you what.
36
+
37
+ #### BASE_TAG
38
+
39
+ The base docker image tag for your built docker image, here we use nvidia pytorch images.
40
+ You can check it from [https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags)
41
+
42
+ Also, you can use other docker image as base, such as: [ubuntu](https://hub.docker.com/_/ubuntu/tags)
43
+
44
+ ### USER_NAME
45
+
46
+ Your user name used in docker container.
47
+
48
+ ### USER_PASSWD
49
+
50
+ Your user password used in docker container.
51
+
52
+ ### Config [requriements.txt](./requirements.txt)
53
+
54
+ You can add your default installed python libraries here.
55
+
56
+ ```txt
57
+ transformers==4.27.1
58
+ ```
59
+
60
+ By default, it has some libs installed, you can check it from [https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-01.html](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-01.html)
61
+
62
+ ### Config [packages.txt](./packages.txt)
63
+
64
+ You can add your default apt-get installed packages here.
65
+
66
+ ```txt
67
+ wget
68
+ curl
69
+ git
70
+ ```
71
+
72
+ ### Config [ports.txt](./ports.txt)
73
+
74
+ You can add some ports enabled for docker container here.
75
+
76
+ ```txt
77
+ -p 6006:6006
78
+ -p 8080:8080
79
+ ```
80
+
81
+ ### Config [postinstallscript.sh](./postinstallscript.sh)
82
+
83
+ You can add your custom script to run when build docker image.
84
+
85
+ ## Q&A
86
+
87
+ If you have any use problems, please contact to <simonlee235@gmail.com>.
docker/build_docker.png ADDED
docker/docker_build.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ WORK_DIR=$(dirname "$(readlink -f "$0")")
4
+ cd $WORK_DIR
5
+
6
+ source setup_env.sh
7
+
8
+ docker build -t $TAG --build-arg BASE_TAG=$BASE_TAG --build-arg USER_NAME=$USER_NAME --build-arg USER_PASSWD=$USER_PASSWD . -f dockerfile.base
docker/docker_run.sh ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ WORK_DIR=$(dirname "$(readlink -f "$0")")
4
+ source $WORK_DIR/setup_env.sh
5
+
6
+ RUNNING_IDS="$(docker ps --filter ancestor=$TAG --format "{{.ID}}")"
7
+
8
+ if [ -n "$RUNNING_IDS" ]; then
9
+ # Initialize an array to hold the container IDs
10
+ declare -a container_ids=($RUNNING_IDS)
11
+
12
+ # Get the first container ID using array indexing
13
+ ID=${container_ids[0]}
14
+
15
+ # Print the first container ID
16
+ echo ' '
17
+ echo "The running container ID is: $ID, enter it!"
18
+ else
19
+ echo ' '
20
+ echo "Not found running containers, run it!"
21
+
22
+ # Run a new docker container instance
23
+ ID=$(docker run \
24
+ --rm \
25
+ --gpus all \
26
+ -itd \
27
+ --ipc=host \
28
+ --ulimit memlock=-1 \
29
+ --ulimit stack=67108864 \
30
+ -e DISPLAY=$DISPLAY \
31
+ -v /tmp/.X11-unix/:/tmp/.X11-unix/ \
32
+ -v $PWD:/home/$USER_NAME/workspace \
33
+ -w /home/$USER_NAME/workspace \
34
+ $(cat $WORK_DIR/ports.txt) \
35
+ $TAG)
36
+ fi
37
+
38
+ docker logs $ID
39
+
40
+ echo ' '
41
+ echo ' '
42
+ echo '========================================='
43
+ echo ' '
44
+
45
+ docker exec -it $ID bash
docker/dockerfile.base ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG BASE_TAG
2
+ FROM ${BASE_TAG}
3
+ ARG USER_NAME=myuser
4
+ ARG USER_PASSWD=111111
5
+ ARG DEBIAN_FRONTEND=noninteractive
6
+
7
+ # Pre-install packages, pip install requirements and run post install script.
8
+ COPY packages.txt .
9
+ COPY requirements.txt .
10
+ COPY postinstallscript.sh .
11
+ RUN apt-get update && apt-get install -y sudo $(cat packages.txt)
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+ RUN bash postinstallscript.sh
14
+
15
+ # Create a new user and group using the username argument
16
+ RUN groupadd -r ${USER_NAME} && useradd -r -m -g${USER_NAME} ${USER_NAME}
17
+ RUN echo "${USER_NAME}:${USER_PASSWD}" | chpasswd
18
+ RUN usermod -aG sudo ${USER_NAME}
19
+ USER ${USER_NAME}
20
+ ENV USER=${USER_NAME}
21
+ WORKDIR /home/${USER_NAME}/workspace
22
+
23
+ # Set the prompt to highlight the username
24
+ RUN echo "export PS1='\[\033[01;32m\]\u\[\033[00m\]@\[\033[01;34m\]\h\[\033[00m\]:\[\033[01;36m\]\w\[\033[00m\]\$'" >> /home/${USER_NAME}/.bashrc
docker/packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ wget
2
+ curl
3
+ git
docker/ports.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ -p 6006:6006
docker/postinstallscript.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # this script will run when build docker image.
3
+
docker/requirements.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ setuptools>=61.0
2
+ torch==2.0.1
3
+ torchvision==0.15.2
4
+ transformers==4.32.0
5
+ albumentations==1.4.0
6
+ av==11.0.0
7
+ decord==0.6.0
8
+ einops==0.3.0
9
+ fastapi==0.110.0
10
+ accelerate==0.21.0
11
+ gdown==5.1.0
12
+ h5py==3.10.0
13
+ idna==3.6
14
+ imageio==2.34.0
15
+ matplotlib==3.7.5
16
+ numpy==1.24.4
17
+ omegaconf==2.1.1
18
+ opencv-python==4.9.0.80
19
+ opencv-python-headless==4.9.0.80
20
+ pandas==2.0.3
21
+ pillow==10.2.0
22
+ pydub==0.25.1
23
+ pytorch-lightning==1.4.2
24
+ pytorchvideo==0.1.5
25
+ PyYAML==6.0.1
26
+ regex==2023.12.25
27
+ requests==2.31.0
28
+ scikit-learn==1.3.2
29
+ scipy==1.10.1
30
+ six==1.16.0
31
+ tensorboard==2.14.0
32
+ test-tube==0.7.5
33
+ timm==0.9.16
34
+ torchdiffeq==0.2.3
35
+ torchmetrics==0.5.0
36
+ tqdm==4.66.2
37
+ urllib3==2.2.1
38
+ uvicorn==0.27.1
39
+ diffusers==0.24.0
40
+ scikit-video==1.1.11
docker/run_docker.png ADDED
docker/setup_env.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Docker tag for new build image
2
+ TAG=open_sora_plan:dev
3
+
4
+ # Base docker image tag used by docker build
5
+ BASE_TAG=nvcr.io/nvidia/pytorch:23.05-py3
6
+
7
+ # User name used in docker container
8
+ USER_NAME=developer
9
+
10
+ # User password used in docker container
11
+ USER_PASSWD=666666
docs/CausalVideoVAE.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CausalVideoVAE Report
2
+
3
+ ## Examples
4
+
5
+ ### Image Reconstruction
6
+
7
+ Resconstruction in **1536×1024**.
8
+
9
+ <img src="https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/1684c3ec-245d-4a60-865c-b8946d788eb9" width="45%"/> <img src="https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/46ef714e-3e5b-492c-aec4-3793cb2260b5" width="45%"/>
10
+
11
+
12
+
13
+ ### Video Reconstruction
14
+
15
+ We reconstruct two videos with **720×1280**. Since github can't put too big video, we put it here: [1](https://streamable.com/gqojal), [2](https://streamable.com/6nu3j8).
16
+
17
+ https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/c100bb02-2420-48a3-9d7b-4608a41f14aa
18
+
19
+ https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/8aa8f587-d9f1-4e8b-8a82-d3bf9ba91d68
20
+
21
+ ## Model Structure
22
+
23
+ ![image](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/e3c8b35d-a217-4d96-b2e9-5c248a2859c8)
24
+
25
+
26
+ The Causal Video VAE architecture inherits from the [Stable-Diffusion Image VAE](https://github.com/CompVis/stable-diffusion/tree/main). To ensure that the pretrained weights of the Image VAE can be seamlessly applied to the Video VAE, the model structure has been designed as follows:
27
+
28
+ **1. CausalConv3D**: Converting Conv2D to CausalConv3D enables joint training of image and video data. CausalConv3D applies a special treatment to the first frame, as it does not have access to subsequent frames. For more specific details, please refer to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/145
29
+
30
+ **2. Initialization**: There are two common [methods](https://github.com/hassony2/inflated_convnets_pytorch/blob/master/src/inflate.py#L5) to expand Conv2D to Conv3D: average initialization and center initialization. But we employ a specific initialization method (tail initialization). This initialization method ensures that without any training, the model is capable of directly reconstructing images, and even videos.
31
+
32
+ ## Training Details
33
+
34
+ <img width="833" alt="image" src="https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/9ffb6dc4-23f6-4274-a066-bbebc7522a14">
35
+
36
+ We present the loss curves for two distinct initialization methods under 17×256×256. The yellow curve represents the loss using tail init, while the blue curve corresponds to the loss from center initialization. As shown in the graph, tail initialization demonstrates better performance on the loss curve. Additionally, **we found that center initialization leads to error accumulation**, causing the collapse over extended durations.
docs/Contribution_Guidelines.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to the Open-Sora Plan Community
2
+
3
+ The Open-Sora Plan open-source community is a collaborative initiative driven by the community, emphasizing a commitment to being free and void of exploitation. Organized spontaneously by community members, we invite you to contribute to the Open-Sora Plan open-source community and help elevate it to new heights!
4
+
5
+ ## Submitting a Pull Request (PR)
6
+
7
+ As a contributor, before submitting your request, kindly follow these guidelines:
8
+
9
+ 1. Start by checking the [Open-Sora Plan GitHub](https://github.com/PKU-YuanGroup/Open-Sora-Plan/pulls) to see if there are any open or closed pull requests related to your intended submission. Avoid duplicating existing work.
10
+
11
+ 2. [Fork](https://github.com/PKU-YuanGroup/Open-Sora-Plan/fork) the [open-sora plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) repository and download your forked repository to your local machine.
12
+
13
+ ```bash
14
+ git clone [your-forked-repository-url]
15
+ ```
16
+
17
+ 3. Add the original Open-Sora Plan repository as a remote to sync with the latest updates:
18
+
19
+ ```bash
20
+ git remote add upstream https://github.com/PKU-YuanGroup/Open-Sora-Plan
21
+ ```
22
+
23
+ 4. Sync the code from the main repository to your local machine, and then push it back to your forked remote repository.
24
+
25
+ ```
26
+ # Pull the latest code from the upstream branch
27
+ git fetch upstream
28
+
29
+ # Switch to the main branch
30
+ git checkout main
31
+
32
+ # Merge the updates from the upstream branch into main, synchronizing the local main branch with the upstream
33
+ git merge upstream/main
34
+
35
+ # Additionally, sync the local main branch to the remote branch of your forked repository
36
+ git push origin main
37
+ ```
38
+
39
+
40
+ > Note: Sync the code from the main repository before each submission.
41
+
42
+ 5. Create a branch in your forked repository for your changes, ensuring the branch name is meaningful.
43
+
44
+ ```bash
45
+ git checkout -b my-docs-branch main
46
+ ```
47
+
48
+ 6. While making modifications and committing changes, adhere to our [Commit Message Format](#Commit-Message-Format).
49
+
50
+ ```bash
51
+ git commit -m "[docs]: xxxx"
52
+ ```
53
+
54
+ 7. Push your changes to your GitHub repository.
55
+
56
+ ```bash
57
+ git push origin my-docs-branch
58
+ ```
59
+
60
+ 8. Submit a pull request to `Open-Sora-Plan:main` on the GitHub repository page.
61
+
62
+ ## Commit Message Format
63
+
64
+ Commit messages must include both `<type>` and `<summary>` sections.
65
+
66
+ ```bash
67
+ [<type>]: <summary>
68
+ │ │
69
+ │ └─⫸ Briefly describe your changes, without ending with a period.
70
+
71
+ └─⫸ Commit Type: |docs|feat|fix|refactor|
72
+ ```
73
+
74
+ ### Type
75
+
76
+ * **docs**: Modify or add documents.
77
+ * **feat**: Introduce a new feature.
78
+ * **fix**: Fix a bug.
79
+ * **refactor**: Restructure code, excluding new features or bug fixes.
80
+
81
+ ### Summary
82
+
83
+ Describe modifications in English, without ending with a period.
84
+
85
+ > e.g., git commit -m "[docs]: add a contributing.md file"
86
+
87
+ This guideline is borrowed by [minisora](https://github.com/mini-sora/minisora). We sincerely appreciate MiniSora authors for their awesome templates.
docs/Data.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ **We need more dataset**, please refer to the [open-sora-Dataset](https://github.com/shaodong233/open-sora-Dataset) for details.
3
+
4
+
5
+ ## Sky
6
+
7
+
8
+ This is an un-condition datasets. [Link](https://drive.google.com/open?id=1xWLiU-MBGN7MrsFHQm4_yXmfHBsMbJQo)
9
+
10
+ ```
11
+ sky_timelapse
12
+ ├── readme
13
+ ├── sky_test
14
+ ├── sky_train
15
+ ├── test_videofolder.py
16
+ └── video_folder.py
17
+ ```
18
+
19
+ ## UCF101
20
+
21
+ We test the code with UCF-101 dataset. In order to download UCF-101 dataset, you can download the necessary files in [here](https://www.crcv.ucf.edu/data/UCF101.php). The code assumes a `ucf101` directory with the following structure
22
+ ```
23
+ UCF-101/
24
+ ApplyEyeMakeup/
25
+ v1.avi
26
+ ...
27
+ ...
28
+ YoYo/
29
+ v1.avi
30
+ ...
31
+ ```
32
+
33
+
34
+ ## Offline feature extraction
35
+ Coming soon...
docs/EVAL.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluate the generated videos quality
2
+
3
+ You can easily calculate the following video quality metrics, which supports the batch-wise process.
4
+ - **CLIP-SCORE**: It uses the pretrained CLIP model to measure the cosine similarity between two modalities.
5
+ - **FVD**: Frechét Video Distance
6
+ - **SSIM**: structural similarity index measure
7
+ - **LPIPS**: learned perceptual image patch similarity
8
+ - **PSNR**: peak-signal-to-noise ratio
9
+
10
+ # Requirement
11
+ ## Environment
12
+ - install Pytorch (torch>=1.7.1)
13
+ - install CLIP
14
+ ```
15
+ pip install git+https://github.com/openai/CLIP.git
16
+ ```
17
+ - install clip-cose from PyPi
18
+ ```
19
+ pip install clip-score
20
+ ```
21
+ - Other package
22
+ ```
23
+ pip install lpips
24
+ pip install scipy (scipy==1.7.3/1.9.3, if you use 1.11.3, **you will calculate a WRONG FVD VALUE!!!**)
25
+ pip install numpy
26
+ pip install pillow
27
+ pip install torchvision>=0.8.2
28
+ pip install ftfy
29
+ pip install regex
30
+ pip install tqdm
31
+ ```
32
+ ## Pretrain model
33
+ - FVD
34
+ Before you cacluate FVD, you should first download the FVD pre-trained model. You can manually download any of the following and put it into FVD folder.
35
+ - `i3d_torchscript.pt` from [here](https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt)
36
+ - `i3d_pretrained_400.pt` from [here](https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI)
37
+
38
+
39
+ ## Other Notices
40
+ 1. Make sure the pixel value of videos should be in [0, 1].
41
+ 2. We average SSIM when images have 3 channels, ssim is the only metric extremely sensitive to gray being compared to b/w.
42
+ 3. Because the i3d model downsamples in the time dimension, `frames_num` should > 10 when calculating FVD, so FVD calculation begins from 10-th frame, like upper example.
43
+ 4. For grayscale videos, we multiply to 3 channels
44
+ 5. data input specifications for clip_score
45
+ > - Image Files:All images should be stored in a single directory. The image files can be in either .png or .jpg format.
46
+ >
47
+ > - Text Files: All text data should be contained in plain text files in a separate directory. These text files should have the extension .txt.
48
+ >
49
+ > Note: The number of files in the image directory should be exactly equal to the number of files in the text directory. Additionally, the files in the image directory and text directory should be paired by file name. For instance, if there is a cat.png in the image directory, there should be a corresponding cat.txt in the text directory.
50
+ >
51
+ > Directory Structure Example:
52
+ > ```
53
+ > ├── path/to/image
54
+ > │ ├── cat.png
55
+ > │ ├── dog.png
56
+ > │ └── bird.jpg
57
+ > └── path/to/text
58
+ > ├── cat.txt
59
+ > ├── dog.txt
60
+ > └── bird.txt
61
+ > ```
62
+
63
+ 6. data input specifications for fvd, psnr, ssim, lpips
64
+
65
+ > Directory Structure Example:
66
+ > ```
67
+ > ├── path/to/generated_image
68
+ > │ ├── cat.mp4
69
+ > │ ├── dog.mp4
70
+ > │ └── bird.mp4
71
+ > └── path/to/real_image
72
+ > ├── cat.mp4
73
+ > ├── dog.mp4
74
+ > └── bird.mp4
75
+ > ```
76
+
77
+
78
+
79
+ # Usage
80
+
81
+ ```
82
+ # you change the file path and need to set the frame_num, resolution etc...
83
+
84
+ # clip_score cross modality
85
+ cd opensora/eval
86
+ bash script/cal_clip_score.sh
87
+
88
+
89
+
90
+ # fvd
91
+ cd opensora/eval
92
+ bash script/cal_fvd.sh
93
+
94
+ # psnr
95
+ cd opensora/eval
96
+ bash eval/script/cal_psnr.sh
97
+
98
+
99
+ # ssim
100
+ cd opensora/eval
101
+ bash eval/script/cal_ssim.sh
102
+
103
+
104
+ # lpips
105
+ cd opensora/eval
106
+ bash eval/script/cal_lpips.sh
107
+ ```
108
+
109
+ # Acknowledgement
110
+ The evaluation codebase refers to [clip-score](https://github.com/Taited/clip-score) and [common_metrics](https://github.com/JunyaoHu/common_metrics_on_video_quality).
docs/VQVAE.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VQVAE Documentation
2
+
3
+ # Introduction
4
+
5
+ Vector Quantized Variational AutoEncoders (VQ-VAE) is a type of autoencoder that uses a discrete latent representation. It is particularly useful for tasks that require discrete latent variables, such as text-to-speech and video generation.
6
+
7
+ # Usage
8
+
9
+ ## Initialization
10
+
11
+ To initialize a VQVAE model, you can use the `VideoGPTVQVAE` class. This class is a part of the `opensora.models.ae` module.
12
+
13
+ ```python
14
+ from opensora.models.ae import VideoGPTVQVAE
15
+
16
+ vqvae = VideoGPTVQVAE()
17
+ ```
18
+
19
+ ### Training
20
+
21
+ To train the VQVAE model, you can use the `train_videogpt.sh` script. This script will train the model using the parameters specified in the script.
22
+
23
+ ```bash
24
+ bash scripts/videogpt/train_videogpt.sh
25
+ ```
26
+
27
+ ### Loading Pretrained Models
28
+
29
+ You can load a pretrained model using the `download_and_load_model` method. This method will download the checkpoint file and load the model.
30
+
31
+ ```python
32
+ vqvae = VideoGPTVQVAE.download_and_load_model("bair_stride4x2x2")
33
+ ```
34
+
35
+ Alternatively, you can load a model from a checkpoint using the `load_from_checkpoint` method.
36
+
37
+ ```python
38
+ vqvae = VQVAEModel.load_from_checkpoint("results/VQVAE/checkpoint-1000")
39
+ ```
40
+
41
+ ### Encoding and Decoding
42
+
43
+ You can encode a video using the `encode` method. This method will return the encodings and embeddings of the video.
44
+
45
+ ```python
46
+ encodings, embeddings = vqvae.encode(x_vae, include_embeddings=True)
47
+ ```
48
+
49
+ You can reconstruct a video from its encodings using the decode method.
50
+
51
+ ```python
52
+ video_recon = vqvae.decode(encodings)
53
+ ```
54
+
55
+ ## Testing
56
+
57
+ You can test the VQVAE model by reconstructing a video. The `examples/rec_video.py` script provides an example of how to do this.
examples/get_latents_std.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader, Subset
3
+ import sys
4
+ sys.path.append(".")
5
+ from opensora.models.ae.videobase import CausalVAEModel, CausalVAEDataset
6
+
7
+ num_workers = 4
8
+ batch_size = 12
9
+
10
+ torch.manual_seed(0)
11
+ torch.set_grad_enabled(False)
12
+
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+
15
+ pretrained_model_name_or_path = 'results/causalvae/checkpoint-26000'
16
+ data_path = '/remote-home1/dataset/UCF-101'
17
+ video_num_frames = 17
18
+ resolution = 128
19
+ sample_rate = 10
20
+
21
+ vae = CausalVAEModel.load_from_checkpoint(pretrained_model_name_or_path)
22
+ vae.to(device)
23
+
24
+ dataset = CausalVAEDataset(data_path, sequence_length=video_num_frames, resolution=resolution, sample_rate=sample_rate)
25
+ subset_indices = list(range(1000))
26
+ subset_dataset = Subset(dataset, subset_indices)
27
+ loader = DataLoader(subset_dataset, batch_size=8, pin_memory=True)
28
+
29
+ all_latents = []
30
+ for video_data in loader:
31
+ video_data = video_data['video'].to(device)
32
+ latents = vae.encode(video_data).sample()
33
+ all_latents.append(video_data.cpu())
34
+
35
+ all_latents_tensor = torch.cat(all_latents)
36
+ std = all_latents_tensor.std().item()
37
+ normalizer = 1 / std
38
+ print(f'{normalizer = }')
examples/prompt_list_0.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ A quiet beach at dawn, the waves gently lapping at the shore and the sky painted in pastel hues.
2
+ A quiet beach at dawn, the waves softly lapping at the shore, pink and orange hues painting the sky, offering a moment of solitude and reflection.
3
+ The majestic beauty of a waterfall cascading down a cliff into a serene lake.
4
+ Sunset over the sea.
5
+ a cat wearing sunglasses and working as a lifeguard at pool.
6
+ Slow pan upward of blazing oak fire in an indoor fireplace.
7
+ Yellow and black tropical fish dart through the sea.
8
+ a serene winter scene in a forest. The forest is blanketed in a thick layer of snow, which has settled on the branches of the trees, creating a canopy of white. The trees, a mix of evergreens and deciduous, stand tall and silent, their forms partially obscured by the snow. The ground is a uniform white, with no visible tracks or signs of human activity. The sun is low in the sky, casting a warm glow that contrasts with the cool tones of the snow. The light filters through the trees, creating a soft, diffused illumination that highlights the texture of the snow and the contours of the trees. The overall style of the scene is naturalistic, with a focus on the tranquility and beauty of the winter landscape.
9
+ a dynamic interaction between the ocean and a large rock. The rock, with its rough texture and jagged edges, is partially submerged in the water, suggesting it is a natural feature of the coastline. The water around the rock is in motion, with white foam and waves crashing against the rock, indicating the force of the ocean's movement. The background is a vast expanse of the ocean, with small ripples and waves, suggesting a moderate sea state. The overall style of the scene is a realistic depiction of a natural landscape, with a focus on the interplay between the rock and the water.
10
+ A serene waterfall cascading down moss-covered rocks, its soothing sound creating a harmonious symphony with nature.
11
+ A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures.
12
+ The video captures the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene. The camera angle provides a bird's eye view of the waterfall, allowing viewers to appreciate the full height and grandeur of the waterfall. The video is a stunning representation of nature's power and beauty.
13
+ A vibrant scene of a snowy mountain landscape. The sky is filled with a multitude of colorful hot air balloons, each floating at different heights, creating a dynamic and lively atmosphere. The balloons are scattered across the sky, some closer to the viewer, others further away, adding depth to the scene. Below, the mountainous terrain is blanketed in a thick layer of snow, with a few patches of bare earth visible here and there. The snow-covered mountains provide a stark contrast to the colorful balloons, enhancing the visual appeal of the scene.
14
+ A serene underwater scene featuring a sea turtle swimming through a coral reef. The turtle, with its greenish-brown shell, is the main focus of the video, swimming gracefully towards the right side of the frame. The coral reef, teeming with life, is visible in the background, providing a vibrant and colorful backdrop to the turtle's journey. Several small fish, darting around the turtle, add a sense of movement and dynamism to the scene.
15
+ A snowy forest landscape with a dirt road running through it. The road is flanked by trees covered in snow, and the ground is also covered in snow. The sun is shining, creating a bright and serene atmosphere. The road appears to be empty, and there are no people or animals visible in the video. The style of the video is a natural landscape shot, with a focus on the beauty of the snowy forest and the peacefulness of the road.
16
+ The dynamic movement of tall, wispy grasses swaying in the wind. The sky above is filled with clouds, creating a dramatic backdrop. The sunlight pierces through the clouds, casting a warm glow on the scene. The grasses are a mix of green and brown, indicating a change in seasons. The overall style of the video is naturalistic, capturing the beauty of the landscape in a realistic manner. The focus is on the grasses and their movement, with the sky serving as a secondary element. The video does not contain any human or animal elements.
examples/rec_image.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append(".")
3
+ from PIL import Image
4
+ import torch
5
+ from torchvision.transforms import ToTensor, Compose, Resize, Normalize
6
+ from torch.nn import functional as F
7
+ from opensora.models.ae.videobase import CausalVAEModel
8
+ import argparse
9
+ import numpy as np
10
+
11
+ def preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor:
12
+ transform = Compose(
13
+ [
14
+ ToTensor(),
15
+ Normalize((0.5), (0.5)),
16
+ Resize(size=short_size),
17
+ ]
18
+ )
19
+ outputs = transform(video_data)
20
+ outputs = outputs.unsqueeze(0).unsqueeze(2)
21
+ return outputs
22
+
23
+ def main(args: argparse.Namespace):
24
+ image_path = args.image_path
25
+ resolution = args.resolution
26
+ device = args.device
27
+
28
+ vqvae = CausalVAEModel.load_from_checkpoint(args.ckpt)
29
+ vqvae.eval()
30
+ vqvae = vqvae.to(device)
31
+
32
+ with torch.no_grad():
33
+ x_vae = preprocess(Image.open(image_path), resolution)
34
+ x_vae = x_vae.to(device)
35
+ latents = vqvae.encode(x_vae)
36
+ recon = vqvae.decode(latents.sample())
37
+ x = recon[0, :, 0, :, :]
38
+ x = x.squeeze()
39
+ x = x.detach().cpu().numpy()
40
+ x = np.clip(x, -1, 1)
41
+ x = (x + 1) / 2
42
+ x = (255*x).astype(np.uint8)
43
+ x = x.transpose(1,2,0)
44
+ image = Image.fromarray(x)
45
+ image.save(args.rec_path)
46
+
47
+
48
+ if __name__ == '__main__':
49
+ parser = argparse.ArgumentParser()
50
+ parser.add_argument('--image-path', type=str, default='')
51
+ parser.add_argument('--rec-path', type=str, default='')
52
+ parser.add_argument('--ckpt', type=str, default='')
53
+ parser.add_argument('--resolution', type=int, default=336)
54
+ parser.add_argument('--device', type=str, default='cuda')
55
+
56
+ args = parser.parse_args()
57
+ main(args)
examples/rec_imvi_vae.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import argparse
4
+ from typing import Optional
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ import torch
10
+ from PIL import Image
11
+ from decord import VideoReader, cpu
12
+ from torch.nn import functional as F
13
+ from pytorchvideo.transforms import ShortSideScale
14
+ from torchvision.transforms import Lambda, Compose
15
+
16
+ import sys
17
+ sys.path.append(".")
18
+ from opensora.dataset.transform import CenterCropVideo, resize
19
+ from opensora.models.ae.videobase import CausalVAEModel
20
+
21
+
22
+ def array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None:
23
+ height, width, channels = image_array[0].shape
24
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
25
+ video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))
26
+
27
+ for image in image_array:
28
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
29
+ video_writer.write(image_rgb)
30
+
31
+ video_writer.release()
32
+
33
+ def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None:
34
+ x = x.detach().cpu()
35
+ x = torch.clamp(x, -1, 1)
36
+ x = (x + 1) / 2
37
+ x = x.permute(1, 2, 3, 0).numpy()
38
+ x = (255*x).astype(np.uint8)
39
+ array_to_video(x, fps=fps, output_file=output_file)
40
+ return
41
+
42
+ def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:
43
+ decord_vr = VideoReader(video_path, ctx=cpu(0))
44
+ total_frames = len(decord_vr)
45
+ sample_frames_len = sample_rate * num_frames
46
+
47
+ if total_frames > sample_frames_len:
48
+ s = random.randint(0, total_frames - sample_frames_len - 1)
49
+ s = 0
50
+ e = s + sample_frames_len
51
+ num_frames = num_frames
52
+ else:
53
+ s = 0
54
+ e = total_frames
55
+ num_frames = int(total_frames / sample_frames_len * num_frames)
56
+ print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path,
57
+ total_frames)
58
+
59
+
60
+ frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
61
+ video_data = decord_vr.get_batch(frame_id_list).asnumpy()
62
+ video_data = torch.from_numpy(video_data)
63
+ video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
64
+ return video_data
65
+
66
+
67
+ class ResizeVideo:
68
+ def __init__(
69
+ self,
70
+ size,
71
+ interpolation_mode="bilinear",
72
+ ):
73
+ self.size = size
74
+
75
+ self.interpolation_mode = interpolation_mode
76
+
77
+ def __call__(self, clip):
78
+ _, _, h, w = clip.shape
79
+ if w < h:
80
+ new_h = int(math.floor((float(h) / w) * self.size))
81
+ new_w = self.size
82
+ else:
83
+ new_h = self.size
84
+ new_w = int(math.floor((float(w) / h) * self.size))
85
+ return torch.nn.functional.interpolate(
86
+ clip, size=(new_h, new_w), mode=self.interpolation_mode, align_corners=False, antialias=True
87
+ )
88
+
89
+ def __repr__(self) -> str:
90
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
91
+
92
+ def preprocess(video_data: torch.Tensor, short_size: int = 128, crop_size: Optional[int] = None) -> torch.Tensor:
93
+
94
+ transform = Compose(
95
+ [
96
+ Lambda(lambda x: ((x / 255.0) * 2 - 1)),
97
+ ResizeVideo(size=short_size),
98
+ CenterCropVideo(crop_size) if crop_size is not None else Lambda(lambda x: x),
99
+ ]
100
+ )
101
+
102
+ video_outputs = transform(video_data)
103
+ video_outputs = torch.unsqueeze(video_outputs, 0)
104
+
105
+ return video_outputs
106
+
107
+
108
+ def main(args: argparse.Namespace):
109
+ video_path = args.video_path
110
+ num_frames = args.num_frames
111
+ resolution = args.resolution
112
+ crop_size = args.crop_size
113
+ sample_fps = args.sample_fps
114
+ sample_rate = args.sample_rate
115
+ device = args.device
116
+ vqvae = CausalVAEModel.from_pretrained(args.ckpt)
117
+ if args.enable_tiling:
118
+ vqvae.enable_tiling()
119
+ vqvae.tile_overlap_factor = args.tile_overlap_factor
120
+ vqvae.eval()
121
+ vqvae = vqvae.to(device)
122
+ vqvae = vqvae # .to(torch.float16)
123
+
124
+ with torch.no_grad():
125
+ x_vae = preprocess(read_video(video_path, num_frames, sample_rate), resolution, crop_size)
126
+ x_vae = x_vae.to(device) # b c t h w
127
+ x_vae = x_vae # .to(torch.float16)
128
+ latents = vqvae.encode(x_vae).sample() # .to(torch.float16)
129
+ video_recon = vqvae.decode(latents)
130
+
131
+ if video_recon.shape[2] == 1:
132
+ x = video_recon[0, :, 0, :, :]
133
+ x = x.squeeze()
134
+ x = x.detach().cpu().numpy()
135
+ x = np.clip(x, -1, 1)
136
+ x = (x + 1) / 2
137
+ x = (255 * x).astype(np.uint8)
138
+ x = x.transpose(1, 2, 0)
139
+ image = Image.fromarray(x)
140
+ image.save(args.rec_path.replace('mp4', 'jpg'))
141
+ else:
142
+ custom_to_video(video_recon[0], fps=sample_fps/sample_rate, output_file=args.rec_path)
143
+
144
+ if __name__ == '__main__':
145
+ parser = argparse.ArgumentParser()
146
+ parser.add_argument('--video-path', type=str, default='')
147
+ parser.add_argument('--rec-path', type=str, default='')
148
+ parser.add_argument('--ckpt', type=str, default='results/pretrained')
149
+ parser.add_argument('--sample-fps', type=int, default=30)
150
+ parser.add_argument('--resolution', type=int, default=336)
151
+ parser.add_argument('--crop-size', type=int, default=None)
152
+ parser.add_argument('--num-frames', type=int, default=100)
153
+ parser.add_argument('--sample-rate', type=int, default=1)
154
+ parser.add_argument('--device', type=str, default="cuda")
155
+ parser.add_argument('--tile_overlap_factor', type=float, default=0.25)
156
+ parser.add_argument('--enable_tiling', action='store_true')
157
+
158
+ args = parser.parse_args()
159
+ main(args)
examples/rec_video.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import argparse
3
+ from typing import Optional
4
+
5
+ import cv2
6
+ import imageio
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ import torch
10
+ from decord import VideoReader, cpu
11
+ from torch.nn import functional as F
12
+ from pytorchvideo.transforms import ShortSideScale
13
+ from torchvision.transforms import Lambda, Compose
14
+ from torchvision.transforms._transforms_video import RandomCropVideo
15
+
16
+ import sys
17
+ sys.path.append(".")
18
+ from opensora.models.ae import VQVAEModel
19
+
20
+
21
+ def array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None:
22
+ height, width, channels = image_array[0].shape
23
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore
24
+ video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))
25
+
26
+ for image in image_array:
27
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
28
+ video_writer.write(image_rgb)
29
+
30
+ video_writer.release()
31
+
32
+ def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None:
33
+ x = x.detach().cpu()
34
+ x = torch.clamp(x, -0.5, 0.5)
35
+ x = (x + 0.5)
36
+ x = x.permute(1, 2, 3, 0).numpy() # (C, T, H, W) -> (T, H, W, C)
37
+ x = (255*x).astype(np.uint8)
38
+ # array_to_video(x, fps=fps, output_file=output_file)
39
+ imageio.mimwrite(output_file, x, fps=fps, quality=9)
40
+ return
41
+
42
+ def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:
43
+ decord_vr = VideoReader(video_path, ctx=cpu(0))
44
+ total_frames = len(decord_vr)
45
+ sample_frames_len = sample_rate * num_frames
46
+
47
+ if total_frames > sample_frames_len:
48
+ s = random.randint(0, total_frames - sample_frames_len - 1)
49
+ e = s + sample_frames_len
50
+ num_frames = num_frames
51
+ else:
52
+ s = 0
53
+ e = total_frames
54
+ num_frames = int(total_frames / sample_frames_len * num_frames)
55
+ print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path,
56
+ total_frames)
57
+
58
+
59
+ frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
60
+ video_data = decord_vr.get_batch(frame_id_list).asnumpy()
61
+ video_data = torch.from_numpy(video_data)
62
+ video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
63
+ return video_data
64
+
65
+ def preprocess(video_data: torch.Tensor, short_size: int = 128, crop_size: Optional[int] = None) -> torch.Tensor:
66
+
67
+ transform = Compose(
68
+ [
69
+ # UniformTemporalSubsample(num_frames),
70
+ Lambda(lambda x: ((x / 255.0) - 0.5)),
71
+ # NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD),
72
+ ShortSideScale(size=short_size),
73
+ RandomCropVideo(size=crop_size) if crop_size is not None else Lambda(lambda x: x),
74
+ # RandomHorizontalFlipVideo(p=0.5),
75
+ ]
76
+ )
77
+
78
+ video_outputs = transform(video_data)
79
+ video_outputs = torch.unsqueeze(video_outputs, 0)
80
+
81
+ return video_outputs
82
+
83
+
84
+ def main(args: argparse.Namespace):
85
+ video_path = args.video_path
86
+ num_frames = args.num_frames
87
+ resolution = args.resolution
88
+ crop_size = args.crop_size
89
+ sample_fps = args.sample_fps
90
+ sample_rate = args.sample_rate
91
+ device = torch.device('cuda')
92
+ if args.ckpt in ['bair_stride4x2x2', 'ucf101_stride4x4x4', 'kinetics_stride4x4x4', 'kinetics_stride2x4x4']:
93
+ vqvae = VQVAEModel.download_and_load_model(args.ckpt)
94
+ else:
95
+ vqvae = VQVAEModel.load_from_checkpoint(args.ckpt)
96
+ vqvae.eval()
97
+ vqvae = vqvae.to(device)
98
+
99
+ with torch.no_grad():
100
+ x_vae = preprocess(read_video(video_path, num_frames, sample_rate), resolution, crop_size)
101
+ x_vae = x_vae.to(device)
102
+ encodings, embeddings = vqvae.encode(x_vae, include_embeddings=True)
103
+ video_recon = vqvae.decode(encodings)
104
+
105
+ # custom_to_video(x_vae[0], fps=sample_fps/sample_rate, output_file='origin_input.mp4')
106
+ custom_to_video(video_recon[0], fps=sample_fps/sample_rate, output_file=args.rec_path)
107
+
108
+
109
+ if __name__ == '__main__':
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument('--video-path', type=str, default='')
112
+ parser.add_argument('--rec-path', type=str, default='')
113
+ parser.add_argument('--ckpt', type=str, default='ucf101_stride4x4x4')
114
+ parser.add_argument('--sample-fps', type=int, default=30)
115
+ parser.add_argument('--resolution', type=int, default=336)
116
+ parser.add_argument('--crop-size', type=int, default=None)
117
+ parser.add_argument('--num-frames', type=int, default=100)
118
+ parser.add_argument('--sample-rate', type=int, default=1)
119
+ args = parser.parse_args()
120
+ main(args)
examples/rec_video_ae.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import argparse
3
+ from typing import Optional
4
+
5
+ import cv2
6
+ import imageio
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ import torch
10
+ from decord import VideoReader, cpu
11
+ from torch.nn import functional as F
12
+ from pytorchvideo.transforms import ShortSideScale
13
+ from torchvision.transforms import Lambda, Compose
14
+ from torchvision.transforms._transforms_video import RandomCropVideo
15
+
16
+ import sys
17
+ sys.path.append(".")
18
+ from opensora.models.ae import VQVAEModel
19
+
20
+
21
+ def array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None:
22
+ height, width, channels = image_array[0].shape
23
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore
24
+ video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))
25
+
26
+ for image in image_array:
27
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
28
+ video_writer.write(image_rgb)
29
+
30
+ video_writer.release()
31
+
32
+ def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None:
33
+ x = x.detach().cpu()
34
+ x = torch.clamp(x, -0.5, 0.5)
35
+ x = (x + 0.5)
36
+ x = x.permute(1, 2, 3, 0).numpy() # (C, T, H, W) -> (T, H, W, C)
37
+ x = (255*x).astype(np.uint8)
38
+ # array_to_video(x, fps=fps, output_file=output_file)
39
+ imageio.mimwrite(output_file, x, fps=fps, quality=9)
40
+ return
41
+
42
+ def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:
43
+ decord_vr = VideoReader(video_path, ctx=cpu(0))
44
+ total_frames = len(decord_vr)
45
+ sample_frames_len = sample_rate * num_frames
46
+
47
+ if total_frames > sample_frames_len:
48
+ s = random.randint(0, total_frames - sample_frames_len - 1)
49
+ e = s + sample_frames_len
50
+ num_frames = num_frames
51
+ else:
52
+ s = 0
53
+ e = total_frames
54
+ num_frames = int(total_frames / sample_frames_len * num_frames)
55
+ print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path,
56
+ total_frames)
57
+
58
+
59
+ frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
60
+ video_data = decord_vr.get_batch(frame_id_list).asnumpy()
61
+ video_data = torch.from_numpy(video_data)
62
+ video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
63
+ return video_data
64
+
65
+ def preprocess(video_data: torch.Tensor, short_size: int = 128, crop_size: Optional[int] = None) -> torch.Tensor:
66
+
67
+ transform = Compose(
68
+ [
69
+ # UniformTemporalSubsample(num_frames),
70
+ Lambda(lambda x: ((x / 255.0) - 0.5)),
71
+ # NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD),
72
+ ShortSideScale(size=short_size),
73
+ RandomCropVideo(size=crop_size) if crop_size is not None else Lambda(lambda x: x),
74
+ # RandomHorizontalFlipVideo(p=0.5),
75
+ ]
76
+ )
77
+
78
+ video_outputs = transform(video_data)
79
+ video_outputs = torch.unsqueeze(video_outputs, 0)
80
+
81
+ return video_outputs
82
+
83
+
84
+ def main(args: argparse.Namespace):
85
+ video_path = args.video_path
86
+ num_frames = args.num_frames
87
+ resolution = args.resolution
88
+ crop_size = args.crop_size
89
+ sample_fps = args.sample_fps
90
+ sample_rate = args.sample_rate
91
+ device = torch.device('cuda')
92
+ if args.ckpt in ['bair_stride4x2x2', 'ucf101_stride4x4x4', 'kinetics_stride4x4x4', 'kinetics_stride2x4x4']:
93
+ vqvae = VQVAEModel.download_and_load_model(args.ckpt)
94
+ else:
95
+ vqvae = VQVAEModel.load_from_checkpoint(args.ckpt)
96
+ vqvae.eval()
97
+ vqvae = vqvae.to(device)
98
+
99
+ with torch.no_grad():
100
+ x_vae = preprocess(read_video(video_path, num_frames, sample_rate), resolution, crop_size)
101
+ x_vae = x_vae.to(device)
102
+ encodings, embeddings = vqvae.encode(x_vae, include_embeddings=True)
103
+ video_recon = vqvae.decode(encodings)
104
+
105
+ # custom_to_video(x_vae[0], fps=sample_fps/sample_rate, output_file='origin_input.mp4')
106
+ custom_to_video(video_recon[0], fps=sample_fps/sample_rate, output_file=args.rec_path)
107
+
108
+
109
+ if __name__ == '__main__':
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument('--video-path', type=str, default='')
112
+ parser.add_argument('--rec-path', type=str, default='')
113
+ parser.add_argument('--ckpt', type=str, default='ucf101_stride4x4x4')
114
+ parser.add_argument('--sample-fps', type=int, default=30)
115
+ parser.add_argument('--resolution', type=int, default=336)
116
+ parser.add_argument('--crop-size', type=int, default=None)
117
+ parser.add_argument('--num-frames', type=int, default=100)
118
+ parser.add_argument('--sample-rate', type=int, default=1)
119
+ args = parser.parse_args()
120
+ main(args)
examples/rec_video_vae.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import argparse
3
+ import cv2
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+ import torch
8
+ from decord import VideoReader, cpu
9
+ from torch.nn import functional as F
10
+ from pytorchvideo.transforms import ShortSideScale
11
+ from torchvision.transforms import Lambda, Compose
12
+ from torchvision.transforms._transforms_video import CenterCropVideo
13
+ import sys
14
+ from torch.utils.data import Dataset, DataLoader, Subset
15
+ import os
16
+
17
+ sys.path.append(".")
18
+ from opensora.models.ae.videobase import CausalVAEModel
19
+ import torch.nn as nn
20
+
21
+ def array_to_video(
22
+ image_array: npt.NDArray, fps: float = 30.0, output_file: str = "output_video.mp4"
23
+ ) -> None:
24
+ height, width, channels = image_array[0].shape
25
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
26
+ video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))
27
+
28
+ for image in image_array:
29
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
30
+ video_writer.write(image_rgb)
31
+
32
+ video_writer.release()
33
+
34
+
35
+ def custom_to_video(
36
+ x: torch.Tensor, fps: float = 2.0, output_file: str = "output_video.mp4"
37
+ ) -> None:
38
+ x = x.detach().cpu()
39
+ x = torch.clamp(x, -1, 1)
40
+ x = (x + 1) / 2
41
+ x = x.permute(1, 2, 3, 0).float().numpy()
42
+ x = (255 * x).astype(np.uint8)
43
+ array_to_video(x, fps=fps, output_file=output_file)
44
+ return
45
+
46
+
47
+ def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:
48
+ decord_vr = VideoReader(video_path, ctx=cpu(0), num_threads=8)
49
+ total_frames = len(decord_vr)
50
+ sample_frames_len = sample_rate * num_frames
51
+
52
+ if total_frames > sample_frames_len:
53
+ s = 0
54
+ e = s + sample_frames_len
55
+ num_frames = num_frames
56
+ else:
57
+ s = 0
58
+ e = total_frames
59
+ num_frames = int(total_frames / sample_frames_len * num_frames)
60
+ print(
61
+ f"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}",
62
+ video_path,
63
+ total_frames,
64
+ )
65
+
66
+ frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
67
+ video_data = decord_vr.get_batch(frame_id_list).asnumpy()
68
+ video_data = torch.from_numpy(video_data)
69
+ video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
70
+ return video_data
71
+
72
+
73
+ class RealVideoDataset(Dataset):
74
+ def __init__(
75
+ self,
76
+ real_video_dir,
77
+ num_frames,
78
+ sample_rate=1,
79
+ crop_size=None,
80
+ resolution=128,
81
+ ) -> None:
82
+ super().__init__()
83
+ self.real_video_files = self._combine_without_prefix(real_video_dir)
84
+ self.num_frames = num_frames
85
+ self.sample_rate = sample_rate
86
+ self.crop_size = crop_size
87
+ self.short_size = resolution
88
+
89
+ def __len__(self):
90
+ return len(self.real_video_files)
91
+
92
+ def __getitem__(self, index):
93
+ if index >= len(self):
94
+ raise IndexError
95
+ real_video_file = self.real_video_files[index]
96
+ real_video_tensor = self._load_video(real_video_file)
97
+ video_name = os.path.basename(real_video_file)
98
+ return {'video': real_video_tensor, 'file_name': video_name }
99
+
100
+ def _load_video(self, video_path):
101
+ num_frames = self.num_frames
102
+ sample_rate = self.sample_rate
103
+ decord_vr = VideoReader(video_path, ctx=cpu(0))
104
+ total_frames = len(decord_vr)
105
+ sample_frames_len = sample_rate * num_frames
106
+
107
+ if total_frames > sample_frames_len:
108
+ s = 0
109
+ e = s + sample_frames_len
110
+ num_frames = num_frames
111
+ else:
112
+ s = 0
113
+ e = total_frames
114
+ num_frames = int(total_frames / sample_frames_len * num_frames)
115
+ print(
116
+ f"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}",
117
+ video_path,
118
+ total_frames,
119
+ )
120
+
121
+ frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
122
+ video_data = decord_vr.get_batch(frame_id_list).asnumpy()
123
+ video_data = torch.from_numpy(video_data)
124
+ video_data = video_data.permute(3, 0, 1, 2)
125
+ return _preprocess(
126
+ video_data, short_size=self.short_size, crop_size=self.crop_size
127
+ )
128
+
129
+ def _combine_without_prefix(self, folder_path, prefix="."):
130
+ folder = []
131
+ for name in os.listdir(folder_path):
132
+ if name[0] == prefix:
133
+ continue
134
+ folder.append(os.path.join(folder_path, name))
135
+ folder.sort()
136
+ return folder
137
+
138
+ def resize(x, resolution):
139
+ height, width = x.shape[-2:]
140
+ aspect_ratio = width / height
141
+ if width <= height:
142
+ new_width = resolution
143
+ new_height = int(resolution / aspect_ratio)
144
+ else:
145
+ new_height = resolution
146
+ new_width = int(resolution * aspect_ratio)
147
+ resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True)
148
+ return resized_x
149
+
150
+ def _preprocess(video_data, short_size=128, crop_size=None):
151
+ transform = Compose(
152
+ [
153
+ Lambda(lambda x: ((x / 255.0) * 2 - 1)),
154
+ Lambda(lambda x: resize(x, short_size)),
155
+ (
156
+ CenterCropVideo(crop_size=crop_size)
157
+ if crop_size is not None
158
+ else Lambda(lambda x: x)
159
+ ),
160
+ ]
161
+ )
162
+ video_outputs = transform(video_data)
163
+ video_outputs = _format_video_shape(video_outputs)
164
+ return video_outputs
165
+
166
+
167
+ def _format_video_shape(video, time_compress=4, spatial_compress=8):
168
+ time = video.shape[1]
169
+ height = video.shape[2]
170
+ width = video.shape[3]
171
+ new_time = (
172
+ (time - (time - 1) % time_compress)
173
+ if (time - 1) % time_compress != 0
174
+ else time
175
+ )
176
+ new_height = (
177
+ (height - (height) % spatial_compress)
178
+ if height % spatial_compress != 0
179
+ else height
180
+ )
181
+ new_width = (
182
+ (width - (width) % spatial_compress) if width % spatial_compress != 0 else width
183
+ )
184
+ return video[:, :new_time, :new_height, :new_width]
185
+
186
+
187
+ @torch.no_grad()
188
+ def main(args: argparse.Namespace):
189
+ real_video_dir = args.real_video_dir
190
+ generated_video_dir = args.generated_video_dir
191
+ ckpt = args.ckpt
192
+ sample_rate = args.sample_rate
193
+ resolution = args.resolution
194
+ crop_size = args.crop_size
195
+ num_frames = args.num_frames
196
+ sample_rate = args.sample_rate
197
+ device = args.device
198
+ sample_fps = args.sample_fps
199
+ batch_size = args.batch_size
200
+ num_workers = args.num_workers
201
+ subset_size = args.subset_size
202
+
203
+ if not os.path.exists(args.generated_video_dir):
204
+ os.makedirs(args.generated_video_dir, exist_ok=True)
205
+
206
+ data_type = torch.bfloat16
207
+
208
+ # ---- Load Model ----
209
+ device = args.device
210
+ vqvae = CausalVAEModel.from_pretrained(args.ckpt)
211
+ vqvae = vqvae.to(device).to(data_type)
212
+ if args.enable_tiling:
213
+ vqvae.enable_tiling()
214
+ vqvae.tile_overlap_factor = args.tile_overlap_factor
215
+ # ---- Load Model ----
216
+
217
+ # ---- Prepare Dataset ----
218
+ dataset = RealVideoDataset(
219
+ real_video_dir=real_video_dir,
220
+ num_frames=num_frames,
221
+ sample_rate=sample_rate,
222
+ crop_size=crop_size,
223
+ resolution=resolution,
224
+ )
225
+
226
+ if subset_size:
227
+ indices = range(subset_size)
228
+ dataset = Subset(dataset, indices=indices)
229
+
230
+ dataloader = DataLoader(
231
+ dataset, batch_size=batch_size, pin_memory=True, num_workers=num_workers
232
+ )
233
+ # ---- Prepare Dataset
234
+
235
+ # ---- Inference ----
236
+ for batch in tqdm(dataloader):
237
+ x, file_names = batch['video'], batch['file_name']
238
+ x = x.to(device=device, dtype=data_type) # b c t h w
239
+ latents = vqvae.encode(x).sample().to(data_type)
240
+ video_recon = vqvae.decode(latents)
241
+ for idx, video in enumerate(video_recon):
242
+ output_path = os.path.join(generated_video_dir, file_names[idx])
243
+ if args.output_origin:
244
+ os.makedirs(os.path.join(generated_video_dir, "origin/"), exist_ok=True)
245
+ origin_output_path = os.path.join(generated_video_dir, "origin/", file_names[idx])
246
+ custom_to_video(
247
+ x[idx], fps=sample_fps / sample_rate, output_file=origin_output_path
248
+ )
249
+ custom_to_video(
250
+ video, fps=sample_fps / sample_rate, output_file=output_path
251
+ )
252
+ # ---- Inference ----
253
+
254
+ if __name__ == "__main__":
255
+ parser = argparse.ArgumentParser()
256
+ parser.add_argument("--real_video_dir", type=str, default="")
257
+ parser.add_argument("--generated_video_dir", type=str, default="")
258
+ parser.add_argument("--ckpt", type=str, default="")
259
+ parser.add_argument("--sample_fps", type=int, default=30)
260
+ parser.add_argument("--resolution", type=int, default=336)
261
+ parser.add_argument("--crop_size", type=int, default=None)
262
+ parser.add_argument("--num_frames", type=int, default=17)
263
+ parser.add_argument("--sample_rate", type=int, default=1)
264
+ parser.add_argument("--batch_size", type=int, default=1)
265
+ parser.add_argument("--num_workers", type=int, default=8)
266
+ parser.add_argument("--subset_size", type=int, default=None)
267
+ parser.add_argument("--tile_overlap_factor", type=float, default=0.25)
268
+ parser.add_argument('--enable_tiling', action='store_true')
269
+ parser.add_argument('--output_origin', action='store_true')
270
+ parser.add_argument("--device", type=str, default="cuda")
271
+
272
+ args = parser.parse_args()
273
+ main(args)
274
+
opensora/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #
opensora/dataset/__init__.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.transforms import Compose
2
+ from transformers import AutoTokenizer
3
+
4
+ from .feature_datasets import T2V_Feature_dataset, T2V_T5_Feature_dataset
5
+ from torchvision import transforms
6
+ from torchvision.transforms import Lambda
7
+
8
+ from .landscope import Landscope
9
+ from .t2v_datasets import T2V_dataset
10
+ from .transform import ToTensorVideo, TemporalRandomCrop, RandomHorizontalFlipVideo, CenterCropResizeVideo
11
+ from .ucf101 import UCF101
12
+ from .sky_datasets import Sky
13
+
14
+ ae_norm = {
15
+ 'CausalVAEModel_4x8x8': Lambda(lambda x: 2. * x - 1.),
16
+ 'CausalVQVAEModel_4x4x4': Lambda(lambda x: x - 0.5),
17
+ 'CausalVQVAEModel_4x8x8': Lambda(lambda x: x - 0.5),
18
+ 'VQVAEModel_4x4x4': Lambda(lambda x: x - 0.5),
19
+ 'VQVAEModel_4x8x8': Lambda(lambda x: x - 0.5),
20
+ "bair_stride4x2x2": Lambda(lambda x: x - 0.5),
21
+ "ucf101_stride4x4x4": Lambda(lambda x: x - 0.5),
22
+ "kinetics_stride4x4x4": Lambda(lambda x: x - 0.5),
23
+ "kinetics_stride2x4x4": Lambda(lambda x: x - 0.5),
24
+ 'stabilityai/sd-vae-ft-mse': transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
25
+ 'stabilityai/sd-vae-ft-ema': transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
26
+ 'vqgan_imagenet_f16_1024': Lambda(lambda x: 2. * x - 1.),
27
+ 'vqgan_imagenet_f16_16384': Lambda(lambda x: 2. * x - 1.),
28
+ 'vqgan_gumbel_f8': Lambda(lambda x: 2. * x - 1.),
29
+
30
+ }
31
+ ae_denorm = {
32
+ 'CausalVAEModel_4x8x8': lambda x: (x + 1.) / 2.,
33
+ 'CausalVQVAEModel_4x4x4': lambda x: x + 0.5,
34
+ 'CausalVQVAEModel_4x8x8': lambda x: x + 0.5,
35
+ 'VQVAEModel_4x4x4': lambda x: x + 0.5,
36
+ 'VQVAEModel_4x8x8': lambda x: x + 0.5,
37
+ "bair_stride4x2x2": lambda x: x + 0.5,
38
+ "ucf101_stride4x4x4": lambda x: x + 0.5,
39
+ "kinetics_stride4x4x4": lambda x: x + 0.5,
40
+ "kinetics_stride2x4x4": lambda x: x + 0.5,
41
+ 'stabilityai/sd-vae-ft-mse': lambda x: 0.5 * x + 0.5,
42
+ 'stabilityai/sd-vae-ft-ema': lambda x: 0.5 * x + 0.5,
43
+ 'vqgan_imagenet_f16_1024': lambda x: (x + 1.) / 2.,
44
+ 'vqgan_imagenet_f16_16384': lambda x: (x + 1.) / 2.,
45
+ 'vqgan_gumbel_f8': lambda x: (x + 1.) / 2.,
46
+ }
47
+
48
+ def getdataset(args):
49
+ temporal_sample = TemporalRandomCrop(args.num_frames * args.sample_rate) # 16 x
50
+ norm_fun = ae_norm[args.ae]
51
+ if args.dataset == 'ucf101':
52
+ transform = Compose(
53
+ [
54
+ ToTensorVideo(), # TCHW
55
+ CenterCropResizeVideo(size=args.max_image_size),
56
+ RandomHorizontalFlipVideo(p=0.5),
57
+ norm_fun,
58
+ ]
59
+ )
60
+ return UCF101(args, transform=transform, temporal_sample=temporal_sample)
61
+ if args.dataset == 'landscope':
62
+ transform = Compose(
63
+ [
64
+ ToTensorVideo(), # TCHW
65
+ CenterCropResizeVideo(size=args.max_image_size),
66
+ RandomHorizontalFlipVideo(p=0.5),
67
+ norm_fun,
68
+ ]
69
+ )
70
+ return Landscope(args, transform=transform, temporal_sample=temporal_sample)
71
+ elif args.dataset == 'sky':
72
+ transform = transforms.Compose([
73
+ ToTensorVideo(),
74
+ CenterCropResizeVideo(args.max_image_size),
75
+ RandomHorizontalFlipVideo(p=0.5),
76
+ norm_fun
77
+ ])
78
+ return Sky(args, transform=transform, temporal_sample=temporal_sample)
79
+ elif args.dataset == 't2v':
80
+ transform = transforms.Compose([
81
+ ToTensorVideo(),
82
+ CenterCropResizeVideo(args.max_image_size),
83
+ RandomHorizontalFlipVideo(p=0.5),
84
+ norm_fun
85
+ ])
86
+ tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name, cache_dir='./cache_dir')
87
+ return T2V_dataset(args, transform=transform, temporal_sample=temporal_sample, tokenizer=tokenizer)
88
+ elif args.dataset == 't2v_feature':
89
+ return T2V_Feature_dataset(args, temporal_sample)
90
+ elif args.dataset == 't2v_t5_feature':
91
+ transform = transforms.Compose([
92
+ ToTensorVideo(),
93
+ CenterCropResizeVideo(args.max_image_size),
94
+ RandomHorizontalFlipVideo(p=0.5),
95
+ norm_fun
96
+ ])
97
+ return T2V_T5_Feature_dataset(args, transform, temporal_sample)
98
+ else:
99
+ raise NotImplementedError(args.dataset)
opensora/dataset/extract_feature_dataset.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchvision
7
+ from PIL import Image
8
+ from torch.utils.data import Dataset
9
+
10
+ from opensora.utils.dataset_utils import DecordInit, is_image_file
11
+
12
+
13
+ class ExtractVideo2Feature(Dataset):
14
+ def __init__(self, args, transform):
15
+ self.data_path = args.data_path
16
+ self.transform = transform
17
+ self.v_decoder = DecordInit()
18
+ self.samples = list(glob(f'{self.data_path}'))
19
+
20
+ def __len__(self):
21
+ return len(self.samples)
22
+
23
+ def __getitem__(self, idx):
24
+ video_path = self.samples[idx]
25
+ video = self.decord_read(video_path)
26
+ video = self.transform(video) # T C H W -> T C H W
27
+ return video, video_path
28
+
29
+ def tv_read(self, path):
30
+ vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
31
+ total_frames = len(vframes)
32
+ frame_indice = list(range(total_frames))
33
+ video = vframes[frame_indice]
34
+ return video
35
+
36
+ def decord_read(self, path):
37
+ decord_vr = self.v_decoder(path)
38
+ total_frames = len(decord_vr)
39
+ frame_indice = list(range(total_frames))
40
+ video_data = decord_vr.get_batch(frame_indice).asnumpy()
41
+ video_data = torch.from_numpy(video_data)
42
+ video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
43
+ return video_data
44
+
45
+
46
+
47
+ class ExtractImage2Feature(Dataset):
48
+ def __init__(self, args, transform):
49
+ self.data_path = args.data_path
50
+ self.transform = transform
51
+ self.data_all = list(glob(f'{self.data_path}'))
52
+
53
+ def __len__(self):
54
+ return len(self.data_all)
55
+
56
+ def __getitem__(self, index):
57
+ path = self.data_all[index]
58
+ video_frame = torch.as_tensor(np.array(Image.open(path), dtype=np.uint8, copy=True)).unsqueeze(0)
59
+ video_frame = video_frame.permute(0, 3, 1, 2)
60
+ video_frame = self.transform(video_frame) # T C H W
61
+ # video_frame = video_frame.transpose(0, 1) # T C H W -> C T H W
62
+
63
+ return video_frame, path
64
+
opensora/dataset/feature_datasets.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import torch
4
+ import random
5
+ import torch.utils.data as data
6
+
7
+ import numpy as np
8
+ from glob import glob
9
+ from PIL import Image
10
+ from torch.utils.data import Dataset
11
+ from tqdm import tqdm
12
+
13
+ from opensora.dataset.transform import center_crop, RandomCropVideo
14
+ from opensora.utils.dataset_utils import DecordInit
15
+
16
+
17
+ class T2V_Feature_dataset(Dataset):
18
+ def __init__(self, args, temporal_sample):
19
+
20
+ self.video_folder = args.video_folder
21
+ self.num_frames = args.video_length
22
+ self.temporal_sample = temporal_sample
23
+
24
+ print('Building dataset...')
25
+ if os.path.exists('samples_430k.json'):
26
+ with open('samples_430k.json', 'r') as f:
27
+ self.samples = json.load(f)
28
+ else:
29
+ self.samples = self._make_dataset()
30
+ with open('samples_430k.json', 'w') as f:
31
+ json.dump(self.samples, f, indent=2)
32
+
33
+ self.use_image_num = args.use_image_num
34
+ self.use_img_from_vid = args.use_img_from_vid
35
+ if self.use_image_num != 0 and not self.use_img_from_vid:
36
+ self.img_cap_list = self.get_img_cap_list()
37
+
38
+ def _make_dataset(self):
39
+ all_mp4 = list(glob(os.path.join(self.video_folder, '**', '*.mp4'), recursive=True))
40
+ # all_mp4 = all_mp4[:1000]
41
+ samples = []
42
+ for i in tqdm(all_mp4):
43
+ video_id = os.path.basename(i).split('.')[0]
44
+ ae = os.path.split(i)[0].replace('data_split_tt', 'lb_causalvideovae444_feature')
45
+ ae = os.path.join(ae, f'{video_id}_causalvideovae444.npy')
46
+ if not os.path.exists(ae):
47
+ continue
48
+
49
+ t5 = os.path.split(i)[0].replace('data_split_tt', 'lb_t5_feature')
50
+ cond_list = []
51
+ cond_llava = os.path.join(t5, f'{video_id}_t5_llava_fea.npy')
52
+ mask_llava = os.path.join(t5, f'{video_id}_t5_llava_mask.npy')
53
+ if os.path.exists(cond_llava) and os.path.exists(mask_llava):
54
+ llava = dict(cond=cond_llava, mask=mask_llava)
55
+ cond_list.append(llava)
56
+ cond_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_fea.npy')
57
+ mask_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_mask.npy')
58
+ if os.path.exists(cond_sharegpt4v) and os.path.exists(mask_sharegpt4v):
59
+ sharegpt4v = dict(cond=cond_sharegpt4v, mask=mask_sharegpt4v)
60
+ cond_list.append(sharegpt4v)
61
+ if len(cond_list) > 0:
62
+ sample = dict(ae=ae, t5=cond_list)
63
+ samples.append(sample)
64
+ return samples
65
+
66
+ def __len__(self):
67
+ return len(self.samples)
68
+
69
+ def __getitem__(self, idx):
70
+ # try:
71
+ sample = self.samples[idx]
72
+ ae, t5 = sample['ae'], sample['t5']
73
+ t5 = random.choice(t5)
74
+ video_origin = np.load(ae)[0] # C T H W
75
+ _, total_frames, _, _ = video_origin.shape
76
+ # Sampling video frames
77
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
78
+ assert end_frame_ind - start_frame_ind >= self.num_frames
79
+ select_video_idx = np.linspace(start_frame_ind, end_frame_ind - 1, num=self.num_frames, dtype=int) # start, stop, num=50
80
+ # print('select_video_idx', total_frames, select_video_idx)
81
+ video = video_origin[:, select_video_idx] # C num_frames H W
82
+ video = torch.from_numpy(video)
83
+
84
+ cond = torch.from_numpy(np.load(t5['cond']))[0] # L
85
+ cond_mask = torch.from_numpy(np.load(t5['mask']))[0] # L D
86
+
87
+ if self.use_image_num != 0 and self.use_img_from_vid:
88
+ select_image_idx = np.random.randint(0, total_frames, self.use_image_num)
89
+ # print('select_image_idx', total_frames, self.use_image_num, select_image_idx)
90
+ images = video_origin[:, select_image_idx] # c, num_img, h, w
91
+ images = torch.from_numpy(images)
92
+ video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w
93
+ cond = torch.stack([cond] * (1+self.use_image_num)) # 1+self.use_image_num, l
94
+ cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) # 1+self.use_image_num, l
95
+ elif self.use_image_num != 0 and not self.use_img_from_vid:
96
+ images, captions = self.img_cap_list[idx]
97
+ raise NotImplementedError
98
+ else:
99
+ pass
100
+
101
+ return video, cond, cond_mask
102
+ # except Exception as e:
103
+ # print(f'Error with {e}, {sample}')
104
+ # return self.__getitem__(random.randint(0, self.__len__() - 1))
105
+
106
+ def get_img_cap_list(self):
107
+ raise NotImplementedError
108
+
109
+
110
+
111
+
112
+ class T2V_T5_Feature_dataset(Dataset):
113
+ def __init__(self, args, transform, temporal_sample):
114
+
115
+ self.video_folder = args.video_folder
116
+ self.num_frames = args.num_frames
117
+ self.transform = transform
118
+ self.temporal_sample = temporal_sample
119
+ self.v_decoder = DecordInit()
120
+
121
+ print('Building dataset...')
122
+ if os.path.exists('samples_430k.json'):
123
+ with open('samples_430k.json', 'r') as f:
124
+ self.samples = json.load(f)
125
+ self.samples = [dict(ae=i['ae'].replace('lb_causalvideovae444_feature', 'data_split_1024').replace('_causalvideovae444.npy', '.mp4'), t5=i['t5']) for i in self.samples]
126
+ else:
127
+ self.samples = self._make_dataset()
128
+ with open('samples_430k.json', 'w') as f:
129
+ json.dump(self.samples, f, indent=2)
130
+
131
+ self.use_image_num = args.use_image_num
132
+ self.use_img_from_vid = args.use_img_from_vid
133
+ if self.use_image_num != 0 and not self.use_img_from_vid:
134
+ self.img_cap_list = self.get_img_cap_list()
135
+
136
+ def _make_dataset(self):
137
+ all_mp4 = list(glob(os.path.join(self.video_folder, '**', '*.mp4'), recursive=True))
138
+ # all_mp4 = all_mp4[:1000]
139
+ samples = []
140
+ for i in tqdm(all_mp4):
141
+ video_id = os.path.basename(i).split('.')[0]
142
+ # ae = os.path.split(i)[0].replace('data_split', 'lb_causalvideovae444_feature')
143
+ # ae = os.path.join(ae, f'{video_id}_causalvideovae444.npy')
144
+ ae = i
145
+ if not os.path.exists(ae):
146
+ continue
147
+
148
+ t5 = os.path.split(i)[0].replace('data_split_1024', 'lb_t5_feature')
149
+ cond_list = []
150
+ cond_llava = os.path.join(t5, f'{video_id}_t5_llava_fea.npy')
151
+ mask_llava = os.path.join(t5, f'{video_id}_t5_llava_mask.npy')
152
+ if os.path.exists(cond_llava) and os.path.exists(mask_llava):
153
+ llava = dict(cond=cond_llava, mask=mask_llava)
154
+ cond_list.append(llava)
155
+ cond_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_fea.npy')
156
+ mask_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_mask.npy')
157
+ if os.path.exists(cond_sharegpt4v) and os.path.exists(mask_sharegpt4v):
158
+ sharegpt4v = dict(cond=cond_sharegpt4v, mask=mask_sharegpt4v)
159
+ cond_list.append(sharegpt4v)
160
+ if len(cond_list) > 0:
161
+ sample = dict(ae=ae, t5=cond_list)
162
+ samples.append(sample)
163
+ return samples
164
+
165
+ def __len__(self):
166
+ return len(self.samples)
167
+
168
+ def __getitem__(self, idx):
169
+ try:
170
+ sample = self.samples[idx]
171
+ ae, t5 = sample['ae'], sample['t5']
172
+ t5 = random.choice(t5)
173
+
174
+ video = self.decord_read(ae)
175
+ video = self.transform(video) # T C H W -> T C H W
176
+ video = video.transpose(0, 1) # T C H W -> C T H W
177
+ total_frames = video.shape[1]
178
+ cond = torch.from_numpy(np.load(t5['cond']))[0] # L
179
+ cond_mask = torch.from_numpy(np.load(t5['mask']))[0] # L D
180
+
181
+ if self.use_image_num != 0 and self.use_img_from_vid:
182
+ select_image_idx = np.random.randint(0, total_frames, self.use_image_num)
183
+ # print('select_image_idx', total_frames, self.use_image_num, select_image_idx)
184
+ images = video.numpy()[:, select_image_idx] # c, num_img, h, w
185
+ images = torch.from_numpy(images)
186
+ video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w
187
+ cond = torch.stack([cond] * (1+self.use_image_num)) # 1+self.use_image_num, l
188
+ cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) # 1+self.use_image_num, l
189
+ elif self.use_image_num != 0 and not self.use_img_from_vid:
190
+ images, captions = self.img_cap_list[idx]
191
+ raise NotImplementedError
192
+ else:
193
+ pass
194
+
195
+ return video, cond, cond_mask
196
+ except Exception as e:
197
+ print(f'Error with {e}, {sample}')
198
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
199
+
200
+ def decord_read(self, path):
201
+ decord_vr = self.v_decoder(path)
202
+ total_frames = len(decord_vr)
203
+ # Sampling video frames
204
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
205
+ # assert end_frame_ind - start_frame_ind >= self.num_frames
206
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
207
+ video_data = decord_vr.get_batch(frame_indice).asnumpy()
208
+ video_data = torch.from_numpy(video_data)
209
+ video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
210
+ return video_data
211
+
212
+ def get_img_cap_list(self):
213
+ raise NotImplementedError
opensora/dataset/landscope.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from glob import glob
4
+
5
+ import decord
6
+ import numpy as np
7
+ import torch
8
+ import torchvision
9
+ from decord import VideoReader, cpu
10
+ from torch.utils.data import Dataset
11
+ from torchvision.transforms import Compose, Lambda, ToTensor
12
+ from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo
13
+ from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample
14
+ from torch.nn import functional as F
15
+ import random
16
+
17
+ from opensora.utils.dataset_utils import DecordInit
18
+
19
+
20
+ class Landscope(Dataset):
21
+ def __init__(self, args, transform, temporal_sample):
22
+ self.data_path = args.data_path
23
+ self.num_frames = args.num_frames
24
+ self.transform = transform
25
+ self.temporal_sample = temporal_sample
26
+ self.v_decoder = DecordInit()
27
+
28
+ self.samples = self._make_dataset()
29
+ self.use_image_num = args.use_image_num
30
+ self.use_img_from_vid = args.use_img_from_vid
31
+ if self.use_image_num != 0 and not self.use_img_from_vid:
32
+ self.img_cap_list = self.get_img_cap_list()
33
+
34
+
35
+ def _make_dataset(self):
36
+ paths = list(glob(os.path.join(self.data_path, '**', '*.mp4'), recursive=True))
37
+
38
+ return paths
39
+
40
+ def __len__(self):
41
+ return len(self.samples)
42
+
43
+ def __getitem__(self, idx):
44
+ video_path = self.samples[idx]
45
+ try:
46
+ video = self.tv_read(video_path)
47
+ video = self.transform(video) # T C H W -> T C H W
48
+ video = video.transpose(0, 1) # T C H W -> C T H W
49
+ if self.use_image_num != 0 and self.use_img_from_vid:
50
+ select_image_idx = np.linspace(0, self.num_frames - 1, self.use_image_num, dtype=int)
51
+ assert self.num_frames >= self.use_image_num
52
+ images = video[:, select_image_idx] # c, num_img, h, w
53
+ video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w
54
+ elif self.use_image_num != 0 and not self.use_img_from_vid:
55
+ images, captions = self.img_cap_list[idx]
56
+ raise NotImplementedError
57
+ else:
58
+ pass
59
+ return video, 1
60
+ except Exception as e:
61
+ print(f'Error with {e}, {video_path}')
62
+ return self.__getitem__(random.randint(0, self.__len__()-1))
63
+
64
+ def tv_read(self, path):
65
+ vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
66
+ total_frames = len(vframes)
67
+
68
+ # Sampling video frames
69
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
70
+ # assert end_frame_ind - start_frame_ind >= self.num_frames
71
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
72
+ video = vframes[frame_indice] # (T, C, H, W)
73
+
74
+ return video
75
+
76
+ def decord_read(self, path):
77
+ decord_vr = self.v_decoder(path)
78
+ total_frames = len(decord_vr)
79
+ # Sampling video frames
80
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
81
+ # assert end_frame_ind - start_frame_ind >= self.num_frames
82
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
83
+
84
+ video_data = decord_vr.get_batch(frame_indice).asnumpy()
85
+ video_data = torch.from_numpy(video_data)
86
+ video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
87
+ return video_data
88
+
89
+ def get_img_cap_list(self):
90
+ raise NotImplementedError
opensora/dataset/sky_datasets.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import torch.utils.data as data
5
+
6
+ import numpy as np
7
+
8
+ from PIL import Image
9
+
10
+ from opensora.utils.dataset_utils import is_image_file
11
+
12
+
13
+ class Sky(data.Dataset):
14
+ def __init__(self, args, transform, temporal_sample=None, train=True):
15
+
16
+ self.args = args
17
+ self.data_path = args.data_path
18
+ self.transform = transform
19
+ self.temporal_sample = temporal_sample
20
+ self.num_frames = self.args.num_frames
21
+ self.sample_rate = self.args.sample_rate
22
+ self.data_all = self.load_video_frames(self.data_path)
23
+ self.use_image_num = args.use_image_num
24
+ self.use_img_from_vid = args.use_img_from_vid
25
+ if self.use_image_num != 0 and not self.use_img_from_vid:
26
+ self.img_cap_list = self.get_img_cap_list()
27
+
28
+ def __getitem__(self, index):
29
+
30
+ vframes = self.data_all[index]
31
+ total_frames = len(vframes)
32
+
33
+ # Sampling video frames
34
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
35
+ assert end_frame_ind - start_frame_ind >= self.num_frames
36
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, num=self.num_frames, dtype=int) # start, stop, num=50
37
+
38
+ select_video_frames = vframes[frame_indice[0]: frame_indice[-1]+1: self.sample_rate]
39
+
40
+ video_frames = []
41
+ for path in select_video_frames:
42
+ video_frame = torch.as_tensor(np.array(Image.open(path), dtype=np.uint8, copy=True)).unsqueeze(0)
43
+ video_frames.append(video_frame)
44
+ video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2)
45
+ video_clip = self.transform(video_clip)
46
+ video_clip = video_clip.transpose(0, 1) # T C H W -> C T H W
47
+
48
+ if self.use_image_num != 0 and self.use_img_from_vid:
49
+ select_image_idx = np.linspace(0, self.num_frames - 1, self.use_image_num, dtype=int)
50
+ assert self.num_frames >= self.use_image_num
51
+ images = video_clip[:, select_image_idx] # c, num_img, h, w
52
+ video_clip = torch.cat([video_clip, images], dim=1) # c, num_frame+num_img, h, w
53
+ elif self.use_image_num != 0 and not self.use_img_from_vid:
54
+ images, captions = self.img_cap_list[index]
55
+ raise NotImplementedError
56
+ else:
57
+ pass
58
+
59
+ return video_clip, 1
60
+
61
+ def __len__(self):
62
+ return self.video_num
63
+
64
+ def load_video_frames(self, dataroot):
65
+ data_all = []
66
+ frame_list = os.walk(dataroot)
67
+ for _, meta in enumerate(frame_list):
68
+ root = meta[0]
69
+ try:
70
+ frames = [i for i in meta[2] if is_image_file(i)]
71
+ frames = sorted(frames, key=lambda item: int(item.split('.')[0].split('_')[-1]))
72
+ except:
73
+ pass
74
+ # print(meta[0]) # root
75
+ # print(meta[2]) # files
76
+ frames = [os.path.join(root, item) for item in frames if is_image_file(item)]
77
+ if len(frames) > max(0, self.num_frames * self.sample_rate): # need all > (16 * frame-interval) videos
78
+ # if len(frames) >= max(0, self.target_video_len): # need all > 16 frames videos
79
+ data_all.append(frames)
80
+ self.video_num = len(data_all)
81
+ return data_all
82
+
83
+ def get_img_cap_list(self):
84
+ raise NotImplementedError
85
+
86
+ if __name__ == '__main__':
87
+
88
+ import argparse
89
+ import torchvision
90
+ import video_transforms
91
+ import torch.utils.data as data
92
+
93
+ from torchvision import transforms
94
+ from torchvision.utils import save_image
95
+
96
+
97
+ parser = argparse.ArgumentParser()
98
+ parser.add_argument("--num_frames", type=int, default=16)
99
+ parser.add_argument("--frame_interval", type=int, default=4)
100
+ parser.add_argument("--data-path", type=str, default="/path/to/datasets/sky_timelapse/sky_train/")
101
+ config = parser.parse_args()
102
+
103
+
104
+ target_video_len = config.num_frames
105
+
106
+ temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval)
107
+ trans = transforms.Compose([
108
+ video_transforms.ToTensorVideo(),
109
+ # video_transforms.CenterCropVideo(256),
110
+ video_transforms.CenterCropResizeVideo(256),
111
+ # video_transforms.RandomHorizontalFlipVideo(),
112
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
113
+ ])
114
+
115
+ taichi_dataset = Sky(config, transform=trans, temporal_sample=temporal_sample)
116
+ print(len(taichi_dataset))
117
+ taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1)
118
+
119
+ for i, video_data in enumerate(taichi_dataloader):
120
+ print(video_data['video'].shape)
121
+
122
+ # print(video_data.dtype)
123
+ # for i in range(target_video_len):
124
+ # save_image(video_data[0][i], os.path.join('./test_data', '%04d.png' % i), normalize=True, value_range=(-1, 1))
125
+
126
+ # video_ = ((video_data[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
127
+ # torchvision.io.write_video('./test_data' + 'test.mp4', video_, fps=8)
128
+ # exit()
opensora/dataset/t2v_datasets.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os, io, csv, math, random
3
+ import numpy as np
4
+ import torchvision
5
+ from einops import rearrange
6
+ from decord import VideoReader
7
+
8
+ import torch
9
+ import torchvision.transforms as transforms
10
+ from torch.utils.data.dataset import Dataset
11
+ from tqdm import tqdm
12
+
13
+ from opensora.utils.dataset_utils import DecordInit
14
+ from opensora.utils.utils import text_preprocessing
15
+
16
+
17
+
18
+ class T2V_dataset(Dataset):
19
+ def __init__(self, args, transform, temporal_sample, tokenizer):
20
+
21
+ # with open(args.data_path, 'r') as csvfile:
22
+ # self.samples = list(csv.DictReader(csvfile))
23
+ self.video_folder = args.video_folder
24
+ self.num_frames = args.num_frames
25
+ self.transform = transform
26
+ self.temporal_sample = temporal_sample
27
+ self.tokenizer = tokenizer
28
+ self.model_max_length = args.model_max_length
29
+ self.v_decoder = DecordInit()
30
+
31
+ with open(args.data_path, 'r') as f:
32
+ self.samples = json.load(f)
33
+ self.use_image_num = args.use_image_num
34
+ self.use_img_from_vid = args.use_img_from_vid
35
+ if self.use_image_num != 0 and not self.use_img_from_vid:
36
+ self.img_cap_list = self.get_img_cap_list()
37
+
38
+ def __len__(self):
39
+ return len(self.samples)
40
+
41
+ def __getitem__(self, idx):
42
+ try:
43
+ # video = torch.randn(3, 16, 128, 128)
44
+ # input_ids = torch.ones(1, 120).to(torch.long).squeeze(0)
45
+ # cond_mask = torch.cat([torch.ones(1, 60).to(torch.long), torch.ones(1, 60).to(torch.long)], dim=1).squeeze(0)
46
+ # return video, input_ids, cond_mask
47
+ video_path = self.samples[idx]['path']
48
+ video = self.decord_read(video_path)
49
+ video = self.transform(video) # T C H W -> T C H W
50
+ video = video.transpose(0, 1) # T C H W -> C T H W
51
+ text = self.samples[idx]['cap'][0]
52
+
53
+ text = text_preprocessing(text)
54
+ text_tokens_and_mask = self.tokenizer(
55
+ text,
56
+ max_length=self.model_max_length,
57
+ padding='max_length',
58
+ truncation=True,
59
+ return_attention_mask=True,
60
+ add_special_tokens=True,
61
+ return_tensors='pt'
62
+ )
63
+ input_ids = text_tokens_and_mask['input_ids'].squeeze(0)
64
+ cond_mask = text_tokens_and_mask['attention_mask'].squeeze(0)
65
+
66
+ if self.use_image_num != 0 and self.use_img_from_vid:
67
+ select_image_idx = np.linspace(0, self.num_frames-1, self.use_image_num, dtype=int)
68
+ assert self.num_frames >= self.use_image_num
69
+ images = video[:, select_image_idx] # c, num_img, h, w
70
+ video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w
71
+ input_ids = torch.stack([input_ids] * (1+self.use_image_num)) # 1+self.use_image_num, l
72
+ cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) # 1+self.use_image_num, l
73
+ elif self.use_image_num != 0 and not self.use_img_from_vid:
74
+ images, captions = self.img_cap_list[idx]
75
+ raise NotImplementedError
76
+ else:
77
+ pass
78
+
79
+ return video, input_ids, cond_mask
80
+ except Exception as e:
81
+ print(f'Error with {e}, {self.samples[idx]}')
82
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
83
+
84
+ def tv_read(self, path):
85
+ vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
86
+ total_frames = len(vframes)
87
+
88
+ # Sampling video frames
89
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
90
+ # assert end_frame_ind - start_frame_ind >= self.num_frames
91
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
92
+
93
+ video = vframes[frame_indice] # (T, C, H, W)
94
+
95
+ return video
96
+
97
+ def decord_read(self, path):
98
+ decord_vr = self.v_decoder(path)
99
+ total_frames = len(decord_vr)
100
+ # Sampling video frames
101
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
102
+ # assert end_frame_ind - start_frame_ind >= self.num_frames
103
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
104
+
105
+ video_data = decord_vr.get_batch(frame_indice).asnumpy()
106
+ video_data = torch.from_numpy(video_data)
107
+ video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
108
+ return video_data
109
+
110
+ def get_img_cap_list(self):
111
+ raise NotImplementedError
opensora/dataset/transform.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numbers
4
+ from torchvision.transforms import RandomCrop, RandomResizedCrop
5
+
6
+
7
+ def _is_tensor_video_clip(clip):
8
+ if not torch.is_tensor(clip):
9
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
10
+
11
+ if not clip.ndimension() == 4:
12
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
13
+
14
+ return True
15
+
16
+
17
+ def center_crop_arr(pil_image, image_size):
18
+ """
19
+ Center cropping implementation from ADM.
20
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
21
+ """
22
+ while min(*pil_image.size) >= 2 * image_size:
23
+ pil_image = pil_image.resize(
24
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
25
+ )
26
+
27
+ scale = image_size / min(*pil_image.size)
28
+ pil_image = pil_image.resize(
29
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
30
+ )
31
+
32
+ arr = np.array(pil_image)
33
+ crop_y = (arr.shape[0] - image_size) // 2
34
+ crop_x = (arr.shape[1] - image_size) // 2
35
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
36
+
37
+
38
+ def crop(clip, i, j, h, w):
39
+ """
40
+ Args:
41
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
42
+ """
43
+ if len(clip.size()) != 4:
44
+ raise ValueError("clip should be a 4D tensor")
45
+ return clip[..., i: i + h, j: j + w]
46
+
47
+
48
+ def resize(clip, target_size, interpolation_mode):
49
+ if len(target_size) != 2:
50
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
51
+ return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=True, antialias=True)
52
+
53
+
54
+ def resize_scale(clip, target_size, interpolation_mode):
55
+ if len(target_size) != 2:
56
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
57
+ H, W = clip.size(-2), clip.size(-1)
58
+ scale_ = target_size[0] / min(H, W)
59
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=True, antialias=True)
60
+
61
+
62
+ def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
63
+ """
64
+ Do spatial cropping and resizing to the video clip
65
+ Args:
66
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
67
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
68
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
69
+ h (int): Height of the cropped region.
70
+ w (int): Width of the cropped region.
71
+ size (tuple(int, int)): height and width of resized clip
72
+ Returns:
73
+ clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
74
+ """
75
+ if not _is_tensor_video_clip(clip):
76
+ raise ValueError("clip should be a 4D torch.tensor")
77
+ clip = crop(clip, i, j, h, w)
78
+ clip = resize(clip, size, interpolation_mode)
79
+ return clip
80
+
81
+
82
+ def center_crop(clip, crop_size):
83
+ if not _is_tensor_video_clip(clip):
84
+ raise ValueError("clip should be a 4D torch.tensor")
85
+ h, w = clip.size(-2), clip.size(-1)
86
+ th, tw = crop_size
87
+ if h < th or w < tw:
88
+ raise ValueError("height and width must be no smaller than crop_size")
89
+
90
+ i = int(round((h - th) / 2.0))
91
+ j = int(round((w - tw) / 2.0))
92
+ return crop(clip, i, j, th, tw)
93
+
94
+
95
+ def center_crop_using_short_edge(clip):
96
+ if not _is_tensor_video_clip(clip):
97
+ raise ValueError("clip should be a 4D torch.tensor")
98
+ h, w = clip.size(-2), clip.size(-1)
99
+ if h < w:
100
+ th, tw = h, h
101
+ i = 0
102
+ j = int(round((w - tw) / 2.0))
103
+ else:
104
+ th, tw = w, w
105
+ i = int(round((h - th) / 2.0))
106
+ j = 0
107
+ return crop(clip, i, j, th, tw)
108
+
109
+
110
+ def random_shift_crop(clip):
111
+ '''
112
+ Slide along the long edge, with the short edge as crop size
113
+ '''
114
+ if not _is_tensor_video_clip(clip):
115
+ raise ValueError("clip should be a 4D torch.tensor")
116
+ h, w = clip.size(-2), clip.size(-1)
117
+
118
+ if h <= w:
119
+ long_edge = w
120
+ short_edge = h
121
+ else:
122
+ long_edge = h
123
+ short_edge = w
124
+
125
+ th, tw = short_edge, short_edge
126
+
127
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
128
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
129
+ return crop(clip, i, j, th, tw)
130
+
131
+
132
+ def to_tensor(clip):
133
+ """
134
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
135
+ permute the dimensions of clip tensor
136
+ Args:
137
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
138
+ Return:
139
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
140
+ """
141
+ _is_tensor_video_clip(clip)
142
+ if not clip.dtype == torch.uint8:
143
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
144
+ # return clip.float().permute(3, 0, 1, 2) / 255.0
145
+ return clip.float() / 255.0
146
+
147
+
148
+ def normalize(clip, mean, std, inplace=False):
149
+ """
150
+ Args:
151
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
152
+ mean (tuple): pixel RGB mean. Size is (3)
153
+ std (tuple): pixel standard deviation. Size is (3)
154
+ Returns:
155
+ normalized clip (torch.tensor): Size is (T, C, H, W)
156
+ """
157
+ if not _is_tensor_video_clip(clip):
158
+ raise ValueError("clip should be a 4D torch.tensor")
159
+ if not inplace:
160
+ clip = clip.clone()
161
+ mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
162
+ # print(mean)
163
+ std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
164
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
165
+ return clip
166
+
167
+
168
+ def hflip(clip):
169
+ """
170
+ Args:
171
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
172
+ Returns:
173
+ flipped clip (torch.tensor): Size is (T, C, H, W)
174
+ """
175
+ if not _is_tensor_video_clip(clip):
176
+ raise ValueError("clip should be a 4D torch.tensor")
177
+ return clip.flip(-1)
178
+
179
+
180
+ class RandomCropVideo:
181
+ def __init__(self, size):
182
+ if isinstance(size, numbers.Number):
183
+ self.size = (int(size), int(size))
184
+ else:
185
+ self.size = size
186
+
187
+ def __call__(self, clip):
188
+ """
189
+ Args:
190
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
191
+ Returns:
192
+ torch.tensor: randomly cropped video clip.
193
+ size is (T, C, OH, OW)
194
+ """
195
+ i, j, h, w = self.get_params(clip)
196
+ return crop(clip, i, j, h, w)
197
+
198
+ def get_params(self, clip):
199
+ h, w = clip.shape[-2:]
200
+ th, tw = self.size
201
+
202
+ if h < th or w < tw:
203
+ raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
204
+
205
+ if w == tw and h == th:
206
+ return 0, 0, h, w
207
+
208
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
209
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
210
+
211
+ return i, j, th, tw
212
+
213
+ def __repr__(self) -> str:
214
+ return f"{self.__class__.__name__}(size={self.size})"
215
+
216
+
217
+ class CenterCropResizeVideo:
218
+ '''
219
+ First use the short side for cropping length,
220
+ center crop video, then resize to the specified size
221
+ '''
222
+
223
+ def __init__(
224
+ self,
225
+ size,
226
+ interpolation_mode="bilinear",
227
+ ):
228
+ if isinstance(size, tuple):
229
+ if len(size) != 2:
230
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
231
+ self.size = size
232
+ else:
233
+ self.size = (size, size)
234
+
235
+ self.interpolation_mode = interpolation_mode
236
+
237
+ def __call__(self, clip):
238
+ """
239
+ Args:
240
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
241
+ Returns:
242
+ torch.tensor: scale resized / center cropped video clip.
243
+ size is (T, C, crop_size, crop_size)
244
+ """
245
+ clip_center_crop = center_crop_using_short_edge(clip)
246
+ clip_center_crop_resize = resize(clip_center_crop, target_size=self.size,
247
+ interpolation_mode=self.interpolation_mode)
248
+ return clip_center_crop_resize
249
+
250
+ def __repr__(self) -> str:
251
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
252
+
253
+
254
+ class UCFCenterCropVideo:
255
+ '''
256
+ First scale to the specified size in equal proportion to the short edge,
257
+ then center cropping
258
+ '''
259
+
260
+ def __init__(
261
+ self,
262
+ size,
263
+ interpolation_mode="bilinear",
264
+ ):
265
+ if isinstance(size, tuple):
266
+ if len(size) != 2:
267
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
268
+ self.size = size
269
+ else:
270
+ self.size = (size, size)
271
+
272
+ self.interpolation_mode = interpolation_mode
273
+
274
+ def __call__(self, clip):
275
+ """
276
+ Args:
277
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
278
+ Returns:
279
+ torch.tensor: scale resized / center cropped video clip.
280
+ size is (T, C, crop_size, crop_size)
281
+ """
282
+ clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
283
+ clip_center_crop = center_crop(clip_resize, self.size)
284
+ return clip_center_crop
285
+
286
+ def __repr__(self) -> str:
287
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
288
+
289
+
290
+ class KineticsRandomCropResizeVideo:
291
+ '''
292
+ Slide along the long edge, with the short edge as crop size. And resie to the desired size.
293
+ '''
294
+
295
+ def __init__(
296
+ self,
297
+ size,
298
+ interpolation_mode="bilinear",
299
+ ):
300
+ if isinstance(size, tuple):
301
+ if len(size) != 2:
302
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
303
+ self.size = size
304
+ else:
305
+ self.size = (size, size)
306
+
307
+ self.interpolation_mode = interpolation_mode
308
+
309
+ def __call__(self, clip):
310
+ clip_random_crop = random_shift_crop(clip)
311
+ clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
312
+ return clip_resize
313
+
314
+
315
+ class CenterCropVideo:
316
+ def __init__(
317
+ self,
318
+ size,
319
+ interpolation_mode="bilinear",
320
+ ):
321
+ if isinstance(size, tuple):
322
+ if len(size) != 2:
323
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
324
+ self.size = size
325
+ else:
326
+ self.size = (size, size)
327
+
328
+ self.interpolation_mode = interpolation_mode
329
+
330
+ def __call__(self, clip):
331
+ """
332
+ Args:
333
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
334
+ Returns:
335
+ torch.tensor: center cropped video clip.
336
+ size is (T, C, crop_size, crop_size)
337
+ """
338
+ clip_center_crop = center_crop(clip, self.size)
339
+ return clip_center_crop
340
+
341
+ def __repr__(self) -> str:
342
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
343
+
344
+
345
+ class NormalizeVideo:
346
+ """
347
+ Normalize the video clip by mean subtraction and division by standard deviation
348
+ Args:
349
+ mean (3-tuple): pixel RGB mean
350
+ std (3-tuple): pixel RGB standard deviation
351
+ inplace (boolean): whether do in-place normalization
352
+ """
353
+
354
+ def __init__(self, mean, std, inplace=False):
355
+ self.mean = mean
356
+ self.std = std
357
+ self.inplace = inplace
358
+
359
+ def __call__(self, clip):
360
+ """
361
+ Args:
362
+ clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
363
+ """
364
+ return normalize(clip, self.mean, self.std, self.inplace)
365
+
366
+ def __repr__(self) -> str:
367
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
368
+
369
+
370
+ class ToTensorVideo:
371
+ """
372
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
373
+ permute the dimensions of clip tensor
374
+ """
375
+
376
+ def __init__(self):
377
+ pass
378
+
379
+ def __call__(self, clip):
380
+ """
381
+ Args:
382
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
383
+ Return:
384
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
385
+ """
386
+ return to_tensor(clip)
387
+
388
+ def __repr__(self) -> str:
389
+ return self.__class__.__name__
390
+
391
+
392
+ class RandomHorizontalFlipVideo:
393
+ """
394
+ Flip the video clip along the horizontal direction with a given probability
395
+ Args:
396
+ p (float): probability of the clip being flipped. Default value is 0.5
397
+ """
398
+
399
+ def __init__(self, p=0.5):
400
+ self.p = p
401
+
402
+ def __call__(self, clip):
403
+ """
404
+ Args:
405
+ clip (torch.tensor): Size is (T, C, H, W)
406
+ Return:
407
+ clip (torch.tensor): Size is (T, C, H, W)
408
+ """
409
+ if random.random() < self.p:
410
+ clip = hflip(clip)
411
+ return clip
412
+
413
+ def __repr__(self) -> str:
414
+ return f"{self.__class__.__name__}(p={self.p})"
415
+
416
+
417
+ # ------------------------------------------------------------
418
+ # --------------------- Sampling ---------------------------
419
+ # ------------------------------------------------------------
420
+ class TemporalRandomCrop(object):
421
+ """Temporally crop the given frame indices at a random location.
422
+
423
+ Args:
424
+ size (int): Desired length of frames will be seen in the model.
425
+ """
426
+
427
+ def __init__(self, size):
428
+ self.size = size
429
+
430
+ def __call__(self, total_frames):
431
+ rand_end = max(0, total_frames - self.size - 1)
432
+ begin_index = random.randint(0, rand_end)
433
+ end_index = min(begin_index + self.size, total_frames)
434
+ return begin_index, end_index
435
+
436
+
437
+ if __name__ == '__main__':
438
+ from torchvision import transforms
439
+ import torchvision.io as io
440
+ import numpy as np
441
+ from torchvision.utils import save_image
442
+ import os
443
+
444
+ vframes, aframes, info = io.read_video(
445
+ filename='./v_Archery_g01_c03.avi',
446
+ pts_unit='sec',
447
+ output_format='TCHW'
448
+ )
449
+
450
+ trans = transforms.Compose([
451
+ ToTensorVideo(),
452
+ RandomHorizontalFlipVideo(),
453
+ UCFCenterCropVideo(512),
454
+ # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
455
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
456
+ ])
457
+
458
+ target_video_len = 32
459
+ frame_interval = 1
460
+ total_frames = len(vframes)
461
+ print(total_frames)
462
+
463
+ temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
464
+
465
+ # Sampling video frames
466
+ start_frame_ind, end_frame_ind = temporal_sample(total_frames)
467
+ # print(start_frame_ind)
468
+ # print(end_frame_ind)
469
+ assert end_frame_ind - start_frame_ind >= target_video_len
470
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
471
+ print(frame_indice)
472
+
473
+ select_vframes = vframes[frame_indice]
474
+ print(select_vframes.shape)
475
+ print(select_vframes.dtype)
476
+
477
+ select_vframes_trans = trans(select_vframes)
478
+ print(select_vframes_trans.shape)
479
+ print(select_vframes_trans.dtype)
480
+
481
+ select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
482
+ print(select_vframes_trans_int.dtype)
483
+ print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
484
+
485
+ io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
486
+
487
+ for i in range(target_video_len):
488
+ save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True,
489
+ value_range=(-1, 1))
opensora/dataset/ucf101.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+
4
+ import decord
5
+ import numpy as np
6
+ import torch
7
+ import torchvision
8
+ from decord import VideoReader, cpu
9
+ from torch.utils.data import Dataset
10
+ from torchvision.transforms import Compose, Lambda, ToTensor
11
+ from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo
12
+ from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample
13
+ from torch.nn import functional as F
14
+ import random
15
+
16
+ from opensora.utils.dataset_utils import DecordInit
17
+
18
+
19
+ class UCF101(Dataset):
20
+ def __init__(self, args, transform, temporal_sample):
21
+ self.data_path = args.data_path
22
+ self.num_frames = args.num_frames
23
+ self.transform = transform
24
+ self.temporal_sample = temporal_sample
25
+ self.v_decoder = DecordInit()
26
+
27
+ self.classes = sorted(os.listdir(self.data_path))
28
+ self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
29
+ self.samples = self._make_dataset()
30
+
31
+
32
+ def _make_dataset(self):
33
+ dataset = []
34
+ for class_name in self.classes:
35
+ class_path = os.path.join(self.data_path, class_name)
36
+ for fname in os.listdir(class_path):
37
+ if fname.endswith('.avi'):
38
+ item = (os.path.join(class_path, fname), self.class_to_idx[class_name])
39
+ dataset.append(item)
40
+ return dataset
41
+
42
+ def __len__(self):
43
+ return len(self.samples)
44
+
45
+ def __getitem__(self, idx):
46
+ video_path, label = self.samples[idx]
47
+ try:
48
+ video = self.tv_read(video_path)
49
+ video = self.transform(video) # T C H W -> T C H W
50
+ video = video.transpose(0, 1) # T C H W -> C T H W
51
+ return video, label
52
+ except Exception as e:
53
+ print(f'Error with {e}, {video_path}')
54
+ return self.__getitem__(random.randint(0, self.__len__()-1))
55
+
56
+ def tv_read(self, path):
57
+ vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
58
+ total_frames = len(vframes)
59
+
60
+ # Sampling video frames
61
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
62
+ # assert end_frame_ind - start_frame_ind >= self.num_frames
63
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
64
+ video = vframes[frame_indice] # (T, C, H, W)
65
+
66
+ return video
67
+
68
+ def decord_read(self, path):
69
+ decord_vr = self.v_decoder(path)
70
+ total_frames = len(decord_vr)
71
+ # Sampling video frames
72
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
73
+ # assert end_frame_ind - start_frame_ind >= self.num_frames
74
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
75
+
76
+ video_data = decord_vr.get_batch(frame_indice).asnumpy()
77
+ video_data = torch.from_numpy(video_data)
78
+ video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
79
+ return video_data
80
+
opensora/eval/cal_flolpips.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from tqdm import tqdm
4
+ import math
5
+ from einops import rearrange
6
+ import sys
7
+ sys.path.append(".")
8
+ from opensora.eval.flolpips.pwcnet import Network as PWCNet
9
+ from opensora.eval.flolpips.flolpips import FloLPIPS
10
+
11
+ loss_fn = FloLPIPS(net='alex', version='0.1').eval().requires_grad_(False)
12
+ flownet = PWCNet().eval().requires_grad_(False)
13
+
14
+ def trans(x):
15
+ return x
16
+
17
+
18
+ def calculate_flolpips(videos1, videos2, device):
19
+ global loss_fn, flownet
20
+
21
+ print("calculate_flowlpips...")
22
+ loss_fn = loss_fn.to(device)
23
+ flownet = flownet.to(device)
24
+
25
+ if videos1.shape != videos2.shape:
26
+ print("Warning: the shape of videos are not equal.")
27
+ min_frames = min(videos1.shape[1], videos2.shape[1])
28
+ videos1 = videos1[:, :min_frames]
29
+ videos2 = videos2[:, :min_frames]
30
+
31
+ videos1 = trans(videos1)
32
+ videos2 = trans(videos2)
33
+
34
+ flolpips_results = []
35
+ for video_num in tqdm(range(videos1.shape[0])):
36
+ video1 = videos1[video_num].to(device)
37
+ video2 = videos2[video_num].to(device)
38
+ frames_rec = video1[:-1]
39
+ frames_rec_next = video1[1:]
40
+ frames_gt = video2[:-1]
41
+ frames_gt_next = video2[1:]
42
+ t, c, h, w = frames_gt.shape
43
+ flow_gt = flownet(frames_gt, frames_gt_next)
44
+ flow_dis = flownet(frames_rec, frames_rec_next)
45
+ flow_diff = flow_gt - flow_dis
46
+ flolpips = loss_fn.forward(frames_gt, frames_rec, flow_diff, normalize=True)
47
+ flolpips_results.append(flolpips.cpu().numpy().tolist())
48
+
49
+ flolpips_results = np.array(flolpips_results) # [batch_size, num_frames]
50
+ flolpips = {}
51
+ flolpips_std = {}
52
+
53
+ for clip_timestamp in range(flolpips_results.shape[1]):
54
+ flolpips[clip_timestamp] = np.mean(flolpips_results[:,clip_timestamp], axis=-1)
55
+ flolpips_std[clip_timestamp] = np.std(flolpips_results[:,clip_timestamp], axis=-1)
56
+
57
+ result = {
58
+ "value": flolpips,
59
+ "value_std": flolpips_std,
60
+ "video_setting": video1.shape,
61
+ "video_setting_name": "time, channel, heigth, width",
62
+ "result": flolpips_results,
63
+ "details": flolpips_results.tolist()
64
+ }
65
+
66
+ return result
67
+
68
+ # test code / using example
69
+
70
+ def main():
71
+ NUMBER_OF_VIDEOS = 8
72
+ VIDEO_LENGTH = 50
73
+ CHANNEL = 3
74
+ SIZE = 64
75
+ videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
76
+ videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
77
+
78
+ import json
79
+ result = calculate_flolpips(videos1, videos2, "cuda:0")
80
+ print(json.dumps(result, indent=4))
81
+
82
+ if __name__ == "__main__":
83
+ main()
opensora/eval/cal_fvd.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ def trans(x):
6
+ # if greyscale images add channel
7
+ if x.shape[-3] == 1:
8
+ x = x.repeat(1, 1, 3, 1, 1)
9
+
10
+ # permute BTCHW -> BCTHW
11
+ x = x.permute(0, 2, 1, 3, 4)
12
+
13
+ return x
14
+
15
+ def calculate_fvd(videos1, videos2, device, method='styleganv'):
16
+
17
+ if method == 'styleganv':
18
+ from fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained
19
+ elif method == 'videogpt':
20
+ from fvd.videogpt.fvd import load_i3d_pretrained
21
+ from fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats
22
+ from fvd.videogpt.fvd import frechet_distance
23
+
24
+ print("calculate_fvd...")
25
+
26
+ # videos [batch_size, timestamps, channel, h, w]
27
+
28
+ assert videos1.shape == videos2.shape
29
+
30
+ i3d = load_i3d_pretrained(device=device)
31
+ fvd_results = []
32
+
33
+ # support grayscale input, if grayscale -> channel*3
34
+ # BTCHW -> BCTHW
35
+ # videos -> [batch_size, channel, timestamps, h, w]
36
+
37
+ videos1 = trans(videos1)
38
+ videos2 = trans(videos2)
39
+
40
+ fvd_results = {}
41
+
42
+ # for calculate FVD, each clip_timestamp must >= 10
43
+ for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)):
44
+
45
+ # get a video clip
46
+ # videos_clip [batch_size, channel, timestamps[:clip], h, w]
47
+ videos_clip1 = videos1[:, :, : clip_timestamp]
48
+ videos_clip2 = videos2[:, :, : clip_timestamp]
49
+
50
+ # get FVD features
51
+ feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device)
52
+ feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device)
53
+
54
+ # calculate FVD when timestamps[:clip]
55
+ fvd_results[clip_timestamp] = frechet_distance(feats1, feats2)
56
+
57
+ result = {
58
+ "value": fvd_results,
59
+ "video_setting": videos1.shape,
60
+ "video_setting_name": "batch_size, channel, time, heigth, width",
61
+ }
62
+
63
+ return result
64
+
65
+ # test code / using example
66
+
67
+ def main():
68
+ NUMBER_OF_VIDEOS = 8
69
+ VIDEO_LENGTH = 50
70
+ CHANNEL = 3
71
+ SIZE = 64
72
+ videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
73
+ videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
74
+ device = torch.device("cuda")
75
+ # device = torch.device("cpu")
76
+
77
+ import json
78
+ result = calculate_fvd(videos1, videos2, device, method='videogpt')
79
+ print(json.dumps(result, indent=4))
80
+
81
+ result = calculate_fvd(videos1, videos2, device, method='styleganv')
82
+ print(json.dumps(result, indent=4))
83
+
84
+ if __name__ == "__main__":
85
+ main()
opensora/eval/cal_lpips.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from tqdm import tqdm
4
+ import math
5
+
6
+ import torch
7
+ import lpips
8
+
9
+ spatial = True # Return a spatial map of perceptual distance.
10
+
11
+ # Linearly calibrated models (LPIPS)
12
+ loss_fn = lpips.LPIPS(net='alex', spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
13
+ # loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
14
+
15
+ def trans(x):
16
+ # if greyscale images add channel
17
+ if x.shape[-3] == 1:
18
+ x = x.repeat(1, 1, 3, 1, 1)
19
+
20
+ # value range [0, 1] -> [-1, 1]
21
+ x = x * 2 - 1
22
+
23
+ return x
24
+
25
+ def calculate_lpips(videos1, videos2, device):
26
+ # image should be RGB, IMPORTANT: normalized to [-1,1]
27
+ print("calculate_lpips...")
28
+
29
+ assert videos1.shape == videos2.shape
30
+
31
+ # videos [batch_size, timestamps, channel, h, w]
32
+
33
+ # support grayscale input, if grayscale -> channel*3
34
+ # value range [0, 1] -> [-1, 1]
35
+ videos1 = trans(videos1)
36
+ videos2 = trans(videos2)
37
+
38
+ lpips_results = []
39
+
40
+ for video_num in tqdm(range(videos1.shape[0])):
41
+ # get a video
42
+ # video [timestamps, channel, h, w]
43
+ video1 = videos1[video_num]
44
+ video2 = videos2[video_num]
45
+
46
+ lpips_results_of_a_video = []
47
+ for clip_timestamp in range(len(video1)):
48
+ # get a img
49
+ # img [timestamps[x], channel, h, w]
50
+ # img [channel, h, w] tensor
51
+
52
+ img1 = video1[clip_timestamp].unsqueeze(0).to(device)
53
+ img2 = video2[clip_timestamp].unsqueeze(0).to(device)
54
+
55
+ loss_fn.to(device)
56
+
57
+ # calculate lpips of a video
58
+ lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())
59
+ lpips_results.append(lpips_results_of_a_video)
60
+
61
+ lpips_results = np.array(lpips_results)
62
+
63
+ lpips = {}
64
+ lpips_std = {}
65
+
66
+ for clip_timestamp in range(len(video1)):
67
+ lpips[clip_timestamp] = np.mean(lpips_results[:,clip_timestamp])
68
+ lpips_std[clip_timestamp] = np.std(lpips_results[:,clip_timestamp])
69
+
70
+
71
+ result = {
72
+ "value": lpips,
73
+ "value_std": lpips_std,
74
+ "video_setting": video1.shape,
75
+ "video_setting_name": "time, channel, heigth, width",
76
+ }
77
+
78
+ return result
79
+
80
+ # test code / using example
81
+
82
+ def main():
83
+ NUMBER_OF_VIDEOS = 8
84
+ VIDEO_LENGTH = 50
85
+ CHANNEL = 3
86
+ SIZE = 64
87
+ videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
88
+ videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
89
+ device = torch.device("cuda")
90
+ # device = torch.device("cpu")
91
+
92
+ import json
93
+ result = calculate_lpips(videos1, videos2, device)
94
+ print(json.dumps(result, indent=4))
95
+
96
+ if __name__ == "__main__":
97
+ main()
opensora/eval/cal_psnr.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from tqdm import tqdm
4
+ import math
5
+
6
+ def img_psnr(img1, img2):
7
+ # [0,1]
8
+ # compute mse
9
+ # mse = np.mean((img1-img2)**2)
10
+ mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
11
+ # compute psnr
12
+ if mse < 1e-10:
13
+ return 100
14
+ psnr = 20 * math.log10(1 / math.sqrt(mse))
15
+ return psnr
16
+
17
+ def trans(x):
18
+ return x
19
+
20
+ def calculate_psnr(videos1, videos2):
21
+ print("calculate_psnr...")
22
+
23
+ # videos [batch_size, timestamps, channel, h, w]
24
+
25
+ assert videos1.shape == videos2.shape
26
+
27
+ videos1 = trans(videos1)
28
+ videos2 = trans(videos2)
29
+
30
+ psnr_results = []
31
+
32
+ for video_num in tqdm(range(videos1.shape[0])):
33
+ # get a video
34
+ # video [timestamps, channel, h, w]
35
+ video1 = videos1[video_num]
36
+ video2 = videos2[video_num]
37
+
38
+ psnr_results_of_a_video = []
39
+ for clip_timestamp in range(len(video1)):
40
+ # get a img
41
+ # img [timestamps[x], channel, h, w]
42
+ # img [channel, h, w] numpy
43
+
44
+ img1 = video1[clip_timestamp].numpy()
45
+ img2 = video2[clip_timestamp].numpy()
46
+
47
+ # calculate psnr of a video
48
+ psnr_results_of_a_video.append(img_psnr(img1, img2))
49
+
50
+ psnr_results.append(psnr_results_of_a_video)
51
+
52
+ psnr_results = np.array(psnr_results) # [batch_size, num_frames]
53
+ psnr = {}
54
+ psnr_std = {}
55
+
56
+ for clip_timestamp in range(len(video1)):
57
+ psnr[clip_timestamp] = np.mean(psnr_results[:,clip_timestamp])
58
+ psnr_std[clip_timestamp] = np.std(psnr_results[:,clip_timestamp])
59
+
60
+ result = {
61
+ "value": psnr,
62
+ "value_std": psnr_std,
63
+ "video_setting": video1.shape,
64
+ "video_setting_name": "time, channel, heigth, width",
65
+ }
66
+
67
+ return result
68
+
69
+ # test code / using example
70
+
71
+ def main():
72
+ NUMBER_OF_VIDEOS = 8
73
+ VIDEO_LENGTH = 50
74
+ CHANNEL = 3
75
+ SIZE = 64
76
+ videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
77
+ videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
78
+
79
+ import json
80
+ result = calculate_psnr(videos1, videos2)
81
+ print(json.dumps(result, indent=4))
82
+
83
+ if __name__ == "__main__":
84
+ main()
opensora/eval/cal_ssim.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from tqdm import tqdm
4
+ import cv2
5
+
6
+ def ssim(img1, img2):
7
+ C1 = 0.01 ** 2
8
+ C2 = 0.03 ** 2
9
+ img1 = img1.astype(np.float64)
10
+ img2 = img2.astype(np.float64)
11
+ kernel = cv2.getGaussianKernel(11, 1.5)
12
+ window = np.outer(kernel, kernel.transpose())
13
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
14
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
15
+ mu1_sq = mu1 ** 2
16
+ mu2_sq = mu2 ** 2
17
+ mu1_mu2 = mu1 * mu2
18
+ sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
19
+ sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
20
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
21
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
22
+ (sigma1_sq + sigma2_sq + C2))
23
+ return ssim_map.mean()
24
+
25
+
26
+ def calculate_ssim_function(img1, img2):
27
+ # [0,1]
28
+ # ssim is the only metric extremely sensitive to gray being compared to b/w
29
+ if not img1.shape == img2.shape:
30
+ raise ValueError('Input images must have the same dimensions.')
31
+ if img1.ndim == 2:
32
+ return ssim(img1, img2)
33
+ elif img1.ndim == 3:
34
+ if img1.shape[0] == 3:
35
+ ssims = []
36
+ for i in range(3):
37
+ ssims.append(ssim(img1[i], img2[i]))
38
+ return np.array(ssims).mean()
39
+ elif img1.shape[0] == 1:
40
+ return ssim(np.squeeze(img1), np.squeeze(img2))
41
+ else:
42
+ raise ValueError('Wrong input image dimensions.')
43
+
44
+ def trans(x):
45
+ return x
46
+
47
+ def calculate_ssim(videos1, videos2):
48
+ print("calculate_ssim...")
49
+
50
+ # videos [batch_size, timestamps, channel, h, w]
51
+
52
+ assert videos1.shape == videos2.shape
53
+
54
+ videos1 = trans(videos1)
55
+ videos2 = trans(videos2)
56
+
57
+ ssim_results = []
58
+
59
+ for video_num in tqdm(range(videos1.shape[0])):
60
+ # get a video
61
+ # video [timestamps, channel, h, w]
62
+ video1 = videos1[video_num]
63
+ video2 = videos2[video_num]
64
+
65
+ ssim_results_of_a_video = []
66
+ for clip_timestamp in range(len(video1)):
67
+ # get a img
68
+ # img [timestamps[x], channel, h, w]
69
+ # img [channel, h, w] numpy
70
+
71
+ img1 = video1[clip_timestamp].numpy()
72
+ img2 = video2[clip_timestamp].numpy()
73
+
74
+ # calculate ssim of a video
75
+ ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))
76
+
77
+ ssim_results.append(ssim_results_of_a_video)
78
+
79
+ ssim_results = np.array(ssim_results)
80
+
81
+ ssim = {}
82
+ ssim_std = {}
83
+
84
+ for clip_timestamp in range(len(video1)):
85
+ ssim[clip_timestamp] = np.mean(ssim_results[:,clip_timestamp])
86
+ ssim_std[clip_timestamp] = np.std(ssim_results[:,clip_timestamp])
87
+
88
+ result = {
89
+ "value": ssim,
90
+ "value_std": ssim_std,
91
+ "video_setting": video1.shape,
92
+ "video_setting_name": "time, channel, heigth, width",
93
+ }
94
+
95
+ return result
96
+
97
+ # test code / using example
98
+
99
+ def main():
100
+ NUMBER_OF_VIDEOS = 8
101
+ VIDEO_LENGTH = 50
102
+ CHANNEL = 3
103
+ SIZE = 64
104
+ videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
105
+ videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
106
+ device = torch.device("cuda")
107
+
108
+ import json
109
+ result = calculate_ssim(videos1, videos2)
110
+ print(json.dumps(result, indent=4))
111
+
112
+ if __name__ == "__main__":
113
+ main()
opensora/eval/eval_clip_score.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Calculates the CLIP Scores
2
+
3
+ The CLIP model is a contrasitively learned language-image model. There is
4
+ an image encoder and a text encoder. It is believed that the CLIP model could
5
+ measure the similarity of cross modalities. Please find more information from
6
+ https://github.com/openai/CLIP.
7
+
8
+ The CLIP Score measures the Cosine Similarity between two embedded features.
9
+ This repository utilizes the pretrained CLIP Model to calculate
10
+ the mean average of cosine similarities.
11
+
12
+ See --help to see further details.
13
+
14
+ Code apapted from https://github.com/mseitzer/pytorch-fid and https://github.com/openai/CLIP.
15
+
16
+ Copyright 2023 The Hong Kong Polytechnic University
17
+
18
+ Licensed under the Apache License, Version 2.0 (the "License");
19
+ you may not use this file except in compliance with the License.
20
+ You may obtain a copy of the License at
21
+
22
+ http://www.apache.org/licenses/LICENSE-2.0
23
+
24
+ Unless required by applicable law or agreed to in writing, software
25
+ distributed under the License is distributed on an "AS IS" BASIS,
26
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27
+ See the License for the specific language governing permissions and
28
+ limitations under the License.
29
+ """
30
+ import os
31
+ import os.path as osp
32
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
33
+
34
+ import clip
35
+ import torch
36
+ from PIL import Image
37
+ from torch.utils.data import Dataset, DataLoader
38
+
39
+ try:
40
+ from tqdm import tqdm
41
+ except ImportError:
42
+ # If tqdm is not available, provide a mock version of it
43
+ def tqdm(x):
44
+ return x
45
+
46
+
47
+ IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
48
+ 'tif', 'tiff', 'webp'}
49
+
50
+ TEXT_EXTENSIONS = {'txt'}
51
+
52
+
53
+ class DummyDataset(Dataset):
54
+
55
+ FLAGS = ['img', 'txt']
56
+ def __init__(self, real_path, generated_path,
57
+ real_flag: str = 'img',
58
+ generated_flag: str = 'img',
59
+ transform = None,
60
+ tokenizer = None) -> None:
61
+ super().__init__()
62
+ assert real_flag in self.FLAGS and generated_flag in self.FLAGS, \
63
+ 'CLIP Score only support modality of {}. However, get {} and {}'.format(
64
+ self.FLAGS, real_flag, generated_flag
65
+ )
66
+ self.real_folder = self._combine_without_prefix(real_path)
67
+ self.real_flag = real_flag
68
+ self.fake_foler = self._combine_without_prefix(generated_path)
69
+ self.generated_flag = generated_flag
70
+ self.transform = transform
71
+ self.tokenizer = tokenizer
72
+ # assert self._check()
73
+
74
+ def __len__(self):
75
+ return len(self.real_folder)
76
+
77
+ def __getitem__(self, index):
78
+ if index >= len(self):
79
+ raise IndexError
80
+ real_path = self.real_folder[index]
81
+ generated_path = self.fake_foler[index]
82
+ real_data = self._load_modality(real_path, self.real_flag)
83
+ fake_data = self._load_modality(generated_path, self.generated_flag)
84
+
85
+ sample = dict(real=real_data, fake=fake_data)
86
+ return sample
87
+
88
+ def _load_modality(self, path, modality):
89
+ if modality == 'img':
90
+ data = self._load_img(path)
91
+ elif modality == 'txt':
92
+ data = self._load_txt(path)
93
+ else:
94
+ raise TypeError("Got unexpected modality: {}".format(modality))
95
+ return data
96
+
97
+ def _load_img(self, path):
98
+ img = Image.open(path)
99
+ if self.transform is not None:
100
+ img = self.transform(img)
101
+ return img
102
+
103
+ def _load_txt(self, path):
104
+ with open(path, 'r') as fp:
105
+ data = fp.read()
106
+ fp.close()
107
+ if self.tokenizer is not None:
108
+ data = self.tokenizer(data).squeeze()
109
+ return data
110
+
111
+ def _check(self):
112
+ for idx in range(len(self)):
113
+ real_name = self.real_folder[idx].split('.')
114
+ fake_name = self.fake_folder[idx].split('.')
115
+ if fake_name != real_name:
116
+ return False
117
+ return True
118
+
119
+ def _combine_without_prefix(self, folder_path, prefix='.'):
120
+ folder = []
121
+ for name in os.listdir(folder_path):
122
+ if name[0] == prefix:
123
+ continue
124
+ folder.append(osp.join(folder_path, name))
125
+ folder.sort()
126
+ return folder
127
+
128
+
129
+ @torch.no_grad()
130
+ def calculate_clip_score(dataloader, model, real_flag, generated_flag):
131
+ score_acc = 0.
132
+ sample_num = 0.
133
+ logit_scale = model.logit_scale.exp()
134
+ for batch_data in tqdm(dataloader):
135
+ real = batch_data['real']
136
+ real_features = forward_modality(model, real, real_flag)
137
+ fake = batch_data['fake']
138
+ fake_features = forward_modality(model, fake, generated_flag)
139
+
140
+ # normalize features
141
+ real_features = real_features / real_features.norm(dim=1, keepdim=True).to(torch.float32)
142
+ fake_features = fake_features / fake_features.norm(dim=1, keepdim=True).to(torch.float32)
143
+
144
+ # calculate scores
145
+ # score = logit_scale * real_features @ fake_features.t()
146
+ # score_acc += torch.diag(score).sum()
147
+ score = logit_scale * (fake_features * real_features).sum()
148
+ score_acc += score
149
+ sample_num += real.shape[0]
150
+
151
+ return score_acc / sample_num
152
+
153
+
154
+ def forward_modality(model, data, flag):
155
+ device = next(model.parameters()).device
156
+ if flag == 'img':
157
+ features = model.encode_image(data.to(device))
158
+ elif flag == 'txt':
159
+ features = model.encode_text(data.to(device))
160
+ else:
161
+ raise TypeError
162
+ return features
163
+
164
+
165
+ def main():
166
+ parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
167
+ parser.add_argument('--batch-size', type=int, default=50,
168
+ help='Batch size to use')
169
+ parser.add_argument('--clip-model', type=str, default='ViT-B/32',
170
+ help='CLIP model to use')
171
+ parser.add_argument('--num-workers', type=int, default=8,
172
+ help=('Number of processes to use for data loading. '
173
+ 'Defaults to `min(8, num_cpus)`'))
174
+ parser.add_argument('--device', type=str, default=None,
175
+ help='Device to use. Like cuda, cuda:0 or cpu')
176
+ parser.add_argument('--real_flag', type=str, default='img',
177
+ help=('The modality of real path. '
178
+ 'Default to img'))
179
+ parser.add_argument('--generated_flag', type=str, default='txt',
180
+ help=('The modality of generated path. '
181
+ 'Default to txt'))
182
+ parser.add_argument('--real_path', type=str,
183
+ help=('Paths to the real images or '
184
+ 'to .npz statistic files'))
185
+ parser.add_argument('--generated_path', type=str,
186
+ help=('Paths to the generated images or '
187
+ 'to .npz statistic files'))
188
+ args = parser.parse_args()
189
+
190
+ if args.device is None:
191
+ device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
192
+ else:
193
+ device = torch.device(args.device)
194
+
195
+ if args.num_workers is None:
196
+ try:
197
+ num_cpus = len(os.sched_getaffinity(0))
198
+ except AttributeError:
199
+ # os.sched_getaffinity is not available under Windows, use
200
+ # os.cpu_count instead (which may not return the *available* number
201
+ # of CPUs).
202
+ num_cpus = os.cpu_count()
203
+
204
+ num_workers = min(num_cpus, 8) if num_cpus is not None else 0
205
+ else:
206
+ num_workers = args.num_workers
207
+
208
+ print('Loading CLIP model: {}'.format(args.clip_model))
209
+ model, preprocess = clip.load(args.clip_model, device=device)
210
+
211
+ dataset = DummyDataset(args.real_path, args.generated_path,
212
+ args.real_flag, args.generated_flag,
213
+ transform=preprocess, tokenizer=clip.tokenize)
214
+ dataloader = DataLoader(dataset, args.batch_size,
215
+ num_workers=num_workers, pin_memory=True)
216
+
217
+ print('Calculating CLIP Score:')
218
+ clip_score = calculate_clip_score(dataloader, model,
219
+ args.real_flag, args.generated_flag)
220
+ clip_score = clip_score.cpu().item()
221
+ print('CLIP Score: ', clip_score)
222
+
223
+
224
+ if __name__ == '__main__':
225
+ main()
opensora/eval/eval_common_metric.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Calculates the CLIP Scores
2
+
3
+ The CLIP model is a contrasitively learned language-image model. There is
4
+ an image encoder and a text encoder. It is believed that the CLIP model could
5
+ measure the similarity of cross modalities. Please find more information from
6
+ https://github.com/openai/CLIP.
7
+
8
+ The CLIP Score measures the Cosine Similarity between two embedded features.
9
+ This repository utilizes the pretrained CLIP Model to calculate
10
+ the mean average of cosine similarities.
11
+
12
+ See --help to see further details.
13
+
14
+ Code apapted from https://github.com/mseitzer/pytorch-fid and https://github.com/openai/CLIP.
15
+
16
+ Copyright 2023 The Hong Kong Polytechnic University
17
+
18
+ Licensed under the Apache License, Version 2.0 (the "License");
19
+ you may not use this file except in compliance with the License.
20
+ You may obtain a copy of the License at
21
+
22
+ http://www.apache.org/licenses/LICENSE-2.0
23
+
24
+ Unless required by applicable law or agreed to in writing, software
25
+ distributed under the License is distributed on an "AS IS" BASIS,
26
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27
+ See the License for the specific language governing permissions and
28
+ limitations under the License.
29
+ """
30
+
31
+ import os
32
+ import os.path as osp
33
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
34
+ import numpy as np
35
+ import torch
36
+ from torch.utils.data import Dataset, DataLoader, Subset
37
+ from decord import VideoReader, cpu
38
+ import random
39
+ from pytorchvideo.transforms import ShortSideScale
40
+ from torchvision.io import read_video
41
+ from torchvision.transforms import Lambda, Compose
42
+ from torchvision.transforms._transforms_video import CenterCropVideo
43
+ import sys
44
+ sys.path.append(".")
45
+ from opensora.eval.cal_lpips import calculate_lpips
46
+ from opensora.eval.cal_fvd import calculate_fvd
47
+ from opensora.eval.cal_psnr import calculate_psnr
48
+ from opensora.eval.cal_flolpips import calculate_flolpips
49
+ from opensora.eval.cal_ssim import calculate_ssim
50
+
51
+ try:
52
+ from tqdm import tqdm
53
+ except ImportError:
54
+ # If tqdm is not available, provide a mock version of it
55
+ def tqdm(x):
56
+ return x
57
+
58
+ class VideoDataset(Dataset):
59
+ def __init__(self,
60
+ real_video_dir,
61
+ generated_video_dir,
62
+ num_frames,
63
+ sample_rate = 1,
64
+ crop_size=None,
65
+ resolution=128,
66
+ ) -> None:
67
+ super().__init__()
68
+ self.real_video_files = self._combine_without_prefix(real_video_dir)
69
+ self.generated_video_files = self._combine_without_prefix(generated_video_dir)
70
+ self.num_frames = num_frames
71
+ self.sample_rate = sample_rate
72
+ self.crop_size = crop_size
73
+ self.short_size = resolution
74
+
75
+
76
+ def __len__(self):
77
+ return len(self.real_video_files)
78
+
79
+ def __getitem__(self, index):
80
+ if index >= len(self):
81
+ raise IndexError
82
+ real_video_file = self.real_video_files[index]
83
+ generated_video_file = self.generated_video_files[index]
84
+ print(real_video_file, generated_video_file)
85
+ real_video_tensor = self._load_video(real_video_file)
86
+ generated_video_tensor = self._load_video(generated_video_file)
87
+ return {'real': real_video_tensor, 'generated':generated_video_tensor }
88
+
89
+
90
+ def _load_video(self, video_path):
91
+ num_frames = self.num_frames
92
+ sample_rate = self.sample_rate
93
+ decord_vr = VideoReader(video_path, ctx=cpu(0))
94
+ total_frames = len(decord_vr)
95
+ sample_frames_len = sample_rate * num_frames
96
+
97
+ if total_frames >= sample_frames_len:
98
+ s = 0
99
+ e = s + sample_frames_len
100
+ num_frames = num_frames
101
+ else:
102
+ s = 0
103
+ e = total_frames
104
+ num_frames = int(total_frames / sample_frames_len * num_frames)
105
+ print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path,
106
+ total_frames)
107
+
108
+
109
+ frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
110
+ video_data = decord_vr.get_batch(frame_id_list).asnumpy()
111
+ video_data = torch.from_numpy(video_data)
112
+ video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (C, T, H, W)
113
+ return _preprocess(video_data, short_size=self.short_size, crop_size = self.crop_size)
114
+
115
+
116
+ def _combine_without_prefix(self, folder_path, prefix='.'):
117
+ folder = []
118
+ os.makedirs(folder_path, exist_ok=True)
119
+ for name in os.listdir(folder_path):
120
+ if name[0] == prefix:
121
+ continue
122
+ if osp.isfile(osp.join(folder_path, name)):
123
+ folder.append(osp.join(folder_path, name))
124
+ folder.sort()
125
+ return folder
126
+
127
+ def _preprocess(video_data, short_size=128, crop_size=None):
128
+ transform = Compose(
129
+ [
130
+ Lambda(lambda x: x / 255.0),
131
+ ShortSideScale(size=short_size),
132
+ CenterCropVideo(crop_size=crop_size),
133
+ ]
134
+ )
135
+ video_outputs = transform(video_data)
136
+ # video_outputs = torch.unsqueeze(video_outputs, 0) # (bz,c,t,h,w)
137
+ return video_outputs
138
+
139
+
140
+ def calculate_common_metric(args, dataloader, device):
141
+
142
+ score_list = []
143
+ for batch_data in tqdm(dataloader): # {'real': real_video_tensor, 'generated':generated_video_tensor }
144
+ real_videos = batch_data['real']
145
+ generated_videos = batch_data['generated']
146
+ assert real_videos.shape[2] == generated_videos.shape[2]
147
+ if args.metric == 'fvd':
148
+ tmp_list = list(calculate_fvd(real_videos, generated_videos, args.device, method=args.fvd_method)['value'].values())
149
+ elif args.metric == 'ssim':
150
+ tmp_list = list(calculate_ssim(real_videos, generated_videos)['value'].values())
151
+ elif args.metric == 'psnr':
152
+ tmp_list = list(calculate_psnr(real_videos, generated_videos)['value'].values())
153
+ elif args.metric == 'flolpips':
154
+ result = calculate_flolpips(real_videos, generated_videos, args.device)
155
+ tmp_list = list(result['value'].values())
156
+ else:
157
+ tmp_list = list(calculate_lpips(real_videos, generated_videos, args.device)['value'].values())
158
+ score_list += tmp_list
159
+ return np.mean(score_list)
160
+
161
+ def main():
162
+ parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
163
+ parser.add_argument('--batch_size', type=int, default=2,
164
+ help='Batch size to use')
165
+ parser.add_argument('--real_video_dir', type=str,
166
+ help=('the path of real videos`'))
167
+ parser.add_argument('--generated_video_dir', type=str,
168
+ help=('the path of generated videos`'))
169
+ parser.add_argument('--device', type=str, default=None,
170
+ help='Device to use. Like cuda, cuda:0 or cpu')
171
+ parser.add_argument('--num_workers', type=int, default=8,
172
+ help=('Number of processes to use for data loading. '
173
+ 'Defaults to `min(8, num_cpus)`'))
174
+ parser.add_argument('--sample_fps', type=int, default=30)
175
+ parser.add_argument('--resolution', type=int, default=336)
176
+ parser.add_argument('--crop_size', type=int, default=None)
177
+ parser.add_argument('--num_frames', type=int, default=100)
178
+ parser.add_argument('--sample_rate', type=int, default=1)
179
+ parser.add_argument('--subset_size', type=int, default=None)
180
+ parser.add_argument("--metric", type=str, default="fvd",choices=['fvd','psnr','ssim','lpips', 'flolpips'])
181
+ parser.add_argument("--fvd_method", type=str, default='styleganv',choices=['styleganv','videogpt'])
182
+
183
+
184
+ args = parser.parse_args()
185
+
186
+ if args.device is None:
187
+ device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
188
+ else:
189
+ device = torch.device(args.device)
190
+
191
+ if args.num_workers is None:
192
+ try:
193
+ num_cpus = len(os.sched_getaffinity(0))
194
+ except AttributeError:
195
+ # os.sched_getaffinity is not available under Windows, use
196
+ # os.cpu_count instead (which may not return the *available* number
197
+ # of CPUs).
198
+ num_cpus = os.cpu_count()
199
+
200
+ num_workers = min(num_cpus, 8) if num_cpus is not None else 0
201
+ else:
202
+ num_workers = args.num_workers
203
+
204
+
205
+ dataset = VideoDataset(args.real_video_dir,
206
+ args.generated_video_dir,
207
+ num_frames = args.num_frames,
208
+ sample_rate = args.sample_rate,
209
+ crop_size=args.crop_size,
210
+ resolution=args.resolution)
211
+
212
+ if args.subset_size:
213
+ indices = range(args.subset_size)
214
+ dataset = Subset(dataset, indices=indices)
215
+
216
+ dataloader = DataLoader(dataset, args.batch_size,
217
+ num_workers=num_workers, pin_memory=True)
218
+
219
+
220
+ metric_score = calculate_common_metric(args, dataloader,device)
221
+ print('metric: ', args.metric, " ",metric_score)
222
+
223
+ if __name__ == '__main__':
224
+ main()
opensora/eval/flolpips/correlation/correlation.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import torch
4
+
5
+ import cupy
6
+ import re
7
+
8
+ kernel_Correlation_rearrange = '''
9
+ extern "C" __global__ void kernel_Correlation_rearrange(
10
+ const int n,
11
+ const float* input,
12
+ float* output
13
+ ) {
14
+ int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
15
+
16
+ if (intIndex >= n) {
17
+ return;
18
+ }
19
+
20
+ int intSample = blockIdx.z;
21
+ int intChannel = blockIdx.y;
22
+
23
+ float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
24
+
25
+ __syncthreads();
26
+
27
+ int intPaddedY = (intIndex / SIZE_3(input)) + 4;
28
+ int intPaddedX = (intIndex % SIZE_3(input)) + 4;
29
+ int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX;
30
+
31
+ output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue;
32
+ }
33
+ '''
34
+
35
+ kernel_Correlation_updateOutput = '''
36
+ extern "C" __global__ void kernel_Correlation_updateOutput(
37
+ const int n,
38
+ const float* rbot0,
39
+ const float* rbot1,
40
+ float* top
41
+ ) {
42
+ extern __shared__ char patch_data_char[];
43
+
44
+ float *patch_data = (float *)patch_data_char;
45
+
46
+ // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
47
+ int x1 = blockIdx.x + 4;
48
+ int y1 = blockIdx.y + 4;
49
+ int item = blockIdx.z;
50
+ int ch_off = threadIdx.x;
51
+
52
+ // Load 3D patch into shared shared memory
53
+ for (int j = 0; j < 1; j++) { // HEIGHT
54
+ for (int i = 0; i < 1; i++) { // WIDTH
55
+ int ji_off = (j + i) * SIZE_3(rbot0);
56
+ for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
57
+ int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
58
+ int idxPatchData = ji_off + ch;
59
+ patch_data[idxPatchData] = rbot0[idx1];
60
+ }
61
+ }
62
+ }
63
+
64
+ __syncthreads();
65
+
66
+ __shared__ float sum[32];
67
+
68
+ // Compute correlation
69
+ for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
70
+ sum[ch_off] = 0;
71
+
72
+ int s2o = top_channel % 9 - 4;
73
+ int s2p = top_channel / 9 - 4;
74
+
75
+ for (int j = 0; j < 1; j++) { // HEIGHT
76
+ for (int i = 0; i < 1; i++) { // WIDTH
77
+ int ji_off = (j + i) * SIZE_3(rbot0);
78
+ for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
79
+ int x2 = x1 + s2o;
80
+ int y2 = y1 + s2p;
81
+
82
+ int idxPatchData = ji_off + ch;
83
+ int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
84
+
85
+ sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
86
+ }
87
+ }
88
+ }
89
+
90
+ __syncthreads();
91
+
92
+ if (ch_off == 0) {
93
+ float total_sum = 0;
94
+ for (int idx = 0; idx < 32; idx++) {
95
+ total_sum += sum[idx];
96
+ }
97
+ const int sumelems = SIZE_3(rbot0);
98
+ const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
99
+ top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
100
+ }
101
+ }
102
+ }
103
+ '''
104
+
105
+ kernel_Correlation_updateGradFirst = '''
106
+ #define ROUND_OFF 50000
107
+
108
+ extern "C" __global__ void kernel_Correlation_updateGradFirst(
109
+ const int n,
110
+ const int intSample,
111
+ const float* rbot0,
112
+ const float* rbot1,
113
+ const float* gradOutput,
114
+ float* gradFirst,
115
+ float* gradSecond
116
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
117
+ int n = intIndex % SIZE_1(gradFirst); // channels
118
+ int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos
119
+ int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos
120
+
121
+ // round_off is a trick to enable integer division with ceil, even for negative numbers
122
+ // We use a large offset, for the inner part not to become negative.
123
+ const int round_off = ROUND_OFF;
124
+ const int round_off_s1 = round_off;
125
+
126
+ // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
127
+ int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
128
+ int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
129
+
130
+ // Same here:
131
+ int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4)
132
+ int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4)
133
+
134
+ float sum = 0;
135
+ if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
136
+ xmin = max(0,xmin);
137
+ xmax = min(SIZE_3(gradOutput)-1,xmax);
138
+
139
+ ymin = max(0,ymin);
140
+ ymax = min(SIZE_2(gradOutput)-1,ymax);
141
+
142
+ for (int p = -4; p <= 4; p++) {
143
+ for (int o = -4; o <= 4; o++) {
144
+ // Get rbot1 data:
145
+ int s2o = o;
146
+ int s2p = p;
147
+ int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
148
+ float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
149
+
150
+ // Index offset for gradOutput in following loops:
151
+ int op = (p+4) * 9 + (o+4); // index[o,p]
152
+ int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
153
+
154
+ for (int y = ymin; y <= ymax; y++) {
155
+ for (int x = xmin; x <= xmax; x++) {
156
+ int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
157
+ sum += gradOutput[idxgradOutput] * bot1tmp;
158
+ }
159
+ }
160
+ }
161
+ }
162
+ }
163
+ const int sumelems = SIZE_1(gradFirst);
164
+ const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4);
165
+ gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems;
166
+ } }
167
+ '''
168
+
169
+ kernel_Correlation_updateGradSecond = '''
170
+ #define ROUND_OFF 50000
171
+
172
+ extern "C" __global__ void kernel_Correlation_updateGradSecond(
173
+ const int n,
174
+ const int intSample,
175
+ const float* rbot0,
176
+ const float* rbot1,
177
+ const float* gradOutput,
178
+ float* gradFirst,
179
+ float* gradSecond
180
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
181
+ int n = intIndex % SIZE_1(gradSecond); // channels
182
+ int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos
183
+ int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos
184
+
185
+ // round_off is a trick to enable integer division with ceil, even for negative numbers
186
+ // We use a large offset, for the inner part not to become negative.
187
+ const int round_off = ROUND_OFF;
188
+ const int round_off_s1 = round_off;
189
+
190
+ float sum = 0;
191
+ for (int p = -4; p <= 4; p++) {
192
+ for (int o = -4; o <= 4; o++) {
193
+ int s2o = o;
194
+ int s2p = p;
195
+
196
+ //Get X,Y ranges and clamp
197
+ // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
198
+ int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
199
+ int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
200
+
201
+ // Same here:
202
+ int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o)
203
+ int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p)
204
+
205
+ if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
206
+ xmin = max(0,xmin);
207
+ xmax = min(SIZE_3(gradOutput)-1,xmax);
208
+
209
+ ymin = max(0,ymin);
210
+ ymax = min(SIZE_2(gradOutput)-1,ymax);
211
+
212
+ // Get rbot0 data:
213
+ int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
214
+ float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
215
+
216
+ // Index offset for gradOutput in following loops:
217
+ int op = (p+4) * 9 + (o+4); // index[o,p]
218
+ int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
219
+
220
+ for (int y = ymin; y <= ymax; y++) {
221
+ for (int x = xmin; x <= xmax; x++) {
222
+ int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
223
+ sum += gradOutput[idxgradOutput] * bot0tmp;
224
+ }
225
+ }
226
+ }
227
+ }
228
+ }
229
+ const int sumelems = SIZE_1(gradSecond);
230
+ const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4);
231
+ gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems;
232
+ } }
233
+ '''
234
+
235
+ def cupy_kernel(strFunction, objVariables):
236
+ strKernel = globals()[strFunction]
237
+
238
+ while True:
239
+ objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
240
+
241
+ if objMatch is None:
242
+ break
243
+ # end
244
+
245
+ intArg = int(objMatch.group(2))
246
+
247
+ strTensor = objMatch.group(4)
248
+ intSizes = objVariables[strTensor].size()
249
+
250
+ strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
251
+ # end
252
+
253
+ while True:
254
+ objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
255
+
256
+ if objMatch is None:
257
+ break
258
+ # end
259
+
260
+ intArgs = int(objMatch.group(2))
261
+ strArgs = objMatch.group(4).split(',')
262
+
263
+ strTensor = strArgs[0]
264
+ intStrides = objVariables[strTensor].stride()
265
+ strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]
266
+
267
+ strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
268
+ # end
269
+
270
+ return strKernel
271
+ # end
272
+
273
+ @cupy.memoize(for_each_device=True)
274
+ def cupy_launch(strFunction, strKernel):
275
+ return cupy.RawKernel(strKernel, strFunction)
276
+ # end
277
+
278
+ class _FunctionCorrelation(torch.autograd.Function):
279
+ @staticmethod
280
+ def forward(self, first, second):
281
+ rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ])
282
+ rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ])
283
+
284
+ self.save_for_backward(first, second, rbot0, rbot1)
285
+
286
+ first = first.contiguous(); assert(first.is_cuda == True)
287
+ second = second.contiguous(); assert(second.is_cuda == True)
288
+
289
+ output = first.new_zeros([ first.shape[0], 81, first.shape[2], first.shape[3] ])
290
+
291
+ if first.is_cuda == True:
292
+ n = first.shape[2] * first.shape[3]
293
+ cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
294
+ 'input': first,
295
+ 'output': rbot0
296
+ }))(
297
+ grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]),
298
+ block=tuple([ 16, 1, 1 ]),
299
+ args=[ n, first.data_ptr(), rbot0.data_ptr() ]
300
+ )
301
+
302
+ n = second.shape[2] * second.shape[3]
303
+ cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
304
+ 'input': second,
305
+ 'output': rbot1
306
+ }))(
307
+ grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]),
308
+ block=tuple([ 16, 1, 1 ]),
309
+ args=[ n, second.data_ptr(), rbot1.data_ptr() ]
310
+ )
311
+
312
+ n = output.shape[1] * output.shape[2] * output.shape[3]
313
+ cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {
314
+ 'rbot0': rbot0,
315
+ 'rbot1': rbot1,
316
+ 'top': output
317
+ }))(
318
+ grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]),
319
+ block=tuple([ 32, 1, 1 ]),
320
+ shared_mem=first.shape[1] * 4,
321
+ args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ]
322
+ )
323
+
324
+ elif first.is_cuda == False:
325
+ raise NotImplementedError()
326
+
327
+ # end
328
+
329
+ return output
330
+ # end
331
+
332
+ @staticmethod
333
+ def backward(self, gradOutput):
334
+ first, second, rbot0, rbot1 = self.saved_tensors
335
+
336
+ gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True)
337
+
338
+ gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None
339
+ gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None
340
+
341
+ if first.is_cuda == True:
342
+ if gradFirst is not None:
343
+ for intSample in range(first.shape[0]):
344
+ n = first.shape[1] * first.shape[2] * first.shape[3]
345
+ cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', {
346
+ 'rbot0': rbot0,
347
+ 'rbot1': rbot1,
348
+ 'gradOutput': gradOutput,
349
+ 'gradFirst': gradFirst,
350
+ 'gradSecond': None
351
+ }))(
352
+ grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
353
+ block=tuple([ 512, 1, 1 ]),
354
+ args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ]
355
+ )
356
+ # end
357
+ # end
358
+
359
+ if gradSecond is not None:
360
+ for intSample in range(first.shape[0]):
361
+ n = first.shape[1] * first.shape[2] * first.shape[3]
362
+ cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', {
363
+ 'rbot0': rbot0,
364
+ 'rbot1': rbot1,
365
+ 'gradOutput': gradOutput,
366
+ 'gradFirst': None,
367
+ 'gradSecond': gradSecond
368
+ }))(
369
+ grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
370
+ block=tuple([ 512, 1, 1 ]),
371
+ args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ]
372
+ )
373
+ # end
374
+ # end
375
+
376
+ elif first.is_cuda == False:
377
+ raise NotImplementedError()
378
+
379
+ # end
380
+
381
+ return gradFirst, gradSecond
382
+ # end
383
+ # end
384
+
385
+ def FunctionCorrelation(tenFirst, tenSecond):
386
+ return _FunctionCorrelation.apply(tenFirst, tenSecond)
387
+ # end
388
+
389
+ class ModuleCorrelation(torch.nn.Module):
390
+ def __init__(self):
391
+ super(ModuleCorrelation, self).__init__()
392
+ # end
393
+
394
+ def forward(self, tenFirst, tenSecond):
395
+ return _FunctionCorrelation.apply(tenFirst, tenSecond)
396
+ # end
397
+ # end
opensora/eval/flolpips/flolpips.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+ import os
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.autograd import Variable
8
+ from .pretrained_networks import vgg16, alexnet, squeezenet
9
+ import torch.nn
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms.functional as TF
12
+ import cv2
13
+
14
+ from .pwcnet import Network as PWCNet
15
+ from .utils import *
16
+
17
+ def spatial_average(in_tens, keepdim=True):
18
+ return in_tens.mean([2,3],keepdim=keepdim)
19
+
20
+ def mw_spatial_average(in_tens, flow, keepdim=True):
21
+ _,_,h,w = in_tens.shape
22
+ flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear')
23
+ flow_mag = torch.sqrt(flow[:,0:1]**2 + flow[:,1:2]**2)
24
+ flow_mag = flow_mag / torch.sum(flow_mag, dim=[1,2,3], keepdim=True)
25
+ return torch.sum(in_tens*flow_mag, dim=[2,3],keepdim=keepdim)
26
+
27
+
28
+ def mtw_spatial_average(in_tens, flow, texture, keepdim=True):
29
+ _,_,h,w = in_tens.shape
30
+ flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear')
31
+ texture = F.interpolate(texture, (h,w), align_corners=False, mode='bilinear')
32
+ flow_mag = torch.sqrt(flow[:,0:1]**2 + flow[:,1:2]**2)
33
+ flow_mag = (flow_mag - flow_mag.min()) / (flow_mag.max() - flow_mag.min()) + 1e-6
34
+ texture = (texture - texture.min()) / (texture.max() - texture.min()) + 1e-6
35
+ weight = flow_mag / texture
36
+ weight /= torch.sum(weight)
37
+ return torch.sum(in_tens*weight, dim=[2,3],keepdim=keepdim)
38
+
39
+
40
+
41
+ def m2w_spatial_average(in_tens, flow, keepdim=True):
42
+ _,_,h,w = in_tens.shape
43
+ flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear')
44
+ flow_mag = flow[:,0:1]**2 + flow[:,1:2]**2 # B,1,H,W
45
+ flow_mag = flow_mag / torch.sum(flow_mag)
46
+ return torch.sum(in_tens*flow_mag, dim=[2,3],keepdim=keepdim)
47
+
48
+ def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W
49
+ in_H, in_W = in_tens.shape[2], in_tens.shape[3]
50
+ return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)
51
+
52
+ # Learned perceptual metric
53
+ class LPIPS(nn.Module):
54
+ def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False,
55
+ pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False):
56
+ # lpips - [True] means with linear calibration on top of base network
57
+ # pretrained - [True] means load linear weights
58
+
59
+ super(LPIPS, self).__init__()
60
+ if(verbose):
61
+ print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'%
62
+ ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))
63
+
64
+ self.pnet_type = net
65
+ self.pnet_tune = pnet_tune
66
+ self.pnet_rand = pnet_rand
67
+ self.spatial = spatial
68
+ self.lpips = lpips # false means baseline of just averaging all layers
69
+ self.version = version
70
+ self.scaling_layer = ScalingLayer()
71
+
72
+ if(self.pnet_type in ['vgg','vgg16']):
73
+ net_type = vgg16
74
+ self.chns = [64,128,256,512,512]
75
+ elif(self.pnet_type=='alex'):
76
+ net_type = alexnet
77
+ self.chns = [64,192,384,256,256]
78
+ elif(self.pnet_type=='squeeze'):
79
+ net_type = squeezenet
80
+ self.chns = [64,128,256,384,384,512,512]
81
+ self.L = len(self.chns)
82
+
83
+ self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
84
+
85
+ if(lpips):
86
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
87
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
88
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
89
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
90
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
91
+ self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
92
+ if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
93
+ self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
94
+ self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
95
+ self.lins+=[self.lin5,self.lin6]
96
+ self.lins = nn.ModuleList(self.lins)
97
+
98
+ if(pretrained):
99
+ if(model_path is None):
100
+ import inspect
101
+ import os
102
+ model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net)))
103
+
104
+ if(verbose):
105
+ print('Loading model from: %s'%model_path)
106
+ self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
107
+
108
+ if(eval_mode):
109
+ self.eval()
110
+
111
+ def forward(self, in0, in1, retPerLayer=False, normalize=False):
112
+ if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
113
+ in0 = 2 * in0 - 1
114
+ in1 = 2 * in1 - 1
115
+
116
+ # v0.0 - original release had a bug, where input was not scaled
117
+ in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
118
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
119
+ feats0, feats1, diffs = {}, {}, {}
120
+
121
+ for kk in range(self.L):
122
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
123
+ diffs[kk] = (feats0[kk]-feats1[kk])**2
124
+
125
+ if(self.lpips):
126
+ if(self.spatial):
127
+ res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
128
+ else:
129
+ res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
130
+ else:
131
+ if(self.spatial):
132
+ res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
133
+ else:
134
+ res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
135
+
136
+ # val = res[0]
137
+ # for l in range(1,self.L):
138
+ # val += res[l]
139
+ # print(val)
140
+
141
+ # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
142
+ # b = torch.max(self.lins[kk](feats0[kk]**2))
143
+ # for kk in range(self.L):
144
+ # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
145
+ # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2)))
146
+ # a = a/self.L
147
+ # from IPython import embed
148
+ # embed()
149
+ # return 10*torch.log10(b/a)
150
+
151
+ # if(retPerLayer):
152
+ # return (val, res)
153
+ # else:
154
+ return torch.sum(torch.cat(res, 1), dim=(1,2,3), keepdims=False)
155
+
156
+
157
+ class ScalingLayer(nn.Module):
158
+ def __init__(self):
159
+ super(ScalingLayer, self).__init__()
160
+ self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
161
+ self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
162
+
163
+ def forward(self, inp):
164
+ return (inp - self.shift) / self.scale
165
+
166
+
167
+ class NetLinLayer(nn.Module):
168
+ ''' A single linear layer which does a 1x1 conv '''
169
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
170
+ super(NetLinLayer, self).__init__()
171
+
172
+ layers = [nn.Dropout(),] if(use_dropout) else []
173
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
174
+ self.model = nn.Sequential(*layers)
175
+
176
+ def forward(self, x):
177
+ return self.model(x)
178
+
179
+ class Dist2LogitLayer(nn.Module):
180
+ ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
181
+ def __init__(self, chn_mid=32, use_sigmoid=True):
182
+ super(Dist2LogitLayer, self).__init__()
183
+
184
+ layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
185
+ layers += [nn.LeakyReLU(0.2,True),]
186
+ layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
187
+ layers += [nn.LeakyReLU(0.2,True),]
188
+ layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
189
+ if(use_sigmoid):
190
+ layers += [nn.Sigmoid(),]
191
+ self.model = nn.Sequential(*layers)
192
+
193
+ def forward(self,d0,d1,eps=0.1):
194
+ return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
195
+
196
+ class BCERankingLoss(nn.Module):
197
+ def __init__(self, chn_mid=32):
198
+ super(BCERankingLoss, self).__init__()
199
+ self.net = Dist2LogitLayer(chn_mid=chn_mid)
200
+ # self.parameters = list(self.net.parameters())
201
+ self.loss = torch.nn.BCELoss()
202
+
203
+ def forward(self, d0, d1, judge):
204
+ per = (judge+1.)/2.
205
+ self.logit = self.net.forward(d0,d1)
206
+ return self.loss(self.logit, per)
207
+
208
+ # L2, DSSIM metrics
209
+ class FakeNet(nn.Module):
210
+ def __init__(self, use_gpu=True, colorspace='Lab'):
211
+ super(FakeNet, self).__init__()
212
+ self.use_gpu = use_gpu
213
+ self.colorspace = colorspace
214
+
215
+ class L2(FakeNet):
216
+ def forward(self, in0, in1, retPerLayer=None):
217
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
218
+
219
+ if(self.colorspace=='RGB'):
220
+ (N,C,X,Y) = in0.size()
221
+ value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
222
+ return value
223
+ elif(self.colorspace=='Lab'):
224
+ value = l2(tensor2np(tensor2tensorlab(in0.data,to_norm=False)),
225
+ tensor2np(tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
226
+ ret_var = Variable( torch.Tensor((value,) ) )
227
+ if(self.use_gpu):
228
+ ret_var = ret_var.cuda()
229
+ return ret_var
230
+
231
+ class DSSIM(FakeNet):
232
+
233
+ def forward(self, in0, in1, retPerLayer=None):
234
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
235
+
236
+ if(self.colorspace=='RGB'):
237
+ value = dssim(1.*tensor2im(in0.data), 1.*tensor2im(in1.data), range=255.).astype('float')
238
+ elif(self.colorspace=='Lab'):
239
+ value = dssim(tensor2np(tensor2tensorlab(in0.data,to_norm=False)),
240
+ tensor2np(tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
241
+ ret_var = Variable( torch.Tensor((value,) ) )
242
+ if(self.use_gpu):
243
+ ret_var = ret_var.cuda()
244
+ return ret_var
245
+
246
+ def print_network(net):
247
+ num_params = 0
248
+ for param in net.parameters():
249
+ num_params += param.numel()
250
+ print('Network',net)
251
+ print('Total number of parameters: %d' % num_params)
252
+
253
+
254
+ class FloLPIPS(LPIPS):
255
+ def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False):
256
+ super(FloLPIPS, self).__init__(pretrained, net, version, lpips, spatial, pnet_rand, pnet_tune, use_dropout, model_path, eval_mode, verbose)
257
+
258
+ def forward(self, in0, in1, flow, retPerLayer=False, normalize=False):
259
+ if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
260
+ in0 = 2 * in0 - 1
261
+ in1 = 2 * in1 - 1
262
+
263
+ in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
264
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
265
+ feats0, feats1, diffs = {}, {}, {}
266
+
267
+ for kk in range(self.L):
268
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
269
+ diffs[kk] = (feats0[kk]-feats1[kk])**2
270
+
271
+ res = [mw_spatial_average(self.lins[kk](diffs[kk]), flow, keepdim=True) for kk in range(self.L)]
272
+
273
+ return torch.sum(torch.cat(res, 1), dim=(1,2,3), keepdims=False)
274
+
275
+
276
+
277
+
278
+
279
+ class Flolpips(nn.Module):
280
+ def __init__(self):
281
+ super(Flolpips, self).__init__()
282
+ self.loss_fn = FloLPIPS(net='alex',version='0.1')
283
+ self.flownet = PWCNet()
284
+
285
+ @torch.no_grad()
286
+ def forward(self, I0, I1, frame_dis, frame_ref):
287
+ """
288
+ args:
289
+ I0: first frame of the triplet, shape: [B, C, H, W]
290
+ I1: third frame of the triplet, shape: [B, C, H, W]
291
+ frame_dis: prediction of the intermediate frame, shape: [B, C, H, W]
292
+ frame_ref: ground-truth of the intermediate frame, shape: [B, C, H, W]
293
+ """
294
+ assert I0.size() == I1.size() == frame_dis.size() == frame_ref.size(), \
295
+ "the 4 input tensors should have same size"
296
+
297
+ flow_ref = self.flownet(frame_ref, I0)
298
+ flow_dis = self.flownet(frame_dis, I0)
299
+ flow_diff = flow_ref - flow_dis
300
+ flolpips_wrt_I0 = self.loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True)
301
+
302
+ flow_ref = self.flownet(frame_ref, I1)
303
+ flow_dis = self.flownet(frame_dis, I1)
304
+ flow_diff = flow_ref - flow_dis
305
+ flolpips_wrt_I1 = self.loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True)
306
+
307
+ flolpips = (flolpips_wrt_I0 + flolpips_wrt_I1) / 2
308
+ return flolpips
opensora/eval/flolpips/pretrained_networks.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ from torchvision import models as tv
4
+
5
+ class squeezenet(torch.nn.Module):
6
+ def __init__(self, requires_grad=False, pretrained=True):
7
+ super(squeezenet, self).__init__()
8
+ pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
9
+ self.slice1 = torch.nn.Sequential()
10
+ self.slice2 = torch.nn.Sequential()
11
+ self.slice3 = torch.nn.Sequential()
12
+ self.slice4 = torch.nn.Sequential()
13
+ self.slice5 = torch.nn.Sequential()
14
+ self.slice6 = torch.nn.Sequential()
15
+ self.slice7 = torch.nn.Sequential()
16
+ self.N_slices = 7
17
+ for x in range(2):
18
+ self.slice1.add_module(str(x), pretrained_features[x])
19
+ for x in range(2,5):
20
+ self.slice2.add_module(str(x), pretrained_features[x])
21
+ for x in range(5, 8):
22
+ self.slice3.add_module(str(x), pretrained_features[x])
23
+ for x in range(8, 10):
24
+ self.slice4.add_module(str(x), pretrained_features[x])
25
+ for x in range(10, 11):
26
+ self.slice5.add_module(str(x), pretrained_features[x])
27
+ for x in range(11, 12):
28
+ self.slice6.add_module(str(x), pretrained_features[x])
29
+ for x in range(12, 13):
30
+ self.slice7.add_module(str(x), pretrained_features[x])
31
+ if not requires_grad:
32
+ for param in self.parameters():
33
+ param.requires_grad = False
34
+
35
+ def forward(self, X):
36
+ h = self.slice1(X)
37
+ h_relu1 = h
38
+ h = self.slice2(h)
39
+ h_relu2 = h
40
+ h = self.slice3(h)
41
+ h_relu3 = h
42
+ h = self.slice4(h)
43
+ h_relu4 = h
44
+ h = self.slice5(h)
45
+ h_relu5 = h
46
+ h = self.slice6(h)
47
+ h_relu6 = h
48
+ h = self.slice7(h)
49
+ h_relu7 = h
50
+ vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
51
+ out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
52
+
53
+ return out
54
+
55
+
56
+ class alexnet(torch.nn.Module):
57
+ def __init__(self, requires_grad=False, pretrained=True):
58
+ super(alexnet, self).__init__()
59
+ alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
60
+ self.slice1 = torch.nn.Sequential()
61
+ self.slice2 = torch.nn.Sequential()
62
+ self.slice3 = torch.nn.Sequential()
63
+ self.slice4 = torch.nn.Sequential()
64
+ self.slice5 = torch.nn.Sequential()
65
+ self.N_slices = 5
66
+ for x in range(2):
67
+ self.slice1.add_module(str(x), alexnet_pretrained_features[x])
68
+ for x in range(2, 5):
69
+ self.slice2.add_module(str(x), alexnet_pretrained_features[x])
70
+ for x in range(5, 8):
71
+ self.slice3.add_module(str(x), alexnet_pretrained_features[x])
72
+ for x in range(8, 10):
73
+ self.slice4.add_module(str(x), alexnet_pretrained_features[x])
74
+ for x in range(10, 12):
75
+ self.slice5.add_module(str(x), alexnet_pretrained_features[x])
76
+ if not requires_grad:
77
+ for param in self.parameters():
78
+ param.requires_grad = False
79
+
80
+ def forward(self, X):
81
+ h = self.slice1(X)
82
+ h_relu1 = h
83
+ h = self.slice2(h)
84
+ h_relu2 = h
85
+ h = self.slice3(h)
86
+ h_relu3 = h
87
+ h = self.slice4(h)
88
+ h_relu4 = h
89
+ h = self.slice5(h)
90
+ h_relu5 = h
91
+ alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
92
+ out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
93
+
94
+ return out
95
+
96
+ class vgg16(torch.nn.Module):
97
+ def __init__(self, requires_grad=False, pretrained=True):
98
+ super(vgg16, self).__init__()
99
+ vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
100
+ self.slice1 = torch.nn.Sequential()
101
+ self.slice2 = torch.nn.Sequential()
102
+ self.slice3 = torch.nn.Sequential()
103
+ self.slice4 = torch.nn.Sequential()
104
+ self.slice5 = torch.nn.Sequential()
105
+ self.N_slices = 5
106
+ for x in range(4):
107
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
108
+ for x in range(4, 9):
109
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
110
+ for x in range(9, 16):
111
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
112
+ for x in range(16, 23):
113
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
114
+ for x in range(23, 30):
115
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
116
+ if not requires_grad:
117
+ for param in self.parameters():
118
+ param.requires_grad = False
119
+
120
+ def forward(self, X):
121
+ h = self.slice1(X)
122
+ h_relu1_2 = h
123
+ h = self.slice2(h)
124
+ h_relu2_2 = h
125
+ h = self.slice3(h)
126
+ h_relu3_3 = h
127
+ h = self.slice4(h)
128
+ h_relu4_3 = h
129
+ h = self.slice5(h)
130
+ h_relu5_3 = h
131
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
132
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
133
+
134
+ return out
135
+
136
+
137
+
138
+ class resnet(torch.nn.Module):
139
+ def __init__(self, requires_grad=False, pretrained=True, num=18):
140
+ super(resnet, self).__init__()
141
+ if(num==18):
142
+ self.net = tv.resnet18(pretrained=pretrained)
143
+ elif(num==34):
144
+ self.net = tv.resnet34(pretrained=pretrained)
145
+ elif(num==50):
146
+ self.net = tv.resnet50(pretrained=pretrained)
147
+ elif(num==101):
148
+ self.net = tv.resnet101(pretrained=pretrained)
149
+ elif(num==152):
150
+ self.net = tv.resnet152(pretrained=pretrained)
151
+ self.N_slices = 5
152
+
153
+ self.conv1 = self.net.conv1
154
+ self.bn1 = self.net.bn1
155
+ self.relu = self.net.relu
156
+ self.maxpool = self.net.maxpool
157
+ self.layer1 = self.net.layer1
158
+ self.layer2 = self.net.layer2
159
+ self.layer3 = self.net.layer3
160
+ self.layer4 = self.net.layer4
161
+
162
+ def forward(self, X):
163
+ h = self.conv1(X)
164
+ h = self.bn1(h)
165
+ h = self.relu(h)
166
+ h_relu1 = h
167
+ h = self.maxpool(h)
168
+ h = self.layer1(h)
169
+ h_conv2 = h
170
+ h = self.layer2(h)
171
+ h_conv3 = h
172
+ h = self.layer3(h)
173
+ h_conv4 = h
174
+ h = self.layer4(h)
175
+ h_conv5 = h
176
+
177
+ outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
178
+ out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
179
+
180
+ return out
opensora/eval/flolpips/pwcnet.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import torch
4
+
5
+ import getopt
6
+ import math
7
+ import numpy
8
+ import os
9
+ import PIL
10
+ import PIL.Image
11
+ import sys
12
+
13
+ # try:
14
+ from .correlation import correlation # the custom cost volume layer
15
+ # except:
16
+ # sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python
17
+ # end
18
+
19
+ ##########################################################
20
+
21
+ # assert(int(str('').join(torch.__version__.split('.')[0:2])) >= 13) # requires at least pytorch version 1.3.0
22
+
23
+ # torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance
24
+
25
+ # torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance
26
+
27
+ # ##########################################################
28
+
29
+ # arguments_strModel = 'default' # 'default', or 'chairs-things'
30
+ # arguments_strFirst = './images/first.png'
31
+ # arguments_strSecond = './images/second.png'
32
+ # arguments_strOut = './out.flo'
33
+
34
+ # for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]:
35
+ # if strOption == '--model' and strArgument != '': arguments_strModel = strArgument # which model to use
36
+ # if strOption == '--first' and strArgument != '': arguments_strFirst = strArgument # path to the first frame
37
+ # if strOption == '--second' and strArgument != '': arguments_strSecond = strArgument # path to the second frame
38
+ # if strOption == '--out' and strArgument != '': arguments_strOut = strArgument # path to where the output should be stored
39
+ # end
40
+
41
+ ##########################################################
42
+
43
+
44
+
45
+ def backwarp(tenInput, tenFlow):
46
+ backwarp_tenGrid = {}
47
+ backwarp_tenPartial = {}
48
+ if str(tenFlow.shape) not in backwarp_tenGrid:
49
+ tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1)
50
+ tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3])
51
+
52
+ backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([ tenHor, tenVer ], 1).cuda()
53
+ # end
54
+
55
+ if str(tenFlow.shape) not in backwarp_tenPartial:
56
+ backwarp_tenPartial[str(tenFlow.shape)] = tenFlow.new_ones([ tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3] ])
57
+ # end
58
+
59
+ tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)
60
+ tenInput = torch.cat([ tenInput, backwarp_tenPartial[str(tenFlow.shape)] ], 1)
61
+
62
+ tenOutput = torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=False)
63
+
64
+ tenMask = tenOutput[:, -1:, :, :]; tenMask[tenMask > 0.999] = 1.0; tenMask[tenMask < 1.0] = 0.0
65
+
66
+ return tenOutput[:, :-1, :, :] * tenMask
67
+ # end
68
+
69
+ ##########################################################
70
+
71
+ class Network(torch.nn.Module):
72
+ def __init__(self):
73
+ super(Network, self).__init__()
74
+
75
+ class Extractor(torch.nn.Module):
76
+ def __init__(self):
77
+ super(Extractor, self).__init__()
78
+
79
+ self.netOne = torch.nn.Sequential(
80
+ torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),
81
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
82
+ torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
83
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
84
+ torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
85
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
86
+ )
87
+
88
+ self.netTwo = torch.nn.Sequential(
89
+ torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),
90
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
91
+ torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
92
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
93
+ torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
94
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
95
+ )
96
+
97
+ self.netThr = torch.nn.Sequential(
98
+ torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
99
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
100
+ torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
101
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
102
+ torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
103
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
104
+ )
105
+
106
+ self.netFou = torch.nn.Sequential(
107
+ torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1),
108
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
109
+ torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
110
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
111
+ torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
112
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
113
+ )
114
+
115
+ self.netFiv = torch.nn.Sequential(
116
+ torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1),
117
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
118
+ torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
119
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
120
+ torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
121
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
122
+ )
123
+
124
+ self.netSix = torch.nn.Sequential(
125
+ torch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1),
126
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
127
+ torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),
128
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
129
+ torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),
130
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
131
+ )
132
+ # end
133
+
134
+ def forward(self, tenInput):
135
+ tenOne = self.netOne(tenInput)
136
+ tenTwo = self.netTwo(tenOne)
137
+ tenThr = self.netThr(tenTwo)
138
+ tenFou = self.netFou(tenThr)
139
+ tenFiv = self.netFiv(tenFou)
140
+ tenSix = self.netSix(tenFiv)
141
+
142
+ return [ tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix ]
143
+ # end
144
+ # end
145
+
146
+ class Decoder(torch.nn.Module):
147
+ def __init__(self, intLevel):
148
+ super(Decoder, self).__init__()
149
+
150
+ intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 1]
151
+ intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 0]
152
+
153
+ if intLevel < 6: self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1)
154
+ if intLevel < 6: self.netUpfeat = torch.nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1)
155
+ if intLevel < 6: self.fltBackwarp = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1]
156
+
157
+ self.netOne = torch.nn.Sequential(
158
+ torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1),
159
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
160
+ )
161
+
162
+ self.netTwo = torch.nn.Sequential(
163
+ torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1),
164
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
165
+ )
166
+
167
+ self.netThr = torch.nn.Sequential(
168
+ torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1),
169
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
170
+ )
171
+
172
+ self.netFou = torch.nn.Sequential(
173
+ torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1),
174
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
175
+ )
176
+
177
+ self.netFiv = torch.nn.Sequential(
178
+ torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1),
179
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
180
+ )
181
+
182
+ self.netSix = torch.nn.Sequential(
183
+ torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1)
184
+ )
185
+ # end
186
+
187
+ def forward(self, tenFirst, tenSecond, objPrevious):
188
+ tenFlow = None
189
+ tenFeat = None
190
+
191
+ if objPrevious is None:
192
+ tenFlow = None
193
+ tenFeat = None
194
+
195
+ tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=tenSecond), negative_slope=0.1, inplace=False)
196
+
197
+ tenFeat = torch.cat([ tenVolume ], 1)
198
+
199
+ elif objPrevious is not None:
200
+ tenFlow = self.netUpflow(objPrevious['tenFlow'])
201
+ tenFeat = self.netUpfeat(objPrevious['tenFeat'])
202
+
203
+ tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=backwarp(tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp)), negative_slope=0.1, inplace=False)
204
+
205
+ tenFeat = torch.cat([ tenVolume, tenFirst, tenFlow, tenFeat ], 1)
206
+
207
+ # end
208
+
209
+ tenFeat = torch.cat([ self.netOne(tenFeat), tenFeat ], 1)
210
+ tenFeat = torch.cat([ self.netTwo(tenFeat), tenFeat ], 1)
211
+ tenFeat = torch.cat([ self.netThr(tenFeat), tenFeat ], 1)
212
+ tenFeat = torch.cat([ self.netFou(tenFeat), tenFeat ], 1)
213
+ tenFeat = torch.cat([ self.netFiv(tenFeat), tenFeat ], 1)
214
+
215
+ tenFlow = self.netSix(tenFeat)
216
+
217
+ return {
218
+ 'tenFlow': tenFlow,
219
+ 'tenFeat': tenFeat
220
+ }
221
+ # end
222
+ # end
223
+
224
+ class Refiner(torch.nn.Module):
225
+ def __init__(self):
226
+ super(Refiner, self).__init__()
227
+
228
+ self.netMain = torch.nn.Sequential(
229
+ torch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1),
230
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
231
+ torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2),
232
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
233
+ torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4),
234
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
235
+ torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8),
236
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
237
+ torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16),
238
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
239
+ torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1),
240
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
241
+ torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1)
242
+ )
243
+ # end
244
+
245
+ def forward(self, tenInput):
246
+ return self.netMain(tenInput)
247
+ # end
248
+ # end
249
+
250
+ self.netExtractor = Extractor()
251
+
252
+ self.netTwo = Decoder(2)
253
+ self.netThr = Decoder(3)
254
+ self.netFou = Decoder(4)
255
+ self.netFiv = Decoder(5)
256
+ self.netSix = Decoder(6)
257
+
258
+ self.netRefiner = Refiner()
259
+
260
+ self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/github/pytorch-pwc/network-' + 'default' + '.pytorch').items() })
261
+ # end
262
+
263
+ def forward(self, tenFirst, tenSecond):
264
+ intWidth = tenFirst.shape[3]
265
+ intHeight = tenFirst.shape[2]
266
+
267
+ intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0))
268
+ intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0))
269
+
270
+ tenPreprocessedFirst = torch.nn.functional.interpolate(input=tenFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
271
+ tenPreprocessedSecond = torch.nn.functional.interpolate(input=tenSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
272
+
273
+ tenFirst = self.netExtractor(tenPreprocessedFirst)
274
+ tenSecond = self.netExtractor(tenPreprocessedSecond)
275
+
276
+
277
+ objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None)
278
+ objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate)
279
+ objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate)
280
+ objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate)
281
+ objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate)
282
+
283
+ tenFlow = objEstimate['tenFlow'] + self.netRefiner(objEstimate['tenFeat'])
284
+ tenFlow = 20.0 * torch.nn.functional.interpolate(input=tenFlow, size=(intHeight, intWidth), mode='bilinear', align_corners=False)
285
+ tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth)
286
+ tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight)
287
+
288
+ return tenFlow
289
+ # end
290
+ # end
291
+
292
+ netNetwork = None
293
+
294
+ ##########################################################
295
+
296
+ def estimate(tenFirst, tenSecond):
297
+ global netNetwork
298
+
299
+ if netNetwork is None:
300
+ netNetwork = Network().cuda().eval()
301
+ # end
302
+
303
+ assert(tenFirst.shape[1] == tenSecond.shape[1])
304
+ assert(tenFirst.shape[2] == tenSecond.shape[2])
305
+
306
+ intWidth = tenFirst.shape[2]
307
+ intHeight = tenFirst.shape[1]
308
+
309
+ assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
310
+ assert(intHeight == 436) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
311
+
312
+ tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth)
313
+ tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth)
314
+
315
+ intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0))
316
+ intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0))
317
+
318
+ tenPreprocessedFirst = torch.nn.functional.interpolate(input=tenPreprocessedFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
319
+ tenPreprocessedSecond = torch.nn.functional.interpolate(input=tenPreprocessedSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
320
+
321
+ tenFlow = 20.0 * torch.nn.functional.interpolate(input=netNetwork(tenPreprocessedFirst, tenPreprocessedSecond), size=(intHeight, intWidth), mode='bilinear', align_corners=False)
322
+
323
+ tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth)
324
+ tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight)
325
+
326
+ return tenFlow[0, :, :, :].cpu()
327
+ # end
328
+
329
+ ##########################################################
330
+
331
+ # if __name__ == '__main__':
332
+ # tenFirst = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strFirst))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
333
+ # tenSecond = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strSecond))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
334
+
335
+ # tenOutput = estimate(tenFirst, tenSecond)
336
+
337
+ # objOutput = open(arguments_strOut, 'wb')
338
+
339
+ # numpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objOutput)
340
+ # numpy.array([ tenOutput.shape[2], tenOutput.shape[1] ], numpy.int32).tofile(objOutput)
341
+ # numpy.array(tenOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objOutput)
342
+
343
+ # objOutput.close()
344
+ # end
opensora/eval/flolpips/utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import torch
4
+
5
+
6
+ def normalize_tensor(in_feat,eps=1e-10):
7
+ norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
8
+ return in_feat/(norm_factor+eps)
9
+
10
+ def l2(p0, p1, range=255.):
11
+ return .5*np.mean((p0 / range - p1 / range)**2)
12
+
13
+ def dssim(p0, p1, range=255.):
14
+ from skimage.measure import compare_ssim
15
+ return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
16
+
17
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
18
+ image_numpy = image_tensor[0].cpu().float().numpy()
19
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
20
+ return image_numpy.astype(imtype)
21
+
22
+ def tensor2np(tensor_obj):
23
+ # change dimension of a tensor object into a numpy array
24
+ return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
25
+
26
+ def np2tensor(np_obj):
27
+ # change dimenion of np array into tensor array
28
+ return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
29
+
30
+ def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
31
+ # image tensor to lab tensor
32
+ from skimage import color
33
+
34
+ img = tensor2im(image_tensor)
35
+ img_lab = color.rgb2lab(img)
36
+ if(mc_only):
37
+ img_lab[:,:,0] = img_lab[:,:,0]-50
38
+ if(to_norm and not mc_only):
39
+ img_lab[:,:,0] = img_lab[:,:,0]-50
40
+ img_lab = img_lab/100.
41
+
42
+ return np2tensor(img_lab)
43
+
44
+ def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt='420'):
45
+ if pix_fmt == '420':
46
+ multiplier = 1
47
+ uv_factor = 2
48
+ elif pix_fmt == '444':
49
+ multiplier = 2
50
+ uv_factor = 1
51
+ else:
52
+ print('Pixel format {} is not supported'.format(pix_fmt))
53
+ return
54
+
55
+ if bit_depth == 8:
56
+ datatype = np.uint8
57
+ stream.seek(iFrame*1.5*width*height*multiplier)
58
+ Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width))
59
+
60
+ # read chroma samples and upsample since original is 4:2:0 sampling
61
+ U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\
62
+ reshape((height//uv_factor, width//uv_factor))
63
+ V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\
64
+ reshape((height//uv_factor, width//uv_factor))
65
+
66
+ else:
67
+ datatype = np.uint16
68
+ stream.seek(iFrame*3*width*height*multiplier)
69
+ Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width))
70
+
71
+ U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\
72
+ reshape((height//uv_factor, width//uv_factor))
73
+ V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\
74
+ reshape((height//uv_factor, width//uv_factor))
75
+
76
+ if pix_fmt == '420':
77
+ yuv = np.empty((height*3//2, width), dtype=datatype)
78
+ yuv[0:height,:] = Y
79
+
80
+ yuv[height:height+height//4,:] = U.reshape(-1, width)
81
+ yuv[height+height//4:,:] = V.reshape(-1, width)
82
+
83
+ if bit_depth != 8:
84
+ yuv = (yuv/(2**bit_depth-1)*255).astype(np.uint8)
85
+
86
+ #convert to rgb
87
+ rgb = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB_I420)
88
+
89
+ else:
90
+ yvu = np.stack([Y,V,U],axis=2)
91
+ if bit_depth != 8:
92
+ yvu = (yvu/(2**bit_depth-1)*255).astype(np.uint8)
93
+ rgb = cv2.cvtColor(yvu, cv2.COLOR_YCrCb2RGB)
94
+
95
+ return rgb
opensora/eval/fvd/styleganv/fvd.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import math
4
+ import torch.nn.functional as F
5
+
6
+ # https://github.com/universome/fvd-comparison
7
+
8
+
9
+ def load_i3d_pretrained(device=torch.device('cpu')):
10
+ i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt"
11
+ filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt')
12
+ print(filepath)
13
+ if not os.path.exists(filepath):
14
+ print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
15
+ os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
16
+ i3d = torch.jit.load(filepath).eval().to(device)
17
+ i3d = torch.nn.DataParallel(i3d)
18
+ return i3d
19
+
20
+
21
+ def get_feats(videos, detector, device, bs=10):
22
+ # videos : torch.tensor BCTHW [0, 1]
23
+ detector_kwargs = dict(rescale=False, resize=False, return_features=True) # Return raw features before the softmax layer.
24
+ feats = np.empty((0, 400))
25
+ with torch.no_grad():
26
+ for i in range((len(videos)-1)//bs + 1):
27
+ feats = np.vstack([feats, detector(torch.stack([preprocess_single(video) for video in videos[i*bs:(i+1)*bs]]).to(device), **detector_kwargs).detach().cpu().numpy()])
28
+ return feats
29
+
30
+
31
+ def get_fvd_feats(videos, i3d, device, bs=10):
32
+ # videos in [0, 1] as torch tensor BCTHW
33
+ # videos = [preprocess_single(video) for video in videos]
34
+ embeddings = get_feats(videos, i3d, device, bs)
35
+ return embeddings
36
+
37
+
38
+ def preprocess_single(video, resolution=224, sequence_length=None):
39
+ # video: CTHW, [0, 1]
40
+ c, t, h, w = video.shape
41
+
42
+ # temporal crop
43
+ if sequence_length is not None:
44
+ assert sequence_length <= t
45
+ video = video[:, :sequence_length]
46
+
47
+ # scale shorter side to resolution
48
+ scale = resolution / min(h, w)
49
+ if h < w:
50
+ target_size = (resolution, math.ceil(w * scale))
51
+ else:
52
+ target_size = (math.ceil(h * scale), resolution)
53
+ video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False)
54
+
55
+ # center crop
56
+ c, t, h, w = video.shape
57
+ w_start = (w - resolution) // 2
58
+ h_start = (h - resolution) // 2
59
+ video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
60
+
61
+ # [0, 1] -> [-1, 1]
62
+ video = (video - 0.5) * 2
63
+
64
+ return video.contiguous()
65
+
66
+
67
+ """
68
+ Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py
69
+ """
70
+ from typing import Tuple
71
+ from scipy.linalg import sqrtm
72
+ import numpy as np
73
+
74
+
75
+ def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
76
+ mu = feats.mean(axis=0) # [d]
77
+ sigma = np.cov(feats, rowvar=False) # [d, d]
78
+ return mu, sigma
79
+
80
+
81
+ def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
82
+ mu_gen, sigma_gen = compute_stats(feats_fake)
83
+ mu_real, sigma_real = compute_stats(feats_real)
84
+ m = np.square(mu_gen - mu_real).sum()
85
+ if feats_fake.shape[0]>1:
86
+ s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
87
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
88
+ else:
89
+ fid = np.real(m)
90
+ return float(fid)
opensora/eval/fvd/styleganv/i3d_torchscript.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bec6519f66ea534e953026b4ae2c65553c17bf105611c746d904657e5860a5e2
3
+ size 51235320