diff --git a/Video-P2P-Beta b/Video-P2P-Beta deleted file mode 160000 index 7a8fa7a8b8d81bbba367865f47b7894cdc4efafb..0000000000000000000000000000000000000000 --- a/Video-P2P-Beta +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 7a8fa7a8b8d81bbba367865f47b7894cdc4efafb diff --git a/Video-P2P/.DS_Store b/Video-P2P/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/Video-P2P/.DS_Store differ diff --git a/Video-P2P/.gitignore b/Video-P2P/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ca7513e3b359f2d7618466be3a4be9af33824c1b --- /dev/null +++ b/Video-P2P/.gitignore @@ -0,0 +1,3 @@ +*.pyc +*.pt +*.gif \ No newline at end of file diff --git a/Video-P2P/README.md b/Video-P2P/README.md new file mode 100644 index 0000000000000000000000000000000000000000..002d37e6240c1fd7666e58d7562af48f536f6d71 --- /dev/null +++ b/Video-P2P/README.md @@ -0,0 +1,99 @@ +# Video-P2P: Video Editing with Cross-attention Control +The official implementation of [Video-P2P](https://video-p2p.github.io/). + +[Shaoteng Liu](https://www.shaotengliu.com/), [Yuechen Zhang](https://julianjuaner.github.io/), [Wenbo Li](https://fenglinglwb.github.io/), [Zhe Lin](https://sites.google.com/site/zhelin625/), [Jiaya Jia](https://jiaya.me/) + +[![Project Website](https://img.shields.io/badge/Project-Website-orange)](https://video-p2p.github.io/) +[![arXiv](https://img.shields.io/badge/arXiv-2303.04761-b31b1b.svg)](https://arxiv.org/abs/2303.04761) + +![Teaser](./docs/teaser.png) + +## Changelog + +- 2023.03.20 Release Gradio Demo. +- 2023.03.19 Release Code. +- 2023.03.09 Paper preprint on arxiv. + +## Todo + +- [x] Release the code with 6 examples. +- [x] Update a faster version. +- [x] Release all data. +- [ ] Release the Gradio Demo. +- [ ] Release more configs and new applications. + +## Setup + +``` bash +pip install -r requirements.txt +``` + +The code was tested on both Tesla V100 32GB and RTX3090 24GB. + +The environment is similar to [Tune-A-video](https://github.com/showlab/Tune-A-Video) and [prompt-to-prompt](https://github.com/google/prompt-to-prompt/). + +[xformers](https://github.com/facebookresearch/xformers) on 3090 may meet this [issue](https://github.com/bryandlee/Tune-A-Video/issues/4). + +## Quickstart + +Please replace ``pretrained_model_path'' with the path to your stable-diffusion. + +``` bash +# You can minimize the tuning epochs to speed up. +python run_tuning.py --config="configs/rabbit-jump-tune.yaml" # Tuning to do model initialization. + +# We develop a faster mode (1 min on V100): +python run_videop2p.py --config="configs/rabbit-jump-p2p.yaml" --fast + +# The official mode (10 mins on V100, more stable): +python run_videop2p.py --config="configs/rabbit-jump-p2p.yaml" +``` + +## Dataset + +We release our dataset [here](). +Download them under ./data and explore your creativity! + +## Results + + + + + + + + + + + + + + + + + + + + + + + + + + +
configs/rabbit-jump-p2p.yamlconfigs/penguin-run-p2p.yaml
configs/man-motor-p2p.yamlconfigs/car-drive-p2p.yaml
configs/tiger-forest-p2p.yamlconfigs/bird-forest-p2p.yaml
+ +## Citation +``` +@misc{liu2023videop2p, + author={Liu, Shaoteng and Zhang, Yuechen and Li, Wenbo and Lin, Zhe and Jia, Jiaya}, + title={Video-P2P: Video Editing with Cross-attention Control}, + journal={arXiv:2303.04761}, + year={2023}, +} +``` + +## References +* prompt-to-prompt: https://github.com/google/prompt-to-prompt +* Tune-A-Video: https://github.com/showlab/Tune-A-Video +* diffusers: https://github.com/huggingface/diffusers \ No newline at end of file diff --git a/Video-P2P/configs/.DS_Store b/Video-P2P/configs/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/Video-P2P/configs/.DS_Store differ diff --git a/Video-P2P/configs/bird-forest-p2p.yaml b/Video-P2P/configs/bird-forest-p2p.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7b92ab58ed3ff785e4368dad74e99c8275201c08 --- /dev/null +++ b/Video-P2P/configs/bird-forest-p2p.yaml @@ -0,0 +1,17 @@ +pretrained_model_path: "./outputs/bird-forest" +image_path: "./data/bird_forest" +prompt: "a bird flying in the forest" +prompts: + - "a bird flying in the forest" + - "children drawing of a bird flying in the forest" +eq_params: + words: + - "children" + - "drawing" + values: + - 5 + - 2 +save_name: "children" +is_word_swap: False +cross_replace_steps: 0.8 +self_replace_steps: 0.7 \ No newline at end of file diff --git a/Video-P2P/configs/bird-forest-tune.yaml b/Video-P2P/configs/bird-forest-tune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b91e054fc06e635b50246a2a5f7d358f8a477742 --- /dev/null +++ b/Video-P2P/configs/bird-forest-tune.yaml @@ -0,0 +1,38 @@ +pretrained_model_path: "/data/stable-diffusion/stable-diffusion-v1-5" +output_dir: "./outputs/bird-forest" + +train_data: + video_path: "./data/bird_forest" + prompt: "a bird flying in the forest" + n_sample_frames: 8 + width: 512 + height: 512 + sample_start_idx: 0 + sample_frame_rate: 1 + +validation_data: + prompts: + - "a bird flying in the forest" + video_length: 8 + width: 512 + height: 512 + num_inference_steps: 50 + guidance_scale: 12.5 + use_inv_latent: True + num_inv_steps: 50 + +learning_rate: 3e-5 +train_batch_size: 1 +max_train_steps: 500 +checkpointing_steps: 1000 +validation_steps: 600 +trainable_modules: + - "attn1.to_q" + - "attn2.to_q" + - "attn_temp" + +seed: 33 +mixed_precision: fp16 +use_8bit_adam: False +gradient_checkpointing: True +enable_xformers_memory_efficient_attention: True diff --git a/Video-P2P/configs/car-drive-p2p.yaml b/Video-P2P/configs/car-drive-p2p.yaml new file mode 100644 index 0000000000000000000000000000000000000000..16496be6f956f6b1b90bde1e6a03cb854985f093 --- /dev/null +++ b/Video-P2P/configs/car-drive-p2p.yaml @@ -0,0 +1,16 @@ +pretrained_model_path: "./outputs/car-drive" +image_path: "./data/car" +prompt: "a car is driving on the road" +prompts: + - "a car is driving on the road" + - "a car is driving on the railway" +blend_word: + - 'road' + - 'railway' +eq_params: + words: + - "railway" + values: + - 2 +save_name: "railway" +is_word_swap: True \ No newline at end of file diff --git a/Video-P2P/configs/car-drive-tune.yaml b/Video-P2P/configs/car-drive-tune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..be8b9bdad12effc45b28595181c869616e03b476 --- /dev/null +++ b/Video-P2P/configs/car-drive-tune.yaml @@ -0,0 +1,38 @@ +pretrained_model_path: "/data/stable-diffusion/stable-diffusion-v1-5" +output_dir: "./outputs/car-drive" + +train_data: + video_path: "./data/car" + prompt: "a car is driving on the road" + n_sample_frames: 8 + width: 512 + height: 512 + sample_start_idx: 0 + sample_frame_rate: 1 + +validation_data: + prompts: + - "a car is driving on the railway" + video_length: 8 + width: 512 + height: 512 + num_inference_steps: 50 + guidance_scale: 12.5 + use_inv_latent: True + num_inv_steps: 50 + +learning_rate: 3e-5 +train_batch_size: 1 +max_train_steps: 300 +checkpointing_steps: 1000 +validation_steps: 300 +trainable_modules: + - "attn1.to_q" + - "attn2.to_q" + - "attn_temp" + +seed: 33 +mixed_precision: fp16 +use_8bit_adam: False +gradient_checkpointing: True +enable_xformers_memory_efficient_attention: True diff --git a/Video-P2P/configs/man-motor-p2p.yaml b/Video-P2P/configs/man-motor-p2p.yaml new file mode 100644 index 0000000000000000000000000000000000000000..773222a9963877c5e2bf7ea2194846963cb9ba83 --- /dev/null +++ b/Video-P2P/configs/man-motor-p2p.yaml @@ -0,0 +1,16 @@ +pretrained_model_path: "./outputs/man-motor" +image_path: "./data/motorbike" +prompt: "a man is driving a motorbike in the forest" +prompts: + - "a man is driving a motorbike in the forest" + - "a Spider-Man is driving a motorbike in the forest" +blend_word: + - 'man' + - 'Spider-Man' +eq_params: + words: + - "Spider-Man" + values: + - 4 +save_name: "spider" +is_word_swap: True \ No newline at end of file diff --git a/Video-P2P/configs/man-motor-tune.yaml b/Video-P2P/configs/man-motor-tune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..24748f6098780987f395b297aec665c605f7d598 --- /dev/null +++ b/Video-P2P/configs/man-motor-tune.yaml @@ -0,0 +1,38 @@ +pretrained_model_path: "/data/stable-diffusion/stable-diffusion-v1-5" +output_dir: "./outputs/man-motor" + +train_data: + video_path: "./data/motorbike" + prompt: "a man is driving a motorbike in the forest" + n_sample_frames: 8 + width: 512 + height: 512 + sample_start_idx: 0 + sample_frame_rate: 1 + +validation_data: + prompts: + - "a Spider-Man is driving a motorbike in the forest" + video_length: 8 + width: 512 + height: 512 + num_inference_steps: 50 + guidance_scale: 12.5 + use_inv_latent: True + num_inv_steps: 50 + +learning_rate: 3e-5 +train_batch_size: 1 +max_train_steps: 500 +checkpointing_steps: 1000 +validation_steps: 500 +trainable_modules: + - "attn1.to_q" + - "attn2.to_q" + - "attn_temp" + +seed: 33 +mixed_precision: fp16 +use_8bit_adam: False +gradient_checkpointing: True +enable_xformers_memory_efficient_attention: True diff --git a/Video-P2P/configs/man-surfing-tune.yaml b/Video-P2P/configs/man-surfing-tune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5c352c215636e5929fcb5991f0cecc269d896a78 --- /dev/null +++ b/Video-P2P/configs/man-surfing-tune.yaml @@ -0,0 +1,38 @@ +pretrained_model_path: "./checkpoints/stable-diffusion-v1-4" +output_dir: "./outputs/man-surfing" + +train_data: + video_path: "data/man-surfing.mp4" + prompt: "a man is surfing" + n_sample_frames: 8 + width: 512 + height: 512 + sample_start_idx: 0 + sample_frame_rate: 1 + +validation_data: + prompts: + - "a panda is surfing" + video_length: 8 + width: 512 + height: 512 + num_inference_steps: 50 + guidance_scale: 12.5 + use_inv_latent: True + num_inv_steps: 50 + +learning_rate: 3e-5 +train_batch_size: 1 +max_train_steps: 500 +checkpointing_steps: 1000 +validation_steps: 500 +trainable_modules: + - "attn1.to_q" + - "attn2.to_q" + - "attn_temp" + +seed: 33 +mixed_precision: fp16 +use_8bit_adam: False +gradient_checkpointing: True +enable_xformers_memory_efficient_attention: True diff --git a/Video-P2P/configs/penguin-run-p2p.yaml b/Video-P2P/configs/penguin-run-p2p.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f72ee8c0cb120ed7f9f82079f995b0a66d6d491b --- /dev/null +++ b/Video-P2P/configs/penguin-run-p2p.yaml @@ -0,0 +1,16 @@ +pretrained_model_path: "./outputs/penguin-run" +image_path: "./data/penguin_ice" +prompt: "a penguin is running on the ice" +prompts: + - "a penguin is running on the ice" + - "a crochet penguin is running on the ice" +blend_word: + - 'penguin' + - 'penguin' +eq_params: + words: + - "crochet" + values: + - 4 +save_name: "crochet" +is_word_swap: False \ No newline at end of file diff --git a/Video-P2P/configs/penguin-run-tune.yaml b/Video-P2P/configs/penguin-run-tune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d38d1f096f6b69ad0c384f360533a290a1adcfc --- /dev/null +++ b/Video-P2P/configs/penguin-run-tune.yaml @@ -0,0 +1,38 @@ +pretrained_model_path: "/data/stable-diffusion/stable-diffusion-v1-5" +output_dir: "./outputs/penguin-run" + +train_data: + video_path: "./data/penguin_ice" + prompt: "a penguin is running on the ice" + n_sample_frames: 8 + width: 512 + height: 512 + sample_start_idx: 0 + sample_frame_rate: 1 + +validation_data: + prompts: + - "a crochet penguin is running on the ice" + video_length: 8 + width: 512 + height: 512 + num_inference_steps: 50 + guidance_scale: 12.5 + use_inv_latent: True + num_inv_steps: 50 + +learning_rate: 3e-5 +train_batch_size: 1 +max_train_steps: 300 +checkpointing_steps: 1000 +validation_steps: 300 +trainable_modules: + - "attn1.to_q" + - "attn2.to_q" + - "attn_temp" + +seed: 33 +mixed_precision: fp16 +use_8bit_adam: False +gradient_checkpointing: True +enable_xformers_memory_efficient_attention: True diff --git a/Video-P2P/configs/rabbit-jump-p2p.yaml b/Video-P2P/configs/rabbit-jump-p2p.yaml new file mode 100644 index 0000000000000000000000000000000000000000..79250dd8aae50df88e2c8aad86dcfc7159e2e60f --- /dev/null +++ b/Video-P2P/configs/rabbit-jump-p2p.yaml @@ -0,0 +1,16 @@ +pretrained_model_path: "./outputs/rabbit-jump" +image_path: "./data/rabbit" +prompt: "a rabbit is jumping on the grass" +prompts: + - "a rabbit is jumping on the grass" + - "a origami rabbit is jumping on the grass" +blend_word: + - 'rabbit' + - 'rabbit' +eq_params: + words: + - "origami" + values: + - 2 +save_name: "origami" +is_word_swap: False \ No newline at end of file diff --git a/Video-P2P/configs/rabbit-jump-tune.yaml b/Video-P2P/configs/rabbit-jump-tune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b614b8046d819bc9680b05252b551324a61e0e8d --- /dev/null +++ b/Video-P2P/configs/rabbit-jump-tune.yaml @@ -0,0 +1,38 @@ +pretrained_model_path: "/data/stable-diffusion/stable-diffusion-v1-5" +output_dir: "./outputs/rabbit-jump" + +train_data: + video_path: "./data/rabbit" + prompt: "a rabbit is jumping on the grass" + n_sample_frames: 8 + width: 512 + height: 512 + sample_start_idx: 0 + sample_frame_rate: 1 + +validation_data: + prompts: + - "a origami rabbit is jumping on the grass" + video_length: 8 + width: 512 + height: 512 + num_inference_steps: 50 + guidance_scale: 12.5 + use_inv_latent: True + num_inv_steps: 50 + +learning_rate: 3e-5 +train_batch_size: 1 +max_train_steps: 500 +checkpointing_steps: 1000 +validation_steps: 500 +trainable_modules: + - "attn1.to_q" + - "attn2.to_q" + - "attn_temp" + +seed: 33 +mixed_precision: fp16 +use_8bit_adam: False +gradient_checkpointing: True +enable_xformers_memory_efficient_attention: True \ No newline at end of file diff --git a/Video-P2P/configs/tiger-forest-p2p.yaml b/Video-P2P/configs/tiger-forest-p2p.yaml new file mode 100644 index 0000000000000000000000000000000000000000..779d400ea2163029cfd5c29a2a5a9501b1327d08 --- /dev/null +++ b/Video-P2P/configs/tiger-forest-p2p.yaml @@ -0,0 +1,16 @@ +pretrained_model_path: "./outputs/tiger-forest" +image_path: "./data/tiger" +prompt: "a tiger is walking in the forest" +prompts: + - "a tiger is walking in the forest" + - "a Lego tiger is walking in the forest" +blend_word: + - 'tiger' + - 'tiger' +eq_params: + words: + - "Lego" + values: + - 2 +save_name: "lego" +is_word_swap: False \ No newline at end of file diff --git a/Video-P2P/configs/tiger-forest-tune.yaml b/Video-P2P/configs/tiger-forest-tune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d2922dde7c174ae10c381323904b5cc7184339d --- /dev/null +++ b/Video-P2P/configs/tiger-forest-tune.yaml @@ -0,0 +1,38 @@ +pretrained_model_path: "/data/stable-diffusion/stable-diffusion-v1-5" +output_dir: "./outputs/tiger-forest" + +train_data: + video_path: "./data/tiger" + prompt: "a tiger is walking in the forest" + n_sample_frames: 8 + width: 512 + height: 512 + sample_start_idx: 0 + sample_frame_rate: 1 + +validation_data: + prompts: + - "a Lego tiger is walking in the forest" + video_length: 8 + width: 512 + height: 512 + num_inference_steps: 50 + guidance_scale: 12.5 + use_inv_latent: True + num_inv_steps: 50 + +learning_rate: 3e-5 +train_batch_size: 1 +max_train_steps: 500 +checkpointing_steps: 1000 +validation_steps: 500 +trainable_modules: + - "attn1.to_q" + - "attn2.to_q" + - "attn_temp" + +seed: 33 +mixed_precision: fp16 +use_8bit_adam: False +gradient_checkpointing: True +enable_xformers_memory_efficient_attention: True diff --git a/Video-P2P/data/.DS_Store b/Video-P2P/data/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..9d73c58a2a6fdf6ffdd4821c2ccf45362871baf5 Binary files /dev/null and b/Video-P2P/data/.DS_Store differ diff --git a/Video-P2P/data/car/.DS_Store b/Video-P2P/data/car/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/Video-P2P/data/car/.DS_Store differ diff --git a/Video-P2P/data/car/1.jpg b/Video-P2P/data/car/1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a23b7bf38f444569168ba3df93ff13f70a3c50ca Binary files /dev/null and b/Video-P2P/data/car/1.jpg differ diff --git a/Video-P2P/data/car/2.jpg b/Video-P2P/data/car/2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bad4a4b4465749a013602536f0b1e440133194ea Binary files /dev/null and b/Video-P2P/data/car/2.jpg differ diff --git a/Video-P2P/data/car/3.jpg b/Video-P2P/data/car/3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5c1fd480dce449d18b21d3773ad6ecfc055f24ac Binary files /dev/null and b/Video-P2P/data/car/3.jpg differ diff --git a/Video-P2P/data/car/4.jpg b/Video-P2P/data/car/4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1d8a953522da74f2679e71f709172e2894171aec Binary files /dev/null and b/Video-P2P/data/car/4.jpg differ diff --git a/Video-P2P/data/car/5.jpg b/Video-P2P/data/car/5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b26e80259deef8d62da8f9c18592bed40271a1c3 Binary files /dev/null and b/Video-P2P/data/car/5.jpg differ diff --git a/Video-P2P/data/car/6.jpg b/Video-P2P/data/car/6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bb16a747d96ee17ea70655a75c6e8cbb7d4d48ed Binary files /dev/null and b/Video-P2P/data/car/6.jpg differ diff --git a/Video-P2P/data/car/7.jpg b/Video-P2P/data/car/7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8e33e8d5b0c3fbc1fe2dbc31f346348b87224a9a Binary files /dev/null and b/Video-P2P/data/car/7.jpg differ diff --git a/Video-P2P/data/car/8.jpg b/Video-P2P/data/car/8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d4e3599129cce6e99d1b5fe166912865dd532678 Binary files /dev/null and b/Video-P2P/data/car/8.jpg differ diff --git a/Video-P2P/data/motorbike/1.jpg b/Video-P2P/data/motorbike/1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..576f9b86785cd48199c2e600187787f825f532d2 Binary files /dev/null and b/Video-P2P/data/motorbike/1.jpg differ diff --git a/Video-P2P/data/motorbike/2.jpg b/Video-P2P/data/motorbike/2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7373af59cfedf00be49feecdd00dfe9cf265c14e Binary files /dev/null and b/Video-P2P/data/motorbike/2.jpg differ diff --git a/Video-P2P/data/motorbike/3.jpg b/Video-P2P/data/motorbike/3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dc59c351f75517a12742694a160d5813f73fd7b8 Binary files /dev/null and b/Video-P2P/data/motorbike/3.jpg differ diff --git a/Video-P2P/data/motorbike/4.jpg b/Video-P2P/data/motorbike/4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e76312317bc0a35c196915210926acf4dbc39fed Binary files /dev/null and b/Video-P2P/data/motorbike/4.jpg differ diff --git a/Video-P2P/data/motorbike/5.jpg b/Video-P2P/data/motorbike/5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d937146774caf8b796795c1cabbd2ba426068f78 Binary files /dev/null and b/Video-P2P/data/motorbike/5.jpg differ diff --git a/Video-P2P/data/motorbike/6.jpg b/Video-P2P/data/motorbike/6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d0991d2ae09b5c732a1c25acda7a6278409f7c10 Binary files /dev/null and b/Video-P2P/data/motorbike/6.jpg differ diff --git a/Video-P2P/data/motorbike/7.jpg b/Video-P2P/data/motorbike/7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e857f9e3f3535b098223ded490ba4b7a05ac5604 Binary files /dev/null and b/Video-P2P/data/motorbike/7.jpg differ diff --git a/Video-P2P/data/motorbike/8.jpg b/Video-P2P/data/motorbike/8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1d1ec5efc6f66c05d0139a8f6b298490c6083d81 Binary files /dev/null and b/Video-P2P/data/motorbike/8.jpg differ diff --git a/Video-P2P/data/penguin_ice/1.jpg b/Video-P2P/data/penguin_ice/1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2d0cbc2b0f1644eb07a6fca8ad398835b29826d1 Binary files /dev/null and b/Video-P2P/data/penguin_ice/1.jpg differ diff --git a/Video-P2P/data/penguin_ice/2.jpg b/Video-P2P/data/penguin_ice/2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..013a63c9a88594fa7a31db7e0e137f14d03b611a Binary files /dev/null and b/Video-P2P/data/penguin_ice/2.jpg differ diff --git a/Video-P2P/data/penguin_ice/3.jpg b/Video-P2P/data/penguin_ice/3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e743293b7f3c9d881e0bd8bed72201a465bd2ab3 Binary files /dev/null and b/Video-P2P/data/penguin_ice/3.jpg differ diff --git a/Video-P2P/data/penguin_ice/4.jpg b/Video-P2P/data/penguin_ice/4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6b5354b4cd7f1cf32fc4916a2c1252b04431bbe4 Binary files /dev/null and b/Video-P2P/data/penguin_ice/4.jpg differ diff --git a/Video-P2P/data/penguin_ice/5.jpg b/Video-P2P/data/penguin_ice/5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4f63a0126c10adb69575d78057e93c683b455907 Binary files /dev/null and b/Video-P2P/data/penguin_ice/5.jpg differ diff --git a/Video-P2P/data/penguin_ice/6.jpg b/Video-P2P/data/penguin_ice/6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a81a617a8850493c86157f65a06b9f3ad09550e0 Binary files /dev/null and b/Video-P2P/data/penguin_ice/6.jpg differ diff --git a/Video-P2P/data/penguin_ice/7.jpg b/Video-P2P/data/penguin_ice/7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..90f8fa773284d22a3ff1a652aed59a4331fb2bee Binary files /dev/null and b/Video-P2P/data/penguin_ice/7.jpg differ diff --git a/Video-P2P/data/penguin_ice/8.jpg b/Video-P2P/data/penguin_ice/8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3a79af8157bb18f594dc2b90568cc0d24f8c7e9f Binary files /dev/null and b/Video-P2P/data/penguin_ice/8.jpg differ diff --git a/Video-P2P/data/rabbit/1.jpg b/Video-P2P/data/rabbit/1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0597d1ba2b38816ac9da2af656c41c5ce0fe77e3 Binary files /dev/null and b/Video-P2P/data/rabbit/1.jpg differ diff --git a/Video-P2P/data/rabbit/2.jpg b/Video-P2P/data/rabbit/2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..27db513bf15cef1e0678e417eb8943fd59f69af0 Binary files /dev/null and b/Video-P2P/data/rabbit/2.jpg differ diff --git a/Video-P2P/data/rabbit/3.jpg b/Video-P2P/data/rabbit/3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d882e7ff058b9a2df256354fc32d9cecab18ff4e Binary files /dev/null and b/Video-P2P/data/rabbit/3.jpg differ diff --git a/Video-P2P/data/rabbit/4.jpg b/Video-P2P/data/rabbit/4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..61663bbfc03afb34052c6df75dec1e5a18aaf3d6 Binary files /dev/null and b/Video-P2P/data/rabbit/4.jpg differ diff --git a/Video-P2P/data/rabbit/5.jpg b/Video-P2P/data/rabbit/5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d46d1b77365532531b41fe5ff236e6ec437d560e Binary files /dev/null and b/Video-P2P/data/rabbit/5.jpg differ diff --git a/Video-P2P/data/rabbit/6.jpg b/Video-P2P/data/rabbit/6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..adabd6eafcbe80b621ab0464693ae0e67f27a639 Binary files /dev/null and b/Video-P2P/data/rabbit/6.jpg differ diff --git a/Video-P2P/data/rabbit/7.jpg b/Video-P2P/data/rabbit/7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..08c16fb984a244e815eaa7b7d09c4d11b620a206 Binary files /dev/null and b/Video-P2P/data/rabbit/7.jpg differ diff --git a/Video-P2P/data/rabbit/8.jpg b/Video-P2P/data/rabbit/8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..75f9546a75adc6dbeb40265680683fb5c85e3596 Binary files /dev/null and b/Video-P2P/data/rabbit/8.jpg differ diff --git a/Video-P2P/ptp_utils.py b/Video-P2P/ptp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..85725e7e7e146f6f4ea478abf9672632d9d68bee --- /dev/null +++ b/Video-P2P/ptp_utils.py @@ -0,0 +1,311 @@ +# From https://github.com/google/prompt-to-prompt/: + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +import cv2 +from typing import Optional, Union, Tuple, List, Callable, Dict +from IPython.display import display +from tqdm.notebook import tqdm + + +def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): + h, w, c = image.shape + offset = int(h * .2) + img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 + font = cv2.FONT_HERSHEY_SIMPLEX + img[:h] = image + textsize = cv2.getTextSize(text, font, 1, 2)[0] + text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 + cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) + return img + + +def view_images(images, num_rows=1, offset_ratio=0.02): + if type(images) is list: + num_empty = len(images) % num_rows + elif images.ndim == 4: + num_empty = images.shape[0] % num_rows + else: + images = [images] + num_empty = 0 + + empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 + images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty + num_items = len(images) + + h, w, c = images[0].shape + offset = int(h * offset_ratio) + num_cols = num_items // num_rows + image_ = np.ones((h * num_rows + offset * (num_rows - 1), + w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 + for i in range(num_rows): + for j in range(num_cols): + image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ + i * num_cols + j] + + pil_img = Image.fromarray(image_) + display(pil_img) + + +def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False, simple=False): + if low_resource: + noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"] + noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"] + else: + latents_input = torch.cat([latents] * 2) + noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"] + noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + if simple: + noise_pred[0] = noise_prediction_text[0] + latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] + # first latents: torch.Size([1, 4, 4, 64, 64]) + latents = controller.step_callback(latents) + return latents + + +def latent2image(vae, latents): + latents = 1 / 0.18215 * latents + image = vae.decode(latents)['sample'] + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + image = (image * 255).astype(np.uint8) + return image + + +@torch.no_grad() +def latent2image_video(vae, latents): + latents = 1 / 0.18215 * latents + latents = latents[0].permute(1, 0, 2, 3) + image = vae.decode(latents)['sample'] + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + image = (image * 255).astype(np.uint8) + return image + + +def init_latent(latent, model, height, width, generator, batch_size): + if latent is None: + latent = torch.randn( + (1, model.unet.in_channels, height // 8, width // 8), + generator=generator, + ) + latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) + return latent, latents + + +@torch.no_grad() +def text2image_ldm( + model, + prompt: List[str], + controller, + num_inference_steps: int = 50, + guidance_scale: Optional[float] = 7., + generator: Optional[torch.Generator] = None, + latent: Optional[torch.FloatTensor] = None, +): + register_attention_control(model, controller) + height = width = 256 + batch_size = len(prompt) + + uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") + uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0] + + text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") + text_embeddings = model.bert(text_input.input_ids.to(model.device))[0] + latent, latents = init_latent(latent, model, height, width, generator, batch_size) + context = torch.cat([uncond_embeddings, text_embeddings]) + + model.scheduler.set_timesteps(num_inference_steps) + for t in tqdm(model.scheduler.timesteps): + latents = diffusion_step(model, controller, latents, context, t, guidance_scale) + + image = latent2image(model.vqvae, latents) + + return image, latent + + +@torch.no_grad() +def text2image_ldm_stable( + model, + prompt: List[str], + controller, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + generator: Optional[torch.Generator] = None, + latent: Optional[torch.FloatTensor] = None, + low_resource: bool = False, +): + register_attention_control(model, controller) + height = width = 512 + batch_size = len(prompt) + + text_input = model.tokenizer( + prompt, + padding="max_length", + max_length=model.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] + max_length = text_input.input_ids.shape[-1] + uncond_input = model.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0] + + context = [uncond_embeddings, text_embeddings] + if not low_resource: + context = torch.cat(context) + + latent, latents = init_latent(latent, model, height, width, generator, batch_size) + + # set timesteps + extra_set_kwargs = {"offset": 1} + model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + for t in tqdm(model.scheduler.timesteps): + latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource) + + image = latent2image(model.vae, latents) + + return image, latent + + +def register_attention_control(model, controller): + def ca_forward(self, place_in_unet): + to_out = self.to_out + if type(to_out) is torch.nn.modules.container.ModuleList: + to_out = self.to_out[0] + else: + to_out = self.to_out + + def forward(x, encoder_hidden_states=None, attention_mask=None): + context = encoder_hidden_states + mask = attention_mask + batch_size, sequence_length, dim = x.shape + h = self.heads + q = self.to_q(x) + is_cross = context is not None + context = context if is_cross else x + k = self.to_k(context) + v = self.to_v(context) + q = self.reshape_heads_to_batch_dim(q) + k = self.reshape_heads_to_batch_dim(k) + v = self.reshape_heads_to_batch_dim(v) + sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale # q: torch.Size([128, 4096, 40]); k: torch.Size([64, 77, 40]) + + if mask is not None: + mask = mask.reshape(batch_size, -1) + max_neg_value = -torch.finfo(sim.dtype).max + mask = mask[:, None, :].repeat(h, 1, 1) + sim.masked_fill_(~mask, max_neg_value) + + attn = torch.exp(sim-torch.max(sim)) / torch.sum(torch.exp(sim-torch.max(sim)), axis=-1).unsqueeze(-1) + attn = controller(attn, is_cross, place_in_unet) + out = torch.einsum("b i j, b j d -> b i d", attn, v) + out = self.reshape_batch_dim_to_heads(out) + return to_out(out) + + return forward + + class DummyController: + + def __call__(self, *args): + return args[0] + + def __init__(self): + self.num_att_layers = 0 + + if controller is None: + controller = DummyController() + + def register_recr(net_, count, place_in_unet): + if net_.__class__.__name__ == 'CrossAttention': + net_.forward = ca_forward(net_, place_in_unet) + return count + 1 + elif hasattr(net_, 'children'): + for net__ in net_.children(): + count = register_recr(net__, count, place_in_unet) + return count + + cross_att_count = 0 + sub_nets = model.unet.named_children() + for net in sub_nets: + if "down" in net[0]: + cross_att_count += register_recr(net[1], 0, "down") + elif "up" in net[0]: + cross_att_count += register_recr(net[1], 0, "up") + elif "mid" in net[0]: + cross_att_count += register_recr(net[1], 0, "mid") + + controller.num_att_layers = cross_att_count + + +def get_word_inds(text: str, word_place: int, tokenizer): + split_text = text.split(" ") + if type(word_place) is str: + word_place = [i for i, word in enumerate(split_text) if word_place == word] + elif type(word_place) is int: + word_place = [word_place] + out = [] + if len(word_place) > 0: + words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] + cur_len, ptr = 0, 0 + + for i in range(len(words_encode)): + cur_len += len(words_encode[i]) + if ptr in word_place: + out.append(i + 1) + if cur_len >= len(split_text[ptr]): + ptr += 1 + cur_len = 0 + return np.array(out) + + +def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, + word_inds: Optional[torch.Tensor]=None): + if type(bounds) is float: + bounds = 0, bounds + start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) + if word_inds is None: + word_inds = torch.arange(alpha.shape[2]) + alpha[: start, prompt_ind, word_inds] = 0 + alpha[start: end, prompt_ind, word_inds] = 1 + alpha[end:, prompt_ind, word_inds] = 0 + return alpha + + +def get_time_words_attention_alpha(prompts, num_steps, + cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], + tokenizer, max_num_words=77): + if type(cross_replace_steps) is not dict: + cross_replace_steps = {"default_": cross_replace_steps} + if "default_" not in cross_replace_steps: + cross_replace_steps["default_"] = (0., 1.) + alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) + for i in range(len(prompts) - 1): # 2 + alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], # {'default_': 0.8} + i) + for key, item in cross_replace_steps.items(): + if key != "default_": + inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] + for i, ind in enumerate(inds): + if len(ind) > 0: + alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) + alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) + return alpha_time_words diff --git a/Video-P2P/requirements.txt b/Video-P2P/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..fbaa5d73d3c11116aa7d1703c54111f81571bc63 --- /dev/null +++ b/Video-P2P/requirements.txt @@ -0,0 +1,15 @@ +torch==1.12.1 +torchvision==0.13.1 +diffusers[torch]==0.11.1 +transformers>=4.25.1 +bitsandbytes==0.35.4 +decord==0.6.0 +accelerate +tensorboard +modelcards +omegaconf +einops +imageio +ftfy +opencv-python +ipywidgets \ No newline at end of file diff --git a/Video-P2P/run_tuning.py b/Video-P2P/run_tuning.py new file mode 100644 index 0000000000000000000000000000000000000000..e917ec564c4519920e6e7f6cf6d064e8b4f509c3 --- /dev/null +++ b/Video-P2P/run_tuning.py @@ -0,0 +1,367 @@ +# From https://github.com/showlab/Tune-A-Video/blob/main/train_tuneavideo.py + +import argparse +import datetime +import logging +import inspect +import math +import os +from typing import Dict, Optional, Tuple +from omegaconf import OmegaConf + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint + +import diffusers +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from tuneavideo.models.unet import UNet3DConditionModel +from tuneavideo.data.dataset import TuneAVideoDataset +from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline +from tuneavideo.util import save_videos_grid, ddim_inversion +from einops import rearrange + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +def main( + pretrained_model_path: str, + output_dir: str, + train_data: Dict, + validation_data: Dict, + validation_steps: int = 100, + trainable_modules: Tuple[str] = ( + "attn1.to_q", + "attn2.to_q", + "attn_temp", + ), + train_batch_size: int = 1, + max_train_steps: int = 500, + learning_rate: float = 3e-5, + scale_lr: bool = False, + lr_scheduler: str = "constant", + lr_warmup_steps: int = 0, + adam_beta1: float = 0.9, + adam_beta2: float = 0.999, + adam_weight_decay: float = 1e-2, + adam_epsilon: float = 1e-08, + max_grad_norm: float = 1.0, + gradient_accumulation_steps: int = 1, + gradient_checkpointing: bool = True, + checkpointing_steps: int = 500, + resume_from_checkpoint: Optional[str] = None, + mixed_precision: Optional[str] = "fp16", + use_8bit_adam: bool = False, + enable_xformers_memory_efficient_attention: bool = True, + seed: Optional[int] = None, +): + *_, config = inspect.getargvalues(inspect.currentframe()) + + accelerator = Accelerator( + gradient_accumulation_steps=gradient_accumulation_steps, + mixed_precision=mixed_precision, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if seed is not None: + set_seed(seed) + + # Handle the output folder creation + if accelerator.is_main_process: + os.makedirs(output_dir, exist_ok=True) + os.makedirs(f"{output_dir}/samples", exist_ok=True) + os.makedirs(f"{output_dir}/inv_latents", exist_ok=True) + OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") + unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet") + + # Freeze vae and text_encoder + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + unet.requires_grad_(False) + for name, module in unet.named_modules(): + if name.endswith(tuple(trainable_modules)): + for params in module.parameters(): + params.requires_grad = True + + if enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if gradient_checkpointing: + unet.enable_gradient_checkpointing() + + if scale_lr: + learning_rate = ( + learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + unet.parameters(), + lr=learning_rate, + betas=(adam_beta1, adam_beta2), + weight_decay=adam_weight_decay, + eps=adam_epsilon, + ) + + # Get the training dataset + train_dataset = TuneAVideoDataset(**train_data) + + # Preprocessing the dataset + train_dataset.prompt_ids = tokenizer( + train_dataset.prompt, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ).input_ids[0] + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=train_batch_size + ) + + # Get the validation pipeline + validation_pipeline = TuneAVideoPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, + scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") + ) + validation_pipeline.enable_vae_slicing() + ddim_inv_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler') + ddim_inv_scheduler.set_timesteps(validation_data.num_inv_steps) + + # Scheduler + lr_scheduler = get_scheduler( + lr_scheduler, + optimizer=optimizer, + num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, + num_training_steps=max_train_steps * gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu and cast to weight_dtype + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + # Afterwards we recalculate our number of training epochs + num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("text2video-fine-tune") + + # Train! + total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if resume_from_checkpoint: + if resume_from_checkpoint != "latest": + path = os.path.basename(resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(output_dir, path)) + global_step = int(path.split("-")[1]) + + first_epoch = global_step // num_update_steps_per_epoch + resume_step = global_step % num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, num_train_epochs): + unet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(unet): + # Convert videos to latent space + pixel_values = batch["pixel_values"].to(weight_dtype) + video_length = pixel_values.shape[1] + pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w") + latents = vae.encode(pixel_values).latent_dist.sample() + latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each video + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["prompt_ids"])[0] + + # Get the target for loss depending on the prediction type + if noise_scheduler.prediction_type == "epsilon": + target = noise + elif noise_scheduler.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}") + + # Predict the noise residual and compute loss + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean() + train_loss += avg_loss.item() / gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = os.path.join(output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % validation_steps == 0: + if accelerator.is_main_process: + samples = [] + generator = torch.Generator(device=latents.device) + generator.manual_seed(seed) + + ddim_inv_latent = None + if validation_data.use_inv_latent: + inv_latents_path = os.path.join(output_dir, f"inv_latents/ddim_latent-{global_step}.pt") + ddim_inv_latent = ddim_inversion( + validation_pipeline, ddim_inv_scheduler, video_latent=latents, + num_inv_steps=validation_data.num_inv_steps, prompt="")[-1].to(weight_dtype) + torch.save(ddim_inv_latent, inv_latents_path) + + for idx, prompt in enumerate(validation_data.prompts): + sample = validation_pipeline(prompt, generator=generator, latents=ddim_inv_latent, + **validation_data).videos + save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{prompt}.gif") + samples.append(sample) + samples = torch.concat(samples) + save_path = f"{output_dir}/samples/sample-{global_step}.gif" + save_videos_grid(samples, save_path) + logger.info(f"Saved samples to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= max_train_steps: + break + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + pipeline = TuneAVideoPipeline.from_pretrained( + pretrained_model_path, + text_encoder=text_encoder, + vae=vae, + unet=unet, + ) + pipeline.save_pretrained(output_dir) + + accelerator.end_training() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="./configs/tuneavideo.yaml") + args = parser.parse_args() + + main(**OmegaConf.load(args.config)) diff --git a/Video-P2P/run_videop2p.py b/Video-P2P/run_videop2p.py new file mode 100644 index 0000000000000000000000000000000000000000..a108e039a0b2c8dcf0e2a1e02c8e8f870f749276 --- /dev/null +++ b/Video-P2P/run_videop2p.py @@ -0,0 +1,664 @@ +# Adapted from https://github.com/google/prompt-to-prompt/blob/main/null_text_w_ptp.ipynb + +import os +from typing import Optional, Union, Tuple, List, Callable, Dict +from tqdm.notebook import tqdm +import torch +from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL +import torch.nn.functional as nnf +import numpy as np +import abc +import ptp_utils +import seq_aligner +import shutil +from torch.optim.adam import Adam +from PIL import Image +from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer +from einops import rearrange + +from tuneavideo.models.unet import UNet3DConditionModel +from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline + +import cv2 +import argparse +from omegaconf import OmegaConf + +scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) +MY_TOKEN = '' +LOW_RESOURCE = False +NUM_DDIM_STEPS = 50 +GUIDANCE_SCALE = 7.5 +MAX_NUM_WORDS = 77 +device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + +# need to adjust sometimes +mask_th = (.3, .3) + +def main( + pretrained_model_path: str, + image_path: str, + prompt: str, + prompts: Tuple[str], + eq_params: Dict, + save_name: str, + is_word_swap: bool, + blend_word: Tuple[str] = None, + cross_replace_steps: float = 0.2, + self_replace_steps: float = 0.5, + video_len: int = 8, + fast: bool = False, + mixed_precision: str = 'fp32', +): + output_folder = os.path.join(pretrained_model_path, 'results') + if fast: + save_name_1 = os.path.join(output_folder, 'inversion_fast.gif') + save_name_2 = os.path.join(output_folder, '{}_fast.gif'.format(save_name)) + else: + save_name_1 = os.path.join(output_folder, 'inversion.gif') + save_name_2 = os.path.join(output_folder, '{}.gif'.format(save_name)) + if blend_word: + blend_word = (((blend_word[0],), (blend_word[1],))) + eq_params = dict(eq_params) + prompts = list(prompts) + cross_replace_steps = {'default_': cross_replace_steps,} + + weight_dtype = torch.float32 + if mixed_precision == "fp16": + weight_dtype = torch.float16 + elif mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + # Load the tokenizer + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_path, + subfolder="text_encoder", + ).to(device, dtype=weight_dtype) + vae = AutoencoderKL.from_pretrained( + pretrained_model_path, + subfolder="vae", + ).to(device, dtype=weight_dtype) + unet = UNet3DConditionModel.from_pretrained( + pretrained_model_path, subfolder="unet" + ).to(device) + ldm_stable = TuneAVideoPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ).to(device) + + try: + ldm_stable.disable_xformers_memory_efficient_attention() + except AttributeError: + print("Attribute disable_xformers_memory_efficient_attention() is missing") + tokenizer = ldm_stable.tokenizer # Tokenizer of class: [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) + # A tokenizer breaks a stream of text into tokens, usually by looking for whitespace (tabs, spaces, new lines). + + class LocalBlend: + + def get_mask(self, maps, alpha, use_pool): + k = 1 + maps = (maps * alpha).sum(-1).mean(2) + if use_pool: + maps = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k)) + mask = nnf.interpolate(maps, size=(x_t.shape[3:])) + mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] + mask = mask.gt(self.th[1-int(use_pool)]) + mask = mask[:1] + mask + return mask + + def __call__(self, x_t, attention_store, step): + self.counter += 1 + if self.counter > self.start_blend: + maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] + maps = [item.reshape(self.alpha_layers.shape[0], -1, 8, 16, 16, MAX_NUM_WORDS) for item in maps] + maps = torch.cat(maps, dim=2) + mask = self.get_mask(maps, self.alpha_layers, True) + if self.substruct_layers is not None: + maps_sub = ~self.get_mask(maps, self.substruct_layers, False) + mask = mask * maps_sub + mask = mask.float() + mask = mask.reshape(-1, 1, mask.shape[-3], mask.shape[-2], mask.shape[-1]) + x_t = x_t[:1] + mask * (x_t - x_t[:1]) + return x_t + + def __init__(self, prompts: List[str], words: [List[List[str]]], substruct_words=None, start_blend=0.2, th=(.3, .3)): + alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS) + for i, (prompt, words_) in enumerate(zip(prompts, words)): + if type(words_) is str: + words_ = [words_] + for word in words_: + ind = ptp_utils.get_word_inds(prompt, word, tokenizer) + alpha_layers[i, :, :, :, :, ind] = 1 + + if substruct_words is not None: + substruct_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS) + for i, (prompt, words_) in enumerate(zip(prompts, substruct_words)): + if type(words_) is str: + words_ = [words_] + for word in words_: + ind = ptp_utils.get_word_inds(prompt, word, tokenizer) + substruct_layers[i, :, :, :, :, ind] = 1 + self.substruct_layers = substruct_layers.to(device) + else: + self.substruct_layers = None + self.alpha_layers = alpha_layers.to(device) + self.start_blend = int(start_blend * NUM_DDIM_STEPS) + self.counter = 0 + self.th=th + + + class EmptyControl: + + + def step_callback(self, x_t): + return x_t + + def between_steps(self): + return + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + return attn + + + class AttentionControl(abc.ABC): + + def step_callback(self, x_t): + return x_t + + def between_steps(self): + return + + @property + def num_uncond_att_layers(self): + return self.num_att_layers if LOW_RESOURCE else 0 + + @abc.abstractmethod + def forward (self, attn, is_cross: bool, place_in_unet: str): + raise NotImplementedError + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + if self.cur_att_layer >= self.num_uncond_att_layers: + if LOW_RESOURCE: + attn = self.forward(attn, is_cross, place_in_unet) + else: + h = attn.shape[0] + attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: + self.cur_att_layer = 0 + self.cur_step += 1 + self.between_steps() + return attn + + def reset(self): + self.cur_step = 0 + self.cur_att_layer = 0 + + def __init__(self): + self.cur_step = 0 + self.num_att_layers = -1 + self.cur_att_layer = 0 + + class SpatialReplace(EmptyControl): + + def step_callback(self, x_t): + if self.cur_step < self.stop_inject: + b = x_t.shape[0] + x_t = x_t[:1].expand(b, *x_t.shape[1:]) + return x_t + + def __init__(self, stop_inject: float): + super(SpatialReplace, self).__init__() + self.stop_inject = int((1 - stop_inject) * NUM_DDIM_STEPS) + + + class AttentionStore(AttentionControl): + + @staticmethod + def get_empty_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], + "down_self": [], "mid_self": [], "up_self": []} + + def forward(self, attn, is_cross: bool, place_in_unet: str): + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + if attn.shape[1] <= 32 ** 2: + self.step_store[key].append(attn) + return attn + + def between_steps(self): + if len(self.attention_store) == 0: + self.attention_store = self.step_store + else: + for key in self.attention_store: + for i in range(len(self.attention_store[key])): + self.attention_store[key][i] += self.step_store[key][i] + self.step_store = self.get_empty_store() + + def get_average_attention(self): + average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} + return average_attention + + + def reset(self): + super(AttentionStore, self).reset() + self.step_store = self.get_empty_store() + self.attention_store = {} + + def __init__(self): + super(AttentionStore, self).__init__() + self.step_store = self.get_empty_store() + self.attention_store = {} + + + class AttentionControlEdit(AttentionStore, abc.ABC): + + def step_callback(self, x_t): + if self.local_blend is not None: + x_t = self.local_blend(x_t, self.attention_store, self.cur_step) + return x_t + + def replace_self_attention(self, attn_base, att_replace, place_in_unet): + if att_replace.shape[2] <= 32 ** 2: + attn_base = attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) + return attn_base + else: + return att_replace + + @abc.abstractmethod + def replace_cross_attention(self, attn_base, att_replace): + raise NotImplementedError + + def forward(self, attn, is_cross: bool, place_in_unet: str): + super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) + if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): + h = attn.shape[0] // (self.batch_size) + attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) + attn_base, attn_repalce = attn[0], attn[1:] + if is_cross: + alpha_words = self.cross_replace_alpha[self.cur_step] + attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce + attn[1:] = attn_repalce_new + else: + attn[1:] = self.replace_self_attention(attn_base, attn_repalce, place_in_unet) + attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) + return attn + + def __init__(self, prompts, num_steps: int, + cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], + self_replace_steps: Union[float, Tuple[float, float]], + local_blend: Optional[LocalBlend]): + super(AttentionControlEdit, self).__init__() + self.batch_size = len(prompts) + self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device) + if type(self_replace_steps) is float: + self_replace_steps = 0, self_replace_steps + self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) + self.local_blend = local_blend + + class AttentionReplace(AttentionControlEdit): + + def replace_cross_attention(self, attn_base, att_replace): + return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) + + def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, + local_blend: Optional[LocalBlend] = None): + super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) + self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) + + + class AttentionRefine(AttentionControlEdit): + + def replace_cross_attention(self, attn_base, att_replace): + attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) + attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) + return attn_replace + + def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, + local_blend: Optional[LocalBlend] = None): + super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) + self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) + self.mapper, alphas = self.mapper.to(device), alphas.to(device) + self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) + + + class AttentionReweight(AttentionControlEdit): + + def replace_cross_attention(self, attn_base, att_replace): + if self.prev_controller is not None: + attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) + attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] + return attn_replace + + def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer, + local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None): + super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) + self.equalizer = equalizer.to(device) + self.prev_controller = controller + + + def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], + Tuple[float, ...]]): + if type(word_select) is int or type(word_select) is str: + word_select = (word_select,) + equalizer = torch.ones(1, 77) + + for word, val in zip(word_select, values): + inds = ptp_utils.get_word_inds(text, word, tokenizer) + equalizer[:, inds] = val + return equalizer + + def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int): + out = [] + attention_maps = attention_store.get_average_attention() + num_pixels = res ** 2 + for location in from_where: + for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: + if item.shape[1] == num_pixels: + cross_maps = item.reshape(8, 8, res, res, item.shape[-1]) + out.append(cross_maps) + out = torch.cat(out, dim=1) + out = out.sum(1) / out.shape[1] + return out.cpu() + + + def make_controller(prompts: List[str], is_replace_controller: bool, cross_replace_steps: Dict[str, float], self_replace_steps: float, blend_words=None, equilizer_params=None, mask_th=(.3,.3)) -> AttentionControlEdit: + if blend_words is None: + lb = None + else: + lb = LocalBlend(prompts, blend_word, th=mask_th) + if is_replace_controller: + controller = AttentionReplace(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, self_replace_steps=self_replace_steps, local_blend=lb) + else: + controller = AttentionRefine(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, self_replace_steps=self_replace_steps, local_blend=lb) + if equilizer_params is not None: + eq = get_equalizer(prompts[1], equilizer_params["words"], equilizer_params["values"]) + controller = AttentionReweight(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, + self_replace_steps=self_replace_steps, equalizer=eq, local_blend=lb, controller=controller) + return controller + + + def load_512_seq(image_path, left=0, right=0, top=0, bottom=0, n_sample_frame=video_len, sampling_rate=1): + images = [] + for file in sorted(os.listdir(image_path)): + images.append(file) + n_images = len(images) + sequence_length = (n_sample_frame - 1) * sampling_rate + 1 + if n_images < sequence_length: + raise ValueError + frames = [] + for index in range(n_sample_frame): + p = os.path.join(image_path, images[index]) + image = np.array(Image.open(p).convert("RGB")) + h, w, c = image.shape + left = min(left, w-1) + right = min(right, w - left - 1) + top = min(top, h - left - 1) + bottom = min(bottom, h - top - 1) + image = image[top:h-bottom, left:w-right] + h, w, c = image.shape + if h < w: + offset = (w - h) // 2 + image = image[:, offset:offset + h] + elif w < h: + offset = (h - w) // 2 + image = image[offset:offset + w] + image = np.array(Image.fromarray(image).resize((512, 512))) + frames.append(image) + return np.stack(frames) + + + class NullInversion: + + def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): + prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output + prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction + return prev_sample + + def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): + timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod + alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] + beta_prod_t = 1 - alpha_prod_t + next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output + next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction + return next_sample + + def get_noise_pred_single(self, latents, t, context): + noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"] + return noise_pred + + def get_noise_pred(self, latents, t, is_forward=True, context=None): + latents_input = torch.cat([latents] * 2) + if context is None: + context = self.context + guidance_scale = 1 if is_forward else GUIDANCE_SCALE + noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"] + noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + if is_forward: + latents = self.next_step(noise_pred, t, latents) + else: + latents = self.prev_step(noise_pred, t, latents) + return latents + + @torch.no_grad() + def latent2image(self, latents, return_type='np'): + latents = 1 / 0.18215 * latents.detach() + image = self.model.vae.decode(latents)['sample'] + if return_type == 'np': + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy()[0] + image = (image * 255).astype(np.uint8) + return image + + @torch.no_grad() + def latent2image_video(self, latents, return_type='np'): + latents = 1 / 0.18215 * latents.detach() + latents = latents[0].permute(1, 0, 2, 3) + image = self.model.vae.decode(latents)['sample'] + if return_type == 'np': + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + image = (image * 255).astype(np.uint8) + return image + + @torch.no_grad() + def image2latent(self, image): + with torch.no_grad(): + if type(image) is Image: + image = np.array(image) + if type(image) is torch.Tensor and image.dim() == 4: + latents = image + else: + image = torch.from_numpy(image).float() / 127.5 - 1 + image = image.permute(2, 0, 1).unsqueeze(0).to(device, dtype=weight_dtype) + latents = self.model.vae.encode(image)['latent_dist'].mean + latents = latents * 0.18215 + return latents + + @torch.no_grad() + def image2latent_video(self, image): + with torch.no_grad(): + image = torch.from_numpy(image).float() / 127.5 - 1 + image = image.permute(0, 3, 1, 2).to(device).to(device, dtype=weight_dtype) + latents = self.model.vae.encode(image)['latent_dist'].mean + latents = rearrange(latents, "(b f) c h w -> b c f h w", b=1) + latents = latents * 0.18215 + return latents + + @torch.no_grad() + def init_prompt(self, prompt: str): + uncond_input = self.model.tokenizer( + [""], padding="max_length", max_length=self.model.tokenizer.model_max_length, + return_tensors="pt" + ) + uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0] + text_input = self.model.tokenizer( + [prompt], + padding="max_length", + max_length=self.model.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0] + self.context = torch.cat([uncond_embeddings, text_embeddings]) + self.prompt = prompt + + @torch.no_grad() + def ddim_loop(self, latent): + uncond_embeddings, cond_embeddings = self.context.chunk(2) + all_latent = [latent] + latent = latent.clone().detach() + for i in range(NUM_DDIM_STEPS): + t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1] + noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings) + latent = self.next_step(noise_pred, t, latent) + all_latent.append(latent) + return all_latent + + @property + def scheduler(self): + return self.model.scheduler + + @torch.no_grad() + def ddim_inversion(self, image): + latent = self.image2latent_video(image) + image_rec = self.latent2image_video(latent) + ddim_latents = self.ddim_loop(latent) + return image_rec, ddim_latents + + def null_optimization(self, latents, num_inner_steps, epsilon): + uncond_embeddings, cond_embeddings = self.context.chunk(2) + uncond_embeddings_list = [] + latent_cur = latents[-1] + bar = tqdm(total=num_inner_steps * NUM_DDIM_STEPS) + for i in range(NUM_DDIM_STEPS): + uncond_embeddings = uncond_embeddings.clone().detach() + uncond_embeddings.requires_grad = True + optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.)) + latent_prev = latents[len(latents) - i - 2] + t = self.model.scheduler.timesteps[i] + with torch.no_grad(): + noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings) + for j in range(num_inner_steps): + noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings) + noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_cond - noise_pred_uncond) + latents_prev_rec = self.prev_step(noise_pred, t, latent_cur) + loss = nnf.mse_loss(latents_prev_rec, latent_prev) + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss_item = loss.item() + bar.update() + if loss_item < epsilon + i * 2e-5: + break + for j in range(j + 1, num_inner_steps): + bar.update() + uncond_embeddings_list.append(uncond_embeddings[:1].detach()) + with torch.no_grad(): + context = torch.cat([uncond_embeddings, cond_embeddings]) + latent_cur = self.get_noise_pred(latent_cur, t, False, context) + bar.close() + return uncond_embeddings_list + + def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False): + self.init_prompt(prompt) + ptp_utils.register_attention_control(self.model, None) + image_gt = load_512_seq(image_path, *offsets) + if verbose: + print("DDIM inversion...") + image_rec, ddim_latents = self.ddim_inversion(image_gt) + if verbose: + print("Null-text optimization...") + uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon) + return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings + + def invert_(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False): + self.init_prompt(prompt) + ptp_utils.register_attention_control(self.model, None) + image_gt = load_512_seq(image_path, *offsets) + if verbose: + print("DDIM inversion...") + image_rec, ddim_latents = self.ddim_inversion(image_gt) + if verbose: + print("Null-text optimization...") + return (image_gt, image_rec), ddim_latents[-1], None + + def __init__(self, model): + scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, + set_alpha_to_one=False) + self.model = model + self.tokenizer = self.model.tokenizer + self.model.scheduler.set_timesteps(NUM_DDIM_STEPS) + self.prompt = None + self.context = None + + null_inversion = NullInversion(ldm_stable) + + ############### + # Custom APIs: + + ldm_stable.enable_xformers_memory_efficient_attention() + + if fast: + (image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert_(image_path, prompt, offsets=(0,0,0,0), verbose=True) + else: + (image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(image_path, prompt, offsets=(0,0,0,0), verbose=True) + + ##### load uncond ##### + # uncond_embeddings_load = np.load(uncond_embeddings_path) + # uncond_embeddings = [] + # for i in range(uncond_embeddings_load.shape[0]): + # uncond_embeddings.append(torch.from_numpy(uncond_embeddings_load[i]).to(device)) + ####################### + + ##### save uncond ##### + # uncond_embeddings = torch.cat(uncond_embeddings) + # uncond_embeddings = uncond_embeddings.cpu().numpy() + ####################### + + print("Start Video-P2P!") + controller = make_controller(prompts, is_word_swap, cross_replace_steps, self_replace_steps, blend_word, eq_params, mask_th=mask_th) + ptp_utils.register_attention_control(ldm_stable, controller) + generator = torch.Generator(device=device) + with torch.no_grad(): + sequence = ldm_stable( + prompts, + generator=generator, + latents=x_t, + uncond_embeddings_pre=uncond_embeddings, + controller = controller, + video_length=video_len, + fast=fast, + ).videos + sequence1 = rearrange(sequence[0], "c t h w -> t h w c") + sequence2 = rearrange(sequence[1], "c t h w -> t h w c") + inversion = [] + videop2p = [] + for i in range(sequence1.shape[0]): + inversion.append( Image.fromarray((sequence1[i] * 255).numpy().astype(np.uint8)) ) + videop2p.append( Image.fromarray((sequence2[i] * 255).numpy().astype(np.uint8)) ) + + inversion[0].save(save_name_1, save_all=True, append_images=inversion[1:], optimize=False, loop=0, duration=250) + videop2p[0].save(save_name_2, save_all=True, append_images=videop2p[1:], optimize=False, loop=0, duration=250) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="./configs/videop2p.yaml") + parser.add_argument("--fast", action='store_true') + args = parser.parse_args() + + main(**OmegaConf.load(args.config), fast=args.fast) diff --git a/Video-P2P/script.sh b/Video-P2P/script.sh new file mode 100644 index 0000000000000000000000000000000000000000..b6c298ed8dd4cd84b732b38825ed2f077f1f1c29 --- /dev/null +++ b/Video-P2P/script.sh @@ -0,0 +1,23 @@ +# python run_tuning.py --config="configs/rabbit-jump-tune.yaml" + +# python run_videop2p.py --config="configs/rabbit-jump-p2p.yaml" --fast + +# python run_tuning.py --config="configs/man-motor-tune.yaml" + +# python run_videop2p.py --config="configs/man-motor-p2p.yaml" + +# python run_tuning.py --config="configs/penguin-run-tune.yaml" + +# python run_videop2p.py --config="configs/penguin-run-p2p.yaml" + +# python run_tuning.py --config="configs/tiger-forest-tune.yaml" + +# python run_videop2p.py --config="configs/tiger-forest-p2p.yaml" --fast + +# python run_tuning.py --config="configs/car-drive-tune.yaml" + +python run_videop2p.py --config="configs/car-drive-p2p.yaml" --fast + +python run_tuning.py --config="configs/bird-forest-tune.yaml" + +python run_videop2p.py --config="configs/bird-forest-p2p.yaml" --fast \ No newline at end of file diff --git a/Video-P2P/seq_aligner.py b/Video-P2P/seq_aligner.py new file mode 100644 index 0000000000000000000000000000000000000000..0f93bd3189607b38bf2f905498b0e0944f36ee5e --- /dev/null +++ b/Video-P2P/seq_aligner.py @@ -0,0 +1,198 @@ +# From https://github.com/google/prompt-to-prompt/: + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import numpy as np + + +class ScoreParams: + + def __init__(self, gap, match, mismatch): + self.gap = gap + self.match = match + self.mismatch = mismatch + + def mis_match_char(self, x, y): + if x != y: + return self.mismatch + else: + return self.match + + +def get_matrix(size_x, size_y, gap): + matrix = [] + for i in range(len(size_x) + 1): + sub_matrix = [] + for j in range(len(size_y) + 1): + sub_matrix.append(0) + matrix.append(sub_matrix) + for j in range(1, len(size_y) + 1): + matrix[0][j] = j*gap + for i in range(1, len(size_x) + 1): + matrix[i][0] = i*gap + return matrix + + +def get_matrix(size_x, size_y, gap): + matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) + matrix[0, 1:] = (np.arange(size_y) + 1) * gap + matrix[1:, 0] = (np.arange(size_x) + 1) * gap + return matrix + + +def get_traceback_matrix(size_x, size_y): + matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32) + matrix[0, 1:] = 1 + matrix[1:, 0] = 2 + matrix[0, 0] = 4 + return matrix + + +def global_align(x, y, score): + matrix = get_matrix(len(x), len(y), score.gap) + trace_back = get_traceback_matrix(len(x), len(y)) + for i in range(1, len(x) + 1): + for j in range(1, len(y) + 1): + left = matrix[i, j - 1] + score.gap + up = matrix[i - 1, j] + score.gap + diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) + matrix[i, j] = max(left, up, diag) + if matrix[i, j] == left: + trace_back[i, j] = 1 + elif matrix[i, j] == up: + trace_back[i, j] = 2 + else: + trace_back[i, j] = 3 + return matrix, trace_back + + +def get_aligned_sequences(x, y, trace_back): + x_seq = [] + y_seq = [] + i = len(x) + j = len(y) + mapper_y_to_x = [] + while i > 0 or j > 0: + if trace_back[i, j] == 3: + x_seq.append(x[i-1]) + y_seq.append(y[j-1]) + i = i-1 + j = j-1 + mapper_y_to_x.append((j, i)) + elif trace_back[i][j] == 1: + x_seq.append('-') + y_seq.append(y[j-1]) + j = j-1 + mapper_y_to_x.append((j, -1)) + elif trace_back[i][j] == 2: + x_seq.append(x[i-1]) + y_seq.append('-') + i = i-1 + elif trace_back[i][j] == 4: + break + mapper_y_to_x.reverse() + return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) + + +def get_mapper(x: str, y: str, tokenizer, max_len=77): + x_seq = tokenizer.encode(x) + y_seq = tokenizer.encode(y) + score = ScoreParams(0, 1, -1) + matrix, trace_back = global_align(x_seq, y_seq, score) + mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] + alphas = torch.ones(max_len) + alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() + mapper = torch.zeros(max_len, dtype=torch.int64) + mapper[:mapper_base.shape[0]] = mapper_base[:, 1] + mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq)) + return mapper, alphas + + +def get_refinement_mapper(prompts, tokenizer, max_len=77): + x_seq = prompts[0] + mappers, alphas = [], [] + for i in range(1, len(prompts)): + mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) + mappers.append(mapper) + alphas.append(alpha) + return torch.stack(mappers), torch.stack(alphas) + + +def get_word_inds(text: str, word_place: int, tokenizer): + split_text = text.split(" ") + if type(word_place) is str: + word_place = [i for i, word in enumerate(split_text) if word_place == word] + elif type(word_place) is int: + word_place = [word_place] + out = [] + if len(word_place) > 0: + words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] + cur_len, ptr = 0, 0 + + for i in range(len(words_encode)): + cur_len += len(words_encode[i]) + if ptr in word_place: + out.append(i + 1) + if cur_len >= len(split_text[ptr]): + ptr += 1 + cur_len = 0 + return np.array(out) + + +def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): + words_x = x.split(' ') + words_y = y.split(' ') + if len(words_x) != len(words_y): + raise ValueError(f"attention replacement edit can only be applied on prompts with the same length" + f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.") + inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] + inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] + inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] + mapper = np.zeros((max_len, max_len)) + i = j = 0 + cur_inds = 0 + while i < max_len and j < max_len: + if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: + inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] + if len(inds_source_) == len(inds_target_): + mapper[inds_source_, inds_target_] = 1 + else: + ratio = 1 / len(inds_target_) + for i_t in inds_target_: + mapper[inds_source_, i_t] = ratio + cur_inds += 1 + i += len(inds_source_) + j += len(inds_target_) + elif cur_inds < len(inds_source): + mapper[i, j] = 1 + i += 1 + j += 1 + else: + mapper[j, j] = 1 + i += 1 + j += 1 + + return torch.from_numpy(mapper).float() + + + +def get_replacement_mapper(prompts, tokenizer, max_len=77): + x_seq = prompts[0] + mappers = [] + for i in range(1, len(prompts)): + mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) + mappers.append(mapper) + return torch.stack(mappers) + diff --git a/Video-P2P/tuneavideo/data/dataset.py b/Video-P2P/tuneavideo/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..199f246633d1a6ba2d106bbb98815a3d8fb9eaad --- /dev/null +++ b/Video-P2P/tuneavideo/data/dataset.py @@ -0,0 +1,57 @@ +import decord +decord.bridge.set_bridge('torch') + +from torch.utils.data import Dataset +from einops import rearrange +import os +from PIL import Image +import numpy as np + +class TuneAVideoDataset(Dataset): + def __init__( + self, + video_path: str, + prompt: str, + width: int = 512, + height: int = 512, + n_sample_frames: int = 8, + sample_start_idx: int = 0, + sample_frame_rate: int = 1, + ): + self.video_path = video_path + self.prompt = prompt + self.prompt_ids = None + self.uncond_prompt_ids = None + + self.width = width + self.height = height + self.n_sample_frames = n_sample_frames + self.sample_start_idx = sample_start_idx + self.sample_frame_rate = sample_frame_rate + + if 'mp4' not in self.video_path: + self.images = [] + for file in sorted(os.listdir(self.video_path), key=lambda x: int(x[:-4])): + if file.endswith('jpg'): + self.images.append(np.asarray(Image.open(os.path.join(self.video_path, file)).convert('RGB').resize((self.width, self.height)))) + self.images = np.stack(self.images) + + def __len__(self): + return 1 + + def __getitem__(self, index): + # load and sample video frames + if 'mp4' in self.video_path: + vr = decord.VideoReader(self.video_path, width=self.width, height=self.height) + sample_index = list(range(self.sample_start_idx, len(vr), self.sample_frame_rate))[:self.n_sample_frames] + video = vr.get_batch(sample_index) + else: + video = self.images[:self.n_sample_frames] + video = rearrange(video, "f h w c -> f c h w") + + example = { + "pixel_values": (video / 127.5 - 1.0), + "prompt_ids": self.prompt_ids, + } + + return example diff --git a/Video-P2P/tuneavideo/models/attention.py b/Video-P2P/tuneavideo/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..a90347820b79efb30c1a94fc85111ea739e12e56 --- /dev/null +++ b/Video-P2P/tuneavideo/models/attention.py @@ -0,0 +1,329 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py +# https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/models/attention.py + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm + +from einops import rearrange, repeat + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + + # SC-Attn + self.attn1 = FrameAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + # Temp-Attn + self.attn_temp = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.only_cross_attention: + hidden_states = ( + self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + ) + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # Temporal-Attention + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) + ) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + + +class FrameAttention(CrossAttention): + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + former_frame_index = torch.arange(video_length) - 1 + former_frame_index[0] = 0 + + key = rearrange(key, "(b f) d c -> b f d c", f=video_length) + key = key[:, [0] * video_length] + key = rearrange(key, "b f d c -> (b f) d c") + + value = rearrange(value, "(b f) d c -> b f d c", f=video_length) + value = value[:, [0] * video_length] + value = rearrange(value, "b f d c -> (b f) d c") + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states \ No newline at end of file diff --git a/Video-P2P/tuneavideo/models/resnet.py b/Video-P2P/tuneavideo/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..be0eeae1511ae13fa128be31338aaed0752fd4bd --- /dev/null +++ b/Video-P2P/tuneavideo/models/resnet.py @@ -0,0 +1,210 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py +# https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/models/resnet.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + raise NotImplementedError + + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) \ No newline at end of file diff --git a/Video-P2P/tuneavideo/models/unet.py b/Video-P2P/tuneavideo/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..f3c3cab4c28f6a6dbfc0cfdd13b1cc8ff313e589 --- /dev/null +++ b/Video-P2P/tuneavideo/models/unet.py @@ -0,0 +1,451 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py +# https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/models/unet.py + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import os +import json + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from .unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) +from .resnet import InflatedConv3d + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D" + ] + config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ] + + from diffusers.utils import WEIGHTS_NAME + model = cls.from_config(config) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + for k, v in model.state_dict().items(): + if '_temp.' in k: + state_dict.update({k: v}) + model.load_state_dict(state_dict) + + return model \ No newline at end of file diff --git a/Video-P2P/tuneavideo/models/unet_blocks.py b/Video-P2P/tuneavideo/models/unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2d21edbc03d0098800a963e74fcf5ef0e29593bd --- /dev/null +++ b/Video-P2P/tuneavideo/models/unet_blocks.py @@ -0,0 +1,589 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py +# https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/models/unet_blocks.py + +import torch +from torch import nn + +from .attention import Transformer3DModel +from .resnet import Downsample3D, ResnetBlock3D, Upsample3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/Video-P2P/tuneavideo/pipelines/pipeline_tuneavideo.py b/Video-P2P/tuneavideo/pipelines/pipeline_tuneavideo.py new file mode 100644 index 0000000000000000000000000000000000000000..9619b2fe7022ebe5467b8968dc1e7004bcb2eae7 --- /dev/null +++ b/Video-P2P/tuneavideo/pipelines/pipeline_tuneavideo.py @@ -0,0 +1,437 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +# https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py + +import inspect +from typing import Callable, List, Optional, Union +from dataclasses import dataclass + +import numpy as np +import torch + +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from diffusers.utils import deprecate, logging, BaseOutput + +from einops import rearrange, repeat + +from ..models.unet import UNet3DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class TuneAVideoPipelineOutput(BaseOutput): + videos: Union[torch.Tensor, np.ndarray] + + +class TuneAVideoPipeline(DiffusionPipeline): + _optional_components = [] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def enable_vae_slicing(self): + self.vae.enable_slicing() + + def disable_vae_slicing(self): + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + + @property + def _execution_device(self): + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def decode_latents(self, latents): + video_length = latents.shape[2] + latents = 1 / 0.18215 * latents + latents = rearrange(latents, "b c f h w -> (b f) c h w") + bs = 4 + video_list = [] + for i in range(max(latents.shape[0]//bs, 1)): + video = self.vae.decode(latents[i*bs:min((i+1)*bs, latents.shape[0])]).sample + video = (video / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + video = video.cpu().float().numpy() + video_list.append(video) + if len(video_list) > 1: + video = np.concatenate(video_list, axis=0) + else: + video = video_list[0] + video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + return video + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + else: + if latents.shape != shape: + # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.expand(shape) + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + video_length: Optional[int], + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + uncond_embeddings_pre=None, + controller=None, + multi=False, + fast=False, + **kwargs, + ): + # Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt + ) + if multi: + text_embeddings = repeat(text_embeddings, 'b n c -> (b f) n c', f=video_length) + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + video_length, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + latents_dtype = latents.dtype + + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if uncond_embeddings_pre is not None: + if multi: + text_embeddings[:video_length] = uncond_embeddings_pre[i] + else: + text_embeddings[0] = uncond_embeddings_pre[i] + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if fast: # not using classifier-free + # after tuning, ddim inversion without CFG can also recover the video at most time + # use null-text to persue more stable performance + noise_pred[0] = noise_pred_text[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + weight_type = latents.dtype + + if controller is not None: + latents = controller.step_callback(latents).to(device, dtype=weight_type) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # Post-processing + video = self.decode_latents(latents) + + # Convert to tensor + if output_type == "tensor": + video = torch.from_numpy(video) + + if not return_dict: + return video + + return TuneAVideoPipelineOutput(videos=video) \ No newline at end of file diff --git a/Video-P2P/tuneavideo/util.py b/Video-P2P/tuneavideo/util.py new file mode 100644 index 0000000000000000000000000000000000000000..5f488e0071ec378730a1ccda89c13fbd6cc1a5b5 --- /dev/null +++ b/Video-P2P/tuneavideo/util.py @@ -0,0 +1,86 @@ +# https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/util.py + +import os +import imageio +import numpy as np +from typing import Union + +import torch +import torchvision + +from tqdm import tqdm +from einops import rearrange + + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + imageio.mimsave(path, outputs, fps=fps) + + +# DDIM Inversion +@torch.no_grad() +def init_prompt(prompt, pipeline): + uncond_input = pipeline.tokenizer( + [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, + return_tensors="pt" + ) + uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] + text_input = pipeline.tokenizer( + [prompt], + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] + context = torch.cat([uncond_embeddings, text_embeddings]) + + return context + + +def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): + timestep, next_timestep = min( + timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep + alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod + alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] + beta_prod_t = 1 - alpha_prod_t + next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output + next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction + return next_sample + + +def get_noise_pred_single(latents, t, context, unet): + noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] + return noise_pred + + +@torch.no_grad() +def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): + context = init_prompt(prompt, pipeline) + uncond_embeddings, cond_embeddings = context.chunk(2) + all_latent = [latent] + latent = latent.clone().detach() + for i in tqdm(range(num_inv_steps)): + t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] + noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) + latent = next_step(noise_pred, t, latent, ddim_scheduler) + all_latent.append(latent) + return all_latent + + +@torch.no_grad() +def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): + ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) + return ddim_latents diff --git a/app.py b/app.py index 82e03a5dc174394511d23e071622c24224975cce..6f2b441e38b276f332ad3c63b41c7b3c4f8b6dbd 100755 --- a/app.py +++ b/app.py @@ -70,13 +70,13 @@ with gr.Blocks(css='style.css') as demo: with gr.Tabs(): with gr.TabItem('Train'): create_training_demo(trainer, pipe) - with gr.TabItem('Run'): - create_inference_demo(pipe, HF_TOKEN) - with gr.TabItem('Upload'): - gr.Markdown(''' - - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed. - ''') - create_upload_demo(HF_TOKEN) + # with gr.TabItem('Run'): + # create_inference_demo(pipe, HF_TOKEN) + # with gr.TabItem('Upload'): + # gr.Markdown(''' + # - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed. + # ''') + # create_upload_demo(HF_TOKEN) if not HF_TOKEN: show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING) diff --git a/app_inference.py b/app_inference.py old mode 100755 new mode 100644 diff --git a/app_upload.py b/app_upload.py old mode 100755 new mode 100644 index f672f555512b456d95d8f674fa832b1c9bf34309..f839c0c33c1ab8a43bc269ede0af920e61ef76cc --- a/app_upload.py +++ b/app_upload.py @@ -75,7 +75,7 @@ def create_upload_demo(hf_token: str | None) -> gr.Blocks: visible=False if hf_token else True) upload_button = gr.Button('Upload') gr.Markdown(f''' - - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}). + - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Video-P2P Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}). ''') with gr.Box(): gr.Markdown('Output message') diff --git a/constants.py b/constants.py index 9fb6e1f7ea852e729e950861e4e5beb4e1e38b75..ae35fb9773b4d6c926ca7d3535681758d521c872 100644 --- a/constants.py +++ b/constants.py @@ -3,8 +3,8 @@ import enum class UploadTarget(enum.Enum): PERSONAL_PROFILE = 'Personal Profile' - MODEL_LIBRARY = 'Tune-A-Video Library' + MODEL_LIBRARY = 'Video-P2P Library' -MODEL_LIBRARY_ORG_NAME = 'Tune-A-Video-library' -SAMPLE_MODEL_REPO = 'Tune-A-Video-library/a-man-is-surfing' +MODEL_LIBRARY_ORG_NAME = 'Video-P2P-library' +SAMPLE_MODEL_REPO = 'Video-P2P-library/a-man-is-surfing' diff --git a/trainer.py b/trainer.py index 5d61d35b6f32d50c0770898e7314d2d68565f793..57de6edd0fdcbd8e03d59375123333ba0ad35192 100644 --- a/trainer.py +++ b/trainer.py @@ -103,7 +103,7 @@ class Trainer: self.join_model_library_org( self.hf_token if self.hf_token else input_token) - config = OmegaConf.load('Tune-A-Video/configs/man-surfing.yaml') + config = OmegaConf.load('Video-P2P/configs/man-surfing-tune.yaml') config.pretrained_model_path = self.download_base_model(base_model) config.output_dir = output_dir.as_posix() config.train_data.video_path = training_video.name # type: ignore @@ -133,7 +133,7 @@ class Trainer: with open(config_path, 'w') as f: OmegaConf.save(config, f) - command = f'accelerate launch Tune-A-Video/train_tuneavideo.py --config {config_path}' + command = f'accelerate launch Video-P2P/run_tuning.py --config {config_path}' subprocess.run(shlex.split(command)) save_model_card(save_dir=output_dir, base_model=base_model,