Upload 244 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +16 -0
- LICENSE +21 -0
- assets/13 00_00_00-00_00_30.gif +3 -0
- assets/5 00_00_00-00_00_30.gif +3 -0
- assets/6 00_00_00-00_00_30.gif +3 -0
- assets/7 00_00_00-00_00_30.gif +3 -0
- assets/dpvj8-y3ubn.gif +3 -0
- assets/framework.jpg +0 -0
- assets/i1ude-11d4e.gif +3 -0
- assets/kntw7-iuluy.gif +3 -0
- assets/nr2a2-oe6qj.gif +3 -0
- assets/ns4et-xj8ax.gif +3 -0
- assets/open-sora-plan.png +3 -0
- assets/ozg76-g1aqh.gif +3 -0
- assets/pvvm5-5hm65.gif +3 -0
- assets/rrdqk-puoud.gif +3 -0
- assets/we_want_you.jpg +0 -0
- assets/y70q9-y5tip.gif +3 -0
- docker/LICENSE +21 -0
- docker/README.md +87 -0
- docker/build_docker.png +0 -0
- docker/docker_build.sh +8 -0
- docker/docker_run.sh +45 -0
- docker/dockerfile.base +24 -0
- docker/packages.txt +3 -0
- docker/ports.txt +1 -0
- docker/postinstallscript.sh +3 -0
- docker/requirements.txt +40 -0
- docker/run_docker.png +0 -0
- docker/setup_env.sh +11 -0
- docs/Contribution_Guidelines.md +87 -0
- docs/Data.md +35 -0
- docs/EVAL.md +110 -0
- docs/Report-v1.0.0.md +131 -0
- docs/VQVAE.md +57 -0
- examples/get_latents_std.py +38 -0
- examples/prompt_list_0.txt +16 -0
- examples/rec_image.py +57 -0
- examples/rec_imvi_vae.py +159 -0
- examples/rec_video.py +120 -0
- examples/rec_video_ae.py +120 -0
- examples/rec_video_vae.py +275 -0
- opensora/__init__.py +1 -0
- opensora/dataset/__init__.py +99 -0
- opensora/dataset/extract_feature_dataset.py +64 -0
- opensora/dataset/feature_datasets.py +213 -0
- opensora/dataset/landscope.py +90 -0
- opensora/dataset/sky_datasets.py +128 -0
- opensora/dataset/t2v_datasets.py +111 -0
- opensora/dataset/transform.py +489 -0
.gitattributes
CHANGED
@@ -33,3 +33,19 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/13[[:space:]]00_00_00-00_00_30.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/5[[:space:]]00_00_00-00_00_30.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/6[[:space:]]00_00_00-00_00_30.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/7[[:space:]]00_00_00-00_00_30.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/dpvj8-y3ubn.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/i1ude-11d4e.gif filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/kntw7-iuluy.gif filter=lfs diff=lfs merge=lfs -text
|
43 |
+
assets/nr2a2-oe6qj.gif filter=lfs diff=lfs merge=lfs -text
|
44 |
+
assets/ns4et-xj8ax.gif filter=lfs diff=lfs merge=lfs -text
|
45 |
+
assets/open-sora-plan.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
assets/ozg76-g1aqh.gif filter=lfs diff=lfs merge=lfs -text
|
47 |
+
assets/pvvm5-5hm65.gif filter=lfs diff=lfs merge=lfs -text
|
48 |
+
assets/rrdqk-puoud.gif filter=lfs diff=lfs merge=lfs -text
|
49 |
+
assets/y70q9-y5tip.gif filter=lfs diff=lfs merge=lfs -text
|
50 |
+
opensora/models/captioner/caption_refiner/dataset/test_videos/video1.gif filter=lfs diff=lfs merge=lfs -text
|
51 |
+
opensora/models/captioner/caption_refiner/dataset/test_videos/video2.gif filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 PKU-YUAN's Group (袁粒课题组-北大信工)
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
assets/13 00_00_00-00_00_30.gif
ADDED
Git LFS Details
|
assets/5 00_00_00-00_00_30.gif
ADDED
Git LFS Details
|
assets/6 00_00_00-00_00_30.gif
ADDED
Git LFS Details
|
assets/7 00_00_00-00_00_30.gif
ADDED
Git LFS Details
|
assets/dpvj8-y3ubn.gif
ADDED
Git LFS Details
|
assets/framework.jpg
ADDED
assets/i1ude-11d4e.gif
ADDED
Git LFS Details
|
assets/kntw7-iuluy.gif
ADDED
Git LFS Details
|
assets/nr2a2-oe6qj.gif
ADDED
Git LFS Details
|
assets/ns4et-xj8ax.gif
ADDED
Git LFS Details
|
assets/open-sora-plan.png
ADDED
Git LFS Details
|
assets/ozg76-g1aqh.gif
ADDED
Git LFS Details
|
assets/pvvm5-5hm65.gif
ADDED
Git LFS Details
|
assets/rrdqk-puoud.gif
ADDED
Git LFS Details
|
assets/we_want_you.jpg
ADDED
assets/y70q9-y5tip.gif
ADDED
Git LFS Details
|
docker/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 SimonLee
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
docker/README.md
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Docker4ML
|
2 |
+
|
3 |
+
Useful docker scripts for ML developement.
|
4 |
+
[https://github.com/SimonLeeGit/Docker4ML](https://github.com/SimonLeeGit/Docker4ML)
|
5 |
+
|
6 |
+
## Build Docker Image
|
7 |
+
|
8 |
+
```bash
|
9 |
+
bash docker_build.sh
|
10 |
+
```
|
11 |
+
|
12 |
+
![build_docker](build_docker.png)
|
13 |
+
|
14 |
+
## Run Docker Container as Development Envirnoment
|
15 |
+
|
16 |
+
```bash
|
17 |
+
bash docker_run.sh
|
18 |
+
```
|
19 |
+
|
20 |
+
![run_docker](run_docker.png)
|
21 |
+
|
22 |
+
## Custom Docker Config
|
23 |
+
|
24 |
+
### Config [setup_env.sh](./setup_env.sh)
|
25 |
+
|
26 |
+
You can modify this file to custom your settings.
|
27 |
+
|
28 |
+
```bash
|
29 |
+
TAG=ml:dev
|
30 |
+
BASE_TAG=nvcr.io/nvidia/pytorch:23.12-py3
|
31 |
+
```
|
32 |
+
|
33 |
+
#### TAG
|
34 |
+
|
35 |
+
Your built docker image tag, you can set it as what you what.
|
36 |
+
|
37 |
+
#### BASE_TAG
|
38 |
+
|
39 |
+
The base docker image tag for your built docker image, here we use nvidia pytorch images.
|
40 |
+
You can check it from [https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags)
|
41 |
+
|
42 |
+
Also, you can use other docker image as base, such as: [ubuntu](https://hub.docker.com/_/ubuntu/tags)
|
43 |
+
|
44 |
+
### USER_NAME
|
45 |
+
|
46 |
+
Your user name used in docker container.
|
47 |
+
|
48 |
+
### USER_PASSWD
|
49 |
+
|
50 |
+
Your user password used in docker container.
|
51 |
+
|
52 |
+
### Config [requriements.txt](./requirements.txt)
|
53 |
+
|
54 |
+
You can add your default installed python libraries here.
|
55 |
+
|
56 |
+
```txt
|
57 |
+
transformers==4.27.1
|
58 |
+
```
|
59 |
+
|
60 |
+
By default, it has some libs installed, you can check it from [https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-01.html](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-01.html)
|
61 |
+
|
62 |
+
### Config [packages.txt](./packages.txt)
|
63 |
+
|
64 |
+
You can add your default apt-get installed packages here.
|
65 |
+
|
66 |
+
```txt
|
67 |
+
wget
|
68 |
+
curl
|
69 |
+
git
|
70 |
+
```
|
71 |
+
|
72 |
+
### Config [ports.txt](./ports.txt)
|
73 |
+
|
74 |
+
You can add some ports enabled for docker container here.
|
75 |
+
|
76 |
+
```txt
|
77 |
+
-p 6006:6006
|
78 |
+
-p 8080:8080
|
79 |
+
```
|
80 |
+
|
81 |
+
### Config [postinstallscript.sh](./postinstallscript.sh)
|
82 |
+
|
83 |
+
You can add your custom script to run when build docker image.
|
84 |
+
|
85 |
+
## Q&A
|
86 |
+
|
87 |
+
If you have any use problems, please contact to <simonlee235@gmail.com>.
|
docker/build_docker.png
ADDED
docker/docker_build.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
WORK_DIR=$(dirname "$(readlink -f "$0")")
|
4 |
+
cd $WORK_DIR
|
5 |
+
|
6 |
+
source setup_env.sh
|
7 |
+
|
8 |
+
docker build -t $TAG --build-arg BASE_TAG=$BASE_TAG --build-arg USER_NAME=$USER_NAME --build-arg USER_PASSWD=$USER_PASSWD . -f dockerfile.base
|
docker/docker_run.sh
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
WORK_DIR=$(dirname "$(readlink -f "$0")")
|
4 |
+
source $WORK_DIR/setup_env.sh
|
5 |
+
|
6 |
+
RUNNING_IDS="$(docker ps --filter ancestor=$TAG --format "{{.ID}}")"
|
7 |
+
|
8 |
+
if [ -n "$RUNNING_IDS" ]; then
|
9 |
+
# Initialize an array to hold the container IDs
|
10 |
+
declare -a container_ids=($RUNNING_IDS)
|
11 |
+
|
12 |
+
# Get the first container ID using array indexing
|
13 |
+
ID=${container_ids[0]}
|
14 |
+
|
15 |
+
# Print the first container ID
|
16 |
+
echo ' '
|
17 |
+
echo "The running container ID is: $ID, enter it!"
|
18 |
+
else
|
19 |
+
echo ' '
|
20 |
+
echo "Not found running containers, run it!"
|
21 |
+
|
22 |
+
# Run a new docker container instance
|
23 |
+
ID=$(docker run \
|
24 |
+
--rm \
|
25 |
+
--gpus all \
|
26 |
+
-itd \
|
27 |
+
--ipc=host \
|
28 |
+
--ulimit memlock=-1 \
|
29 |
+
--ulimit stack=67108864 \
|
30 |
+
-e DISPLAY=$DISPLAY \
|
31 |
+
-v /tmp/.X11-unix/:/tmp/.X11-unix/ \
|
32 |
+
-v $PWD:/home/$USER_NAME/workspace \
|
33 |
+
-w /home/$USER_NAME/workspace \
|
34 |
+
$(cat $WORK_DIR/ports.txt) \
|
35 |
+
$TAG)
|
36 |
+
fi
|
37 |
+
|
38 |
+
docker logs $ID
|
39 |
+
|
40 |
+
echo ' '
|
41 |
+
echo ' '
|
42 |
+
echo '========================================='
|
43 |
+
echo ' '
|
44 |
+
|
45 |
+
docker exec -it $ID bash
|
docker/dockerfile.base
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ARG BASE_TAG
|
2 |
+
FROM ${BASE_TAG}
|
3 |
+
ARG USER_NAME=myuser
|
4 |
+
ARG USER_PASSWD=111111
|
5 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
6 |
+
|
7 |
+
# Pre-install packages, pip install requirements and run post install script.
|
8 |
+
COPY packages.txt .
|
9 |
+
COPY requirements.txt .
|
10 |
+
COPY postinstallscript.sh .
|
11 |
+
RUN apt-get update && apt-get install -y sudo $(cat packages.txt)
|
12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
13 |
+
RUN bash postinstallscript.sh
|
14 |
+
|
15 |
+
# Create a new user and group using the username argument
|
16 |
+
RUN groupadd -r ${USER_NAME} && useradd -r -m -g${USER_NAME} ${USER_NAME}
|
17 |
+
RUN echo "${USER_NAME}:${USER_PASSWD}" | chpasswd
|
18 |
+
RUN usermod -aG sudo ${USER_NAME}
|
19 |
+
USER ${USER_NAME}
|
20 |
+
ENV USER=${USER_NAME}
|
21 |
+
WORKDIR /home/${USER_NAME}/workspace
|
22 |
+
|
23 |
+
# Set the prompt to highlight the username
|
24 |
+
RUN echo "export PS1='\[\033[01;32m\]\u\[\033[00m\]@\[\033[01;34m\]\h\[\033[00m\]:\[\033[01;36m\]\w\[\033[00m\]\$'" >> /home/${USER_NAME}/.bashrc
|
docker/packages.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
wget
|
2 |
+
curl
|
3 |
+
git
|
docker/ports.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
-p 6006:6006
|
docker/postinstallscript.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# this script will run when build docker image.
|
3 |
+
|
docker/requirements.txt
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
setuptools>=61.0
|
2 |
+
torch==2.0.1
|
3 |
+
torchvision==0.15.2
|
4 |
+
transformers==4.32.0
|
5 |
+
albumentations==1.4.0
|
6 |
+
av==11.0.0
|
7 |
+
decord==0.6.0
|
8 |
+
einops==0.3.0
|
9 |
+
fastapi==0.110.0
|
10 |
+
accelerate==0.21.0
|
11 |
+
gdown==5.1.0
|
12 |
+
h5py==3.10.0
|
13 |
+
idna==3.6
|
14 |
+
imageio==2.34.0
|
15 |
+
matplotlib==3.7.5
|
16 |
+
numpy==1.24.4
|
17 |
+
omegaconf==2.1.1
|
18 |
+
opencv-python==4.9.0.80
|
19 |
+
opencv-python-headless==4.9.0.80
|
20 |
+
pandas==2.0.3
|
21 |
+
pillow==10.2.0
|
22 |
+
pydub==0.25.1
|
23 |
+
pytorch-lightning==1.4.2
|
24 |
+
pytorchvideo==0.1.5
|
25 |
+
PyYAML==6.0.1
|
26 |
+
regex==2023.12.25
|
27 |
+
requests==2.31.0
|
28 |
+
scikit-learn==1.3.2
|
29 |
+
scipy==1.10.1
|
30 |
+
six==1.16.0
|
31 |
+
tensorboard==2.14.0
|
32 |
+
test-tube==0.7.5
|
33 |
+
timm==0.9.16
|
34 |
+
torchdiffeq==0.2.3
|
35 |
+
torchmetrics==0.5.0
|
36 |
+
tqdm==4.66.2
|
37 |
+
urllib3==2.2.1
|
38 |
+
uvicorn==0.27.1
|
39 |
+
diffusers==0.24.0
|
40 |
+
scikit-video==1.1.11
|
docker/run_docker.png
ADDED
docker/setup_env.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Docker tag for new build image
|
2 |
+
TAG=open_sora_plan:dev
|
3 |
+
|
4 |
+
# Base docker image tag used by docker build
|
5 |
+
BASE_TAG=nvcr.io/nvidia/pytorch:23.05-py3
|
6 |
+
|
7 |
+
# User name used in docker container
|
8 |
+
USER_NAME=developer
|
9 |
+
|
10 |
+
# User password used in docker container
|
11 |
+
USER_PASSWD=666666
|
docs/Contribution_Guidelines.md
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to the Open-Sora Plan Community
|
2 |
+
|
3 |
+
The Open-Sora Plan open-source community is a collaborative initiative driven by the community, emphasizing a commitment to being free and void of exploitation. Organized spontaneously by community members, we invite you to contribute to the Open-Sora Plan open-source community and help elevate it to new heights!
|
4 |
+
|
5 |
+
## Submitting a Pull Request (PR)
|
6 |
+
|
7 |
+
As a contributor, before submitting your request, kindly follow these guidelines:
|
8 |
+
|
9 |
+
1. Start by checking the [Open-Sora Plan GitHub](https://github.com/PKU-YuanGroup/Open-Sora-Plan/pulls) to see if there are any open or closed pull requests related to your intended submission. Avoid duplicating existing work.
|
10 |
+
|
11 |
+
2. [Fork](https://github.com/PKU-YuanGroup/Open-Sora-Plan/fork) the [open-sora plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) repository and download your forked repository to your local machine.
|
12 |
+
|
13 |
+
```bash
|
14 |
+
git clone [your-forked-repository-url]
|
15 |
+
```
|
16 |
+
|
17 |
+
3. Add the original Open-Sora Plan repository as a remote to sync with the latest updates:
|
18 |
+
|
19 |
+
```bash
|
20 |
+
git remote add upstream https://github.com/PKU-YuanGroup/Open-Sora-Plan
|
21 |
+
```
|
22 |
+
|
23 |
+
4. Sync the code from the main repository to your local machine, and then push it back to your forked remote repository.
|
24 |
+
|
25 |
+
```
|
26 |
+
# Pull the latest code from the upstream branch
|
27 |
+
git fetch upstream
|
28 |
+
|
29 |
+
# Switch to the main branch
|
30 |
+
git checkout main
|
31 |
+
|
32 |
+
# Merge the updates from the upstream branch into main, synchronizing the local main branch with the upstream
|
33 |
+
git merge upstream/main
|
34 |
+
|
35 |
+
# Additionally, sync the local main branch to the remote branch of your forked repository
|
36 |
+
git push origin main
|
37 |
+
```
|
38 |
+
|
39 |
+
|
40 |
+
> Note: Sync the code from the main repository before each submission.
|
41 |
+
|
42 |
+
5. Create a branch in your forked repository for your changes, ensuring the branch name is meaningful.
|
43 |
+
|
44 |
+
```bash
|
45 |
+
git checkout -b my-docs-branch main
|
46 |
+
```
|
47 |
+
|
48 |
+
6. While making modifications and committing changes, adhere to our [Commit Message Format](#Commit-Message-Format).
|
49 |
+
|
50 |
+
```bash
|
51 |
+
git commit -m "[docs]: xxxx"
|
52 |
+
```
|
53 |
+
|
54 |
+
7. Push your changes to your GitHub repository.
|
55 |
+
|
56 |
+
```bash
|
57 |
+
git push origin my-docs-branch
|
58 |
+
```
|
59 |
+
|
60 |
+
8. Submit a pull request to `Open-Sora-Plan:main` on the GitHub repository page.
|
61 |
+
|
62 |
+
## Commit Message Format
|
63 |
+
|
64 |
+
Commit messages must include both `<type>` and `<summary>` sections.
|
65 |
+
|
66 |
+
```bash
|
67 |
+
[<type>]: <summary>
|
68 |
+
│ │
|
69 |
+
│ └─⫸ Briefly describe your changes, without ending with a period.
|
70 |
+
│
|
71 |
+
└─⫸ Commit Type: |docs|feat|fix|refactor|
|
72 |
+
```
|
73 |
+
|
74 |
+
### Type
|
75 |
+
|
76 |
+
* **docs**: Modify or add documents.
|
77 |
+
* **feat**: Introduce a new feature.
|
78 |
+
* **fix**: Fix a bug.
|
79 |
+
* **refactor**: Restructure code, excluding new features or bug fixes.
|
80 |
+
|
81 |
+
### Summary
|
82 |
+
|
83 |
+
Describe modifications in English, without ending with a period.
|
84 |
+
|
85 |
+
> e.g., git commit -m "[docs]: add a contributing.md file"
|
86 |
+
|
87 |
+
This guideline is borrowed by [minisora](https://github.com/mini-sora/minisora). We sincerely appreciate MiniSora authors for their awesome templates.
|
docs/Data.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
**We need more dataset**, please refer to the [open-sora-Dataset](https://github.com/shaodong233/open-sora-Dataset) for details.
|
3 |
+
|
4 |
+
|
5 |
+
## Sky
|
6 |
+
|
7 |
+
|
8 |
+
This is an un-condition datasets. [Link](https://drive.google.com/open?id=1xWLiU-MBGN7MrsFHQm4_yXmfHBsMbJQo)
|
9 |
+
|
10 |
+
```
|
11 |
+
sky_timelapse
|
12 |
+
├── readme
|
13 |
+
├── sky_test
|
14 |
+
├── sky_train
|
15 |
+
├── test_videofolder.py
|
16 |
+
└── video_folder.py
|
17 |
+
```
|
18 |
+
|
19 |
+
## UCF101
|
20 |
+
|
21 |
+
We test the code with UCF-101 dataset. In order to download UCF-101 dataset, you can download the necessary files in [here](https://www.crcv.ucf.edu/data/UCF101.php). The code assumes a `ucf101` directory with the following structure
|
22 |
+
```
|
23 |
+
UCF-101/
|
24 |
+
ApplyEyeMakeup/
|
25 |
+
v1.avi
|
26 |
+
...
|
27 |
+
...
|
28 |
+
YoYo/
|
29 |
+
v1.avi
|
30 |
+
...
|
31 |
+
```
|
32 |
+
|
33 |
+
|
34 |
+
## Offline feature extraction
|
35 |
+
Coming soon...
|
docs/EVAL.md
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Evaluate the generated videos quality
|
2 |
+
|
3 |
+
You can easily calculate the following video quality metrics, which supports the batch-wise process.
|
4 |
+
- **CLIP-SCORE**: It uses the pretrained CLIP model to measure the cosine similarity between two modalities.
|
5 |
+
- **FVD**: Frechét Video Distance
|
6 |
+
- **SSIM**: structural similarity index measure
|
7 |
+
- **LPIPS**: learned perceptual image patch similarity
|
8 |
+
- **PSNR**: peak-signal-to-noise ratio
|
9 |
+
|
10 |
+
# Requirement
|
11 |
+
## Environment
|
12 |
+
- install Pytorch (torch>=1.7.1)
|
13 |
+
- install CLIP
|
14 |
+
```
|
15 |
+
pip install git+https://github.com/openai/CLIP.git
|
16 |
+
```
|
17 |
+
- install clip-cose from PyPi
|
18 |
+
```
|
19 |
+
pip install clip-score
|
20 |
+
```
|
21 |
+
- Other package
|
22 |
+
```
|
23 |
+
pip install lpips
|
24 |
+
pip install scipy (scipy==1.7.3/1.9.3, if you use 1.11.3, **you will calculate a WRONG FVD VALUE!!!**)
|
25 |
+
pip install numpy
|
26 |
+
pip install pillow
|
27 |
+
pip install torchvision>=0.8.2
|
28 |
+
pip install ftfy
|
29 |
+
pip install regex
|
30 |
+
pip install tqdm
|
31 |
+
```
|
32 |
+
## Pretrain model
|
33 |
+
- FVD
|
34 |
+
Before you cacluate FVD, you should first download the FVD pre-trained model. You can manually download any of the following and put it into FVD folder.
|
35 |
+
- `i3d_torchscript.pt` from [here](https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt)
|
36 |
+
- `i3d_pretrained_400.pt` from [here](https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI)
|
37 |
+
|
38 |
+
|
39 |
+
## Other Notices
|
40 |
+
1. Make sure the pixel value of videos should be in [0, 1].
|
41 |
+
2. We average SSIM when images have 3 channels, ssim is the only metric extremely sensitive to gray being compared to b/w.
|
42 |
+
3. Because the i3d model downsamples in the time dimension, `frames_num` should > 10 when calculating FVD, so FVD calculation begins from 10-th frame, like upper example.
|
43 |
+
4. For grayscale videos, we multiply to 3 channels
|
44 |
+
5. data input specifications for clip_score
|
45 |
+
> - Image Files:All images should be stored in a single directory. The image files can be in either .png or .jpg format.
|
46 |
+
>
|
47 |
+
> - Text Files: All text data should be contained in plain text files in a separate directory. These text files should have the extension .txt.
|
48 |
+
>
|
49 |
+
> Note: The number of files in the image directory should be exactly equal to the number of files in the text directory. Additionally, the files in the image directory and text directory should be paired by file name. For instance, if there is a cat.png in the image directory, there should be a corresponding cat.txt in the text directory.
|
50 |
+
>
|
51 |
+
> Directory Structure Example:
|
52 |
+
> ```
|
53 |
+
> ├── path/to/image
|
54 |
+
> │ ├── cat.png
|
55 |
+
> │ ├── dog.png
|
56 |
+
> │ └── bird.jpg
|
57 |
+
> └── path/to/text
|
58 |
+
> ├── cat.txt
|
59 |
+
> ├── dog.txt
|
60 |
+
> └── bird.txt
|
61 |
+
> ```
|
62 |
+
|
63 |
+
6. data input specifications for fvd, psnr, ssim, lpips
|
64 |
+
|
65 |
+
> Directory Structure Example:
|
66 |
+
> ```
|
67 |
+
> ├── path/to/generated_image
|
68 |
+
> │ ├── cat.mp4
|
69 |
+
> │ ├── dog.mp4
|
70 |
+
> │ └── bird.mp4
|
71 |
+
> └── path/to/real_image
|
72 |
+
> ├── cat.mp4
|
73 |
+
> ├── dog.mp4
|
74 |
+
> └── bird.mp4
|
75 |
+
> ```
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
# Usage
|
80 |
+
|
81 |
+
```
|
82 |
+
# you change the file path and need to set the frame_num, resolution etc...
|
83 |
+
|
84 |
+
# clip_score cross modality
|
85 |
+
cd opensora/eval
|
86 |
+
bash script/cal_clip_score.sh
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
# fvd
|
91 |
+
cd opensora/eval
|
92 |
+
bash script/cal_fvd.sh
|
93 |
+
|
94 |
+
# psnr
|
95 |
+
cd opensora/eval
|
96 |
+
bash eval/script/cal_psnr.sh
|
97 |
+
|
98 |
+
|
99 |
+
# ssim
|
100 |
+
cd opensora/eval
|
101 |
+
bash eval/script/cal_ssim.sh
|
102 |
+
|
103 |
+
|
104 |
+
# lpips
|
105 |
+
cd opensora/eval
|
106 |
+
bash eval/script/cal_lpips.sh
|
107 |
+
```
|
108 |
+
|
109 |
+
# Acknowledgement
|
110 |
+
The evaluation codebase refers to [clip-score](https://github.com/Taited/clip-score) and [common_metrics](https://github.com/JunyaoHu/common_metrics_on_video_quality).
|
docs/Report-v1.0.0.md
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Report v1.0.0
|
2 |
+
|
3 |
+
In March 2024, we launched a plan called Open-Sora-Plan, which aims to reproduce the OpenAI [Sora](https://openai.com/sora) through an open-source framework. As a foundational open-source framework, it enables training of video generation models, including Unconditioned Video Generation, Class Video Generation, and Text-to-Video Generation.
|
4 |
+
|
5 |
+
**Today, we are thrilled to present Open-Sora-Plan v1.0.0, which significantly enhances video generation quality and text control capabilities.**
|
6 |
+
|
7 |
+
Compared with previous video generation model, Open-Sora-Plan v1.0.0 has several improvements:
|
8 |
+
|
9 |
+
1. **Efficient training and inference with CausalVideoVAE**. We apply a spatial-temporal compression to the videos by 4×8×8.
|
10 |
+
2. **Joint image-video training for better quality**. Our CausalVideoVAE considers the first frame as an image, allowing for the simultaneous encoding of both images and videos in a natural manner. This allows the diffusion model to grasp more spatial-visual details to improve visual quality.
|
11 |
+
|
12 |
+
### Open-Source Release
|
13 |
+
We open-source the Open-Sora-Plan to facilitate future development of Video Generation in the community. Code, data, model will be made publicly available.
|
14 |
+
- Demo: Hugging Face demo [here](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.0.0). 🤝 Enjoying the [![Replicate demo and cloud API](https://replicate.com/camenduru/open-sora-plan-512x512/badge)](https://replicate.com/camenduru/open-sora-plan-512x512) and [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/Open-Sora-Plan-jupyter/blob/main/Open_Sora_Plan_jupyter.ipynb), created by [@camenduru](https://github.com/camenduru), who generously supports our research!
|
15 |
+
- Code: All training scripts and sample scripts.
|
16 |
+
- Model: Both Diffusion Model and CausalVideoVAE [here](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0).
|
17 |
+
- Data: Both raw videos and captions [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0).
|
18 |
+
|
19 |
+
## Gallery
|
20 |
+
|
21 |
+
Open-Sora-Plan v1.0.0 supports joint training of images and videos. Here, we present the capabilities of Video/Image Reconstruction and Generation:
|
22 |
+
|
23 |
+
### CausalVideoVAE Reconstruction
|
24 |
+
|
25 |
+
**Video Reconstruction** with 720×1280. Since github can't upload large video, we put it here: [1](https://streamable.com/gqojal), [2](https://streamable.com/6nu3j8).
|
26 |
+
|
27 |
+
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/c100bb02-2420-48a3-9d7b-4608a41f14aa
|
28 |
+
|
29 |
+
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/8aa8f587-d9f1-4e8b-8a82-d3bf9ba91d68
|
30 |
+
|
31 |
+
**Image Reconstruction** in 1536×1024.
|
32 |
+
|
33 |
+
<img src="https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/1684c3ec-245d-4a60-865c-b8946d788eb9" width="45%"/> <img src="https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/46ef714e-3e5b-492c-aec4-3793cb2260b5" width="45%"/>
|
34 |
+
|
35 |
+
**Text-to-Video Generation** with 65×1024×1024
|
36 |
+
|
37 |
+
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/2641a8aa-66ac-4cda-8279-86b2e6a6e011
|
38 |
+
|
39 |
+
**Text-to-Video Generation** with 65×512×512
|
40 |
+
|
41 |
+
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/37e3107e-56b3-4b09-8920-fa1d8d144b9e
|
42 |
+
|
43 |
+
|
44 |
+
**Text-to-Image Generation** with 512×512
|
45 |
+
|
46 |
+
![download](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/491d72bc-e762-48ff-bdcc-cc69350f56d6)
|
47 |
+
|
48 |
+
## Detailed Technical Report
|
49 |
+
|
50 |
+
### CausalVideoVAE
|
51 |
+
|
52 |
+
#### Model Structure
|
53 |
+
|
54 |
+
![image](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/e3c8b35d-a217-4d96-b2e9-5c248a2859c8)
|
55 |
+
|
56 |
+
The CausalVideoVAE architecture inherits from the [Stable-Diffusion Image VAE](https://github.com/CompVis/stable-diffusion/tree/main). To ensure that the pretrained weights of the Image VAE can be seamlessly applied to the Video VAE, the model structure has been designed as follows:
|
57 |
+
|
58 |
+
1. **CausalConv3D**: Converting Conv2D to CausalConv3D enables joint training of image and video data. CausalConv3D applies a special treatment to the first frame, as it does not have access to subsequent frames. For more specific details, please refer to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/145
|
59 |
+
|
60 |
+
2. **Initialization**: There are two common [methods](https://github.com/hassony2/inflated_convnets_pytorch/blob/master/src/inflate.py#L5) to expand Conv2D to Conv3D: average initialization and center initialization. But we employ a specific initialization method (tail initialization). This initialization method ensures that without any training, the model is capable of directly reconstructing images, and even videos.
|
61 |
+
|
62 |
+
#### Training Details
|
63 |
+
|
64 |
+
<img width="833" alt="image" src="https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/9ffb6dc4-23f6-4274-a066-bbebc7522a14">
|
65 |
+
|
66 |
+
We present the loss curves for two distinct initialization methods under 17×256×256. The yellow curve represents the loss using tail init, while the blue curve corresponds to the loss from center initialization. As shown in the graph, tail initialization demonstrates better performance on the loss curve. Additionally, we found that center initialization leads to error accumulation, causing the collapse over extended durations.
|
67 |
+
|
68 |
+
#### Inference Tricks
|
69 |
+
Despite the VAE in Diffusion training being frozen, we still find it challenging to afford the cost of the CausalVideoVAE. In our case, with 80GB of GPU memory, we can only infer a video of either 256×512×512 or 32×1024×1024 resolution using half-precision, which limits our ability to scale up to longer and higher-resolution videos. Therefore, we adopt tile convolution, which allows us to infer videos of arbitrary duration or resolution with nearly constant memory usage.
|
70 |
+
|
71 |
+
### Data Construction
|
72 |
+
We define a high-quality video dataset based on two core principles: (1) No content-unrelated watermarks. (2) High-quality and dense captions.
|
73 |
+
|
74 |
+
**For principles 1**, we crawled approximately 40,000 videos from open-source websites under the CC0 license. Specifically, we obtained 1,244 videos from [mixkit](https://mixkit.co/), 7,408 videos from [pexels](https://www.pexels.com/), and 31,617 videos from [pixabay](https://pixabay.com/). These videos adhere to the principle of having no content-unrelated watermarks. According to the scene transformation and clipping script provided by [Panda70M](https://github.com/snap-research/Panda-70M/blob/main/splitting/README.md), we have divided these videos into approximately 434,000 video clips. In fact, based on our clipping results, 99% of the videos obtained from these online sources are found to contain single scenes. Additionally, we have observed that over 60% of the crawled data comprises landscape videos.
|
75 |
+
|
76 |
+
**For principles 2**, it is challenging to directly crawl a large quantity of high-quality dense captions from the internet. Therefore, we utilize a mature Image-captioner model to obtain high-quality dense captions. We conducted ablation experiments on two multimodal large models: [ShareGPT4V-Captioner-7B](https://github.com/InternLM/InternLM-XComposer/blob/main/projects/ShareGPT4V/README.md) and [LLaVA-1.6-34B](https://github.com/haotian-liu/LLaVA). The former is specifically designed for caption generation, while the latter is a general-purpose multimodal large model. After conducting our ablation experiments, we found that they are comparable in performance. However, there is a significant difference in their inference speed on the A800 GPU: 40s/it of batch size of 12 for ShareGPT4V-Captioner-7B, 15s/it of batch size of 1 for LLaVA-1.6-34B. We open-source all annotations [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0)。
|
77 |
+
|
78 |
+
### Training Diffusion Model
|
79 |
+
Similar to previous work, we employ a multi-stage cascaded training approach, which consumes a total of 2,048 A800 GPU hours. We found that joint training with images significantly accelerates model convergence and enhances visual perception, aligning with the findings of [Latte](https://github.com/Vchitect/Latte). Below is our training card:
|
80 |
+
|
81 |
+
| Name | Stage 1 | Stage 2 | Stage 3 | Stage 4 |
|
82 |
+
|---|---|---|---|---|
|
83 |
+
| Training Video Size | 17×256×256 | 65×256×256 | 65×512×512 | 65×1024×1024 |
|
84 |
+
| Compute (#A800 GPU x #Hours) | 32 × 40 | 32 × 18 | 32 × 6 | Under training |
|
85 |
+
| Checkpoint | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/17x256x256) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x256x256) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x512x512) | Under training |
|
86 |
+
| Log | [wandb](https://api.wandb.ai/links/linbin/p6n3evym) | [wandb](https://api.wandb.ai/links/linbin/t2g53sew) | [wandb](https://api.wandb.ai/links/linbin/uomr0xzb) | Under training |
|
87 |
+
| Training Data | ~40k videos | ~40k videos | ~40k videos | ~40k videos |
|
88 |
+
|
89 |
+
## Next Release Preview
|
90 |
+
### CausalVideoVAE
|
91 |
+
Currently, the released version of CausalVideoVAE (v1.0.0) has two main drawbacks: **motion blurring** and **gridding effect**. We have made a series of improvements to CausalVideoVAE to reduce its inference cost and enhance its performance. We are currently referring to this enhanced version as the "preview version," which will be released in the next update. Preview reconstruction is as follows:
|
92 |
+
|
93 |
+
**1 min Video Reconstruction with 720×1280**. Since github can't put too big video, we put it here: [origin video](https://streamable.com/u4onbb), [reconstruction video](https://streamable.com/qt8ncc).
|
94 |
+
|
95 |
+
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/cdcfa9a3-4de0-42d4-94c0-0669710e407b
|
96 |
+
|
97 |
+
We randomly selected 100 samples from the validation set of Kinetics-400 for evaluation, and the results are presented in the following table:
|
98 |
+
|
99 |
+
| | SSIM↑ | LPIPS↓ | PSNR↑ | FLOLPIPS↓ |
|
100 |
+
|---|---|---|---|---|
|
101 |
+
| v1.0.0 | 0.829 | 0.106 | 27.171 | 0.119 |
|
102 |
+
| Preview | 0.877 | 0.064 | 29.695 | 0.070 |
|
103 |
+
|
104 |
+
#### Motion Blurring
|
105 |
+
|
106 |
+
| **v1.0.0** | **Preview** |
|
107 |
+
| --- | --- |
|
108 |
+
| ![6862cae0-b1b6-48d1-bd11-84348cf42b42](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/f815636f-fb38-4891-918b-50b1f9aa086d) | ![9189da06-ef2c-42e6-ad34-bd702a6f538e](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/1e413f50-a785-485a-9851-a1449f952f1c) |
|
109 |
+
|
110 |
+
#### Gridding effect
|
111 |
+
|
112 |
+
| **v1.0.0** | **Preview** |
|
113 |
+
| --- | --- |
|
114 |
+
| ![img](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/7fec5bed-3c83-4ee9-baef-4a3dacafc658) | ![img](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/4f41b432-a3ef-484e-a492-8afd8a691bf7) |
|
115 |
+
|
116 |
+
### Data Construction
|
117 |
+
|
118 |
+
**Data source**. As mentioned earlier, over 60% of our dataset consists of landscape videos. This implies that our ability to generate videos in other domains is limited. However, most of the current large-scale open-source datasets are primarily obtained through web scraping from platforms like YouTube. While these datasets provide a vast quantity of videos, we have concerns about the quality of the videos themselves. Therefore, we will continue to collect high-quality datasets and also welcome recommendations from the open-source community.
|
119 |
+
|
120 |
+
**Caption Generation Pipeline**. As the video duration increases, we need to consider more efficient methods for video caption generation instead of relying solely on large multimodal image models. We are currently developing a new video caption generation pipeline that provides robust support for long videos. We are excited to share more details with you in the near future. Stay tuned!
|
121 |
+
|
122 |
+
### Training Diffusion Model
|
123 |
+
Although v1.0.0 has shown promising results, we acknowledge that we still have a ways to go to reach the level of Sora. In our upcoming work, we will primarily focus on three aspects:
|
124 |
+
|
125 |
+
1. **Training support for dynamic resolution and duration**: We aim to develop techniques that enable training models with varying resolutions and durations, allowing for more flexible and adaptable training processes.
|
126 |
+
|
127 |
+
2. **Support for longer video generation**: We will explore methods to extend the generation capabilities of our models, enabling them to produce longer videos beyond the current limitations.
|
128 |
+
|
129 |
+
3. **Enhanced conditional control**: We seek to enhance the conditional control capabilities of our models, providing users with more options and control over the generated videos.
|
130 |
+
|
131 |
+
Furthermore, through careful observation of the generated videos, we have noticed the presence of some non-physiological speckles or abnormal flow. This can be attributed to the limited performance of CausalVideoVAE, as mentioned earlier. In future experiments, we plan to retrain a diffusion model using a more powerful version of CausalVideoVAE to address these issues.
|
docs/VQVAE.md
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VQVAE Documentation
|
2 |
+
|
3 |
+
# Introduction
|
4 |
+
|
5 |
+
Vector Quantized Variational AutoEncoders (VQ-VAE) is a type of autoencoder that uses a discrete latent representation. It is particularly useful for tasks that require discrete latent variables, such as text-to-speech and video generation.
|
6 |
+
|
7 |
+
# Usage
|
8 |
+
|
9 |
+
## Initialization
|
10 |
+
|
11 |
+
To initialize a VQVAE model, you can use the `VideoGPTVQVAE` class. This class is a part of the `opensora.models.ae` module.
|
12 |
+
|
13 |
+
```python
|
14 |
+
from opensora.models.ae import VideoGPTVQVAE
|
15 |
+
|
16 |
+
vqvae = VideoGPTVQVAE()
|
17 |
+
```
|
18 |
+
|
19 |
+
### Training
|
20 |
+
|
21 |
+
To train the VQVAE model, you can use the `train_videogpt.sh` script. This script will train the model using the parameters specified in the script.
|
22 |
+
|
23 |
+
```bash
|
24 |
+
bash scripts/videogpt/train_videogpt.sh
|
25 |
+
```
|
26 |
+
|
27 |
+
### Loading Pretrained Models
|
28 |
+
|
29 |
+
You can load a pretrained model using the `download_and_load_model` method. This method will download the checkpoint file and load the model.
|
30 |
+
|
31 |
+
```python
|
32 |
+
vqvae = VideoGPTVQVAE.download_and_load_model("bair_stride4x2x2")
|
33 |
+
```
|
34 |
+
|
35 |
+
Alternatively, you can load a model from a checkpoint using the `load_from_checkpoint` method.
|
36 |
+
|
37 |
+
```python
|
38 |
+
vqvae = VQVAEModel.load_from_checkpoint("results/VQVAE/checkpoint-1000")
|
39 |
+
```
|
40 |
+
|
41 |
+
### Encoding and Decoding
|
42 |
+
|
43 |
+
You can encode a video using the `encode` method. This method will return the encodings and embeddings of the video.
|
44 |
+
|
45 |
+
```python
|
46 |
+
encodings, embeddings = vqvae.encode(x_vae, include_embeddings=True)
|
47 |
+
```
|
48 |
+
|
49 |
+
You can reconstruct a video from its encodings using the decode method.
|
50 |
+
|
51 |
+
```python
|
52 |
+
video_recon = vqvae.decode(encodings)
|
53 |
+
```
|
54 |
+
|
55 |
+
## Testing
|
56 |
+
|
57 |
+
You can test the VQVAE model by reconstructing a video. The `examples/rec_video.py` script provides an example of how to do this.
|
examples/get_latents_std.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader, Subset
|
3 |
+
import sys
|
4 |
+
sys.path.append(".")
|
5 |
+
from opensora.models.ae.videobase import CausalVAEModel, CausalVAEDataset
|
6 |
+
|
7 |
+
num_workers = 4
|
8 |
+
batch_size = 12
|
9 |
+
|
10 |
+
torch.manual_seed(0)
|
11 |
+
torch.set_grad_enabled(False)
|
12 |
+
|
13 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
+
|
15 |
+
pretrained_model_name_or_path = 'results/causalvae/checkpoint-26000'
|
16 |
+
data_path = '/remote-home1/dataset/UCF-101'
|
17 |
+
video_num_frames = 17
|
18 |
+
resolution = 128
|
19 |
+
sample_rate = 10
|
20 |
+
|
21 |
+
vae = CausalVAEModel.load_from_checkpoint(pretrained_model_name_or_path)
|
22 |
+
vae.to(device)
|
23 |
+
|
24 |
+
dataset = CausalVAEDataset(data_path, sequence_length=video_num_frames, resolution=resolution, sample_rate=sample_rate)
|
25 |
+
subset_indices = list(range(1000))
|
26 |
+
subset_dataset = Subset(dataset, subset_indices)
|
27 |
+
loader = DataLoader(subset_dataset, batch_size=8, pin_memory=True)
|
28 |
+
|
29 |
+
all_latents = []
|
30 |
+
for video_data in loader:
|
31 |
+
video_data = video_data['video'].to(device)
|
32 |
+
latents = vae.encode(video_data).sample()
|
33 |
+
all_latents.append(video_data.cpu())
|
34 |
+
|
35 |
+
all_latents_tensor = torch.cat(all_latents)
|
36 |
+
std = all_latents_tensor.std().item()
|
37 |
+
normalizer = 1 / std
|
38 |
+
print(f'{normalizer = }')
|
examples/prompt_list_0.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
A quiet beach at dawn, the waves gently lapping at the shore and the sky painted in pastel hues.
|
2 |
+
A quiet beach at dawn, the waves softly lapping at the shore, pink and orange hues painting the sky, offering a moment of solitude and reflection.
|
3 |
+
The majestic beauty of a waterfall cascading down a cliff into a serene lake.
|
4 |
+
Sunset over the sea.
|
5 |
+
a cat wearing sunglasses and working as a lifeguard at pool.
|
6 |
+
Slow pan upward of blazing oak fire in an indoor fireplace.
|
7 |
+
Yellow and black tropical fish dart through the sea.
|
8 |
+
a serene winter scene in a forest. The forest is blanketed in a thick layer of snow, which has settled on the branches of the trees, creating a canopy of white. The trees, a mix of evergreens and deciduous, stand tall and silent, their forms partially obscured by the snow. The ground is a uniform white, with no visible tracks or signs of human activity. The sun is low in the sky, casting a warm glow that contrasts with the cool tones of the snow. The light filters through the trees, creating a soft, diffused illumination that highlights the texture of the snow and the contours of the trees. The overall style of the scene is naturalistic, with a focus on the tranquility and beauty of the winter landscape.
|
9 |
+
a dynamic interaction between the ocean and a large rock. The rock, with its rough texture and jagged edges, is partially submerged in the water, suggesting it is a natural feature of the coastline. The water around the rock is in motion, with white foam and waves crashing against the rock, indicating the force of the ocean's movement. The background is a vast expanse of the ocean, with small ripples and waves, suggesting a moderate sea state. The overall style of the scene is a realistic depiction of a natural landscape, with a focus on the interplay between the rock and the water.
|
10 |
+
A serene waterfall cascading down moss-covered rocks, its soothing sound creating a harmonious symphony with nature.
|
11 |
+
A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures.
|
12 |
+
The video captures the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene. The camera angle provides a bird's eye view of the waterfall, allowing viewers to appreciate the full height and grandeur of the waterfall. The video is a stunning representation of nature's power and beauty.
|
13 |
+
A vibrant scene of a snowy mountain landscape. The sky is filled with a multitude of colorful hot air balloons, each floating at different heights, creating a dynamic and lively atmosphere. The balloons are scattered across the sky, some closer to the viewer, others further away, adding depth to the scene. Below, the mountainous terrain is blanketed in a thick layer of snow, with a few patches of bare earth visible here and there. The snow-covered mountains provide a stark contrast to the colorful balloons, enhancing the visual appeal of the scene.
|
14 |
+
A serene underwater scene featuring a sea turtle swimming through a coral reef. The turtle, with its greenish-brown shell, is the main focus of the video, swimming gracefully towards the right side of the frame. The coral reef, teeming with life, is visible in the background, providing a vibrant and colorful backdrop to the turtle's journey. Several small fish, darting around the turtle, add a sense of movement and dynamism to the scene.
|
15 |
+
A snowy forest landscape with a dirt road running through it. The road is flanked by trees covered in snow, and the ground is also covered in snow. The sun is shining, creating a bright and serene atmosphere. The road appears to be empty, and there are no people or animals visible in the video. The style of the video is a natural landscape shot, with a focus on the beauty of the snowy forest and the peacefulness of the road.
|
16 |
+
The dynamic movement of tall, wispy grasses swaying in the wind. The sky above is filled with clouds, creating a dramatic backdrop. The sunlight pierces through the clouds, casting a warm glow on the scene. The grasses are a mix of green and brown, indicating a change in seasons. The overall style of the video is naturalistic, capturing the beauty of the landscape in a realistic manner. The focus is on the grasses and their movement, with the sky serving as a secondary element. The video does not contain any human or animal elements.
|
examples/rec_image.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append(".")
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
from torchvision.transforms import ToTensor, Compose, Resize, Normalize
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from opensora.models.ae.videobase import CausalVAEModel
|
8 |
+
import argparse
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
def preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor:
|
12 |
+
transform = Compose(
|
13 |
+
[
|
14 |
+
ToTensor(),
|
15 |
+
Normalize((0.5), (0.5)),
|
16 |
+
Resize(size=short_size),
|
17 |
+
]
|
18 |
+
)
|
19 |
+
outputs = transform(video_data)
|
20 |
+
outputs = outputs.unsqueeze(0).unsqueeze(2)
|
21 |
+
return outputs
|
22 |
+
|
23 |
+
def main(args: argparse.Namespace):
|
24 |
+
image_path = args.image_path
|
25 |
+
resolution = args.resolution
|
26 |
+
device = args.device
|
27 |
+
|
28 |
+
vqvae = CausalVAEModel.load_from_checkpoint(args.ckpt)
|
29 |
+
vqvae.eval()
|
30 |
+
vqvae = vqvae.to(device)
|
31 |
+
|
32 |
+
with torch.no_grad():
|
33 |
+
x_vae = preprocess(Image.open(image_path), resolution)
|
34 |
+
x_vae = x_vae.to(device)
|
35 |
+
latents = vqvae.encode(x_vae)
|
36 |
+
recon = vqvae.decode(latents.sample())
|
37 |
+
x = recon[0, :, 0, :, :]
|
38 |
+
x = x.squeeze()
|
39 |
+
x = x.detach().cpu().numpy()
|
40 |
+
x = np.clip(x, -1, 1)
|
41 |
+
x = (x + 1) / 2
|
42 |
+
x = (255*x).astype(np.uint8)
|
43 |
+
x = x.transpose(1,2,0)
|
44 |
+
image = Image.fromarray(x)
|
45 |
+
image.save(args.rec_path)
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == '__main__':
|
49 |
+
parser = argparse.ArgumentParser()
|
50 |
+
parser.add_argument('--image-path', type=str, default='')
|
51 |
+
parser.add_argument('--rec-path', type=str, default='')
|
52 |
+
parser.add_argument('--ckpt', type=str, default='')
|
53 |
+
parser.add_argument('--resolution', type=int, default=336)
|
54 |
+
parser.add_argument('--device', type=str, default='cuda')
|
55 |
+
|
56 |
+
args = parser.parse_args()
|
57 |
+
main(args)
|
examples/rec_imvi_vae.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import argparse
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import numpy.typing as npt
|
9 |
+
import torch
|
10 |
+
from PIL import Image
|
11 |
+
from decord import VideoReader, cpu
|
12 |
+
from torch.nn import functional as F
|
13 |
+
from pytorchvideo.transforms import ShortSideScale
|
14 |
+
from torchvision.transforms import Lambda, Compose
|
15 |
+
|
16 |
+
import sys
|
17 |
+
sys.path.append(".")
|
18 |
+
from opensora.dataset.transform import CenterCropVideo, resize
|
19 |
+
from opensora.models.ae.videobase import CausalVAEModel
|
20 |
+
|
21 |
+
|
22 |
+
def array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None:
|
23 |
+
height, width, channels = image_array[0].shape
|
24 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
25 |
+
video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))
|
26 |
+
|
27 |
+
for image in image_array:
|
28 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
29 |
+
video_writer.write(image_rgb)
|
30 |
+
|
31 |
+
video_writer.release()
|
32 |
+
|
33 |
+
def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None:
|
34 |
+
x = x.detach().cpu()
|
35 |
+
x = torch.clamp(x, -1, 1)
|
36 |
+
x = (x + 1) / 2
|
37 |
+
x = x.permute(1, 2, 3, 0).numpy()
|
38 |
+
x = (255*x).astype(np.uint8)
|
39 |
+
array_to_video(x, fps=fps, output_file=output_file)
|
40 |
+
return
|
41 |
+
|
42 |
+
def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:
|
43 |
+
decord_vr = VideoReader(video_path, ctx=cpu(0))
|
44 |
+
total_frames = len(decord_vr)
|
45 |
+
sample_frames_len = sample_rate * num_frames
|
46 |
+
|
47 |
+
if total_frames > sample_frames_len:
|
48 |
+
s = random.randint(0, total_frames - sample_frames_len - 1)
|
49 |
+
s = 0
|
50 |
+
e = s + sample_frames_len
|
51 |
+
num_frames = num_frames
|
52 |
+
else:
|
53 |
+
s = 0
|
54 |
+
e = total_frames
|
55 |
+
num_frames = int(total_frames / sample_frames_len * num_frames)
|
56 |
+
print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path,
|
57 |
+
total_frames)
|
58 |
+
|
59 |
+
|
60 |
+
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
|
61 |
+
video_data = decord_vr.get_batch(frame_id_list).asnumpy()
|
62 |
+
video_data = torch.from_numpy(video_data)
|
63 |
+
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
|
64 |
+
return video_data
|
65 |
+
|
66 |
+
|
67 |
+
class ResizeVideo:
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
size,
|
71 |
+
interpolation_mode="bilinear",
|
72 |
+
):
|
73 |
+
self.size = size
|
74 |
+
|
75 |
+
self.interpolation_mode = interpolation_mode
|
76 |
+
|
77 |
+
def __call__(self, clip):
|
78 |
+
_, _, h, w = clip.shape
|
79 |
+
if w < h:
|
80 |
+
new_h = int(math.floor((float(h) / w) * self.size))
|
81 |
+
new_w = self.size
|
82 |
+
else:
|
83 |
+
new_h = self.size
|
84 |
+
new_w = int(math.floor((float(w) / h) * self.size))
|
85 |
+
return torch.nn.functional.interpolate(
|
86 |
+
clip, size=(new_h, new_w), mode=self.interpolation_mode, align_corners=False, antialias=True
|
87 |
+
)
|
88 |
+
|
89 |
+
def __repr__(self) -> str:
|
90 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
91 |
+
|
92 |
+
def preprocess(video_data: torch.Tensor, short_size: int = 128, crop_size: Optional[int] = None) -> torch.Tensor:
|
93 |
+
|
94 |
+
transform = Compose(
|
95 |
+
[
|
96 |
+
Lambda(lambda x: ((x / 255.0) * 2 - 1)),
|
97 |
+
ResizeVideo(size=short_size),
|
98 |
+
CenterCropVideo(crop_size) if crop_size is not None else Lambda(lambda x: x),
|
99 |
+
]
|
100 |
+
)
|
101 |
+
|
102 |
+
video_outputs = transform(video_data)
|
103 |
+
video_outputs = torch.unsqueeze(video_outputs, 0)
|
104 |
+
|
105 |
+
return video_outputs
|
106 |
+
|
107 |
+
|
108 |
+
def main(args: argparse.Namespace):
|
109 |
+
video_path = args.video_path
|
110 |
+
num_frames = args.num_frames
|
111 |
+
resolution = args.resolution
|
112 |
+
crop_size = args.crop_size
|
113 |
+
sample_fps = args.sample_fps
|
114 |
+
sample_rate = args.sample_rate
|
115 |
+
device = args.device
|
116 |
+
vqvae = CausalVAEModel.from_pretrained(args.ckpt)
|
117 |
+
if args.enable_tiling:
|
118 |
+
vqvae.enable_tiling()
|
119 |
+
vqvae.tile_overlap_factor = args.tile_overlap_factor
|
120 |
+
vqvae.eval()
|
121 |
+
vqvae = vqvae.to(device)
|
122 |
+
vqvae = vqvae.to(torch.float16)
|
123 |
+
|
124 |
+
with torch.no_grad():
|
125 |
+
x_vae = preprocess(read_video(video_path, num_frames, sample_rate), resolution, crop_size)
|
126 |
+
x_vae = x_vae.to(device) # b c t h w
|
127 |
+
x_vae = x_vae.to(torch.float16)
|
128 |
+
latents = vqvae.encode(x_vae).sample().to(torch.float16)
|
129 |
+
video_recon = vqvae.decode(latents)
|
130 |
+
|
131 |
+
if video_recon.shape[2] == 1:
|
132 |
+
x = video_recon[0, :, 0, :, :]
|
133 |
+
x = x.squeeze()
|
134 |
+
x = x.detach().cpu().numpy()
|
135 |
+
x = np.clip(x, -1, 1)
|
136 |
+
x = (x + 1) / 2
|
137 |
+
x = (255 * x).astype(np.uint8)
|
138 |
+
x = x.transpose(1, 2, 0)
|
139 |
+
image = Image.fromarray(x)
|
140 |
+
image.save(args.rec_path.replace('mp4', 'jpg'))
|
141 |
+
else:
|
142 |
+
custom_to_video(video_recon[0], fps=sample_fps/sample_rate, output_file=args.rec_path)
|
143 |
+
|
144 |
+
if __name__ == '__main__':
|
145 |
+
parser = argparse.ArgumentParser()
|
146 |
+
parser.add_argument('--video-path', type=str, default='')
|
147 |
+
parser.add_argument('--rec-path', type=str, default='')
|
148 |
+
parser.add_argument('--ckpt', type=str, default='results/pretrained')
|
149 |
+
parser.add_argument('--sample-fps', type=int, default=30)
|
150 |
+
parser.add_argument('--resolution', type=int, default=336)
|
151 |
+
parser.add_argument('--crop-size', type=int, default=None)
|
152 |
+
parser.add_argument('--num-frames', type=int, default=100)
|
153 |
+
parser.add_argument('--sample-rate', type=int, default=1)
|
154 |
+
parser.add_argument('--device', type=str, default="cuda")
|
155 |
+
parser.add_argument('--tile_overlap_factor', type=float, default=0.25)
|
156 |
+
parser.add_argument('--enable_tiling', action='store_true')
|
157 |
+
|
158 |
+
args = parser.parse_args()
|
159 |
+
main(args)
|
examples/rec_video.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import argparse
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import imageio
|
7 |
+
import numpy as np
|
8 |
+
import numpy.typing as npt
|
9 |
+
import torch
|
10 |
+
from decord import VideoReader, cpu
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from pytorchvideo.transforms import ShortSideScale
|
13 |
+
from torchvision.transforms import Lambda, Compose
|
14 |
+
from torchvision.transforms._transforms_video import RandomCropVideo
|
15 |
+
|
16 |
+
import sys
|
17 |
+
sys.path.append(".")
|
18 |
+
from opensora.models.ae import VQVAEModel
|
19 |
+
|
20 |
+
|
21 |
+
def array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None:
|
22 |
+
height, width, channels = image_array[0].shape
|
23 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore
|
24 |
+
video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))
|
25 |
+
|
26 |
+
for image in image_array:
|
27 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
28 |
+
video_writer.write(image_rgb)
|
29 |
+
|
30 |
+
video_writer.release()
|
31 |
+
|
32 |
+
def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None:
|
33 |
+
x = x.detach().cpu()
|
34 |
+
x = torch.clamp(x, -0.5, 0.5)
|
35 |
+
x = (x + 0.5)
|
36 |
+
x = x.permute(1, 2, 3, 0).numpy() # (C, T, H, W) -> (T, H, W, C)
|
37 |
+
x = (255*x).astype(np.uint8)
|
38 |
+
# array_to_video(x, fps=fps, output_file=output_file)
|
39 |
+
imageio.mimwrite(output_file, x, fps=fps, quality=9)
|
40 |
+
return
|
41 |
+
|
42 |
+
def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:
|
43 |
+
decord_vr = VideoReader(video_path, ctx=cpu(0))
|
44 |
+
total_frames = len(decord_vr)
|
45 |
+
sample_frames_len = sample_rate * num_frames
|
46 |
+
|
47 |
+
if total_frames > sample_frames_len:
|
48 |
+
s = random.randint(0, total_frames - sample_frames_len - 1)
|
49 |
+
e = s + sample_frames_len
|
50 |
+
num_frames = num_frames
|
51 |
+
else:
|
52 |
+
s = 0
|
53 |
+
e = total_frames
|
54 |
+
num_frames = int(total_frames / sample_frames_len * num_frames)
|
55 |
+
print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path,
|
56 |
+
total_frames)
|
57 |
+
|
58 |
+
|
59 |
+
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
|
60 |
+
video_data = decord_vr.get_batch(frame_id_list).asnumpy()
|
61 |
+
video_data = torch.from_numpy(video_data)
|
62 |
+
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
|
63 |
+
return video_data
|
64 |
+
|
65 |
+
def preprocess(video_data: torch.Tensor, short_size: int = 128, crop_size: Optional[int] = None) -> torch.Tensor:
|
66 |
+
|
67 |
+
transform = Compose(
|
68 |
+
[
|
69 |
+
# UniformTemporalSubsample(num_frames),
|
70 |
+
Lambda(lambda x: ((x / 255.0) - 0.5)),
|
71 |
+
# NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD),
|
72 |
+
ShortSideScale(size=short_size),
|
73 |
+
RandomCropVideo(size=crop_size) if crop_size is not None else Lambda(lambda x: x),
|
74 |
+
# RandomHorizontalFlipVideo(p=0.5),
|
75 |
+
]
|
76 |
+
)
|
77 |
+
|
78 |
+
video_outputs = transform(video_data)
|
79 |
+
video_outputs = torch.unsqueeze(video_outputs, 0)
|
80 |
+
|
81 |
+
return video_outputs
|
82 |
+
|
83 |
+
|
84 |
+
def main(args: argparse.Namespace):
|
85 |
+
video_path = args.video_path
|
86 |
+
num_frames = args.num_frames
|
87 |
+
resolution = args.resolution
|
88 |
+
crop_size = args.crop_size
|
89 |
+
sample_fps = args.sample_fps
|
90 |
+
sample_rate = args.sample_rate
|
91 |
+
device = torch.device('cuda')
|
92 |
+
if args.ckpt in ['bair_stride4x2x2', 'ucf101_stride4x4x4', 'kinetics_stride4x4x4', 'kinetics_stride2x4x4']:
|
93 |
+
vqvae = VQVAEModel.download_and_load_model(args.ckpt)
|
94 |
+
else:
|
95 |
+
vqvae = VQVAEModel.load_from_checkpoint(args.ckpt)
|
96 |
+
vqvae.eval()
|
97 |
+
vqvae = vqvae.to(device)
|
98 |
+
|
99 |
+
with torch.no_grad():
|
100 |
+
x_vae = preprocess(read_video(video_path, num_frames, sample_rate), resolution, crop_size)
|
101 |
+
x_vae = x_vae.to(device)
|
102 |
+
encodings, embeddings = vqvae.encode(x_vae, include_embeddings=True)
|
103 |
+
video_recon = vqvae.decode(encodings)
|
104 |
+
|
105 |
+
# custom_to_video(x_vae[0], fps=sample_fps/sample_rate, output_file='origin_input.mp4')
|
106 |
+
custom_to_video(video_recon[0], fps=sample_fps/sample_rate, output_file=args.rec_path)
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == '__main__':
|
110 |
+
parser = argparse.ArgumentParser()
|
111 |
+
parser.add_argument('--video-path', type=str, default='')
|
112 |
+
parser.add_argument('--rec-path', type=str, default='')
|
113 |
+
parser.add_argument('--ckpt', type=str, default='ucf101_stride4x4x4')
|
114 |
+
parser.add_argument('--sample-fps', type=int, default=30)
|
115 |
+
parser.add_argument('--resolution', type=int, default=336)
|
116 |
+
parser.add_argument('--crop-size', type=int, default=None)
|
117 |
+
parser.add_argument('--num-frames', type=int, default=100)
|
118 |
+
parser.add_argument('--sample-rate', type=int, default=1)
|
119 |
+
args = parser.parse_args()
|
120 |
+
main(args)
|
examples/rec_video_ae.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import argparse
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import imageio
|
7 |
+
import numpy as np
|
8 |
+
import numpy.typing as npt
|
9 |
+
import torch
|
10 |
+
from decord import VideoReader, cpu
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from pytorchvideo.transforms import ShortSideScale
|
13 |
+
from torchvision.transforms import Lambda, Compose
|
14 |
+
from torchvision.transforms._transforms_video import RandomCropVideo
|
15 |
+
|
16 |
+
import sys
|
17 |
+
sys.path.append(".")
|
18 |
+
from opensora.models.ae import VQVAEModel
|
19 |
+
|
20 |
+
|
21 |
+
def array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None:
|
22 |
+
height, width, channels = image_array[0].shape
|
23 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore
|
24 |
+
video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))
|
25 |
+
|
26 |
+
for image in image_array:
|
27 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
28 |
+
video_writer.write(image_rgb)
|
29 |
+
|
30 |
+
video_writer.release()
|
31 |
+
|
32 |
+
def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None:
|
33 |
+
x = x.detach().cpu()
|
34 |
+
x = torch.clamp(x, -0.5, 0.5)
|
35 |
+
x = (x + 0.5)
|
36 |
+
x = x.permute(1, 2, 3, 0).numpy() # (C, T, H, W) -> (T, H, W, C)
|
37 |
+
x = (255*x).astype(np.uint8)
|
38 |
+
# array_to_video(x, fps=fps, output_file=output_file)
|
39 |
+
imageio.mimwrite(output_file, x, fps=fps, quality=9)
|
40 |
+
return
|
41 |
+
|
42 |
+
def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:
|
43 |
+
decord_vr = VideoReader(video_path, ctx=cpu(0))
|
44 |
+
total_frames = len(decord_vr)
|
45 |
+
sample_frames_len = sample_rate * num_frames
|
46 |
+
|
47 |
+
if total_frames > sample_frames_len:
|
48 |
+
s = random.randint(0, total_frames - sample_frames_len - 1)
|
49 |
+
e = s + sample_frames_len
|
50 |
+
num_frames = num_frames
|
51 |
+
else:
|
52 |
+
s = 0
|
53 |
+
e = total_frames
|
54 |
+
num_frames = int(total_frames / sample_frames_len * num_frames)
|
55 |
+
print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path,
|
56 |
+
total_frames)
|
57 |
+
|
58 |
+
|
59 |
+
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
|
60 |
+
video_data = decord_vr.get_batch(frame_id_list).asnumpy()
|
61 |
+
video_data = torch.from_numpy(video_data)
|
62 |
+
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
|
63 |
+
return video_data
|
64 |
+
|
65 |
+
def preprocess(video_data: torch.Tensor, short_size: int = 128, crop_size: Optional[int] = None) -> torch.Tensor:
|
66 |
+
|
67 |
+
transform = Compose(
|
68 |
+
[
|
69 |
+
# UniformTemporalSubsample(num_frames),
|
70 |
+
Lambda(lambda x: ((x / 255.0) - 0.5)),
|
71 |
+
# NormalizeVideo(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD),
|
72 |
+
ShortSideScale(size=short_size),
|
73 |
+
RandomCropVideo(size=crop_size) if crop_size is not None else Lambda(lambda x: x),
|
74 |
+
# RandomHorizontalFlipVideo(p=0.5),
|
75 |
+
]
|
76 |
+
)
|
77 |
+
|
78 |
+
video_outputs = transform(video_data)
|
79 |
+
video_outputs = torch.unsqueeze(video_outputs, 0)
|
80 |
+
|
81 |
+
return video_outputs
|
82 |
+
|
83 |
+
|
84 |
+
def main(args: argparse.Namespace):
|
85 |
+
video_path = args.video_path
|
86 |
+
num_frames = args.num_frames
|
87 |
+
resolution = args.resolution
|
88 |
+
crop_size = args.crop_size
|
89 |
+
sample_fps = args.sample_fps
|
90 |
+
sample_rate = args.sample_rate
|
91 |
+
device = torch.device('cuda')
|
92 |
+
if args.ckpt in ['bair_stride4x2x2', 'ucf101_stride4x4x4', 'kinetics_stride4x4x4', 'kinetics_stride2x4x4']:
|
93 |
+
vqvae = VQVAEModel.download_and_load_model(args.ckpt)
|
94 |
+
else:
|
95 |
+
vqvae = VQVAEModel.load_from_checkpoint(args.ckpt)
|
96 |
+
vqvae.eval()
|
97 |
+
vqvae = vqvae.to(device)
|
98 |
+
|
99 |
+
with torch.no_grad():
|
100 |
+
x_vae = preprocess(read_video(video_path, num_frames, sample_rate), resolution, crop_size)
|
101 |
+
x_vae = x_vae.to(device)
|
102 |
+
encodings, embeddings = vqvae.encode(x_vae, include_embeddings=True)
|
103 |
+
video_recon = vqvae.decode(encodings)
|
104 |
+
|
105 |
+
# custom_to_video(x_vae[0], fps=sample_fps/sample_rate, output_file='origin_input.mp4')
|
106 |
+
custom_to_video(video_recon[0], fps=sample_fps/sample_rate, output_file=args.rec_path)
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == '__main__':
|
110 |
+
parser = argparse.ArgumentParser()
|
111 |
+
parser.add_argument('--video-path', type=str, default='')
|
112 |
+
parser.add_argument('--rec-path', type=str, default='')
|
113 |
+
parser.add_argument('--ckpt', type=str, default='ucf101_stride4x4x4')
|
114 |
+
parser.add_argument('--sample-fps', type=int, default=30)
|
115 |
+
parser.add_argument('--resolution', type=int, default=336)
|
116 |
+
parser.add_argument('--crop-size', type=int, default=None)
|
117 |
+
parser.add_argument('--num-frames', type=int, default=100)
|
118 |
+
parser.add_argument('--sample-rate', type=int, default=1)
|
119 |
+
args = parser.parse_args()
|
120 |
+
main(args)
|
examples/rec_video_vae.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import argparse
|
3 |
+
import cv2
|
4 |
+
from tqdm import tqdm
|
5 |
+
import numpy as np
|
6 |
+
import numpy.typing as npt
|
7 |
+
import torch
|
8 |
+
from decord import VideoReader, cpu
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from pytorchvideo.transforms import ShortSideScale
|
11 |
+
from torchvision.transforms import Lambda, Compose
|
12 |
+
from torchvision.transforms._transforms_video import CenterCropVideo
|
13 |
+
import sys
|
14 |
+
from torch.utils.data import Dataset, DataLoader, Subset
|
15 |
+
import os
|
16 |
+
|
17 |
+
sys.path.append(".")
|
18 |
+
from opensora.models.ae.videobase import CausalVAEModel
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
|
22 |
+
def array_to_video(
|
23 |
+
image_array: npt.NDArray, fps: float = 30.0, output_file: str = "output_video.mp4"
|
24 |
+
) -> None:
|
25 |
+
height, width, channels = image_array[0].shape
|
26 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
27 |
+
video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))
|
28 |
+
|
29 |
+
for image in image_array:
|
30 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
31 |
+
video_writer.write(image_rgb)
|
32 |
+
|
33 |
+
video_writer.release()
|
34 |
+
|
35 |
+
|
36 |
+
def custom_to_video(
|
37 |
+
x: torch.Tensor, fps: float = 2.0, output_file: str = "output_video.mp4"
|
38 |
+
) -> None:
|
39 |
+
x = x.detach().cpu()
|
40 |
+
x = torch.clamp(x, -1, 1)
|
41 |
+
x = (x + 1) / 2
|
42 |
+
x = x.permute(1, 2, 3, 0).float().numpy()
|
43 |
+
x = (255 * x).astype(np.uint8)
|
44 |
+
array_to_video(x, fps=fps, output_file=output_file)
|
45 |
+
return
|
46 |
+
|
47 |
+
|
48 |
+
def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:
|
49 |
+
decord_vr = VideoReader(video_path, ctx=cpu(0), num_threads=8)
|
50 |
+
total_frames = len(decord_vr)
|
51 |
+
sample_frames_len = sample_rate * num_frames
|
52 |
+
|
53 |
+
if total_frames > sample_frames_len:
|
54 |
+
s = 0
|
55 |
+
e = s + sample_frames_len
|
56 |
+
num_frames = num_frames
|
57 |
+
else:
|
58 |
+
s = 0
|
59 |
+
e = total_frames
|
60 |
+
num_frames = int(total_frames / sample_frames_len * num_frames)
|
61 |
+
print(
|
62 |
+
f"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}",
|
63 |
+
video_path,
|
64 |
+
total_frames,
|
65 |
+
)
|
66 |
+
|
67 |
+
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
|
68 |
+
video_data = decord_vr.get_batch(frame_id_list).asnumpy()
|
69 |
+
video_data = torch.from_numpy(video_data)
|
70 |
+
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
|
71 |
+
return video_data
|
72 |
+
|
73 |
+
|
74 |
+
class RealVideoDataset(Dataset):
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
real_video_dir,
|
78 |
+
num_frames,
|
79 |
+
sample_rate=1,
|
80 |
+
crop_size=None,
|
81 |
+
resolution=128,
|
82 |
+
) -> None:
|
83 |
+
super().__init__()
|
84 |
+
self.real_video_files = self._combine_without_prefix(real_video_dir)
|
85 |
+
self.num_frames = num_frames
|
86 |
+
self.sample_rate = sample_rate
|
87 |
+
self.crop_size = crop_size
|
88 |
+
self.short_size = resolution
|
89 |
+
|
90 |
+
def __len__(self):
|
91 |
+
return len(self.real_video_files)
|
92 |
+
|
93 |
+
def __getitem__(self, index):
|
94 |
+
if index >= len(self):
|
95 |
+
raise IndexError
|
96 |
+
real_video_file = self.real_video_files[index]
|
97 |
+
real_video_tensor = self._load_video(real_video_file)
|
98 |
+
video_name = os.path.basename(real_video_file)
|
99 |
+
return {'video': real_video_tensor, 'file_name': video_name }
|
100 |
+
|
101 |
+
def _load_video(self, video_path):
|
102 |
+
num_frames = self.num_frames
|
103 |
+
sample_rate = self.sample_rate
|
104 |
+
decord_vr = VideoReader(video_path, ctx=cpu(0))
|
105 |
+
total_frames = len(decord_vr)
|
106 |
+
sample_frames_len = sample_rate * num_frames
|
107 |
+
|
108 |
+
if total_frames > sample_frames_len:
|
109 |
+
s = 0
|
110 |
+
e = s + sample_frames_len
|
111 |
+
num_frames = num_frames
|
112 |
+
else:
|
113 |
+
s = 0
|
114 |
+
e = total_frames
|
115 |
+
num_frames = int(total_frames / sample_frames_len * num_frames)
|
116 |
+
print(
|
117 |
+
f"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}",
|
118 |
+
video_path,
|
119 |
+
total_frames,
|
120 |
+
)
|
121 |
+
|
122 |
+
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
|
123 |
+
video_data = decord_vr.get_batch(frame_id_list).asnumpy()
|
124 |
+
video_data = torch.from_numpy(video_data)
|
125 |
+
video_data = video_data.permute(3, 0, 1, 2)
|
126 |
+
return _preprocess(
|
127 |
+
video_data, short_size=self.short_size, crop_size=self.crop_size
|
128 |
+
)
|
129 |
+
|
130 |
+
def _combine_without_prefix(self, folder_path, prefix="."):
|
131 |
+
folder = []
|
132 |
+
for name in os.listdir(folder_path):
|
133 |
+
if name[0] == prefix:
|
134 |
+
continue
|
135 |
+
folder.append(os.path.join(folder_path, name))
|
136 |
+
folder.sort()
|
137 |
+
return folder
|
138 |
+
|
139 |
+
def resize(x, resolution):
|
140 |
+
height, width = x.shape[-2:]
|
141 |
+
aspect_ratio = width / height
|
142 |
+
if width <= height:
|
143 |
+
new_width = resolution
|
144 |
+
new_height = int(resolution / aspect_ratio)
|
145 |
+
else:
|
146 |
+
new_height = resolution
|
147 |
+
new_width = int(resolution * aspect_ratio)
|
148 |
+
resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True)
|
149 |
+
return resized_x
|
150 |
+
|
151 |
+
def _preprocess(video_data, short_size=128, crop_size=None):
|
152 |
+
transform = Compose(
|
153 |
+
[
|
154 |
+
Lambda(lambda x: ((x / 255.0) * 2 - 1)),
|
155 |
+
Lambda(lambda x: resize(x, short_size)),
|
156 |
+
(
|
157 |
+
CenterCropVideo(crop_size=crop_size)
|
158 |
+
if crop_size is not None
|
159 |
+
else Lambda(lambda x: x)
|
160 |
+
),
|
161 |
+
]
|
162 |
+
)
|
163 |
+
video_outputs = transform(video_data)
|
164 |
+
video_outputs = _format_video_shape(video_outputs)
|
165 |
+
return video_outputs
|
166 |
+
|
167 |
+
|
168 |
+
def _format_video_shape(video, time_compress=4, spatial_compress=8):
|
169 |
+
time = video.shape[1]
|
170 |
+
height = video.shape[2]
|
171 |
+
width = video.shape[3]
|
172 |
+
new_time = (
|
173 |
+
(time - (time - 1) % time_compress)
|
174 |
+
if (time - 1) % time_compress != 0
|
175 |
+
else time
|
176 |
+
)
|
177 |
+
new_height = (
|
178 |
+
(height - (height) % spatial_compress)
|
179 |
+
if height % spatial_compress != 0
|
180 |
+
else height
|
181 |
+
)
|
182 |
+
new_width = (
|
183 |
+
(width - (width) % spatial_compress) if width % spatial_compress != 0 else width
|
184 |
+
)
|
185 |
+
return video[:, :new_time, :new_height, :new_width]
|
186 |
+
|
187 |
+
|
188 |
+
@torch.no_grad()
|
189 |
+
def main(args: argparse.Namespace):
|
190 |
+
real_video_dir = args.real_video_dir
|
191 |
+
generated_video_dir = args.generated_video_dir
|
192 |
+
ckpt = args.ckpt
|
193 |
+
sample_rate = args.sample_rate
|
194 |
+
resolution = args.resolution
|
195 |
+
crop_size = args.crop_size
|
196 |
+
num_frames = args.num_frames
|
197 |
+
sample_rate = args.sample_rate
|
198 |
+
device = args.device
|
199 |
+
sample_fps = args.sample_fps
|
200 |
+
batch_size = args.batch_size
|
201 |
+
num_workers = args.num_workers
|
202 |
+
subset_size = args.subset_size
|
203 |
+
|
204 |
+
if not os.path.exists(args.generated_video_dir):
|
205 |
+
os.makedirs(args.generated_video_dir, exist_ok=True)
|
206 |
+
|
207 |
+
data_type = torch.bfloat16
|
208 |
+
|
209 |
+
# ---- Load Model ----
|
210 |
+
device = args.device
|
211 |
+
vqvae = CausalVAEModel.from_pretrained(args.ckpt)
|
212 |
+
vqvae = vqvae.to(device).to(data_type)
|
213 |
+
if args.enable_tiling:
|
214 |
+
vqvae.enable_tiling()
|
215 |
+
vqvae.tile_overlap_factor = args.tile_overlap_factor
|
216 |
+
# ---- Load Model ----
|
217 |
+
|
218 |
+
# ---- Prepare Dataset ----
|
219 |
+
dataset = RealVideoDataset(
|
220 |
+
real_video_dir=real_video_dir,
|
221 |
+
num_frames=num_frames,
|
222 |
+
sample_rate=sample_rate,
|
223 |
+
crop_size=crop_size,
|
224 |
+
resolution=resolution,
|
225 |
+
)
|
226 |
+
|
227 |
+
if subset_size:
|
228 |
+
indices = range(subset_size)
|
229 |
+
dataset = Subset(dataset, indices=indices)
|
230 |
+
|
231 |
+
dataloader = DataLoader(
|
232 |
+
dataset, batch_size=batch_size, pin_memory=True, num_workers=num_workers
|
233 |
+
)
|
234 |
+
# ---- Prepare Dataset
|
235 |
+
|
236 |
+
# ---- Inference ----
|
237 |
+
for batch in tqdm(dataloader):
|
238 |
+
x, file_names = batch['video'], batch['file_name']
|
239 |
+
x = x.to(device=device, dtype=data_type) # b c t h w
|
240 |
+
latents = vqvae.encode(x).sample().to(data_type)
|
241 |
+
video_recon = vqvae.decode(latents)
|
242 |
+
for idx, video in enumerate(video_recon):
|
243 |
+
output_path = os.path.join(generated_video_dir, file_names[idx])
|
244 |
+
if args.output_origin:
|
245 |
+
os.makedirs(os.path.join(generated_video_dir, "origin/"), exist_ok=True)
|
246 |
+
origin_output_path = os.path.join(generated_video_dir, "origin/", file_names[idx])
|
247 |
+
custom_to_video(
|
248 |
+
x[idx], fps=sample_fps / sample_rate, output_file=origin_output_path
|
249 |
+
)
|
250 |
+
custom_to_video(
|
251 |
+
video, fps=sample_fps / sample_rate, output_file=output_path
|
252 |
+
)
|
253 |
+
# ---- Inference ----
|
254 |
+
|
255 |
+
if __name__ == "__main__":
|
256 |
+
parser = argparse.ArgumentParser()
|
257 |
+
parser.add_argument("--real_video_dir", type=str, default="")
|
258 |
+
parser.add_argument("--generated_video_dir", type=str, default="")
|
259 |
+
parser.add_argument("--ckpt", type=str, default="")
|
260 |
+
parser.add_argument("--sample_fps", type=int, default=30)
|
261 |
+
parser.add_argument("--resolution", type=int, default=336)
|
262 |
+
parser.add_argument("--crop_size", type=int, default=None)
|
263 |
+
parser.add_argument("--num_frames", type=int, default=17)
|
264 |
+
parser.add_argument("--sample_rate", type=int, default=1)
|
265 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
266 |
+
parser.add_argument("--num_workers", type=int, default=8)
|
267 |
+
parser.add_argument("--subset_size", type=int, default=None)
|
268 |
+
parser.add_argument("--tile_overlap_factor", type=float, default=0.25)
|
269 |
+
parser.add_argument('--enable_tiling', action='store_true')
|
270 |
+
parser.add_argument('--output_origin', action='store_true')
|
271 |
+
parser.add_argument("--device", type=str, default="cuda")
|
272 |
+
|
273 |
+
args = parser.parse_args()
|
274 |
+
main(args)
|
275 |
+
|
opensora/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
#
|
opensora/dataset/__init__.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.transforms import Compose
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
|
4 |
+
from .feature_datasets import T2V_Feature_dataset, T2V_T5_Feature_dataset
|
5 |
+
from torchvision import transforms
|
6 |
+
from torchvision.transforms import Lambda
|
7 |
+
|
8 |
+
from .landscope import Landscope
|
9 |
+
from .t2v_datasets import T2V_dataset
|
10 |
+
from .transform import ToTensorVideo, TemporalRandomCrop, RandomHorizontalFlipVideo, CenterCropResizeVideo
|
11 |
+
from .ucf101 import UCF101
|
12 |
+
from .sky_datasets import Sky
|
13 |
+
|
14 |
+
ae_norm = {
|
15 |
+
'CausalVAEModel_4x8x8': Lambda(lambda x: 2. * x - 1.),
|
16 |
+
'CausalVQVAEModel_4x4x4': Lambda(lambda x: x - 0.5),
|
17 |
+
'CausalVQVAEModel_4x8x8': Lambda(lambda x: x - 0.5),
|
18 |
+
'VQVAEModel_4x4x4': Lambda(lambda x: x - 0.5),
|
19 |
+
'VQVAEModel_4x8x8': Lambda(lambda x: x - 0.5),
|
20 |
+
"bair_stride4x2x2": Lambda(lambda x: x - 0.5),
|
21 |
+
"ucf101_stride4x4x4": Lambda(lambda x: x - 0.5),
|
22 |
+
"kinetics_stride4x4x4": Lambda(lambda x: x - 0.5),
|
23 |
+
"kinetics_stride2x4x4": Lambda(lambda x: x - 0.5),
|
24 |
+
'stabilityai/sd-vae-ft-mse': transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
25 |
+
'stabilityai/sd-vae-ft-ema': transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
26 |
+
'vqgan_imagenet_f16_1024': Lambda(lambda x: 2. * x - 1.),
|
27 |
+
'vqgan_imagenet_f16_16384': Lambda(lambda x: 2. * x - 1.),
|
28 |
+
'vqgan_gumbel_f8': Lambda(lambda x: 2. * x - 1.),
|
29 |
+
|
30 |
+
}
|
31 |
+
ae_denorm = {
|
32 |
+
'CausalVAEModel_4x8x8': lambda x: (x + 1.) / 2.,
|
33 |
+
'CausalVQVAEModel_4x4x4': lambda x: x + 0.5,
|
34 |
+
'CausalVQVAEModel_4x8x8': lambda x: x + 0.5,
|
35 |
+
'VQVAEModel_4x4x4': lambda x: x + 0.5,
|
36 |
+
'VQVAEModel_4x8x8': lambda x: x + 0.5,
|
37 |
+
"bair_stride4x2x2": lambda x: x + 0.5,
|
38 |
+
"ucf101_stride4x4x4": lambda x: x + 0.5,
|
39 |
+
"kinetics_stride4x4x4": lambda x: x + 0.5,
|
40 |
+
"kinetics_stride2x4x4": lambda x: x + 0.5,
|
41 |
+
'stabilityai/sd-vae-ft-mse': lambda x: 0.5 * x + 0.5,
|
42 |
+
'stabilityai/sd-vae-ft-ema': lambda x: 0.5 * x + 0.5,
|
43 |
+
'vqgan_imagenet_f16_1024': lambda x: (x + 1.) / 2.,
|
44 |
+
'vqgan_imagenet_f16_16384': lambda x: (x + 1.) / 2.,
|
45 |
+
'vqgan_gumbel_f8': lambda x: (x + 1.) / 2.,
|
46 |
+
}
|
47 |
+
|
48 |
+
def getdataset(args):
|
49 |
+
temporal_sample = TemporalRandomCrop(args.num_frames * args.sample_rate) # 16 x
|
50 |
+
norm_fun = ae_norm[args.ae]
|
51 |
+
if args.dataset == 'ucf101':
|
52 |
+
transform = Compose(
|
53 |
+
[
|
54 |
+
ToTensorVideo(), # TCHW
|
55 |
+
CenterCropResizeVideo(size=args.max_image_size),
|
56 |
+
RandomHorizontalFlipVideo(p=0.5),
|
57 |
+
norm_fun,
|
58 |
+
]
|
59 |
+
)
|
60 |
+
return UCF101(args, transform=transform, temporal_sample=temporal_sample)
|
61 |
+
if args.dataset == 'landscope':
|
62 |
+
transform = Compose(
|
63 |
+
[
|
64 |
+
ToTensorVideo(), # TCHW
|
65 |
+
CenterCropResizeVideo(size=args.max_image_size),
|
66 |
+
RandomHorizontalFlipVideo(p=0.5),
|
67 |
+
norm_fun,
|
68 |
+
]
|
69 |
+
)
|
70 |
+
return Landscope(args, transform=transform, temporal_sample=temporal_sample)
|
71 |
+
elif args.dataset == 'sky':
|
72 |
+
transform = transforms.Compose([
|
73 |
+
ToTensorVideo(),
|
74 |
+
CenterCropResizeVideo(args.max_image_size),
|
75 |
+
RandomHorizontalFlipVideo(p=0.5),
|
76 |
+
norm_fun
|
77 |
+
])
|
78 |
+
return Sky(args, transform=transform, temporal_sample=temporal_sample)
|
79 |
+
elif args.dataset == 't2v':
|
80 |
+
transform = transforms.Compose([
|
81 |
+
ToTensorVideo(),
|
82 |
+
CenterCropResizeVideo(args.max_image_size),
|
83 |
+
RandomHorizontalFlipVideo(p=0.5),
|
84 |
+
norm_fun
|
85 |
+
])
|
86 |
+
tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name, cache_dir='./cache_dir')
|
87 |
+
return T2V_dataset(args, transform=transform, temporal_sample=temporal_sample, tokenizer=tokenizer)
|
88 |
+
elif args.dataset == 't2v_feature':
|
89 |
+
return T2V_Feature_dataset(args, temporal_sample)
|
90 |
+
elif args.dataset == 't2v_t5_feature':
|
91 |
+
transform = transforms.Compose([
|
92 |
+
ToTensorVideo(),
|
93 |
+
CenterCropResizeVideo(args.max_image_size),
|
94 |
+
RandomHorizontalFlipVideo(p=0.5),
|
95 |
+
norm_fun
|
96 |
+
])
|
97 |
+
return T2V_T5_Feature_dataset(args, transform, temporal_sample)
|
98 |
+
else:
|
99 |
+
raise NotImplementedError(args.dataset)
|
opensora/dataset/extract_feature_dataset.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
from PIL import Image
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
from opensora.utils.dataset_utils import DecordInit, is_image_file
|
11 |
+
|
12 |
+
|
13 |
+
class ExtractVideo2Feature(Dataset):
|
14 |
+
def __init__(self, args, transform):
|
15 |
+
self.data_path = args.data_path
|
16 |
+
self.transform = transform
|
17 |
+
self.v_decoder = DecordInit()
|
18 |
+
self.samples = list(glob(f'{self.data_path}'))
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return len(self.samples)
|
22 |
+
|
23 |
+
def __getitem__(self, idx):
|
24 |
+
video_path = self.samples[idx]
|
25 |
+
video = self.decord_read(video_path)
|
26 |
+
video = self.transform(video) # T C H W -> T C H W
|
27 |
+
return video, video_path
|
28 |
+
|
29 |
+
def tv_read(self, path):
|
30 |
+
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
|
31 |
+
total_frames = len(vframes)
|
32 |
+
frame_indice = list(range(total_frames))
|
33 |
+
video = vframes[frame_indice]
|
34 |
+
return video
|
35 |
+
|
36 |
+
def decord_read(self, path):
|
37 |
+
decord_vr = self.v_decoder(path)
|
38 |
+
total_frames = len(decord_vr)
|
39 |
+
frame_indice = list(range(total_frames))
|
40 |
+
video_data = decord_vr.get_batch(frame_indice).asnumpy()
|
41 |
+
video_data = torch.from_numpy(video_data)
|
42 |
+
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
|
43 |
+
return video_data
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
class ExtractImage2Feature(Dataset):
|
48 |
+
def __init__(self, args, transform):
|
49 |
+
self.data_path = args.data_path
|
50 |
+
self.transform = transform
|
51 |
+
self.data_all = list(glob(f'{self.data_path}'))
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
return len(self.data_all)
|
55 |
+
|
56 |
+
def __getitem__(self, index):
|
57 |
+
path = self.data_all[index]
|
58 |
+
video_frame = torch.as_tensor(np.array(Image.open(path), dtype=np.uint8, copy=True)).unsqueeze(0)
|
59 |
+
video_frame = video_frame.permute(0, 3, 1, 2)
|
60 |
+
video_frame = self.transform(video_frame) # T C H W
|
61 |
+
# video_frame = video_frame.transpose(0, 1) # T C H W -> C T H W
|
62 |
+
|
63 |
+
return video_frame, path
|
64 |
+
|
opensora/dataset/feature_datasets.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import torch.utils.data as data
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from glob import glob
|
9 |
+
from PIL import Image
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from opensora.dataset.transform import center_crop, RandomCropVideo
|
14 |
+
from opensora.utils.dataset_utils import DecordInit
|
15 |
+
|
16 |
+
|
17 |
+
class T2V_Feature_dataset(Dataset):
|
18 |
+
def __init__(self, args, temporal_sample):
|
19 |
+
|
20 |
+
self.video_folder = args.video_folder
|
21 |
+
self.num_frames = args.video_length
|
22 |
+
self.temporal_sample = temporal_sample
|
23 |
+
|
24 |
+
print('Building dataset...')
|
25 |
+
if os.path.exists('samples_430k.json'):
|
26 |
+
with open('samples_430k.json', 'r') as f:
|
27 |
+
self.samples = json.load(f)
|
28 |
+
else:
|
29 |
+
self.samples = self._make_dataset()
|
30 |
+
with open('samples_430k.json', 'w') as f:
|
31 |
+
json.dump(self.samples, f, indent=2)
|
32 |
+
|
33 |
+
self.use_image_num = args.use_image_num
|
34 |
+
self.use_img_from_vid = args.use_img_from_vid
|
35 |
+
if self.use_image_num != 0 and not self.use_img_from_vid:
|
36 |
+
self.img_cap_list = self.get_img_cap_list()
|
37 |
+
|
38 |
+
def _make_dataset(self):
|
39 |
+
all_mp4 = list(glob(os.path.join(self.video_folder, '**', '*.mp4'), recursive=True))
|
40 |
+
# all_mp4 = all_mp4[:1000]
|
41 |
+
samples = []
|
42 |
+
for i in tqdm(all_mp4):
|
43 |
+
video_id = os.path.basename(i).split('.')[0]
|
44 |
+
ae = os.path.split(i)[0].replace('data_split_tt', 'lb_causalvideovae444_feature')
|
45 |
+
ae = os.path.join(ae, f'{video_id}_causalvideovae444.npy')
|
46 |
+
if not os.path.exists(ae):
|
47 |
+
continue
|
48 |
+
|
49 |
+
t5 = os.path.split(i)[0].replace('data_split_tt', 'lb_t5_feature')
|
50 |
+
cond_list = []
|
51 |
+
cond_llava = os.path.join(t5, f'{video_id}_t5_llava_fea.npy')
|
52 |
+
mask_llava = os.path.join(t5, f'{video_id}_t5_llava_mask.npy')
|
53 |
+
if os.path.exists(cond_llava) and os.path.exists(mask_llava):
|
54 |
+
llava = dict(cond=cond_llava, mask=mask_llava)
|
55 |
+
cond_list.append(llava)
|
56 |
+
cond_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_fea.npy')
|
57 |
+
mask_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_mask.npy')
|
58 |
+
if os.path.exists(cond_sharegpt4v) and os.path.exists(mask_sharegpt4v):
|
59 |
+
sharegpt4v = dict(cond=cond_sharegpt4v, mask=mask_sharegpt4v)
|
60 |
+
cond_list.append(sharegpt4v)
|
61 |
+
if len(cond_list) > 0:
|
62 |
+
sample = dict(ae=ae, t5=cond_list)
|
63 |
+
samples.append(sample)
|
64 |
+
return samples
|
65 |
+
|
66 |
+
def __len__(self):
|
67 |
+
return len(self.samples)
|
68 |
+
|
69 |
+
def __getitem__(self, idx):
|
70 |
+
# try:
|
71 |
+
sample = self.samples[idx]
|
72 |
+
ae, t5 = sample['ae'], sample['t5']
|
73 |
+
t5 = random.choice(t5)
|
74 |
+
video_origin = np.load(ae)[0] # C T H W
|
75 |
+
_, total_frames, _, _ = video_origin.shape
|
76 |
+
# Sampling video frames
|
77 |
+
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
|
78 |
+
assert end_frame_ind - start_frame_ind >= self.num_frames
|
79 |
+
select_video_idx = np.linspace(start_frame_ind, end_frame_ind - 1, num=self.num_frames, dtype=int) # start, stop, num=50
|
80 |
+
# print('select_video_idx', total_frames, select_video_idx)
|
81 |
+
video = video_origin[:, select_video_idx] # C num_frames H W
|
82 |
+
video = torch.from_numpy(video)
|
83 |
+
|
84 |
+
cond = torch.from_numpy(np.load(t5['cond']))[0] # L
|
85 |
+
cond_mask = torch.from_numpy(np.load(t5['mask']))[0] # L D
|
86 |
+
|
87 |
+
if self.use_image_num != 0 and self.use_img_from_vid:
|
88 |
+
select_image_idx = np.random.randint(0, total_frames, self.use_image_num)
|
89 |
+
# print('select_image_idx', total_frames, self.use_image_num, select_image_idx)
|
90 |
+
images = video_origin[:, select_image_idx] # c, num_img, h, w
|
91 |
+
images = torch.from_numpy(images)
|
92 |
+
video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w
|
93 |
+
cond = torch.stack([cond] * (1+self.use_image_num)) # 1+self.use_image_num, l
|
94 |
+
cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) # 1+self.use_image_num, l
|
95 |
+
elif self.use_image_num != 0 and not self.use_img_from_vid:
|
96 |
+
images, captions = self.img_cap_list[idx]
|
97 |
+
raise NotImplementedError
|
98 |
+
else:
|
99 |
+
pass
|
100 |
+
|
101 |
+
return video, cond, cond_mask
|
102 |
+
# except Exception as e:
|
103 |
+
# print(f'Error with {e}, {sample}')
|
104 |
+
# return self.__getitem__(random.randint(0, self.__len__() - 1))
|
105 |
+
|
106 |
+
def get_img_cap_list(self):
|
107 |
+
raise NotImplementedError
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
class T2V_T5_Feature_dataset(Dataset):
|
113 |
+
def __init__(self, args, transform, temporal_sample):
|
114 |
+
|
115 |
+
self.video_folder = args.video_folder
|
116 |
+
self.num_frames = args.num_frames
|
117 |
+
self.transform = transform
|
118 |
+
self.temporal_sample = temporal_sample
|
119 |
+
self.v_decoder = DecordInit()
|
120 |
+
|
121 |
+
print('Building dataset...')
|
122 |
+
if os.path.exists('samples_430k.json'):
|
123 |
+
with open('samples_430k.json', 'r') as f:
|
124 |
+
self.samples = json.load(f)
|
125 |
+
self.samples = [dict(ae=i['ae'].replace('lb_causalvideovae444_feature', 'data_split_1024').replace('_causalvideovae444.npy', '.mp4'), t5=i['t5']) for i in self.samples]
|
126 |
+
else:
|
127 |
+
self.samples = self._make_dataset()
|
128 |
+
with open('samples_430k.json', 'w') as f:
|
129 |
+
json.dump(self.samples, f, indent=2)
|
130 |
+
|
131 |
+
self.use_image_num = args.use_image_num
|
132 |
+
self.use_img_from_vid = args.use_img_from_vid
|
133 |
+
if self.use_image_num != 0 and not self.use_img_from_vid:
|
134 |
+
self.img_cap_list = self.get_img_cap_list()
|
135 |
+
|
136 |
+
def _make_dataset(self):
|
137 |
+
all_mp4 = list(glob(os.path.join(self.video_folder, '**', '*.mp4'), recursive=True))
|
138 |
+
# all_mp4 = all_mp4[:1000]
|
139 |
+
samples = []
|
140 |
+
for i in tqdm(all_mp4):
|
141 |
+
video_id = os.path.basename(i).split('.')[0]
|
142 |
+
# ae = os.path.split(i)[0].replace('data_split', 'lb_causalvideovae444_feature')
|
143 |
+
# ae = os.path.join(ae, f'{video_id}_causalvideovae444.npy')
|
144 |
+
ae = i
|
145 |
+
if not os.path.exists(ae):
|
146 |
+
continue
|
147 |
+
|
148 |
+
t5 = os.path.split(i)[0].replace('data_split_1024', 'lb_t5_feature')
|
149 |
+
cond_list = []
|
150 |
+
cond_llava = os.path.join(t5, f'{video_id}_t5_llava_fea.npy')
|
151 |
+
mask_llava = os.path.join(t5, f'{video_id}_t5_llava_mask.npy')
|
152 |
+
if os.path.exists(cond_llava) and os.path.exists(mask_llava):
|
153 |
+
llava = dict(cond=cond_llava, mask=mask_llava)
|
154 |
+
cond_list.append(llava)
|
155 |
+
cond_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_fea.npy')
|
156 |
+
mask_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_mask.npy')
|
157 |
+
if os.path.exists(cond_sharegpt4v) and os.path.exists(mask_sharegpt4v):
|
158 |
+
sharegpt4v = dict(cond=cond_sharegpt4v, mask=mask_sharegpt4v)
|
159 |
+
cond_list.append(sharegpt4v)
|
160 |
+
if len(cond_list) > 0:
|
161 |
+
sample = dict(ae=ae, t5=cond_list)
|
162 |
+
samples.append(sample)
|
163 |
+
return samples
|
164 |
+
|
165 |
+
def __len__(self):
|
166 |
+
return len(self.samples)
|
167 |
+
|
168 |
+
def __getitem__(self, idx):
|
169 |
+
try:
|
170 |
+
sample = self.samples[idx]
|
171 |
+
ae, t5 = sample['ae'], sample['t5']
|
172 |
+
t5 = random.choice(t5)
|
173 |
+
|
174 |
+
video = self.decord_read(ae)
|
175 |
+
video = self.transform(video) # T C H W -> T C H W
|
176 |
+
video = video.transpose(0, 1) # T C H W -> C T H W
|
177 |
+
total_frames = video.shape[1]
|
178 |
+
cond = torch.from_numpy(np.load(t5['cond']))[0] # L
|
179 |
+
cond_mask = torch.from_numpy(np.load(t5['mask']))[0] # L D
|
180 |
+
|
181 |
+
if self.use_image_num != 0 and self.use_img_from_vid:
|
182 |
+
select_image_idx = np.random.randint(0, total_frames, self.use_image_num)
|
183 |
+
# print('select_image_idx', total_frames, self.use_image_num, select_image_idx)
|
184 |
+
images = video.numpy()[:, select_image_idx] # c, num_img, h, w
|
185 |
+
images = torch.from_numpy(images)
|
186 |
+
video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w
|
187 |
+
cond = torch.stack([cond] * (1+self.use_image_num)) # 1+self.use_image_num, l
|
188 |
+
cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) # 1+self.use_image_num, l
|
189 |
+
elif self.use_image_num != 0 and not self.use_img_from_vid:
|
190 |
+
images, captions = self.img_cap_list[idx]
|
191 |
+
raise NotImplementedError
|
192 |
+
else:
|
193 |
+
pass
|
194 |
+
|
195 |
+
return video, cond, cond_mask
|
196 |
+
except Exception as e:
|
197 |
+
print(f'Error with {e}, {sample}')
|
198 |
+
return self.__getitem__(random.randint(0, self.__len__() - 1))
|
199 |
+
|
200 |
+
def decord_read(self, path):
|
201 |
+
decord_vr = self.v_decoder(path)
|
202 |
+
total_frames = len(decord_vr)
|
203 |
+
# Sampling video frames
|
204 |
+
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
|
205 |
+
# assert end_frame_ind - start_frame_ind >= self.num_frames
|
206 |
+
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
|
207 |
+
video_data = decord_vr.get_batch(frame_indice).asnumpy()
|
208 |
+
video_data = torch.from_numpy(video_data)
|
209 |
+
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
|
210 |
+
return video_data
|
211 |
+
|
212 |
+
def get_img_cap_list(self):
|
213 |
+
raise NotImplementedError
|
opensora/dataset/landscope.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from glob import glob
|
4 |
+
|
5 |
+
import decord
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torchvision
|
9 |
+
from decord import VideoReader, cpu
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
from torchvision.transforms import Compose, Lambda, ToTensor
|
12 |
+
from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo
|
13 |
+
from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample
|
14 |
+
from torch.nn import functional as F
|
15 |
+
import random
|
16 |
+
|
17 |
+
from opensora.utils.dataset_utils import DecordInit
|
18 |
+
|
19 |
+
|
20 |
+
class Landscope(Dataset):
|
21 |
+
def __init__(self, args, transform, temporal_sample):
|
22 |
+
self.data_path = args.data_path
|
23 |
+
self.num_frames = args.num_frames
|
24 |
+
self.transform = transform
|
25 |
+
self.temporal_sample = temporal_sample
|
26 |
+
self.v_decoder = DecordInit()
|
27 |
+
|
28 |
+
self.samples = self._make_dataset()
|
29 |
+
self.use_image_num = args.use_image_num
|
30 |
+
self.use_img_from_vid = args.use_img_from_vid
|
31 |
+
if self.use_image_num != 0 and not self.use_img_from_vid:
|
32 |
+
self.img_cap_list = self.get_img_cap_list()
|
33 |
+
|
34 |
+
|
35 |
+
def _make_dataset(self):
|
36 |
+
paths = list(glob(os.path.join(self.data_path, '**', '*.mp4'), recursive=True))
|
37 |
+
|
38 |
+
return paths
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
return len(self.samples)
|
42 |
+
|
43 |
+
def __getitem__(self, idx):
|
44 |
+
video_path = self.samples[idx]
|
45 |
+
try:
|
46 |
+
video = self.tv_read(video_path)
|
47 |
+
video = self.transform(video) # T C H W -> T C H W
|
48 |
+
video = video.transpose(0, 1) # T C H W -> C T H W
|
49 |
+
if self.use_image_num != 0 and self.use_img_from_vid:
|
50 |
+
select_image_idx = np.linspace(0, self.num_frames - 1, self.use_image_num, dtype=int)
|
51 |
+
assert self.num_frames >= self.use_image_num
|
52 |
+
images = video[:, select_image_idx] # c, num_img, h, w
|
53 |
+
video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w
|
54 |
+
elif self.use_image_num != 0 and not self.use_img_from_vid:
|
55 |
+
images, captions = self.img_cap_list[idx]
|
56 |
+
raise NotImplementedError
|
57 |
+
else:
|
58 |
+
pass
|
59 |
+
return video, 1
|
60 |
+
except Exception as e:
|
61 |
+
print(f'Error with {e}, {video_path}')
|
62 |
+
return self.__getitem__(random.randint(0, self.__len__()-1))
|
63 |
+
|
64 |
+
def tv_read(self, path):
|
65 |
+
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
|
66 |
+
total_frames = len(vframes)
|
67 |
+
|
68 |
+
# Sampling video frames
|
69 |
+
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
|
70 |
+
# assert end_frame_ind - start_frame_ind >= self.num_frames
|
71 |
+
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
|
72 |
+
video = vframes[frame_indice] # (T, C, H, W)
|
73 |
+
|
74 |
+
return video
|
75 |
+
|
76 |
+
def decord_read(self, path):
|
77 |
+
decord_vr = self.v_decoder(path)
|
78 |
+
total_frames = len(decord_vr)
|
79 |
+
# Sampling video frames
|
80 |
+
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
|
81 |
+
# assert end_frame_ind - start_frame_ind >= self.num_frames
|
82 |
+
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
|
83 |
+
|
84 |
+
video_data = decord_vr.get_batch(frame_indice).asnumpy()
|
85 |
+
video_data = torch.from_numpy(video_data)
|
86 |
+
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
|
87 |
+
return video_data
|
88 |
+
|
89 |
+
def get_img_cap_list(self):
|
90 |
+
raise NotImplementedError
|
opensora/dataset/sky_datasets.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import torch.utils.data as data
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from opensora.utils.dataset_utils import is_image_file
|
11 |
+
|
12 |
+
|
13 |
+
class Sky(data.Dataset):
|
14 |
+
def __init__(self, args, transform, temporal_sample=None, train=True):
|
15 |
+
|
16 |
+
self.args = args
|
17 |
+
self.data_path = args.data_path
|
18 |
+
self.transform = transform
|
19 |
+
self.temporal_sample = temporal_sample
|
20 |
+
self.num_frames = self.args.num_frames
|
21 |
+
self.sample_rate = self.args.sample_rate
|
22 |
+
self.data_all = self.load_video_frames(self.data_path)
|
23 |
+
self.use_image_num = args.use_image_num
|
24 |
+
self.use_img_from_vid = args.use_img_from_vid
|
25 |
+
if self.use_image_num != 0 and not self.use_img_from_vid:
|
26 |
+
self.img_cap_list = self.get_img_cap_list()
|
27 |
+
|
28 |
+
def __getitem__(self, index):
|
29 |
+
|
30 |
+
vframes = self.data_all[index]
|
31 |
+
total_frames = len(vframes)
|
32 |
+
|
33 |
+
# Sampling video frames
|
34 |
+
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
|
35 |
+
assert end_frame_ind - start_frame_ind >= self.num_frames
|
36 |
+
frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, num=self.num_frames, dtype=int) # start, stop, num=50
|
37 |
+
|
38 |
+
select_video_frames = vframes[frame_indice[0]: frame_indice[-1]+1: self.sample_rate]
|
39 |
+
|
40 |
+
video_frames = []
|
41 |
+
for path in select_video_frames:
|
42 |
+
video_frame = torch.as_tensor(np.array(Image.open(path), dtype=np.uint8, copy=True)).unsqueeze(0)
|
43 |
+
video_frames.append(video_frame)
|
44 |
+
video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2)
|
45 |
+
video_clip = self.transform(video_clip)
|
46 |
+
video_clip = video_clip.transpose(0, 1) # T C H W -> C T H W
|
47 |
+
|
48 |
+
if self.use_image_num != 0 and self.use_img_from_vid:
|
49 |
+
select_image_idx = np.linspace(0, self.num_frames - 1, self.use_image_num, dtype=int)
|
50 |
+
assert self.num_frames >= self.use_image_num
|
51 |
+
images = video_clip[:, select_image_idx] # c, num_img, h, w
|
52 |
+
video_clip = torch.cat([video_clip, images], dim=1) # c, num_frame+num_img, h, w
|
53 |
+
elif self.use_image_num != 0 and not self.use_img_from_vid:
|
54 |
+
images, captions = self.img_cap_list[index]
|
55 |
+
raise NotImplementedError
|
56 |
+
else:
|
57 |
+
pass
|
58 |
+
|
59 |
+
return video_clip, 1
|
60 |
+
|
61 |
+
def __len__(self):
|
62 |
+
return self.video_num
|
63 |
+
|
64 |
+
def load_video_frames(self, dataroot):
|
65 |
+
data_all = []
|
66 |
+
frame_list = os.walk(dataroot)
|
67 |
+
for _, meta in enumerate(frame_list):
|
68 |
+
root = meta[0]
|
69 |
+
try:
|
70 |
+
frames = [i for i in meta[2] if is_image_file(i)]
|
71 |
+
frames = sorted(frames, key=lambda item: int(item.split('.')[0].split('_')[-1]))
|
72 |
+
except:
|
73 |
+
pass
|
74 |
+
# print(meta[0]) # root
|
75 |
+
# print(meta[2]) # files
|
76 |
+
frames = [os.path.join(root, item) for item in frames if is_image_file(item)]
|
77 |
+
if len(frames) > max(0, self.num_frames * self.sample_rate): # need all > (16 * frame-interval) videos
|
78 |
+
# if len(frames) >= max(0, self.target_video_len): # need all > 16 frames videos
|
79 |
+
data_all.append(frames)
|
80 |
+
self.video_num = len(data_all)
|
81 |
+
return data_all
|
82 |
+
|
83 |
+
def get_img_cap_list(self):
|
84 |
+
raise NotImplementedError
|
85 |
+
|
86 |
+
if __name__ == '__main__':
|
87 |
+
|
88 |
+
import argparse
|
89 |
+
import torchvision
|
90 |
+
import video_transforms
|
91 |
+
import torch.utils.data as data
|
92 |
+
|
93 |
+
from torchvision import transforms
|
94 |
+
from torchvision.utils import save_image
|
95 |
+
|
96 |
+
|
97 |
+
parser = argparse.ArgumentParser()
|
98 |
+
parser.add_argument("--num_frames", type=int, default=16)
|
99 |
+
parser.add_argument("--frame_interval", type=int, default=4)
|
100 |
+
parser.add_argument("--data-path", type=str, default="/path/to/datasets/sky_timelapse/sky_train/")
|
101 |
+
config = parser.parse_args()
|
102 |
+
|
103 |
+
|
104 |
+
target_video_len = config.num_frames
|
105 |
+
|
106 |
+
temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval)
|
107 |
+
trans = transforms.Compose([
|
108 |
+
video_transforms.ToTensorVideo(),
|
109 |
+
# video_transforms.CenterCropVideo(256),
|
110 |
+
video_transforms.CenterCropResizeVideo(256),
|
111 |
+
# video_transforms.RandomHorizontalFlipVideo(),
|
112 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
113 |
+
])
|
114 |
+
|
115 |
+
taichi_dataset = Sky(config, transform=trans, temporal_sample=temporal_sample)
|
116 |
+
print(len(taichi_dataset))
|
117 |
+
taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1)
|
118 |
+
|
119 |
+
for i, video_data in enumerate(taichi_dataloader):
|
120 |
+
print(video_data['video'].shape)
|
121 |
+
|
122 |
+
# print(video_data.dtype)
|
123 |
+
# for i in range(target_video_len):
|
124 |
+
# save_image(video_data[0][i], os.path.join('./test_data', '%04d.png' % i), normalize=True, value_range=(-1, 1))
|
125 |
+
|
126 |
+
# video_ = ((video_data[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
|
127 |
+
# torchvision.io.write_video('./test_data' + 'test.mp4', video_, fps=8)
|
128 |
+
# exit()
|
opensora/dataset/t2v_datasets.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os, io, csv, math, random
|
3 |
+
import numpy as np
|
4 |
+
import torchvision
|
5 |
+
from einops import rearrange
|
6 |
+
from decord import VideoReader
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
from torch.utils.data.dataset import Dataset
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from opensora.utils.dataset_utils import DecordInit
|
14 |
+
from opensora.utils.utils import text_preprocessing
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
class T2V_dataset(Dataset):
|
19 |
+
def __init__(self, args, transform, temporal_sample, tokenizer):
|
20 |
+
|
21 |
+
# with open(args.data_path, 'r') as csvfile:
|
22 |
+
# self.samples = list(csv.DictReader(csvfile))
|
23 |
+
self.video_folder = args.video_folder
|
24 |
+
self.num_frames = args.num_frames
|
25 |
+
self.transform = transform
|
26 |
+
self.temporal_sample = temporal_sample
|
27 |
+
self.tokenizer = tokenizer
|
28 |
+
self.model_max_length = args.model_max_length
|
29 |
+
self.v_decoder = DecordInit()
|
30 |
+
|
31 |
+
with open(args.data_path, 'r') as f:
|
32 |
+
self.samples = json.load(f)
|
33 |
+
self.use_image_num = args.use_image_num
|
34 |
+
self.use_img_from_vid = args.use_img_from_vid
|
35 |
+
if self.use_image_num != 0 and not self.use_img_from_vid:
|
36 |
+
self.img_cap_list = self.get_img_cap_list()
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.samples)
|
40 |
+
|
41 |
+
def __getitem__(self, idx):
|
42 |
+
try:
|
43 |
+
# video = torch.randn(3, 16, 128, 128)
|
44 |
+
# input_ids = torch.ones(1, 120).to(torch.long).squeeze(0)
|
45 |
+
# cond_mask = torch.cat([torch.ones(1, 60).to(torch.long), torch.ones(1, 60).to(torch.long)], dim=1).squeeze(0)
|
46 |
+
# return video, input_ids, cond_mask
|
47 |
+
video_path = self.samples[idx]['path']
|
48 |
+
video = self.decord_read(video_path)
|
49 |
+
video = self.transform(video) # T C H W -> T C H W
|
50 |
+
video = video.transpose(0, 1) # T C H W -> C T H W
|
51 |
+
text = self.samples[idx]['cap'][0]
|
52 |
+
|
53 |
+
text = text_preprocessing(text)
|
54 |
+
text_tokens_and_mask = self.tokenizer(
|
55 |
+
text,
|
56 |
+
max_length=self.model_max_length,
|
57 |
+
padding='max_length',
|
58 |
+
truncation=True,
|
59 |
+
return_attention_mask=True,
|
60 |
+
add_special_tokens=True,
|
61 |
+
return_tensors='pt'
|
62 |
+
)
|
63 |
+
input_ids = text_tokens_and_mask['input_ids'].squeeze(0)
|
64 |
+
cond_mask = text_tokens_and_mask['attention_mask'].squeeze(0)
|
65 |
+
|
66 |
+
if self.use_image_num != 0 and self.use_img_from_vid:
|
67 |
+
select_image_idx = np.linspace(0, self.num_frames-1, self.use_image_num, dtype=int)
|
68 |
+
assert self.num_frames >= self.use_image_num
|
69 |
+
images = video[:, select_image_idx] # c, num_img, h, w
|
70 |
+
video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w
|
71 |
+
input_ids = torch.stack([input_ids] * (1+self.use_image_num)) # 1+self.use_image_num, l
|
72 |
+
cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) # 1+self.use_image_num, l
|
73 |
+
elif self.use_image_num != 0 and not self.use_img_from_vid:
|
74 |
+
images, captions = self.img_cap_list[idx]
|
75 |
+
raise NotImplementedError
|
76 |
+
else:
|
77 |
+
pass
|
78 |
+
|
79 |
+
return video, input_ids, cond_mask
|
80 |
+
except Exception as e:
|
81 |
+
print(f'Error with {e}, {self.samples[idx]}')
|
82 |
+
return self.__getitem__(random.randint(0, self.__len__() - 1))
|
83 |
+
|
84 |
+
def tv_read(self, path):
|
85 |
+
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
|
86 |
+
total_frames = len(vframes)
|
87 |
+
|
88 |
+
# Sampling video frames
|
89 |
+
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
|
90 |
+
# assert end_frame_ind - start_frame_ind >= self.num_frames
|
91 |
+
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
|
92 |
+
|
93 |
+
video = vframes[frame_indice] # (T, C, H, W)
|
94 |
+
|
95 |
+
return video
|
96 |
+
|
97 |
+
def decord_read(self, path):
|
98 |
+
decord_vr = self.v_decoder(path)
|
99 |
+
total_frames = len(decord_vr)
|
100 |
+
# Sampling video frames
|
101 |
+
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
|
102 |
+
# assert end_frame_ind - start_frame_ind >= self.num_frames
|
103 |
+
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
|
104 |
+
|
105 |
+
video_data = decord_vr.get_batch(frame_indice).asnumpy()
|
106 |
+
video_data = torch.from_numpy(video_data)
|
107 |
+
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
|
108 |
+
return video_data
|
109 |
+
|
110 |
+
def get_img_cap_list(self):
|
111 |
+
raise NotImplementedError
|
opensora/dataset/transform.py
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
import numbers
|
4 |
+
from torchvision.transforms import RandomCrop, RandomResizedCrop
|
5 |
+
|
6 |
+
|
7 |
+
def _is_tensor_video_clip(clip):
|
8 |
+
if not torch.is_tensor(clip):
|
9 |
+
raise TypeError("clip should be Tensor. Got %s" % type(clip))
|
10 |
+
|
11 |
+
if not clip.ndimension() == 4:
|
12 |
+
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
|
13 |
+
|
14 |
+
return True
|
15 |
+
|
16 |
+
|
17 |
+
def center_crop_arr(pil_image, image_size):
|
18 |
+
"""
|
19 |
+
Center cropping implementation from ADM.
|
20 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
21 |
+
"""
|
22 |
+
while min(*pil_image.size) >= 2 * image_size:
|
23 |
+
pil_image = pil_image.resize(
|
24 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
25 |
+
)
|
26 |
+
|
27 |
+
scale = image_size / min(*pil_image.size)
|
28 |
+
pil_image = pil_image.resize(
|
29 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
30 |
+
)
|
31 |
+
|
32 |
+
arr = np.array(pil_image)
|
33 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
34 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
35 |
+
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
36 |
+
|
37 |
+
|
38 |
+
def crop(clip, i, j, h, w):
|
39 |
+
"""
|
40 |
+
Args:
|
41 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
42 |
+
"""
|
43 |
+
if len(clip.size()) != 4:
|
44 |
+
raise ValueError("clip should be a 4D tensor")
|
45 |
+
return clip[..., i: i + h, j: j + w]
|
46 |
+
|
47 |
+
|
48 |
+
def resize(clip, target_size, interpolation_mode):
|
49 |
+
if len(target_size) != 2:
|
50 |
+
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
51 |
+
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=True, antialias=True)
|
52 |
+
|
53 |
+
|
54 |
+
def resize_scale(clip, target_size, interpolation_mode):
|
55 |
+
if len(target_size) != 2:
|
56 |
+
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
57 |
+
H, W = clip.size(-2), clip.size(-1)
|
58 |
+
scale_ = target_size[0] / min(H, W)
|
59 |
+
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=True, antialias=True)
|
60 |
+
|
61 |
+
|
62 |
+
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
|
63 |
+
"""
|
64 |
+
Do spatial cropping and resizing to the video clip
|
65 |
+
Args:
|
66 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
67 |
+
i (int): i in (i,j) i.e coordinates of the upper left corner.
|
68 |
+
j (int): j in (i,j) i.e coordinates of the upper left corner.
|
69 |
+
h (int): Height of the cropped region.
|
70 |
+
w (int): Width of the cropped region.
|
71 |
+
size (tuple(int, int)): height and width of resized clip
|
72 |
+
Returns:
|
73 |
+
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
|
74 |
+
"""
|
75 |
+
if not _is_tensor_video_clip(clip):
|
76 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
77 |
+
clip = crop(clip, i, j, h, w)
|
78 |
+
clip = resize(clip, size, interpolation_mode)
|
79 |
+
return clip
|
80 |
+
|
81 |
+
|
82 |
+
def center_crop(clip, crop_size):
|
83 |
+
if not _is_tensor_video_clip(clip):
|
84 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
85 |
+
h, w = clip.size(-2), clip.size(-1)
|
86 |
+
th, tw = crop_size
|
87 |
+
if h < th or w < tw:
|
88 |
+
raise ValueError("height and width must be no smaller than crop_size")
|
89 |
+
|
90 |
+
i = int(round((h - th) / 2.0))
|
91 |
+
j = int(round((w - tw) / 2.0))
|
92 |
+
return crop(clip, i, j, th, tw)
|
93 |
+
|
94 |
+
|
95 |
+
def center_crop_using_short_edge(clip):
|
96 |
+
if not _is_tensor_video_clip(clip):
|
97 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
98 |
+
h, w = clip.size(-2), clip.size(-1)
|
99 |
+
if h < w:
|
100 |
+
th, tw = h, h
|
101 |
+
i = 0
|
102 |
+
j = int(round((w - tw) / 2.0))
|
103 |
+
else:
|
104 |
+
th, tw = w, w
|
105 |
+
i = int(round((h - th) / 2.0))
|
106 |
+
j = 0
|
107 |
+
return crop(clip, i, j, th, tw)
|
108 |
+
|
109 |
+
|
110 |
+
def random_shift_crop(clip):
|
111 |
+
'''
|
112 |
+
Slide along the long edge, with the short edge as crop size
|
113 |
+
'''
|
114 |
+
if not _is_tensor_video_clip(clip):
|
115 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
116 |
+
h, w = clip.size(-2), clip.size(-1)
|
117 |
+
|
118 |
+
if h <= w:
|
119 |
+
long_edge = w
|
120 |
+
short_edge = h
|
121 |
+
else:
|
122 |
+
long_edge = h
|
123 |
+
short_edge = w
|
124 |
+
|
125 |
+
th, tw = short_edge, short_edge
|
126 |
+
|
127 |
+
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
128 |
+
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
129 |
+
return crop(clip, i, j, th, tw)
|
130 |
+
|
131 |
+
|
132 |
+
def to_tensor(clip):
|
133 |
+
"""
|
134 |
+
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
135 |
+
permute the dimensions of clip tensor
|
136 |
+
Args:
|
137 |
+
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
138 |
+
Return:
|
139 |
+
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
140 |
+
"""
|
141 |
+
_is_tensor_video_clip(clip)
|
142 |
+
if not clip.dtype == torch.uint8:
|
143 |
+
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
|
144 |
+
# return clip.float().permute(3, 0, 1, 2) / 255.0
|
145 |
+
return clip.float() / 255.0
|
146 |
+
|
147 |
+
|
148 |
+
def normalize(clip, mean, std, inplace=False):
|
149 |
+
"""
|
150 |
+
Args:
|
151 |
+
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
152 |
+
mean (tuple): pixel RGB mean. Size is (3)
|
153 |
+
std (tuple): pixel standard deviation. Size is (3)
|
154 |
+
Returns:
|
155 |
+
normalized clip (torch.tensor): Size is (T, C, H, W)
|
156 |
+
"""
|
157 |
+
if not _is_tensor_video_clip(clip):
|
158 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
159 |
+
if not inplace:
|
160 |
+
clip = clip.clone()
|
161 |
+
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
|
162 |
+
# print(mean)
|
163 |
+
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
|
164 |
+
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
|
165 |
+
return clip
|
166 |
+
|
167 |
+
|
168 |
+
def hflip(clip):
|
169 |
+
"""
|
170 |
+
Args:
|
171 |
+
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
172 |
+
Returns:
|
173 |
+
flipped clip (torch.tensor): Size is (T, C, H, W)
|
174 |
+
"""
|
175 |
+
if not _is_tensor_video_clip(clip):
|
176 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
177 |
+
return clip.flip(-1)
|
178 |
+
|
179 |
+
|
180 |
+
class RandomCropVideo:
|
181 |
+
def __init__(self, size):
|
182 |
+
if isinstance(size, numbers.Number):
|
183 |
+
self.size = (int(size), int(size))
|
184 |
+
else:
|
185 |
+
self.size = size
|
186 |
+
|
187 |
+
def __call__(self, clip):
|
188 |
+
"""
|
189 |
+
Args:
|
190 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
191 |
+
Returns:
|
192 |
+
torch.tensor: randomly cropped video clip.
|
193 |
+
size is (T, C, OH, OW)
|
194 |
+
"""
|
195 |
+
i, j, h, w = self.get_params(clip)
|
196 |
+
return crop(clip, i, j, h, w)
|
197 |
+
|
198 |
+
def get_params(self, clip):
|
199 |
+
h, w = clip.shape[-2:]
|
200 |
+
th, tw = self.size
|
201 |
+
|
202 |
+
if h < th or w < tw:
|
203 |
+
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
|
204 |
+
|
205 |
+
if w == tw and h == th:
|
206 |
+
return 0, 0, h, w
|
207 |
+
|
208 |
+
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
209 |
+
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
210 |
+
|
211 |
+
return i, j, th, tw
|
212 |
+
|
213 |
+
def __repr__(self) -> str:
|
214 |
+
return f"{self.__class__.__name__}(size={self.size})"
|
215 |
+
|
216 |
+
|
217 |
+
class CenterCropResizeVideo:
|
218 |
+
'''
|
219 |
+
First use the short side for cropping length,
|
220 |
+
center crop video, then resize to the specified size
|
221 |
+
'''
|
222 |
+
|
223 |
+
def __init__(
|
224 |
+
self,
|
225 |
+
size,
|
226 |
+
interpolation_mode="bilinear",
|
227 |
+
):
|
228 |
+
if isinstance(size, tuple):
|
229 |
+
if len(size) != 2:
|
230 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
231 |
+
self.size = size
|
232 |
+
else:
|
233 |
+
self.size = (size, size)
|
234 |
+
|
235 |
+
self.interpolation_mode = interpolation_mode
|
236 |
+
|
237 |
+
def __call__(self, clip):
|
238 |
+
"""
|
239 |
+
Args:
|
240 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
241 |
+
Returns:
|
242 |
+
torch.tensor: scale resized / center cropped video clip.
|
243 |
+
size is (T, C, crop_size, crop_size)
|
244 |
+
"""
|
245 |
+
clip_center_crop = center_crop_using_short_edge(clip)
|
246 |
+
clip_center_crop_resize = resize(clip_center_crop, target_size=self.size,
|
247 |
+
interpolation_mode=self.interpolation_mode)
|
248 |
+
return clip_center_crop_resize
|
249 |
+
|
250 |
+
def __repr__(self) -> str:
|
251 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
252 |
+
|
253 |
+
|
254 |
+
class UCFCenterCropVideo:
|
255 |
+
'''
|
256 |
+
First scale to the specified size in equal proportion to the short edge,
|
257 |
+
then center cropping
|
258 |
+
'''
|
259 |
+
|
260 |
+
def __init__(
|
261 |
+
self,
|
262 |
+
size,
|
263 |
+
interpolation_mode="bilinear",
|
264 |
+
):
|
265 |
+
if isinstance(size, tuple):
|
266 |
+
if len(size) != 2:
|
267 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
268 |
+
self.size = size
|
269 |
+
else:
|
270 |
+
self.size = (size, size)
|
271 |
+
|
272 |
+
self.interpolation_mode = interpolation_mode
|
273 |
+
|
274 |
+
def __call__(self, clip):
|
275 |
+
"""
|
276 |
+
Args:
|
277 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
278 |
+
Returns:
|
279 |
+
torch.tensor: scale resized / center cropped video clip.
|
280 |
+
size is (T, C, crop_size, crop_size)
|
281 |
+
"""
|
282 |
+
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
|
283 |
+
clip_center_crop = center_crop(clip_resize, self.size)
|
284 |
+
return clip_center_crop
|
285 |
+
|
286 |
+
def __repr__(self) -> str:
|
287 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
288 |
+
|
289 |
+
|
290 |
+
class KineticsRandomCropResizeVideo:
|
291 |
+
'''
|
292 |
+
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
|
293 |
+
'''
|
294 |
+
|
295 |
+
def __init__(
|
296 |
+
self,
|
297 |
+
size,
|
298 |
+
interpolation_mode="bilinear",
|
299 |
+
):
|
300 |
+
if isinstance(size, tuple):
|
301 |
+
if len(size) != 2:
|
302 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
303 |
+
self.size = size
|
304 |
+
else:
|
305 |
+
self.size = (size, size)
|
306 |
+
|
307 |
+
self.interpolation_mode = interpolation_mode
|
308 |
+
|
309 |
+
def __call__(self, clip):
|
310 |
+
clip_random_crop = random_shift_crop(clip)
|
311 |
+
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
|
312 |
+
return clip_resize
|
313 |
+
|
314 |
+
|
315 |
+
class CenterCropVideo:
|
316 |
+
def __init__(
|
317 |
+
self,
|
318 |
+
size,
|
319 |
+
interpolation_mode="bilinear",
|
320 |
+
):
|
321 |
+
if isinstance(size, tuple):
|
322 |
+
if len(size) != 2:
|
323 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
324 |
+
self.size = size
|
325 |
+
else:
|
326 |
+
self.size = (size, size)
|
327 |
+
|
328 |
+
self.interpolation_mode = interpolation_mode
|
329 |
+
|
330 |
+
def __call__(self, clip):
|
331 |
+
"""
|
332 |
+
Args:
|
333 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
334 |
+
Returns:
|
335 |
+
torch.tensor: center cropped video clip.
|
336 |
+
size is (T, C, crop_size, crop_size)
|
337 |
+
"""
|
338 |
+
clip_center_crop = center_crop(clip, self.size)
|
339 |
+
return clip_center_crop
|
340 |
+
|
341 |
+
def __repr__(self) -> str:
|
342 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
343 |
+
|
344 |
+
|
345 |
+
class NormalizeVideo:
|
346 |
+
"""
|
347 |
+
Normalize the video clip by mean subtraction and division by standard deviation
|
348 |
+
Args:
|
349 |
+
mean (3-tuple): pixel RGB mean
|
350 |
+
std (3-tuple): pixel RGB standard deviation
|
351 |
+
inplace (boolean): whether do in-place normalization
|
352 |
+
"""
|
353 |
+
|
354 |
+
def __init__(self, mean, std, inplace=False):
|
355 |
+
self.mean = mean
|
356 |
+
self.std = std
|
357 |
+
self.inplace = inplace
|
358 |
+
|
359 |
+
def __call__(self, clip):
|
360 |
+
"""
|
361 |
+
Args:
|
362 |
+
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
|
363 |
+
"""
|
364 |
+
return normalize(clip, self.mean, self.std, self.inplace)
|
365 |
+
|
366 |
+
def __repr__(self) -> str:
|
367 |
+
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
|
368 |
+
|
369 |
+
|
370 |
+
class ToTensorVideo:
|
371 |
+
"""
|
372 |
+
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
373 |
+
permute the dimensions of clip tensor
|
374 |
+
"""
|
375 |
+
|
376 |
+
def __init__(self):
|
377 |
+
pass
|
378 |
+
|
379 |
+
def __call__(self, clip):
|
380 |
+
"""
|
381 |
+
Args:
|
382 |
+
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
383 |
+
Return:
|
384 |
+
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
385 |
+
"""
|
386 |
+
return to_tensor(clip)
|
387 |
+
|
388 |
+
def __repr__(self) -> str:
|
389 |
+
return self.__class__.__name__
|
390 |
+
|
391 |
+
|
392 |
+
class RandomHorizontalFlipVideo:
|
393 |
+
"""
|
394 |
+
Flip the video clip along the horizontal direction with a given probability
|
395 |
+
Args:
|
396 |
+
p (float): probability of the clip being flipped. Default value is 0.5
|
397 |
+
"""
|
398 |
+
|
399 |
+
def __init__(self, p=0.5):
|
400 |
+
self.p = p
|
401 |
+
|
402 |
+
def __call__(self, clip):
|
403 |
+
"""
|
404 |
+
Args:
|
405 |
+
clip (torch.tensor): Size is (T, C, H, W)
|
406 |
+
Return:
|
407 |
+
clip (torch.tensor): Size is (T, C, H, W)
|
408 |
+
"""
|
409 |
+
if random.random() < self.p:
|
410 |
+
clip = hflip(clip)
|
411 |
+
return clip
|
412 |
+
|
413 |
+
def __repr__(self) -> str:
|
414 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
415 |
+
|
416 |
+
|
417 |
+
# ------------------------------------------------------------
|
418 |
+
# --------------------- Sampling ---------------------------
|
419 |
+
# ------------------------------------------------------------
|
420 |
+
class TemporalRandomCrop(object):
|
421 |
+
"""Temporally crop the given frame indices at a random location.
|
422 |
+
|
423 |
+
Args:
|
424 |
+
size (int): Desired length of frames will be seen in the model.
|
425 |
+
"""
|
426 |
+
|
427 |
+
def __init__(self, size):
|
428 |
+
self.size = size
|
429 |
+
|
430 |
+
def __call__(self, total_frames):
|
431 |
+
rand_end = max(0, total_frames - self.size - 1)
|
432 |
+
begin_index = random.randint(0, rand_end)
|
433 |
+
end_index = min(begin_index + self.size, total_frames)
|
434 |
+
return begin_index, end_index
|
435 |
+
|
436 |
+
|
437 |
+
if __name__ == '__main__':
|
438 |
+
from torchvision import transforms
|
439 |
+
import torchvision.io as io
|
440 |
+
import numpy as np
|
441 |
+
from torchvision.utils import save_image
|
442 |
+
import os
|
443 |
+
|
444 |
+
vframes, aframes, info = io.read_video(
|
445 |
+
filename='./v_Archery_g01_c03.avi',
|
446 |
+
pts_unit='sec',
|
447 |
+
output_format='TCHW'
|
448 |
+
)
|
449 |
+
|
450 |
+
trans = transforms.Compose([
|
451 |
+
ToTensorVideo(),
|
452 |
+
RandomHorizontalFlipVideo(),
|
453 |
+
UCFCenterCropVideo(512),
|
454 |
+
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
455 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
456 |
+
])
|
457 |
+
|
458 |
+
target_video_len = 32
|
459 |
+
frame_interval = 1
|
460 |
+
total_frames = len(vframes)
|
461 |
+
print(total_frames)
|
462 |
+
|
463 |
+
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
|
464 |
+
|
465 |
+
# Sampling video frames
|
466 |
+
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
|
467 |
+
# print(start_frame_ind)
|
468 |
+
# print(end_frame_ind)
|
469 |
+
assert end_frame_ind - start_frame_ind >= target_video_len
|
470 |
+
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
|
471 |
+
print(frame_indice)
|
472 |
+
|
473 |
+
select_vframes = vframes[frame_indice]
|
474 |
+
print(select_vframes.shape)
|
475 |
+
print(select_vframes.dtype)
|
476 |
+
|
477 |
+
select_vframes_trans = trans(select_vframes)
|
478 |
+
print(select_vframes_trans.shape)
|
479 |
+
print(select_vframes_trans.dtype)
|
480 |
+
|
481 |
+
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
|
482 |
+
print(select_vframes_trans_int.dtype)
|
483 |
+
print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
|
484 |
+
|
485 |
+
io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
|
486 |
+
|
487 |
+
for i in range(target_video_len):
|
488 |
+
save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True,
|
489 |
+
value_range=(-1, 1))
|