Sylvain Filoni commited on
Commit
8502051
1 Parent(s): 869bd9e
.DS_Store ADDED
Binary file (8.2 kB). View file
 
GITHUB_README.md ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ControlVideo
2
+
3
+ Official pytorch implementation of "ControlVideo: Training-free Controllable Text-to-Video Generation"
4
+
5
+ [![arXiv](https://img.shields.io/badge/arXiv-2305.13077-b31b1b.svg)](https://arxiv.org/abs/2305.13077)
6
+ ![visitors](https://visitor-badge.laobi.icu/badge?page_id=YBYBZhang/ControlVideo)
7
+ [![Replicate](https://replicate.com/cjwbw/controlvideo/badge)](https://replicate.com/cjwbw/controlvideo)
8
+
9
+ <p align="center">
10
+ <img src="assets/overview.png" width="1080px"/>
11
+ <br>
12
+ <em>ControlVideo adapts ControlNet to the video counterpart without any finetuning, aiming to directly inherit its high-quality and consistent generation </em>
13
+ </p>
14
+
15
+ ## News
16
+
17
+ * [05/28/2023] Thanks [chenxwh](https://github.com/chenxwh), add a [Replicate demo](https://replicate.com/cjwbw/controlvideo)!
18
+ * [05/25/2023] Code [ControlVideo](https://github.com/YBYBZhang/ControlVideo/) released!
19
+ * [05/23/2023] Paper [ControlVideo](https://arxiv.org/abs/2305.13077) released!
20
+
21
+ ## Setup
22
+
23
+ ### 1. Download Weights
24
+ All pre-trained weights are downloaded to `checkpoints/` directory, including the pre-trained weights of [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5), ControlNet conditioned on [canny edges](https://huggingface.co/lllyasviel/sd-controlnet-canny), [depth maps](https://huggingface.co/lllyasviel/sd-controlnet-depth), [human poses](https://huggingface.co/lllyasviel/sd-controlnet-openpose).
25
+ The `flownet.pkl` is the weights of [RIFE](https://github.com/megvii-research/ECCV2022-RIFE).
26
+ The final file tree likes:
27
+
28
+ ```none
29
+ checkpoints
30
+ ├── stable-diffusion-v1-5
31
+ ├── sd-controlnet-canny
32
+ ├── sd-controlnet-depth
33
+ ├── sd-controlnet-openpose
34
+ ├── flownet.pkl
35
+ ```
36
+ ### 2. Requirements
37
+
38
+ ```shell
39
+ conda create -n controlvideo python=3.10
40
+ conda activate controlvideo
41
+ pip install -r requirements.txt
42
+ ```
43
+ `xformers` is recommended to save memory and running time.
44
+
45
+ ## Inference
46
+
47
+ To perform text-to-video generation, just run this command in `inference.sh`:
48
+ ```bash
49
+ python inference.py \
50
+ --prompt "A striking mallard floats effortlessly on the sparkling pond." \
51
+ --condition "depth" \
52
+ --video_path "data/mallard-water.mp4" \
53
+ --output_path "outputs/" \
54
+ --video_length 15 \
55
+ --smoother_steps 19 20 \
56
+ --width 512 \
57
+ --height 512 \
58
+ # --is_long_video
59
+ ```
60
+ where `--video_length` is the length of synthesized video, `--condition` represents the type of structure sequence,
61
+ `--smoother_steps` determines at which timesteps to perform smoothing, and `--is_long_video` denotes whether to enable efficient long-video synthesis.
62
+
63
+ ## Visualizations
64
+
65
+ ### ControlVideo on depth maps
66
+
67
+ <table class="center">
68
+ <tr>
69
+ <td width=30% align="center"><img src="assets/depth/A_charming_flamingo_gracefully_wanders_in_the_calm_and_serene_water,_its_delicate_neck_curving_into_an_elegant_shape..gif" raw=true></td>
70
+ <td width=30% align="center"><img src="assets/depth/A_striking_mallard_floats_effortlessly_on_the_sparkling_pond..gif" raw=true></td>
71
+ <td width=30% align="center"><img src="assets/depth/A_gigantic_yellow_jeep_slowly_turns_on_a_wide,_smooth_road_in_the_city..gif" raw=true></td>
72
+ </tr>
73
+ <tr>
74
+ <td width=30% align="center">"A charming flamingo gracefully wanders in the calm and serene water, its delicate neck curving into an elegant shape."</td>
75
+ <td width=30% align="center">"A striking mallard floats effortlessly on the sparkling pond."</td>
76
+ <td width=30% align="center">"A gigantic yellow jeep slowly turns on a wide, smooth road in the city."</td>
77
+ </tr>
78
+ <tr>
79
+ <td width=30% align="center"><img src="assets/depth/A_sleek_boat_glides_effortlessly_through_the_shimmering_river,_van_gogh_style..gif" raw=true></td>
80
+ <td width=30% align="center"><img src="assets/depth/A_majestic_sailing_boat_cruises_along_the_vast,_azure_sea..gif" raw=true></td>
81
+ <td width=30% align="center"><img src="assets/depth/A_contented_cow_ambles_across_the_dewy,_verdant_pasture..gif" raw=true></td>
82
+ </tr>
83
+ <tr>
84
+ <td width=30% align="center">"A sleek boat glides effortlessly through the shimmering river, van gogh style."</td>
85
+ <td width=30% align="center">"A majestic sailing boat cruises along the vast, azure sea."</td>
86
+ <td width=30% align="center">"A contented cow ambles across the dewy, verdant pasture."</td>
87
+ </tr>
88
+ </table>
89
+
90
+ ### ControlVideo on canny edges
91
+
92
+ <table class="center">
93
+ <tr>
94
+ <td width=30% align="center"><img src="assets/canny/A_young_man_riding_a_sleek,_black_motorbike_through_the_winding_mountain_roads..gif" raw=true></td>
95
+ <td width=30% align="center"><img src="assets/canny/A_white_swan_moving_on_the_lake,_cartoon_style..gif" raw=true></td>
96
+ <td width=30% align="center"><img src="assets/canny/A_dusty_old_jeep_was_making_its_way_down_the_winding_forest_road,_creaking_and_groaning_with_each_bump_and_turn..gif" raw=true></td>
97
+ </tr>
98
+ <tr>
99
+ <td width=30% align="center">"A young man riding a sleek, black motorbike through the winding mountain roads."</td>
100
+ <td width=30% align="center">"A white swan movingon the lake, cartoon style."</td>
101
+ <td width=30% align="center">"A dusty old jeep was making its way down the winding forest road, creaking and groaning with each bump and turn."</td>
102
+ </tr>
103
+ <tr>
104
+ <td width=30% align="center"><img src="assets/canny/A_shiny_red_jeep_smoothly_turns_on_a_narrow,_winding_road_in_the_mountains..gif" raw=true></td>
105
+ <td width=30% align="center"><img src="assets/canny/A_majestic_camel_gracefully_strides_across_the_scorching_desert_sands..gif" raw=true></td>
106
+ <td width=30% align="center"><img src="assets/canny/A_fit_man_is_leisurely_hiking_through_a_lush_and_verdant_forest..gif" raw=true></td>
107
+ </tr>
108
+ <tr>
109
+ <td width=30% align="center">"A shiny red jeep smoothly turns on a narrow, winding road in the mountains."</td>
110
+ <td width=30% align="center">"A majestic camel gracefully strides across the scorching desert sands."</td>
111
+ <td width=30% align="center">"A fit man is leisurely hiking through a lush and verdant forest."</td>
112
+ </tr>
113
+ </table>
114
+
115
+
116
+ ### ControlVideo on human poses
117
+
118
+ <table class="center">
119
+ <tr>
120
+ <td width=25% align="center"><img src="assets/pose/James_bond_moonwalk_on_the_beach,_animation_style.gif" raw=true></td>
121
+ <td width=25% align="center"><img src="assets/pose/Goku_in_a_mountain_range,_surreal_style..gif" raw=true></td>
122
+ <td width=25% align="center"><img src="assets/pose/Hulk_is_jumping_on_the_street,_cartoon_style.gif" raw=true></td>
123
+ <td width=25% align="center"><img src="assets/pose/A_robot_dances_on_a_road,_animation_style.gif" raw=true></td>
124
+ </tr>
125
+ <tr>
126
+ <td width=25% align="center">"James bond moonwalk on the beach, animation style."</td>
127
+ <td width=25% align="center">"Goku in a mountain range, surreal style."</td>
128
+ <td width=25% align="center">"Hulk is jumping on the street, cartoon style."</td>
129
+ <td width=25% align="center">"A robot dances on a road, animation style."</td>
130
+ </tr></table>
131
+
132
+ ### Long video generation
133
+
134
+ <table class="center">
135
+ <tr>
136
+ <td width=60% align="center"><img src="assets/long/A_steamship_on_the_ocean,_at_sunset,_sketch_style.gif" raw=true></td>
137
+ <td width=40% align="center"><img src="assets/long/Hulk_is_dancing_on_the_beach,_cartoon_style.gif" raw=true></td>
138
+ </tr>
139
+ <tr>
140
+ <td width=60% align="center">"A steamship on the ocean, at sunset, sketch style."</td>
141
+ <td width=40% align="center">"Hulk is dancing on the beach, cartoon style."</td>
142
+ </tr>
143
+ </table>
144
+
145
+ ## Citation
146
+ If you make use of our work, please cite our paper.
147
+ ```bibtex
148
+ @article{zhang2023controlvideo,
149
+ title={ControlVideo: Training-free Controllable Text-to-Video Generation},
150
+ author={Zhang, Yabo and Wei, Yuxiang and Jiang, Dongsheng and Zhang, Xiaopeng and Zuo, Wangmeng and Tian, Qi},
151
+ journal={arXiv preprint arXiv:2305.13077},
152
+ year={2023}
153
+ }
154
+ ```
155
+
156
+ ## Acknowledgement
157
+ This work repository borrows heavily from [Diffusers](https://github.com/huggingface/diffusers), [ControlNet](https://github.com/lllyasviel/ControlNet), [Tune-A-Video](https://github.com/showlab/Tune-A-Video), and [RIFE](https://github.com/megvii-research/ECCV2022-RIFE).
158
+
159
+ There are also many interesting works on video generation: [Tune-A-Video](https://github.com/showlab/Tune-A-Video), [Text2Video-Zero](https://github.com/Picsart-AI-Research/Text2Video-Zero), [Follow-Your-Pose](https://github.com/mayuelala/FollowYourPose), [Control-A-Video](https://github.com/Weifeng-Chen/control-a-video), et al.
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 YaboZhang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
assets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
cog.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Cog ⚙️
2
+ # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3
+
4
+ build:
5
+ gpu: true
6
+ cuda: "11.6"
7
+ system_packages:
8
+ - "libgl1-mesa-glx"
9
+ - "libglib2.0-0"
10
+ python_version: "3.8"
11
+ python_packages:
12
+ - "accelerate==0.17.1"
13
+ - "addict==2.4.0"
14
+ - "basicsr==1.4.2"
15
+ - "bitsandbytes==0.35.4"
16
+ - "clip==0.2.0"
17
+ - "cmake==3.25.2"
18
+ - "controlnet-aux==0.0.4"
19
+ - "decord==0.6.0"
20
+ - "deepspeed==0.8.0"
21
+ - "diffusers==0.14.0"
22
+ - "easydict==1.10"
23
+ - "einops==0.6.0"
24
+ - "ffmpy==0.3.0"
25
+ - "ftfy==6.1.1"
26
+ - "imageio==2.25.1"
27
+ - "imageio-ffmpeg==0.4.8"
28
+ - "moviepy==1.0.3"
29
+ - "numpy==1.24.2"
30
+ - "omegaconf==2.3.0"
31
+ - "opencv-python==4.7.0.68"
32
+ - "pandas==1.5.3"
33
+ - "pillow==9.4.0"
34
+ - "scikit-image==0.19.3"
35
+ - "scipy==1.10.1"
36
+ - "tensorboard==2.12.0"
37
+ - "tensorboard-data-server==0.7.0"
38
+ - "tensorboard-plugin-wit==1.8.1"
39
+ - "termcolor==2.2.0"
40
+ - "thinc==8.1.10"
41
+ - "timm==0.6.12"
42
+ - "tokenizers==0.13.2"
43
+ - "torch==1.13.1"
44
+ - "torchvision==0.14.1"
45
+ - "tqdm==4.64.1"
46
+ - "transformers==4.26.1"
47
+ - "wandb==0.13.10"
48
+ - "xformers==0.0.16"
49
+ - "positional_encodings==6.0.1"
50
+ - "mediapipe==0.10.0"
51
+ - "triton==2.0.0.post1"
52
+
53
+ predict: "predict.py:Predictor"
inference.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import argparse
4
+ import imageio
5
+ import torch
6
+
7
+ from einops import rearrange
8
+ from diffusers import DDIMScheduler, AutoencoderKL
9
+ from transformers import CLIPTextModel, CLIPTokenizer
10
+ # from annotator.canny import CannyDetector
11
+ # from annotator.openpose import OpenposeDetector
12
+ # from annotator.midas import MidasDetector
13
+ # import sys
14
+ # sys.path.insert(0, ".")
15
+ from huggingface_hub import hf_hub_download
16
+ import controlnet_aux
17
+ from controlnet_aux import OpenposeDetector, CannyDetector, MidasDetector
18
+ from controlnet_aux.open_pose.body import Body
19
+
20
+ from models.pipeline_controlvideo import ControlVideoPipeline
21
+ from models.util import save_videos_grid, read_video, get_annotation
22
+ from models.unet import UNet3DConditionModel
23
+ from models.controlnet import ControlNetModel3D
24
+ from models.RIFE.IFNet_HDv3 import IFNet
25
+
26
+
27
+ device = "cuda"
28
+ sd_path = "checkpoints/stable-diffusion-v1-5"
29
+ inter_path = "checkpoints/flownet.pkl"
30
+ controlnet_dict = {
31
+ "pose": "checkpoints/sd-controlnet-openpose",
32
+ "depth": "checkpoints/sd-controlnet-depth",
33
+ "canny": "checkpoints/sd-controlnet-canny",
34
+ }
35
+
36
+ controlnet_parser_dict = {
37
+ "pose": OpenposeDetector,
38
+ "depth": MidasDetector,
39
+ "canny": CannyDetector,
40
+ }
41
+
42
+ POS_PROMPT = " ,best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth"
43
+ NEG_PROMPT = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic"
44
+
45
+
46
+
47
+ def get_args():
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument("--prompt", type=str, required=True, help="Text description of target video")
50
+ parser.add_argument("--video_path", type=str, required=True, help="Path to a source video")
51
+ parser.add_argument("--output_path", type=str, default="./outputs", help="Directory of output")
52
+ parser.add_argument("--condition", type=str, default="depth", help="Condition of structure sequence")
53
+ parser.add_argument("--video_length", type=int, default=15, help="Length of synthesized video")
54
+ parser.add_argument("--height", type=int, default=512, help="Height of synthesized video, and should be a multiple of 32")
55
+ parser.add_argument("--width", type=int, default=512, help="Width of synthesized video, and should be a multiple of 32")
56
+ parser.add_argument("--smoother_steps", nargs='+', default=[19, 20], type=int, help="Timesteps at which using interleaved-frame smoother")
57
+ parser.add_argument("--is_long_video", action='store_true', help="Whether to use hierarchical sampler to produce long video")
58
+ parser.add_argument("--seed", type=int, default=42, help="Random seed of generator")
59
+
60
+ args = parser.parse_args()
61
+ return args
62
+
63
+ if __name__ == "__main__":
64
+ args = get_args()
65
+ os.makedirs(args.output_path, exist_ok=True)
66
+
67
+ # Height and width should be a multiple of 32
68
+ args.height = (args.height // 32) * 32
69
+ args.width = (args.width // 32) * 32
70
+
71
+ if args.condition == "pose":
72
+ pretrained_model_or_path = "lllyasviel/ControlNet"
73
+ body_model_path = hf_hub_download(pretrained_model_or_path, "annotator/ckpts/body_pose_model.pth", cache_dir="checkpoints")
74
+ body_estimation = Body(body_model_path)
75
+ annotator = controlnet_parser_dict[args.condition](body_estimation)
76
+ else:
77
+ annotator = controlnet_parser_dict[args.condition]()
78
+
79
+ tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
80
+ text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").to(dtype=torch.float16)
81
+ vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae").to(dtype=torch.float16)
82
+ unet = UNet3DConditionModel.from_pretrained_2d(sd_path, subfolder="unet").to(dtype=torch.float16)
83
+ controlnet = ControlNetModel3D.from_pretrained_2d(controlnet_dict[args.condition]).to(dtype=torch.float16)
84
+ interpolater = IFNet(ckpt_path=inter_path).to(dtype=torch.float16)
85
+ scheduler=DDIMScheduler.from_pretrained(sd_path, subfolder="scheduler")
86
+
87
+ pipe = ControlVideoPipeline(
88
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
89
+ controlnet=controlnet, interpolater=interpolater, scheduler=scheduler,
90
+ )
91
+ pipe.enable_vae_slicing()
92
+ pipe.enable_xformers_memory_efficient_attention()
93
+ pipe.to(device)
94
+
95
+ generator = torch.Generator(device="cuda")
96
+ generator.manual_seed(args.seed)
97
+
98
+ # Step 1. Read a video
99
+ video = read_video(video_path=args.video_path, video_length=args.video_length, width=args.width, height=args.height)
100
+
101
+ # Save source video
102
+ original_pixels = rearrange(video, "(b f) c h w -> b c f h w", b=1)
103
+ save_videos_grid(original_pixels, os.path.join(args.output_path, "source_video.mp4"), rescale=True)
104
+
105
+
106
+ # Step 2. Parse a video to conditional frames
107
+ pil_annotation = get_annotation(video, annotator)
108
+ if args.condition == "depth" and controlnet_aux.__version__ == '0.0.1':
109
+ pil_annotation = [pil_annot[0] for pil_annot in pil_annotation]
110
+
111
+ # Save condition video
112
+ video_cond = [np.array(p).astype(np.uint8) for p in pil_annotation]
113
+ imageio.mimsave(os.path.join(args.output_path, f"{args.condition}_condition.mp4"), video_cond, fps=8)
114
+
115
+ # Reduce memory (optional)
116
+ del annotator; torch.cuda.empty_cache()
117
+
118
+ # Step 3. inference
119
+
120
+ if args.is_long_video:
121
+ window_size = int(np.sqrt(args.video_length))
122
+ sample = pipe.generate_long_video(args.prompt + POS_PROMPT, video_length=args.video_length, frames=pil_annotation,
123
+ num_inference_steps=50, smooth_steps=args.smoother_steps, window_size=window_size,
124
+ generator=generator, guidance_scale=12.5, negative_prompt=NEG_PROMPT,
125
+ width=args.width, height=args.height
126
+ ).videos
127
+ else:
128
+ sample = pipe(args.prompt + POS_PROMPT, video_length=args.video_length, frames=pil_annotation,
129
+ num_inference_steps=50, smooth_steps=args.smoother_steps,
130
+ generator=generator, guidance_scale=12.5, negative_prompt=NEG_PROMPT,
131
+ width=args.width, height=args.height
132
+ ).videos
133
+ save_videos_grid(sample, f"{args.output_path}/{args.prompt}.mp4")
inference.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ python inference.py \
2
+ --prompt "A striking mallard floats effortlessly on the sparkling pond." \
3
+ --condition "depth" \
4
+ --video_path "data/mallard-water.mp4" \
5
+ --output_path "outputs/" \
6
+ --video_length 15 \
7
+ --smoother_steps 19 20 \
8
+ --width 512 \
9
+ --height 512 \
10
+ # --is_long_video
models/RIFE/IFNet_HDv3.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from diffusers import ModelMixin
5
+
6
+ from .warplayer import warp
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
11
+ return nn.Sequential(
12
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
13
+ padding=padding, dilation=dilation, bias=True),
14
+ nn.PReLU(out_planes)
15
+ )
16
+
17
+ def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
18
+ return nn.Sequential(
19
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
20
+ padding=padding, dilation=dilation, bias=False),
21
+ nn.BatchNorm2d(out_planes),
22
+ nn.PReLU(out_planes)
23
+ )
24
+
25
+ def convert(param):
26
+ return {
27
+ k.replace("module.", ""): v
28
+ for k, v in param.items()
29
+ if "module." in k
30
+ }
31
+
32
+ class IFBlock(nn.Module):
33
+ def __init__(self, in_planes, c=64):
34
+ super(IFBlock, self).__init__()
35
+ self.conv0 = nn.Sequential(
36
+ conv(in_planes, c//2, 3, 2, 1),
37
+ conv(c//2, c, 3, 2, 1),
38
+ )
39
+ self.convblock0 = nn.Sequential(
40
+ conv(c, c),
41
+ conv(c, c)
42
+ )
43
+ self.convblock1 = nn.Sequential(
44
+ conv(c, c),
45
+ conv(c, c)
46
+ )
47
+ self.convblock2 = nn.Sequential(
48
+ conv(c, c),
49
+ conv(c, c)
50
+ )
51
+ self.convblock3 = nn.Sequential(
52
+ conv(c, c),
53
+ conv(c, c)
54
+ )
55
+ self.conv1 = nn.Sequential(
56
+ nn.ConvTranspose2d(c, c//2, 4, 2, 1),
57
+ nn.PReLU(c//2),
58
+ nn.ConvTranspose2d(c//2, 4, 4, 2, 1),
59
+ )
60
+ self.conv2 = nn.Sequential(
61
+ nn.ConvTranspose2d(c, c//2, 4, 2, 1),
62
+ nn.PReLU(c//2),
63
+ nn.ConvTranspose2d(c//2, 1, 4, 2, 1),
64
+ )
65
+
66
+ def forward(self, x, flow, scale=1):
67
+ x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
68
+ flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
69
+ feat = self.conv0(torch.cat((x, flow), 1))
70
+ feat = self.convblock0(feat) + feat
71
+ feat = self.convblock1(feat) + feat
72
+ feat = self.convblock2(feat) + feat
73
+ feat = self.convblock3(feat) + feat
74
+ flow = self.conv1(feat)
75
+ mask = self.conv2(feat)
76
+ flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
77
+ mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
78
+ return flow, mask
79
+
80
+ class IFNet(ModelMixin):
81
+ def __init__(self, ckpt_path="checkpoints/flownet.pkl"):
82
+ super(IFNet, self).__init__()
83
+ self.block0 = IFBlock(7+4, c=90)
84
+ self.block1 = IFBlock(7+4, c=90)
85
+ self.block2 = IFBlock(7+4, c=90)
86
+ self.block_tea = IFBlock(10+4, c=90)
87
+ if ckpt_path is not None:
88
+ self.load_state_dict(convert(torch.load(ckpt_path, map_location ='cpu')))
89
+
90
+ def inference(self, img0, img1, scale=1.0):
91
+ imgs = torch.cat((img0, img1), 1)
92
+ scale_list = [4/scale, 2/scale, 1/scale]
93
+ flow, mask, merged = self.forward(imgs, scale_list)
94
+ return merged[2]
95
+
96
+ def forward(self, x, scale_list=[4, 2, 1], training=False):
97
+ if training == False:
98
+ channel = x.shape[1] // 2
99
+ img0 = x[:, :channel]
100
+ img1 = x[:, channel:]
101
+ flow_list = []
102
+ merged = []
103
+ mask_list = []
104
+ warped_img0 = img0
105
+ warped_img1 = img1
106
+ flow = (x[:, :4]).detach() * 0
107
+ mask = (x[:, :1]).detach() * 0
108
+ loss_cons = 0
109
+ block = [self.block0, self.block1, self.block2]
110
+ for i in range(3):
111
+ f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
112
+ f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
113
+ flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
114
+ mask = mask + (m0 + (-m1)) / 2
115
+ mask_list.append(mask)
116
+ flow_list.append(flow)
117
+ warped_img0 = warp(img0, flow[:, :2])
118
+ warped_img1 = warp(img1, flow[:, 2:4])
119
+ merged.append((warped_img0, warped_img1))
120
+ '''
121
+ c0 = self.contextnet(img0, flow[:, :2])
122
+ c1 = self.contextnet(img1, flow[:, 2:4])
123
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
124
+ res = tmp[:, 1:4] * 2 - 1
125
+ '''
126
+ for i in range(3):
127
+ mask_list[i] = torch.sigmoid(mask_list[i])
128
+ merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
129
+ # merged[i] = torch.clamp(merged[i] + res, 0, 1)
130
+ return flow_list, mask_list[2], merged
models/RIFE/warplayer.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+ backwarp_tenGrid = {}
6
+
7
+
8
+ def warp(tenInput, tenFlow):
9
+ k = (str(tenFlow.device), str(tenFlow.size()))
10
+ if k not in backwarp_tenGrid:
11
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
12
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
13
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
14
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
15
+ backwarp_tenGrid[k] = torch.cat(
16
+ [tenHorizontal, tenVertical], 1).to(device)
17
+
18
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
19
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
20
+
21
+ g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1).to(dtype=tenInput.dtype)
22
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
models/attention.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Callable
5
+ import math
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ from positional_encodings.torch_encodings import PositionalEncoding2D
10
+
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers import ModelMixin
13
+ from diffusers.utils import BaseOutput
14
+ from diffusers.utils.import_utils import is_xformers_available
15
+ from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
16
+ from einops import rearrange, repeat
17
+
18
+
19
+ @dataclass
20
+ class Transformer3DModelOutput(BaseOutput):
21
+ sample: torch.FloatTensor
22
+
23
+
24
+ if is_xformers_available():
25
+ import xformers
26
+ import xformers.ops
27
+ else:
28
+ xformers = None
29
+
30
+
31
+ class Transformer3DModel(ModelMixin, ConfigMixin):
32
+ @register_to_config
33
+ def __init__(
34
+ self,
35
+ num_attention_heads: int = 16,
36
+ attention_head_dim: int = 88,
37
+ in_channels: Optional[int] = None,
38
+ num_layers: int = 1,
39
+ dropout: float = 0.0,
40
+ norm_num_groups: int = 32,
41
+ cross_attention_dim: Optional[int] = None,
42
+ attention_bias: bool = False,
43
+ activation_fn: str = "geglu",
44
+ num_embeds_ada_norm: Optional[int] = None,
45
+ use_linear_projection: bool = False,
46
+ only_cross_attention: bool = False,
47
+ upcast_attention: bool = False,
48
+ ):
49
+ super().__init__()
50
+ self.use_linear_projection = use_linear_projection
51
+ self.num_attention_heads = num_attention_heads
52
+ self.attention_head_dim = attention_head_dim
53
+ inner_dim = num_attention_heads * attention_head_dim
54
+
55
+ # Define input layers
56
+ self.in_channels = in_channels
57
+
58
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
59
+ if use_linear_projection:
60
+ self.proj_in = nn.Linear(in_channels, inner_dim)
61
+ else:
62
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
63
+
64
+ # Define transformers blocks
65
+ self.transformer_blocks = nn.ModuleList(
66
+ [
67
+ BasicTransformerBlock(
68
+ inner_dim,
69
+ num_attention_heads,
70
+ attention_head_dim,
71
+ dropout=dropout,
72
+ cross_attention_dim=cross_attention_dim,
73
+ activation_fn=activation_fn,
74
+ num_embeds_ada_norm=num_embeds_ada_norm,
75
+ attention_bias=attention_bias,
76
+ only_cross_attention=only_cross_attention,
77
+ upcast_attention=upcast_attention,
78
+ )
79
+ for d in range(num_layers)
80
+ ]
81
+ )
82
+
83
+ # 4. Define output layers
84
+ if use_linear_projection:
85
+ self.proj_out = nn.Linear(in_channels, inner_dim)
86
+ else:
87
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
88
+
89
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True, \
90
+ inter_frame=False):
91
+ # Input
92
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
93
+ video_length = hidden_states.shape[2]
94
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
95
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
96
+
97
+ batch, channel, height, weight = hidden_states.shape
98
+ residual = hidden_states
99
+
100
+ hidden_states = self.norm(hidden_states)
101
+ if not self.use_linear_projection:
102
+ hidden_states = self.proj_in(hidden_states)
103
+ inner_dim = hidden_states.shape[1]
104
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
105
+ else:
106
+ inner_dim = hidden_states.shape[1]
107
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
108
+ hidden_states = self.proj_in(hidden_states)
109
+
110
+ # Blocks
111
+ for block in self.transformer_blocks:
112
+ hidden_states = block(
113
+ hidden_states,
114
+ encoder_hidden_states=encoder_hidden_states,
115
+ timestep=timestep,
116
+ video_length=video_length,
117
+ inter_frame=inter_frame
118
+ )
119
+
120
+ # Output
121
+ if not self.use_linear_projection:
122
+ hidden_states = (
123
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
124
+ )
125
+ hidden_states = self.proj_out(hidden_states)
126
+ else:
127
+ hidden_states = self.proj_out(hidden_states)
128
+ hidden_states = (
129
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
130
+ )
131
+
132
+ output = hidden_states + residual
133
+
134
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
135
+ if not return_dict:
136
+ return (output,)
137
+
138
+ return Transformer3DModelOutput(sample=output)
139
+
140
+
141
+ class BasicTransformerBlock(nn.Module):
142
+ def __init__(
143
+ self,
144
+ dim: int,
145
+ num_attention_heads: int,
146
+ attention_head_dim: int,
147
+ dropout=0.0,
148
+ cross_attention_dim: Optional[int] = None,
149
+ activation_fn: str = "geglu",
150
+ num_embeds_ada_norm: Optional[int] = None,
151
+ attention_bias: bool = False,
152
+ only_cross_attention: bool = False,
153
+ upcast_attention: bool = False,
154
+ ):
155
+ super().__init__()
156
+ self.only_cross_attention = only_cross_attention
157
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
158
+
159
+ # Fully
160
+ self.attn1 = FullyFrameAttention(
161
+ query_dim=dim,
162
+ heads=num_attention_heads,
163
+ dim_head=attention_head_dim,
164
+ dropout=dropout,
165
+ bias=attention_bias,
166
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
167
+ upcast_attention=upcast_attention,
168
+ )
169
+
170
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
171
+
172
+ # Cross-Attn
173
+ if cross_attention_dim is not None:
174
+ self.attn2 = CrossAttention(
175
+ query_dim=dim,
176
+ cross_attention_dim=cross_attention_dim,
177
+ heads=num_attention_heads,
178
+ dim_head=attention_head_dim,
179
+ dropout=dropout,
180
+ bias=attention_bias,
181
+ upcast_attention=upcast_attention,
182
+ )
183
+ else:
184
+ self.attn2 = None
185
+
186
+ if cross_attention_dim is not None:
187
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
188
+ else:
189
+ self.norm2 = None
190
+
191
+ # Feed-forward
192
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
193
+ self.norm3 = nn.LayerNorm(dim)
194
+
195
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None):
196
+ if not is_xformers_available():
197
+ print("Here is how to install it")
198
+ raise ModuleNotFoundError(
199
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
200
+ " xformers",
201
+ name="xformers",
202
+ )
203
+ elif not torch.cuda.is_available():
204
+ raise ValueError(
205
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
206
+ " available for GPU "
207
+ )
208
+ else:
209
+ try:
210
+ # Make sure we can run the memory efficient attention
211
+ _ = xformers.ops.memory_efficient_attention(
212
+ torch.randn((1, 2, 40), device="cuda"),
213
+ torch.randn((1, 2, 40), device="cuda"),
214
+ torch.randn((1, 2, 40), device="cuda"),
215
+ )
216
+ except Exception as e:
217
+ raise e
218
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
219
+ if self.attn2 is not None:
220
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
221
+
222
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, \
223
+ inter_frame=False):
224
+ # SparseCausal-Attention
225
+ norm_hidden_states = (
226
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
227
+ )
228
+
229
+ if self.only_cross_attention:
230
+ hidden_states = (
231
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask, inter_frame=inter_frame) + hidden_states
232
+ )
233
+ else:
234
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length, inter_frame=inter_frame) + hidden_states
235
+
236
+ if self.attn2 is not None:
237
+ # Cross-Attention
238
+ norm_hidden_states = (
239
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
240
+ )
241
+ hidden_states = (
242
+ self.attn2(
243
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
244
+ )
245
+ + hidden_states
246
+ )
247
+
248
+ # Feed-forward
249
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
250
+
251
+ return hidden_states
252
+
253
+ class FullyFrameAttention(nn.Module):
254
+ r"""
255
+ A cross attention layer.
256
+
257
+ Parameters:
258
+ query_dim (`int`): The number of channels in the query.
259
+ cross_attention_dim (`int`, *optional*):
260
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
261
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
262
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
263
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
264
+ bias (`bool`, *optional*, defaults to False):
265
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
266
+ """
267
+
268
+ def __init__(
269
+ self,
270
+ query_dim: int,
271
+ cross_attention_dim: Optional[int] = None,
272
+ heads: int = 8,
273
+ dim_head: int = 64,
274
+ dropout: float = 0.0,
275
+ bias=False,
276
+ upcast_attention: bool = False,
277
+ upcast_softmax: bool = False,
278
+ added_kv_proj_dim: Optional[int] = None,
279
+ norm_num_groups: Optional[int] = None,
280
+ ):
281
+ super().__init__()
282
+ inner_dim = dim_head * heads
283
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
284
+ self.upcast_attention = upcast_attention
285
+ self.upcast_softmax = upcast_softmax
286
+
287
+ self.scale = dim_head**-0.5
288
+
289
+ self.heads = heads
290
+ # for slice_size > 0 the attention score computation
291
+ # is split across the batch axis to save memory
292
+ # You can set slice_size with `set_attention_slice`
293
+ self.sliceable_head_dim = heads
294
+ self._slice_size = None
295
+ self._use_memory_efficient_attention_xformers = False
296
+ self.added_kv_proj_dim = added_kv_proj_dim
297
+
298
+ if norm_num_groups is not None:
299
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
300
+ else:
301
+ self.group_norm = None
302
+
303
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
304
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
305
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
306
+
307
+ if self.added_kv_proj_dim is not None:
308
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
309
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
310
+
311
+ self.to_out = nn.ModuleList([])
312
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
313
+ self.to_out.append(nn.Dropout(dropout))
314
+
315
+ def reshape_heads_to_batch_dim(self, tensor):
316
+ batch_size, seq_len, dim = tensor.shape
317
+ head_size = self.heads
318
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
319
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
320
+ return tensor
321
+
322
+ def reshape_batch_dim_to_heads(self, tensor):
323
+ batch_size, seq_len, dim = tensor.shape
324
+ head_size = self.heads
325
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
326
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
327
+ return tensor
328
+
329
+ def set_attention_slice(self, slice_size):
330
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
331
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
332
+
333
+ self._slice_size = slice_size
334
+
335
+ def _attention(self, query, key, value, attention_mask=None):
336
+ if self.upcast_attention:
337
+ query = query.float()
338
+ key = key.float()
339
+
340
+ attention_scores = torch.baddbmm(
341
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
342
+ query,
343
+ key.transpose(-1, -2),
344
+ beta=0,
345
+ alpha=self.scale,
346
+ )
347
+ if attention_mask is not None:
348
+ attention_scores = attention_scores + attention_mask
349
+
350
+ if self.upcast_softmax:
351
+ attention_scores = attention_scores.float()
352
+
353
+ attention_probs = attention_scores.softmax(dim=-1)
354
+
355
+ # cast back to the original dtype
356
+ attention_probs = attention_probs.to(value.dtype)
357
+
358
+ # compute attention output
359
+ hidden_states = torch.bmm(attention_probs, value)
360
+
361
+ # reshape hidden_states
362
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
363
+ return hidden_states
364
+
365
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
366
+ batch_size_attention = query.shape[0]
367
+ hidden_states = torch.zeros(
368
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
369
+ )
370
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
371
+ for i in range(hidden_states.shape[0] // slice_size):
372
+ start_idx = i * slice_size
373
+ end_idx = (i + 1) * slice_size
374
+
375
+ query_slice = query[start_idx:end_idx]
376
+ key_slice = key[start_idx:end_idx]
377
+
378
+ if self.upcast_attention:
379
+ query_slice = query_slice.float()
380
+ key_slice = key_slice.float()
381
+
382
+ attn_slice = torch.baddbmm(
383
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
384
+ query_slice,
385
+ key_slice.transpose(-1, -2),
386
+ beta=0,
387
+ alpha=self.scale,
388
+ )
389
+
390
+ if attention_mask is not None:
391
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
392
+
393
+ if self.upcast_softmax:
394
+ attn_slice = attn_slice.float()
395
+
396
+ attn_slice = attn_slice.softmax(dim=-1)
397
+
398
+ # cast back to the original dtype
399
+ attn_slice = attn_slice.to(value.dtype)
400
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
401
+
402
+ hidden_states[start_idx:end_idx] = attn_slice
403
+
404
+ # reshape hidden_states
405
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
406
+ return hidden_states
407
+
408
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
409
+ # TODO attention_mask
410
+ query = query.contiguous()
411
+ key = key.contiguous()
412
+ value = value.contiguous()
413
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
414
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
415
+ return hidden_states
416
+
417
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, inter_frame=False):
418
+ batch_size, sequence_length, _ = hidden_states.shape
419
+
420
+ encoder_hidden_states = encoder_hidden_states
421
+
422
+ if self.group_norm is not None:
423
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
424
+
425
+ query = self.to_q(hidden_states) # (bf) x d(hw) x c
426
+ dim = query.shape[-1]
427
+
428
+ # All frames
429
+ query = rearrange(query, "(b f) d c -> b (f d) c", f=video_length)
430
+
431
+ query = self.reshape_heads_to_batch_dim(query)
432
+
433
+ if self.added_kv_proj_dim is not None:
434
+ raise NotImplementedError
435
+
436
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
437
+ key = self.to_k(encoder_hidden_states)
438
+ value = self.to_v(encoder_hidden_states)
439
+
440
+ if inter_frame:
441
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)[:, [0, -1]]
442
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)[:, [0, -1]]
443
+ key = rearrange(key, "b f d c -> b (f d) c",)
444
+ value = rearrange(value, "b f d c -> b (f d) c")
445
+ else:
446
+ # All frames
447
+ key = rearrange(key, "(b f) d c -> b (f d) c", f=video_length)
448
+ value = rearrange(value, "(b f) d c -> b (f d) c", f=video_length)
449
+
450
+ key = self.reshape_heads_to_batch_dim(key)
451
+ value = self.reshape_heads_to_batch_dim(value)
452
+
453
+ if attention_mask is not None:
454
+ if attention_mask.shape[-1] != query.shape[1]:
455
+ target_length = query.shape[1]
456
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
457
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
458
+
459
+ # attention, what we cannot get enough of
460
+ if self._use_memory_efficient_attention_xformers:
461
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
462
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
463
+ hidden_states = hidden_states.to(query.dtype)
464
+ else:
465
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
466
+ hidden_states = self._attention(query, key, value, attention_mask)
467
+ else:
468
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
469
+
470
+ # linear proj
471
+ hidden_states = self.to_out[0](hidden_states)
472
+
473
+ # dropout
474
+ hidden_states = self.to_out[1](hidden_states)
475
+
476
+ # All frames
477
+ hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=video_length)
478
+ return hidden_states
models/controlnet.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+ import json
18
+
19
+ import torch
20
+ from torch import nn
21
+ from torch.nn import functional as F
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.utils import BaseOutput, logging
25
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
26
+ from diffusers import ModelMixin
27
+ from .controlnet_unet_blocks import (
28
+ CrossAttnDownBlock3D,
29
+ DownBlock3D,
30
+ UNetMidBlock3DCrossAttn,
31
+ get_down_block,
32
+ )
33
+ from .resnet import InflatedConv3d
34
+
35
+ from diffusers.models.unet_2d_condition import UNet2DConditionModel
36
+ from diffusers.models.cross_attention import AttnProcessor
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ @dataclass
42
+ class ControlNetOutput(BaseOutput):
43
+ down_block_res_samples: Tuple[torch.Tensor]
44
+ mid_block_res_sample: torch.Tensor
45
+
46
+
47
+ class ControlNetConditioningEmbedding(nn.Module):
48
+ """
49
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
50
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
51
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
52
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
53
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
54
+ model) to encode image-space conditions ... into feature maps ..."
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ conditioning_embedding_channels: int,
60
+ conditioning_channels: int = 3,
61
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
62
+ ):
63
+ super().__init__()
64
+
65
+ self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
66
+
67
+ self.blocks = nn.ModuleList([])
68
+
69
+ for i in range(len(block_out_channels) - 1):
70
+ channel_in = block_out_channels[i]
71
+ channel_out = block_out_channels[i + 1]
72
+ self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1))
73
+ self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
74
+
75
+ self.conv_out = zero_module(
76
+ InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
77
+ )
78
+
79
+ def forward(self, conditioning):
80
+ embedding = self.conv_in(conditioning)
81
+ embedding = F.silu(embedding)
82
+
83
+ for block in self.blocks:
84
+ embedding = block(embedding)
85
+ embedding = F.silu(embedding)
86
+
87
+ embedding = self.conv_out(embedding)
88
+
89
+ return embedding
90
+
91
+
92
+ class ControlNetModel3D(ModelMixin, ConfigMixin):
93
+ _supports_gradient_checkpointing = True
94
+
95
+ @register_to_config
96
+ def __init__(
97
+ self,
98
+ in_channels: int = 4,
99
+ flip_sin_to_cos: bool = True,
100
+ freq_shift: int = 0,
101
+ down_block_types: Tuple[str] = (
102
+ "CrossAttnDownBlock3D",
103
+ "CrossAttnDownBlock3D",
104
+ "CrossAttnDownBlock3D",
105
+ "DownBlock3D",
106
+ ),
107
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
108
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
109
+ layers_per_block: int = 2,
110
+ downsample_padding: int = 1,
111
+ mid_block_scale_factor: float = 1,
112
+ act_fn: str = "silu",
113
+ norm_num_groups: Optional[int] = 32,
114
+ norm_eps: float = 1e-5,
115
+ cross_attention_dim: int = 1280,
116
+ attention_head_dim: Union[int, Tuple[int]] = 8,
117
+ dual_cross_attention: bool = False,
118
+ use_linear_projection: bool = False,
119
+ class_embed_type: Optional[str] = None,
120
+ num_class_embeds: Optional[int] = None,
121
+ upcast_attention: bool = False,
122
+ resnet_time_scale_shift: str = "default",
123
+ projection_class_embeddings_input_dim: Optional[int] = None,
124
+ controlnet_conditioning_channel_order: str = "rgb",
125
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
126
+ ):
127
+ super().__init__()
128
+
129
+ # Check inputs
130
+ if len(block_out_channels) != len(down_block_types):
131
+ raise ValueError(
132
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
133
+ )
134
+
135
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
136
+ raise ValueError(
137
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
138
+ )
139
+
140
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
141
+ raise ValueError(
142
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
143
+ )
144
+
145
+ # input
146
+ conv_in_kernel = 3
147
+ conv_in_padding = (conv_in_kernel - 1) // 2
148
+ self.conv_in = InflatedConv3d(
149
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
150
+ )
151
+
152
+ # time
153
+ time_embed_dim = block_out_channels[0] * 4
154
+
155
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
156
+ timestep_input_dim = block_out_channels[0]
157
+
158
+ self.time_embedding = TimestepEmbedding(
159
+ timestep_input_dim,
160
+ time_embed_dim,
161
+ act_fn=act_fn,
162
+ )
163
+
164
+ # class embedding
165
+ if class_embed_type is None and num_class_embeds is not None:
166
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
167
+ elif class_embed_type == "timestep":
168
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
169
+ elif class_embed_type == "identity":
170
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
171
+ elif class_embed_type == "projection":
172
+ if projection_class_embeddings_input_dim is None:
173
+ raise ValueError(
174
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
175
+ )
176
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
177
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
178
+ # 2. it projects from an arbitrary input dimension.
179
+ #
180
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
181
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
182
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
183
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
184
+ else:
185
+ self.class_embedding = None
186
+
187
+ # control net conditioning embedding
188
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
189
+ conditioning_embedding_channels=block_out_channels[0],
190
+ block_out_channels=conditioning_embedding_out_channels,
191
+ )
192
+
193
+ self.down_blocks = nn.ModuleList([])
194
+ self.controlnet_down_blocks = nn.ModuleList([])
195
+
196
+ if isinstance(only_cross_attention, bool):
197
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
198
+
199
+ if isinstance(attention_head_dim, int):
200
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
201
+
202
+ # down
203
+ output_channel = block_out_channels[0]
204
+
205
+ controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
206
+ controlnet_block = zero_module(controlnet_block)
207
+ self.controlnet_down_blocks.append(controlnet_block)
208
+
209
+ for i, down_block_type in enumerate(down_block_types):
210
+ input_channel = output_channel
211
+ output_channel = block_out_channels[i]
212
+ is_final_block = i == len(block_out_channels) - 1
213
+
214
+ down_block = get_down_block(
215
+ down_block_type,
216
+ num_layers=layers_per_block,
217
+ in_channels=input_channel,
218
+ out_channels=output_channel,
219
+ temb_channels=time_embed_dim,
220
+ add_downsample=not is_final_block,
221
+ resnet_eps=norm_eps,
222
+ resnet_act_fn=act_fn,
223
+ resnet_groups=norm_num_groups,
224
+ cross_attention_dim=cross_attention_dim,
225
+ attn_num_head_channels=attention_head_dim[i],
226
+ downsample_padding=downsample_padding,
227
+ dual_cross_attention=dual_cross_attention,
228
+ use_linear_projection=use_linear_projection,
229
+ only_cross_attention=only_cross_attention[i],
230
+ upcast_attention=upcast_attention,
231
+ resnet_time_scale_shift=resnet_time_scale_shift,
232
+ )
233
+ self.down_blocks.append(down_block)
234
+
235
+ for _ in range(layers_per_block):
236
+ controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
237
+ controlnet_block = zero_module(controlnet_block)
238
+ self.controlnet_down_blocks.append(controlnet_block)
239
+
240
+ if not is_final_block:
241
+ controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
242
+ controlnet_block = zero_module(controlnet_block)
243
+ self.controlnet_down_blocks.append(controlnet_block)
244
+
245
+ # mid
246
+ mid_block_channel = block_out_channels[-1]
247
+
248
+ controlnet_block = InflatedConv3d(mid_block_channel, mid_block_channel, kernel_size=1)
249
+ controlnet_block = zero_module(controlnet_block)
250
+ self.controlnet_mid_block = controlnet_block
251
+
252
+ # mid
253
+ self.mid_block = UNetMidBlock3DCrossAttn(
254
+ in_channels=block_out_channels[-1],
255
+ temb_channels=time_embed_dim,
256
+ resnet_eps=norm_eps,
257
+ resnet_act_fn=act_fn,
258
+ output_scale_factor=mid_block_scale_factor,
259
+ resnet_time_scale_shift=resnet_time_scale_shift,
260
+ cross_attention_dim=cross_attention_dim,
261
+ attn_num_head_channels=attention_head_dim[-1],
262
+ resnet_groups=norm_num_groups,
263
+ dual_cross_attention=dual_cross_attention,
264
+ use_linear_projection=use_linear_projection,
265
+ upcast_attention=upcast_attention,
266
+ )
267
+
268
+ @classmethod
269
+ def from_unet(
270
+ cls,
271
+ unet: UNet2DConditionModel,
272
+ controlnet_conditioning_channel_order: str = "rgb",
273
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
274
+ load_weights_from_unet: bool = True,
275
+ ):
276
+ r"""
277
+ Instantiate Controlnet class from UNet2DConditionModel.
278
+
279
+ Parameters:
280
+ unet (`UNet2DConditionModel`):
281
+ UNet model which weights are copied to the ControlNet. Note that all configuration options are also
282
+ copied where applicable.
283
+ """
284
+ controlnet = cls(
285
+ in_channels=unet.config.in_channels,
286
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
287
+ freq_shift=unet.config.freq_shift,
288
+ down_block_types=unet.config.down_block_types,
289
+ only_cross_attention=unet.config.only_cross_attention,
290
+ block_out_channels=unet.config.block_out_channels,
291
+ layers_per_block=unet.config.layers_per_block,
292
+ downsample_padding=unet.config.downsample_padding,
293
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
294
+ act_fn=unet.config.act_fn,
295
+ norm_num_groups=unet.config.norm_num_groups,
296
+ norm_eps=unet.config.norm_eps,
297
+ cross_attention_dim=unet.config.cross_attention_dim,
298
+ attention_head_dim=unet.config.attention_head_dim,
299
+ use_linear_projection=unet.config.use_linear_projection,
300
+ class_embed_type=unet.config.class_embed_type,
301
+ num_class_embeds=unet.config.num_class_embeds,
302
+ upcast_attention=unet.config.upcast_attention,
303
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
304
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
305
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
306
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
307
+ )
308
+
309
+ if load_weights_from_unet:
310
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
311
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
312
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
313
+
314
+ if controlnet.class_embedding:
315
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
316
+
317
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
318
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
319
+
320
+ return controlnet
321
+
322
+ @property
323
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
324
+ def attn_processors(self) -> Dict[str, AttnProcessor]:
325
+ r"""
326
+ Returns:
327
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
328
+ indexed by its weight name.
329
+ """
330
+ # set recursively
331
+ processors = {}
332
+
333
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
334
+ if hasattr(module, "set_processor"):
335
+ processors[f"{name}.processor"] = module.processor
336
+
337
+ for sub_name, child in module.named_children():
338
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
339
+
340
+ return processors
341
+
342
+ for name, module in self.named_children():
343
+ fn_recursive_add_processors(name, module, processors)
344
+
345
+ return processors
346
+
347
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
348
+ def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
349
+ r"""
350
+ Parameters:
351
+ `processor (`dict` of `AttnProcessor` or `AttnProcessor`):
352
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
353
+ of **all** `Attention` layers.
354
+ In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
355
+
356
+ """
357
+ count = len(self.attn_processors.keys())
358
+
359
+ if isinstance(processor, dict) and len(processor) != count:
360
+ raise ValueError(
361
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
362
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
363
+ )
364
+
365
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
366
+ if hasattr(module, "set_processor"):
367
+ if not isinstance(processor, dict):
368
+ module.set_processor(processor)
369
+ else:
370
+ module.set_processor(processor.pop(f"{name}.processor"))
371
+
372
+ for sub_name, child in module.named_children():
373
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
374
+
375
+ for name, module in self.named_children():
376
+ fn_recursive_attn_processor(name, module, processor)
377
+
378
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
379
+ def set_attention_slice(self, slice_size):
380
+ r"""
381
+ Enable sliced attention computation.
382
+
383
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
384
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
385
+
386
+ Args:
387
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
388
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
389
+ `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
390
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
391
+ must be a multiple of `slice_size`.
392
+ """
393
+ sliceable_head_dims = []
394
+
395
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
396
+ if hasattr(module, "set_attention_slice"):
397
+ sliceable_head_dims.append(module.sliceable_head_dim)
398
+
399
+ for child in module.children():
400
+ fn_recursive_retrieve_sliceable_dims(child)
401
+
402
+ # retrieve number of attention layers
403
+ for module in self.children():
404
+ fn_recursive_retrieve_sliceable_dims(module)
405
+
406
+ num_sliceable_layers = len(sliceable_head_dims)
407
+
408
+ if slice_size == "auto":
409
+ # half the attention head size is usually a good trade-off between
410
+ # speed and memory
411
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
412
+ elif slice_size == "max":
413
+ # make smallest slice possible
414
+ slice_size = num_sliceable_layers * [1]
415
+
416
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
417
+
418
+ if len(slice_size) != len(sliceable_head_dims):
419
+ raise ValueError(
420
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
421
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
422
+ )
423
+
424
+ for i in range(len(slice_size)):
425
+ size = slice_size[i]
426
+ dim = sliceable_head_dims[i]
427
+ if size is not None and size > dim:
428
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
429
+
430
+ # Recursively walk through all the children.
431
+ # Any children which exposes the set_attention_slice method
432
+ # gets the message
433
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
434
+ if hasattr(module, "set_attention_slice"):
435
+ module.set_attention_slice(slice_size.pop())
436
+
437
+ for child in module.children():
438
+ fn_recursive_set_attention_slice(child, slice_size)
439
+
440
+ reversed_slice_size = list(reversed(slice_size))
441
+ for module in self.children():
442
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
443
+
444
+ def _set_gradient_checkpointing(self, module, value=False):
445
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D)):
446
+ module.gradient_checkpointing = value
447
+
448
+ def forward(
449
+ self,
450
+ sample: torch.FloatTensor,
451
+ timestep: Union[torch.Tensor, float, int],
452
+ encoder_hidden_states: torch.Tensor,
453
+ controlnet_cond: torch.FloatTensor,
454
+ conditioning_scale: float = 1.0,
455
+ class_labels: Optional[torch.Tensor] = None,
456
+ timestep_cond: Optional[torch.Tensor] = None,
457
+ attention_mask: Optional[torch.Tensor] = None,
458
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
459
+ return_dict: bool = True,
460
+ ) -> Union[ControlNetOutput, Tuple]:
461
+ # check channel order
462
+ channel_order = self.config.controlnet_conditioning_channel_order
463
+
464
+ if channel_order == "rgb":
465
+ # in rgb order by default
466
+ ...
467
+ elif channel_order == "bgr":
468
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
469
+ else:
470
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
471
+
472
+ # prepare attention_mask
473
+ if attention_mask is not None:
474
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
475
+ attention_mask = attention_mask.unsqueeze(1)
476
+
477
+ # 1. time
478
+ timesteps = timestep
479
+ if not torch.is_tensor(timesteps):
480
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
481
+ # This would be a good case for the `match` statement (Python 3.10+)
482
+ is_mps = sample.device.type == "mps"
483
+ if isinstance(timestep, float):
484
+ dtype = torch.float32 if is_mps else torch.float64
485
+ else:
486
+ dtype = torch.int32 if is_mps else torch.int64
487
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
488
+ elif len(timesteps.shape) == 0:
489
+ timesteps = timesteps[None].to(sample.device)
490
+
491
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
492
+ timesteps = timesteps.expand(sample.shape[0])
493
+
494
+ t_emb = self.time_proj(timesteps)
495
+
496
+ # timesteps does not contain any weights and will always return f32 tensors
497
+ # but time_embedding might actually be running in fp16. so we need to cast here.
498
+ # there might be better ways to encapsulate this.
499
+ t_emb = t_emb.to(dtype=self.dtype)
500
+
501
+ emb = self.time_embedding(t_emb, timestep_cond)
502
+
503
+ if self.class_embedding is not None:
504
+ if class_labels is None:
505
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
506
+
507
+ if self.config.class_embed_type == "timestep":
508
+ class_labels = self.time_proj(class_labels)
509
+
510
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
511
+ emb = emb + class_emb
512
+
513
+ # 2. pre-process
514
+ sample = self.conv_in(sample)
515
+
516
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
517
+
518
+ sample += controlnet_cond
519
+
520
+ # 3. down
521
+ down_block_res_samples = (sample,)
522
+ for downsample_block in self.down_blocks:
523
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
524
+ sample, res_samples = downsample_block(
525
+ hidden_states=sample,
526
+ temb=emb,
527
+ encoder_hidden_states=encoder_hidden_states,
528
+ attention_mask=attention_mask,
529
+ cross_attention_kwargs=cross_attention_kwargs,
530
+ )
531
+ else:
532
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
533
+
534
+ down_block_res_samples += res_samples
535
+
536
+ # 4. mid
537
+ if self.mid_block is not None:
538
+ sample = self.mid_block(
539
+ sample,
540
+ emb,
541
+ encoder_hidden_states=encoder_hidden_states,
542
+ attention_mask=attention_mask,
543
+ cross_attention_kwargs=cross_attention_kwargs,
544
+ )
545
+
546
+ # 5. Control net blocks
547
+
548
+ controlnet_down_block_res_samples = ()
549
+
550
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
551
+ down_block_res_sample = controlnet_block(down_block_res_sample)
552
+ controlnet_down_block_res_samples += (down_block_res_sample,)
553
+
554
+ down_block_res_samples = controlnet_down_block_res_samples
555
+
556
+ mid_block_res_sample = self.controlnet_mid_block(sample)
557
+
558
+ # 6. scaling
559
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
560
+ mid_block_res_sample *= conditioning_scale
561
+
562
+ if not return_dict:
563
+ return (down_block_res_samples, mid_block_res_sample)
564
+
565
+ return ControlNetOutput(
566
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
567
+ )
568
+
569
+ @classmethod
570
+ def from_pretrained_2d(cls, pretrained_model_path, control_path=None):
571
+ config_file = os.path.join(pretrained_model_path, 'config.json')
572
+ if not os.path.isfile(config_file):
573
+ raise RuntimeError(f"{config_file} does not exist")
574
+ with open(config_file, "r") as f:
575
+ config = json.load(f)
576
+ config["_class_name"] = cls.__name__
577
+ config["down_block_types"] = [
578
+ "CrossAttnDownBlock3D",
579
+ "CrossAttnDownBlock3D",
580
+ "CrossAttnDownBlock3D",
581
+ "DownBlock3D"
582
+ ]
583
+
584
+ from diffusers.utils import WEIGHTS_NAME
585
+ model = cls.from_config(config)
586
+ if control_path is None:
587
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
588
+ state_dict = torch.load(model_file, map_location="cpu")
589
+ else:
590
+ model_file = control_path
591
+ state_dict = torch.load(model_file, map_location="cpu")
592
+ state_dict = {k[14:]: state_dict[k] for k in state_dict.keys()}
593
+
594
+
595
+ for k, v in model.state_dict().items():
596
+ if '_temp.' in k:
597
+ state_dict.update({k: v})
598
+ model.load_state_dict(state_dict)
599
+
600
+ return model
601
+
602
+ def zero_module(module):
603
+ for p in module.parameters():
604
+ nn.init.zeros_(p)
605
+ return module
models/controlnet_attention.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Callable
5
+ import math
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ from positional_encodings.torch_encodings import PositionalEncoding2D
10
+
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers import ModelMixin
13
+ from diffusers.utils import BaseOutput
14
+ from diffusers.utils.import_utils import is_xformers_available
15
+ from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
16
+ from einops import rearrange, repeat
17
+
18
+
19
+ @dataclass
20
+ class Transformer3DModelOutput(BaseOutput):
21
+ sample: torch.FloatTensor
22
+
23
+
24
+ if is_xformers_available():
25
+ import xformers
26
+ import xformers.ops
27
+ else:
28
+ xformers = None
29
+
30
+
31
+ class Transformer3DModel(ModelMixin, ConfigMixin):
32
+ @register_to_config
33
+ def __init__(
34
+ self,
35
+ num_attention_heads: int = 16,
36
+ attention_head_dim: int = 88,
37
+ in_channels: Optional[int] = None,
38
+ num_layers: int = 1,
39
+ dropout: float = 0.0,
40
+ norm_num_groups: int = 32,
41
+ cross_attention_dim: Optional[int] = None,
42
+ attention_bias: bool = False,
43
+ activation_fn: str = "geglu",
44
+ num_embeds_ada_norm: Optional[int] = None,
45
+ use_linear_projection: bool = False,
46
+ only_cross_attention: bool = False,
47
+ upcast_attention: bool = False,
48
+ ):
49
+ super().__init__()
50
+ self.use_linear_projection = use_linear_projection
51
+ self.num_attention_heads = num_attention_heads
52
+ self.attention_head_dim = attention_head_dim
53
+ inner_dim = num_attention_heads * attention_head_dim
54
+
55
+ # Define input layers
56
+ self.in_channels = in_channels
57
+
58
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
59
+ if use_linear_projection:
60
+ self.proj_in = nn.Linear(in_channels, inner_dim)
61
+ else:
62
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
63
+
64
+ # Define transformers blocks
65
+ self.transformer_blocks = nn.ModuleList(
66
+ [
67
+ BasicTransformerBlock(
68
+ inner_dim,
69
+ num_attention_heads,
70
+ attention_head_dim,
71
+ dropout=dropout,
72
+ cross_attention_dim=cross_attention_dim,
73
+ activation_fn=activation_fn,
74
+ num_embeds_ada_norm=num_embeds_ada_norm,
75
+ attention_bias=attention_bias,
76
+ only_cross_attention=only_cross_attention,
77
+ upcast_attention=upcast_attention,
78
+ )
79
+ for d in range(num_layers)
80
+ ]
81
+ )
82
+
83
+ # 4. Define output layers
84
+ if use_linear_projection:
85
+ self.proj_out = nn.Linear(in_channels, inner_dim)
86
+ else:
87
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
88
+
89
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
90
+ # Input
91
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
92
+ video_length = hidden_states.shape[2]
93
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
94
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
95
+
96
+ batch, channel, height, weight = hidden_states.shape
97
+ residual = hidden_states
98
+
99
+ hidden_states = self.norm(hidden_states)
100
+ if not self.use_linear_projection:
101
+ hidden_states = self.proj_in(hidden_states)
102
+ inner_dim = hidden_states.shape[1]
103
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
104
+ else:
105
+ inner_dim = hidden_states.shape[1]
106
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
107
+ hidden_states = self.proj_in(hidden_states)
108
+
109
+ # Blocks
110
+ for block in self.transformer_blocks:
111
+ hidden_states = block(
112
+ hidden_states,
113
+ encoder_hidden_states=encoder_hidden_states,
114
+ timestep=timestep,
115
+ video_length=video_length
116
+ )
117
+
118
+ # Output
119
+ if not self.use_linear_projection:
120
+ hidden_states = (
121
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
122
+ )
123
+ hidden_states = self.proj_out(hidden_states)
124
+ else:
125
+ hidden_states = self.proj_out(hidden_states)
126
+ hidden_states = (
127
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
128
+ )
129
+
130
+ output = hidden_states + residual
131
+
132
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
133
+ if not return_dict:
134
+ return (output,)
135
+
136
+ return Transformer3DModelOutput(sample=output)
137
+
138
+
139
+ class BasicTransformerBlock(nn.Module):
140
+ def __init__(
141
+ self,
142
+ dim: int,
143
+ num_attention_heads: int,
144
+ attention_head_dim: int,
145
+ dropout=0.0,
146
+ cross_attention_dim: Optional[int] = None,
147
+ activation_fn: str = "geglu",
148
+ num_embeds_ada_norm: Optional[int] = None,
149
+ attention_bias: bool = False,
150
+ only_cross_attention: bool = False,
151
+ upcast_attention: bool = False,
152
+ ):
153
+ super().__init__()
154
+ self.only_cross_attention = only_cross_attention
155
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
156
+
157
+ # Individual-Attn
158
+ self.attn1 = IndividualAttention(
159
+ query_dim=dim,
160
+ heads=num_attention_heads,
161
+ dim_head=attention_head_dim,
162
+ dropout=dropout,
163
+ bias=attention_bias,
164
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
165
+ upcast_attention=upcast_attention,
166
+ )
167
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
168
+
169
+ # Cross-Attn
170
+ if cross_attention_dim is not None:
171
+ self.attn2 = CrossAttention(
172
+ query_dim=dim,
173
+ cross_attention_dim=cross_attention_dim,
174
+ heads=num_attention_heads,
175
+ dim_head=attention_head_dim,
176
+ dropout=dropout,
177
+ bias=attention_bias,
178
+ upcast_attention=upcast_attention,
179
+ )
180
+ else:
181
+ self.attn2 = None
182
+
183
+ if cross_attention_dim is not None:
184
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
185
+ else:
186
+ self.norm2 = None
187
+
188
+ # Feed-forward
189
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
190
+ self.norm3 = nn.LayerNorm(dim)
191
+
192
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
193
+
194
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None):
195
+ if not is_xformers_available():
196
+ print("Here is how to install it")
197
+ raise ModuleNotFoundError(
198
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
199
+ " xformers",
200
+ name="xformers",
201
+ )
202
+ elif not torch.cuda.is_available():
203
+ raise ValueError(
204
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
205
+ " available for GPU "
206
+ )
207
+ else:
208
+ try:
209
+ # Make sure we can run the memory efficient attention
210
+ _ = xformers.ops.memory_efficient_attention(
211
+ torch.randn((1, 2, 40), device="cuda"),
212
+ torch.randn((1, 2, 40), device="cuda"),
213
+ torch.randn((1, 2, 40), device="cuda"),
214
+ )
215
+ except Exception as e:
216
+ raise e
217
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
218
+ if self.attn2 is not None:
219
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
220
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
221
+
222
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
223
+ # Individual-Attention
224
+ norm_hidden_states = (
225
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
226
+ )
227
+
228
+ if self.only_cross_attention:
229
+ hidden_states = (
230
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
231
+ )
232
+ else:
233
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
234
+
235
+ if self.attn2 is not None:
236
+ # Cross-Attention
237
+ norm_hidden_states = (
238
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
239
+ )
240
+ hidden_states = (
241
+ self.attn2(
242
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
243
+ )
244
+ + hidden_states
245
+ )
246
+
247
+ # Feed-forward
248
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
249
+
250
+ # # Temporal-Attention
251
+ # d = hidden_states.shape[1]
252
+ # hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
253
+ # norm_hidden_states = (
254
+ # self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
255
+ # )
256
+ # hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
257
+ # hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
258
+
259
+ return hidden_states
260
+
261
+ class IndividualAttention(nn.Module):
262
+ r"""
263
+ A cross attention layer.
264
+
265
+ Parameters:
266
+ query_dim (`int`): The number of channels in the query.
267
+ cross_attention_dim (`int`, *optional*):
268
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
269
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
270
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
271
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
272
+ bias (`bool`, *optional*, defaults to False):
273
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
274
+ """
275
+
276
+ def __init__(
277
+ self,
278
+ query_dim: int,
279
+ cross_attention_dim: Optional[int] = None,
280
+ heads: int = 8,
281
+ dim_head: int = 64,
282
+ dropout: float = 0.0,
283
+ bias=False,
284
+ upcast_attention: bool = False,
285
+ upcast_softmax: bool = False,
286
+ added_kv_proj_dim: Optional[int] = None,
287
+ norm_num_groups: Optional[int] = None,
288
+ ):
289
+ super().__init__()
290
+ inner_dim = dim_head * heads
291
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
292
+ self.upcast_attention = upcast_attention
293
+ self.upcast_softmax = upcast_softmax
294
+
295
+ self.scale = dim_head**-0.5
296
+
297
+ self.heads = heads
298
+ # for slice_size > 0 the attention score computation
299
+ # is split across the batch axis to save memory
300
+ # You can set slice_size with `set_attention_slice`
301
+ self.sliceable_head_dim = heads
302
+ self._slice_size = None
303
+ self._use_memory_efficient_attention_xformers = False
304
+ self.added_kv_proj_dim = added_kv_proj_dim
305
+
306
+ if norm_num_groups is not None:
307
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
308
+ else:
309
+ self.group_norm = None
310
+
311
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
312
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
313
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
314
+
315
+ if self.added_kv_proj_dim is not None:
316
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
317
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
318
+
319
+ self.to_out = nn.ModuleList([])
320
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
321
+ self.to_out.append(nn.Dropout(dropout))
322
+
323
+ def reshape_heads_to_batch_dim(self, tensor):
324
+ batch_size, seq_len, dim = tensor.shape
325
+ head_size = self.heads
326
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
327
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
328
+ return tensor
329
+
330
+ def reshape_batch_dim_to_heads(self, tensor):
331
+ batch_size, seq_len, dim = tensor.shape
332
+ head_size = self.heads
333
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
334
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
335
+ return tensor
336
+
337
+ def set_attention_slice(self, slice_size):
338
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
339
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
340
+
341
+ self._slice_size = slice_size
342
+
343
+ def _attention(self, query, key, value, attention_mask=None):
344
+ if self.upcast_attention:
345
+ query = query.float()
346
+ key = key.float()
347
+
348
+ attention_scores = torch.baddbmm(
349
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
350
+ query,
351
+ key.transpose(-1, -2),
352
+ beta=0,
353
+ alpha=self.scale,
354
+ )
355
+
356
+ if attention_mask is not None:
357
+ attention_scores = attention_scores + attention_mask
358
+
359
+ if self.upcast_softmax:
360
+ attention_scores = attention_scores.float()
361
+
362
+ attention_probs = attention_scores.softmax(dim=-1)
363
+
364
+ # cast back to the original dtype
365
+ attention_probs = attention_probs.to(value.dtype)
366
+
367
+ # compute attention output
368
+ hidden_states = torch.bmm(attention_probs, value)
369
+
370
+ # reshape hidden_states
371
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
372
+ return hidden_states
373
+
374
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
375
+ batch_size_attention = query.shape[0]
376
+ hidden_states = torch.zeros(
377
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
378
+ )
379
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
380
+ for i in range(hidden_states.shape[0] // slice_size):
381
+ start_idx = i * slice_size
382
+ end_idx = (i + 1) * slice_size
383
+
384
+ query_slice = query[start_idx:end_idx]
385
+ key_slice = key[start_idx:end_idx]
386
+
387
+ if self.upcast_attention:
388
+ query_slice = query_slice.float()
389
+ key_slice = key_slice.float()
390
+
391
+ attn_slice = torch.baddbmm(
392
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
393
+ query_slice,
394
+ key_slice.transpose(-1, -2),
395
+ beta=0,
396
+ alpha=self.scale,
397
+ )
398
+
399
+ if attention_mask is not None:
400
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
401
+
402
+ if self.upcast_softmax:
403
+ attn_slice = attn_slice.float()
404
+
405
+ attn_slice = attn_slice.softmax(dim=-1)
406
+
407
+ # cast back to the original dtype
408
+ attn_slice = attn_slice.to(value.dtype)
409
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
410
+
411
+ hidden_states[start_idx:end_idx] = attn_slice
412
+
413
+ # reshape hidden_states
414
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
415
+ return hidden_states
416
+
417
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
418
+ # TODO attention_mask
419
+ query = query.contiguous()
420
+ key = key.contiguous()
421
+ value = value.contiguous()
422
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
423
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
424
+ return hidden_states
425
+
426
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
427
+ batch_size, sequence_length, _ = hidden_states.shape
428
+
429
+ encoder_hidden_states = encoder_hidden_states
430
+
431
+ if self.group_norm is not None:
432
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
433
+
434
+ query = self.to_q(hidden_states) # (bf) x d(hw) x c
435
+ dim = query.shape[-1]
436
+
437
+ query = self.reshape_heads_to_batch_dim(query)
438
+
439
+ if self.added_kv_proj_dim is not None:
440
+ raise NotImplementedError
441
+
442
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
443
+ key = self.to_k(encoder_hidden_states)
444
+ value = self.to_v(encoder_hidden_states)
445
+
446
+ curr_frame_index = torch.arange(video_length)
447
+
448
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
449
+
450
+ key = key[:, curr_frame_index]
451
+ key = rearrange(key, "b f d c -> (b f) d c")
452
+
453
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
454
+
455
+ value = value[:, curr_frame_index]
456
+ value = rearrange(value, "b f d c -> (b f) d c")
457
+
458
+ key = self.reshape_heads_to_batch_dim(key)
459
+ value = self.reshape_heads_to_batch_dim(value)
460
+
461
+ if attention_mask is not None:
462
+ if attention_mask.shape[-1] != query.shape[1]:
463
+ target_length = query.shape[1]
464
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
465
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
466
+
467
+ # attention, what we cannot get enough of
468
+ if self._use_memory_efficient_attention_xformers:
469
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
470
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
471
+ hidden_states = hidden_states.to(query.dtype)
472
+ else:
473
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
474
+ hidden_states = self._attention(query, key, value, attention_mask)
475
+ else:
476
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
477
+
478
+ # linear proj
479
+ hidden_states = self.to_out[0](hidden_states)
480
+
481
+ # dropout
482
+ hidden_states = self.to_out[1](hidden_states)
483
+ return hidden_states
models/controlnet_unet_blocks.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .controlnet_attention import Transformer3DModel
7
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8
+
9
+
10
+ def get_down_block(
11
+ down_block_type,
12
+ num_layers,
13
+ in_channels,
14
+ out_channels,
15
+ temb_channels,
16
+ add_downsample,
17
+ resnet_eps,
18
+ resnet_act_fn,
19
+ attn_num_head_channels,
20
+ resnet_groups=None,
21
+ cross_attention_dim=None,
22
+ downsample_padding=None,
23
+ dual_cross_attention=False,
24
+ use_linear_projection=False,
25
+ only_cross_attention=False,
26
+ upcast_attention=False,
27
+ resnet_time_scale_shift="default",
28
+ ):
29
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
30
+ if down_block_type == "DownBlock3D":
31
+ return DownBlock3D(
32
+ num_layers=num_layers,
33
+ in_channels=in_channels,
34
+ out_channels=out_channels,
35
+ temb_channels=temb_channels,
36
+ add_downsample=add_downsample,
37
+ resnet_eps=resnet_eps,
38
+ resnet_act_fn=resnet_act_fn,
39
+ resnet_groups=resnet_groups,
40
+ downsample_padding=downsample_padding,
41
+ resnet_time_scale_shift=resnet_time_scale_shift,
42
+ )
43
+ elif down_block_type == "CrossAttnDownBlock3D":
44
+ if cross_attention_dim is None:
45
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
46
+ return CrossAttnDownBlock3D(
47
+ num_layers=num_layers,
48
+ in_channels=in_channels,
49
+ out_channels=out_channels,
50
+ temb_channels=temb_channels,
51
+ add_downsample=add_downsample,
52
+ resnet_eps=resnet_eps,
53
+ resnet_act_fn=resnet_act_fn,
54
+ resnet_groups=resnet_groups,
55
+ downsample_padding=downsample_padding,
56
+ cross_attention_dim=cross_attention_dim,
57
+ attn_num_head_channels=attn_num_head_channels,
58
+ dual_cross_attention=dual_cross_attention,
59
+ use_linear_projection=use_linear_projection,
60
+ only_cross_attention=only_cross_attention,
61
+ upcast_attention=upcast_attention,
62
+ resnet_time_scale_shift=resnet_time_scale_shift,
63
+ )
64
+ raise ValueError(f"{down_block_type} does not exist.")
65
+
66
+
67
+ def get_up_block(
68
+ up_block_type,
69
+ num_layers,
70
+ in_channels,
71
+ out_channels,
72
+ prev_output_channel,
73
+ temb_channels,
74
+ add_upsample,
75
+ resnet_eps,
76
+ resnet_act_fn,
77
+ attn_num_head_channels,
78
+ resnet_groups=None,
79
+ cross_attention_dim=None,
80
+ dual_cross_attention=False,
81
+ use_linear_projection=False,
82
+ only_cross_attention=False,
83
+ upcast_attention=False,
84
+ resnet_time_scale_shift="default",
85
+ ):
86
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
87
+ if up_block_type == "UpBlock3D":
88
+ return UpBlock3D(
89
+ num_layers=num_layers,
90
+ in_channels=in_channels,
91
+ out_channels=out_channels,
92
+ prev_output_channel=prev_output_channel,
93
+ temb_channels=temb_channels,
94
+ add_upsample=add_upsample,
95
+ resnet_eps=resnet_eps,
96
+ resnet_act_fn=resnet_act_fn,
97
+ resnet_groups=resnet_groups,
98
+ resnet_time_scale_shift=resnet_time_scale_shift,
99
+ )
100
+ elif up_block_type == "CrossAttnUpBlock3D":
101
+ if cross_attention_dim is None:
102
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
103
+ return CrossAttnUpBlock3D(
104
+ num_layers=num_layers,
105
+ in_channels=in_channels,
106
+ out_channels=out_channels,
107
+ prev_output_channel=prev_output_channel,
108
+ temb_channels=temb_channels,
109
+ add_upsample=add_upsample,
110
+ resnet_eps=resnet_eps,
111
+ resnet_act_fn=resnet_act_fn,
112
+ resnet_groups=resnet_groups,
113
+ cross_attention_dim=cross_attention_dim,
114
+ attn_num_head_channels=attn_num_head_channels,
115
+ dual_cross_attention=dual_cross_attention,
116
+ use_linear_projection=use_linear_projection,
117
+ only_cross_attention=only_cross_attention,
118
+ upcast_attention=upcast_attention,
119
+ resnet_time_scale_shift=resnet_time_scale_shift,
120
+ )
121
+ raise ValueError(f"{up_block_type} does not exist.")
122
+
123
+
124
+ class UNetMidBlock3DCrossAttn(nn.Module):
125
+ def __init__(
126
+ self,
127
+ in_channels: int,
128
+ temb_channels: int,
129
+ dropout: float = 0.0,
130
+ num_layers: int = 1,
131
+ resnet_eps: float = 1e-6,
132
+ resnet_time_scale_shift: str = "default",
133
+ resnet_act_fn: str = "swish",
134
+ resnet_groups: int = 32,
135
+ resnet_pre_norm: bool = True,
136
+ attn_num_head_channels=1,
137
+ output_scale_factor=1.0,
138
+ cross_attention_dim=1280,
139
+ dual_cross_attention=False,
140
+ use_linear_projection=False,
141
+ upcast_attention=False,
142
+ ):
143
+ super().__init__()
144
+
145
+ self.has_cross_attention = True
146
+ self.attn_num_head_channels = attn_num_head_channels
147
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
148
+
149
+ # there is always at least one resnet
150
+ resnets = [
151
+ ResnetBlock3D(
152
+ in_channels=in_channels,
153
+ out_channels=in_channels,
154
+ temb_channels=temb_channels,
155
+ eps=resnet_eps,
156
+ groups=resnet_groups,
157
+ dropout=dropout,
158
+ time_embedding_norm=resnet_time_scale_shift,
159
+ non_linearity=resnet_act_fn,
160
+ output_scale_factor=output_scale_factor,
161
+ pre_norm=resnet_pre_norm,
162
+ )
163
+ ]
164
+ attentions = []
165
+
166
+ for _ in range(num_layers):
167
+ if dual_cross_attention:
168
+ raise NotImplementedError
169
+ attentions.append(
170
+ Transformer3DModel(
171
+ attn_num_head_channels,
172
+ in_channels // attn_num_head_channels,
173
+ in_channels=in_channels,
174
+ num_layers=1,
175
+ cross_attention_dim=cross_attention_dim,
176
+ norm_num_groups=resnet_groups,
177
+ use_linear_projection=use_linear_projection,
178
+ upcast_attention=upcast_attention,
179
+ )
180
+ )
181
+ resnets.append(
182
+ ResnetBlock3D(
183
+ in_channels=in_channels,
184
+ out_channels=in_channels,
185
+ temb_channels=temb_channels,
186
+ eps=resnet_eps,
187
+ groups=resnet_groups,
188
+ dropout=dropout,
189
+ time_embedding_norm=resnet_time_scale_shift,
190
+ non_linearity=resnet_act_fn,
191
+ output_scale_factor=output_scale_factor,
192
+ pre_norm=resnet_pre_norm,
193
+ )
194
+ )
195
+
196
+ self.attentions = nn.ModuleList(attentions)
197
+ self.resnets = nn.ModuleList(resnets)
198
+
199
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None):
200
+ hidden_states = self.resnets[0](hidden_states, temb)
201
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
202
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
203
+ hidden_states = resnet(hidden_states, temb)
204
+
205
+ return hidden_states
206
+
207
+
208
+ class CrossAttnDownBlock3D(nn.Module):
209
+ def __init__(
210
+ self,
211
+ in_channels: int,
212
+ out_channels: int,
213
+ temb_channels: int,
214
+ dropout: float = 0.0,
215
+ num_layers: int = 1,
216
+ resnet_eps: float = 1e-6,
217
+ resnet_time_scale_shift: str = "default",
218
+ resnet_act_fn: str = "swish",
219
+ resnet_groups: int = 32,
220
+ resnet_pre_norm: bool = True,
221
+ attn_num_head_channels=1,
222
+ cross_attention_dim=1280,
223
+ output_scale_factor=1.0,
224
+ downsample_padding=1,
225
+ add_downsample=True,
226
+ dual_cross_attention=False,
227
+ use_linear_projection=False,
228
+ only_cross_attention=False,
229
+ upcast_attention=False,
230
+ ):
231
+ super().__init__()
232
+ resnets = []
233
+ attentions = []
234
+
235
+ self.has_cross_attention = True
236
+ self.attn_num_head_channels = attn_num_head_channels
237
+
238
+ for i in range(num_layers):
239
+ in_channels = in_channels if i == 0 else out_channels
240
+ resnets.append(
241
+ ResnetBlock3D(
242
+ in_channels=in_channels,
243
+ out_channels=out_channels,
244
+ temb_channels=temb_channels,
245
+ eps=resnet_eps,
246
+ groups=resnet_groups,
247
+ dropout=dropout,
248
+ time_embedding_norm=resnet_time_scale_shift,
249
+ non_linearity=resnet_act_fn,
250
+ output_scale_factor=output_scale_factor,
251
+ pre_norm=resnet_pre_norm,
252
+ )
253
+ )
254
+ if dual_cross_attention:
255
+ raise NotImplementedError
256
+ attentions.append(
257
+ Transformer3DModel(
258
+ attn_num_head_channels,
259
+ out_channels // attn_num_head_channels,
260
+ in_channels=out_channels,
261
+ num_layers=1,
262
+ cross_attention_dim=cross_attention_dim,
263
+ norm_num_groups=resnet_groups,
264
+ use_linear_projection=use_linear_projection,
265
+ only_cross_attention=only_cross_attention,
266
+ upcast_attention=upcast_attention,
267
+ )
268
+ )
269
+ self.attentions = nn.ModuleList(attentions)
270
+ self.resnets = nn.ModuleList(resnets)
271
+
272
+ if add_downsample:
273
+ self.downsamplers = nn.ModuleList(
274
+ [
275
+ Downsample3D(
276
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
277
+ )
278
+ ]
279
+ )
280
+ else:
281
+ self.downsamplers = None
282
+
283
+ self.gradient_checkpointing = False
284
+
285
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None,cross_attention_kwargs=None):
286
+ output_states = ()
287
+
288
+ for resnet, attn in zip(self.resnets, self.attentions):
289
+ if self.training and self.gradient_checkpointing:
290
+
291
+ def create_custom_forward(module, return_dict=None):
292
+ def custom_forward(*inputs):
293
+ if return_dict is not None:
294
+ return module(*inputs, return_dict=return_dict)
295
+ else:
296
+ return module(*inputs)
297
+
298
+ return custom_forward
299
+
300
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
301
+ hidden_states = torch.utils.checkpoint.checkpoint(
302
+ create_custom_forward(attn, return_dict=False),
303
+ hidden_states,
304
+ encoder_hidden_states,
305
+ )[0]
306
+ else:
307
+ hidden_states = resnet(hidden_states, temb)
308
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
309
+
310
+ output_states += (hidden_states,)
311
+
312
+ if self.downsamplers is not None:
313
+ for downsampler in self.downsamplers:
314
+ hidden_states = downsampler(hidden_states)
315
+
316
+ output_states += (hidden_states,)
317
+
318
+ return hidden_states, output_states
319
+
320
+
321
+ class DownBlock3D(nn.Module):
322
+ def __init__(
323
+ self,
324
+ in_channels: int,
325
+ out_channels: int,
326
+ temb_channels: int,
327
+ dropout: float = 0.0,
328
+ num_layers: int = 1,
329
+ resnet_eps: float = 1e-6,
330
+ resnet_time_scale_shift: str = "default",
331
+ resnet_act_fn: str = "swish",
332
+ resnet_groups: int = 32,
333
+ resnet_pre_norm: bool = True,
334
+ output_scale_factor=1.0,
335
+ add_downsample=True,
336
+ downsample_padding=1,
337
+ ):
338
+ super().__init__()
339
+ resnets = []
340
+
341
+ for i in range(num_layers):
342
+ in_channels = in_channels if i == 0 else out_channels
343
+ resnets.append(
344
+ ResnetBlock3D(
345
+ in_channels=in_channels,
346
+ out_channels=out_channels,
347
+ temb_channels=temb_channels,
348
+ eps=resnet_eps,
349
+ groups=resnet_groups,
350
+ dropout=dropout,
351
+ time_embedding_norm=resnet_time_scale_shift,
352
+ non_linearity=resnet_act_fn,
353
+ output_scale_factor=output_scale_factor,
354
+ pre_norm=resnet_pre_norm,
355
+ )
356
+ )
357
+
358
+ self.resnets = nn.ModuleList(resnets)
359
+
360
+ if add_downsample:
361
+ self.downsamplers = nn.ModuleList(
362
+ [
363
+ Downsample3D(
364
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
365
+ )
366
+ ]
367
+ )
368
+ else:
369
+ self.downsamplers = None
370
+
371
+ self.gradient_checkpointing = False
372
+
373
+ def forward(self, hidden_states, temb=None):
374
+ output_states = ()
375
+
376
+ for resnet in self.resnets:
377
+ if self.training and self.gradient_checkpointing:
378
+
379
+ def create_custom_forward(module):
380
+ def custom_forward(*inputs):
381
+ return module(*inputs)
382
+
383
+ return custom_forward
384
+
385
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
386
+ else:
387
+ hidden_states = resnet(hidden_states, temb)
388
+
389
+ output_states += (hidden_states,)
390
+
391
+ if self.downsamplers is not None:
392
+ for downsampler in self.downsamplers:
393
+ hidden_states = downsampler(hidden_states)
394
+
395
+ output_states += (hidden_states,)
396
+
397
+ return hidden_states, output_states
398
+
399
+
400
+ class CrossAttnUpBlock3D(nn.Module):
401
+ def __init__(
402
+ self,
403
+ in_channels: int,
404
+ out_channels: int,
405
+ prev_output_channel: int,
406
+ temb_channels: int,
407
+ dropout: float = 0.0,
408
+ num_layers: int = 1,
409
+ resnet_eps: float = 1e-6,
410
+ resnet_time_scale_shift: str = "default",
411
+ resnet_act_fn: str = "swish",
412
+ resnet_groups: int = 32,
413
+ resnet_pre_norm: bool = True,
414
+ attn_num_head_channels=1,
415
+ cross_attention_dim=1280,
416
+ output_scale_factor=1.0,
417
+ add_upsample=True,
418
+ dual_cross_attention=False,
419
+ use_linear_projection=False,
420
+ only_cross_attention=False,
421
+ upcast_attention=False,
422
+ ):
423
+ super().__init__()
424
+ resnets = []
425
+ attentions = []
426
+
427
+ self.has_cross_attention = True
428
+ self.attn_num_head_channels = attn_num_head_channels
429
+
430
+ for i in range(num_layers):
431
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
432
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
433
+
434
+ resnets.append(
435
+ ResnetBlock3D(
436
+ in_channels=resnet_in_channels + res_skip_channels,
437
+ out_channels=out_channels,
438
+ temb_channels=temb_channels,
439
+ eps=resnet_eps,
440
+ groups=resnet_groups,
441
+ dropout=dropout,
442
+ time_embedding_norm=resnet_time_scale_shift,
443
+ non_linearity=resnet_act_fn,
444
+ output_scale_factor=output_scale_factor,
445
+ pre_norm=resnet_pre_norm,
446
+ )
447
+ )
448
+ if dual_cross_attention:
449
+ raise NotImplementedError
450
+ attentions.append(
451
+ Transformer3DModel(
452
+ attn_num_head_channels,
453
+ out_channels // attn_num_head_channels,
454
+ in_channels=out_channels,
455
+ num_layers=1,
456
+ cross_attention_dim=cross_attention_dim,
457
+ norm_num_groups=resnet_groups,
458
+ use_linear_projection=use_linear_projection,
459
+ only_cross_attention=only_cross_attention,
460
+ upcast_attention=upcast_attention,
461
+ )
462
+ )
463
+
464
+ self.attentions = nn.ModuleList(attentions)
465
+ self.resnets = nn.ModuleList(resnets)
466
+
467
+ if add_upsample:
468
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
469
+ else:
470
+ self.upsamplers = None
471
+
472
+ self.gradient_checkpointing = False
473
+
474
+ def forward(
475
+ self,
476
+ hidden_states,
477
+ res_hidden_states_tuple,
478
+ temb=None,
479
+ encoder_hidden_states=None,
480
+ upsample_size=None,
481
+ attention_mask=None,
482
+ cross_attention_kwargs=None
483
+ ):
484
+ for resnet, attn in zip(self.resnets, self.attentions):
485
+ # pop res hidden states
486
+ res_hidden_states = res_hidden_states_tuple[-1]
487
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
488
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
489
+
490
+ if self.training and self.gradient_checkpointing:
491
+
492
+ def create_custom_forward(module, return_dict=None):
493
+ def custom_forward(*inputs):
494
+ if return_dict is not None:
495
+ return module(*inputs, return_dict=return_dict)
496
+ else:
497
+ return module(*inputs)
498
+
499
+ return custom_forward
500
+
501
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
502
+ hidden_states = torch.utils.checkpoint.checkpoint(
503
+ create_custom_forward(attn, return_dict=False),
504
+ hidden_states,
505
+ encoder_hidden_states,
506
+ )[0]
507
+ else:
508
+ hidden_states = resnet(hidden_states, temb)
509
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
510
+
511
+ if self.upsamplers is not None:
512
+ for upsampler in self.upsamplers:
513
+ hidden_states = upsampler(hidden_states, upsample_size)
514
+
515
+ return hidden_states
516
+
517
+
518
+ class UpBlock3D(nn.Module):
519
+ def __init__(
520
+ self,
521
+ in_channels: int,
522
+ prev_output_channel: int,
523
+ out_channels: int,
524
+ temb_channels: int,
525
+ dropout: float = 0.0,
526
+ num_layers: int = 1,
527
+ resnet_eps: float = 1e-6,
528
+ resnet_time_scale_shift: str = "default",
529
+ resnet_act_fn: str = "swish",
530
+ resnet_groups: int = 32,
531
+ resnet_pre_norm: bool = True,
532
+ output_scale_factor=1.0,
533
+ add_upsample=True,
534
+ ):
535
+ super().__init__()
536
+ resnets = []
537
+
538
+ for i in range(num_layers):
539
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
540
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
541
+
542
+ resnets.append(
543
+ ResnetBlock3D(
544
+ in_channels=resnet_in_channels + res_skip_channels,
545
+ out_channels=out_channels,
546
+ temb_channels=temb_channels,
547
+ eps=resnet_eps,
548
+ groups=resnet_groups,
549
+ dropout=dropout,
550
+ time_embedding_norm=resnet_time_scale_shift,
551
+ non_linearity=resnet_act_fn,
552
+ output_scale_factor=output_scale_factor,
553
+ pre_norm=resnet_pre_norm,
554
+ )
555
+ )
556
+
557
+ self.resnets = nn.ModuleList(resnets)
558
+
559
+ if add_upsample:
560
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
561
+ else:
562
+ self.upsamplers = None
563
+
564
+ self.gradient_checkpointing = False
565
+
566
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
567
+ for resnet in self.resnets:
568
+ # pop res hidden states
569
+ res_hidden_states = res_hidden_states_tuple[-1]
570
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
571
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
572
+
573
+ if self.training and self.gradient_checkpointing:
574
+
575
+ def create_custom_forward(module):
576
+ def custom_forward(*inputs):
577
+ return module(*inputs)
578
+
579
+ return custom_forward
580
+
581
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
582
+ else:
583
+ hidden_states = resnet(hidden_states, temb)
584
+
585
+ if self.upsamplers is not None:
586
+ for upsampler in self.upsamplers:
587
+ hidden_states = upsampler(hidden_states, upsample_size)
588
+
589
+ return hidden_states
models/pipeline_controlvideo.py ADDED
@@ -0,0 +1,1351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import inspect
17
+ import os
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+ from dataclasses import dataclass
20
+
21
+ import numpy as np
22
+ import PIL.Image
23
+ import torch
24
+ from torch import nn
25
+ from transformers import CLIPTextModel, CLIPTokenizer
26
+
27
+ from diffusers.models import AutoencoderKL
28
+ from .controlnet import ControlNetOutput
29
+ from diffusers import ModelMixin
30
+ from diffusers.schedulers import DDIMScheduler
31
+ from diffusers.utils import (
32
+ PIL_INTERPOLATION,
33
+ is_accelerate_available,
34
+ is_accelerate_version,
35
+ logging,
36
+ randn_tensor,
37
+ BaseOutput
38
+ )
39
+ from diffusers.pipeline_utils import DiffusionPipeline
40
+
41
+ from einops import rearrange
42
+
43
+ from .unet import UNet3DConditionModel
44
+ from .controlnet import ControlNetModel3D
45
+ from .RIFE.IFNet_HDv3 import IFNet
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+
50
+ @dataclass
51
+ class ControlVideoPipelineOutput(BaseOutput):
52
+ videos: Union[torch.Tensor, np.ndarray]
53
+
54
+
55
+ class MultiControlNetModel3D(ModelMixin):
56
+ r"""
57
+ Multiple `ControlNetModel` wrapper class for Multi-ControlNet
58
+
59
+ This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
60
+ compatible with `ControlNetModel`.
61
+
62
+ Args:
63
+ controlnets (`List[ControlNetModel]`):
64
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
65
+ `ControlNetModel` as a list.
66
+ """
67
+
68
+ def __init__(self, controlnets: Union[List[ControlNetModel3D], Tuple[ControlNetModel3D]]):
69
+ super().__init__()
70
+ self.nets = nn.ModuleList(controlnets)
71
+
72
+ def forward(
73
+ self,
74
+ sample: torch.FloatTensor,
75
+ timestep: Union[torch.Tensor, float, int],
76
+ encoder_hidden_states: torch.Tensor,
77
+ controlnet_cond: List[List[torch.tensor]],
78
+ conditioning_scale: List[float],
79
+ class_labels: Optional[torch.Tensor] = None,
80
+ timestep_cond: Optional[torch.Tensor] = None,
81
+ attention_mask: Optional[torch.Tensor] = None,
82
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
83
+ return_dict: bool = True,
84
+ ) -> Union[ControlNetOutput, Tuple]:
85
+ for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
86
+ down_samples, mid_sample = controlnet(
87
+ sample,
88
+ timestep,
89
+ encoder_hidden_states,
90
+ torch.cat(image, dim=0),
91
+ scale,
92
+ class_labels,
93
+ timestep_cond,
94
+ attention_mask,
95
+ cross_attention_kwargs,
96
+ return_dict,
97
+ )
98
+
99
+ # merge samples
100
+ if i == 0:
101
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
102
+ else:
103
+ down_block_res_samples = [
104
+ samples_prev + samples_curr
105
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
106
+ ]
107
+ mid_block_res_sample += mid_sample
108
+
109
+ return down_block_res_samples, mid_block_res_sample
110
+
111
+
112
+ class ControlVideoPipeline(DiffusionPipeline):
113
+ r"""
114
+ Pipeline for text-to-video generation using Stable Diffusion with ControlNet guidance.
115
+
116
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
117
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
118
+
119
+ Args:
120
+ vae ([`AutoencoderKL`]):
121
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
122
+ text_encoder ([`CLIPTextModel`]):
123
+ Frozen text-encoder. Stable Diffusion uses the text portion of
124
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
125
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
126
+ tokenizer (`CLIPTokenizer`):
127
+ Tokenizer of class
128
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
129
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
130
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
131
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
132
+ as a list, the outputs from each ControlNet are added together to create one combined additional
133
+ conditioning.
134
+ scheduler ([`SchedulerMixin`]):
135
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
136
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
137
+ safety_checker ([`StableDiffusionSafetyChecker`]):
138
+ Classification module that estimates whether generated images could be considered offensive or harmful.
139
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
140
+ feature_extractor ([`CLIPImageProcessor`]):
141
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
142
+ """
143
+ _optional_components = ["safety_checker", "feature_extractor"]
144
+
145
+ def __init__(
146
+ self,
147
+ vae: AutoencoderKL,
148
+ text_encoder: CLIPTextModel,
149
+ tokenizer: CLIPTokenizer,
150
+ unet: UNet3DConditionModel,
151
+ controlnet: Union[ControlNetModel3D, List[ControlNetModel3D], Tuple[ControlNetModel3D], MultiControlNetModel3D],
152
+ scheduler: DDIMScheduler,
153
+ interpolater: IFNet,
154
+ ):
155
+ super().__init__()
156
+
157
+ if isinstance(controlnet, (list, tuple)):
158
+ controlnet = MultiControlNetModel3D(controlnet)
159
+
160
+ self.register_modules(
161
+ vae=vae,
162
+ text_encoder=text_encoder,
163
+ tokenizer=tokenizer,
164
+ unet=unet,
165
+ controlnet=controlnet,
166
+ scheduler=scheduler,
167
+ interpolater=interpolater,
168
+ )
169
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
170
+
171
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
172
+ def enable_vae_slicing(self):
173
+ r"""
174
+ Enable sliced VAE decoding.
175
+
176
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
177
+ steps. This is useful to save some memory and allow larger batch sizes.
178
+ """
179
+ self.vae.enable_slicing()
180
+
181
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
182
+ def disable_vae_slicing(self):
183
+ r"""
184
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
185
+ computing decoding in one step.
186
+ """
187
+ self.vae.disable_slicing()
188
+
189
+ def enable_sequential_cpu_offload(self, gpu_id=0):
190
+ r"""
191
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
192
+ text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
193
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
194
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
195
+ `enable_model_cpu_offload`, but performance is lower.
196
+ """
197
+ if is_accelerate_available():
198
+ from accelerate import cpu_offload
199
+ else:
200
+ raise ImportError("Please install accelerate via `pip install accelerate`")
201
+
202
+ device = torch.device(f"cuda:{gpu_id}")
203
+
204
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
205
+ cpu_offload(cpu_offloaded_model, device)
206
+
207
+ if self.safety_checker is not None:
208
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
209
+
210
+ def enable_model_cpu_offload(self, gpu_id=0):
211
+ r"""
212
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
213
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
214
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
215
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
216
+ """
217
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
218
+ from accelerate import cpu_offload_with_hook
219
+ else:
220
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
221
+
222
+ device = torch.device(f"cuda:{gpu_id}")
223
+
224
+ hook = None
225
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
226
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
227
+
228
+ if self.safety_checker is not None:
229
+ # the safety checker can offload the vae again
230
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
231
+
232
+ # control net hook has be manually offloaded as it alternates with unet
233
+ cpu_offload_with_hook(self.controlnet, device)
234
+
235
+ # We'll offload the last model manually.
236
+ self.final_offload_hook = hook
237
+
238
+ @property
239
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
240
+ def _execution_device(self):
241
+ r"""
242
+ Returns the device on which the pipeline's models will be executed. After calling
243
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
244
+ hooks.
245
+ """
246
+ if not hasattr(self.unet, "_hf_hook"):
247
+ return self.device
248
+ for module in self.unet.modules():
249
+ if (
250
+ hasattr(module, "_hf_hook")
251
+ and hasattr(module._hf_hook, "execution_device")
252
+ and module._hf_hook.execution_device is not None
253
+ ):
254
+ return torch.device(module._hf_hook.execution_device)
255
+ return self.device
256
+
257
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
258
+ def _encode_prompt(
259
+ self,
260
+ prompt,
261
+ device,
262
+ num_videos_per_prompt,
263
+ do_classifier_free_guidance,
264
+ negative_prompt=None,
265
+ prompt_embeds: Optional[torch.FloatTensor] = None,
266
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
267
+ ):
268
+ r"""
269
+ Encodes the prompt into text encoder hidden states.
270
+
271
+ Args:
272
+ prompt (`str` or `List[str]`, *optional*):
273
+ prompt to be encoded
274
+ device: (`torch.device`):
275
+ torch device
276
+ num_videos_per_prompt (`int`):
277
+ number of images that should be generated per prompt
278
+ do_classifier_free_guidance (`bool`):
279
+ whether to use classifier free guidance or not
280
+ negative_prompt (`str` or `List[str]`, *optional*):
281
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
282
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
283
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
284
+ prompt_embeds (`torch.FloatTensor`, *optional*):
285
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
286
+ provided, text embeddings will be generated from `prompt` input argument.
287
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
288
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
289
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
290
+ argument.
291
+ """
292
+ if prompt is not None and isinstance(prompt, str):
293
+ batch_size = 1
294
+ elif prompt is not None and isinstance(prompt, list):
295
+ batch_size = len(prompt)
296
+ else:
297
+ batch_size = prompt_embeds.shape[0]
298
+
299
+ if prompt_embeds is None:
300
+ text_inputs = self.tokenizer(
301
+ prompt,
302
+ padding="max_length",
303
+ max_length=self.tokenizer.model_max_length,
304
+ truncation=True,
305
+ return_tensors="pt",
306
+ )
307
+ text_input_ids = text_inputs.input_ids
308
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
309
+
310
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
311
+ text_input_ids, untruncated_ids
312
+ ):
313
+ removed_text = self.tokenizer.batch_decode(
314
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
315
+ )
316
+ logger.warning(
317
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
318
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
319
+ )
320
+
321
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
322
+ attention_mask = text_inputs.attention_mask.to(device)
323
+ else:
324
+ attention_mask = None
325
+
326
+ prompt_embeds = self.text_encoder(
327
+ text_input_ids.to(device),
328
+ attention_mask=attention_mask,
329
+ )
330
+ prompt_embeds = prompt_embeds[0]
331
+
332
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
333
+
334
+ bs_embed, seq_len, _ = prompt_embeds.shape
335
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
336
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
337
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
338
+
339
+ # get unconditional embeddings for classifier free guidance
340
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
341
+ uncond_tokens: List[str]
342
+ if negative_prompt is None:
343
+ uncond_tokens = [""] * batch_size
344
+ elif type(prompt) is not type(negative_prompt):
345
+ raise TypeError(
346
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
347
+ f" {type(prompt)}."
348
+ )
349
+ elif isinstance(negative_prompt, str):
350
+ uncond_tokens = [negative_prompt]
351
+ elif batch_size != len(negative_prompt):
352
+ raise ValueError(
353
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
354
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
355
+ " the batch size of `prompt`."
356
+ )
357
+ else:
358
+ uncond_tokens = negative_prompt
359
+
360
+ max_length = prompt_embeds.shape[1]
361
+ uncond_input = self.tokenizer(
362
+ uncond_tokens,
363
+ padding="max_length",
364
+ max_length=max_length,
365
+ truncation=True,
366
+ return_tensors="pt",
367
+ )
368
+
369
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
370
+ attention_mask = uncond_input.attention_mask.to(device)
371
+ else:
372
+ attention_mask = None
373
+
374
+ negative_prompt_embeds = self.text_encoder(
375
+ uncond_input.input_ids.to(device),
376
+ attention_mask=attention_mask,
377
+ )
378
+ negative_prompt_embeds = negative_prompt_embeds[0]
379
+
380
+ if do_classifier_free_guidance:
381
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
382
+ seq_len = negative_prompt_embeds.shape[1]
383
+
384
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
385
+
386
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
387
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
388
+
389
+ # For classifier free guidance, we need to do two forward passes.
390
+ # Here we concatenate the unconditional and text embeddings into a single batch
391
+ # to avoid doing two forward passes
392
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
393
+
394
+ return prompt_embeds
395
+
396
+
397
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
398
+ def decode_latents(self, latents, return_tensor=False):
399
+ video_length = latents.shape[2]
400
+ latents = 1 / 0.18215 * latents
401
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
402
+ video = self.vae.decode(latents).sample
403
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
404
+ video = (video / 2 + 0.5).clamp(0, 1)
405
+ if return_tensor:
406
+ return video
407
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
408
+ video = video.cpu().float().numpy()
409
+ return video
410
+
411
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
412
+ def prepare_extra_step_kwargs(self, generator, eta):
413
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
414
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
415
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
416
+ # and should be between [0, 1]
417
+
418
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
419
+ extra_step_kwargs = {}
420
+ if accepts_eta:
421
+ extra_step_kwargs["eta"] = eta
422
+
423
+ # check if the scheduler accepts generator
424
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
425
+ if accepts_generator:
426
+ extra_step_kwargs["generator"] = generator
427
+ return extra_step_kwargs
428
+
429
+ def check_inputs(
430
+ self,
431
+ prompt,
432
+ # image,
433
+ height,
434
+ width,
435
+ callback_steps,
436
+ negative_prompt=None,
437
+ prompt_embeds=None,
438
+ negative_prompt_embeds=None,
439
+ controlnet_conditioning_scale=1.0,
440
+ ):
441
+ if height % 8 != 0 or width % 8 != 0:
442
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
443
+
444
+ if (callback_steps is None) or (
445
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
446
+ ):
447
+ raise ValueError(
448
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
449
+ f" {type(callback_steps)}."
450
+ )
451
+
452
+ if prompt is not None and prompt_embeds is not None:
453
+ raise ValueError(
454
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
455
+ " only forward one of the two."
456
+ )
457
+ elif prompt is None and prompt_embeds is None:
458
+ raise ValueError(
459
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
460
+ )
461
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
462
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
463
+
464
+ if negative_prompt is not None and negative_prompt_embeds is not None:
465
+ raise ValueError(
466
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
467
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
468
+ )
469
+
470
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
471
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
472
+ raise ValueError(
473
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
474
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
475
+ f" {negative_prompt_embeds.shape}."
476
+ )
477
+
478
+ # Check `image`
479
+
480
+ # if isinstance(self.controlnet, ControlNetModel):
481
+ # self.check_image(image, prompt, prompt_embeds)
482
+ # elif isinstance(self.controlnet, MultiControlNetModel):
483
+ # if not isinstance(image, list):
484
+ # raise TypeError("For multiple controlnets: `image` must be type `list`")
485
+
486
+ # if len(image) != len(self.controlnet.nets):
487
+ # raise ValueError(
488
+ # "For multiple controlnets: `image` must have the same length as the number of controlnets."
489
+ # )
490
+
491
+ # for image_ in image:
492
+ # self.check_image(image_, prompt, prompt_embeds)
493
+ # else:
494
+ # assert False
495
+
496
+ # Check `controlnet_conditioning_scale`
497
+
498
+ if isinstance(self.controlnet, ControlNetModel3D):
499
+ if not isinstance(controlnet_conditioning_scale, float):
500
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
501
+ elif isinstance(self.controlnet, MultiControlNetModel3D):
502
+ if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
503
+ self.controlnet.nets
504
+ ):
505
+ raise ValueError(
506
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
507
+ " the same length as the number of controlnets"
508
+ )
509
+ else:
510
+ assert False
511
+
512
+ def check_image(self, image, prompt, prompt_embeds):
513
+ image_is_pil = isinstance(image, PIL.Image.Image)
514
+ image_is_tensor = isinstance(image, torch.Tensor)
515
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
516
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
517
+
518
+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
519
+ raise TypeError(
520
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
521
+ )
522
+
523
+ if image_is_pil:
524
+ image_batch_size = 1
525
+ elif image_is_tensor:
526
+ image_batch_size = image.shape[0]
527
+ elif image_is_pil_list:
528
+ image_batch_size = len(image)
529
+ elif image_is_tensor_list:
530
+ image_batch_size = len(image)
531
+
532
+ if prompt is not None and isinstance(prompt, str):
533
+ prompt_batch_size = 1
534
+ elif prompt is not None and isinstance(prompt, list):
535
+ prompt_batch_size = len(prompt)
536
+ elif prompt_embeds is not None:
537
+ prompt_batch_size = prompt_embeds.shape[0]
538
+
539
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
540
+ raise ValueError(
541
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
542
+ )
543
+
544
+ def prepare_image(
545
+ self, image, width, height, batch_size, num_videos_per_prompt, device, dtype, do_classifier_free_guidance
546
+ ):
547
+ if not isinstance(image, torch.Tensor):
548
+ if isinstance(image, PIL.Image.Image):
549
+ image = [image]
550
+
551
+ if isinstance(image[0], PIL.Image.Image):
552
+ images = []
553
+
554
+ for image_ in image:
555
+ image_ = image_.convert("RGB")
556
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
557
+ image_ = np.array(image_)
558
+ image_ = image_[None, :]
559
+ images.append(image_)
560
+
561
+ image = images
562
+
563
+ image = np.concatenate(image, axis=0)
564
+ image = np.array(image).astype(np.float32) / 255.0
565
+ image = image.transpose(0, 3, 1, 2)
566
+ image = torch.from_numpy(image)
567
+ elif isinstance(image[0], torch.Tensor):
568
+ image = torch.cat(image, dim=0)
569
+
570
+ image_batch_size = image.shape[0]
571
+
572
+ if image_batch_size == 1:
573
+ repeat_by = batch_size
574
+ else:
575
+ # image batch size is the same as prompt batch size
576
+ repeat_by = num_videos_per_prompt
577
+
578
+ image = image.repeat_interleave(repeat_by, dim=0)
579
+
580
+ image = image.to(device=device, dtype=dtype)
581
+
582
+ if do_classifier_free_guidance:
583
+ image = torch.cat([image] * 2)
584
+
585
+ return image
586
+
587
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
588
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, \
589
+ device, generator, latents=None, same_frame_noise=True):
590
+ if isinstance(generator, list) and len(generator) != batch_size:
591
+ raise ValueError(
592
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
593
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
594
+ )
595
+
596
+ if latents is None:
597
+ if same_frame_noise:
598
+ shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
599
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
600
+ latents = latents.repeat(1, 1, video_length, 1, 1)
601
+ else:
602
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
603
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
604
+ else:
605
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
606
+ if latents.shape != shape:
607
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
608
+ latents = latents.to(device)
609
+
610
+ # scale the initial noise by the standard deviation required by the scheduler
611
+ latents = latents * self.scheduler.init_noise_sigma
612
+ return latents
613
+
614
+ def _default_height_width(self, height, width, image):
615
+ # NOTE: It is possible that a list of images have different
616
+ # dimensions for each image, so just checking the first image
617
+ # is not _exactly_ correct, but it is simple.
618
+ while isinstance(image, list):
619
+ image = image[0]
620
+
621
+ if height is None:
622
+ if isinstance(image, PIL.Image.Image):
623
+ height = image.height
624
+ elif isinstance(image, torch.Tensor):
625
+ height = image.shape[3]
626
+
627
+ height = (height // 8) * 8 # round down to nearest multiple of 8
628
+
629
+ if width is None:
630
+ if isinstance(image, PIL.Image.Image):
631
+ width = image.width
632
+ elif isinstance(image, torch.Tensor):
633
+ width = image.shape[2]
634
+
635
+ width = (width // 8) * 8 # round down to nearest multiple of 8
636
+
637
+ return height, width
638
+
639
+ # override DiffusionPipeline
640
+ def save_pretrained(
641
+ self,
642
+ save_directory: Union[str, os.PathLike],
643
+ safe_serialization: bool = False,
644
+ variant: Optional[str] = None,
645
+ ):
646
+ if isinstance(self.controlnet, ControlNetModel3D):
647
+ super().save_pretrained(save_directory, safe_serialization, variant)
648
+ else:
649
+ raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")
650
+
651
+ def get_alpha_prev(self, timestep):
652
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
653
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
654
+ return alpha_prod_t_prev
655
+
656
+ def get_slide_window_indices(self, video_length, window_size):
657
+ assert window_size >=3
658
+ key_frame_indices = np.arange(0, video_length, window_size-1).tolist()
659
+
660
+ # Append last index
661
+ if key_frame_indices[-1] != (video_length-1):
662
+ key_frame_indices.append(video_length-1)
663
+
664
+ slices = np.split(np.arange(video_length), key_frame_indices)
665
+ inter_frame_list = []
666
+ for s in slices:
667
+ if len(s) < 2:
668
+ continue
669
+ inter_frame_list.append(s[1:].tolist())
670
+ return key_frame_indices, inter_frame_list
671
+
672
+ @torch.no_grad()
673
+ def __call__(
674
+ self,
675
+ prompt: Union[str, List[str]] = None,
676
+ video_length: Optional[int] = 1,
677
+ frames: Union[List[torch.FloatTensor], List[PIL.Image.Image], List[List[torch.FloatTensor]], List[List[PIL.Image.Image]]] = None,
678
+ height: Optional[int] = None,
679
+ width: Optional[int] = None,
680
+ num_inference_steps: int = 50,
681
+ guidance_scale: float = 7.5,
682
+ negative_prompt: Optional[Union[str, List[str]]] = None,
683
+ num_videos_per_prompt: Optional[int] = 1,
684
+ eta: float = 0.0,
685
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
686
+ latents: Optional[torch.FloatTensor] = None,
687
+ prompt_embeds: Optional[torch.FloatTensor] = None,
688
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
689
+ output_type: Optional[str] = "tensor",
690
+ return_dict: bool = True,
691
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
692
+ callback_steps: int = 1,
693
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
694
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
695
+ smooth_steps: List = [19, 20],
696
+ **kwargs,
697
+ ):
698
+ r"""
699
+ Function invoked when calling the pipeline for generation.
700
+
701
+ Args:
702
+ prompt (`str` or `List[str]`, *optional*):
703
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
704
+ instead.
705
+ frames (`List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
706
+ `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
707
+ The ControlVideo input condition. ControlVideo uses this input condition to generate guidance to Unet. If
708
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
709
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
710
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
711
+ specified in init, images must be passed as a list such that each element of the list can be correctly
712
+ batched for input to a single controlnet.
713
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
714
+ The height in pixels of the generated image.
715
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
716
+ The width in pixels of the generated image.
717
+ num_inference_steps (`int`, *optional*, defaults to 50):
718
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
719
+ expense of slower inference.
720
+ guidance_scale (`float`, *optional*, defaults to 7.5):
721
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
722
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
723
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
724
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
725
+ usually at the expense of lower image quality.
726
+ negative_prompt (`str` or `List[str]`, *optional*):
727
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
728
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
729
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
730
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
731
+ The number of images to generate per prompt.
732
+ eta (`float`, *optional*, defaults to 0.0):
733
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
734
+ [`schedulers.DDIMScheduler`], will be ignored for others.
735
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
736
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
737
+ to make generation deterministic.
738
+ latents (`torch.FloatTensor`, *optional*):
739
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
740
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
741
+ tensor will ge generated by sampling using the supplied random `generator`.
742
+ prompt_embeds (`torch.FloatTensor`, *optional*):
743
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
744
+ provided, text embeddings will be generated from `prompt` input argument.
745
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
746
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
747
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
748
+ argument.
749
+ output_type (`str`, *optional*, defaults to `"pil"`):
750
+ The output format of the generate image. Choose between
751
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
752
+ return_dict (`bool`, *optional*, defaults to `True`):
753
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
754
+ plain tuple.
755
+ callback (`Callable`, *optional*):
756
+ A function that will be called every `callback_steps` steps during inference. The function will be
757
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
758
+ callback_steps (`int`, *optional*, defaults to 1):
759
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
760
+ called at every step.
761
+ cross_attention_kwargs (`dict`, *optional*):
762
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
763
+ `self.processor` in
764
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
765
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
766
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
767
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
768
+ corresponding scale as a list.
769
+ smooth_steps (`List[int]`):
770
+ Perform smoother on predicted RGB frames at these timesteps.
771
+
772
+ Examples:
773
+
774
+ Returns:
775
+ [`ControlVideoPipelineOutput`] or `tuple`:
776
+ [`ControlVideoPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
777
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
778
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
779
+ (nsfw) content, according to the `safety_checker`.
780
+ """
781
+ # 0. Default height and width to unet
782
+ height, width = self._default_height_width(height, width, frames)
783
+
784
+ # 1. Check inputs. Raise error if not correct
785
+ self.check_inputs(
786
+ prompt,
787
+ height,
788
+ width,
789
+ callback_steps,
790
+ negative_prompt,
791
+ prompt_embeds,
792
+ negative_prompt_embeds,
793
+ controlnet_conditioning_scale,
794
+ )
795
+
796
+ # 2. Define call parameters
797
+ if prompt is not None and isinstance(prompt, str):
798
+ batch_size = 1
799
+ elif prompt is not None and isinstance(prompt, list):
800
+ batch_size = len(prompt)
801
+ else:
802
+ batch_size = prompt_embeds.shape[0]
803
+
804
+ device = self._execution_device
805
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
806
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
807
+ # corresponds to doing no classifier free guidance.
808
+ do_classifier_free_guidance = guidance_scale > 1.0
809
+
810
+ if isinstance(self.controlnet, MultiControlNetModel3D) and isinstance(controlnet_conditioning_scale, float):
811
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
812
+
813
+ # 3. Encode input prompt
814
+ prompt_embeds = self._encode_prompt(
815
+ prompt,
816
+ device,
817
+ num_videos_per_prompt,
818
+ do_classifier_free_guidance,
819
+ negative_prompt,
820
+ prompt_embeds=prompt_embeds,
821
+ negative_prompt_embeds=negative_prompt_embeds,
822
+ )
823
+
824
+ # 4. Prepare image
825
+ if isinstance(self.controlnet, ControlNetModel3D):
826
+ images = []
827
+ for i_img in frames:
828
+ i_img = self.prepare_image(
829
+ image=i_img,
830
+ width=width,
831
+ height=height,
832
+ batch_size=batch_size * num_videos_per_prompt,
833
+ num_videos_per_prompt=num_videos_per_prompt,
834
+ device=device,
835
+ dtype=self.controlnet.dtype,
836
+ do_classifier_free_guidance=do_classifier_free_guidance,
837
+ )
838
+ images.append(i_img)
839
+ frames = torch.stack(images, dim=2) # b x c x f x h x w
840
+ elif isinstance(self.controlnet, MultiControlNetModel3D):
841
+ images = []
842
+ for i_img in frames:
843
+ i_images = []
844
+ for ii_img in i_img:
845
+ ii_img = self.prepare_image(
846
+ image=ii_img,
847
+ width=width,
848
+ height=height,
849
+ batch_size=batch_size * num_videos_per_prompt,
850
+ num_videos_per_prompt=num_videos_per_prompt,
851
+ device=device,
852
+ dtype=self.controlnet.dtype,
853
+ do_classifier_free_guidance=do_classifier_free_guidance,
854
+ )
855
+
856
+ i_images.append(ii_img)
857
+ images.append(torch.stack(i_images, dim=2))
858
+ frames = images
859
+ else:
860
+ assert False
861
+
862
+ # 5. Prepare timesteps
863
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
864
+ timesteps = self.scheduler.timesteps
865
+
866
+ # 6. Prepare latent variables
867
+ num_channels_latents = self.unet.in_channels
868
+ latents = self.prepare_latents(
869
+ batch_size * num_videos_per_prompt,
870
+ num_channels_latents,
871
+ video_length,
872
+ height,
873
+ width,
874
+ prompt_embeds.dtype,
875
+ device,
876
+ generator,
877
+ latents,
878
+ same_frame_noise=True,
879
+ )
880
+
881
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
882
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
883
+
884
+
885
+ # Prepare video indices if performing smoothing
886
+ if len(smooth_steps) > 0:
887
+ video_indices = np.arange(video_length)
888
+ zero_indices = video_indices[0::2]
889
+ one_indices = video_indices[1::2]
890
+
891
+ # 8. Denoising loop
892
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
893
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
894
+ for i, t in enumerate(timesteps):
895
+ torch.cuda.empty_cache()
896
+
897
+ # expand the latents if we are doing classifier free guidance
898
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
899
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
900
+
901
+ # controlnet(s) inference
902
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
903
+ latent_model_input,
904
+ t,
905
+ encoder_hidden_states=prompt_embeds,
906
+ controlnet_cond=frames,
907
+ conditioning_scale=controlnet_conditioning_scale,
908
+ return_dict=False,
909
+ )
910
+ # predict the noise residual
911
+ noise_pred = self.unet(
912
+ latent_model_input,
913
+ t,
914
+ encoder_hidden_states=prompt_embeds,
915
+ cross_attention_kwargs=cross_attention_kwargs,
916
+ down_block_additional_residuals=down_block_res_samples,
917
+ mid_block_additional_residual=mid_block_res_sample,
918
+ ).sample
919
+
920
+ # perform guidance
921
+ if do_classifier_free_guidance:
922
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
923
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
924
+
925
+ # compute the previous noisy sample x_t -> x_t-1
926
+ step_dict = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
927
+ latents = step_dict.prev_sample
928
+ pred_original_sample = step_dict.pred_original_sample
929
+
930
+ # Smooth videos
931
+ if (num_inference_steps - i) in smooth_steps:
932
+ pred_video = self.decode_latents(pred_original_sample, return_tensor=True) # b c f h w
933
+ pred_video = rearrange(pred_video, "b c f h w -> b f c h w")
934
+ for b_i in range(len(pred_video)):
935
+ if i % 2 == 0:
936
+ for v_i in range(len(zero_indices)-1):
937
+ s_frame = pred_video[b_i][zero_indices[v_i]].unsqueeze(0)
938
+ e_frame = pred_video[b_i][zero_indices[v_i+1]].unsqueeze(0)
939
+ pred_video[b_i][one_indices[v_i]] = self.interpolater.inference(s_frame, e_frame)[0]
940
+ else:
941
+ if video_length % 2 == 1:
942
+ tmp_one_indices = [0] + one_indices.tolist() + [video_length-1]
943
+ else:
944
+ tmp_one_indices = [0] + one_indices.tolist()
945
+
946
+ for v_i in range(len(tmp_one_indices)-1):
947
+ s_frame = pred_video[b_i][tmp_one_indices[v_i]].unsqueeze(0)
948
+ e_frame = pred_video[b_i][tmp_one_indices[v_i+1]].unsqueeze(0)
949
+ pred_video[b_i][zero_indices[v_i]] = self.interpolater.inference(s_frame, e_frame)[0]
950
+ pred_video = rearrange(pred_video, "b f c h w -> (b f) c h w")
951
+ pred_video = 2.0 * pred_video - 1.0
952
+ # ori_pred_original_sample = pred_original_sample
953
+ pred_original_sample = self.vae.encode(pred_video).latent_dist.sample(generator)
954
+ pred_original_sample *= self.vae.config.scaling_factor
955
+ pred_original_sample = rearrange(pred_original_sample, "(b f) c h w -> b c f h w", f=video_length)
956
+
957
+ # predict xt-1 with smoothed x0
958
+ alpha_prod_t_prev =self.get_alpha_prev(t)
959
+ # preserve more details
960
+ # pred_original_sample = ori_pred_original_sample * alpha_prod_t_prev + (1 - alpha_prod_t_prev) * pred_original_sample
961
+ # compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
962
+ pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * noise_pred
963
+ # compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
964
+ latents = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
965
+
966
+
967
+ # call the callback, if provided
968
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
969
+ progress_bar.update()
970
+ if callback is not None and i % callback_steps == 0:
971
+ callback(i, t, latents)
972
+
973
+ # If we do sequential model offloading, let's offload unet and controlnet
974
+ # manually for max memory savings
975
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
976
+ self.unet.to("cpu")
977
+ self.controlnet.to("cpu")
978
+ torch.cuda.empty_cache()
979
+ # Post-processing
980
+ video = self.decode_latents(latents)
981
+
982
+ # Convert to tensor
983
+ if output_type == "tensor":
984
+ video = torch.from_numpy(video)
985
+
986
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
987
+ self.final_offload_hook.offload()
988
+
989
+ if not return_dict:
990
+ return video
991
+
992
+ return ControlVideoPipelineOutput(videos=video)
993
+
994
+ @torch.no_grad()
995
+ def generate_long_video(
996
+ self,
997
+ prompt: Union[str, List[str]] = None,
998
+ video_length: Optional[int] = 1,
999
+ frames: Union[List[torch.FloatTensor], List[PIL.Image.Image], List[List[torch.FloatTensor]], List[List[PIL.Image.Image]]] = None,
1000
+ height: Optional[int] = None,
1001
+ width: Optional[int] = None,
1002
+ num_inference_steps: int = 50,
1003
+ guidance_scale: float = 7.5,
1004
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1005
+ num_videos_per_prompt: Optional[int] = 1,
1006
+ eta: float = 0.0,
1007
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1008
+ latents: Optional[torch.FloatTensor] = None,
1009
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1010
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1011
+ output_type: Optional[str] = "tensor",
1012
+ return_dict: bool = True,
1013
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1014
+ callback_steps: int = 1,
1015
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1016
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
1017
+ smooth_steps: List = [19, 20],
1018
+ window_size: int = 8,
1019
+ **kwargs,
1020
+ ):
1021
+ r"""
1022
+ Function invoked when calling the pipeline for generation.
1023
+
1024
+ Args:
1025
+ prompt (`str` or `List[str]`, *optional*):
1026
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1027
+ instead.
1028
+ frames (`List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
1029
+ `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
1030
+ The ControlVideo input condition. ControlVideo uses this input condition to generate guidance to Unet. If
1031
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
1032
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
1033
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
1034
+ specified in init, images must be passed as a list such that each element of the list can be correctly
1035
+ batched for input to a single controlnet.
1036
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1037
+ The height in pixels of the generated image.
1038
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1039
+ The width in pixels of the generated image.
1040
+ num_inference_steps (`int`, *optional*, defaults to 50):
1041
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1042
+ expense of slower inference.
1043
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1044
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1045
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1046
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1047
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1048
+ usually at the expense of lower image quality.
1049
+ negative_prompt (`str` or `List[str]`, *optional*):
1050
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1051
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
1052
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
1053
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
1054
+ The number of images to generate per prompt.
1055
+ eta (`float`, *optional*, defaults to 0.0):
1056
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1057
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1058
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1059
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1060
+ to make generation deterministic.
1061
+ latents (`torch.FloatTensor`, *optional*):
1062
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1063
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1064
+ tensor will ge generated by sampling using the supplied random `generator`.
1065
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1066
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1067
+ provided, text embeddings will be generated from `prompt` input argument.
1068
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1069
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1070
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1071
+ argument.
1072
+ output_type (`str`, *optional*, defaults to `"pil"`):
1073
+ The output format of the generate image. Choose between
1074
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1075
+ return_dict (`bool`, *optional*, defaults to `True`):
1076
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1077
+ plain tuple.
1078
+ callback (`Callable`, *optional*):
1079
+ A function that will be called every `callback_steps` steps during inference. The function will be
1080
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1081
+ callback_steps (`int`, *optional*, defaults to 1):
1082
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1083
+ called at every step.
1084
+ cross_attention_kwargs (`dict`, *optional*):
1085
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1086
+ `self.processor` in
1087
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1088
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1089
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
1090
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
1091
+ corresponding scale as a list.
1092
+ smooth_steps (`List[int]`):
1093
+ Perform smoother on predicted RGB frames at these timesteps.
1094
+ window_size ('int'):
1095
+ The length of each short clip.
1096
+ Examples:
1097
+
1098
+ Returns:
1099
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1100
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1101
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1102
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1103
+ (nsfw) content, according to the `safety_checker`.
1104
+ """
1105
+ # 0. Default height and width to unet
1106
+ height, width = self._default_height_width(height, width, frames)
1107
+
1108
+ # 1. Check inputs. Raise error if not correct
1109
+ self.check_inputs(
1110
+ prompt,
1111
+ height,
1112
+ width,
1113
+ callback_steps,
1114
+ negative_prompt,
1115
+ prompt_embeds,
1116
+ negative_prompt_embeds,
1117
+ controlnet_conditioning_scale,
1118
+ )
1119
+
1120
+ # 2. Define call parameters
1121
+ if prompt is not None and isinstance(prompt, str):
1122
+ batch_size = 1
1123
+ elif prompt is not None and isinstance(prompt, list):
1124
+ batch_size = len(prompt)
1125
+ else:
1126
+ batch_size = prompt_embeds.shape[0]
1127
+
1128
+ device = self._execution_device
1129
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1130
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1131
+ # corresponds to doing no classifier free guidance.
1132
+ do_classifier_free_guidance = guidance_scale > 1.0
1133
+
1134
+ if isinstance(self.controlnet, MultiControlNetModel3D) and isinstance(controlnet_conditioning_scale, float):
1135
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
1136
+
1137
+ # 3. Encode input prompt
1138
+ prompt_embeds = self._encode_prompt(
1139
+ prompt,
1140
+ device,
1141
+ num_videos_per_prompt,
1142
+ do_classifier_free_guidance,
1143
+ negative_prompt,
1144
+ prompt_embeds=prompt_embeds,
1145
+ negative_prompt_embeds=negative_prompt_embeds,
1146
+ )
1147
+
1148
+ # 4. Prepare image
1149
+ if isinstance(self.controlnet, ControlNetModel3D):
1150
+ images = []
1151
+ for i_img in frames:
1152
+ i_img = self.prepare_image(
1153
+ image=i_img,
1154
+ width=width,
1155
+ height=height,
1156
+ batch_size=batch_size * num_videos_per_prompt,
1157
+ num_videos_per_prompt=num_videos_per_prompt,
1158
+ device=device,
1159
+ dtype=self.controlnet.dtype,
1160
+ do_classifier_free_guidance=do_classifier_free_guidance,
1161
+ )
1162
+ images.append(i_img)
1163
+ frames = torch.stack(images, dim=2) # b x c x f x h x w
1164
+ elif isinstance(self.controlnet, MultiControlNetModel3D):
1165
+ images = []
1166
+ for i_img in frames:
1167
+ i_images = []
1168
+ for ii_img in i_img:
1169
+ ii_img = self.prepare_image(
1170
+ image=ii_img,
1171
+ width=width,
1172
+ height=height,
1173
+ batch_size=batch_size * num_videos_per_prompt,
1174
+ num_videos_per_prompt=num_videos_per_prompt,
1175
+ device=device,
1176
+ dtype=self.controlnet.dtype,
1177
+ do_classifier_free_guidance=do_classifier_free_guidance,
1178
+ )
1179
+
1180
+ i_images.append(ii_img)
1181
+ images.append(torch.stack(i_images, dim=2))
1182
+ frames = images
1183
+ else:
1184
+ assert False
1185
+
1186
+ # 5. Prepare timesteps
1187
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1188
+ timesteps = self.scheduler.timesteps
1189
+
1190
+ # 6. Prepare latent variables
1191
+ num_channels_latents = self.unet.in_channels
1192
+ latents = self.prepare_latents(
1193
+ batch_size * num_videos_per_prompt,
1194
+ num_channels_latents,
1195
+ video_length,
1196
+ height,
1197
+ width,
1198
+ prompt_embeds.dtype,
1199
+ device,
1200
+ generator,
1201
+ latents,
1202
+ same_frame_noise=True,
1203
+ )
1204
+
1205
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1206
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1207
+
1208
+ # Prepare indices of key frames and interval frames
1209
+ key_frame_indices, inter_frame_list = self.get_slide_window_indices(video_length, window_size)
1210
+
1211
+ # Prepare video indices if performing smoothing
1212
+ if len(smooth_steps) > 0:
1213
+ video_indices = np.arange(video_length)
1214
+ zero_indices = video_indices[0::2]
1215
+ one_indices = video_indices[1::2]
1216
+
1217
+ # 8. Denoising loop
1218
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1219
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1220
+ for i, t in enumerate(timesteps):
1221
+ torch.cuda.empty_cache()
1222
+ # expand the latents if we are doing classifier free guidance
1223
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1224
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1225
+ noise_pred = torch.zeros_like(latents)
1226
+ pred_original_sample = torch.zeros_like(latents)
1227
+
1228
+ # 8.1 Key frames
1229
+ # controlnet(s) inference
1230
+ key_down_block_res_samples, key_mid_block_res_sample = self.controlnet(
1231
+ latent_model_input[:, :, key_frame_indices],
1232
+ t,
1233
+ encoder_hidden_states=prompt_embeds,
1234
+ controlnet_cond=frames[:, :, key_frame_indices],
1235
+ conditioning_scale=controlnet_conditioning_scale,
1236
+ return_dict=False,
1237
+ )
1238
+ # predict the noise residual
1239
+ key_noise_pred = self.unet(
1240
+ latent_model_input[:, :, key_frame_indices],
1241
+ t,
1242
+ encoder_hidden_states=prompt_embeds,
1243
+ cross_attention_kwargs=cross_attention_kwargs,
1244
+ down_block_additional_residuals=key_down_block_res_samples,
1245
+ mid_block_additional_residual=key_mid_block_res_sample,
1246
+ inter_frame=False,
1247
+ ).sample
1248
+
1249
+ # perform guidance
1250
+ if do_classifier_free_guidance:
1251
+ noise_pred_uncond, noise_pred_text = key_noise_pred.chunk(2)
1252
+ noise_pred[:, :, key_frame_indices] = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1253
+
1254
+ # compute the previous noisy sample x_t -> x_t-1
1255
+ key_step_dict = self.scheduler.step(noise_pred[:, :, key_frame_indices], t, latents[:, :, key_frame_indices], **extra_step_kwargs)
1256
+ latents[:, :, key_frame_indices] = key_step_dict.prev_sample
1257
+ pred_original_sample[:, :, key_frame_indices] = key_step_dict.pred_original_sample
1258
+
1259
+ # 8.2 compute interval frames
1260
+ for f_i, frame_ids in enumerate(inter_frame_list):
1261
+ input_frame_ids = key_frame_indices[f_i:f_i+2] + frame_ids
1262
+ # controlnet(s) inference
1263
+ inter_down_block_res_samples, inter_mid_block_res_sample = self.controlnet(
1264
+ latent_model_input[:, :, input_frame_ids],
1265
+ t,
1266
+ encoder_hidden_states=prompt_embeds,
1267
+ controlnet_cond=frames[:, :, input_frame_ids],
1268
+ conditioning_scale=controlnet_conditioning_scale,
1269
+ return_dict=False,
1270
+ )
1271
+ # predict the noise residual
1272
+ inter_noise_pred = self.unet(
1273
+ latent_model_input[:, :, input_frame_ids],
1274
+ t,
1275
+ encoder_hidden_states=prompt_embeds,
1276
+ cross_attention_kwargs=cross_attention_kwargs,
1277
+ down_block_additional_residuals=inter_down_block_res_samples,
1278
+ mid_block_additional_residual=inter_mid_block_res_sample,
1279
+ inter_frame=True,
1280
+ ).sample
1281
+
1282
+ # perform guidance
1283
+ if do_classifier_free_guidance:
1284
+ noise_pred_uncond, noise_pred_text = inter_noise_pred[:, :, 2:].chunk(2)
1285
+ noise_pred[:, :, frame_ids] = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1286
+
1287
+ # compute the previous noisy sample x_t -> x_t-1
1288
+ step_dict = self.scheduler.step(noise_pred[:, :, frame_ids], t, latents[:, :, frame_ids], **extra_step_kwargs)
1289
+ latents[:, :, frame_ids] = step_dict.prev_sample
1290
+ pred_original_sample[:, :, frame_ids] = step_dict.pred_original_sample
1291
+
1292
+ # Smooth videos
1293
+ if (num_inference_steps - i) in smooth_steps:
1294
+ pred_video = self.decode_latents(pred_original_sample, return_tensor=True) # b c f h w
1295
+ pred_video = rearrange(pred_video, "b c f h w -> b f c h w")
1296
+ for b_i in range(len(pred_video)):
1297
+ if i % 2 == 0:
1298
+ for v_i in range(len(zero_indices)-1):
1299
+ s_frame = pred_video[b_i][zero_indices[v_i]].unsqueeze(0)
1300
+ e_frame = pred_video[b_i][zero_indices[v_i+1]].unsqueeze(0)
1301
+ pred_video[b_i][one_indices[v_i]] = self.interpolater.inference(s_frame, e_frame)[0]
1302
+ else:
1303
+ if video_length % 2 == 1:
1304
+ tmp_one_indices = [0] + one_indices.tolist() + [video_length-1]
1305
+ else:
1306
+ tmp_one_indices = [0] + one_indices.tolist()
1307
+ for v_i in range(len(tmp_one_indices)-1):
1308
+ s_frame = pred_video[b_i][tmp_one_indices[v_i]].unsqueeze(0)
1309
+ e_frame = pred_video[b_i][tmp_one_indices[v_i+1]].unsqueeze(0)
1310
+ pred_video[b_i][zero_indices[v_i]] = self.interpolater.inference(s_frame, e_frame)[0]
1311
+ pred_video = rearrange(pred_video, "b f c h w -> (b f) c h w")
1312
+ pred_video = 2.0 * pred_video - 1.0
1313
+ for v_i in range(len(pred_video)):
1314
+ pred_original_sample[:, :, v_i] = self.vae.encode(pred_video[v_i:v_i+1]).latent_dist.sample(generator)
1315
+ pred_original_sample[:, :, v_i] *= self.vae.config.scaling_factor
1316
+
1317
+
1318
+ # predict xt-1 with smoothed x0
1319
+ alpha_prod_t_prev =self.get_alpha_prev(t)
1320
+ # preserve more details
1321
+ pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * noise_pred
1322
+ # compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
1323
+ latents = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
1324
+
1325
+
1326
+ # call the callback, if provided
1327
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1328
+ progress_bar.update()
1329
+ if callback is not None and i % callback_steps == 0:
1330
+ callback(i, t, latents)
1331
+
1332
+ # If we do sequential model offloading, let's offload unet and controlnet
1333
+ # manually for max memory savings
1334
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1335
+ self.unet.to("cpu")
1336
+ self.controlnet.to("cpu")
1337
+ torch.cuda.empty_cache()
1338
+ # Post-processing
1339
+ video = self.decode_latents(latents)
1340
+
1341
+ # Convert to tensor
1342
+ if output_type == "tensor":
1343
+ video = torch.from_numpy(video)
1344
+
1345
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1346
+ self.final_offload_hook.offload()
1347
+
1348
+ if not return_dict:
1349
+ return video
1350
+
1351
+ return ControlVideoPipelineOutput(videos=video)
models/resnet.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from einops import rearrange
8
+
9
+
10
+ class InflatedConv3d(nn.Conv2d):
11
+ def forward(self, x):
12
+ video_length = x.shape[2]
13
+
14
+ x = rearrange(x, "b c f h w -> (b f) c h w")
15
+ x = super().forward(x)
16
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17
+
18
+ return x
19
+
20
+ class TemporalConv1d(nn.Conv1d):
21
+ def forward(self, x):
22
+ b, c, f, h, w = x.shape
23
+ y = rearrange(x.clone(), "b c f h w -> (b h w) c f")
24
+ y = super().forward(y)
25
+ y = rearrange(y, "(b h w) c f -> b c f h w", b=b, h=h, w=w)
26
+ return y
27
+
28
+
29
+ class Upsample3D(nn.Module):
30
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
31
+ super().__init__()
32
+ self.channels = channels
33
+ self.out_channels = out_channels or channels
34
+ self.use_conv = use_conv
35
+ self.use_conv_transpose = use_conv_transpose
36
+ self.name = name
37
+
38
+ conv = None
39
+ if use_conv_transpose:
40
+ raise NotImplementedError
41
+ elif use_conv:
42
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
43
+
44
+ if name == "conv":
45
+ self.conv = conv
46
+ else:
47
+ self.Conv2d_0 = conv
48
+
49
+ def forward(self, hidden_states, output_size=None):
50
+ assert hidden_states.shape[1] == self.channels
51
+
52
+ if self.use_conv_transpose:
53
+ raise NotImplementedError
54
+
55
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
56
+ dtype = hidden_states.dtype
57
+ if dtype == torch.bfloat16:
58
+ hidden_states = hidden_states.to(torch.float32)
59
+
60
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
61
+ if hidden_states.shape[0] >= 64:
62
+ hidden_states = hidden_states.contiguous()
63
+
64
+ # if `output_size` is passed we force the interpolation output
65
+ # size and do not make use of `scale_factor=2`
66
+ if output_size is None:
67
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
68
+ else:
69
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
70
+
71
+ # If the input is bfloat16, we cast back to bfloat16
72
+ if dtype == torch.bfloat16:
73
+ hidden_states = hidden_states.to(dtype)
74
+
75
+ if self.use_conv:
76
+ if self.name == "conv":
77
+ hidden_states = self.conv(hidden_states)
78
+ else:
79
+ hidden_states = self.Conv2d_0(hidden_states)
80
+
81
+ return hidden_states
82
+
83
+
84
+ class Downsample3D(nn.Module):
85
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
86
+ super().__init__()
87
+ self.channels = channels
88
+ self.out_channels = out_channels or channels
89
+ self.use_conv = use_conv
90
+ self.padding = padding
91
+ stride = 2
92
+ self.name = name
93
+
94
+ if use_conv:
95
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
96
+ else:
97
+ raise NotImplementedError
98
+
99
+ if name == "conv":
100
+ self.Conv2d_0 = conv
101
+ self.conv = conv
102
+ elif name == "Conv2d_0":
103
+ self.conv = conv
104
+ else:
105
+ self.conv = conv
106
+
107
+ def forward(self, hidden_states):
108
+ assert hidden_states.shape[1] == self.channels
109
+ if self.use_conv and self.padding == 0:
110
+ raise NotImplementedError
111
+
112
+ assert hidden_states.shape[1] == self.channels
113
+ hidden_states = self.conv(hidden_states)
114
+
115
+ return hidden_states
116
+
117
+
118
+ class ResnetBlock3D(nn.Module):
119
+ def __init__(
120
+ self,
121
+ *,
122
+ in_channels,
123
+ out_channels=None,
124
+ conv_shortcut=False,
125
+ dropout=0.0,
126
+ temb_channels=512,
127
+ groups=32,
128
+ groups_out=None,
129
+ pre_norm=True,
130
+ eps=1e-6,
131
+ non_linearity="swish",
132
+ time_embedding_norm="default",
133
+ output_scale_factor=1.0,
134
+ use_in_shortcut=None,
135
+ ):
136
+ super().__init__()
137
+ self.pre_norm = pre_norm
138
+ self.pre_norm = True
139
+ self.in_channels = in_channels
140
+ out_channels = in_channels if out_channels is None else out_channels
141
+ self.out_channels = out_channels
142
+ self.use_conv_shortcut = conv_shortcut
143
+ self.time_embedding_norm = time_embedding_norm
144
+ self.output_scale_factor = output_scale_factor
145
+
146
+ if groups_out is None:
147
+ groups_out = groups
148
+
149
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
150
+
151
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
152
+
153
+ if temb_channels is not None:
154
+ if self.time_embedding_norm == "default":
155
+ time_emb_proj_out_channels = out_channels
156
+ elif self.time_embedding_norm == "scale_shift":
157
+ time_emb_proj_out_channels = out_channels * 2
158
+ else:
159
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
160
+
161
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
162
+ else:
163
+ self.time_emb_proj = None
164
+
165
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
166
+ self.dropout = torch.nn.Dropout(dropout)
167
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
168
+
169
+ if non_linearity == "swish":
170
+ self.nonlinearity = lambda x: F.silu(x)
171
+ elif non_linearity == "mish":
172
+ self.nonlinearity = Mish()
173
+ elif non_linearity == "silu":
174
+ self.nonlinearity = nn.SiLU()
175
+
176
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
177
+
178
+ self.conv_shortcut = None
179
+ if self.use_in_shortcut:
180
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
181
+
182
+ def forward(self, input_tensor, temb):
183
+ hidden_states = input_tensor
184
+
185
+ hidden_states = self.norm1(hidden_states)
186
+ hidden_states = self.nonlinearity(hidden_states)
187
+
188
+ hidden_states = self.conv1(hidden_states)
189
+
190
+ if temb is not None:
191
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
192
+
193
+ if temb is not None and self.time_embedding_norm == "default":
194
+ hidden_states = hidden_states + temb
195
+
196
+ hidden_states = self.norm2(hidden_states)
197
+
198
+ if temb is not None and self.time_embedding_norm == "scale_shift":
199
+ scale, shift = torch.chunk(temb, 2, dim=1)
200
+ hidden_states = hidden_states * (1 + scale) + shift
201
+
202
+ hidden_states = self.nonlinearity(hidden_states)
203
+
204
+ hidden_states = self.dropout(hidden_states)
205
+ hidden_states = self.conv2(hidden_states)
206
+
207
+ if self.conv_shortcut is not None:
208
+ input_tensor = self.conv_shortcut(input_tensor)
209
+
210
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
211
+
212
+ return output_tensor
213
+
214
+
215
+ class Mish(torch.nn.Module):
216
+ def forward(self, hidden_states):
217
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
models/unet.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import json
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers import ModelMixin
15
+ from diffusers.utils import BaseOutput, logging
16
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
17
+ from .unet_blocks import (
18
+ CrossAttnDownBlock3D,
19
+ CrossAttnUpBlock3D,
20
+ DownBlock3D,
21
+ UNetMidBlock3DCrossAttn,
22
+ UpBlock3D,
23
+ get_down_block,
24
+ get_up_block,
25
+ )
26
+ from .resnet import InflatedConv3d
27
+
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ @dataclass
33
+ class UNet3DConditionOutput(BaseOutput):
34
+ sample: torch.FloatTensor
35
+
36
+
37
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
38
+ _supports_gradient_checkpointing = True
39
+
40
+ @register_to_config
41
+ def __init__(
42
+ self,
43
+ sample_size: Optional[int] = None,
44
+ in_channels: int = 4,
45
+ out_channels: int = 4,
46
+ center_input_sample: bool = False,
47
+ flip_sin_to_cos: bool = True,
48
+ freq_shift: int = 0,
49
+ down_block_types: Tuple[str] = (
50
+ "CrossAttnDownBlock3D",
51
+ "CrossAttnDownBlock3D",
52
+ "CrossAttnDownBlock3D",
53
+ "DownBlock3D",
54
+ ),
55
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
56
+ up_block_types: Tuple[str] = (
57
+ "UpBlock3D",
58
+ "CrossAttnUpBlock3D",
59
+ "CrossAttnUpBlock3D",
60
+ "CrossAttnUpBlock3D"
61
+ ),
62
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
63
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
64
+ layers_per_block: int = 2,
65
+ downsample_padding: int = 1,
66
+ mid_block_scale_factor: float = 1,
67
+ act_fn: str = "silu",
68
+ norm_num_groups: int = 32,
69
+ norm_eps: float = 1e-5,
70
+ cross_attention_dim: int = 1280,
71
+ attention_head_dim: Union[int, Tuple[int]] = 8,
72
+ dual_cross_attention: bool = False,
73
+ use_linear_projection: bool = False,
74
+ class_embed_type: Optional[str] = None,
75
+ num_class_embeds: Optional[int] = None,
76
+ upcast_attention: bool = False,
77
+ resnet_time_scale_shift: str = "default",
78
+ ):
79
+ super().__init__()
80
+
81
+ self.sample_size = sample_size
82
+ time_embed_dim = block_out_channels[0] * 4
83
+
84
+ # input
85
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
86
+
87
+ # time
88
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
89
+ timestep_input_dim = block_out_channels[0]
90
+
91
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
92
+
93
+ # class embedding
94
+ if class_embed_type is None and num_class_embeds is not None:
95
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
96
+ elif class_embed_type == "timestep":
97
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
98
+ elif class_embed_type == "identity":
99
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
100
+ else:
101
+ self.class_embedding = None
102
+
103
+ self.down_blocks = nn.ModuleList([])
104
+ self.mid_block = None
105
+ self.up_blocks = nn.ModuleList([])
106
+
107
+ if isinstance(only_cross_attention, bool):
108
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
109
+
110
+ if isinstance(attention_head_dim, int):
111
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
112
+
113
+ # down
114
+ output_channel = block_out_channels[0]
115
+ for i, down_block_type in enumerate(down_block_types):
116
+ input_channel = output_channel
117
+ output_channel = block_out_channels[i]
118
+ is_final_block = i == len(block_out_channels) - 1
119
+
120
+ down_block = get_down_block(
121
+ down_block_type,
122
+ num_layers=layers_per_block,
123
+ in_channels=input_channel,
124
+ out_channels=output_channel,
125
+ temb_channels=time_embed_dim,
126
+ add_downsample=not is_final_block,
127
+ resnet_eps=norm_eps,
128
+ resnet_act_fn=act_fn,
129
+ resnet_groups=norm_num_groups,
130
+ cross_attention_dim=cross_attention_dim,
131
+ attn_num_head_channels=attention_head_dim[i],
132
+ downsample_padding=downsample_padding,
133
+ dual_cross_attention=dual_cross_attention,
134
+ use_linear_projection=use_linear_projection,
135
+ only_cross_attention=only_cross_attention[i],
136
+ upcast_attention=upcast_attention,
137
+ resnet_time_scale_shift=resnet_time_scale_shift,
138
+ )
139
+ self.down_blocks.append(down_block)
140
+
141
+ # mid
142
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
143
+ self.mid_block = UNetMidBlock3DCrossAttn(
144
+ in_channels=block_out_channels[-1],
145
+ temb_channels=time_embed_dim,
146
+ resnet_eps=norm_eps,
147
+ resnet_act_fn=act_fn,
148
+ output_scale_factor=mid_block_scale_factor,
149
+ resnet_time_scale_shift=resnet_time_scale_shift,
150
+ cross_attention_dim=cross_attention_dim,
151
+ attn_num_head_channels=attention_head_dim[-1],
152
+ resnet_groups=norm_num_groups,
153
+ dual_cross_attention=dual_cross_attention,
154
+ use_linear_projection=use_linear_projection,
155
+ upcast_attention=upcast_attention,
156
+ )
157
+ else:
158
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
159
+
160
+ # count how many layers upsample the videos
161
+ self.num_upsamplers = 0
162
+
163
+ # up
164
+ reversed_block_out_channels = list(reversed(block_out_channels))
165
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
166
+ only_cross_attention = list(reversed(only_cross_attention))
167
+ output_channel = reversed_block_out_channels[0]
168
+ for i, up_block_type in enumerate(up_block_types):
169
+ is_final_block = i == len(block_out_channels) - 1
170
+
171
+ prev_output_channel = output_channel
172
+ output_channel = reversed_block_out_channels[i]
173
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
174
+
175
+ # add upsample block for all BUT final layer
176
+ if not is_final_block:
177
+ add_upsample = True
178
+ self.num_upsamplers += 1
179
+ else:
180
+ add_upsample = False
181
+
182
+ up_block = get_up_block(
183
+ up_block_type,
184
+ num_layers=layers_per_block + 1,
185
+ in_channels=input_channel,
186
+ out_channels=output_channel,
187
+ prev_output_channel=prev_output_channel,
188
+ temb_channels=time_embed_dim,
189
+ add_upsample=add_upsample,
190
+ resnet_eps=norm_eps,
191
+ resnet_act_fn=act_fn,
192
+ resnet_groups=norm_num_groups,
193
+ cross_attention_dim=cross_attention_dim,
194
+ attn_num_head_channels=reversed_attention_head_dim[i],
195
+ dual_cross_attention=dual_cross_attention,
196
+ use_linear_projection=use_linear_projection,
197
+ only_cross_attention=only_cross_attention[i],
198
+ upcast_attention=upcast_attention,
199
+ resnet_time_scale_shift=resnet_time_scale_shift,
200
+ )
201
+ self.up_blocks.append(up_block)
202
+ prev_output_channel = output_channel
203
+
204
+ # out
205
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
206
+ self.conv_act = nn.SiLU()
207
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
208
+
209
+ def set_attention_slice(self, slice_size):
210
+ r"""
211
+ Enable sliced attention computation.
212
+
213
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
214
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
215
+
216
+ Args:
217
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
218
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
219
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
220
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
221
+ must be a multiple of `slice_size`.
222
+ """
223
+ sliceable_head_dims = []
224
+
225
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
226
+ if hasattr(module, "set_attention_slice"):
227
+ sliceable_head_dims.append(module.sliceable_head_dim)
228
+
229
+ for child in module.children():
230
+ fn_recursive_retrieve_slicable_dims(child)
231
+
232
+ # retrieve number of attention layers
233
+ for module in self.children():
234
+ fn_recursive_retrieve_slicable_dims(module)
235
+
236
+ num_slicable_layers = len(sliceable_head_dims)
237
+
238
+ if slice_size == "auto":
239
+ # half the attention head size is usually a good trade-off between
240
+ # speed and memory
241
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
242
+ elif slice_size == "max":
243
+ # make smallest slice possible
244
+ slice_size = num_slicable_layers * [1]
245
+
246
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
247
+
248
+ if len(slice_size) != len(sliceable_head_dims):
249
+ raise ValueError(
250
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
251
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
252
+ )
253
+
254
+ for i in range(len(slice_size)):
255
+ size = slice_size[i]
256
+ dim = sliceable_head_dims[i]
257
+ if size is not None and size > dim:
258
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
259
+
260
+ # Recursively walk through all the children.
261
+ # Any children which exposes the set_attention_slice method
262
+ # gets the message
263
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
264
+ if hasattr(module, "set_attention_slice"):
265
+ module.set_attention_slice(slice_size.pop())
266
+
267
+ for child in module.children():
268
+ fn_recursive_set_attention_slice(child, slice_size)
269
+
270
+ reversed_slice_size = list(reversed(slice_size))
271
+ for module in self.children():
272
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
273
+
274
+ def _set_gradient_checkpointing(self, module, value=False):
275
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
276
+ module.gradient_checkpointing = value
277
+
278
+ def forward(
279
+ self,
280
+ sample: torch.FloatTensor,
281
+ timestep: Union[torch.Tensor, float, int],
282
+ encoder_hidden_states: torch.Tensor,
283
+ class_labels: Optional[torch.Tensor] = None,
284
+ attention_mask: Optional[torch.Tensor] = None,
285
+ return_dict: bool = True,
286
+ cross_attention_kwargs = None,
287
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
288
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
289
+ inter_frame = False,
290
+ ) -> Union[UNet3DConditionOutput, Tuple]:
291
+ r"""
292
+ Args:
293
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
294
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
295
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
296
+ return_dict (`bool`, *optional*, defaults to `True`):
297
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
298
+
299
+ Returns:
300
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
301
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
302
+ returning a tuple, the first element is the sample tensor.
303
+ """
304
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
305
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
306
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
307
+ # on the fly if necessary.
308
+ default_overall_up_factor = 2**self.num_upsamplers
309
+
310
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
311
+ forward_upsample_size = False
312
+ upsample_size = None
313
+
314
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
315
+ logger.info("Forward upsample size to force interpolation output size.")
316
+ forward_upsample_size = True
317
+
318
+ # prepare attention_mask
319
+ if attention_mask is not None:
320
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
321
+ attention_mask = attention_mask.unsqueeze(1)
322
+
323
+ # center input if necessary
324
+ if self.config.center_input_sample:
325
+ sample = 2 * sample - 1.0
326
+
327
+ # time
328
+ timesteps = timestep
329
+ if not torch.is_tensor(timesteps):
330
+ # This would be a good case for the `match` statement (Python 3.10+)
331
+ is_mps = sample.device.type == "mps"
332
+ if isinstance(timestep, float):
333
+ dtype = torch.float32 if is_mps else torch.float64
334
+ else:
335
+ dtype = torch.int32 if is_mps else torch.int64
336
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
337
+ elif len(timesteps.shape) == 0:
338
+ timesteps = timesteps[None].to(sample.device)
339
+
340
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
341
+ timesteps = timesteps.expand(sample.shape[0])
342
+
343
+ t_emb = self.time_proj(timesteps)
344
+
345
+ # timesteps does not contain any weights and will always return f32 tensors
346
+ # but time_embedding might actually be running in fp16. so we need to cast here.
347
+ # there might be better ways to encapsulate this.
348
+ t_emb = t_emb.to(dtype=self.dtype)
349
+ emb = self.time_embedding(t_emb)
350
+
351
+ if self.class_embedding is not None:
352
+ if class_labels is None:
353
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
354
+
355
+ if self.config.class_embed_type == "timestep":
356
+ class_labels = self.time_proj(class_labels)
357
+
358
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
359
+ emb = emb + class_emb
360
+
361
+ # pre-process
362
+ sample = self.conv_in(sample)
363
+
364
+ # down
365
+ down_block_res_samples = (sample,)
366
+ for downsample_block in self.down_blocks:
367
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
368
+ sample, res_samples = downsample_block(
369
+ hidden_states=sample,
370
+ temb=emb,
371
+ encoder_hidden_states=encoder_hidden_states,
372
+ attention_mask=attention_mask,
373
+ inter_frame=inter_frame
374
+ )
375
+ else:
376
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
377
+
378
+ down_block_res_samples += res_samples
379
+
380
+ if down_block_additional_residuals is not None:
381
+ new_down_block_res_samples = ()
382
+
383
+ for down_block_res_sample, down_block_additional_residual in zip(
384
+ down_block_res_samples, down_block_additional_residuals
385
+ ):
386
+ down_block_res_sample += down_block_additional_residual
387
+ new_down_block_res_samples += (down_block_res_sample,)
388
+
389
+ down_block_res_samples = new_down_block_res_samples
390
+
391
+ # mid
392
+ sample = self.mid_block(
393
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask,
394
+ inter_frame=inter_frame
395
+
396
+ )
397
+
398
+ if mid_block_additional_residual is not None:
399
+ sample += mid_block_additional_residual
400
+
401
+ # up
402
+ for i, upsample_block in enumerate(self.up_blocks):
403
+ is_final_block = i == len(self.up_blocks) - 1
404
+
405
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
406
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
407
+
408
+ # if we have not reached the final block and need to forward the
409
+ # upsample size, we do it here
410
+ if not is_final_block and forward_upsample_size:
411
+ upsample_size = down_block_res_samples[-1].shape[2:]
412
+
413
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
414
+ sample = upsample_block(
415
+ hidden_states=sample,
416
+ temb=emb,
417
+ res_hidden_states_tuple=res_samples,
418
+ encoder_hidden_states=encoder_hidden_states,
419
+ upsample_size=upsample_size,
420
+ attention_mask=attention_mask,
421
+ inter_frame=inter_frame
422
+ )
423
+ else:
424
+ sample = upsample_block(
425
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
426
+ )
427
+ # post-process
428
+ sample = self.conv_norm_out(sample)
429
+ sample = self.conv_act(sample)
430
+ sample = self.conv_out(sample)
431
+
432
+ if not return_dict:
433
+ return (sample,)
434
+
435
+ return UNet3DConditionOutput(sample=sample)
436
+
437
+ @classmethod
438
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None):
439
+ if subfolder is not None:
440
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
441
+
442
+ config_file = os.path.join(pretrained_model_path, 'config.json')
443
+ if not os.path.isfile(config_file):
444
+ raise RuntimeError(f"{config_file} does not exist")
445
+ with open(config_file, "r") as f:
446
+ config = json.load(f)
447
+ config["_class_name"] = cls.__name__
448
+ config["down_block_types"] = [
449
+ "CrossAttnDownBlock3D",
450
+ "CrossAttnDownBlock3D",
451
+ "CrossAttnDownBlock3D",
452
+ "DownBlock3D"
453
+ ]
454
+ config["up_block_types"] = [
455
+ "UpBlock3D",
456
+ "CrossAttnUpBlock3D",
457
+ "CrossAttnUpBlock3D",
458
+ "CrossAttnUpBlock3D"
459
+ ]
460
+
461
+ from diffusers.utils import WEIGHTS_NAME
462
+ model = cls.from_config(config)
463
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
464
+ if not os.path.isfile(model_file):
465
+ raise RuntimeError(f"{model_file} does not exist")
466
+ state_dict = torch.load(model_file, map_location="cpu")
467
+ # for k, v in model.state_dict().items():
468
+ # if '_temp.' in k:
469
+ # state_dict.update({k: v})
470
+ model.load_state_dict(state_dict, strict=False)
471
+
472
+ return model
models/unet_blocks.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .attention import Transformer3DModel
7
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8
+
9
+
10
+ def get_down_block(
11
+ down_block_type,
12
+ num_layers,
13
+ in_channels,
14
+ out_channels,
15
+ temb_channels,
16
+ add_downsample,
17
+ resnet_eps,
18
+ resnet_act_fn,
19
+ attn_num_head_channels,
20
+ resnet_groups=None,
21
+ cross_attention_dim=None,
22
+ downsample_padding=None,
23
+ dual_cross_attention=False,
24
+ use_linear_projection=False,
25
+ only_cross_attention=False,
26
+ upcast_attention=False,
27
+ resnet_time_scale_shift="default",
28
+ ):
29
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
30
+ if down_block_type == "DownBlock3D":
31
+ return DownBlock3D(
32
+ num_layers=num_layers,
33
+ in_channels=in_channels,
34
+ out_channels=out_channels,
35
+ temb_channels=temb_channels,
36
+ add_downsample=add_downsample,
37
+ resnet_eps=resnet_eps,
38
+ resnet_act_fn=resnet_act_fn,
39
+ resnet_groups=resnet_groups,
40
+ downsample_padding=downsample_padding,
41
+ resnet_time_scale_shift=resnet_time_scale_shift,
42
+ )
43
+ elif down_block_type == "CrossAttnDownBlock3D":
44
+ if cross_attention_dim is None:
45
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
46
+ return CrossAttnDownBlock3D(
47
+ num_layers=num_layers,
48
+ in_channels=in_channels,
49
+ out_channels=out_channels,
50
+ temb_channels=temb_channels,
51
+ add_downsample=add_downsample,
52
+ resnet_eps=resnet_eps,
53
+ resnet_act_fn=resnet_act_fn,
54
+ resnet_groups=resnet_groups,
55
+ downsample_padding=downsample_padding,
56
+ cross_attention_dim=cross_attention_dim,
57
+ attn_num_head_channels=attn_num_head_channels,
58
+ dual_cross_attention=dual_cross_attention,
59
+ use_linear_projection=use_linear_projection,
60
+ only_cross_attention=only_cross_attention,
61
+ upcast_attention=upcast_attention,
62
+ resnet_time_scale_shift=resnet_time_scale_shift,
63
+ )
64
+ raise ValueError(f"{down_block_type} does not exist.")
65
+
66
+
67
+ def get_up_block(
68
+ up_block_type,
69
+ num_layers,
70
+ in_channels,
71
+ out_channels,
72
+ prev_output_channel,
73
+ temb_channels,
74
+ add_upsample,
75
+ resnet_eps,
76
+ resnet_act_fn,
77
+ attn_num_head_channels,
78
+ resnet_groups=None,
79
+ cross_attention_dim=None,
80
+ dual_cross_attention=False,
81
+ use_linear_projection=False,
82
+ only_cross_attention=False,
83
+ upcast_attention=False,
84
+ resnet_time_scale_shift="default",
85
+ ):
86
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
87
+ if up_block_type == "UpBlock3D":
88
+ return UpBlock3D(
89
+ num_layers=num_layers,
90
+ in_channels=in_channels,
91
+ out_channels=out_channels,
92
+ prev_output_channel=prev_output_channel,
93
+ temb_channels=temb_channels,
94
+ add_upsample=add_upsample,
95
+ resnet_eps=resnet_eps,
96
+ resnet_act_fn=resnet_act_fn,
97
+ resnet_groups=resnet_groups,
98
+ resnet_time_scale_shift=resnet_time_scale_shift,
99
+ )
100
+ elif up_block_type == "CrossAttnUpBlock3D":
101
+ if cross_attention_dim is None:
102
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
103
+ return CrossAttnUpBlock3D(
104
+ num_layers=num_layers,
105
+ in_channels=in_channels,
106
+ out_channels=out_channels,
107
+ prev_output_channel=prev_output_channel,
108
+ temb_channels=temb_channels,
109
+ add_upsample=add_upsample,
110
+ resnet_eps=resnet_eps,
111
+ resnet_act_fn=resnet_act_fn,
112
+ resnet_groups=resnet_groups,
113
+ cross_attention_dim=cross_attention_dim,
114
+ attn_num_head_channels=attn_num_head_channels,
115
+ dual_cross_attention=dual_cross_attention,
116
+ use_linear_projection=use_linear_projection,
117
+ only_cross_attention=only_cross_attention,
118
+ upcast_attention=upcast_attention,
119
+ resnet_time_scale_shift=resnet_time_scale_shift,
120
+ )
121
+ raise ValueError(f"{up_block_type} does not exist.")
122
+
123
+
124
+ class UNetMidBlock3DCrossAttn(nn.Module):
125
+ def __init__(
126
+ self,
127
+ in_channels: int,
128
+ temb_channels: int,
129
+ dropout: float = 0.0,
130
+ num_layers: int = 1,
131
+ resnet_eps: float = 1e-6,
132
+ resnet_time_scale_shift: str = "default",
133
+ resnet_act_fn: str = "swish",
134
+ resnet_groups: int = 32,
135
+ resnet_pre_norm: bool = True,
136
+ attn_num_head_channels=1,
137
+ output_scale_factor=1.0,
138
+ cross_attention_dim=1280,
139
+ dual_cross_attention=False,
140
+ use_linear_projection=False,
141
+ upcast_attention=False,
142
+ ):
143
+ super().__init__()
144
+
145
+ self.has_cross_attention = True
146
+ self.attn_num_head_channels = attn_num_head_channels
147
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
148
+
149
+ # there is always at least one resnet
150
+ resnets = [
151
+ ResnetBlock3D(
152
+ in_channels=in_channels,
153
+ out_channels=in_channels,
154
+ temb_channels=temb_channels,
155
+ eps=resnet_eps,
156
+ groups=resnet_groups,
157
+ dropout=dropout,
158
+ time_embedding_norm=resnet_time_scale_shift,
159
+ non_linearity=resnet_act_fn,
160
+ output_scale_factor=output_scale_factor,
161
+ pre_norm=resnet_pre_norm,
162
+ )
163
+ ]
164
+ attentions = []
165
+
166
+ for _ in range(num_layers):
167
+ if dual_cross_attention:
168
+ raise NotImplementedError
169
+ attentions.append(
170
+ Transformer3DModel(
171
+ attn_num_head_channels,
172
+ in_channels // attn_num_head_channels,
173
+ in_channels=in_channels,
174
+ num_layers=1,
175
+ cross_attention_dim=cross_attention_dim,
176
+ norm_num_groups=resnet_groups,
177
+ use_linear_projection=use_linear_projection,
178
+ upcast_attention=upcast_attention,
179
+ )
180
+ )
181
+ resnets.append(
182
+ ResnetBlock3D(
183
+ in_channels=in_channels,
184
+ out_channels=in_channels,
185
+ temb_channels=temb_channels,
186
+ eps=resnet_eps,
187
+ groups=resnet_groups,
188
+ dropout=dropout,
189
+ time_embedding_norm=resnet_time_scale_shift,
190
+ non_linearity=resnet_act_fn,
191
+ output_scale_factor=output_scale_factor,
192
+ pre_norm=resnet_pre_norm,
193
+ )
194
+ )
195
+
196
+ self.attentions = nn.ModuleList(attentions)
197
+ self.resnets = nn.ModuleList(resnets)
198
+
199
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, inter_frame=False):
200
+ hidden_states = self.resnets[0](hidden_states, temb)
201
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
202
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, inter_frame=inter_frame).sample
203
+ hidden_states = resnet(hidden_states, temb)
204
+
205
+ return hidden_states
206
+
207
+
208
+ class CrossAttnDownBlock3D(nn.Module):
209
+ def __init__(
210
+ self,
211
+ in_channels: int,
212
+ out_channels: int,
213
+ temb_channels: int,
214
+ dropout: float = 0.0,
215
+ num_layers: int = 1,
216
+ resnet_eps: float = 1e-6,
217
+ resnet_time_scale_shift: str = "default",
218
+ resnet_act_fn: str = "swish",
219
+ resnet_groups: int = 32,
220
+ resnet_pre_norm: bool = True,
221
+ attn_num_head_channels=1,
222
+ cross_attention_dim=1280,
223
+ output_scale_factor=1.0,
224
+ downsample_padding=1,
225
+ add_downsample=True,
226
+ dual_cross_attention=False,
227
+ use_linear_projection=False,
228
+ only_cross_attention=False,
229
+ upcast_attention=False,
230
+ ):
231
+ super().__init__()
232
+ resnets = []
233
+ attentions = []
234
+
235
+ self.has_cross_attention = True
236
+ self.attn_num_head_channels = attn_num_head_channels
237
+
238
+ for i in range(num_layers):
239
+ in_channels = in_channels if i == 0 else out_channels
240
+ resnets.append(
241
+ ResnetBlock3D(
242
+ in_channels=in_channels,
243
+ out_channels=out_channels,
244
+ temb_channels=temb_channels,
245
+ eps=resnet_eps,
246
+ groups=resnet_groups,
247
+ dropout=dropout,
248
+ time_embedding_norm=resnet_time_scale_shift,
249
+ non_linearity=resnet_act_fn,
250
+ output_scale_factor=output_scale_factor,
251
+ pre_norm=resnet_pre_norm,
252
+ )
253
+ )
254
+ if dual_cross_attention:
255
+ raise NotImplementedError
256
+ attentions.append(
257
+ Transformer3DModel(
258
+ attn_num_head_channels,
259
+ out_channels // attn_num_head_channels,
260
+ in_channels=out_channels,
261
+ num_layers=1,
262
+ cross_attention_dim=cross_attention_dim,
263
+ norm_num_groups=resnet_groups,
264
+ use_linear_projection=use_linear_projection,
265
+ only_cross_attention=only_cross_attention,
266
+ upcast_attention=upcast_attention,
267
+ )
268
+ )
269
+ self.attentions = nn.ModuleList(attentions)
270
+ self.resnets = nn.ModuleList(resnets)
271
+
272
+ if add_downsample:
273
+ self.downsamplers = nn.ModuleList(
274
+ [
275
+ Downsample3D(
276
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
277
+ )
278
+ ]
279
+ )
280
+ else:
281
+ self.downsamplers = None
282
+
283
+ self.gradient_checkpointing = False
284
+
285
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, inter_frame=False):
286
+ output_states = ()
287
+
288
+ for resnet, attn in zip(self.resnets, self.attentions):
289
+ if self.training and self.gradient_checkpointing:
290
+
291
+ def create_custom_forward(module, return_dict=None, inter_frame=None):
292
+ def custom_forward(*inputs):
293
+ if return_dict is not None:
294
+ return module(*inputs, return_dict=return_dict, inter_frame=inter_frame)
295
+ else:
296
+ return module(*inputs)
297
+
298
+ return custom_forward
299
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
300
+ hidden_states = torch.utils.checkpoint.checkpoint(
301
+ create_custom_forward(attn, return_dict=False, inter_frame=inter_frame),
302
+ hidden_states,
303
+ encoder_hidden_states,
304
+ )[0]
305
+ else:
306
+ hidden_states = resnet(hidden_states, temb)
307
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, inter_frame=inter_frame).sample
308
+
309
+ output_states += (hidden_states,)
310
+
311
+ if self.downsamplers is not None:
312
+ for downsampler in self.downsamplers:
313
+ hidden_states = downsampler(hidden_states)
314
+
315
+ output_states += (hidden_states,)
316
+
317
+ return hidden_states, output_states
318
+
319
+
320
+ class DownBlock3D(nn.Module):
321
+ def __init__(
322
+ self,
323
+ in_channels: int,
324
+ out_channels: int,
325
+ temb_channels: int,
326
+ dropout: float = 0.0,
327
+ num_layers: int = 1,
328
+ resnet_eps: float = 1e-6,
329
+ resnet_time_scale_shift: str = "default",
330
+ resnet_act_fn: str = "swish",
331
+ resnet_groups: int = 32,
332
+ resnet_pre_norm: bool = True,
333
+ output_scale_factor=1.0,
334
+ add_downsample=True,
335
+ downsample_padding=1,
336
+ ):
337
+ super().__init__()
338
+ resnets = []
339
+
340
+ for i in range(num_layers):
341
+ in_channels = in_channels if i == 0 else out_channels
342
+ resnets.append(
343
+ ResnetBlock3D(
344
+ in_channels=in_channels,
345
+ out_channels=out_channels,
346
+ temb_channels=temb_channels,
347
+ eps=resnet_eps,
348
+ groups=resnet_groups,
349
+ dropout=dropout,
350
+ time_embedding_norm=resnet_time_scale_shift,
351
+ non_linearity=resnet_act_fn,
352
+ output_scale_factor=output_scale_factor,
353
+ pre_norm=resnet_pre_norm,
354
+ )
355
+ )
356
+
357
+ self.resnets = nn.ModuleList(resnets)
358
+
359
+ if add_downsample:
360
+ self.downsamplers = nn.ModuleList(
361
+ [
362
+ Downsample3D(
363
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
364
+ )
365
+ ]
366
+ )
367
+ else:
368
+ self.downsamplers = None
369
+
370
+ self.gradient_checkpointing = False
371
+
372
+ def forward(self, hidden_states, temb=None):
373
+ output_states = ()
374
+
375
+ for resnet in self.resnets:
376
+ if self.training and self.gradient_checkpointing:
377
+
378
+ def create_custom_forward(module):
379
+ def custom_forward(*inputs):
380
+ return module(*inputs)
381
+
382
+ return custom_forward
383
+
384
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
385
+ else:
386
+ hidden_states = resnet(hidden_states, temb)
387
+
388
+ output_states += (hidden_states,)
389
+
390
+ if self.downsamplers is not None:
391
+ for downsampler in self.downsamplers:
392
+ hidden_states = downsampler(hidden_states)
393
+
394
+ output_states += (hidden_states,)
395
+
396
+ return hidden_states, output_states
397
+
398
+
399
+ class CrossAttnUpBlock3D(nn.Module):
400
+ def __init__(
401
+ self,
402
+ in_channels: int,
403
+ out_channels: int,
404
+ prev_output_channel: int,
405
+ temb_channels: int,
406
+ dropout: float = 0.0,
407
+ num_layers: int = 1,
408
+ resnet_eps: float = 1e-6,
409
+ resnet_time_scale_shift: str = "default",
410
+ resnet_act_fn: str = "swish",
411
+ resnet_groups: int = 32,
412
+ resnet_pre_norm: bool = True,
413
+ attn_num_head_channels=1,
414
+ cross_attention_dim=1280,
415
+ output_scale_factor=1.0,
416
+ add_upsample=True,
417
+ dual_cross_attention=False,
418
+ use_linear_projection=False,
419
+ only_cross_attention=False,
420
+ upcast_attention=False,
421
+ ):
422
+ super().__init__()
423
+ resnets = []
424
+ attentions = []
425
+
426
+ self.has_cross_attention = True
427
+ self.attn_num_head_channels = attn_num_head_channels
428
+
429
+ for i in range(num_layers):
430
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
431
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
432
+
433
+ resnets.append(
434
+ ResnetBlock3D(
435
+ in_channels=resnet_in_channels + res_skip_channels,
436
+ out_channels=out_channels,
437
+ temb_channels=temb_channels,
438
+ eps=resnet_eps,
439
+ groups=resnet_groups,
440
+ dropout=dropout,
441
+ time_embedding_norm=resnet_time_scale_shift,
442
+ non_linearity=resnet_act_fn,
443
+ output_scale_factor=output_scale_factor,
444
+ pre_norm=resnet_pre_norm,
445
+ )
446
+ )
447
+ if dual_cross_attention:
448
+ raise NotImplementedError
449
+ attentions.append(
450
+ Transformer3DModel(
451
+ attn_num_head_channels,
452
+ out_channels // attn_num_head_channels,
453
+ in_channels=out_channels,
454
+ num_layers=1,
455
+ cross_attention_dim=cross_attention_dim,
456
+ norm_num_groups=resnet_groups,
457
+ use_linear_projection=use_linear_projection,
458
+ only_cross_attention=only_cross_attention,
459
+ upcast_attention=upcast_attention,
460
+ )
461
+ )
462
+
463
+ self.attentions = nn.ModuleList(attentions)
464
+ self.resnets = nn.ModuleList(resnets)
465
+
466
+ if add_upsample:
467
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
468
+ else:
469
+ self.upsamplers = None
470
+
471
+ self.gradient_checkpointing = False
472
+
473
+ def forward(
474
+ self,
475
+ hidden_states,
476
+ res_hidden_states_tuple,
477
+ temb=None,
478
+ encoder_hidden_states=None,
479
+ upsample_size=None,
480
+ attention_mask=None,
481
+ inter_frame=False
482
+ ):
483
+ for resnet, attn in zip(self.resnets, self.attentions):
484
+ # pop res hidden states
485
+ res_hidden_states = res_hidden_states_tuple[-1]
486
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
487
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
488
+
489
+ if self.training and self.gradient_checkpointing:
490
+
491
+ def create_custom_forward(module, return_dict=None, inter_frame=None):
492
+ def custom_forward(*inputs):
493
+ if return_dict is not None:
494
+ return module(*inputs, return_dict=return_dict, inter_frame=inter_frame)
495
+ else:
496
+ return module(*inputs)
497
+
498
+ return custom_forward
499
+
500
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
501
+ hidden_states = torch.utils.checkpoint.checkpoint(
502
+ create_custom_forward(attn, return_dict=False, inter_frame=inter_frame),
503
+ hidden_states,
504
+ encoder_hidden_states,
505
+ )[0]
506
+ else:
507
+ hidden_states = resnet(hidden_states, temb)
508
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, inter_frame=inter_frame).sample
509
+
510
+ if self.upsamplers is not None:
511
+ for upsampler in self.upsamplers:
512
+ hidden_states = upsampler(hidden_states, upsample_size)
513
+
514
+ return hidden_states
515
+
516
+
517
+ class UpBlock3D(nn.Module):
518
+ def __init__(
519
+ self,
520
+ in_channels: int,
521
+ prev_output_channel: int,
522
+ out_channels: int,
523
+ temb_channels: int,
524
+ dropout: float = 0.0,
525
+ num_layers: int = 1,
526
+ resnet_eps: float = 1e-6,
527
+ resnet_time_scale_shift: str = "default",
528
+ resnet_act_fn: str = "swish",
529
+ resnet_groups: int = 32,
530
+ resnet_pre_norm: bool = True,
531
+ output_scale_factor=1.0,
532
+ add_upsample=True,
533
+ ):
534
+ super().__init__()
535
+ resnets = []
536
+
537
+ for i in range(num_layers):
538
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
539
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
540
+
541
+ resnets.append(
542
+ ResnetBlock3D(
543
+ in_channels=resnet_in_channels + res_skip_channels,
544
+ out_channels=out_channels,
545
+ temb_channels=temb_channels,
546
+ eps=resnet_eps,
547
+ groups=resnet_groups,
548
+ dropout=dropout,
549
+ time_embedding_norm=resnet_time_scale_shift,
550
+ non_linearity=resnet_act_fn,
551
+ output_scale_factor=output_scale_factor,
552
+ pre_norm=resnet_pre_norm,
553
+ )
554
+ )
555
+
556
+ self.resnets = nn.ModuleList(resnets)
557
+
558
+ if add_upsample:
559
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
560
+ else:
561
+ self.upsamplers = None
562
+
563
+ self.gradient_checkpointing = False
564
+
565
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
566
+ for resnet in self.resnets:
567
+ # pop res hidden states
568
+ res_hidden_states = res_hidden_states_tuple[-1]
569
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
570
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
571
+
572
+ if self.training and self.gradient_checkpointing:
573
+
574
+ def create_custom_forward(module):
575
+ def custom_forward(*inputs):
576
+ return module(*inputs)
577
+
578
+ return custom_forward
579
+
580
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
581
+ else:
582
+ hidden_states = resnet(hidden_states, temb)
583
+
584
+ if self.upsamplers is not None:
585
+ for upsampler in self.upsamplers:
586
+ hidden_states = upsampler(hidden_states, upsample_size)
587
+
588
+ return hidden_states
models/util.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ from typing import Union
5
+ import decord
6
+ decord.bridge.set_bridge('torch')
7
+ import torch
8
+ import torchvision
9
+ import PIL
10
+ from typing import List
11
+ from tqdm import tqdm
12
+ from einops import rearrange
13
+
14
+ from controlnet_aux import CannyDetector
15
+
16
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
17
+ videos = rearrange(videos, "b c t h w -> t b c h w")
18
+ outputs = []
19
+ for x in videos:
20
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
21
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
22
+ if rescale:
23
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
24
+ x = (x * 255).numpy().astype(np.uint8)
25
+ outputs.append(x)
26
+
27
+ os.makedirs(os.path.dirname(path), exist_ok=True)
28
+ imageio.mimsave(path, outputs, fps=fps)
29
+
30
+ def save_videos_grid_pil(videos: List[PIL.Image.Image], path: str, rescale=False, n_rows=4, fps=8):
31
+ videos = rearrange(videos, "b c t h w -> t b c h w")
32
+ outputs = []
33
+ for x in videos:
34
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
35
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
36
+ if rescale:
37
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
38
+ x = (x * 255).numpy().astype(np.uint8)
39
+ outputs.append(x)
40
+
41
+ os.makedirs(os.path.dirname(path), exist_ok=True)
42
+ imageio.mimsave(path, outputs, fps=fps)
43
+
44
+ def read_video(video_path, video_length, width=512, height=512, frame_rate=2):
45
+ vr = decord.VideoReader(video_path, width=width, height=height)
46
+ sample_index = list(range(0, len(vr), frame_rate))[:video_length]
47
+ video = vr.get_batch(sample_index)
48
+ video = rearrange(video, "f h w c -> f c h w")
49
+ video = (video / 127.5 - 1.0)
50
+ return video
51
+
52
+
53
+ def get_annotation(video, annotator):
54
+ t2i_transform = torchvision.transforms.ToPILImage()
55
+ annotation = []
56
+ for frame in video:
57
+ pil_frame = t2i_transform(frame)
58
+ if isinstance(annotator, CannyDetector):
59
+ annotation.append(annotator(pil_frame, low_threshold=100, high_threshold=200))
60
+ else:
61
+ annotation.append(annotator(pil_frame))
62
+ return annotation
63
+
64
+ # DDIM Inversion
65
+ @torch.no_grad()
66
+ def init_prompt(prompt, pipeline):
67
+ uncond_input = pipeline.tokenizer(
68
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
69
+ return_tensors="pt"
70
+ )
71
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
72
+ text_input = pipeline.tokenizer(
73
+ [prompt],
74
+ padding="max_length",
75
+ max_length=pipeline.tokenizer.model_max_length,
76
+ truncation=True,
77
+ return_tensors="pt",
78
+ )
79
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
80
+ context = torch.cat([uncond_embeddings, text_embeddings])
81
+
82
+ return context
83
+
84
+
85
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
86
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
87
+ timestep, next_timestep = min(
88
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
89
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
90
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
91
+ beta_prod_t = 1 - alpha_prod_t
92
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
93
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
94
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
95
+ return next_sample
96
+
97
+
98
+ def get_noise_pred_single(latents, t, context, unet):
99
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
100
+ return noise_pred
101
+
102
+
103
+ @torch.no_grad()
104
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
105
+ context = init_prompt(prompt, pipeline)
106
+ uncond_embeddings, cond_embeddings = context.chunk(2)
107
+ all_latent = [latent]
108
+ latent = latent.clone().detach()
109
+ for i in tqdm(range(num_inv_steps)):
110
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
111
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
112
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
113
+ all_latent.append(latent)
114
+ return all_latent
115
+
116
+
117
+ @torch.no_grad()
118
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
119
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
120
+ return ddim_latents
predict.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prediction interface for Cog ⚙️
2
+ # https://github.com/replicate/cog/blob/main/docs/python.md
3
+ import os
4
+ import numpy as np
5
+ import argparse
6
+ import imageio
7
+ import torch
8
+
9
+ from einops import rearrange
10
+ from diffusers import DDIMScheduler, AutoencoderKL
11
+ from transformers import CLIPTextModel, CLIPTokenizer
12
+ import controlnet_aux
13
+ from controlnet_aux import OpenposeDetector, CannyDetector, MidasDetector
14
+
15
+ from models.pipeline_controlvideo import ControlVideoPipeline
16
+ from models.util import save_videos_grid, read_video, get_annotation
17
+ from models.unet import UNet3DConditionModel
18
+ from models.controlnet import ControlNetModel3D
19
+ from models.RIFE.IFNet_HDv3 import IFNet
20
+ from cog import BasePredictor, Input, Path
21
+
22
+
23
+ sd_path = "checkpoints/stable-diffusion-v1-5"
24
+ inter_path = "checkpoints/flownet.pkl"
25
+ controlnet_dict = {
26
+ "pose": "checkpoints/sd-controlnet-openpose",
27
+ "depth": "checkpoints/sd-controlnet-depth",
28
+ "canny": "checkpoints/sd-controlnet-canny",
29
+ }
30
+
31
+ controlnet_parser_dict = {
32
+ "pose": OpenposeDetector,
33
+ "depth": MidasDetector,
34
+ "canny": CannyDetector,
35
+ }
36
+
37
+ POS_PROMPT = " ,best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth"
38
+ NEG_PROMPT = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic"
39
+
40
+
41
+ class Predictor(BasePredictor):
42
+ def setup(self):
43
+ """Load the model into memory to make running multiple predictions efficient"""
44
+
45
+ self.tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
46
+ self.text_encoder = CLIPTextModel.from_pretrained(
47
+ sd_path, subfolder="text_encoder"
48
+ ).to(dtype=torch.float16)
49
+ self.vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae").to(
50
+ dtype=torch.float16
51
+ )
52
+ self.unet = UNet3DConditionModel.from_pretrained_2d(
53
+ sd_path, subfolder="unet"
54
+ ).to(dtype=torch.float16)
55
+ self.interpolater = IFNet(ckpt_path=inter_path).to(dtype=torch.float16)
56
+ self.scheduler = DDIMScheduler.from_pretrained(sd_path, subfolder="scheduler")
57
+ self.controlnet = {
58
+ k: ControlNetModel3D.from_pretrained_2d(controlnet_dict[k]).to(
59
+ dtype=torch.float16
60
+ )
61
+ for k in ["depth", "canny", "pose"]
62
+ }
63
+ self.annotator = {k: controlnet_parser_dict[k]() for k in ["depth", "canny"]}
64
+ self.annotator["pose"] = controlnet_parser_dict["pose"].from_pretrained(
65
+ "lllyasviel/ControlNet", cache_dir="checkpoints"
66
+ )
67
+
68
+ def predict(
69
+ self,
70
+ prompt: str = Input(
71
+ description="Text description of target video",
72
+ default="A striking mallard floats effortlessly on the sparkling pond.",
73
+ ),
74
+ video_path: Path = Input(description="source video"),
75
+ condition: str = Input(
76
+ default="depth",
77
+ choices=["depth", "canny", "pose"],
78
+ description="Condition of structure sequence",
79
+ ),
80
+ video_length: int = Input(
81
+ default=15, description="Length of synthesized video"
82
+ ),
83
+ smoother_steps: str = Input(
84
+ default="19, 20",
85
+ description="Timesteps at which using interleaved-frame smoother, separate with comma",
86
+ ),
87
+ is_long_video: bool = Input(
88
+ default=False,
89
+ description="Whether to use hierarchical sampler to produce long video",
90
+ ),
91
+ num_inference_steps: int = Input(
92
+ description="Number of denoising steps", default=50
93
+ ),
94
+ guidance_scale: float = Input(
95
+ description="Scale for classifier-free guidance", ge=1, le=20, default=12.5
96
+ ),
97
+ seed: str = Input(
98
+ default=None, description="Random seed. Leave blank to randomize the seed"
99
+ ),
100
+ ) -> Path:
101
+ """Run a single prediction on the model"""
102
+ if seed is None:
103
+ seed = int.from_bytes(os.urandom(2), "big")
104
+ else:
105
+ seed = int(seed)
106
+ print(f"Using seed: {seed}")
107
+
108
+ generator = torch.Generator(device="cuda")
109
+ generator.manual_seed(seed)
110
+
111
+ pipe = ControlVideoPipeline(
112
+ vae=self.vae,
113
+ text_encoder=self.text_encoder,
114
+ tokenizer=self.tokenizer,
115
+ unet=self.unet,
116
+ controlnet=self.controlnet[condition],
117
+ interpolater=self.interpolater,
118
+ scheduler=self.scheduler,
119
+ )
120
+
121
+ pipe.enable_vae_slicing()
122
+ pipe.enable_xformers_memory_efficient_attention()
123
+ pipe.to("cuda")
124
+
125
+ # Step 1. Read a video
126
+ video = read_video(video_path=str(video_path), video_length=video_length)
127
+
128
+ # Step 2. Parse a video to conditional frames
129
+ pil_annotation = get_annotation(video, self.annotator[condition])
130
+
131
+ # Step 3. inference
132
+ smoother_steps = [int(s) for s in smoother_steps.split(",")]
133
+
134
+ if is_long_video:
135
+ window_size = int(np.sqrt(video_length))
136
+ sample = pipe.generate_long_video(
137
+ prompt + POS_PROMPT,
138
+ video_length=video_length,
139
+ frames=pil_annotation,
140
+ num_inference_steps=num_inference_steps,
141
+ smooth_steps=smoother_steps,
142
+ window_size=window_size,
143
+ generator=generator,
144
+ guidance_scale=guidance_scale,
145
+ negative_prompt=NEG_PROMPT,
146
+ ).videos
147
+ else:
148
+ sample = pipe(
149
+ prompt + POS_PROMPT,
150
+ video_length=video_length,
151
+ frames=pil_annotation,
152
+ num_inference_steps=num_inference_steps,
153
+ smooth_steps=smoother_steps,
154
+ generator=generator,
155
+ guidance_scale=guidance_scale,
156
+ negative_prompt=NEG_PROMPT,
157
+ ).videos
158
+
159
+ out_path = "/tmp/out.mp4"
160
+ save_videos_grid(sample, out_path)
161
+ del pipe
162
+ torch.cuda.empty_cache()
163
+
164
+ return Path(out_path)
requirements.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate=0.17.1
2
+ addict=2.4.0
3
+ basicsr=1.4.2
4
+ bitsandbytes=0.35.4
5
+ clip=1.0
6
+ cmake=3.25.2
7
+ controlnet-aux=0.0.4
8
+ decord=0.6.0
9
+ deepspeed=0.8.0
10
+ diffusers=0.14.0
11
+ easydict=1.10
12
+ einops=0.6.0
13
+ ffmpy=0.3.0
14
+ ftfy=6.1.1
15
+ imageio=2.25.1
16
+ imageio-ffmpeg=0.4.8
17
+ moviepy=1.0.3
18
+ numpy=1.24.2
19
+ omegaconf=2.3.0
20
+ opencv-python=4.7.0.68
21
+ pandas=1.5.3
22
+ pillow=9.4.0
23
+ scikit-image=0.19.3
24
+ scipy=1.10.1
25
+ tensorboard=2.12.0
26
+ tensorboard-data-server=0.7.0
27
+ tensorboard-plugin-wit=1.8.1
28
+ termcolor=2.2.0
29
+ thinc=8.1.10
30
+ timm=0.6.12
31
+ tokenizers=0.13.2
32
+ torch=1.13.1+cu116
33
+ torchvision=0.14.1+cu116
34
+ tqdm=4.64.1
35
+ transformers=4.26.1
36
+ wandb=0.13.10
37
+ xformers=0.0.16
38
+ modelcards