wrdias kadirnar commited on
Commit
2cc4443
0 Parent(s):

Duplicate from ArtGAN/Video-Diffusion-WebUI

Browse files

Co-authored-by: Kadir Nar <kadirnar@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Video Diffusion WebUI
3
+ emoji: 🏃
4
+ colorFrom: gray
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.19.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ tags:
12
+ - making-demos
13
+ duplicated_from: ArtGAN/Video-Diffusion-WebUI
14
+ ---
15
+
16
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from video_diffusion.damo.damo_text2_video import DamoText2VideoGenerator
4
+ from video_diffusion.inpaint_zoom.zoom_in_app import StableDiffusionZoomIn
5
+ from video_diffusion.inpaint_zoom.zoom_out_app import StableDiffusionZoomOut
6
+ from video_diffusion.stable_diffusion_video.stable_video_text2video import StableDiffusionText2VideoGenerator
7
+ from video_diffusion.tuneavideo.tuneavideo_text2video import TunaVideoText2VideoGenerator
8
+ from video_diffusion.zero_shot.zero_shot_text2video import ZeroShotText2VideoGenerator
9
+
10
+
11
+ def diffusion_app():
12
+ app = gr.Blocks()
13
+ with app:
14
+ gr.HTML(
15
+ """
16
+ <h1 style='text-align: center'>
17
+ Video Diffusion WebUI
18
+ </h1>
19
+ """
20
+ )
21
+ gr.HTML(
22
+ """
23
+ <h3 style='text-align: center'>
24
+ Follow me for more!
25
+ <a href='https://twitter.com/kadirnar_ai' target='_blank'>Twitter</a> | <a href='https://github.com/kadirnar' target='_blank'>Github</a> | <a href='https://www.linkedin.com/in/kadir-nar/' target='_blank'>Linkedin</a>
26
+ </h3>
27
+ """
28
+ )
29
+ with gr.Row():
30
+ with gr.Column():
31
+ with gr.Tab("Stable Diffusion Video"):
32
+ StableDiffusionText2VideoGenerator.app()
33
+ with gr.Tab("Tune-a-Video"):
34
+ TunaVideoText2VideoGenerator.app()
35
+ with gr.Tab("Stable Infinite Zoom"):
36
+ with gr.Tab("Zoom In"):
37
+ StableDiffusionZoomIn.app()
38
+ with gr.Tab("Zoom Out"):
39
+ StableDiffusionZoomOut.app()
40
+ with gr.Tab("Damo Text2Video"):
41
+ DamoText2VideoGenerator.app()
42
+ with gr.Tab("Zero Shot Text2Video"):
43
+ ZeroShotText2VideoGenerator.app()
44
+
45
+ app.queue(concurrency_count=1)
46
+ app.launch(debug=True, enable_queue=True)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ diffusion_app()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.0
2
+ git+https://github.com/huggingface/diffusers
3
+ transformers
4
+ accelerate
5
+ opencv-python
6
+ realesrgan==0.2.5.0
7
+ librosa
8
+ xformers
9
+ einops
10
+ av<10.0.0
11
+ imageio==2.9.0
12
+ imageio-ffmpeg==0.4.2
video_diffusion/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.0.1"
video_diffusion/damo/damo_text2_video.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
4
+ from diffusers.utils import export_to_video
5
+
6
+ from video_diffusion.utils.scheduler_list import diff_scheduler_list, get_scheduler_list
7
+
8
+ stable_model_list =["damo-vilab/text-to-video-ms-1.7b","cerspense/zeroscope_v2_576w"]
9
+
10
+ class DamoText2VideoGenerator:
11
+ def __init__(self):
12
+ self.pipe = None
13
+
14
+ def load_model(self, stable_model, scheduler):
15
+ if self.pipe is None:
16
+ self.pipe = DiffusionPipeline.from_pretrained(
17
+ stable_model, torch_dtype=torch.float16, variant="fp16"
18
+ )
19
+ self.pipe = get_scheduler_list(pipe=self.pipe, scheduler=scheduler)
20
+ self.pipe.to("cuda")
21
+ self.pipe.enable_xformers_memory_efficient_attention()
22
+ return self.pipe
23
+
24
+ def generate_video(
25
+ self,
26
+ prompt: str,
27
+ negative_prompt: str,
28
+ stable_model:str,
29
+ num_frames: int,
30
+ num_inference_steps: int,
31
+ guidance_scale: int,
32
+ height: int,
33
+ width: int,
34
+ scheduler: str,
35
+ ):
36
+ pipe = self.load_model(stable_model=stable_model, scheduler=scheduler)
37
+ video = pipe(
38
+ prompt,
39
+ negative_prompt=negative_prompt,
40
+ num_frames=int(num_frames),
41
+ height=height,
42
+ width=width,
43
+ num_inference_steps=num_inference_steps,
44
+ guidance_scale=guidance_scale,
45
+ ).frames
46
+
47
+ video_path = export_to_video(video)
48
+ return video_path
49
+
50
+ def app():
51
+ with gr.Blocks():
52
+ with gr.Row():
53
+ with gr.Column():
54
+ dano_text2video_prompt = gr.Textbox(lines=1, placeholder="Prompt", show_label=False)
55
+ dano_text2video_negative_prompt = gr.Textbox(
56
+ lines=1, placeholder="Negative Prompt", show_label=False
57
+ )
58
+ with gr.Row():
59
+ with gr.Column():
60
+ dano_text2video_model_list = gr.Dropdown(
61
+ choices=stable_model_list,
62
+ label="Model List",
63
+ value=stable_model_list[0],
64
+ )
65
+
66
+ dano_text2video_num_inference_steps = gr.Slider(
67
+ minimum=1,
68
+ maximum=100,
69
+ value=50,
70
+ step=1,
71
+ label="Inference Steps",
72
+ )
73
+ dano_text2video_guidance_scale = gr.Slider(
74
+ minimum=1,
75
+ maximum=15,
76
+ value=7,
77
+ step=1,
78
+ label="Guidance Scale",
79
+ )
80
+ dano_text2video_num_frames = gr.Slider(
81
+ minimum=1,
82
+ maximum=50,
83
+ value=16,
84
+ step=1,
85
+ label="Number of Frames",
86
+ )
87
+ with gr.Row():
88
+ with gr.Column():
89
+ dano_text2video_height = gr.Slider(
90
+ minimum=128,
91
+ maximum=1280,
92
+ value=512,
93
+ step=32,
94
+ label="Height",
95
+ )
96
+ dano_text2video_width = gr.Slider(
97
+ minimum=128,
98
+ maximum=1280,
99
+ value=512,
100
+ step=32,
101
+ label="Width",
102
+ )
103
+ damo_text2video_scheduler = gr.Dropdown(
104
+ choices=diff_scheduler_list,
105
+ label="Scheduler",
106
+ value=diff_scheduler_list[6],
107
+ )
108
+ dano_text2video_generate = gr.Button(value="Generator")
109
+ with gr.Column():
110
+ dano_output = gr.Video(label="Output")
111
+
112
+ dano_text2video_generate.click(
113
+ fn=DamoText2VideoGenerator().generate_video,
114
+ inputs=[
115
+ dano_text2video_prompt,
116
+ dano_text2video_negative_prompt,
117
+ dano_text2video_model_list,
118
+ dano_text2video_num_frames,
119
+ dano_text2video_num_inference_steps,
120
+ dano_text2video_guidance_scale,
121
+ dano_text2video_height,
122
+ dano_text2video_width,
123
+ damo_text2video_scheduler,
124
+ ],
125
+ outputs=dano_output,
126
+ )
video_diffusion/inpaint_zoom/__init__.py ADDED
File without changes
video_diffusion/inpaint_zoom/utils/__init__.py ADDED
File without changes
video_diffusion/inpaint_zoom/utils/zoom_in_utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
8
+
9
+
10
+ def write_video(file_path, frames, fps, reversed=True):
11
+ """
12
+ Writes frames to an mp4 video file
13
+ :param file_path: Path to output video, must end with .mp4
14
+ :param frames: List of PIL.Image objects
15
+ :param fps: Desired frame rate
16
+ :param reversed: if order of images to be reversed (default = True)
17
+ """
18
+ if reversed == True:
19
+ frames.reverse()
20
+
21
+ w, h = frames[0].size
22
+ fourcc = cv2.VideoWriter_fourcc("m", "p", "4", "v")
23
+ # fourcc = cv2.VideoWriter_fourcc(*'avc1')
24
+ writer = cv2.VideoWriter(file_path, fourcc, fps, (w, h))
25
+
26
+ for frame in frames:
27
+ np_frame = np.array(frame.convert("RGB"))
28
+ cv_frame = cv2.cvtColor(np_frame, cv2.COLOR_RGB2BGR)
29
+ writer.write(cv_frame)
30
+
31
+ writer.release()
32
+
33
+
34
+ def image_grid(imgs, rows, cols):
35
+ assert len(imgs) == rows * cols
36
+
37
+ w, h = imgs[0].size
38
+ grid = Image.new("RGB", size=(cols * w, rows * h))
39
+ grid_w, grid_h = grid.size
40
+
41
+ for i, img in enumerate(imgs):
42
+ grid.paste(img, box=(i % cols * w, i // cols * h))
43
+ return grid
44
+
45
+
46
+ def shrink_and_paste_on_blank(current_image, mask_width):
47
+ """
48
+ Decreases size of current_image by mask_width pixels from each side,
49
+ then adds a mask_width width transparent frame,
50
+ so that the image the function returns is the same size as the input.
51
+ :param current_image: input image to transform
52
+ :param mask_width: width in pixels to shrink from each side
53
+ """
54
+
55
+ height = current_image.height
56
+ width = current_image.width
57
+
58
+ # shrink down by mask_width
59
+ prev_image = current_image.resize((height - 2 * mask_width, width - 2 * mask_width))
60
+ prev_image = prev_image.convert("RGBA")
61
+ prev_image = np.array(prev_image)
62
+
63
+ # create blank non-transparent image
64
+ blank_image = np.array(current_image.convert("RGBA")) * 0
65
+ blank_image[:, :, 3] = 1
66
+
67
+ # paste shrinked onto blank
68
+ blank_image[mask_width : height - mask_width, mask_width : width - mask_width, :] = prev_image
69
+ prev_image = Image.fromarray(blank_image)
70
+
71
+ return prev_image
72
+
73
+
74
+ def dummy(images, **kwargs):
75
+ return images, False
video_diffusion/inpaint_zoom/utils/zoom_out_utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+
6
+ def write_video(file_path, frames, fps):
7
+ """
8
+ Writes frames to an mp4 video file
9
+ :param file_path: Path to output video, must end with .mp4
10
+ :param frames: List of PIL.Image objects
11
+ :param fps: Desired frame rate
12
+ """
13
+
14
+ w, h = frames[0].size
15
+ fourcc = cv2.VideoWriter_fourcc("m", "p", "4", "v")
16
+ writer = cv2.VideoWriter(file_path, fourcc, fps, (w, h))
17
+
18
+ for frame in frames:
19
+ np_frame = np.array(frame.convert("RGB"))
20
+ cv_frame = cv2.cvtColor(np_frame, cv2.COLOR_RGB2BGR)
21
+ writer.write(cv_frame)
22
+
23
+ writer.release()
24
+
25
+
26
+ def dummy(images, **kwargs):
27
+ return images, False
28
+
29
+
30
+ def preprocess_image(current_image, steps, image_size):
31
+ next_image = np.array(current_image.convert("RGBA")) * 0
32
+ prev_image = current_image.resize((image_size - 2 * steps, image_size - 2 * steps))
33
+ prev_image = prev_image.convert("RGBA")
34
+ prev_image = np.array(prev_image)
35
+ next_image[:, :, 3] = 1
36
+ next_image[steps : image_size - steps, steps : image_size - steps, :] = prev_image
37
+ prev_image = Image.fromarray(next_image)
38
+
39
+ return prev_image
40
+
41
+
42
+ def preprocess_mask_image(current_image):
43
+ mask_image = np.array(current_image)[:, :, 3] # assume image has alpha mask (use .mode to check for "RGBA")
44
+ mask_image = Image.fromarray(255 - mask_image).convert("RGB")
45
+ current_image = current_image.convert("RGB")
46
+
47
+ return current_image, mask_image
video_diffusion/inpaint_zoom/zoom_in_app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
7
+ from PIL import Image
8
+
9
+ from video_diffusion.inpaint_zoom.utils.zoom_in_utils import dummy, image_grid, shrink_and_paste_on_blank, write_video
10
+
11
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
12
+
13
+
14
+ stable_paint_model_list = ["stabilityai/stable-diffusion-2-inpainting", "runwayml/stable-diffusion-inpainting"]
15
+
16
+ stable_paint_prompt_list = [
17
+ "children running in the forest , sunny, bright, by studio ghibli painting, superior quality, masterpiece, traditional Japanese colors, by Grzegorz Rutkowski, concept art",
18
+ "A beautiful landscape of a mountain range with a lake in the foreground",
19
+ ]
20
+
21
+ stable_paint_negative_prompt_list = [
22
+ "lurry, bad art, blurred, text, watermark",
23
+ ]
24
+
25
+
26
+ class StableDiffusionZoomIn:
27
+ def __init__(self):
28
+ self.pipe = None
29
+
30
+ def load_model(self, model_id):
31
+ if self.pipe is None:
32
+ self.pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16")
33
+ self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
34
+ self.pipe = self.pipe.to("cuda")
35
+ self.pipe.safety_checker = dummy
36
+ self.pipe.enable_attention_slicing()
37
+ self.pipe.enable_xformers_memory_efficient_attention()
38
+ self.g_cuda = torch.Generator(device="cuda")
39
+
40
+ return self.pipe
41
+
42
+ def generate_video(
43
+ self,
44
+ model_id,
45
+ prompt,
46
+ negative_prompt,
47
+ guidance_scale,
48
+ num_inference_steps,
49
+ ):
50
+ pipe = self.load_model(model_id)
51
+
52
+ num_init_images = 2
53
+ seed = 42
54
+ height = 512
55
+ width = height
56
+
57
+ current_image = Image.new(mode="RGBA", size=(height, width))
58
+ mask_image = np.array(current_image)[:, :, 3]
59
+ mask_image = Image.fromarray(255 - mask_image).convert("RGB")
60
+ current_image = current_image.convert("RGB")
61
+
62
+ init_images = pipe(
63
+ prompt=[prompt] * num_init_images,
64
+ negative_prompt=[negative_prompt] * num_init_images,
65
+ image=current_image,
66
+ guidance_scale=guidance_scale,
67
+ height=height,
68
+ width=width,
69
+ generator=self.g_cuda.manual_seed(seed),
70
+ mask_image=mask_image,
71
+ num_inference_steps=num_inference_steps,
72
+ )[0]
73
+
74
+ image_grid(init_images, rows=1, cols=num_init_images)
75
+
76
+ init_image_selected = 1 # @param
77
+ if num_init_images == 1:
78
+ init_image_selected = 0
79
+ else:
80
+ init_image_selected = init_image_selected - 1
81
+
82
+ num_outpainting_steps = 20 # @param
83
+ mask_width = 128 # @param
84
+ num_interpol_frames = 30 # @param
85
+
86
+ current_image = init_images[init_image_selected]
87
+ all_frames = []
88
+ all_frames.append(current_image)
89
+
90
+ for i in range(num_outpainting_steps):
91
+ print("Generating image: " + str(i + 1) + " / " + str(num_outpainting_steps))
92
+
93
+ prev_image_fix = current_image
94
+
95
+ prev_image = shrink_and_paste_on_blank(current_image, mask_width)
96
+
97
+ current_image = prev_image
98
+
99
+ # create mask (black image with white mask_width width edges)
100
+ mask_image = np.array(current_image)[:, :, 3]
101
+ mask_image = Image.fromarray(255 - mask_image).convert("RGB")
102
+
103
+ # inpainting step
104
+ current_image = current_image.convert("RGB")
105
+ images = pipe(
106
+ prompt=prompt,
107
+ negative_prompt=negative_prompt,
108
+ image=current_image,
109
+ guidance_scale=guidance_scale,
110
+ height=height,
111
+ width=width,
112
+ # this can make the whole thing deterministic but the output less exciting
113
+ # generator = g_cuda.manual_seed(seed),
114
+ mask_image=mask_image,
115
+ num_inference_steps=num_inference_steps,
116
+ )[0]
117
+ current_image = images[0]
118
+ current_image.paste(prev_image, mask=prev_image)
119
+
120
+ # interpolation steps bewteen 2 inpainted images (=sequential zoom and crop)
121
+ for j in range(num_interpol_frames - 1):
122
+ interpol_image = current_image
123
+ interpol_width = round(
124
+ (1 - (1 - 2 * mask_width / height) ** (1 - (j + 1) / num_interpol_frames)) * height / 2
125
+ )
126
+ interpol_image = interpol_image.crop(
127
+ (interpol_width, interpol_width, width - interpol_width, height - interpol_width)
128
+ )
129
+
130
+ interpol_image = interpol_image.resize((height, width))
131
+
132
+ # paste the higher resolution previous image in the middle to avoid drop in quality caused by zooming
133
+ interpol_width2 = round((1 - (height - 2 * mask_width) / (height - 2 * interpol_width)) / 2 * height)
134
+ prev_image_fix_crop = shrink_and_paste_on_blank(prev_image_fix, interpol_width2)
135
+ interpol_image.paste(prev_image_fix_crop, mask=prev_image_fix_crop)
136
+
137
+ all_frames.append(interpol_image)
138
+
139
+ all_frames.append(current_image)
140
+
141
+ video_file_name = "infinite_zoom_out"
142
+ fps = 30
143
+ save_path = video_file_name + ".mp4"
144
+ write_video(save_path, all_frames, fps)
145
+ return save_path
146
+
147
+ def app():
148
+ with gr.Blocks():
149
+ with gr.Row():
150
+ with gr.Column():
151
+ text2image_in_model_path = gr.Dropdown(
152
+ choices=stable_paint_model_list, value=stable_paint_model_list[0], label="Text-Image Model Id"
153
+ )
154
+
155
+ text2image_in_prompt = gr.Textbox(lines=2, value=stable_paint_prompt_list[0], label="Prompt")
156
+
157
+ text2image_in_negative_prompt = gr.Textbox(
158
+ lines=1, value=stable_paint_negative_prompt_list[0], label="Negative Prompt"
159
+ )
160
+
161
+ with gr.Row():
162
+ with gr.Column():
163
+ text2image_in_guidance_scale = gr.Slider(
164
+ minimum=0.1, maximum=15, step=0.1, value=7.5, label="Guidance Scale"
165
+ )
166
+
167
+ text2image_in_num_inference_step = gr.Slider(
168
+ minimum=1, maximum=100, step=1, value=50, label="Num Inference Step"
169
+ )
170
+
171
+ text2image_in_predict = gr.Button(value="Generator")
172
+
173
+ with gr.Column():
174
+ output_image = gr.Video(label="Output")
175
+
176
+ text2image_in_predict.click(
177
+ fn=StableDiffusionZoomIn().generate_video,
178
+ inputs=[
179
+ text2image_in_model_path,
180
+ text2image_in_prompt,
181
+ text2image_in_negative_prompt,
182
+ text2image_in_guidance_scale,
183
+ text2image_in_num_inference_step,
184
+ ],
185
+ outputs=output_image,
186
+ )
video_diffusion/inpaint_zoom/zoom_out_app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
6
+ from PIL import Image
7
+
8
+ from video_diffusion.inpaint_zoom.utils.zoom_out_utils import (
9
+ dummy,
10
+ preprocess_image,
11
+ preprocess_mask_image,
12
+ write_video,
13
+ )
14
+
15
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
16
+
17
+
18
+ stable_paint_model_list = ["stabilityai/stable-diffusion-2-inpainting", "runwayml/stable-diffusion-inpainting"]
19
+
20
+ stable_paint_prompt_list = [
21
+ "children running in the forest , sunny, bright, by studio ghibli painting, superior quality, masterpiece, traditional Japanese colors, by Grzegorz Rutkowski, concept art",
22
+ "A beautiful landscape of a mountain range with a lake in the foreground",
23
+ ]
24
+
25
+ stable_paint_negative_prompt_list = [
26
+ "lurry, bad art, blurred, text, watermark",
27
+ ]
28
+
29
+
30
+ class StableDiffusionZoomOut:
31
+ def __init__(self):
32
+ self.pipe = None
33
+
34
+ def load_model(self, model_id):
35
+ if self.pipe is None:
36
+ self.pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
37
+ self.pipe.set_use_memory_efficient_attention_xformers(True)
38
+ self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
39
+ self.pipe = self.pipe.to("cuda")
40
+ self.pipe.safety_checker = dummy
41
+ self.g_cuda = torch.Generator(device="cuda")
42
+
43
+ return self.pipe
44
+
45
+ def generate_video(
46
+ self,
47
+ model_id,
48
+ prompt,
49
+ negative_prompt,
50
+ guidance_scale,
51
+ num_inference_steps,
52
+ num_frames,
53
+ step_size,
54
+ ):
55
+ pipe = self.load_model(model_id)
56
+
57
+ new_image = Image.new(mode="RGBA", size=(512, 512))
58
+ current_image, mask_image = preprocess_mask_image(new_image)
59
+
60
+ current_image = pipe(
61
+ prompt=[prompt],
62
+ negative_prompt=[negative_prompt],
63
+ image=current_image,
64
+ mask_image=mask_image,
65
+ num_inference_steps=num_inference_steps,
66
+ guidance_scale=guidance_scale,
67
+ ).images[0]
68
+
69
+ all_frames = []
70
+ all_frames.append(current_image)
71
+
72
+ for i in range(num_frames):
73
+ prev_image = preprocess_image(current_image, step_size, 512)
74
+ current_image = prev_image
75
+ current_image, mask_image = preprocess_mask_image(current_image)
76
+ current_image = pipe(
77
+ prompt=[prompt],
78
+ negative_prompt=[negative_prompt],
79
+ image=current_image,
80
+ mask_image=mask_image,
81
+ num_inference_steps=num_inference_steps,
82
+ ).images[0]
83
+ current_image.paste(prev_image, mask=prev_image)
84
+ all_frames.append(current_image)
85
+
86
+ save_path = "output.mp4"
87
+ write_video(save_path, all_frames, fps=30)
88
+ return save_path
89
+
90
+ def app():
91
+ with gr.Blocks():
92
+ with gr.Row():
93
+ with gr.Column():
94
+ text2image_out_model_path = gr.Dropdown(
95
+ choices=stable_paint_model_list, value=stable_paint_model_list[0], label="Text-Image Model Id"
96
+ )
97
+
98
+ text2image_out_prompt = gr.Textbox(lines=2, value=stable_paint_prompt_list[0], label="Prompt")
99
+
100
+ text2image_out_negative_prompt = gr.Textbox(
101
+ lines=1, value=stable_paint_negative_prompt_list[0], label="Negative Prompt"
102
+ )
103
+
104
+ with gr.Row():
105
+ with gr.Column():
106
+ text2image_out_guidance_scale = gr.Slider(
107
+ minimum=0.1, maximum=15, step=0.1, value=7.5, label="Guidance Scale"
108
+ )
109
+
110
+ text2image_out_num_inference_step = gr.Slider(
111
+ minimum=1, maximum=100, step=1, value=50, label="Num Inference Step"
112
+ )
113
+ with gr.Row():
114
+ with gr.Column():
115
+ text2image_out_step_size = gr.Slider(
116
+ minimum=1, maximum=100, step=1, value=10, label="Step Size"
117
+ )
118
+
119
+ text2image_out_num_frames = gr.Slider(
120
+ minimum=1, maximum=100, step=1, value=10, label="Frames"
121
+ )
122
+
123
+ text2image_out_predict = gr.Button(value="Generator")
124
+
125
+ with gr.Column():
126
+ output_image = gr.Video(label="Output")
127
+
128
+ text2image_out_predict.click(
129
+ fn=StableDiffusionZoomOut().generate_video,
130
+ inputs=[
131
+ text2image_out_model_path,
132
+ text2image_out_prompt,
133
+ text2image_out_negative_prompt,
134
+ text2image_out_guidance_scale,
135
+ text2image_out_num_inference_step,
136
+ text2image_out_step_size,
137
+ text2image_out_num_frames,
138
+ ],
139
+ outputs=output_image,
140
+ )
video_diffusion/stable_diffusion_video/__init__.py ADDED
File without changes
video_diffusion/stable_diffusion_video/image_generation.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import random
4
+ import time
5
+ from pathlib import Path
6
+ from uuid import uuid4
7
+
8
+ import torch
9
+ from diffusers import __version__ as diffusers_version
10
+ from huggingface_hub import CommitOperationAdd, create_commit, create_repo
11
+
12
+ from .upsampling import RealESRGANModel
13
+ from .utils import pad_along_axis
14
+
15
+
16
+ def get_all_files(root: Path):
17
+ dirs = [root]
18
+ while len(dirs) > 0:
19
+ dir = dirs.pop()
20
+ for candidate in dir.iterdir():
21
+ if candidate.is_file():
22
+ yield candidate
23
+ if candidate.is_dir():
24
+ dirs.append(candidate)
25
+
26
+
27
+ def get_groups_of_n(n: int, iterator):
28
+ assert n > 1
29
+ buffer = []
30
+ for elt in iterator:
31
+ if len(buffer) == n:
32
+ yield buffer
33
+ buffer = []
34
+ buffer.append(elt)
35
+ if len(buffer) != 0:
36
+ yield buffer
37
+
38
+
39
+ def upload_folder_chunked(
40
+ repo_id: str,
41
+ upload_dir: Path,
42
+ n: int = 100,
43
+ private: bool = False,
44
+ create_pr: bool = False,
45
+ ):
46
+ """Upload a folder to the Hugging Face Hub in chunks of n files at a time.
47
+ Args:
48
+ repo_id (str): The repo id to upload to.
49
+ upload_dir (Path): The directory to upload.
50
+ n (int, *optional*, defaults to 100): The number of files to upload at a time.
51
+ private (bool, *optional*): Whether to upload the repo as private.
52
+ create_pr (bool, *optional*): Whether to create a PR after uploading instead of commiting directly.
53
+ """
54
+
55
+ url = create_repo(repo_id, exist_ok=True, private=private, repo_type="dataset")
56
+ print(f"Uploading files to: {url}")
57
+
58
+ root = Path(upload_dir)
59
+ if not root.exists():
60
+ raise ValueError(f"Upload directory {root} does not exist.")
61
+
62
+ for i, file_paths in enumerate(get_groups_of_n(n, get_all_files(root))):
63
+ print(f"Committing {file_paths}")
64
+ operations = [
65
+ CommitOperationAdd(
66
+ path_in_repo=f"{file_path.parent.name}/{file_path.name}",
67
+ path_or_fileobj=str(file_path),
68
+ )
69
+ for file_path in file_paths
70
+ ]
71
+ create_commit(
72
+ repo_id=repo_id,
73
+ operations=operations,
74
+ commit_message=f"Upload part {i}",
75
+ repo_type="dataset",
76
+ create_pr=create_pr,
77
+ )
78
+
79
+
80
+ def generate_input_batches(pipeline, prompts, seeds, batch_size, height, width):
81
+ if len(prompts) != len(seeds):
82
+ raise ValueError("Number of prompts and seeds must be equal.")
83
+
84
+ embeds_batch, noise_batch = None, None
85
+ batch_idx = 0
86
+ for i, (prompt, seed) in enumerate(zip(prompts, seeds)):
87
+ embeds = pipeline.embed_text(prompt)
88
+ noise = torch.randn(
89
+ (1, pipeline.unet.in_channels, height // 8, width // 8),
90
+ device=pipeline.device,
91
+ generator=torch.Generator(device="cpu" if pipeline.device.type == "mps" else pipeline.device).manual_seed(
92
+ seed
93
+ ),
94
+ )
95
+ embeds_batch = embeds if embeds_batch is None else torch.cat([embeds_batch, embeds])
96
+ noise_batch = noise if noise_batch is None else torch.cat([noise_batch, noise])
97
+ batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == len(prompts)
98
+ if not batch_is_ready:
99
+ continue
100
+ yield batch_idx, embeds_batch.type(torch.cuda.HalfTensor), noise_batch.type(torch.cuda.HalfTensor)
101
+ batch_idx += 1
102
+ del embeds_batch, noise_batch
103
+ torch.cuda.empty_cache()
104
+ embeds_batch, noise_batch = None, None
105
+
106
+
107
+ def generate_images(
108
+ pipeline,
109
+ prompt,
110
+ batch_size=1,
111
+ num_batches=1,
112
+ seeds=None,
113
+ num_inference_steps=50,
114
+ guidance_scale=7.5,
115
+ output_dir="./images",
116
+ image_file_ext=".jpg",
117
+ upsample=False,
118
+ height=512,
119
+ width=512,
120
+ eta=0.0,
121
+ push_to_hub=False,
122
+ repo_id=None,
123
+ private=False,
124
+ create_pr=False,
125
+ name=None,
126
+ ):
127
+ """Generate images using the StableDiffusion pipeline.
128
+ Args:
129
+ pipeline (StableDiffusionWalkPipeline): The StableDiffusion pipeline instance.
130
+ prompt (str): The prompt to use for the image generation.
131
+ batch_size (int, *optional*, defaults to 1): The batch size to use for image generation.
132
+ num_batches (int, *optional*, defaults to 1): The number of batches to generate.
133
+ seeds (list[int], *optional*): The seeds to use for the image generation.
134
+ num_inference_steps (int, *optional*, defaults to 50): The number of inference steps to take.
135
+ guidance_scale (float, *optional*, defaults to 7.5): The guidance scale to use for image generation.
136
+ output_dir (str, *optional*, defaults to "./images"): The output directory to save the images to.
137
+ image_file_ext (str, *optional*, defaults to '.jpg'): The image file extension to use.
138
+ upsample (bool, *optional*, defaults to False): Whether to upsample the images.
139
+ height (int, *optional*, defaults to 512): The height of the images to generate.
140
+ width (int, *optional*, defaults to 512): The width of the images to generate.
141
+ eta (float, *optional*, defaults to 0.0): The eta parameter to use for image generation.
142
+ push_to_hub (bool, *optional*, defaults to False): Whether to push the generated images to the Hugging Face Hub.
143
+ repo_id (str, *optional*): The repo id to push the images to.
144
+ private (bool, *optional*): Whether to push the repo as private.
145
+ create_pr (bool, *optional*): Whether to create a PR after pushing instead of commiting directly.
146
+ name (str, *optional*, defaults to current timestamp str): The name of the sub-directory of
147
+ output_dir to save the images to.
148
+ """
149
+ if push_to_hub:
150
+ if repo_id is None:
151
+ raise ValueError("Must provide repo_id if push_to_hub is True.")
152
+
153
+ name = name or time.strftime("%Y%m%d-%H%M%S")
154
+ save_path = Path(output_dir) / name
155
+ save_path.mkdir(exist_ok=False, parents=True)
156
+ prompt_config_path = save_path / "prompt_config.json"
157
+
158
+ num_images = batch_size * num_batches
159
+ seeds = seeds or [random.choice(list(range(0, 9999999))) for _ in range(num_images)]
160
+ if len(seeds) != num_images:
161
+ raise ValueError("Number of seeds must be equal to batch_size * num_batches.")
162
+
163
+ if upsample:
164
+ if getattr(pipeline, "upsampler", None) is None:
165
+ pipeline.upsampler = RealESRGANModel.from_pretrained("nateraw/real-esrgan")
166
+ pipeline.upsampler.to(pipeline.device)
167
+
168
+ cfg = dict(
169
+ prompt=prompt,
170
+ guidance_scale=guidance_scale,
171
+ eta=eta,
172
+ num_inference_steps=num_inference_steps,
173
+ upsample=upsample,
174
+ height=height,
175
+ width=width,
176
+ scheduler=dict(pipeline.scheduler.config),
177
+ tiled=pipeline.tiled,
178
+ diffusers_version=diffusers_version,
179
+ device_name=torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown",
180
+ )
181
+ prompt_config_path.write_text(json.dumps(cfg, indent=2, sort_keys=False))
182
+
183
+ frame_index = 0
184
+ frame_filepaths = []
185
+ for batch_idx, embeds, noise in generate_input_batches(
186
+ pipeline, [prompt] * num_images, seeds, batch_size, height, width
187
+ ):
188
+ print(f"Generating batch {batch_idx}")
189
+
190
+ outputs = pipeline(
191
+ text_embeddings=embeds,
192
+ latents=noise,
193
+ num_inference_steps=num_inference_steps,
194
+ guidance_scale=guidance_scale,
195
+ eta=eta,
196
+ height=height,
197
+ width=width,
198
+ output_type="pil" if not upsample else "numpy",
199
+ )["images"]
200
+ if upsample:
201
+ images = []
202
+ for output in outputs:
203
+ images.append(pipeline.upsampler(output))
204
+ else:
205
+ images = outputs
206
+
207
+ for image in images:
208
+ frame_filepath = save_path / f"{seeds[frame_index]}{image_file_ext}"
209
+ image.save(frame_filepath)
210
+ frame_filepaths.append(str(frame_filepath))
211
+ frame_index += 1
212
+
213
+ return frame_filepaths
214
+
215
+ if push_to_hub:
216
+ upload_folder_chunked(repo_id, save_path, private=private, create_pr=create_pr)
217
+
218
+
219
+ def generate_images_flax(
220
+ pipeline,
221
+ params,
222
+ prompt,
223
+ batch_size=1,
224
+ num_batches=1,
225
+ seeds=None,
226
+ num_inference_steps=50,
227
+ guidance_scale=7.5,
228
+ output_dir="./images",
229
+ image_file_ext=".jpg",
230
+ upsample=False,
231
+ height=512,
232
+ width=512,
233
+ push_to_hub=False,
234
+ repo_id=None,
235
+ private=False,
236
+ create_pr=False,
237
+ name=None,
238
+ ):
239
+ import jax
240
+ from flax.training.common_utils import shard
241
+
242
+ """Generate images using the StableDiffusion pipeline.
243
+ Args:
244
+ pipeline (StableDiffusionWalkPipeline): The StableDiffusion pipeline instance.
245
+ params (`Union[Dict, FrozenDict]`): The model parameters.
246
+ prompt (str): The prompt to use for the image generation.
247
+ batch_size (int, *optional*, defaults to 1): The batch size to use for image generation.
248
+ num_batches (int, *optional*, defaults to 1): The number of batches to generate.
249
+ seeds (int, *optional*): The seed to use for the image generation.
250
+ num_inference_steps (int, *optional*, defaults to 50): The number of inference steps to take.
251
+ guidance_scale (float, *optional*, defaults to 7.5): The guidance scale to use for image generation.
252
+ output_dir (str, *optional*, defaults to "./images"): The output directory to save the images to.
253
+ image_file_ext (str, *optional*, defaults to '.jpg'): The image file extension to use.
254
+ upsample (bool, *optional*, defaults to False): Whether to upsample the images.
255
+ height (int, *optional*, defaults to 512): The height of the images to generate.
256
+ width (int, *optional*, defaults to 512): The width of the images to generate.
257
+ push_to_hub (bool, *optional*, defaults to False): Whether to push the generated images to the Hugging Face Hub.
258
+ repo_id (str, *optional*): The repo id to push the images to.
259
+ private (bool, *optional*): Whether to push the repo as private.
260
+ create_pr (bool, *optional*): Whether to create a PR after pushing instead of commiting directly.
261
+ name (str, *optional*, defaults to current timestamp str): The name of the sub-directory of
262
+ output_dir to save the images to.
263
+ """
264
+ if push_to_hub:
265
+ if repo_id is None:
266
+ raise ValueError("Must provide repo_id if push_to_hub is True.")
267
+
268
+ name = name or time.strftime("%Y%m%d-%H%M%S")
269
+ save_path = Path(output_dir) / name
270
+ save_path.mkdir(exist_ok=False, parents=True)
271
+ prompt_config_path = save_path / "prompt_config.json"
272
+
273
+ num_images = batch_size * num_batches
274
+ seeds = seeds or random.choice(list(range(0, 9999999)))
275
+ prng_seed = jax.random.PRNGKey(seeds)
276
+
277
+ if upsample:
278
+ if getattr(pipeline, "upsampler", None) is None:
279
+ pipeline.upsampler = RealESRGANModel.from_pretrained("nateraw/real-esrgan")
280
+ if not torch.cuda.is_available():
281
+ print("Upsampling is recommended to be done on a GPU, as it is very slow on CPU")
282
+ else:
283
+ pipeline.upsampler = pipeline.upsampler.cuda()
284
+
285
+ cfg = dict(
286
+ prompt=prompt,
287
+ guidance_scale=guidance_scale,
288
+ num_inference_steps=num_inference_steps,
289
+ upsample=upsample,
290
+ height=height,
291
+ width=width,
292
+ scheduler=dict(pipeline.scheduler.config),
293
+ # tiled=pipeline.tiled,
294
+ diffusers_version=diffusers_version,
295
+ device_name=torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown",
296
+ )
297
+ prompt_config_path.write_text(json.dumps(cfg, indent=2, sort_keys=False))
298
+
299
+ NUM_TPU_CORES = jax.device_count()
300
+ jit = True # force jit, assume params are already sharded
301
+ batch_size_total = NUM_TPU_CORES * batch_size if jit else batch_size
302
+
303
+ def generate_input_batches(prompts, batch_size):
304
+ prompt_batch = None
305
+ for batch_idx in range(math.ceil(len(prompts) / batch_size)):
306
+ prompt_batch = prompts[batch_idx * batch_size : (batch_idx + 1) * batch_size]
307
+ yield batch_idx, prompt_batch
308
+
309
+ frame_index = 0
310
+ frame_filepaths = []
311
+ for batch_idx, prompt_batch in generate_input_batches([prompt] * num_images, batch_size_total):
312
+ # This batch size correspond to each TPU core, so we are generating batch_size * NUM_TPU_CORES images
313
+ print(f"Generating batches: {batch_idx*NUM_TPU_CORES} - {min((batch_idx+1)*NUM_TPU_CORES, num_batches)}")
314
+ prompt_ids_batch = pipeline.prepare_inputs(prompt_batch)
315
+ prng_seed_batch = prng_seed
316
+
317
+ if jit:
318
+ padded = False
319
+ # Check if len of prompt_batch is multiple of NUM_TPU_CORES, if not pad its ids
320
+ if len(prompt_batch) % NUM_TPU_CORES != 0:
321
+ padded = True
322
+ pad_size = NUM_TPU_CORES - (len(prompt_batch) % NUM_TPU_CORES)
323
+ # Pad embeds_batch and noise_batch with zeros in batch dimension
324
+ prompt_ids_batch = pad_along_axis(prompt_ids_batch, pad_size, axis=0)
325
+
326
+ prompt_ids_batch = shard(prompt_ids_batch)
327
+ prng_seed_batch = jax.random.split(prng_seed, jax.device_count())
328
+
329
+ outputs = pipeline(
330
+ params,
331
+ prng_seed=prng_seed_batch,
332
+ prompt_ids=prompt_ids_batch,
333
+ height=height,
334
+ width=width,
335
+ guidance_scale=guidance_scale,
336
+ num_inference_steps=num_inference_steps,
337
+ output_type="pil" if not upsample else "numpy",
338
+ jit=jit,
339
+ )["images"]
340
+
341
+ if jit:
342
+ # check if we padded and remove that padding from outputs
343
+ if padded:
344
+ outputs = outputs[:-pad_size]
345
+
346
+ if upsample:
347
+ images = []
348
+ for output in outputs:
349
+ images.append(pipeline.upsampler(output))
350
+ else:
351
+ images = outputs
352
+
353
+ for image in images:
354
+ uuid = str(uuid4())
355
+ frame_filepath = save_path / f"{uuid}{image_file_ext}"
356
+ image.save(frame_filepath)
357
+ frame_filepaths.append(str(frame_filepath))
358
+ frame_index += 1
359
+
360
+ return frame_filepaths
361
+
362
+ if push_to_hub:
363
+ upload_folder_chunked(repo_id, save_path, private=private, create_pr=create_pr)
video_diffusion/stable_diffusion_video/stable_diffusion_pipeline.py ADDED
@@ -0,0 +1,848 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import json
3
+ import math
4
+ import time
5
+ from pathlib import Path
6
+ from typing import Callable, List, Optional, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from diffusers.configuration_utils import FrozenDict
11
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
12
+ from diffusers.pipeline_utils import DiffusionPipeline
13
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
14
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
15
+ from diffusers.schedulers import (
16
+ DDIMScheduler,
17
+ DPMSolverMultistepScheduler,
18
+ EulerAncestralDiscreteScheduler,
19
+ EulerDiscreteScheduler,
20
+ LMSDiscreteScheduler,
21
+ PNDMScheduler,
22
+ )
23
+ from diffusers.utils import deprecate, logging
24
+ from packaging import version
25
+ from torch import nn
26
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
27
+
28
+ from .upsampling import RealESRGANModel
29
+ from .utils import get_timesteps_arr, make_video_pyav, slerp
30
+
31
+ logging.set_verbosity_info()
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class StableDiffusionWalkPipeline(DiffusionPipeline):
36
+ r"""
37
+ Pipeline for generating videos by interpolating Stable Diffusion's latent space.
38
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
39
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
40
+ Args:
41
+ vae ([`AutoencoderKL`]):
42
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
43
+ text_encoder ([`CLIPTextModel`]):
44
+ Frozen text-encoder. Stable Diffusion uses the text portion of
45
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
46
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
47
+ tokenizer (`CLIPTokenizer`):
48
+ Tokenizer of class
49
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
50
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
51
+ scheduler ([`SchedulerMixin`]):
52
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
53
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
54
+ safety_checker ([`StableDiffusionSafetyChecker`]):
55
+ Classification module that estimates whether generated images could be considered offensive or harmful.
56
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
57
+ feature_extractor ([`CLIPFeatureExtractor`]):
58
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
59
+ """
60
+ _optional_components = ["safety_checker", "feature_extractor"]
61
+
62
+ def __init__(
63
+ self,
64
+ vae: AutoencoderKL,
65
+ text_encoder: CLIPTextModel,
66
+ tokenizer: CLIPTokenizer,
67
+ unet: UNet2DConditionModel,
68
+ scheduler: Union[
69
+ DDIMScheduler,
70
+ PNDMScheduler,
71
+ LMSDiscreteScheduler,
72
+ EulerDiscreteScheduler,
73
+ EulerAncestralDiscreteScheduler,
74
+ DPMSolverMultistepScheduler,
75
+ ],
76
+ safety_checker: StableDiffusionSafetyChecker,
77
+ feature_extractor: CLIPFeatureExtractor,
78
+ requires_safety_checker: bool = True,
79
+ ):
80
+ super().__init__()
81
+
82
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
83
+ deprecation_message = (
84
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
85
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
86
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
87
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
88
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
89
+ " file"
90
+ )
91
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
92
+ new_config = dict(scheduler.config)
93
+ new_config["steps_offset"] = 1
94
+ scheduler._internal_dict = FrozenDict(new_config)
95
+
96
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
97
+ deprecation_message = (
98
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
99
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
100
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
101
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
102
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
103
+ )
104
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
105
+ new_config = dict(scheduler.config)
106
+ new_config["clip_sample"] = False
107
+ scheduler._internal_dict = FrozenDict(new_config)
108
+
109
+ if safety_checker is None and requires_safety_checker:
110
+ logger.warning(
111
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
112
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
113
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
114
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
115
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
116
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
117
+ )
118
+
119
+ if safety_checker is not None and feature_extractor is None:
120
+ raise ValueError(
121
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
122
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
123
+ )
124
+
125
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
126
+ version.parse(unet.config._diffusers_version).base_version
127
+ ) < version.parse("0.9.0.dev0")
128
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
129
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
130
+ deprecation_message = (
131
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
132
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
133
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
134
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
135
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
136
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
137
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
138
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
139
+ " the `unet/config.json` file"
140
+ )
141
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
142
+ new_config = dict(unet.config)
143
+ new_config["sample_size"] = 64
144
+ unet._internal_dict = FrozenDict(new_config)
145
+
146
+ self.register_modules(
147
+ vae=vae,
148
+ text_encoder=text_encoder,
149
+ tokenizer=tokenizer,
150
+ unet=unet,
151
+ scheduler=scheduler,
152
+ safety_checker=safety_checker,
153
+ feature_extractor=feature_extractor,
154
+ )
155
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
156
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
157
+
158
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
159
+ r"""
160
+ Enable sliced attention computation.
161
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
162
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
163
+ Args:
164
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
165
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
166
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
167
+ `attention_head_dim` must be a multiple of `slice_size`.
168
+ """
169
+ if slice_size == "auto":
170
+ if isinstance(self.unet.config.attention_head_dim, int):
171
+ # half the attention head size is usually a good trade-off between
172
+ # speed and memory
173
+ slice_size = self.unet.config.attention_head_dim // 2
174
+ else:
175
+ # if `attention_head_dim` is a list, take the smallest head size
176
+ slice_size = min(self.unet.config.attention_head_dim)
177
+
178
+ self.unet.set_attention_slice(slice_size)
179
+
180
+ def disable_attention_slicing(self):
181
+ r"""
182
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
183
+ back to computing attention in one step.
184
+ """
185
+ # set slice_size = `None` to disable `attention slicing`
186
+ self.enable_attention_slicing(None)
187
+
188
+ @torch.no_grad()
189
+ def __call__(
190
+ self,
191
+ prompt: Optional[Union[str, List[str]]] = None,
192
+ height: Optional[int] = None,
193
+ width: Optional[int] = None,
194
+ num_inference_steps: int = 50,
195
+ guidance_scale: float = 7.5,
196
+ negative_prompt: Optional[Union[str, List[str]]] = None,
197
+ num_images_per_prompt: Optional[int] = 1,
198
+ eta: float = 0.0,
199
+ generator: Optional[torch.Generator] = None,
200
+ latents: Optional[torch.FloatTensor] = None,
201
+ output_type: Optional[str] = "pil",
202
+ return_dict: bool = True,
203
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
204
+ callback_steps: Optional[int] = 1,
205
+ text_embeddings: Optional[torch.FloatTensor] = None,
206
+ **kwargs,
207
+ ):
208
+ r"""
209
+ Function invoked when calling the pipeline for generation.
210
+ Args:
211
+ prompt (`str` or `List[str]`, *optional*, defaults to `None`):
212
+ The prompt or prompts to guide the image generation. If not provided, `text_embeddings` is required.
213
+ height (`int`, *optional*, defaults to 512):
214
+ The height in pixels of the generated image.
215
+ width (`int`, *optional*, defaults to 512):
216
+ The width in pixels of the generated image.
217
+ num_inference_steps (`int`, *optional*, defaults to 50):
218
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
219
+ expense of slower inference.
220
+ guidance_scale (`float`, *optional*, defaults to 7.5):
221
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
222
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
223
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
224
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
225
+ usually at the expense of lower image quality.
226
+ negative_prompt (`str` or `List[str]`, *optional*):
227
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
228
+ if `guidance_scale` is less than `1`).
229
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
230
+ The number of images to generate per prompt.
231
+ eta (`float`, *optional*, defaults to 0.0):
232
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
233
+ [`schedulers.DDIMScheduler`], will be ignored for others.
234
+ generator (`torch.Generator`, *optional*):
235
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
236
+ deterministic.
237
+ latents (`torch.FloatTensor`, *optional*):
238
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
239
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
240
+ tensor will ge generated by sampling using the supplied random `generator`.
241
+ output_type (`str`, *optional*, defaults to `"pil"`):
242
+ The output format of the generate image. Choose between
243
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
244
+ return_dict (`bool`, *optional*, defaults to `True`):
245
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
246
+ plain tuple.
247
+ callback (`Callable`, *optional*):
248
+ A function that will be called every `callback_steps` steps during inference. The function will be
249
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
250
+ callback_steps (`int`, *optional*, defaults to 1):
251
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
252
+ called at every step.
253
+ text_embeddings (`torch.FloatTensor`, *optional*, defaults to `None`):
254
+ Pre-generated text embeddings to be used as inputs for image generation. Can be used in place of
255
+ `prompt` to avoid re-computing the embeddings. If not provided, the embeddings will be generated from
256
+ the supplied `prompt`.
257
+ Returns:
258
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
259
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
260
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
261
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
262
+ (nsfw) content, according to the `safety_checker`.
263
+ """
264
+ # 0. Default height and width to unet
265
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
266
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
267
+
268
+ if height % 8 != 0 or width % 8 != 0:
269
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
270
+
271
+ if (callback_steps is None) or (
272
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
273
+ ):
274
+ raise ValueError(
275
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
276
+ f" {type(callback_steps)}."
277
+ )
278
+
279
+ if text_embeddings is None:
280
+ if isinstance(prompt, str):
281
+ batch_size = 1
282
+ elif isinstance(prompt, list):
283
+ batch_size = len(prompt)
284
+ else:
285
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
286
+
287
+ # get prompt text embeddings
288
+ text_inputs = self.tokenizer(
289
+ prompt,
290
+ padding="max_length",
291
+ max_length=self.tokenizer.model_max_length,
292
+ return_tensors="pt",
293
+ )
294
+ text_input_ids = text_inputs.input_ids
295
+
296
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
297
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
298
+ print(
299
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
300
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
301
+ )
302
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
303
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
304
+ else:
305
+ batch_size = text_embeddings.shape[0]
306
+
307
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
308
+ bs_embed, seq_len, _ = text_embeddings.shape
309
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
310
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
311
+
312
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
313
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
314
+ # corresponds to doing no classifier free guidance.
315
+ do_classifier_free_guidance = guidance_scale > 1.0
316
+ # get unconditional embeddings for classifier free guidance
317
+ if do_classifier_free_guidance:
318
+ uncond_tokens: List[str]
319
+ if negative_prompt is None:
320
+ uncond_tokens = [""]
321
+ elif text_embeddings is None and type(prompt) is not type(negative_prompt):
322
+ raise TypeError(
323
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
324
+ f" {type(prompt)}."
325
+ )
326
+ elif isinstance(negative_prompt, str):
327
+ uncond_tokens = [negative_prompt]
328
+ elif batch_size != len(negative_prompt):
329
+ raise ValueError(
330
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
331
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
332
+ " the batch size of `prompt`."
333
+ )
334
+ else:
335
+ uncond_tokens = negative_prompt
336
+
337
+ max_length = self.tokenizer.model_max_length
338
+ uncond_input = self.tokenizer(
339
+ uncond_tokens,
340
+ padding="max_length",
341
+ max_length=max_length,
342
+ truncation=True,
343
+ return_tensors="pt",
344
+ )
345
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
346
+
347
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
348
+ seq_len = uncond_embeddings.shape[1]
349
+ uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
350
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
351
+
352
+ # For classifier free guidance, we need to do two forward passes.
353
+ # Here we concatenate the unconditional and text embeddings into a single batch
354
+ # to avoid doing two forward passes
355
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
356
+
357
+ # get the initial random noise unless the user supplied it
358
+
359
+ # Unlike in other pipelines, latents need to be generated in the target device
360
+ # for 1-to-1 results reproducibility with the CompVis implementation.
361
+ # However this currently doesn't work in `mps`.
362
+ latents_shape = (
363
+ batch_size * num_images_per_prompt,
364
+ self.unet.in_channels,
365
+ height // 8,
366
+ width // 8,
367
+ )
368
+ latents_dtype = text_embeddings.dtype
369
+ if latents is None:
370
+ if self.device.type == "mps":
371
+ # randn does not exist on mps
372
+ latents = torch.randn(
373
+ latents_shape,
374
+ generator=generator,
375
+ device="cpu",
376
+ dtype=latents_dtype,
377
+ ).to(self.device)
378
+ else:
379
+ latents = torch.randn(
380
+ latents_shape,
381
+ generator=generator,
382
+ device=self.device,
383
+ dtype=latents_dtype,
384
+ )
385
+ else:
386
+ if latents.shape != latents_shape:
387
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
388
+ latents = latents.to(self.device)
389
+
390
+ # set timesteps
391
+ self.scheduler.set_timesteps(num_inference_steps)
392
+
393
+ # Some schedulers like PNDM have timesteps as arrays
394
+ # It's more optimized to move all timesteps to correct device beforehand
395
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
396
+
397
+ # scale the initial noise by the standard deviation required by the scheduler
398
+ latents = latents * self.scheduler.init_noise_sigma
399
+
400
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
401
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
402
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
403
+ # and should be between [0, 1]
404
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
405
+ extra_step_kwargs = {}
406
+ if accepts_eta:
407
+ extra_step_kwargs["eta"] = eta
408
+
409
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
410
+ # expand the latents if we are doing classifier free guidance
411
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
412
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
413
+
414
+ # predict the noise residual
415
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
416
+
417
+ # perform guidance
418
+ if do_classifier_free_guidance:
419
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
420
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
421
+
422
+ # compute the previous noisy sample x_t -> x_t-1
423
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
424
+
425
+ # call the callback, if provided
426
+ if callback is not None and i % callback_steps == 0:
427
+ callback(i, t, latents)
428
+
429
+ latents = 1 / 0.18215 * latents
430
+ image = self.vae.decode(latents).sample
431
+
432
+ image = (image / 2 + 0.5).clamp(0, 1)
433
+
434
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
435
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
436
+
437
+ if self.safety_checker is not None:
438
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
439
+ image, has_nsfw_concept = self.safety_checker(
440
+ images=image,
441
+ clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype),
442
+ )
443
+ else:
444
+ has_nsfw_concept = None
445
+
446
+ if output_type == "pil":
447
+ image = self.numpy_to_pil(image)
448
+
449
+ if not return_dict:
450
+ return (image, has_nsfw_concept)
451
+
452
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
453
+
454
+ def generate_inputs(self, prompt_a, prompt_b, seed_a, seed_b, noise_shape, T, batch_size):
455
+ embeds_a = self.embed_text(prompt_a)
456
+ embeds_b = self.embed_text(prompt_b)
457
+ latents_dtype = embeds_a.dtype
458
+ latents_a = self.init_noise(seed_a, noise_shape, latents_dtype)
459
+ latents_b = self.init_noise(seed_b, noise_shape, latents_dtype)
460
+
461
+ batch_idx = 0
462
+ embeds_batch, noise_batch = None, None
463
+ for i, t in enumerate(T):
464
+ embeds = torch.lerp(embeds_a, embeds_b, t)
465
+ noise = slerp(float(t), latents_a, latents_b)
466
+
467
+ embeds_batch = embeds if embeds_batch is None else torch.cat([embeds_batch, embeds])
468
+ noise_batch = noise if noise_batch is None else torch.cat([noise_batch, noise])
469
+ batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == T.shape[0]
470
+ if not batch_is_ready:
471
+ continue
472
+ yield batch_idx, embeds_batch, noise_batch
473
+ batch_idx += 1
474
+ del embeds_batch, noise_batch
475
+ torch.cuda.empty_cache()
476
+ embeds_batch, noise_batch = None, None
477
+
478
+ def make_clip_frames(
479
+ self,
480
+ prompt_a: str,
481
+ prompt_b: str,
482
+ seed_a: int,
483
+ seed_b: int,
484
+ num_interpolation_steps: int = 5,
485
+ save_path: Union[str, Path] = "outputs/",
486
+ num_inference_steps: int = 50,
487
+ guidance_scale: float = 7.5,
488
+ eta: float = 0.0,
489
+ height: Optional[int] = None,
490
+ width: Optional[int] = None,
491
+ upsample: bool = False,
492
+ batch_size: int = 1,
493
+ image_file_ext: str = ".png",
494
+ T: np.ndarray = None,
495
+ skip: int = 0,
496
+ negative_prompt: str = None,
497
+ step: Optional[Tuple[int, int]] = None,
498
+ ):
499
+ # 0. Default height and width to unet
500
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
501
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
502
+
503
+ save_path = Path(save_path)
504
+ save_path.mkdir(parents=True, exist_ok=True)
505
+
506
+ T = T if T is not None else np.linspace(0.0, 1.0, num_interpolation_steps)
507
+ if T.shape[0] != num_interpolation_steps:
508
+ raise ValueError(f"Unexpected T shape, got {T.shape}, expected dim 0 to be {num_interpolation_steps}")
509
+
510
+ if upsample:
511
+ if getattr(self, "upsampler", None) is None:
512
+ self.upsampler = RealESRGANModel.from_pretrained("nateraw/real-esrgan")
513
+ self.upsampler.to(self.device)
514
+
515
+ batch_generator = self.generate_inputs(
516
+ prompt_a,
517
+ prompt_b,
518
+ seed_a,
519
+ seed_b,
520
+ (1, self.unet.in_channels, height // 8, width // 8),
521
+ T[skip:],
522
+ batch_size,
523
+ )
524
+ num_batches = math.ceil(num_interpolation_steps / batch_size)
525
+
526
+ log_prefix = "" if step is None else f"[{step[0]}/{step[1]}] "
527
+
528
+ frame_index = skip
529
+ for batch_idx, embeds_batch, noise_batch in batch_generator:
530
+ if batch_size == 1:
531
+ msg = f"Generating frame {frame_index}"
532
+ else:
533
+ msg = f"Generating frames {frame_index}-{frame_index+embeds_batch.shape[0]-1}"
534
+ logger.info(f"{log_prefix}[{batch_idx}/{num_batches}] {msg}")
535
+ outputs = self(
536
+ latents=noise_batch,
537
+ text_embeddings=embeds_batch,
538
+ height=height,
539
+ width=width,
540
+ guidance_scale=guidance_scale,
541
+ eta=eta,
542
+ num_inference_steps=num_inference_steps,
543
+ output_type="pil" if not upsample else "numpy",
544
+ negative_prompt=negative_prompt,
545
+ )["images"]
546
+
547
+ for image in outputs:
548
+ frame_filepath = save_path / (f"frame%06d{image_file_ext}" % frame_index)
549
+ image = image if not upsample else self.upsampler(image)
550
+ image.save(frame_filepath)
551
+ frame_index += 1
552
+
553
+ def walk(
554
+ self,
555
+ prompts: Optional[List[str]] = None,
556
+ seeds: Optional[List[int]] = None,
557
+ num_interpolation_steps: Optional[Union[int, List[int]]] = 5, # int or list of int
558
+ output_dir: Optional[str] = "./dreams",
559
+ name: Optional[str] = None,
560
+ image_file_ext: Optional[str] = ".png",
561
+ fps: Optional[int] = 30,
562
+ num_inference_steps: Optional[int] = 50,
563
+ guidance_scale: Optional[float] = 7.5,
564
+ eta: Optional[float] = 0.0,
565
+ height: Optional[int] = None,
566
+ width: Optional[int] = None,
567
+ upsample: Optional[bool] = False,
568
+ batch_size: Optional[int] = 1,
569
+ resume: Optional[bool] = False,
570
+ audio_filepath: str = None,
571
+ audio_start_sec: Optional[Union[int, float]] = None,
572
+ margin: Optional[float] = 1.0,
573
+ smooth: Optional[float] = 0.0,
574
+ negative_prompt: Optional[str] = None,
575
+ make_video: Optional[bool] = True,
576
+ ):
577
+ """Generate a video from a sequence of prompts and seeds. Optionally, add audio to the
578
+ video to interpolate to the intensity of the audio.
579
+ Args:
580
+ prompts (Optional[List[str]], optional):
581
+ list of text prompts. Defaults to None.
582
+ seeds (Optional[List[int]], optional):
583
+ list of random seeds corresponding to prompts. Defaults to None.
584
+ num_interpolation_steps (Union[int, List[int]], *optional*):
585
+ How many interpolation steps between each prompt. Defaults to None.
586
+ output_dir (Optional[str], optional):
587
+ Where to save the video. Defaults to './dreams'.
588
+ name (Optional[str], optional):
589
+ Name of the subdirectory of output_dir. Defaults to None.
590
+ image_file_ext (Optional[str], *optional*, defaults to '.png'):
591
+ The extension to use when writing video frames.
592
+ fps (Optional[int], *optional*, defaults to 30):
593
+ The frames per second in the resulting output videos.
594
+ num_inference_steps (Optional[int], *optional*, defaults to 50):
595
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
596
+ expense of slower inference.
597
+ guidance_scale (Optional[float], *optional*, defaults to 7.5):
598
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
599
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
600
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
601
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
602
+ usually at the expense of lower image quality.
603
+ eta (Optional[float], *optional*, defaults to 0.0):
604
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
605
+ [`schedulers.DDIMScheduler`], will be ignored for others.
606
+ height (Optional[int], *optional*, defaults to None):
607
+ height of the images to generate.
608
+ width (Optional[int], *optional*, defaults to None):
609
+ width of the images to generate.
610
+ upsample (Optional[bool], *optional*, defaults to False):
611
+ When True, upsamples images with realesrgan.
612
+ batch_size (Optional[int], *optional*, defaults to 1):
613
+ Number of images to generate at once.
614
+ resume (Optional[bool], *optional*, defaults to False):
615
+ When True, resumes from the last frame in the output directory based
616
+ on available prompt config. Requires you to provide the `name` argument.
617
+ audio_filepath (str, *optional*, defaults to None):
618
+ Optional path to an audio file to influence the interpolation rate.
619
+ audio_start_sec (Optional[Union[int, float]], *optional*, defaults to 0):
620
+ Global start time of the provided audio_filepath.
621
+ margin (Optional[float], *optional*, defaults to 1.0):
622
+ Margin from librosa hpss to use for audio interpolation.
623
+ smooth (Optional[float], *optional*, defaults to 0.0):
624
+ Smoothness of the audio interpolation. 1.0 means linear interpolation.
625
+ negative_prompt (Optional[str], *optional*, defaults to None):
626
+ Optional negative prompt to use. Same across all prompts.
627
+ make_video (Optional[bool], *optional*, defaults to True):
628
+ When True, makes a video from the generated frames. If False, only
629
+ generates the frames.
630
+ This function will create sub directories for each prompt and seed pair.
631
+ For example, if you provide the following prompts and seeds:
632
+ ```
633
+ prompts = ['a dog', 'a cat', 'a bird']
634
+ seeds = [1, 2, 3]
635
+ num_interpolation_steps = 5
636
+ output_dir = 'output_dir'
637
+ name = 'name'
638
+ fps = 5
639
+ ```
640
+ Then the following directories will be created:
641
+ ```
642
+ output_dir
643
+ ├── name
644
+ │ ├── name_000000
645
+ │ │ ├── frame000000.png
646
+ │ │ ├── ...
647
+ │ │ ├── frame000004.png
648
+ │ │ ├── name_000000.mp4
649
+ │ ├── name_000001
650
+ │ │ ├── frame000000.png
651
+ │ │ ├── ...
652
+ │ │ ├── frame000004.png
653
+ │ │ ├── name_000001.mp4
654
+ │ ├── ...
655
+ │ ├── name.mp4
656
+ | |── prompt_config.json
657
+ ```
658
+ Returns:
659
+ str: The resulting video filepath. This video includes all sub directories' video clips.
660
+ """
661
+ # 0. Default height and width to unet
662
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
663
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
664
+
665
+ output_path = Path(output_dir)
666
+
667
+ name = name or time.strftime("%Y%m%d-%H%M%S")
668
+ save_path_root = output_path / name
669
+ save_path_root.mkdir(parents=True, exist_ok=True)
670
+
671
+ # Where the final video of all the clips combined will be saved
672
+ output_filepath = save_path_root / f"{name}.mp4"
673
+
674
+ # If using same number of interpolation steps between, we turn into list
675
+ if not resume and isinstance(num_interpolation_steps, int):
676
+ num_interpolation_steps = [num_interpolation_steps] * (len(prompts) - 1)
677
+
678
+ if not resume:
679
+ audio_start_sec = audio_start_sec or 0
680
+
681
+ # Save/reload prompt config
682
+ prompt_config_path = save_path_root / "prompt_config.json"
683
+ if not resume:
684
+ prompt_config_path.write_text(
685
+ json.dumps(
686
+ dict(
687
+ prompts=prompts,
688
+ seeds=seeds,
689
+ num_interpolation_steps=num_interpolation_steps,
690
+ fps=fps,
691
+ num_inference_steps=num_inference_steps,
692
+ guidance_scale=guidance_scale,
693
+ eta=eta,
694
+ upsample=upsample,
695
+ height=height,
696
+ width=width,
697
+ audio_filepath=audio_filepath,
698
+ audio_start_sec=audio_start_sec,
699
+ negative_prompt=negative_prompt,
700
+ ),
701
+ indent=2,
702
+ sort_keys=False,
703
+ )
704
+ )
705
+ else:
706
+ data = json.load(open(prompt_config_path))
707
+ prompts = data["prompts"]
708
+ seeds = data["seeds"]
709
+ num_interpolation_steps = data["num_interpolation_steps"]
710
+ fps = data["fps"]
711
+ num_inference_steps = data["num_inference_steps"]
712
+ guidance_scale = data["guidance_scale"]
713
+ eta = data["eta"]
714
+ upsample = data["upsample"]
715
+ height = data["height"]
716
+ width = data["width"]
717
+ audio_filepath = data["audio_filepath"]
718
+ audio_start_sec = data["audio_start_sec"]
719
+ negative_prompt = data.get("negative_prompt", None)
720
+
721
+ for i, (prompt_a, prompt_b, seed_a, seed_b, num_step) in enumerate(
722
+ zip(prompts, prompts[1:], seeds, seeds[1:], num_interpolation_steps)
723
+ ):
724
+ # {name}_000000 / {name}_000001 / ...
725
+ save_path = save_path_root / f"{name}_{i:06d}"
726
+
727
+ # Where the individual clips will be saved
728
+ step_output_filepath = save_path / f"{name}_{i:06d}.mp4"
729
+
730
+ # Determine if we need to resume from a previous run
731
+ skip = 0
732
+ if resume:
733
+ if step_output_filepath.exists():
734
+ print(f"Skipping {save_path} because frames already exist")
735
+ continue
736
+
737
+ existing_frames = sorted(save_path.glob(f"*{image_file_ext}"))
738
+ if existing_frames:
739
+ skip = int(existing_frames[-1].stem[-6:]) + 1
740
+ if skip + 1 >= num_step:
741
+ print(f"Skipping {save_path} because frames already exist")
742
+ continue
743
+ print(f"Resuming {save_path.name} from frame {skip}")
744
+
745
+ audio_offset = audio_start_sec + sum(num_interpolation_steps[:i]) / fps
746
+ audio_duration = num_step / fps
747
+
748
+ self.make_clip_frames(
749
+ prompt_a,
750
+ prompt_b,
751
+ seed_a,
752
+ seed_b,
753
+ num_interpolation_steps=num_step,
754
+ save_path=save_path,
755
+ num_inference_steps=num_inference_steps,
756
+ guidance_scale=guidance_scale,
757
+ eta=eta,
758
+ height=height,
759
+ width=width,
760
+ upsample=upsample,
761
+ batch_size=batch_size,
762
+ T=get_timesteps_arr(
763
+ audio_filepath,
764
+ offset=audio_offset,
765
+ duration=audio_duration,
766
+ fps=fps,
767
+ margin=margin,
768
+ smooth=smooth,
769
+ )
770
+ if audio_filepath
771
+ else None,
772
+ skip=skip,
773
+ negative_prompt=negative_prompt,
774
+ step=(i, len(prompts) - 1),
775
+ )
776
+ if make_video:
777
+ make_video_pyav(
778
+ save_path,
779
+ audio_filepath=audio_filepath,
780
+ fps=fps,
781
+ output_filepath=step_output_filepath,
782
+ glob_pattern=f"*{image_file_ext}",
783
+ audio_offset=audio_offset,
784
+ audio_duration=audio_duration,
785
+ sr=44100,
786
+ )
787
+ if make_video:
788
+ return make_video_pyav(
789
+ save_path_root,
790
+ audio_filepath=audio_filepath,
791
+ fps=fps,
792
+ audio_offset=audio_start_sec,
793
+ audio_duration=sum(num_interpolation_steps) / fps,
794
+ output_filepath=output_filepath,
795
+ glob_pattern=f"**/*{image_file_ext}",
796
+ sr=44100,
797
+ )
798
+
799
+ def embed_text(self, text, negative_prompt=None):
800
+ """Helper to embed some text"""
801
+ text_input = self.tokenizer(
802
+ text,
803
+ padding="max_length",
804
+ max_length=self.tokenizer.model_max_length,
805
+ truncation=True,
806
+ return_tensors="pt",
807
+ )
808
+ with torch.no_grad():
809
+ embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
810
+ return embed
811
+
812
+ def init_noise(self, seed, noise_shape, dtype):
813
+ """Helper to initialize noise"""
814
+ # randn does not exist on mps, so we create noise on CPU here and move it to the device after initialization
815
+ if self.device.type == "mps":
816
+ noise = torch.randn(
817
+ noise_shape,
818
+ device="cpu",
819
+ generator=torch.Generator(device="cpu").manual_seed(seed),
820
+ ).to(self.device)
821
+ else:
822
+ noise = torch.randn(
823
+ noise_shape,
824
+ device=self.device,
825
+ generator=torch.Generator(device=self.device).manual_seed(seed),
826
+ dtype=dtype,
827
+ )
828
+ return noise
829
+
830
+ @classmethod
831
+ def from_pretrained(cls, *args, tiled=False, **kwargs):
832
+ """Same as diffusers `from_pretrained` but with tiled option, which makes images tilable"""
833
+ if tiled:
834
+
835
+ def patch_conv(**patch):
836
+ cls = nn.Conv2d
837
+ init = cls.__init__
838
+
839
+ def __init__(self, *args, **kwargs):
840
+ return init(self, *args, **kwargs, **patch)
841
+
842
+ cls.__init__ = __init__
843
+
844
+ patch_conv(padding_mode="circular")
845
+
846
+ pipeline = super().from_pretrained(*args, **kwargs)
847
+ pipeline.tiled = tiled
848
+ return pipeline
video_diffusion/stable_diffusion_video/stable_video_text2video.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+
5
+ from video_diffusion.stable_diffusion_video.stable_diffusion_pipeline import StableDiffusionWalkPipeline
6
+ from video_diffusion.utils.model_list import stable_model_list
7
+
8
+
9
+ class StableDiffusionText2VideoGenerator:
10
+ def __init__(self):
11
+ self.pipe = None
12
+
13
+ def load_model(
14
+ self,
15
+ model_path,
16
+ ):
17
+ if self.pipe is None:
18
+ self.pipe = StableDiffusionWalkPipeline.from_pretrained(
19
+ model_path,
20
+ torch_dtype=torch.float16,
21
+ revision="fp16",
22
+ )
23
+
24
+ self.pipe.to("cuda")
25
+ self.pipe.enable_xformers_memory_efficient_attention()
26
+ self.pipe.enable_attention_slicing()
27
+
28
+ return self.pipe
29
+
30
+ def generate_video(
31
+ self,
32
+ model_path: str,
33
+ first_prompts: str,
34
+ second_prompts: str,
35
+ negative_prompt: str,
36
+ num_interpolation_steps: int,
37
+ guidance_scale: int,
38
+ num_inference_step: int,
39
+ height: int,
40
+ width: int,
41
+ upsample: bool,
42
+ fps=int,
43
+ ):
44
+ first_seed = np.random.randint(0, 100000)
45
+ second_seed = np.random.randint(0, 100000)
46
+ seeds = [first_seed, second_seed]
47
+ prompts = [first_prompts, second_prompts]
48
+ pipe = self.load_model(model_path=model_path)
49
+
50
+ output_video = pipe.walk(
51
+ prompts=prompts,
52
+ num_interpolation_steps=int(num_interpolation_steps),
53
+ height=height,
54
+ width=width,
55
+ guidance_scale=guidance_scale,
56
+ num_inference_steps=num_inference_step,
57
+ negative_prompt=negative_prompt,
58
+ seeds=seeds,
59
+ upsample=upsample,
60
+ fps=fps,
61
+ )
62
+
63
+ return output_video
64
+
65
+ def app():
66
+ with gr.Blocks():
67
+ with gr.Row():
68
+ with gr.Column():
69
+ stable_text2video_first_prompt = gr.Textbox(
70
+ lines=1,
71
+ placeholder="First Prompt",
72
+ show_label=False,
73
+ )
74
+ stable_text2video_second_prompt = gr.Textbox(
75
+ lines=1,
76
+ placeholder="Second Prompt",
77
+ show_label=False,
78
+ )
79
+ stable_text2video_negative_prompt = gr.Textbox(
80
+ lines=1,
81
+ placeholder="Negative Prompt ",
82
+ show_label=False,
83
+ )
84
+ with gr.Row():
85
+ with gr.Column():
86
+ stable_text2video_model_path = gr.Dropdown(
87
+ choices=stable_model_list,
88
+ label="Stable Model List",
89
+ value=stable_model_list[0],
90
+ )
91
+ stable_text2video_guidance_scale = gr.Slider(
92
+ minimum=0,
93
+ maximum=15,
94
+ step=1,
95
+ value=8.5,
96
+ label="Guidance Scale",
97
+ )
98
+ stable_text2video_num_inference_steps = gr.Slider(
99
+ minimum=1,
100
+ maximum=100,
101
+ step=1,
102
+ value=30,
103
+ label="Number of Inference Steps",
104
+ )
105
+ stable_text2video_fps = gr.Slider(
106
+ minimum=1,
107
+ maximum=60,
108
+ step=1,
109
+ value=10,
110
+ label="Fps",
111
+ )
112
+ with gr.Row():
113
+ with gr.Column():
114
+ stable_text2video_num_interpolation_steps = gr.Number(
115
+ value=10,
116
+ label="Number of Interpolation Steps",
117
+ )
118
+ stable_text2video_height = gr.Slider(
119
+ minimum=1,
120
+ maximum=1000,
121
+ step=1,
122
+ value=512,
123
+ label="Height",
124
+ )
125
+ stable_text2video_width = gr.Slider(
126
+ minimum=1,
127
+ maximum=1000,
128
+ step=1,
129
+ value=512,
130
+ label="Width",
131
+ )
132
+ stable_text2video_upsample = gr.Checkbox(
133
+ label="Upsample",
134
+ default=False,
135
+ )
136
+
137
+ text2video_generate = gr.Button(value="Generator")
138
+
139
+ with gr.Column():
140
+ text2video_output = gr.Video(label="Output")
141
+
142
+ text2video_generate.click(
143
+ fn=StableDiffusionText2VideoGenerator().generate_video,
144
+ inputs=[
145
+ stable_text2video_model_path,
146
+ stable_text2video_first_prompt,
147
+ stable_text2video_second_prompt,
148
+ stable_text2video_negative_prompt,
149
+ stable_text2video_num_interpolation_steps,
150
+ stable_text2video_guidance_scale,
151
+ stable_text2video_num_inference_steps,
152
+ stable_text2video_height,
153
+ stable_text2video_width,
154
+ stable_text2video_upsample,
155
+ stable_text2video_fps,
156
+ ],
157
+ outputs=text2video_output,
158
+ )
video_diffusion/stable_diffusion_video/upsampling.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import cv2
4
+ from diffusers.utils import logging
5
+ from huggingface_hub import hf_hub_download
6
+ from PIL import Image
7
+ from torch import nn
8
+
9
+ try:
10
+ from basicsr.archs.rrdbnet_arch import RRDBNet
11
+ from realesrgan import RealESRGANer
12
+ except ImportError as e:
13
+ raise ImportError(
14
+ "You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
15
+ "pip install realesrgan"
16
+ )
17
+
18
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
+
20
+
21
+ class RealESRGANModel(nn.Module):
22
+ def __init__(self, model_path, tile=0, tile_pad=10, pre_pad=0, fp32=False):
23
+ super().__init__()
24
+ try:
25
+ from basicsr.archs.rrdbnet_arch import RRDBNet
26
+ from realesrgan import RealESRGANer
27
+ except ImportError as e:
28
+ raise ImportError(
29
+ "You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
30
+ "pip install realesrgan"
31
+ )
32
+
33
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
34
+ self.upsampler = RealESRGANer(
35
+ scale=4, model_path=model_path, model=model, tile=tile, tile_pad=tile_pad, pre_pad=pre_pad, half=not fp32
36
+ )
37
+
38
+ def forward(self, image, outscale=4, convert_to_pil=True):
39
+ """Upsample an image array or path.
40
+ Args:
41
+ image (Union[np.ndarray, str]): Either a np array or an image path. np array is assumed to be in RGB format,
42
+ and we convert it to BGR.
43
+ outscale (int, optional): Amount to upscale the image. Defaults to 4.
44
+ convert_to_pil (bool, optional): If True, return PIL image. Otherwise, return numpy array (BGR). Defaults to True.
45
+ Returns:
46
+ Union[np.ndarray, PIL.Image.Image]: An upsampled version of the input image.
47
+ """
48
+ if isinstance(image, (str, Path)):
49
+ img = cv2.imread(image, cv2.IMREAD_UNCHANGED)
50
+ else:
51
+ img = image
52
+ img = (img * 255).round().astype("uint8")
53
+ img = img[:, :, ::-1]
54
+
55
+ image, _ = self.upsampler.enhance(img, outscale=outscale)
56
+
57
+ if convert_to_pil:
58
+ image = Image.fromarray(image[:, :, ::-1])
59
+
60
+ return image
61
+
62
+ @classmethod
63
+ def from_pretrained(cls, model_name_or_path="nateraw/real-esrgan"):
64
+ """Initialize a pretrained Real-ESRGAN upsampler.
65
+ Example:
66
+ ```python
67
+ >>> from stable_diffusion_videos import PipelineRealESRGAN
68
+ >>> pipe = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan')
69
+ >>> im_out = pipe('input_img.jpg')
70
+ ```
71
+ Args:
72
+ model_name_or_path (str, optional): The Hugging Face repo ID or path to local model. Defaults to 'nateraw/real-esrgan'.
73
+ Returns:
74
+ stable_diffusion_videos.PipelineRealESRGAN: An instance of `PipelineRealESRGAN` instantiated from pretrained model.
75
+ """
76
+ # reuploaded form official ones mentioned here:
77
+ # https://github.com/xinntao/Real-ESRGAN
78
+ if Path(model_name_or_path).exists():
79
+ file = model_name_or_path
80
+ else:
81
+ file = hf_hub_download(model_name_or_path, "RealESRGAN_x4plus.pth")
82
+ return cls(file)
83
+
84
+ def upsample_imagefolder(self, in_dir, out_dir, suffix="out", outfile_ext=".png", recursive=False, force=False):
85
+ in_dir, out_dir = Path(in_dir), Path(out_dir)
86
+ if not in_dir.exists():
87
+ raise FileNotFoundError(f"Provided input directory {in_dir} does not exist")
88
+
89
+ out_dir.mkdir(exist_ok=True, parents=True)
90
+
91
+ generator = in_dir.rglob("*") if recursive else in_dir.glob("*")
92
+ image_paths = [x for x in generator if x.suffix.lower() in [".png", ".jpg", ".jpeg"]]
93
+ n_img = len(image_paths)
94
+ for i, image in enumerate(image_paths):
95
+ out_filepath = out_dir / (str(image.relative_to(in_dir).with_suffix("")) + suffix + outfile_ext)
96
+ if not force and out_filepath.exists():
97
+ logger.info(
98
+ f"[{i}/{n_img}] {out_filepath} already exists, skipping. To avoid skipping, pass force=True."
99
+ )
100
+ continue
101
+ logger.info(f"[{i}/{n_img}] upscaling {image}")
102
+ im = self(str(image))
103
+ out_filepath.parent.mkdir(parents=True, exist_ok=True)
104
+ im.save(out_filepath)
video_diffusion/stable_diffusion_video/utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Union
3
+
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from torchvision.io import write_video
9
+ from torchvision.transforms.functional import pil_to_tensor
10
+
11
+
12
+ def get_timesteps_arr(audio_filepath, offset, duration, fps=30, margin=1.0, smooth=0.0):
13
+ y, sr = librosa.load(audio_filepath, offset=offset, duration=duration)
14
+
15
+ # librosa.stft hardcoded defaults...
16
+ # n_fft defaults to 2048
17
+ # hop length is win_length // 4
18
+ # win_length defaults to n_fft
19
+ D = librosa.stft(y, n_fft=2048, hop_length=2048 // 4, win_length=2048)
20
+
21
+ # Extract percussive elements
22
+ D_harmonic, D_percussive = librosa.decompose.hpss(D, margin=margin)
23
+ y_percussive = librosa.istft(D_percussive, length=len(y))
24
+
25
+ # Get normalized melspectrogram
26
+ spec_raw = librosa.feature.melspectrogram(y=y_percussive, sr=sr)
27
+ spec_max = np.amax(spec_raw, axis=0)
28
+ spec_norm = (spec_max - np.min(spec_max)) / np.ptp(spec_max)
29
+
30
+ # Resize cumsum of spec norm to our desired number of interpolation frames
31
+ x_norm = np.linspace(0, spec_norm.shape[-1], spec_norm.shape[-1])
32
+ y_norm = np.cumsum(spec_norm)
33
+ y_norm /= y_norm[-1]
34
+ x_resize = np.linspace(0, y_norm.shape[-1], int(duration * fps))
35
+
36
+ T = np.interp(x_resize, x_norm, y_norm)
37
+
38
+ # Apply smoothing
39
+ return T * (1 - smooth) + np.linspace(0.0, 1.0, T.shape[0]) * smooth
40
+
41
+
42
+ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
43
+ """helper function to spherically interpolate two arrays v1 v2"""
44
+
45
+ inputs_are_torch = isinstance(v0, torch.Tensor)
46
+ if inputs_are_torch:
47
+ input_device = v0.device
48
+ v0 = v0.cpu().numpy()
49
+ v1 = v1.cpu().numpy()
50
+
51
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
52
+ if np.abs(dot) > DOT_THRESHOLD:
53
+ v2 = (1 - t) * v0 + t * v1
54
+ else:
55
+ theta_0 = np.arccos(dot)
56
+ sin_theta_0 = np.sin(theta_0)
57
+ theta_t = theta_0 * t
58
+ sin_theta_t = np.sin(theta_t)
59
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
60
+ s1 = sin_theta_t / sin_theta_0
61
+ v2 = s0 * v0 + s1 * v1
62
+
63
+ if inputs_are_torch:
64
+ v2 = torch.from_numpy(v2).to(input_device)
65
+
66
+ return v2
67
+
68
+
69
+ def make_video_pyav(
70
+ frames_or_frame_dir: Union[str, Path, torch.Tensor],
71
+ audio_filepath: Union[str, Path] = None,
72
+ fps: int = 30,
73
+ audio_offset: int = 0,
74
+ audio_duration: int = 2,
75
+ sr: int = 22050,
76
+ output_filepath: Union[str, Path] = "output.mp4",
77
+ glob_pattern: str = "*.png",
78
+ ):
79
+ """
80
+ TODO - docstring here
81
+ frames_or_frame_dir: (Union[str, Path, torch.Tensor]):
82
+ Either a directory of images, or a tensor of shape (T, C, H, W) in range [0, 255].
83
+ """
84
+
85
+ # Torchvision write_video doesn't support pathlib paths
86
+ output_filepath = str(output_filepath)
87
+
88
+ if isinstance(frames_or_frame_dir, (str, Path)):
89
+ frames = None
90
+ for img in sorted(Path(frames_or_frame_dir).glob(glob_pattern)):
91
+ frame = pil_to_tensor(Image.open(img)).unsqueeze(0)
92
+ frames = frame if frames is None else torch.cat([frames, frame])
93
+ else:
94
+ frames = frames_or_frame_dir
95
+
96
+ # TCHW -> THWC
97
+ frames = frames.permute(0, 2, 3, 1)
98
+
99
+ if audio_filepath:
100
+ # Read audio, convert to tensor
101
+ audio, sr = librosa.load(
102
+ audio_filepath,
103
+ sr=sr,
104
+ mono=True,
105
+ offset=audio_offset,
106
+ duration=audio_duration,
107
+ )
108
+ audio_tensor = torch.tensor(audio).unsqueeze(0)
109
+
110
+ write_video(
111
+ output_filepath,
112
+ frames,
113
+ fps=fps,
114
+ audio_array=audio_tensor,
115
+ audio_fps=sr,
116
+ audio_codec="aac",
117
+ options={"crf": "10", "pix_fmt": "yuv420p"},
118
+ )
119
+ else:
120
+ write_video(
121
+ output_filepath,
122
+ frames,
123
+ fps=fps,
124
+ options={"crf": "10", "pix_fmt": "yuv420p"},
125
+ )
126
+
127
+ return output_filepath
128
+
129
+
130
+ def pad_along_axis(array: np.ndarray, pad_size: int, axis: int = 0) -> np.ndarray:
131
+ if pad_size <= 0:
132
+ return array
133
+ npad = [(0, 0)] * array.ndim
134
+ npad[axis] = (0, pad_size)
135
+ return np.pad(array, pad_width=npad, mode="constant", constant_values=0)
video_diffusion/tuneavideo/models/attention.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.models.attention import AdaLayerNorm, FeedForward
10
+ from diffusers.models.cross_attention import CrossAttention
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from einops import rearrange, repeat
15
+ from torch import nn
16
+
17
+
18
+ @dataclass
19
+ class Transformer3DModelOutput(BaseOutput):
20
+ sample: torch.FloatTensor
21
+
22
+
23
+ if is_xformers_available():
24
+ import xformers
25
+ import xformers.ops
26
+ else:
27
+ xformers = None
28
+
29
+
30
+ class Transformer3DModel(ModelMixin, ConfigMixin):
31
+ @register_to_config
32
+ def __init__(
33
+ self,
34
+ num_attention_heads: int = 16,
35
+ attention_head_dim: int = 88,
36
+ in_channels: Optional[int] = None,
37
+ num_layers: int = 1,
38
+ dropout: float = 0.0,
39
+ norm_num_groups: int = 32,
40
+ cross_attention_dim: Optional[int] = None,
41
+ attention_bias: bool = False,
42
+ activation_fn: str = "geglu",
43
+ num_embeds_ada_norm: Optional[int] = None,
44
+ use_linear_projection: bool = False,
45
+ only_cross_attention: bool = False,
46
+ upcast_attention: bool = False,
47
+ ):
48
+ super().__init__()
49
+ self.use_linear_projection = use_linear_projection
50
+ self.num_attention_heads = num_attention_heads
51
+ self.attention_head_dim = attention_head_dim
52
+ inner_dim = num_attention_heads * attention_head_dim
53
+
54
+ # Define input layers
55
+ self.in_channels = in_channels
56
+
57
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
58
+ if use_linear_projection:
59
+ self.proj_in = nn.Linear(in_channels, inner_dim)
60
+ else:
61
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
62
+
63
+ # Define transformers blocks
64
+ self.transformer_blocks = nn.ModuleList(
65
+ [
66
+ BasicTransformerBlock(
67
+ inner_dim,
68
+ num_attention_heads,
69
+ attention_head_dim,
70
+ dropout=dropout,
71
+ cross_attention_dim=cross_attention_dim,
72
+ activation_fn=activation_fn,
73
+ num_embeds_ada_norm=num_embeds_ada_norm,
74
+ attention_bias=attention_bias,
75
+ only_cross_attention=only_cross_attention,
76
+ upcast_attention=upcast_attention,
77
+ )
78
+ for d in range(num_layers)
79
+ ]
80
+ )
81
+
82
+ # 4. Define output layers
83
+ if use_linear_projection:
84
+ self.proj_out = nn.Linear(in_channels, inner_dim)
85
+ else:
86
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
87
+
88
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
89
+ # Input
90
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
91
+ video_length = hidden_states.shape[2]
92
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
93
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b f) n c", f=video_length)
94
+
95
+ batch, channel, height, weight = hidden_states.shape
96
+ residual = hidden_states
97
+
98
+ hidden_states = self.norm(hidden_states)
99
+ if not self.use_linear_projection:
100
+ hidden_states = self.proj_in(hidden_states)
101
+ inner_dim = hidden_states.shape[1]
102
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
103
+ else:
104
+ inner_dim = hidden_states.shape[1]
105
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
106
+ hidden_states = self.proj_in(hidden_states)
107
+
108
+ # Blocks
109
+ for block in self.transformer_blocks:
110
+ hidden_states = block(
111
+ hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, video_length=video_length
112
+ )
113
+
114
+ # Output
115
+ if not self.use_linear_projection:
116
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
117
+ hidden_states = self.proj_out(hidden_states)
118
+ else:
119
+ hidden_states = self.proj_out(hidden_states)
120
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
121
+
122
+ output = hidden_states + residual
123
+
124
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
125
+ if not return_dict:
126
+ return (output,)
127
+
128
+ return Transformer3DModelOutput(sample=output)
129
+
130
+
131
+ class BasicTransformerBlock(nn.Module):
132
+ def __init__(
133
+ self,
134
+ dim: int,
135
+ num_attention_heads: int,
136
+ attention_head_dim: int,
137
+ dropout=0.0,
138
+ cross_attention_dim: Optional[int] = None,
139
+ activation_fn: str = "geglu",
140
+ num_embeds_ada_norm: Optional[int] = None,
141
+ attention_bias: bool = False,
142
+ only_cross_attention: bool = False,
143
+ upcast_attention: bool = False,
144
+ ):
145
+ super().__init__()
146
+ self.only_cross_attention = only_cross_attention
147
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
148
+
149
+ # SC-Attn
150
+ self.attn1 = SparseCausalAttention(
151
+ query_dim=dim,
152
+ heads=num_attention_heads,
153
+ dim_head=attention_head_dim,
154
+ dropout=dropout,
155
+ bias=attention_bias,
156
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
157
+ upcast_attention=upcast_attention,
158
+ )
159
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
160
+
161
+ # Cross-Attn
162
+ if cross_attention_dim is not None:
163
+ self.attn2 = CrossAttention(
164
+ query_dim=dim,
165
+ cross_attention_dim=cross_attention_dim,
166
+ heads=num_attention_heads,
167
+ dim_head=attention_head_dim,
168
+ dropout=dropout,
169
+ bias=attention_bias,
170
+ upcast_attention=upcast_attention,
171
+ )
172
+ else:
173
+ self.attn2 = None
174
+
175
+ if cross_attention_dim is not None:
176
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
177
+ else:
178
+ self.norm2 = None
179
+
180
+ # Feed-forward
181
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
182
+ self.norm3 = nn.LayerNorm(dim)
183
+
184
+ # Temp-Attn
185
+ self.attn_temp = CrossAttention(
186
+ query_dim=dim,
187
+ heads=num_attention_heads,
188
+ dim_head=attention_head_dim,
189
+ dropout=dropout,
190
+ bias=attention_bias,
191
+ upcast_attention=upcast_attention,
192
+ )
193
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
194
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
195
+
196
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
197
+ if not is_xformers_available():
198
+ print("Here is how to install it")
199
+ raise ModuleNotFoundError(
200
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
201
+ " xformers",
202
+ name="xformers",
203
+ )
204
+ elif not torch.cuda.is_available():
205
+ raise ValueError(
206
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
207
+ " available for GPU "
208
+ )
209
+ else:
210
+ try:
211
+ # Make sure we can run the memory efficient attention
212
+ _ = xformers.ops.memory_efficient_attention(
213
+ torch.randn((1, 2, 40), device="cuda"),
214
+ torch.randn((1, 2, 40), device="cuda"),
215
+ torch.randn((1, 2, 40), device="cuda"),
216
+ )
217
+ except Exception as e:
218
+ raise e
219
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
220
+ if self.attn2 is not None:
221
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
222
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
223
+
224
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
225
+ # SparseCausal-Attention
226
+ norm_hidden_states = (
227
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
228
+ )
229
+
230
+ if self.only_cross_attention:
231
+ hidden_states = (
232
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
233
+ )
234
+ else:
235
+ hidden_states = (
236
+ self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
237
+ )
238
+
239
+ if self.attn2 is not None:
240
+ # Cross-Attention
241
+ norm_hidden_states = (
242
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
243
+ )
244
+ hidden_states = (
245
+ self.attn2(
246
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
247
+ )
248
+ + hidden_states
249
+ )
250
+
251
+ # Feed-forward
252
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
253
+
254
+ # Temporal-Attention
255
+ d = hidden_states.shape[1]
256
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
257
+ norm_hidden_states = (
258
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
259
+ )
260
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
261
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
262
+
263
+ return hidden_states
264
+
265
+
266
+ class SparseCausalAttention(CrossAttention):
267
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
268
+ batch_size, sequence_length, _ = hidden_states.shape
269
+
270
+ encoder_hidden_states = encoder_hidden_states
271
+
272
+ if self.group_norm is not None:
273
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
274
+
275
+ query = self.to_q(hidden_states)
276
+ dim = query.shape[-1]
277
+ query = self.reshape_heads_to_batch_dim(query)
278
+
279
+ if self.added_kv_proj_dim is not None:
280
+ raise NotImplementedError
281
+
282
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
283
+ key = self.to_k(encoder_hidden_states)
284
+ value = self.to_v(encoder_hidden_states)
285
+
286
+ former_frame_index = torch.arange(video_length) - 1
287
+ former_frame_index[0] = 0
288
+
289
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
290
+ key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
291
+ key = rearrange(key, "b f d c -> (b f) d c")
292
+
293
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
294
+ value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
295
+ value = rearrange(value, "b f d c -> (b f) d c")
296
+
297
+ key = self.reshape_heads_to_batch_dim(key)
298
+ value = self.reshape_heads_to_batch_dim(value)
299
+
300
+ if attention_mask is not None:
301
+ if attention_mask.shape[-1] != query.shape[1]:
302
+ target_length = query.shape[1]
303
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
304
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
305
+
306
+ # attention, what we cannot get enough of
307
+ if self._use_memory_efficient_attention_xformers:
308
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
309
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
310
+ hidden_states = hidden_states.to(query.dtype)
311
+ else:
312
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
313
+ hidden_states = self._attention(query, key, value, attention_mask)
314
+ else:
315
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
316
+
317
+ # linear proj
318
+ hidden_states = self.to_out[0](hidden_states)
319
+
320
+ # dropout
321
+ hidden_states = self.to_out[1](hidden_states)
322
+ return hidden_states
video_diffusion/tuneavideo/models/resnet.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+
9
+ class InflatedConv3d(nn.Conv2d):
10
+ def forward(self, x):
11
+ video_length = x.shape[2]
12
+
13
+ x = rearrange(x, "b c f h w -> (b f) c h w")
14
+ x = super().forward(x)
15
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
16
+
17
+ return x
18
+
19
+
20
+ class Upsample3D(nn.Module):
21
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
22
+ super().__init__()
23
+ self.channels = channels
24
+ self.out_channels = out_channels or channels
25
+ self.use_conv = use_conv
26
+ self.use_conv_transpose = use_conv_transpose
27
+ self.name = name
28
+
29
+ conv = None
30
+ if use_conv_transpose:
31
+ raise NotImplementedError
32
+ elif use_conv:
33
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
34
+
35
+ if name == "conv":
36
+ self.conv = conv
37
+ else:
38
+ self.Conv2d_0 = conv
39
+
40
+ def forward(self, hidden_states, output_size=None):
41
+ assert hidden_states.shape[1] == self.channels
42
+
43
+ if self.use_conv_transpose:
44
+ raise NotImplementedError
45
+
46
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
47
+ dtype = hidden_states.dtype
48
+ if dtype == torch.bfloat16:
49
+ hidden_states = hidden_states.to(torch.float32)
50
+
51
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
52
+ if hidden_states.shape[0] >= 64:
53
+ hidden_states = hidden_states.contiguous()
54
+
55
+ # if `output_size` is passed we force the interpolation output
56
+ # size and do not make use of `scale_factor=2`
57
+ if output_size is None:
58
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
59
+ else:
60
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
61
+
62
+ # If the input is bfloat16, we cast back to bfloat16
63
+ if dtype == torch.bfloat16:
64
+ hidden_states = hidden_states.to(dtype)
65
+
66
+ if self.use_conv:
67
+ if self.name == "conv":
68
+ hidden_states = self.conv(hidden_states)
69
+ else:
70
+ hidden_states = self.Conv2d_0(hidden_states)
71
+
72
+ return hidden_states
73
+
74
+
75
+ class Downsample3D(nn.Module):
76
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
77
+ super().__init__()
78
+ self.channels = channels
79
+ self.out_channels = out_channels or channels
80
+ self.use_conv = use_conv
81
+ self.padding = padding
82
+ stride = 2
83
+ self.name = name
84
+
85
+ if use_conv:
86
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
87
+ else:
88
+ raise NotImplementedError
89
+
90
+ if name == "conv":
91
+ self.Conv2d_0 = conv
92
+ self.conv = conv
93
+ elif name == "Conv2d_0":
94
+ self.conv = conv
95
+ else:
96
+ self.conv = conv
97
+
98
+ def forward(self, hidden_states):
99
+ assert hidden_states.shape[1] == self.channels
100
+ if self.use_conv and self.padding == 0:
101
+ raise NotImplementedError
102
+
103
+ assert hidden_states.shape[1] == self.channels
104
+ hidden_states = self.conv(hidden_states)
105
+
106
+ return hidden_states
107
+
108
+
109
+ class ResnetBlock3D(nn.Module):
110
+ def __init__(
111
+ self,
112
+ *,
113
+ in_channels,
114
+ out_channels=None,
115
+ conv_shortcut=False,
116
+ dropout=0.0,
117
+ temb_channels=512,
118
+ groups=32,
119
+ groups_out=None,
120
+ pre_norm=True,
121
+ eps=1e-6,
122
+ non_linearity="swish",
123
+ time_embedding_norm="default",
124
+ output_scale_factor=1.0,
125
+ use_in_shortcut=None,
126
+ ):
127
+ super().__init__()
128
+ self.pre_norm = pre_norm
129
+ self.pre_norm = True
130
+ self.in_channels = in_channels
131
+ out_channels = in_channels if out_channels is None else out_channels
132
+ self.out_channels = out_channels
133
+ self.use_conv_shortcut = conv_shortcut
134
+ self.time_embedding_norm = time_embedding_norm
135
+ self.output_scale_factor = output_scale_factor
136
+
137
+ if groups_out is None:
138
+ groups_out = groups
139
+
140
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
141
+
142
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
143
+
144
+ if temb_channels is not None:
145
+ if self.time_embedding_norm == "default":
146
+ time_emb_proj_out_channels = out_channels
147
+ elif self.time_embedding_norm == "scale_shift":
148
+ time_emb_proj_out_channels = out_channels * 2
149
+ else:
150
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
151
+
152
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
153
+ else:
154
+ self.time_emb_proj = None
155
+
156
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
157
+ self.dropout = torch.nn.Dropout(dropout)
158
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
159
+
160
+ if non_linearity == "swish":
161
+ self.nonlinearity = lambda x: F.silu(x)
162
+ elif non_linearity == "mish":
163
+ self.nonlinearity = Mish()
164
+ elif non_linearity == "silu":
165
+ self.nonlinearity = nn.SiLU()
166
+
167
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
168
+
169
+ self.conv_shortcut = None
170
+ if self.use_in_shortcut:
171
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
172
+
173
+ def forward(self, input_tensor, temb):
174
+ hidden_states = input_tensor
175
+
176
+ hidden_states = self.norm1(hidden_states)
177
+ hidden_states = self.nonlinearity(hidden_states)
178
+
179
+ hidden_states = self.conv1(hidden_states)
180
+
181
+ if temb is not None:
182
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
183
+
184
+ if temb is not None and self.time_embedding_norm == "default":
185
+ hidden_states = hidden_states + temb
186
+
187
+ hidden_states = self.norm2(hidden_states)
188
+
189
+ if temb is not None and self.time_embedding_norm == "scale_shift":
190
+ scale, shift = torch.chunk(temb, 2, dim=1)
191
+ hidden_states = hidden_states * (1 + scale) + shift
192
+
193
+ hidden_states = self.nonlinearity(hidden_states)
194
+
195
+ hidden_states = self.dropout(hidden_states)
196
+ hidden_states = self.conv2(hidden_states)
197
+
198
+ if self.conv_shortcut is not None:
199
+ input_tensor = self.conv_shortcut(input_tensor)
200
+
201
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
202
+
203
+ return output_tensor
204
+
205
+
206
+ class Mish(torch.nn.Module):
207
+ def forward(self, hidden_states):
208
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
video_diffusion/tuneavideo/models/unet.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ import json
4
+ import os
5
+ from dataclasses import dataclass
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.utils.checkpoint
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+ from diffusers.utils import BaseOutput, logging
15
+
16
+ from .resnet import InflatedConv3d
17
+ from .unet_blocks import (
18
+ CrossAttnDownBlock3D,
19
+ CrossAttnUpBlock3D,
20
+ DownBlock3D,
21
+ UNetMidBlock3DCrossAttn,
22
+ UpBlock3D,
23
+ get_down_block,
24
+ get_up_block,
25
+ )
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ class UNet3DConditionOutput(BaseOutput):
32
+ sample: torch.FloatTensor
33
+
34
+
35
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
36
+ _supports_gradient_checkpointing = True
37
+
38
+ @register_to_config
39
+ def __init__(
40
+ self,
41
+ sample_size: Optional[int] = None,
42
+ in_channels: int = 4,
43
+ out_channels: int = 4,
44
+ center_input_sample: bool = False,
45
+ flip_sin_to_cos: bool = True,
46
+ freq_shift: int = 0,
47
+ down_block_types: Tuple[str] = (
48
+ "CrossAttnDownBlock3D",
49
+ "CrossAttnDownBlock3D",
50
+ "CrossAttnDownBlock3D",
51
+ "DownBlock3D",
52
+ ),
53
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
54
+ up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
55
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
56
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
57
+ layers_per_block: int = 2,
58
+ downsample_padding: int = 1,
59
+ mid_block_scale_factor: float = 1,
60
+ act_fn: str = "silu",
61
+ norm_num_groups: int = 32,
62
+ norm_eps: float = 1e-5,
63
+ cross_attention_dim: int = 1280,
64
+ attention_head_dim: Union[int, Tuple[int]] = 8,
65
+ dual_cross_attention: bool = False,
66
+ use_linear_projection: bool = False,
67
+ class_embed_type: Optional[str] = None,
68
+ num_class_embeds: Optional[int] = None,
69
+ upcast_attention: bool = False,
70
+ resnet_time_scale_shift: str = "default",
71
+ ):
72
+ super().__init__()
73
+
74
+ self.sample_size = sample_size
75
+ time_embed_dim = block_out_channels[0] * 4
76
+
77
+ # input
78
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
79
+
80
+ # time
81
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
82
+ timestep_input_dim = block_out_channels[0]
83
+
84
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
85
+
86
+ # class embedding
87
+ if class_embed_type is None and num_class_embeds is not None:
88
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
89
+ elif class_embed_type == "timestep":
90
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
91
+ elif class_embed_type == "identity":
92
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
93
+ else:
94
+ self.class_embedding = None
95
+
96
+ self.down_blocks = nn.ModuleList([])
97
+ self.mid_block = None
98
+ self.up_blocks = nn.ModuleList([])
99
+
100
+ if isinstance(only_cross_attention, bool):
101
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
102
+
103
+ if isinstance(attention_head_dim, int):
104
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
105
+
106
+ # down
107
+ output_channel = block_out_channels[0]
108
+ for i, down_block_type in enumerate(down_block_types):
109
+ input_channel = output_channel
110
+ output_channel = block_out_channels[i]
111
+ is_final_block = i == len(block_out_channels) - 1
112
+
113
+ down_block = get_down_block(
114
+ down_block_type,
115
+ num_layers=layers_per_block,
116
+ in_channels=input_channel,
117
+ out_channels=output_channel,
118
+ temb_channels=time_embed_dim,
119
+ add_downsample=not is_final_block,
120
+ resnet_eps=norm_eps,
121
+ resnet_act_fn=act_fn,
122
+ resnet_groups=norm_num_groups,
123
+ cross_attention_dim=cross_attention_dim,
124
+ attn_num_head_channels=attention_head_dim[i],
125
+ downsample_padding=downsample_padding,
126
+ dual_cross_attention=dual_cross_attention,
127
+ use_linear_projection=use_linear_projection,
128
+ only_cross_attention=only_cross_attention[i],
129
+ upcast_attention=upcast_attention,
130
+ resnet_time_scale_shift=resnet_time_scale_shift,
131
+ )
132
+ self.down_blocks.append(down_block)
133
+
134
+ # mid
135
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
136
+ self.mid_block = UNetMidBlock3DCrossAttn(
137
+ in_channels=block_out_channels[-1],
138
+ temb_channels=time_embed_dim,
139
+ resnet_eps=norm_eps,
140
+ resnet_act_fn=act_fn,
141
+ output_scale_factor=mid_block_scale_factor,
142
+ resnet_time_scale_shift=resnet_time_scale_shift,
143
+ cross_attention_dim=cross_attention_dim,
144
+ attn_num_head_channels=attention_head_dim[-1],
145
+ resnet_groups=norm_num_groups,
146
+ dual_cross_attention=dual_cross_attention,
147
+ use_linear_projection=use_linear_projection,
148
+ upcast_attention=upcast_attention,
149
+ )
150
+ else:
151
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
152
+
153
+ # count how many layers upsample the videos
154
+ self.num_upsamplers = 0
155
+
156
+ # up
157
+ reversed_block_out_channels = list(reversed(block_out_channels))
158
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
159
+ only_cross_attention = list(reversed(only_cross_attention))
160
+ output_channel = reversed_block_out_channels[0]
161
+ for i, up_block_type in enumerate(up_block_types):
162
+ is_final_block = i == len(block_out_channels) - 1
163
+
164
+ prev_output_channel = output_channel
165
+ output_channel = reversed_block_out_channels[i]
166
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
167
+
168
+ # add upsample block for all BUT final layer
169
+ if not is_final_block:
170
+ add_upsample = True
171
+ self.num_upsamplers += 1
172
+ else:
173
+ add_upsample = False
174
+
175
+ up_block = get_up_block(
176
+ up_block_type,
177
+ num_layers=layers_per_block + 1,
178
+ in_channels=input_channel,
179
+ out_channels=output_channel,
180
+ prev_output_channel=prev_output_channel,
181
+ temb_channels=time_embed_dim,
182
+ add_upsample=add_upsample,
183
+ resnet_eps=norm_eps,
184
+ resnet_act_fn=act_fn,
185
+ resnet_groups=norm_num_groups,
186
+ cross_attention_dim=cross_attention_dim,
187
+ attn_num_head_channels=reversed_attention_head_dim[i],
188
+ dual_cross_attention=dual_cross_attention,
189
+ use_linear_projection=use_linear_projection,
190
+ only_cross_attention=only_cross_attention[i],
191
+ upcast_attention=upcast_attention,
192
+ resnet_time_scale_shift=resnet_time_scale_shift,
193
+ )
194
+ self.up_blocks.append(up_block)
195
+ prev_output_channel = output_channel
196
+
197
+ # out
198
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
199
+ self.conv_act = nn.SiLU()
200
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
201
+
202
+ def set_attention_slice(self, slice_size):
203
+ r"""
204
+ Enable sliced attention computation.
205
+
206
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
207
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
208
+
209
+ Args:
210
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
211
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
212
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
213
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
214
+ must be a multiple of `slice_size`.
215
+ """
216
+ sliceable_head_dims = []
217
+
218
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
219
+ if hasattr(module, "set_attention_slice"):
220
+ sliceable_head_dims.append(module.sliceable_head_dim)
221
+
222
+ for child in module.children():
223
+ fn_recursive_retrieve_slicable_dims(child)
224
+
225
+ # retrieve number of attention layers
226
+ for module in self.children():
227
+ fn_recursive_retrieve_slicable_dims(module)
228
+
229
+ num_slicable_layers = len(sliceable_head_dims)
230
+
231
+ if slice_size == "auto":
232
+ # half the attention head size is usually a good trade-off between
233
+ # speed and memory
234
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
235
+ elif slice_size == "max":
236
+ # make smallest slice possible
237
+ slice_size = num_slicable_layers * [1]
238
+
239
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
240
+
241
+ if len(slice_size) != len(sliceable_head_dims):
242
+ raise ValueError(
243
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
244
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
245
+ )
246
+
247
+ for i in range(len(slice_size)):
248
+ size = slice_size[i]
249
+ dim = sliceable_head_dims[i]
250
+ if size is not None and size > dim:
251
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
252
+
253
+ # Recursively walk through all the children.
254
+ # Any children which exposes the set_attention_slice method
255
+ # gets the message
256
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
257
+ if hasattr(module, "set_attention_slice"):
258
+ module.set_attention_slice(slice_size.pop())
259
+
260
+ for child in module.children():
261
+ fn_recursive_set_attention_slice(child, slice_size)
262
+
263
+ reversed_slice_size = list(reversed(slice_size))
264
+ for module in self.children():
265
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
266
+
267
+ def _set_gradient_checkpointing(self, module, value=False):
268
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
269
+ module.gradient_checkpointing = value
270
+
271
+ def forward(
272
+ self,
273
+ sample: torch.FloatTensor,
274
+ timestep: Union[torch.Tensor, float, int],
275
+ encoder_hidden_states: torch.Tensor,
276
+ class_labels: Optional[torch.Tensor] = None,
277
+ attention_mask: Optional[torch.Tensor] = None,
278
+ return_dict: bool = True,
279
+ ) -> Union[UNet3DConditionOutput, Tuple]:
280
+ r"""
281
+ Args:
282
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
283
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
284
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
285
+ return_dict (`bool`, *optional*, defaults to `True`):
286
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
287
+
288
+ Returns:
289
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
290
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
291
+ returning a tuple, the first element is the sample tensor.
292
+ """
293
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
294
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
295
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
296
+ # on the fly if necessary.
297
+ default_overall_up_factor = 2**self.num_upsamplers
298
+
299
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
300
+ forward_upsample_size = False
301
+ upsample_size = None
302
+
303
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
304
+ logger.info("Forward upsample size to force interpolation output size.")
305
+ forward_upsample_size = True
306
+
307
+ # prepare attention_mask
308
+ if attention_mask is not None:
309
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
310
+ attention_mask = attention_mask.unsqueeze(1)
311
+
312
+ # center input if necessary
313
+ if self.config.center_input_sample:
314
+ sample = 2 * sample - 1.0
315
+
316
+ # time
317
+ timesteps = timestep
318
+ if not torch.is_tensor(timesteps):
319
+ # This would be a good case for the `match` statement (Python 3.10+)
320
+ is_mps = sample.device.type == "mps"
321
+ if isinstance(timestep, float):
322
+ dtype = torch.float32 if is_mps else torch.float64
323
+ else:
324
+ dtype = torch.int32 if is_mps else torch.int64
325
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
326
+ elif len(timesteps.shape) == 0:
327
+ timesteps = timesteps[None].to(sample.device)
328
+
329
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
330
+ timesteps = timesteps.expand(sample.shape[0])
331
+
332
+ t_emb = self.time_proj(timesteps)
333
+
334
+ # timesteps does not contain any weights and will always return f32 tensors
335
+ # but time_embedding might actually be running in fp16. so we need to cast here.
336
+ # there might be better ways to encapsulate this.
337
+ t_emb = t_emb.to(dtype=self.dtype)
338
+ emb = self.time_embedding(t_emb)
339
+
340
+ if self.class_embedding is not None:
341
+ if class_labels is None:
342
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
343
+
344
+ if self.config.class_embed_type == "timestep":
345
+ class_labels = self.time_proj(class_labels)
346
+
347
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
348
+ emb = emb + class_emb
349
+
350
+ # pre-process
351
+ sample = self.conv_in(sample)
352
+
353
+ # down
354
+ down_block_res_samples = (sample,)
355
+ for downsample_block in self.down_blocks:
356
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
357
+ sample, res_samples = downsample_block(
358
+ hidden_states=sample,
359
+ temb=emb,
360
+ encoder_hidden_states=encoder_hidden_states,
361
+ attention_mask=attention_mask,
362
+ )
363
+ else:
364
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
365
+
366
+ down_block_res_samples += res_samples
367
+
368
+ # mid
369
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask)
370
+
371
+ # up
372
+ for i, upsample_block in enumerate(self.up_blocks):
373
+ is_final_block = i == len(self.up_blocks) - 1
374
+
375
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
376
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
377
+
378
+ # if we have not reached the final block and need to forward the
379
+ # upsample size, we do it here
380
+ if not is_final_block and forward_upsample_size:
381
+ upsample_size = down_block_res_samples[-1].shape[2:]
382
+
383
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
384
+ sample = upsample_block(
385
+ hidden_states=sample,
386
+ temb=emb,
387
+ res_hidden_states_tuple=res_samples,
388
+ encoder_hidden_states=encoder_hidden_states,
389
+ upsample_size=upsample_size,
390
+ attention_mask=attention_mask,
391
+ )
392
+ else:
393
+ sample = upsample_block(
394
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
395
+ )
396
+ # post-process
397
+ sample = self.conv_norm_out(sample)
398
+ sample = self.conv_act(sample)
399
+ sample = self.conv_out(sample)
400
+
401
+ if not return_dict:
402
+ return (sample,)
403
+
404
+ return UNet3DConditionOutput(sample=sample)
405
+
406
+ @classmethod
407
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None):
408
+ if subfolder is not None:
409
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
410
+
411
+ config_file = os.path.join(pretrained_model_path, "config.json")
412
+ if not os.path.isfile(config_file):
413
+ raise RuntimeError(f"{config_file} does not exist")
414
+ with open(config_file, "r") as f:
415
+ config = json.load(f)
416
+ config["_class_name"] = cls.__name__
417
+ config["down_block_types"] = [
418
+ "CrossAttnDownBlock3D",
419
+ "CrossAttnDownBlock3D",
420
+ "CrossAttnDownBlock3D",
421
+ "DownBlock3D",
422
+ ]
423
+ config["up_block_types"] = ["UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"]
424
+
425
+ from diffusers.utils import WEIGHTS_NAME
426
+
427
+ model = cls.from_config(config)
428
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
429
+ if not os.path.isfile(model_file):
430
+ raise RuntimeError(f"{model_file} does not exist")
431
+ state_dict = torch.load(model_file, map_location="cpu")
432
+ for k, v in model.state_dict().items():
433
+ if "_temp." in k:
434
+ state_dict.update({k: v})
435
+ model.load_state_dict(state_dict)
436
+
437
+ return model
video_diffusion/tuneavideo/models/unet_blocks.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .attention import Transformer3DModel
7
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8
+
9
+
10
+ def get_down_block(
11
+ down_block_type,
12
+ num_layers,
13
+ in_channels,
14
+ out_channels,
15
+ temb_channels,
16
+ add_downsample,
17
+ resnet_eps,
18
+ resnet_act_fn,
19
+ attn_num_head_channels,
20
+ resnet_groups=None,
21
+ cross_attention_dim=None,
22
+ downsample_padding=None,
23
+ dual_cross_attention=False,
24
+ use_linear_projection=False,
25
+ only_cross_attention=False,
26
+ upcast_attention=False,
27
+ resnet_time_scale_shift="default",
28
+ ):
29
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
30
+ if down_block_type == "DownBlock3D":
31
+ return DownBlock3D(
32
+ num_layers=num_layers,
33
+ in_channels=in_channels,
34
+ out_channels=out_channels,
35
+ temb_channels=temb_channels,
36
+ add_downsample=add_downsample,
37
+ resnet_eps=resnet_eps,
38
+ resnet_act_fn=resnet_act_fn,
39
+ resnet_groups=resnet_groups,
40
+ downsample_padding=downsample_padding,
41
+ resnet_time_scale_shift=resnet_time_scale_shift,
42
+ )
43
+ elif down_block_type == "CrossAttnDownBlock3D":
44
+ if cross_attention_dim is None:
45
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
46
+ return CrossAttnDownBlock3D(
47
+ num_layers=num_layers,
48
+ in_channels=in_channels,
49
+ out_channels=out_channels,
50
+ temb_channels=temb_channels,
51
+ add_downsample=add_downsample,
52
+ resnet_eps=resnet_eps,
53
+ resnet_act_fn=resnet_act_fn,
54
+ resnet_groups=resnet_groups,
55
+ downsample_padding=downsample_padding,
56
+ cross_attention_dim=cross_attention_dim,
57
+ attn_num_head_channels=attn_num_head_channels,
58
+ dual_cross_attention=dual_cross_attention,
59
+ use_linear_projection=use_linear_projection,
60
+ only_cross_attention=only_cross_attention,
61
+ upcast_attention=upcast_attention,
62
+ resnet_time_scale_shift=resnet_time_scale_shift,
63
+ )
64
+ raise ValueError(f"{down_block_type} does not exist.")
65
+
66
+
67
+ def get_up_block(
68
+ up_block_type,
69
+ num_layers,
70
+ in_channels,
71
+ out_channels,
72
+ prev_output_channel,
73
+ temb_channels,
74
+ add_upsample,
75
+ resnet_eps,
76
+ resnet_act_fn,
77
+ attn_num_head_channels,
78
+ resnet_groups=None,
79
+ cross_attention_dim=None,
80
+ dual_cross_attention=False,
81
+ use_linear_projection=False,
82
+ only_cross_attention=False,
83
+ upcast_attention=False,
84
+ resnet_time_scale_shift="default",
85
+ ):
86
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
87
+ if up_block_type == "UpBlock3D":
88
+ return UpBlock3D(
89
+ num_layers=num_layers,
90
+ in_channels=in_channels,
91
+ out_channels=out_channels,
92
+ prev_output_channel=prev_output_channel,
93
+ temb_channels=temb_channels,
94
+ add_upsample=add_upsample,
95
+ resnet_eps=resnet_eps,
96
+ resnet_act_fn=resnet_act_fn,
97
+ resnet_groups=resnet_groups,
98
+ resnet_time_scale_shift=resnet_time_scale_shift,
99
+ )
100
+ elif up_block_type == "CrossAttnUpBlock3D":
101
+ if cross_attention_dim is None:
102
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
103
+ return CrossAttnUpBlock3D(
104
+ num_layers=num_layers,
105
+ in_channels=in_channels,
106
+ out_channels=out_channels,
107
+ prev_output_channel=prev_output_channel,
108
+ temb_channels=temb_channels,
109
+ add_upsample=add_upsample,
110
+ resnet_eps=resnet_eps,
111
+ resnet_act_fn=resnet_act_fn,
112
+ resnet_groups=resnet_groups,
113
+ cross_attention_dim=cross_attention_dim,
114
+ attn_num_head_channels=attn_num_head_channels,
115
+ dual_cross_attention=dual_cross_attention,
116
+ use_linear_projection=use_linear_projection,
117
+ only_cross_attention=only_cross_attention,
118
+ upcast_attention=upcast_attention,
119
+ resnet_time_scale_shift=resnet_time_scale_shift,
120
+ )
121
+ raise ValueError(f"{up_block_type} does not exist.")
122
+
123
+
124
+ class UNetMidBlock3DCrossAttn(nn.Module):
125
+ def __init__(
126
+ self,
127
+ in_channels: int,
128
+ temb_channels: int,
129
+ dropout: float = 0.0,
130
+ num_layers: int = 1,
131
+ resnet_eps: float = 1e-6,
132
+ resnet_time_scale_shift: str = "default",
133
+ resnet_act_fn: str = "swish",
134
+ resnet_groups: int = 32,
135
+ resnet_pre_norm: bool = True,
136
+ attn_num_head_channels=1,
137
+ output_scale_factor=1.0,
138
+ cross_attention_dim=1280,
139
+ dual_cross_attention=False,
140
+ use_linear_projection=False,
141
+ upcast_attention=False,
142
+ ):
143
+ super().__init__()
144
+
145
+ self.has_cross_attention = True
146
+ self.attn_num_head_channels = attn_num_head_channels
147
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
148
+
149
+ # there is always at least one resnet
150
+ resnets = [
151
+ ResnetBlock3D(
152
+ in_channels=in_channels,
153
+ out_channels=in_channels,
154
+ temb_channels=temb_channels,
155
+ eps=resnet_eps,
156
+ groups=resnet_groups,
157
+ dropout=dropout,
158
+ time_embedding_norm=resnet_time_scale_shift,
159
+ non_linearity=resnet_act_fn,
160
+ output_scale_factor=output_scale_factor,
161
+ pre_norm=resnet_pre_norm,
162
+ )
163
+ ]
164
+ attentions = []
165
+
166
+ for _ in range(num_layers):
167
+ if dual_cross_attention:
168
+ raise NotImplementedError
169
+ attentions.append(
170
+ Transformer3DModel(
171
+ attn_num_head_channels,
172
+ in_channels // attn_num_head_channels,
173
+ in_channels=in_channels,
174
+ num_layers=1,
175
+ cross_attention_dim=cross_attention_dim,
176
+ norm_num_groups=resnet_groups,
177
+ use_linear_projection=use_linear_projection,
178
+ upcast_attention=upcast_attention,
179
+ )
180
+ )
181
+ resnets.append(
182
+ ResnetBlock3D(
183
+ in_channels=in_channels,
184
+ out_channels=in_channels,
185
+ temb_channels=temb_channels,
186
+ eps=resnet_eps,
187
+ groups=resnet_groups,
188
+ dropout=dropout,
189
+ time_embedding_norm=resnet_time_scale_shift,
190
+ non_linearity=resnet_act_fn,
191
+ output_scale_factor=output_scale_factor,
192
+ pre_norm=resnet_pre_norm,
193
+ )
194
+ )
195
+
196
+ self.attentions = nn.ModuleList(attentions)
197
+ self.resnets = nn.ModuleList(resnets)
198
+
199
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
200
+ hidden_states = self.resnets[0](hidden_states, temb)
201
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
202
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
203
+ hidden_states = resnet(hidden_states, temb)
204
+
205
+ return hidden_states
206
+
207
+
208
+ class CrossAttnDownBlock3D(nn.Module):
209
+ def __init__(
210
+ self,
211
+ in_channels: int,
212
+ out_channels: int,
213
+ temb_channels: int,
214
+ dropout: float = 0.0,
215
+ num_layers: int = 1,
216
+ resnet_eps: float = 1e-6,
217
+ resnet_time_scale_shift: str = "default",
218
+ resnet_act_fn: str = "swish",
219
+ resnet_groups: int = 32,
220
+ resnet_pre_norm: bool = True,
221
+ attn_num_head_channels=1,
222
+ cross_attention_dim=1280,
223
+ output_scale_factor=1.0,
224
+ downsample_padding=1,
225
+ add_downsample=True,
226
+ dual_cross_attention=False,
227
+ use_linear_projection=False,
228
+ only_cross_attention=False,
229
+ upcast_attention=False,
230
+ ):
231
+ super().__init__()
232
+ resnets = []
233
+ attentions = []
234
+
235
+ self.has_cross_attention = True
236
+ self.attn_num_head_channels = attn_num_head_channels
237
+
238
+ for i in range(num_layers):
239
+ in_channels = in_channels if i == 0 else out_channels
240
+ resnets.append(
241
+ ResnetBlock3D(
242
+ in_channels=in_channels,
243
+ out_channels=out_channels,
244
+ temb_channels=temb_channels,
245
+ eps=resnet_eps,
246
+ groups=resnet_groups,
247
+ dropout=dropout,
248
+ time_embedding_norm=resnet_time_scale_shift,
249
+ non_linearity=resnet_act_fn,
250
+ output_scale_factor=output_scale_factor,
251
+ pre_norm=resnet_pre_norm,
252
+ )
253
+ )
254
+ if dual_cross_attention:
255
+ raise NotImplementedError
256
+ attentions.append(
257
+ Transformer3DModel(
258
+ attn_num_head_channels,
259
+ out_channels // attn_num_head_channels,
260
+ in_channels=out_channels,
261
+ num_layers=1,
262
+ cross_attention_dim=cross_attention_dim,
263
+ norm_num_groups=resnet_groups,
264
+ use_linear_projection=use_linear_projection,
265
+ only_cross_attention=only_cross_attention,
266
+ upcast_attention=upcast_attention,
267
+ )
268
+ )
269
+ self.attentions = nn.ModuleList(attentions)
270
+ self.resnets = nn.ModuleList(resnets)
271
+
272
+ if add_downsample:
273
+ self.downsamplers = nn.ModuleList(
274
+ [
275
+ Downsample3D(
276
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
277
+ )
278
+ ]
279
+ )
280
+ else:
281
+ self.downsamplers = None
282
+
283
+ self.gradient_checkpointing = False
284
+
285
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
286
+ output_states = ()
287
+
288
+ for resnet, attn in zip(self.resnets, self.attentions):
289
+ if self.training and self.gradient_checkpointing:
290
+
291
+ def create_custom_forward(module, return_dict=None):
292
+ def custom_forward(*inputs):
293
+ if return_dict is not None:
294
+ return module(*inputs, return_dict=return_dict)
295
+ else:
296
+ return module(*inputs)
297
+
298
+ return custom_forward
299
+
300
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
301
+ hidden_states = torch.utils.checkpoint.checkpoint(
302
+ create_custom_forward(attn, return_dict=False),
303
+ hidden_states,
304
+ encoder_hidden_states,
305
+ )[0]
306
+ else:
307
+ hidden_states = resnet(hidden_states, temb)
308
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
309
+
310
+ output_states += (hidden_states,)
311
+
312
+ if self.downsamplers is not None:
313
+ for downsampler in self.downsamplers:
314
+ hidden_states = downsampler(hidden_states)
315
+
316
+ output_states += (hidden_states,)
317
+
318
+ return hidden_states, output_states
319
+
320
+
321
+ class DownBlock3D(nn.Module):
322
+ def __init__(
323
+ self,
324
+ in_channels: int,
325
+ out_channels: int,
326
+ temb_channels: int,
327
+ dropout: float = 0.0,
328
+ num_layers: int = 1,
329
+ resnet_eps: float = 1e-6,
330
+ resnet_time_scale_shift: str = "default",
331
+ resnet_act_fn: str = "swish",
332
+ resnet_groups: int = 32,
333
+ resnet_pre_norm: bool = True,
334
+ output_scale_factor=1.0,
335
+ add_downsample=True,
336
+ downsample_padding=1,
337
+ ):
338
+ super().__init__()
339
+ resnets = []
340
+
341
+ for i in range(num_layers):
342
+ in_channels = in_channels if i == 0 else out_channels
343
+ resnets.append(
344
+ ResnetBlock3D(
345
+ in_channels=in_channels,
346
+ out_channels=out_channels,
347
+ temb_channels=temb_channels,
348
+ eps=resnet_eps,
349
+ groups=resnet_groups,
350
+ dropout=dropout,
351
+ time_embedding_norm=resnet_time_scale_shift,
352
+ non_linearity=resnet_act_fn,
353
+ output_scale_factor=output_scale_factor,
354
+ pre_norm=resnet_pre_norm,
355
+ )
356
+ )
357
+
358
+ self.resnets = nn.ModuleList(resnets)
359
+
360
+ if add_downsample:
361
+ self.downsamplers = nn.ModuleList(
362
+ [
363
+ Downsample3D(
364
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
365
+ )
366
+ ]
367
+ )
368
+ else:
369
+ self.downsamplers = None
370
+
371
+ self.gradient_checkpointing = False
372
+
373
+ def forward(self, hidden_states, temb=None):
374
+ output_states = ()
375
+
376
+ for resnet in self.resnets:
377
+ if self.training and self.gradient_checkpointing:
378
+
379
+ def create_custom_forward(module):
380
+ def custom_forward(*inputs):
381
+ return module(*inputs)
382
+
383
+ return custom_forward
384
+
385
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
386
+ else:
387
+ hidden_states = resnet(hidden_states, temb)
388
+
389
+ output_states += (hidden_states,)
390
+
391
+ if self.downsamplers is not None:
392
+ for downsampler in self.downsamplers:
393
+ hidden_states = downsampler(hidden_states)
394
+
395
+ output_states += (hidden_states,)
396
+
397
+ return hidden_states, output_states
398
+
399
+
400
+ class CrossAttnUpBlock3D(nn.Module):
401
+ def __init__(
402
+ self,
403
+ in_channels: int,
404
+ out_channels: int,
405
+ prev_output_channel: int,
406
+ temb_channels: int,
407
+ dropout: float = 0.0,
408
+ num_layers: int = 1,
409
+ resnet_eps: float = 1e-6,
410
+ resnet_time_scale_shift: str = "default",
411
+ resnet_act_fn: str = "swish",
412
+ resnet_groups: int = 32,
413
+ resnet_pre_norm: bool = True,
414
+ attn_num_head_channels=1,
415
+ cross_attention_dim=1280,
416
+ output_scale_factor=1.0,
417
+ add_upsample=True,
418
+ dual_cross_attention=False,
419
+ use_linear_projection=False,
420
+ only_cross_attention=False,
421
+ upcast_attention=False,
422
+ ):
423
+ super().__init__()
424
+ resnets = []
425
+ attentions = []
426
+
427
+ self.has_cross_attention = True
428
+ self.attn_num_head_channels = attn_num_head_channels
429
+
430
+ for i in range(num_layers):
431
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
432
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
433
+
434
+ resnets.append(
435
+ ResnetBlock3D(
436
+ in_channels=resnet_in_channels + res_skip_channels,
437
+ out_channels=out_channels,
438
+ temb_channels=temb_channels,
439
+ eps=resnet_eps,
440
+ groups=resnet_groups,
441
+ dropout=dropout,
442
+ time_embedding_norm=resnet_time_scale_shift,
443
+ non_linearity=resnet_act_fn,
444
+ output_scale_factor=output_scale_factor,
445
+ pre_norm=resnet_pre_norm,
446
+ )
447
+ )
448
+ if dual_cross_attention:
449
+ raise NotImplementedError
450
+ attentions.append(
451
+ Transformer3DModel(
452
+ attn_num_head_channels,
453
+ out_channels // attn_num_head_channels,
454
+ in_channels=out_channels,
455
+ num_layers=1,
456
+ cross_attention_dim=cross_attention_dim,
457
+ norm_num_groups=resnet_groups,
458
+ use_linear_projection=use_linear_projection,
459
+ only_cross_attention=only_cross_attention,
460
+ upcast_attention=upcast_attention,
461
+ )
462
+ )
463
+
464
+ self.attentions = nn.ModuleList(attentions)
465
+ self.resnets = nn.ModuleList(resnets)
466
+
467
+ if add_upsample:
468
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
469
+ else:
470
+ self.upsamplers = None
471
+
472
+ self.gradient_checkpointing = False
473
+
474
+ def forward(
475
+ self,
476
+ hidden_states,
477
+ res_hidden_states_tuple,
478
+ temb=None,
479
+ encoder_hidden_states=None,
480
+ upsample_size=None,
481
+ attention_mask=None,
482
+ ):
483
+ for resnet, attn in zip(self.resnets, self.attentions):
484
+ # pop res hidden states
485
+ res_hidden_states = res_hidden_states_tuple[-1]
486
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
487
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
488
+
489
+ if self.training and self.gradient_checkpointing:
490
+
491
+ def create_custom_forward(module, return_dict=None):
492
+ def custom_forward(*inputs):
493
+ if return_dict is not None:
494
+ return module(*inputs, return_dict=return_dict)
495
+ else:
496
+ return module(*inputs)
497
+
498
+ return custom_forward
499
+
500
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
501
+ hidden_states = torch.utils.checkpoint.checkpoint(
502
+ create_custom_forward(attn, return_dict=False),
503
+ hidden_states,
504
+ encoder_hidden_states,
505
+ )[0]
506
+ else:
507
+ hidden_states = resnet(hidden_states, temb)
508
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
509
+
510
+ if self.upsamplers is not None:
511
+ for upsampler in self.upsamplers:
512
+ hidden_states = upsampler(hidden_states, upsample_size)
513
+
514
+ return hidden_states
515
+
516
+
517
+ class UpBlock3D(nn.Module):
518
+ def __init__(
519
+ self,
520
+ in_channels: int,
521
+ prev_output_channel: int,
522
+ out_channels: int,
523
+ temb_channels: int,
524
+ dropout: float = 0.0,
525
+ num_layers: int = 1,
526
+ resnet_eps: float = 1e-6,
527
+ resnet_time_scale_shift: str = "default",
528
+ resnet_act_fn: str = "swish",
529
+ resnet_groups: int = 32,
530
+ resnet_pre_norm: bool = True,
531
+ output_scale_factor=1.0,
532
+ add_upsample=True,
533
+ ):
534
+ super().__init__()
535
+ resnets = []
536
+
537
+ for i in range(num_layers):
538
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
539
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
540
+
541
+ resnets.append(
542
+ ResnetBlock3D(
543
+ in_channels=resnet_in_channels + res_skip_channels,
544
+ out_channels=out_channels,
545
+ temb_channels=temb_channels,
546
+ eps=resnet_eps,
547
+ groups=resnet_groups,
548
+ dropout=dropout,
549
+ time_embedding_norm=resnet_time_scale_shift,
550
+ non_linearity=resnet_act_fn,
551
+ output_scale_factor=output_scale_factor,
552
+ pre_norm=resnet_pre_norm,
553
+ )
554
+ )
555
+
556
+ self.resnets = nn.ModuleList(resnets)
557
+
558
+ if add_upsample:
559
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
560
+ else:
561
+ self.upsamplers = None
562
+
563
+ self.gradient_checkpointing = False
564
+
565
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
566
+ for resnet in self.resnets:
567
+ # pop res hidden states
568
+ res_hidden_states = res_hidden_states_tuple[-1]
569
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
570
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
571
+
572
+ if self.training and self.gradient_checkpointing:
573
+
574
+ def create_custom_forward(module):
575
+ def custom_forward(*inputs):
576
+ return module(*inputs)
577
+
578
+ return custom_forward
579
+
580
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
581
+ else:
582
+ hidden_states = resnet(hidden_states, temb)
583
+
584
+ if self.upsamplers is not None:
585
+ for upsampler in self.upsamplers:
586
+ hidden_states = upsampler(hidden_states, upsample_size)
587
+
588
+ return hidden_states
video_diffusion/tuneavideo/pipelines/pipeline_tuneavideo.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
2
+
3
+ import inspect
4
+ from dataclasses import dataclass
5
+ from typing import Callable, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from diffusers.configuration_utils import FrozenDict
10
+ from diffusers.models import AutoencoderKL
11
+ from diffusers.pipeline_utils import DiffusionPipeline
12
+ from diffusers.schedulers import (
13
+ DDIMScheduler,
14
+ DPMSolverMultistepScheduler,
15
+ EulerAncestralDiscreteScheduler,
16
+ EulerDiscreteScheduler,
17
+ LMSDiscreteScheduler,
18
+ PNDMScheduler,
19
+ )
20
+ from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging
21
+ from einops import rearrange
22
+ from packaging import version
23
+ from transformers import CLIPTextModel, CLIPTokenizer
24
+
25
+ from ..models.unet import UNet3DConditionModel
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ class TuneAVideoPipelineOutput(BaseOutput):
32
+ videos: Union[torch.Tensor, np.ndarray]
33
+
34
+
35
+ class TuneAVideoPipeline(DiffusionPipeline):
36
+ _optional_components = []
37
+
38
+ def __init__(
39
+ self,
40
+ vae: AutoencoderKL,
41
+ text_encoder: CLIPTextModel,
42
+ tokenizer: CLIPTokenizer,
43
+ unet: UNet3DConditionModel,
44
+ scheduler: Union[
45
+ DDIMScheduler,
46
+ PNDMScheduler,
47
+ LMSDiscreteScheduler,
48
+ EulerDiscreteScheduler,
49
+ EulerAncestralDiscreteScheduler,
50
+ DPMSolverMultistepScheduler,
51
+ ],
52
+ ):
53
+ super().__init__()
54
+
55
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
56
+ deprecation_message = (
57
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
58
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
59
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
60
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
61
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
62
+ " file"
63
+ )
64
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
65
+ new_config = dict(scheduler.config)
66
+ new_config["steps_offset"] = 1
67
+ scheduler._internal_dict = FrozenDict(new_config)
68
+
69
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
70
+ deprecation_message = (
71
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
72
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
73
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
74
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
75
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
76
+ )
77
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
78
+ new_config = dict(scheduler.config)
79
+ new_config["clip_sample"] = False
80
+ scheduler._internal_dict = FrozenDict(new_config)
81
+
82
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
83
+ version.parse(unet.config._diffusers_version).base_version
84
+ ) < version.parse("0.9.0.dev0")
85
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
86
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
87
+ deprecation_message = (
88
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
89
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
90
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
91
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
92
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
93
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
94
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
95
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
96
+ " the `unet/config.json` file"
97
+ )
98
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
99
+ new_config = dict(unet.config)
100
+ new_config["sample_size"] = 64
101
+ unet._internal_dict = FrozenDict(new_config)
102
+
103
+ self.register_modules(
104
+ vae=vae,
105
+ text_encoder=text_encoder,
106
+ tokenizer=tokenizer,
107
+ unet=unet,
108
+ scheduler=scheduler,
109
+ )
110
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
111
+
112
+ def enable_vae_slicing(self):
113
+ self.vae.enable_slicing()
114
+
115
+ def disable_vae_slicing(self):
116
+ self.vae.disable_slicing()
117
+
118
+ def enable_sequential_cpu_offload(self, gpu_id=0):
119
+ if is_accelerate_available():
120
+ from accelerate import cpu_offload
121
+ else:
122
+ raise ImportError("Please install accelerate via `pip install accelerate`")
123
+
124
+ device = torch.device(f"cuda:{gpu_id}")
125
+
126
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
127
+ if cpu_offloaded_model is not None:
128
+ cpu_offload(cpu_offloaded_model, device)
129
+
130
+ @property
131
+ def _execution_device(self):
132
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
133
+ return self.device
134
+ for module in self.unet.modules():
135
+ if (
136
+ hasattr(module, "_hf_hook")
137
+ and hasattr(module._hf_hook, "execution_device")
138
+ and module._hf_hook.execution_device is not None
139
+ ):
140
+ return torch.device(module._hf_hook.execution_device)
141
+ return self.device
142
+
143
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
144
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
145
+
146
+ text_inputs = self.tokenizer(
147
+ prompt,
148
+ padding="max_length",
149
+ max_length=self.tokenizer.model_max_length,
150
+ truncation=True,
151
+ return_tensors="pt",
152
+ )
153
+ text_input_ids = text_inputs.input_ids
154
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
155
+
156
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
157
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
158
+ logger.warning(
159
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
160
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
161
+ )
162
+
163
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
164
+ attention_mask = text_inputs.attention_mask.to(device)
165
+ else:
166
+ attention_mask = None
167
+
168
+ text_embeddings = self.text_encoder(
169
+ text_input_ids.to(device),
170
+ attention_mask=attention_mask,
171
+ )
172
+ text_embeddings = text_embeddings[0]
173
+
174
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
175
+ bs_embed, seq_len, _ = text_embeddings.shape
176
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
177
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
178
+
179
+ # get unconditional embeddings for classifier free guidance
180
+ if do_classifier_free_guidance:
181
+ uncond_tokens: List[str]
182
+ if negative_prompt is None:
183
+ uncond_tokens = [""] * batch_size
184
+ elif type(prompt) is not type(negative_prompt):
185
+ raise TypeError(
186
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
187
+ f" {type(prompt)}."
188
+ )
189
+ elif isinstance(negative_prompt, str):
190
+ uncond_tokens = [negative_prompt]
191
+ elif batch_size != len(negative_prompt):
192
+ raise ValueError(
193
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
194
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
195
+ " the batch size of `prompt`."
196
+ )
197
+ else:
198
+ uncond_tokens = negative_prompt
199
+
200
+ max_length = text_input_ids.shape[-1]
201
+ uncond_input = self.tokenizer(
202
+ uncond_tokens,
203
+ padding="max_length",
204
+ max_length=max_length,
205
+ truncation=True,
206
+ return_tensors="pt",
207
+ )
208
+
209
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
210
+ attention_mask = uncond_input.attention_mask.to(device)
211
+ else:
212
+ attention_mask = None
213
+
214
+ uncond_embeddings = self.text_encoder(
215
+ uncond_input.input_ids.to(device),
216
+ attention_mask=attention_mask,
217
+ )
218
+ uncond_embeddings = uncond_embeddings[0]
219
+
220
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
221
+ seq_len = uncond_embeddings.shape[1]
222
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
223
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
224
+
225
+ # For classifier free guidance, we need to do two forward passes.
226
+ # Here we concatenate the unconditional and text embeddings into a single batch
227
+ # to avoid doing two forward passes
228
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
229
+
230
+ return text_embeddings
231
+
232
+ def decode_latents(self, latents):
233
+ video_length = latents.shape[2]
234
+ latents = 1 / 0.18215 * latents
235
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
236
+ video = self.vae.decode(latents).sample
237
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
238
+ video = (video / 2 + 0.5).clamp(0, 1)
239
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
240
+ video = video.cpu().float().numpy()
241
+ return video
242
+
243
+ def prepare_extra_step_kwargs(self, generator, eta):
244
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
245
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
246
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
247
+ # and should be between [0, 1]
248
+
249
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
250
+ extra_step_kwargs = {}
251
+ if accepts_eta:
252
+ extra_step_kwargs["eta"] = eta
253
+
254
+ # check if the scheduler accepts generator
255
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
256
+ if accepts_generator:
257
+ extra_step_kwargs["generator"] = generator
258
+ return extra_step_kwargs
259
+
260
+ def check_inputs(self, prompt, height, width, callback_steps):
261
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
262
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
263
+
264
+ if height % 8 != 0 or width % 8 != 0:
265
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
266
+
267
+ if (callback_steps is None) or (
268
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
269
+ ):
270
+ raise ValueError(
271
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
272
+ f" {type(callback_steps)}."
273
+ )
274
+
275
+ def prepare_latents(
276
+ self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None
277
+ ):
278
+ shape = (
279
+ batch_size,
280
+ num_channels_latents,
281
+ video_length,
282
+ height // self.vae_scale_factor,
283
+ width // self.vae_scale_factor,
284
+ )
285
+ if isinstance(generator, list) and len(generator) != batch_size:
286
+ raise ValueError(
287
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
288
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
289
+ )
290
+
291
+ if latents is None:
292
+ rand_device = "cpu" if device.type == "mps" else device
293
+
294
+ if isinstance(generator, list):
295
+ shape = (1,) + shape[1:]
296
+ latents = [
297
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
298
+ for i in range(batch_size)
299
+ ]
300
+ latents = torch.cat(latents, dim=0).to(device)
301
+ else:
302
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
303
+ else:
304
+ if latents.shape != shape:
305
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
306
+ latents = latents.to(device)
307
+
308
+ # scale the initial noise by the standard deviation required by the scheduler
309
+ latents = latents * self.scheduler.init_noise_sigma
310
+ return latents
311
+
312
+ @torch.no_grad()
313
+ def __call__(
314
+ self,
315
+ prompt: Union[str, List[str]],
316
+ video_length: Optional[int],
317
+ height: Optional[int] = None,
318
+ width: Optional[int] = None,
319
+ num_inference_steps: int = 50,
320
+ guidance_scale: float = 7.5,
321
+ negative_prompt: Optional[Union[str, List[str]]] = None,
322
+ num_videos_per_prompt: Optional[int] = 1,
323
+ eta: float = 0.0,
324
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
325
+ latents: Optional[torch.FloatTensor] = None,
326
+ output_type: Optional[str] = "tensor",
327
+ return_dict: bool = True,
328
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
329
+ callback_steps: Optional[int] = 1,
330
+ **kwargs,
331
+ ):
332
+ # Default height and width to unet
333
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
334
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
335
+
336
+ # Check inputs. Raise error if not correct
337
+ self.check_inputs(prompt, height, width, callback_steps)
338
+
339
+ # Define call parameters
340
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
341
+ device = self._execution_device
342
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
343
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
344
+ # corresponds to doing no classifier free guidance.
345
+ do_classifier_free_guidance = guidance_scale > 1.0
346
+
347
+ # Encode input prompt
348
+ text_embeddings = self._encode_prompt(
349
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
350
+ )
351
+
352
+ # Prepare timesteps
353
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
354
+ timesteps = self.scheduler.timesteps
355
+
356
+ # Prepare latent variables
357
+ num_channels_latents = self.unet.in_channels
358
+ latents = self.prepare_latents(
359
+ batch_size * num_videos_per_prompt,
360
+ num_channels_latents,
361
+ video_length,
362
+ height,
363
+ width,
364
+ text_embeddings.dtype,
365
+ device,
366
+ generator,
367
+ latents,
368
+ )
369
+ latents_dtype = latents.dtype
370
+
371
+ # Prepare extra step kwargs.
372
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
373
+
374
+ # Denoising loop
375
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
376
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
377
+ for i, t in enumerate(timesteps):
378
+ # expand the latents if we are doing classifier free guidance
379
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
380
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
381
+
382
+ # predict the noise residual
383
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(
384
+ dtype=latents_dtype
385
+ )
386
+
387
+ # perform guidance
388
+ if do_classifier_free_guidance:
389
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
390
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
391
+
392
+ # compute the previous noisy sample x_t -> x_t-1
393
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
394
+
395
+ # call the callback, if provided
396
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
397
+ progress_bar.update()
398
+ if callback is not None and i % callback_steps == 0:
399
+ callback(i, t, latents)
400
+
401
+ # Post-processing
402
+ video = self.decode_latents(latents)
403
+
404
+ # Convert to tensor
405
+ if output_type == "tensor":
406
+ video = torch.from_numpy(video)
407
+
408
+ if not return_dict:
409
+ return video
410
+
411
+ return TuneAVideoPipelineOutput(videos=video)
video_diffusion/tuneavideo/tuneavideo_text2video.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ from video_diffusion.tuneavideo.models.unet import UNet3DConditionModel
5
+ from video_diffusion.tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
6
+ from video_diffusion.tuneavideo.util import save_videos_grid
7
+ from video_diffusion.utils.model_list import stable_model_list
8
+
9
+ video_diffusion_model_list = [
10
+ "Tune-A-Video-library/a-man-is-surfing",
11
+ "Tune-A-Video-library/mo-di-bear-guitar",
12
+ "Tune-A-Video-library/redshift-man-skiing",
13
+ ]
14
+
15
+
16
+ class TunaVideoText2VideoGenerator:
17
+ def __init__(self):
18
+ self.pipe = None
19
+ self.unet = None
20
+
21
+ def load_model(self, video_diffusion_model_list, stable_model_list):
22
+ if self.pipe is None:
23
+ if self.unet is None:
24
+ self.unet = UNet3DConditionModel.from_pretrained(
25
+ video_diffusion_model_list, subfolder="unet", torch_dtype=torch.float16
26
+ ).to("cuda")
27
+
28
+ self.pipe = TuneAVideoPipeline.from_pretrained(
29
+ stable_model_list, unet=self.unet, torch_dtype=torch.float16
30
+ )
31
+ self.pipe.to("cuda")
32
+ self.pipe.enable_xformers_memory_efficient_attention()
33
+
34
+ return self.pipe
35
+
36
+ def generate_video(
37
+ self,
38
+ video_diffusion_model: str,
39
+ stable_model_list: str,
40
+ prompt: str,
41
+ negative_prompt: str,
42
+ video_length: int,
43
+ height: int,
44
+ width: int,
45
+ num_inference_steps: int,
46
+ guidance_scale: int,
47
+ fps: int,
48
+ ):
49
+ pipe = self.load_model(video_diffusion_model, stable_model_list)
50
+ video = pipe(
51
+ prompt,
52
+ negative_prompt=negative_prompt,
53
+ video_length=video_length,
54
+ height=height,
55
+ width=width,
56
+ num_inference_steps=num_inference_steps,
57
+ guidance_scale=guidance_scale,
58
+ ).videos
59
+
60
+ save_videos_grid(videos=video, path="output.gif", fps=fps)
61
+ return "output.gif"
62
+
63
+ def app():
64
+ with gr.Blocks():
65
+ with gr.Row():
66
+ with gr.Column():
67
+ tunevideo_video_diffusion_model_list = gr.Dropdown(
68
+ choices=video_diffusion_model_list,
69
+ label="Video Diffusion Model",
70
+ value=video_diffusion_model_list[0],
71
+ )
72
+ tunevideo_stable_model_list = gr.Dropdown(
73
+ choices=stable_model_list,
74
+ label="Stable Model List",
75
+ value=stable_model_list[0],
76
+ )
77
+ with gr.Row():
78
+ with gr.Column():
79
+ tunevideo_prompt = gr.Textbox(
80
+ lines=1,
81
+ placeholder="Prompt",
82
+ show_label=False,
83
+ )
84
+ tunevideo_video_length = gr.Slider(
85
+ minimum=1,
86
+ maximum=100,
87
+ step=1,
88
+ value=10,
89
+ label="Video Length",
90
+ )
91
+ tunevideo_num_inference_steps = gr.Slider(
92
+ minimum=1,
93
+ maximum=100,
94
+ step=1,
95
+ value=50,
96
+ label="Num Inference Steps",
97
+ )
98
+ tunevideo_fps = gr.Slider(
99
+ minimum=1,
100
+ maximum=60,
101
+ step=1,
102
+ value=5,
103
+ label="Fps",
104
+ )
105
+ with gr.Row():
106
+ with gr.Column():
107
+ tunevideo_negative_prompt = gr.Textbox(
108
+ lines=1,
109
+ placeholder="Negative Prompt",
110
+ show_label=False,
111
+ )
112
+ tunevideo_guidance_scale = gr.Slider(
113
+ minimum=1,
114
+ maximum=15,
115
+ step=1,
116
+ value=7.5,
117
+ label="Guidance Scale",
118
+ )
119
+ tunevideo_height = gr.Slider(
120
+ minimum=1,
121
+ maximum=1280,
122
+ step=32,
123
+ value=512,
124
+ label="Height",
125
+ )
126
+ tunevideo_width = gr.Slider(
127
+ minimum=1,
128
+ maximum=1280,
129
+ step=32,
130
+ value=512,
131
+ label="Width",
132
+ )
133
+ tunevideo_generate = gr.Button(value="Generator")
134
+
135
+ with gr.Column():
136
+ tunevideo_output = gr.Video(label="Output")
137
+
138
+ tunevideo_generate.click(
139
+ fn=TunaVideoText2VideoGenerator().generate_video,
140
+ inputs=[
141
+ tunevideo_video_diffusion_model_list,
142
+ tunevideo_stable_model_list,
143
+ tunevideo_prompt,
144
+ tunevideo_negative_prompt,
145
+ tunevideo_video_length,
146
+ tunevideo_height,
147
+ tunevideo_width,
148
+ tunevideo_num_inference_steps,
149
+ tunevideo_guidance_scale,
150
+ tunevideo_fps,
151
+ ],
152
+ outputs=tunevideo_output,
153
+ )
video_diffusion/tuneavideo/util.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union
3
+
4
+ import imageio
5
+ import numpy as np
6
+ import torch
7
+ import torchvision
8
+ from einops import rearrange
9
+ from tqdm import tqdm
10
+
11
+
12
+ def save_videos_grid(
13
+ videos: torch.Tensor, save_path: str = "output", path: str = "output.gif", rescale=False, n_rows=4, fps=3
14
+ ):
15
+ videos = rearrange(videos, "b c t h w -> t b c h w")
16
+ outputs = []
17
+ for x in videos:
18
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
19
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
20
+ if rescale:
21
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
22
+ x = (x * 255).numpy().astype(np.uint8)
23
+ outputs.append(x)
24
+
25
+ if not os.path.exists(save_path):
26
+ os.makedirs(save_path)
27
+
28
+ imageio.mimsave(os.path.join(save_path, path), outputs, fps=fps)
29
+ return os.path.join(save_path, path)
30
+
31
+
32
+ # DDIM Inversion
33
+ @torch.no_grad()
34
+ def init_prompt(prompt, pipeline):
35
+ uncond_input = pipeline.tokenizer(
36
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, return_tensors="pt"
37
+ )
38
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
39
+ text_input = pipeline.tokenizer(
40
+ [prompt],
41
+ padding="max_length",
42
+ max_length=pipeline.tokenizer.model_max_length,
43
+ truncation=True,
44
+ return_tensors="pt",
45
+ )
46
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
47
+ context = torch.cat([uncond_embeddings, text_embeddings])
48
+
49
+ return context
50
+
51
+
52
+ def next_step(
53
+ model_output: Union[torch.FloatTensor, np.ndarray],
54
+ timestep: int,
55
+ sample: Union[torch.FloatTensor, np.ndarray],
56
+ ddim_scheduler,
57
+ ):
58
+ timestep, next_timestep = (
59
+ min(timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999),
60
+ timestep,
61
+ )
62
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
63
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
64
+ beta_prod_t = 1 - alpha_prod_t
65
+ next_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
66
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
67
+ next_sample = alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction
68
+ return next_sample
69
+
70
+
71
+ def get_noise_pred_single(latents, t, context, unet):
72
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
73
+ return noise_pred
74
+
75
+
76
+ @torch.no_grad()
77
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
78
+ context = init_prompt(prompt, pipeline)
79
+ uncond_embeddings, cond_embeddings = context.chunk(2)
80
+ all_latent = [latent]
81
+ latent = latent.clone().detach()
82
+ for i in tqdm(range(num_inv_steps)):
83
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
84
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
85
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
86
+ all_latent.append(latent)
87
+ return all_latent
88
+
89
+
90
+ @torch.no_grad()
91
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
92
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
93
+ return ddim_latents
video_diffusion/utils/__init__.py ADDED
File without changes
video_diffusion/utils/model_list.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ stable_model_list = [
2
+ "runwayml/stable-diffusion-v1-5",
3
+ "stabilityai/stable-diffusion-2-1",
4
+ # "prompthero/openjourney-v4",
5
+ "cerspense/zeroscope_v2_576w"
6
+ ]
video_diffusion/utils/scheduler_list.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ DDIMScheduler,
3
+ DPMSolverMultistepScheduler,
4
+ EulerAncestralDiscreteScheduler,
5
+ EulerDiscreteScheduler,
6
+ HeunDiscreteScheduler,
7
+ LMSDiscreteScheduler,
8
+ )
9
+
10
+ diff_scheduler_list = ["DDIM", "EulerA", "Euler", "LMS", "Heun", "UniPC", "DPMSolver"]
11
+
12
+
13
+ def get_scheduler_list(pipe, scheduler):
14
+ if scheduler == diff_scheduler_list[0]:
15
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
16
+
17
+ elif scheduler == diff_scheduler_list[1]:
18
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
19
+
20
+ elif scheduler == diff_scheduler_list[2]:
21
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
22
+
23
+ elif scheduler == diff_scheduler_list[3]:
24
+ pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
25
+
26
+ elif scheduler == diff_scheduler_list[4]:
27
+ pipe.scheduler = HeunDiscreteScheduler.from_config(pipe.scheduler.config)
28
+
29
+ elif scheduler == diff_scheduler_list[5]:
30
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
31
+
32
+ return pipe
video_diffusion/zero_shot/zero_shot_text2video.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import imageio
3
+ import torch
4
+ from diffusers import TextToVideoZeroPipeline
5
+
6
+ from video_diffusion.tuneavideo.util import save_videos_grid
7
+ from video_diffusion.utils.model_list import stable_model_list
8
+
9
+
10
+ class ZeroShotText2VideoGenerator:
11
+ def __init__(self):
12
+ self.pipe = None
13
+
14
+ def load_model(self, model_id):
15
+ if self.pipe is None:
16
+ self.pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
17
+ self.pipe.to("cuda")
18
+ self.pipe.enable_xformers_memory_efficient_attention()
19
+ self.pipe.enable_attention_slicing()
20
+
21
+ return self.pipe
22
+
23
+ def generate_video(
24
+ self,
25
+ prompt,
26
+ negative_prompt,
27
+ model_id,
28
+ height,
29
+ width,
30
+ video_length,
31
+ guidance_scale,
32
+ fps,
33
+ t0,
34
+ t1,
35
+ motion_field_strength_x,
36
+ motion_field_strength_y,
37
+ ):
38
+ pipe = self.load_model(model_id)
39
+ result = pipe(
40
+ prompt=prompt,
41
+ negative_prompt=negative_prompt,
42
+ height=height,
43
+ width=width,
44
+ video_length=video_length,
45
+ guidance_scale=guidance_scale,
46
+ t0=t0,
47
+ t1=t1,
48
+ motion_field_strength_x=motion_field_strength_x,
49
+ motion_field_strength_y=motion_field_strength_y,
50
+ ).images
51
+
52
+ result = [(r * 255).astype("uint8") for r in result]
53
+ imageio.mimsave("video.mp4", result, fps=fps)
54
+ return "video.mp4"
55
+
56
+ def app():
57
+ with gr.Blocks():
58
+ with gr.Row():
59
+ with gr.Column():
60
+ zero_shot_text2video_prompt = gr.Textbox(
61
+ lines=1,
62
+ placeholder="Prompt",
63
+ show_label=False,
64
+ )
65
+ zero_shot_text2video_negative_prompt = gr.Textbox(
66
+ lines=1,
67
+ placeholder="Negative Prompt",
68
+ show_label=False,
69
+ )
70
+ zero_shot_text2video_model_id = gr.Dropdown(
71
+ choices=stable_model_list,
72
+ label="Stable Model List",
73
+ value=stable_model_list[0],
74
+ )
75
+ with gr.Row():
76
+ with gr.Column():
77
+ zero_shot_text2video_guidance_scale = gr.Slider(
78
+ label="Guidance Scale",
79
+ minimum=1,
80
+ maximum=15,
81
+ step=1,
82
+ value=7.5,
83
+ )
84
+ zero_shot_text2video_video_length = gr.Slider(
85
+ label="Video Length",
86
+ minimum=1,
87
+ maximum=100,
88
+ step=1,
89
+ value=10,
90
+ )
91
+ zero_shot_text2video_t0 = gr.Slider(
92
+ label="Timestep T0",
93
+ minimum=0,
94
+ maximum=100,
95
+ step=1,
96
+ value=44,
97
+ )
98
+ zero_shot_text2video_motion_field_strength_x = gr.Slider(
99
+ label="Motion Field Strength X",
100
+ minimum=0,
101
+ maximum=100,
102
+ step=1,
103
+ value=12,
104
+ )
105
+ zero_shot_text2video_fps = gr.Slider(
106
+ label="Fps",
107
+ minimum=1,
108
+ maximum=60,
109
+ step=1,
110
+ value=10,
111
+ )
112
+ with gr.Row():
113
+ with gr.Column():
114
+ zero_shot_text2video_height = gr.Slider(
115
+ label="Height",
116
+ minimum=128,
117
+ maximum=1280,
118
+ step=32,
119
+ value=512,
120
+ )
121
+ zero_shot_text2video_width = gr.Slider(
122
+ label="Width",
123
+ minimum=128,
124
+ maximum=1280,
125
+ step=32,
126
+ value=512,
127
+ )
128
+ zero_shot_text2video_t1 = gr.Slider(
129
+ label="Timestep T1",
130
+ minimum=0,
131
+ maximum=100,
132
+ step=1,
133
+ value=47,
134
+ )
135
+ zero_shot_text2video_motion_field_strength_y = gr.Slider(
136
+ label="Motion Field Strength Y",
137
+ minimum=0,
138
+ maximum=100,
139
+ step=1,
140
+ value=12,
141
+ )
142
+ zero_shot_text2video_button = gr.Button(value="Generator")
143
+
144
+ with gr.Column():
145
+ zero_shot_text2video_output = gr.Video(label="Output")
146
+
147
+ zero_shot_text2video_button.click(
148
+ fn=ZeroShotText2VideoGenerator().generate_video,
149
+ inputs=[
150
+ zero_shot_text2video_prompt,
151
+ zero_shot_text2video_negative_prompt,
152
+ zero_shot_text2video_model_id,
153
+ zero_shot_text2video_height,
154
+ zero_shot_text2video_width,
155
+ zero_shot_text2video_video_length,
156
+ zero_shot_text2video_guidance_scale,
157
+ zero_shot_text2video_fps,
158
+ zero_shot_text2video_t0,
159
+ zero_shot_text2video_t1,
160
+ zero_shot_text2video_motion_field_strength_x,
161
+ zero_shot_text2video_motion_field_strength_y,
162
+ ],
163
+ outputs=zero_shot_text2video_output,
164
+ )