MonsterMMORPG commited on
Commit
9445995
1 Parent(s): 88b7342

Upload 50 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. README.md +39 -0
  3. SECourses.py +393 -0
  4. assets/crop_example.jpeg +0 -0
  5. assets/demo/fix_face_weight.gif +3 -0
  6. assets/demo/gt_generate_compare.gif +3 -0
  7. assets/demo/talk_tys_fix_face_post_processing.gif +3 -0
  8. assets/demo/talk_tys_naive_retarget_post_processing.gif +3 -0
  9. assets/demo/talk_tys_offset_retarget_post_processing.gif +3 -0
  10. assets/global_framework.png +0 -0
  11. executed_command.txt +1 -0
  12. git +0 -0
  13. inference.py +291 -0
  14. inference_v2.yaml +35 -0
  15. modules/__init__.py +5 -0
  16. modules/__pycache__/__init__.cpython-310.pyc +0 -0
  17. modules/__pycache__/attention.cpython-310.pyc +0 -0
  18. modules/__pycache__/audio_projection.cpython-310.pyc +0 -0
  19. modules/__pycache__/motion_module.cpython-310.pyc +0 -0
  20. modules/__pycache__/mutual_self_attention.cpython-310.pyc +0 -0
  21. modules/__pycache__/resnet.cpython-310.pyc +0 -0
  22. modules/__pycache__/transformer_2d.cpython-310.pyc +0 -0
  23. modules/__pycache__/transformer_3d.cpython-310.pyc +0 -0
  24. modules/__pycache__/unet_2d_blocks.cpython-310.pyc +0 -0
  25. modules/__pycache__/unet_2d_condition.cpython-310.pyc +0 -0
  26. modules/__pycache__/unet_3d.cpython-310.pyc +0 -0
  27. modules/__pycache__/unet_3d_blocks.cpython-310.pyc +0 -0
  28. modules/__pycache__/v_kps_guider.cpython-310.pyc +0 -0
  29. modules/attention.py +626 -0
  30. modules/audio_projection.py +150 -0
  31. modules/motion_module.py +388 -0
  32. modules/mutual_self_attention.py +376 -0
  33. modules/resnet.py +256 -0
  34. modules/transformer_2d.py +401 -0
  35. modules/transformer_3d.py +169 -0
  36. modules/unet_2d_blocks.py +1072 -0
  37. modules/unet_2d_condition.py +1312 -0
  38. modules/unet_3d.py +698 -0
  39. modules/unet_3d_blocks.py +862 -0
  40. modules/v_kps_guider.py +45 -0
  41. pipelines/__init__.py +1 -0
  42. pipelines/__pycache__/__init__.cpython-310.pyc +0 -0
  43. pipelines/__pycache__/context.cpython-310.pyc +0 -0
  44. pipelines/__pycache__/utils.cpython-310.pyc +0 -0
  45. pipelines/__pycache__/v_express_pipeline.cpython-310.pyc +0 -0
  46. pipelines/context.py +78 -0
  47. pipelines/utils.py +218 -0
  48. pipelines/v_express_pipeline.py +586 -0
  49. requirements.txt +21 -0
  50. scripts/crop.py +108 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demo/fix_face_weight.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/demo/gt_generate_compare.gif filter=lfs diff=lfs merge=lfs -text
38
+ assets/demo/talk_tys_fix_face_post_processing.gif filter=lfs diff=lfs merge=lfs -text
39
+ assets/demo/talk_tys_naive_retarget_post_processing.gif filter=lfs diff=lfs merge=lfs -text
40
+ assets/demo/talk_tys_offset_retarget_post_processing.gif filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - text-to-image
4
+ - stable-diffusion
5
+ - audio-to-video
6
+ license: apache-2.0
7
+ language:
8
+ - en
9
+ library_name: diffusers
10
+ ---
11
+
12
+ # V-Express Model Card
13
+
14
+ <div align="center">
15
+
16
+ [**Project Page**](https://tenvence.github.io/p/v-express/) **|** [**Paper**](https://arxiv.org/abs/2406.02511) **|** [**Code**](https://github.com/tencent-ailab/V-Express)
17
+
18
+ </div>
19
+
20
+ ---
21
+
22
+ ## Introduction
23
+
24
+ ## Models
25
+
26
+ ### Audio Encoder
27
+
28
+ - [model_ckpts/wav2vec2-base-960h](https://huggingface.co/tk93/V-Express/tree/main/model_ckpts/wav2vec2-base-960h). (It is also available from the original model card [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h))
29
+
30
+ ### Face Analysis
31
+
32
+ - [model_ckpts/insightface_models/models/buffalo_l](https://huggingface.co/tk93/V-Express/tree/main/model_ckpts/insightface_models/models/buffalo_l). (It is also available from the original repository [insightface/buffalo_l](https://github.com/deepinsight/insightface/releases/download/v0.7/buffalo_l.zip))
33
+
34
+ ### V-Express
35
+
36
+ - [model_ckpts/sd-vae-ft-mse](https://huggingface.co/tk93/V-Express/tree/main/model_ckpts/sd-vae-ft-mse). VAE encoder. (original model card [stabilityai/sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse))
37
+ - [model_ckpts/stable-diffusion-v1-5](https://huggingface.co/tk93/V-Express/tree/main/model_ckpts/stable-diffusion-v1-5). Only the model configuration file for unet is needed here. (original model card [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5))
38
+ - [model_ckpts/v-express](https://huggingface.co/tk93/V-Express/tree/main/model_ckpts/v-express). The video generation model conditional on audio and V-kps we call V-Express.
39
+ - You should download and put all `.bin` model to `model_ckpts/v-express` directory, which includes `audio_projection.bin`, `denoising_unet.bin`, `motion_module.bin`, `reference_net.bin`, and `v_kps_guider.bin`.
SECourses.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import gradio as gr
4
+ from retinaface import RetinaFace
5
+ from PIL import Image
6
+ import filetype
7
+ from datetime import datetime
8
+ import re
9
+ import sys
10
+ import torch
11
+ import argparse
12
+
13
+ import platform, os
14
+
15
+ def open_folder():
16
+ open_folder_path = os.path.abspath("outputs")
17
+ if platform.system() == "Windows":
18
+ os.startfile(open_folder_path)
19
+ elif platform.system() == "Linux":
20
+ os.system(f'xdg-open "{open_folder_path}"')
21
+
22
+
23
+ # Get the path to the currently activated Python executable
24
+ python_executable = sys.executable
25
+
26
+ def display_media(file):
27
+ # Determine the type of the uploaded file using filetype
28
+ if file is None:
29
+ return gr.update(visible=False), gr.update(visible=False)
30
+ kind = filetype.guess(file.name)
31
+
32
+ if kind is None:
33
+ return gr.update(visible=False), gr.update(visible=False)
34
+
35
+ if kind.mime.startswith('video'):
36
+ return gr.update(value=file.name, visible=True), gr.update(visible=False)
37
+ elif kind.mime.startswith('audio'):
38
+ return gr.update(visible=False), gr.update(value=file.name, visible=True)
39
+ else:
40
+ return gr.update(visible=False), gr.update(visible=False)
41
+
42
+
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument("--share", type=str, default=False, help="Set to True to share the app publicly.")
45
+ args = parser.parse_args()
46
+
47
+
48
+ # Function to extract audio from video using FFmpeg
49
+ def extract_audio(video_path, audio_path):
50
+ command = [python_executable, "-m", "ffmpeg", "-i", video_path, "-vn", "-acodec", "libmp3lame", "-q:a", "2", audio_path]
51
+ subprocess.call(command)
52
+
53
+ # Function to convert audio to MP3 using FFmpeg
54
+ def convert_audio_to_mp3(audio_path, mp3_path):
55
+ command = ["ffmpeg", "-i", audio_path, "-acodec", "libmp3lame", "-q:a", "2", mp3_path]
56
+ subprocess.call(command)
57
+
58
+ def crop_and_save_image(image_path, auto_crop, crop_width, crop_height, crop_expansion):
59
+ cropped_image = auto_crop_image(image_path, crop_expansion, crop_size=(crop_width, crop_height))
60
+ if cropped_image is not None:
61
+ cropped_folder = os.path.join("outputs", "cropped_images")
62
+ os.makedirs(cropped_folder, exist_ok=True)
63
+
64
+ # Get the base name and extension of the image file
65
+ base_name, extension = os.path.splitext(os.path.basename(image_path))
66
+
67
+ # Initialize the counter for the image number
68
+ counter = 1
69
+
70
+ # Generate the new image name with the incremented number
71
+ new_image_name = f"{base_name}_{counter:04d}{extension}"
72
+ cropped_image_path = os.path.join(cropped_folder, new_image_name)
73
+
74
+ # Check if the image already exists and increment the counter until a unique name is found
75
+ while os.path.exists(cropped_image_path):
76
+ counter += 1
77
+ new_image_name = f"{base_name}_{counter:04d}{extension}"
78
+ cropped_image_path = os.path.join(cropped_folder, new_image_name)
79
+
80
+ # Save the cropped image with the new name
81
+ cropped_image.save(cropped_image_path, format='PNG')
82
+ return cropped_image_path
83
+ return None
84
+
85
+ # Function to generate kps sequence and audio from video
86
+ def generate_kps_sequence_and_audio(video_path, kps_sequence_save_path, audio_save_path):
87
+ command = [python_executable, "scripts/extract_kps_sequence_and_audio.py", "--video_path", video_path, "--kps_sequence_save_path", kps_sequence_save_path, "--audio_save_path", audio_save_path]
88
+ subprocess.call(command)
89
+
90
+ def auto_crop_image(image_path, expand_percent, crop_size=(512, 512)):
91
+ # Check if CUDA is available
92
+ if torch.cuda.is_available():
93
+ device = 'cuda'
94
+ print("Using GPU for RetinaFace detection.")
95
+ else:
96
+ device = 'cpu'
97
+ print("Using CPU for RetinaFace detection.")
98
+
99
+ # Load image
100
+ img = Image.open(image_path)
101
+
102
+ # Perform face detection
103
+ faces = RetinaFace.detect_faces(image_path)
104
+
105
+ if not faces:
106
+ print("No faces detected.")
107
+ return None
108
+
109
+ # Assuming 'faces' is a dictionary of detected faces
110
+ # Pick the first face detected
111
+ face = list(faces.values())[0]
112
+ landmarks = face['landmarks']
113
+
114
+ # Extract the landmarks
115
+ right_eye = landmarks['right_eye']
116
+ left_eye = landmarks['left_eye']
117
+ right_mouth = landmarks['mouth_right']
118
+ left_mouth = landmarks['mouth_left']
119
+
120
+ # Calculate the distance between the eyes
121
+ eye_distance = abs(right_eye[0] - left_eye[0])
122
+
123
+ # Estimate the head width and height
124
+ head_width = eye_distance * 4.5 # Increase the width multiplier
125
+ head_height = eye_distance * 6.5 # Increase the height multiplier
126
+
127
+ # Calculate the center point between the eyes
128
+ eye_center_x = (right_eye[0] + left_eye[0]) // 2
129
+ eye_center_y = (right_eye[1] + left_eye[1]) // 2
130
+
131
+ # Calculate the top-left and bottom-right coordinates of the assumed head region
132
+ head_left = max(0, int(eye_center_x - head_width // 2))
133
+ head_top = max(0, int(eye_center_y - head_height // 2)) # Adjust the top coordinate
134
+ head_right = min(img.width, int(eye_center_x + head_width // 2))
135
+ head_bottom = min(img.height, int(eye_center_y + head_height // 2)) # Adjust the bottom coordinate
136
+
137
+ # Save the assumed head image
138
+ assumed_head_img = img.crop((head_left, head_top, head_right, head_bottom))
139
+ assumed_head_img.save("assumed_head.png", format='PNG')
140
+
141
+ # Calculate the expansion in pixels and the new dimensions
142
+ expanded_w = int(head_width * (1 + expand_percent))
143
+ expanded_h = int(head_height * (1 + expand_percent))
144
+
145
+ # Calculate the top-left and bottom-right points of the expanded box
146
+ center_x, center_y = head_left + head_width // 2, head_top + head_height // 2
147
+ left = max(0, center_x - expanded_w // 2)
148
+ right = min(img.width, center_x + expanded_w // 2)
149
+ top = max(0, center_y - expanded_h // 2)
150
+ bottom = min(img.height, center_y + expanded_h // 2)
151
+
152
+ # Crop the image with the expanded boundaries
153
+ cropped_img = img.crop((left, top, right, bottom))
154
+ cropped_img.save("expanded_face.png", format='PNG')
155
+
156
+ # Calculate the aspect ratio of the cropped image
157
+ cropped_width, cropped_height = cropped_img.size
158
+ aspect_ratio = cropped_width / cropped_height
159
+
160
+ # Calculate the target dimensions based on the desired crop size
161
+ target_width = crop_size[0]
162
+ target_height = crop_size[1]
163
+
164
+ # Adjust the crop to match the desired aspect ratio
165
+ if aspect_ratio > target_width / target_height:
166
+ # Crop from left and right
167
+ new_width = int(cropped_height * target_width / target_height)
168
+ left_crop = (cropped_width - new_width) // 2
169
+ right_crop = left_crop + new_width
170
+ top_crop = 0
171
+ bottom_crop = cropped_height
172
+ else:
173
+ # Crop from top and bottom
174
+ new_height = int(cropped_width * target_height / target_width)
175
+ top_crop = (cropped_height - new_height) // 2
176
+ bottom_crop = top_crop + new_height
177
+ left_crop = 0
178
+ right_crop = cropped_width
179
+
180
+ # Crop the image with the adjusted boundaries
181
+ final_cropped_img = cropped_img.crop((left_crop, top_crop, right_crop, bottom_crop))
182
+ final_cropped_img.save("final_cropped_img.png", format='PNG')
183
+
184
+ # Resize the cropped image to the desired size (512x512 by default) with best quality
185
+ resized_img = final_cropped_img.resize(crop_size, resample=Image.LANCZOS)
186
+
187
+ # Save the resized image as PNG
188
+ resized_img.save(image_path, format='PNG')
189
+ return resized_img
190
+
191
+
192
+ def generate_output_video(reference_image_path, audio_path, kps_path, output_path, retarget_strategy, num_inference_steps, reference_attention_weight, audio_attention_weight, auto_crop, crop_width, crop_height, crop_expansion,image_width,image_height, low_vram):
193
+ print("auto cropping...")
194
+ if auto_crop:
195
+ auto_crop_image(reference_image_path,crop_expansion, crop_size=(crop_width, crop_height))
196
+
197
+ print("starting inference...")
198
+ command = [
199
+ python_executable, "inference.py",
200
+ "--reference_image_path", reference_image_path,
201
+ "--audio_path", audio_path,
202
+ "--kps_path", kps_path,
203
+ "--output_path", output_path,
204
+ "--retarget_strategy", retarget_strategy,
205
+ "--num_inference_steps", str(num_inference_steps),
206
+ "--reference_attention_weight", str(reference_attention_weight),
207
+ "--audio_attention_weight", str(audio_attention_weight),
208
+ "--image_width", str(image_width),
209
+ "--image_height", str(image_height)
210
+ ]
211
+
212
+ if low_vram: # Add the --save_gpu_memory flag if Low VRAM is checked
213
+ command.append("--save_gpu_memory")
214
+
215
+ with open("executed_command.txt", "w") as file:
216
+ file.write(" ".join(command))
217
+
218
+ subprocess.call(command)
219
+ return output_path, reference_image_path
220
+
221
+ def sanitize_folder_name(name):
222
+ # Define a regex pattern to match invalid characters for both Linux and Windows
223
+ invalid_chars = r'[<>:"/\\|?*\x00-\x1F]'
224
+ # Replace invalid characters with an underscore
225
+ sanitized_name = re.sub(invalid_chars, '_', name)
226
+ return sanitized_name
227
+
228
+ # Function to handle the input and generate the output
229
+ def process_input(reference_image, target_input, retarget_strategy, num_inference_steps, reference_attention_weight, audio_attention_weight, auto_crop, crop_width, crop_height, crop_expansion,image_width,image_height,low_vram):
230
+ # Create temp_process directory for intermediate files
231
+ temp_process_dir = "temp_process"
232
+ os.makedirs(temp_process_dir, exist_ok=True)
233
+
234
+ input_file_name = os.path.splitext(os.path.basename(reference_image))[0]
235
+ input_file_name=sanitize_folder_name(input_file_name)
236
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
237
+ temp_dir = os.path.join(temp_process_dir, f"{input_file_name}_{timestamp}")
238
+ os.makedirs(temp_dir, exist_ok=True)
239
+
240
+ kind = filetype.guess(target_input)
241
+ if not kind:
242
+ raise ValueError("Cannot determine file type. Please provide a valid video or audio file.")
243
+
244
+ mime_type = kind.mime
245
+
246
+ if mime_type.startswith("video/"): # Video input
247
+ audio_path = os.path.join(temp_dir, "target_audio.mp3")
248
+ kps_path = os.path.join(temp_dir, "kps.pth")
249
+ print("generating generate_kps_sequence_and_audio...")
250
+ generate_kps_sequence_and_audio(target_input, kps_path, audio_path)
251
+ elif mime_type.startswith("audio/"): # Audio input
252
+ audio_path = target_input
253
+ if mime_type != "audio/mpeg":
254
+ mp3_path = os.path.join(temp_dir, "target_audio_converted.mp3")
255
+ convert_audio_to_mp3(target_input, mp3_path)
256
+ audio_path = mp3_path
257
+ kps_path = ""
258
+ else:
259
+ raise ValueError("Unsupported file type. Please provide a video or audio file.")
260
+
261
+ output_dir = "outputs"
262
+ os.makedirs(output_dir, exist_ok=True)
263
+ output_file_name = f"{input_file_name}_result_"
264
+ output_file_name=sanitize_folder_name(output_file_name)
265
+ output_file_ext = ".mp4"
266
+ output_file_count = 1
267
+ while os.path.exists(os.path.join(output_dir, f"{output_file_name}{output_file_count:04d}{output_file_ext}")):
268
+ output_file_count += 1
269
+ output_path = os.path.join(output_dir, f"{output_file_name}{output_file_count:04d}{output_file_ext}")
270
+
271
+
272
+ output_video_path, cropped_image_path = generate_output_video(reference_image, audio_path, kps_path, output_path, retarget_strategy, num_inference_steps, reference_attention_weight, audio_attention_weight, auto_crop,crop_width,crop_height, crop_expansion,image_width,image_height,low_vram)
273
+
274
+ return output_video_path, cropped_image_path
275
+
276
+ def launch_interface():
277
+ retarget_strategies = ["fix_face", "no_retarget", "offset_retarget", "naive_retarget"]
278
+
279
+ with gr.Blocks() as demo:
280
+ gr.Markdown("# Tencent AI Lab - V-Express Image to Animation V4 : https://www.patreon.com/posts/105251204")
281
+ with gr.Row():
282
+ with gr.Column():
283
+ input_image = gr.Image(label="Reference Image", format="png", type="filepath", height=512)
284
+ generate_button = gr.Button("Generate Talking Video")
285
+ low_vram = gr.Checkbox(label="Low VRAM - Greatly reduces VRAM usage but takes longer", value=False,visible=False)
286
+ crop_button = gr.Button("Crop Image")
287
+ with gr.Row():
288
+
289
+ with gr.Column(min_width=0):
290
+ image_width = gr.Number(label="Target Video Width", value=512)
291
+
292
+ with gr.Column(min_width=0):
293
+ image_height = gr.Number(label="Target Video Height", value=512)
294
+
295
+ with gr.Row():
296
+ with gr.Column(min_width=0):
297
+ retarget_strategy = gr.Dropdown(retarget_strategies, label="Retarget Strategy", value="fix_face")
298
+ with gr.Column(min_width=0):
299
+ inference_steps = gr.Slider(10, 90, step=1, label="Number of Inference Steps", value=30)
300
+
301
+ with gr.Row():
302
+ with gr.Column(min_width=0):
303
+ reference_attention = gr.Slider(0.80, 1.1, step=0.01, label="Reference Attention Weight", value=0.95)
304
+ with gr.Column(min_width=0):
305
+ audio_attention = gr.Slider(1.0, 5.0, step=0.1, label="Audio Attention Weight", value=3.0)
306
+
307
+ with gr.Row(visible=True) as crop_size_row:
308
+ with gr.Column(min_width=0):
309
+ auto_crop = gr.Checkbox(label="Auto Crop Image", value=True)
310
+ with gr.Column(min_width=0):
311
+ crop_expansion = gr.Slider(0.0, 1.0, step=0.01, label="Face Focus Expansion Percent", value=0.15)
312
+ with gr.Row():
313
+ with gr.Column(min_width=0):
314
+ crop_width = gr.Number(label="Crop Width", value=512)
315
+ with gr.Column(min_width=0):
316
+ crop_height = gr.Number(label="Crop Height", value=512)
317
+
318
+ with gr.Column():
319
+ input_video = gr.File(
320
+ label="Target Input (Image or Video)",
321
+ type="filepath",
322
+ file_count="single",
323
+ file_types=[
324
+ ".mp4", ".avi", ".mov", ".wmv", ".flv", ".mkv", ".webm", # Video extensions
325
+ ".3gp", ".m4v", ".mpg", ".mpeg", ".m2v", ".m4v", ".mts", # More video extensions
326
+ ".mp3", ".wav", ".aac", ".flac", ".m4a", ".wma", ".ogg" # Audio extensions
327
+ ],
328
+ height=512 )
329
+ video_output = gr.Video(visible=False)
330
+ audio_output = gr.Audio(visible=False)
331
+
332
+ input_video.change(display_media, inputs=input_video, outputs=[video_output, audio_output])
333
+ btn_open_outputs = gr.Button("Open Outputs Folder")
334
+ btn_open_outputs.click(fn=open_folder)
335
+ gr.Markdown("""
336
+
337
+ Retarget Strategies
338
+
339
+ Only target audio : fix_face
340
+
341
+ Input picture and target video (same person - best practice) select : no_retarget
342
+
343
+ Input picture and target video (different person) select : offset_retarget or naive_retarget
344
+
345
+ Please look examples in Tests folder to see which settings you like most. I feel like offset_retarget is best
346
+
347
+ You can turn up reference_attention_weight to make the model maintain higher character consistency, and turn down audio_attention_weight to reduce mouth artifacts. E.g. setting both values to 1.0
348
+ """)
349
+
350
+
351
+
352
+ with gr.Column():
353
+ output_video = gr.Video(label="Generated Video", height=512)
354
+ output_image = gr.Image(label="Cropped Image")
355
+
356
+
357
+ generate_button.click(
358
+ fn=process_input,
359
+ inputs=[
360
+ input_image,
361
+ input_video,
362
+ retarget_strategy,
363
+ inference_steps,
364
+ reference_attention,
365
+ audio_attention,
366
+ auto_crop,
367
+ crop_width,
368
+ crop_height,
369
+ crop_expansion,
370
+ image_width,
371
+ image_height,
372
+ low_vram
373
+ ],
374
+ outputs=[output_video, output_image]
375
+ )
376
+
377
+ crop_button.click(
378
+ fn=crop_and_save_image,
379
+ inputs=[
380
+ input_image,
381
+ auto_crop,
382
+ crop_width,
383
+ crop_height,
384
+ crop_expansion
385
+ ],
386
+ outputs=output_image
387
+ )
388
+
389
+ demo.queue()
390
+ demo.launch(inbrowser=True,share=args.share)
391
+
392
+ # Run the Gradio interface
393
+ launch_interface()
assets/crop_example.jpeg ADDED
assets/demo/fix_face_weight.gif ADDED

Git LFS Details

  • SHA256: ff0a3eca118822dd6676b1856854be99ad7a11a571f08e735a2777c13613db5f
  • Pointer size: 133 Bytes
  • Size of remote file: 41.7 MB
assets/demo/gt_generate_compare.gif ADDED

Git LFS Details

  • SHA256: 943a0b94df873eeffe1334148fb8935251e3a200cf4bd7631aae89c718aab5e3
  • Pointer size: 133 Bytes
  • Size of remote file: 13.3 MB
assets/demo/talk_tys_fix_face_post_processing.gif ADDED

Git LFS Details

  • SHA256: 20924e3ab2e7e7d32fe5ab0adc70d4918fefac11e3fa2da77f265a5cc63296cb
  • Pointer size: 132 Bytes
  • Size of remote file: 3.96 MB
assets/demo/talk_tys_naive_retarget_post_processing.gif ADDED

Git LFS Details

  • SHA256: 5d476f224799620c161ac199eb3748312e385c5f169435db9f677c3cdb10709e
  • Pointer size: 132 Bytes
  • Size of remote file: 7.67 MB
assets/demo/talk_tys_offset_retarget_post_processing.gif ADDED

Git LFS Details

  • SHA256: 0132d055bd25875a0e70dd4828afc0273f754293c083cbd0051f6895d6722c41
  • Pointer size: 132 Bytes
  • Size of remote file: 7.32 MB
assets/global_framework.png ADDED
executed_command.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ R:\V_Express_Installers_V6\V-Express\venv\Scripts\python.exe inference.py --reference_image_path C:\Users\King\AppData\Local\Temp\gradio\e57288fe0b82a668f63c481f82c0fc65eaa84618\Biden_Photo_Big.png --audio_path temp_process\Biden_Photo_Big_20240620112047\target_audio.mp3 --kps_path temp_process\Biden_Photo_Big_20240620112047\kps.pth --output_path outputs\Biden_Photo_Big_result_0006.mp4 --retarget_strategy offset_retarget --num_inference_steps 30 --reference_attention_weight 0.95 --audio_attention_weight 3 --image_width 512 --image_height 512
git ADDED
File without changes
inference.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+
5
+ import accelerate
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ import torchaudio.functional
10
+ import torchvision.io
11
+ from PIL import Image
12
+ from diffusers import AutoencoderKL, DDIMScheduler
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from insightface.app import FaceAnalysis
15
+ from omegaconf import OmegaConf
16
+ from transformers import Wav2Vec2Model, Wav2Vec2Processor
17
+
18
+ from modules import UNet2DConditionModel, UNet3DConditionModel, VKpsGuider, AudioProjection
19
+ from pipelines import VExpressPipeline
20
+ from pipelines.utils import draw_kps_image, save_video
21
+ from pipelines.utils import retarget_kps
22
+
23
+
24
+ def parse_args():
25
+ parser = argparse.ArgumentParser()
26
+
27
+ parser.add_argument('--unet_config_path', type=str, default='./model_ckpts/stable-diffusion-v1-5/unet/config.json')
28
+ parser.add_argument('--vae_path', type=str, default='./model_ckpts/sd-vae-ft-mse/')
29
+ parser.add_argument('--audio_encoder_path', type=str, default='./model_ckpts/wav2vec2-base-960h/')
30
+ parser.add_argument('--insightface_model_path', type=str, default='./model_ckpts/insightface_models/')
31
+
32
+ parser.add_argument('--denoising_unet_path', type=str, default='./model_ckpts/v-express/denoising_unet.bin')
33
+ parser.add_argument('--reference_net_path', type=str, default='./model_ckpts/v-express/reference_net.bin')
34
+ parser.add_argument('--v_kps_guider_path', type=str, default='./model_ckpts/v-express/v_kps_guider.bin')
35
+ parser.add_argument('--audio_projection_path', type=str, default='./model_ckpts/v-express/audio_projection.bin')
36
+ parser.add_argument('--motion_module_path', type=str, default='./model_ckpts/v-express/motion_module.bin')
37
+
38
+ parser.add_argument('--retarget_strategy', type=str, default='fix_face',
39
+ help='{fix_face, no_retarget, offset_retarget, naive_retarget}')
40
+
41
+ parser.add_argument('--dtype', type=str, default='fp16')
42
+ parser.add_argument('--device', type=str, default='cuda')
43
+ parser.add_argument('--gpu_id', type=int, default=0)
44
+ parser.add_argument('--do_multi_devices_inference', action='store_true')
45
+ parser.add_argument('--save_gpu_memory', action='store_true')
46
+
47
+ parser.add_argument('--num_pad_audio_frames', type=int, default=2)
48
+ parser.add_argument('--standard_audio_sampling_rate', type=int, default=16000)
49
+
50
+ parser.add_argument('--reference_image_path', type=str, default='./test_samples/emo/talk_emotion/ref.jpg')
51
+ parser.add_argument('--audio_path', type=str, default='./test_samples/emo/talk_emotion/aud.mp3')
52
+ parser.add_argument('--kps_path', type=str, default='./test_samples/emo/talk_emotion/kps.pth')
53
+ parser.add_argument('--output_path', type=str, default='./output/emo/talk_emotion.mp4')
54
+
55
+ parser.add_argument('--image_width', type=int, default=512)
56
+ parser.add_argument('--image_height', type=int, default=512)
57
+ parser.add_argument('--fps', type=float, default=30.0)
58
+ parser.add_argument('--seed', type=int, default=42)
59
+ parser.add_argument('--num_inference_steps', type=int, default=25)
60
+ parser.add_argument('--guidance_scale', type=float, default=3.5)
61
+ parser.add_argument('--context_frames', type=int, default=12)
62
+ parser.add_argument('--context_overlap', type=int, default=4)
63
+ parser.add_argument('--reference_attention_weight', default=0.95, type=float)
64
+ parser.add_argument('--audio_attention_weight', default=3., type=float)
65
+
66
+ args = parser.parse_args()
67
+
68
+ return args
69
+
70
+
71
+ def load_reference_net(unet_config_path, reference_net_path, dtype, device):
72
+ reference_net = UNet2DConditionModel.from_config(unet_config_path).to(dtype=dtype, device=device)
73
+ reference_net.load_state_dict(torch.load(reference_net_path, map_location="cpu"), strict=False)
74
+ print(f'Loaded weights of Reference Net from {reference_net_path}.')
75
+ return reference_net
76
+
77
+
78
+ def load_denoising_unet(inf_config_path, unet_config_path, denoising_unet_path, motion_module_path, dtype, device):
79
+ inference_config = OmegaConf.load(inf_config_path)
80
+ denoising_unet = UNet3DConditionModel.from_config_2d(
81
+ unet_config_path,
82
+ unet_additional_kwargs=inference_config.unet_additional_kwargs,
83
+ ).to(dtype=dtype, device=device)
84
+ denoising_unet.load_state_dict(torch.load(denoising_unet_path, map_location="cpu"), strict=False)
85
+ print(f'Loaded weights of Denoising U-Net from {denoising_unet_path}.')
86
+
87
+ denoising_unet.load_state_dict(torch.load(motion_module_path, map_location="cpu"), strict=False)
88
+ print(f'Loaded weights of Denoising U-Net Motion Module from {motion_module_path}.')
89
+
90
+ return denoising_unet
91
+
92
+
93
+ def load_v_kps_guider(v_kps_guider_path, dtype, device):
94
+ v_kps_guider = VKpsGuider(320, block_out_channels=(16, 32, 96, 256)).to(dtype=dtype, device=device)
95
+ v_kps_guider.load_state_dict(torch.load(v_kps_guider_path, map_location="cpu"))
96
+ print(f'Loaded weights of V-Kps Guider from {v_kps_guider_path}.')
97
+ return v_kps_guider
98
+
99
+
100
+ def load_audio_projection(
101
+ audio_projection_path,
102
+ dtype,
103
+ device,
104
+ inp_dim: int,
105
+ mid_dim: int,
106
+ out_dim: int,
107
+ inp_seq_len: int,
108
+ out_seq_len: int,
109
+ ):
110
+ audio_projection = AudioProjection(
111
+ dim=mid_dim,
112
+ depth=4,
113
+ dim_head=64,
114
+ heads=12,
115
+ num_queries=out_seq_len,
116
+ embedding_dim=inp_dim,
117
+ output_dim=out_dim,
118
+ ff_mult=4,
119
+ max_seq_len=inp_seq_len,
120
+ ).to(dtype=dtype, device=device)
121
+ audio_projection.load_state_dict(torch.load(audio_projection_path, map_location='cpu'))
122
+ print(f'Loaded weights of Audio Projection from {audio_projection_path}.')
123
+ return audio_projection
124
+
125
+
126
+ def get_scheduler(inference_config_path):
127
+ inference_config = OmegaConf.load(inference_config_path)
128
+ scheduler_kwargs = OmegaConf.to_container(inference_config.noise_scheduler_kwargs)
129
+ scheduler = DDIMScheduler(**scheduler_kwargs)
130
+ return scheduler
131
+
132
+
133
+ def main():
134
+ args = parse_args()
135
+ start_time = time.time()
136
+
137
+ if not args.do_multi_devices_inference:
138
+ # TODO
139
+ accelerator = None
140
+ device = torch.device(f'{args.device}:{args.gpu_id}' if args.device == 'cuda' else args.device)
141
+ else:
142
+ accelerator = accelerate.Accelerator()
143
+ device = torch.device(f'cuda:{accelerator.process_index}')
144
+ dtype = torch.float16 if args.dtype == 'fp16' else torch.float32
145
+
146
+ vae_path = args.vae_path
147
+ audio_encoder_path = args.audio_encoder_path
148
+
149
+ vae = AutoencoderKL.from_pretrained(vae_path).to(dtype=dtype, device=device)
150
+ audio_encoder = Wav2Vec2Model.from_pretrained(audio_encoder_path).to(dtype=dtype, device=device)
151
+ audio_processor = Wav2Vec2Processor.from_pretrained(audio_encoder_path)
152
+
153
+ unet_config_path = args.unet_config_path
154
+ reference_net_path = args.reference_net_path
155
+ denoising_unet_path = args.denoising_unet_path
156
+ v_kps_guider_path = args.v_kps_guider_path
157
+ audio_projection_path = args.audio_projection_path
158
+ motion_module_path = args.motion_module_path
159
+
160
+ inference_config_path = './inference_v2.yaml'
161
+ scheduler = get_scheduler(inference_config_path)
162
+ reference_net = load_reference_net(unet_config_path, reference_net_path, dtype, device)
163
+ denoising_unet = load_denoising_unet(
164
+ inference_config_path, unet_config_path, denoising_unet_path, motion_module_path,
165
+ dtype, device
166
+ )
167
+ v_kps_guider = load_v_kps_guider(v_kps_guider_path, dtype, device)
168
+ audio_projection = load_audio_projection(
169
+ audio_projection_path,
170
+ dtype,
171
+ device,
172
+ inp_dim=denoising_unet.config.cross_attention_dim,
173
+ mid_dim=denoising_unet.config.cross_attention_dim,
174
+ out_dim=denoising_unet.config.cross_attention_dim,
175
+ inp_seq_len=2 * (2 * args.num_pad_audio_frames + 1),
176
+ out_seq_len=2 * args.num_pad_audio_frames + 1,
177
+ )
178
+
179
+ if is_xformers_available():
180
+ reference_net.enable_xformers_memory_efficient_attention()
181
+ denoising_unet.enable_xformers_memory_efficient_attention()
182
+ else:
183
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
184
+
185
+ generator = torch.manual_seed(args.seed)
186
+ pipeline = VExpressPipeline(
187
+ vae=vae,
188
+ reference_net=reference_net,
189
+ denoising_unet=denoising_unet,
190
+ v_kps_guider=v_kps_guider,
191
+ audio_processor=audio_processor,
192
+ audio_encoder=audio_encoder,
193
+ audio_projection=audio_projection,
194
+ scheduler=scheduler,
195
+ ).to(dtype=dtype, device=device)
196
+
197
+ app = FaceAnalysis(
198
+ providers=['CUDAExecutionProvider' if args.device == 'cuda' else 'CPUExecutionProvider'],
199
+ provider_options=[{'device_id': args.gpu_id}] if args.device == 'cuda' else [],
200
+ root=args.insightface_model_path,
201
+ )
202
+ app.prepare(ctx_id=0, det_size=(args.image_height, args.image_width))
203
+
204
+ reference_image = Image.open(args.reference_image_path).convert('RGB')
205
+ reference_image = reference_image.resize((args.image_height, args.image_width))
206
+
207
+ reference_image_for_kps = cv2.imread(args.reference_image_path)
208
+ reference_image_for_kps = cv2.resize(reference_image_for_kps, (args.image_width, args.image_height))
209
+ reference_kps = app.get(reference_image_for_kps)[0].kps[:3]
210
+ if args.save_gpu_memory:
211
+ del app
212
+ torch.cuda.empty_cache()
213
+
214
+ _, audio_waveform, meta_info = torchvision.io.read_video(os.path.join(os.path.dirname(args.audio_path), os.path.basename(args.audio_path)), pts_unit='sec')
215
+ audio_sampling_rate = meta_info['audio_fps']
216
+ print(f'Length of audio is {audio_waveform.shape[1]} with the sampling rate of {audio_sampling_rate}.')
217
+ if audio_sampling_rate != args.standard_audio_sampling_rate:
218
+ audio_waveform = torchaudio.functional.resample(
219
+ audio_waveform,
220
+ orig_freq=audio_sampling_rate,
221
+ new_freq=args.standard_audio_sampling_rate,
222
+ )
223
+ audio_waveform = audio_waveform.mean(dim=0)
224
+
225
+ duration = audio_waveform.shape[0] / args.standard_audio_sampling_rate
226
+ init_video_length = int(duration * args.fps)
227
+ num_contexts = np.around((init_video_length + args.context_overlap) / args.context_frames)
228
+ video_length = int(num_contexts * args.context_frames - args.context_overlap)
229
+ fps = video_length / duration
230
+ print(f'The corresponding video length is {video_length}.')
231
+
232
+ kps_sequence = None
233
+ if args.kps_path != "":
234
+ assert os.path.exists(args.kps_path), f'{args.kps_path} does not exist'
235
+ kps_sequence = torch.tensor(torch.load(args.kps_path)) # [len, 3, 2]
236
+ print(f'The original length of kps sequence is {kps_sequence.shape[0]}.')
237
+
238
+ if kps_sequence.shape[0] > video_length:
239
+ kps_sequence = kps_sequence[:video_length, :, :]
240
+
241
+ kps_sequence = torch.nn.functional.interpolate(kps_sequence.permute(1, 2, 0), size=video_length, mode='linear')
242
+ kps_sequence = kps_sequence.permute(2, 0, 1)
243
+ print(f'The interpolated length of kps sequence is {kps_sequence.shape[0]}.')
244
+
245
+ retarget_strategy = args.retarget_strategy
246
+ if retarget_strategy == 'fix_face':
247
+ kps_sequence = torch.tensor([reference_kps] * video_length)
248
+ elif retarget_strategy == 'no_retarget':
249
+ kps_sequence = kps_sequence
250
+ elif retarget_strategy == 'offset_retarget':
251
+ kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=True)
252
+ elif retarget_strategy == 'naive_retarget':
253
+ kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=False)
254
+ else:
255
+ raise ValueError(f'The retarget strategy {retarget_strategy} is not supported.')
256
+
257
+ kps_images = []
258
+ for i in range(video_length):
259
+ kps_image = draw_kps_image(args.image_height, args.image_width, kps_sequence[i])
260
+ kps_images.append(Image.fromarray(kps_image))
261
+
262
+ video_tensor = pipeline(
263
+ reference_image=reference_image,
264
+ kps_images=kps_images,
265
+ audio_waveform=audio_waveform,
266
+ width=args.image_width,
267
+ height=args.image_height,
268
+ video_length=video_length,
269
+ num_inference_steps=args.num_inference_steps,
270
+ guidance_scale=args.guidance_scale,
271
+ context_frames=args.context_frames,
272
+ context_overlap=args.context_overlap,
273
+ reference_attention_weight=args.reference_attention_weight,
274
+ audio_attention_weight=args.audio_attention_weight,
275
+ num_pad_audio_frames=args.num_pad_audio_frames,
276
+ generator=generator,
277
+ do_multi_devices_inference=args.do_multi_devices_inference,
278
+ save_gpu_memory=args.save_gpu_memory,
279
+ )
280
+
281
+ if accelerator is None or accelerator.is_main_process:
282
+ save_video(video_tensor, args.audio_path, args.output_path, device, fps)
283
+ consumed_time = time.time() - start_time
284
+ generation_fps = video_tensor.shape[2] / consumed_time
285
+ print(f'The generated video has been saved at {args.output_path}. '
286
+ f'The generation time is {consumed_time:.1f} seconds. '
287
+ f'The generation FPS is {generation_fps:.2f}.')
288
+
289
+
290
+ if __name__ == '__main__':
291
+ main()
inference_v2.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true
3
+ unet_use_cross_frame_attention: false
4
+ unet_use_temporal_attention: false
5
+ use_motion_module: true
6
+ motion_module_resolutions:
7
+ - 1
8
+ - 2
9
+ - 4
10
+ - 8
11
+ motion_module_mid_block: true
12
+ motion_module_decoder_only: false
13
+ motion_module_type: Vanilla
14
+ motion_module_kwargs:
15
+ num_attention_heads: 8
16
+ num_transformer_block: 1
17
+ attention_block_types:
18
+ - Temporal_Self
19
+ - Temporal_Self
20
+ temporal_position_encoding: true
21
+ temporal_position_encoding_max_len: 32
22
+ temporal_attention_dim_div: 1
23
+
24
+ noise_scheduler_kwargs:
25
+ beta_start: 0.00085
26
+ beta_end: 0.012
27
+ beta_schedule: "scaled_linear"
28
+ clip_sample: false
29
+ steps_offset: 1
30
+ ### Zero-SNR params
31
+ prediction_type: "v_prediction"
32
+ rescale_betas_zero_snr: True
33
+ timestep_spacing: "trailing"
34
+
35
+ sampler: DDIM
modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .unet_2d_condition import UNet2DConditionModel
2
+ from .unet_3d import UNet3DConditionModel
3
+ from .v_kps_guider import VKpsGuider
4
+ from .audio_projection import AudioProjection
5
+ from .mutual_self_attention import ReferenceAttentionControl
modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (437 Bytes). View file
 
modules/__pycache__/attention.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
modules/__pycache__/audio_projection.cpython-310.pyc ADDED
Binary file (4.29 kB). View file
 
modules/__pycache__/motion_module.cpython-310.pyc ADDED
Binary file (8.67 kB). View file
 
modules/__pycache__/mutual_self_attention.cpython-310.pyc ADDED
Binary file (7.2 kB). View file
 
modules/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (5.44 kB). View file
 
modules/__pycache__/transformer_2d.cpython-310.pyc ADDED
Binary file (12.4 kB). View file
 
modules/__pycache__/transformer_3d.cpython-310.pyc ADDED
Binary file (4.07 kB). View file
 
modules/__pycache__/unet_2d_blocks.cpython-310.pyc ADDED
Binary file (21.7 kB). View file
 
modules/__pycache__/unet_2d_condition.cpython-310.pyc ADDED
Binary file (37.6 kB). View file
 
modules/__pycache__/unet_3d.cpython-310.pyc ADDED
Binary file (16.6 kB). View file
 
modules/__pycache__/unet_3d_blocks.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
modules/__pycache__/v_kps_guider.cpython-310.pyc ADDED
Binary file (1.54 kB). View file
 
modules/attention.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.models.attention import AdaLayerNorm, AdaLayerNormZero, Attention, FeedForward, GatedSelfAttentionDense
7
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+
12
+ class BasicTransformerBlock(nn.Module):
13
+ r"""
14
+ A basic Transformer block.
15
+
16
+ Parameters:
17
+ dim (`int`): The number of channels in the input and output.
18
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
19
+ attention_head_dim (`int`): The number of channels in each head.
20
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
21
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
22
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
23
+ num_embeds_ada_norm (:
24
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
25
+ attention_bias (:
26
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
27
+ only_cross_attention (`bool`, *optional*):
28
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
29
+ double_self_attention (`bool`, *optional*):
30
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
31
+ upcast_attention (`bool`, *optional*):
32
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
33
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
34
+ Whether to use learnable elementwise affine parameters for normalization.
35
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
36
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
37
+ final_dropout (`bool` *optional*, defaults to False):
38
+ Whether to apply a final dropout after the last feed-forward layer.
39
+ attention_type (`str`, *optional*, defaults to `"default"`):
40
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
41
+ positional_embeddings (`str`, *optional*, defaults to `None`):
42
+ The type of positional embeddings to apply to.
43
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
44
+ The maximum number of positional embeddings to apply.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ dim: int,
50
+ num_attention_heads: int,
51
+ attention_head_dim: int,
52
+ dropout=0.0,
53
+ cross_attention_dim: Optional[int] = None,
54
+ activation_fn: str = "geglu",
55
+ num_embeds_ada_norm: Optional[int] = None,
56
+ attention_bias: bool = False,
57
+ only_cross_attention: bool = False,
58
+ double_self_attention: bool = False,
59
+ upcast_attention: bool = False,
60
+ norm_elementwise_affine: bool = True,
61
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
62
+ norm_eps: float = 1e-5,
63
+ final_dropout: bool = False,
64
+ attention_type: str = "default",
65
+ positional_embeddings: Optional[str] = None,
66
+ num_positional_embeddings: Optional[int] = None,
67
+ ):
68
+ super().__init__()
69
+ self.only_cross_attention = only_cross_attention
70
+
71
+ self.use_ada_layer_norm_zero = (
72
+ num_embeds_ada_norm is not None
73
+ ) and norm_type == "ada_norm_zero"
74
+ self.use_ada_layer_norm = (
75
+ num_embeds_ada_norm is not None
76
+ ) and norm_type == "ada_norm"
77
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
78
+ self.use_layer_norm = norm_type == "layer_norm"
79
+
80
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
81
+ raise ValueError(
82
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
83
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
84
+ )
85
+
86
+ if positional_embeddings and (num_positional_embeddings is None):
87
+ raise ValueError(
88
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
89
+ )
90
+
91
+ if positional_embeddings == "sinusoidal":
92
+ self.pos_embed = SinusoidalPositionalEmbedding(
93
+ dim, max_seq_length=num_positional_embeddings
94
+ )
95
+ else:
96
+ self.pos_embed = None
97
+
98
+ # Define 3 blocks. Each block has its own normalization layer.
99
+ # 1. Self-Attn
100
+ if self.use_ada_layer_norm:
101
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
102
+ elif self.use_ada_layer_norm_zero:
103
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
104
+ else:
105
+ self.norm1 = nn.LayerNorm(
106
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
107
+ )
108
+
109
+ self.attn1 = Attention(
110
+ query_dim=dim,
111
+ heads=num_attention_heads,
112
+ dim_head=attention_head_dim,
113
+ dropout=dropout,
114
+ bias=attention_bias,
115
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
116
+ upcast_attention=upcast_attention,
117
+ )
118
+
119
+ # 2. Cross-Attn
120
+ if cross_attention_dim is not None or double_self_attention:
121
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
122
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
123
+ # the second cross attention block.
124
+ self.norm2 = (
125
+ AdaLayerNorm(dim, num_embeds_ada_norm)
126
+ if self.use_ada_layer_norm
127
+ else nn.LayerNorm(
128
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
129
+ )
130
+ )
131
+ self.attn2 = Attention(
132
+ query_dim=dim,
133
+ cross_attention_dim=cross_attention_dim
134
+ if not double_self_attention
135
+ else None,
136
+ heads=num_attention_heads,
137
+ dim_head=attention_head_dim,
138
+ dropout=dropout,
139
+ bias=attention_bias,
140
+ upcast_attention=upcast_attention,
141
+ ) # is self-attn if encoder_hidden_states is none
142
+ else:
143
+ self.norm2 = None
144
+ self.attn2 = None
145
+
146
+ # 3. Feed-forward
147
+ if not self.use_ada_layer_norm_single:
148
+ self.norm3 = nn.LayerNorm(
149
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
150
+ )
151
+
152
+ self.ff = FeedForward(
153
+ dim,
154
+ dropout=dropout,
155
+ activation_fn=activation_fn,
156
+ final_dropout=final_dropout,
157
+ )
158
+
159
+ # 4. Fuser
160
+ if attention_type == "gated" or attention_type == "gated-text-image":
161
+ self.fuser = GatedSelfAttentionDense(
162
+ dim, cross_attention_dim, num_attention_heads, attention_head_dim
163
+ )
164
+
165
+ # 5. Scale-shift for PixArt-Alpha.
166
+ if self.use_ada_layer_norm_single:
167
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
168
+
169
+ # let chunk size default to None
170
+ self._chunk_size = None
171
+ self._chunk_dim = 0
172
+
173
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
174
+ # Sets chunk feed-forward
175
+ self._chunk_size = chunk_size
176
+ self._chunk_dim = dim
177
+
178
+ def forward(
179
+ self,
180
+ hidden_states: torch.FloatTensor,
181
+ attention_mask: Optional[torch.FloatTensor] = None,
182
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
183
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
184
+ timestep: Optional[torch.LongTensor] = None,
185
+ cross_attention_kwargs: Dict[str, Any] = None,
186
+ class_labels: Optional[torch.LongTensor] = None,
187
+ ) -> torch.FloatTensor:
188
+ # Notice that normalization is always applied before the real computation in the following blocks.
189
+ # 0. Self-Attention
190
+ batch_size = hidden_states.shape[0]
191
+
192
+ if self.use_ada_layer_norm:
193
+ norm_hidden_states = self.norm1(hidden_states, timestep)
194
+ elif self.use_ada_layer_norm_zero:
195
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
196
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
197
+ )
198
+ elif self.use_layer_norm:
199
+ norm_hidden_states = self.norm1(hidden_states)
200
+ elif self.use_ada_layer_norm_single:
201
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
202
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
203
+ ).chunk(6, dim=1)
204
+ norm_hidden_states = self.norm1(hidden_states)
205
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
206
+ norm_hidden_states = norm_hidden_states.squeeze(1)
207
+ else:
208
+ raise ValueError("Incorrect norm used")
209
+
210
+ if self.pos_embed is not None:
211
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
212
+
213
+ # 1. Retrieve lora scale.
214
+ lora_scale = (
215
+ cross_attention_kwargs.get("scale", 1.0)
216
+ if cross_attention_kwargs is not None
217
+ else 1.0
218
+ )
219
+
220
+ # 2. Prepare GLIGEN inputs
221
+ cross_attention_kwargs = (
222
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
223
+ )
224
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
225
+
226
+ attn_output = self.attn1(
227
+ norm_hidden_states,
228
+ encoder_hidden_states=encoder_hidden_states
229
+ if self.only_cross_attention
230
+ else None,
231
+ attention_mask=attention_mask,
232
+ **cross_attention_kwargs,
233
+ )
234
+ if self.use_ada_layer_norm_zero:
235
+ attn_output = gate_msa.unsqueeze(1) * attn_output
236
+ elif self.use_ada_layer_norm_single:
237
+ attn_output = gate_msa * attn_output
238
+
239
+ hidden_states = attn_output + hidden_states
240
+ if hidden_states.ndim == 4:
241
+ hidden_states = hidden_states.squeeze(1)
242
+
243
+ # 2.5 GLIGEN Control
244
+ if gligen_kwargs is not None:
245
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
246
+
247
+ # 3. Cross-Attention
248
+ if self.attn2 is not None:
249
+ if self.use_ada_layer_norm:
250
+ norm_hidden_states = self.norm2(hidden_states, timestep)
251
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
252
+ norm_hidden_states = self.norm2(hidden_states)
253
+ elif self.use_ada_layer_norm_single:
254
+ # For PixArt norm2 isn't applied here:
255
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
256
+ norm_hidden_states = hidden_states
257
+ else:
258
+ raise ValueError("Incorrect norm")
259
+
260
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
261
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
262
+
263
+ attn_output = self.attn2(
264
+ norm_hidden_states,
265
+ encoder_hidden_states=encoder_hidden_states,
266
+ attention_mask=encoder_attention_mask,
267
+ **cross_attention_kwargs,
268
+ )
269
+ hidden_states = attn_output + hidden_states
270
+
271
+ # 4. Feed-forward
272
+ if not self.use_ada_layer_norm_single:
273
+ norm_hidden_states = self.norm3(hidden_states)
274
+
275
+ if self.use_ada_layer_norm_zero:
276
+ norm_hidden_states = (
277
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
278
+ )
279
+
280
+ if self.use_ada_layer_norm_single:
281
+ norm_hidden_states = self.norm2(hidden_states)
282
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
283
+
284
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
285
+
286
+ if self.use_ada_layer_norm_zero:
287
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
288
+ elif self.use_ada_layer_norm_single:
289
+ ff_output = gate_mlp * ff_output
290
+
291
+ hidden_states = ff_output + hidden_states
292
+ if hidden_states.ndim == 4:
293
+ hidden_states = hidden_states.squeeze(1)
294
+
295
+ return hidden_states
296
+
297
+
298
+ class TemporalBasicTransformerBlock(nn.Module):
299
+ def __init__(
300
+ self,
301
+ dim: int,
302
+ num_attention_heads: int,
303
+ attention_head_dim: int,
304
+ dropout=0.0,
305
+ cross_attention_dim: Optional[int] = None,
306
+ activation_fn: str = "geglu",
307
+ num_embeds_ada_norm: Optional[int] = None,
308
+ attention_bias: bool = False,
309
+ only_cross_attention: bool = False,
310
+ upcast_attention: bool = False,
311
+ unet_use_cross_frame_attention=None,
312
+ unet_use_temporal_attention=None,
313
+ ):
314
+ super().__init__()
315
+ self.only_cross_attention = only_cross_attention
316
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
317
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
318
+ self.unet_use_temporal_attention = unet_use_temporal_attention
319
+
320
+ # old self attention layer for only self-attention
321
+ self.attn1 = Attention(
322
+ query_dim=dim,
323
+ heads=num_attention_heads,
324
+ dim_head=attention_head_dim,
325
+ dropout=dropout,
326
+ bias=attention_bias,
327
+ upcast_attention=upcast_attention,
328
+ )
329
+ self.norm1 = (
330
+ AdaLayerNorm(dim, num_embeds_ada_norm)
331
+ if self.use_ada_layer_norm
332
+ else nn.LayerNorm(dim)
333
+ )
334
+
335
+ # new self attention layer for reference features
336
+ self.attn1_5 = Attention(
337
+ query_dim=dim,
338
+ heads=num_attention_heads,
339
+ dim_head=attention_head_dim,
340
+ dropout=dropout,
341
+ bias=attention_bias,
342
+ upcast_attention=upcast_attention,
343
+ )
344
+ self.norm1_5 = (
345
+ AdaLayerNorm(dim, num_embeds_ada_norm)
346
+ if self.use_ada_layer_norm
347
+ else nn.LayerNorm(dim)
348
+ )
349
+
350
+ # Cross-Attn
351
+ if cross_attention_dim is not None:
352
+ self.attn2 = Attention(
353
+ query_dim=dim,
354
+ cross_attention_dim=cross_attention_dim,
355
+ heads=num_attention_heads,
356
+ dim_head=attention_head_dim,
357
+ dropout=dropout,
358
+ bias=attention_bias,
359
+ upcast_attention=upcast_attention,
360
+ )
361
+ else:
362
+ self.attn2 = None
363
+
364
+ if cross_attention_dim is not None:
365
+ self.norm2 = (
366
+ AdaLayerNorm(dim, num_embeds_ada_norm)
367
+ if self.use_ada_layer_norm
368
+ else nn.LayerNorm(dim)
369
+ )
370
+ else:
371
+ self.norm2 = None
372
+
373
+ # Feed-forward
374
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
375
+ self.norm3 = nn.LayerNorm(dim)
376
+ self.use_ada_layer_norm_zero = False
377
+
378
+ # Temp-Attn
379
+ assert unet_use_temporal_attention is not None
380
+ if unet_use_temporal_attention:
381
+ self.attn_temp = Attention(
382
+ query_dim=dim,
383
+ heads=num_attention_heads,
384
+ dim_head=attention_head_dim,
385
+ dropout=dropout,
386
+ bias=attention_bias,
387
+ upcast_attention=upcast_attention,
388
+ )
389
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
390
+ self.norm_temp = (
391
+ AdaLayerNorm(dim, num_embeds_ada_norm)
392
+ if self.use_ada_layer_norm
393
+ else nn.LayerNorm(dim)
394
+ )
395
+
396
+ def forward(
397
+ self,
398
+ hidden_states,
399
+ encoder_hidden_states=None,
400
+ timestep=None,
401
+ attention_mask=None,
402
+ video_length=None,
403
+ ):
404
+ norm_hidden_states = (
405
+ self.norm1(hidden_states, timestep)
406
+ if self.use_ada_layer_norm
407
+ else self.norm1(hidden_states)
408
+ )
409
+
410
+ if self.unet_use_cross_frame_attention:
411
+ hidden_states = (
412
+ self.attn1(
413
+ norm_hidden_states,
414
+ attention_mask=attention_mask,
415
+ video_length=video_length,
416
+ )
417
+ + hidden_states
418
+ )
419
+ else:
420
+ hidden_states = (
421
+ self.attn1(norm_hidden_states, attention_mask=attention_mask)
422
+ + hidden_states
423
+ )
424
+
425
+ norm_hidden_states = (
426
+ self.norm1_5(hidden_states, timestep)
427
+ if self.use_ada_layer_norm
428
+ else self.norm1_5(hidden_states)
429
+ )
430
+
431
+ if self.unet_use_cross_frame_attention:
432
+ hidden_states = (
433
+ self.attn1_5(
434
+ norm_hidden_states,
435
+ attention_mask=attention_mask,
436
+ video_length=video_length,
437
+ )
438
+ + hidden_states
439
+ )
440
+ else:
441
+ hidden_states = (
442
+ self.attn1_5(norm_hidden_states, attention_mask=attention_mask)
443
+ + hidden_states
444
+ )
445
+
446
+ if self.attn2 is not None:
447
+ # Cross-Attention
448
+ norm_hidden_states = (
449
+ self.norm2(hidden_states, timestep)
450
+ if self.use_ada_layer_norm
451
+ else self.norm2(hidden_states)
452
+ )
453
+ hidden_states = (
454
+ self.attn2(
455
+ norm_hidden_states,
456
+ encoder_hidden_states=encoder_hidden_states,
457
+ attention_mask=attention_mask,
458
+ )
459
+ + hidden_states
460
+ )
461
+
462
+ # Feed-forward
463
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
464
+
465
+ # Temporal-Attention
466
+ if self.unet_use_temporal_attention:
467
+ d = hidden_states.shape[1]
468
+ hidden_states = rearrange(
469
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
470
+ )
471
+ norm_hidden_states = (
472
+ self.norm_temp(hidden_states, timestep)
473
+ if self.use_ada_layer_norm
474
+ else self.norm_temp(hidden_states)
475
+ )
476
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
477
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
478
+
479
+ return hidden_states
480
+
481
+ class TemporalBasicTransformerBlockOld(nn.Module):
482
+ def __init__(
483
+ self,
484
+ dim: int,
485
+ num_attention_heads: int,
486
+ attention_head_dim: int,
487
+ dropout=0.0,
488
+ cross_attention_dim: Optional[int] = None,
489
+ activation_fn: str = "geglu",
490
+ num_embeds_ada_norm: Optional[int] = None,
491
+ attention_bias: bool = False,
492
+ only_cross_attention: bool = False,
493
+ upcast_attention: bool = False,
494
+ unet_use_cross_frame_attention=None,
495
+ unet_use_temporal_attention=None,
496
+ ):
497
+ super().__init__()
498
+ self.only_cross_attention = only_cross_attention
499
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
500
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
501
+ self.unet_use_temporal_attention = unet_use_temporal_attention
502
+
503
+ # SC-Attn
504
+ self.attn1 = Attention(
505
+ query_dim=dim,
506
+ heads=num_attention_heads,
507
+ dim_head=attention_head_dim,
508
+ dropout=dropout,
509
+ bias=attention_bias,
510
+ upcast_attention=upcast_attention,
511
+ )
512
+ self.norm1 = (
513
+ AdaLayerNorm(dim, num_embeds_ada_norm)
514
+ if self.use_ada_layer_norm
515
+ else nn.LayerNorm(dim)
516
+ )
517
+
518
+ # Cross-Attn
519
+ if cross_attention_dim is not None:
520
+ self.attn2 = Attention(
521
+ query_dim=dim,
522
+ cross_attention_dim=cross_attention_dim,
523
+ heads=num_attention_heads,
524
+ dim_head=attention_head_dim,
525
+ dropout=dropout,
526
+ bias=attention_bias,
527
+ upcast_attention=upcast_attention,
528
+ )
529
+ else:
530
+ self.attn2 = None
531
+
532
+ if cross_attention_dim is not None:
533
+ self.norm2 = (
534
+ AdaLayerNorm(dim, num_embeds_ada_norm)
535
+ if self.use_ada_layer_norm
536
+ else nn.LayerNorm(dim)
537
+ )
538
+ else:
539
+ self.norm2 = None
540
+
541
+ # Feed-forward
542
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
543
+ self.norm3 = nn.LayerNorm(dim)
544
+ self.use_ada_layer_norm_zero = False
545
+
546
+ # Temp-Attn
547
+ assert unet_use_temporal_attention is not None
548
+ if unet_use_temporal_attention:
549
+ self.attn_temp = Attention(
550
+ query_dim=dim,
551
+ heads=num_attention_heads,
552
+ dim_head=attention_head_dim,
553
+ dropout=dropout,
554
+ bias=attention_bias,
555
+ upcast_attention=upcast_attention,
556
+ )
557
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
558
+ self.norm_temp = (
559
+ AdaLayerNorm(dim, num_embeds_ada_norm)
560
+ if self.use_ada_layer_norm
561
+ else nn.LayerNorm(dim)
562
+ )
563
+
564
+ def forward(
565
+ self,
566
+ hidden_states,
567
+ encoder_hidden_states=None,
568
+ timestep=None,
569
+ attention_mask=None,
570
+ video_length=None,
571
+ ):
572
+ norm_hidden_states = (
573
+ self.norm1(hidden_states, timestep)
574
+ if self.use_ada_layer_norm
575
+ else self.norm1(hidden_states)
576
+ )
577
+
578
+ if self.unet_use_cross_frame_attention:
579
+ hidden_states = (
580
+ self.attn1(
581
+ norm_hidden_states,
582
+ attention_mask=attention_mask,
583
+ video_length=video_length,
584
+ )
585
+ + hidden_states
586
+ )
587
+ else:
588
+ hidden_states = (
589
+ self.attn1(norm_hidden_states, attention_mask=attention_mask)
590
+ + hidden_states
591
+ )
592
+
593
+ if self.attn2 is not None:
594
+ # Cross-Attention
595
+ norm_hidden_states = (
596
+ self.norm2(hidden_states, timestep)
597
+ if self.use_ada_layer_norm
598
+ else self.norm2(hidden_states)
599
+ )
600
+ hidden_states = (
601
+ self.attn2(
602
+ norm_hidden_states,
603
+ encoder_hidden_states=encoder_hidden_states,
604
+ attention_mask=attention_mask,
605
+ )
606
+ + hidden_states
607
+ )
608
+
609
+ # Feed-forward
610
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
611
+
612
+ # Temporal-Attention
613
+ if self.unet_use_temporal_attention:
614
+ d = hidden_states.shape[1]
615
+ hidden_states = rearrange(
616
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
617
+ )
618
+ norm_hidden_states = (
619
+ self.norm_temp(hidden_states, timestep)
620
+ if self.use_ada_layer_norm
621
+ else self.norm_temp(hidden_states)
622
+ )
623
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
624
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
625
+
626
+ return hidden_states
modules/audio_projection.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from diffusers.models.modeling_utils import ModelMixin
6
+ from einops import rearrange
7
+ from einops.layers.torch import Rearrange
8
+
9
+
10
+ def reshape_tensor(x, heads):
11
+ bs, length, width = x.shape
12
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
13
+ x = x.view(bs, length, heads, -1)
14
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
15
+ x = x.transpose(1, 2)
16
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
17
+ x = x.reshape(bs, heads, length, -1)
18
+ return x
19
+
20
+
21
+ def masked_mean(t, *, dim, mask=None):
22
+ if mask is None:
23
+ return t.mean(dim=dim)
24
+
25
+ denom = mask.sum(dim=dim, keepdim=True)
26
+ mask = rearrange(mask, "b n -> b n 1")
27
+ masked_t = t.masked_fill(~mask, 0.0)
28
+
29
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
30
+
31
+
32
+ class PerceiverAttention(nn.Module):
33
+ def __init__(self, *, dim, dim_head=64, heads=8):
34
+ super().__init__()
35
+ self.scale = dim_head ** -0.5
36
+ self.dim_head = dim_head
37
+ self.heads = heads
38
+ inner_dim = dim_head * heads
39
+
40
+ self.norm1 = nn.LayerNorm(dim)
41
+ self.norm2 = nn.LayerNorm(dim)
42
+
43
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
44
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
45
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
46
+
47
+ def forward(self, x, latents):
48
+ """
49
+ Args:
50
+ x (torch.Tensor): image features
51
+ shape (b, n1, D)
52
+ latent (torch.Tensor): latent features
53
+ shape (b, n2, D)
54
+ """
55
+ x = self.norm1(x)
56
+ latents = self.norm2(latents)
57
+
58
+ b, l, _ = latents.shape
59
+
60
+ q = self.to_q(latents)
61
+ kv_input = torch.cat((x, latents), dim=-2)
62
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
63
+
64
+ q = reshape_tensor(q, self.heads)
65
+ k = reshape_tensor(k, self.heads)
66
+ v = reshape_tensor(v, self.heads)
67
+
68
+ # attention
69
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
70
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
71
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
72
+ out = weight @ v
73
+
74
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
75
+
76
+ return self.to_out(out)
77
+
78
+
79
+ def FeedForward(dim, mult=4):
80
+ inner_dim = int(dim * mult)
81
+ return nn.Sequential(
82
+ nn.LayerNorm(dim),
83
+ nn.Linear(dim, inner_dim, bias=False),
84
+ nn.GELU(),
85
+ nn.Linear(inner_dim, dim, bias=False),
86
+ )
87
+
88
+
89
+ class AudioProjection(ModelMixin):
90
+ def __init__(
91
+ self,
92
+ dim=1024,
93
+ depth=8,
94
+ dim_head=64,
95
+ heads=16,
96
+ num_queries=8,
97
+ embedding_dim=768,
98
+ output_dim=1024,
99
+ ff_mult=4,
100
+ max_seq_len: int = 257,
101
+ num_latents_mean_pooled: int = 0,
102
+ ):
103
+ super().__init__()
104
+
105
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim)
106
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
107
+
108
+ self.proj_in = nn.Linear(embedding_dim, dim)
109
+
110
+ self.proj_out = nn.Linear(dim, output_dim)
111
+ self.norm_out = nn.LayerNorm(output_dim)
112
+
113
+ self.to_latents_from_mean_pooled_seq = (
114
+ nn.Sequential(
115
+ nn.LayerNorm(dim),
116
+ nn.Linear(dim, dim * num_latents_mean_pooled),
117
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
118
+ )
119
+ if num_latents_mean_pooled > 0
120
+ else None
121
+ )
122
+
123
+ self.layers = nn.ModuleList([])
124
+ for _ in range(depth):
125
+ self.layers.append(nn.ModuleList([
126
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
127
+ FeedForward(dim=dim, mult=ff_mult),
128
+ ]))
129
+
130
+ def forward(self, x):
131
+ if self.pos_emb is not None:
132
+ n, device = x.shape[1], x.device
133
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
134
+ x = x + pos_emb
135
+
136
+ latents = self.latents.repeat(x.size(0), 1, 1)
137
+
138
+ x = self.proj_in(x)
139
+
140
+ if self.to_latents_from_mean_pooled_seq:
141
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
142
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
143
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
144
+
145
+ for attn, ff in self.layers:
146
+ latents = attn(x, latents) + latents
147
+ latents = ff(latents) + latents
148
+
149
+ latents = self.proj_out(latents)
150
+ return self.norm_out(latents)
modules/motion_module.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ from diffusers.models.attention import FeedForward
8
+ from diffusers.models.attention_processor import Attention, AttnProcessor
9
+ from diffusers.utils import BaseOutput
10
+ from diffusers.utils.import_utils import is_xformers_available
11
+ from einops import rearrange, repeat
12
+ from torch import nn
13
+
14
+
15
+ def zero_module(module):
16
+ # Zero out the parameters of a module and return it.
17
+ for p in module.parameters():
18
+ p.detach().zero_()
19
+ return module
20
+
21
+
22
+ @dataclass
23
+ class TemporalTransformer3DModelOutput(BaseOutput):
24
+ sample: torch.FloatTensor
25
+
26
+
27
+ if is_xformers_available():
28
+ import xformers
29
+ import xformers.ops
30
+ else:
31
+ xformers = None
32
+
33
+
34
+ def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
35
+ if motion_module_type == "Vanilla":
36
+ return VanillaTemporalModule(
37
+ in_channels=in_channels,
38
+ **motion_module_kwargs,
39
+ )
40
+ else:
41
+ raise ValueError
42
+
43
+
44
+ class VanillaTemporalModule(nn.Module):
45
+ def __init__(
46
+ self,
47
+ in_channels,
48
+ num_attention_heads=8,
49
+ num_transformer_block=2,
50
+ attention_block_types=("Temporal_Self", "Temporal_Self"),
51
+ cross_frame_attention_mode=None,
52
+ temporal_position_encoding=False,
53
+ temporal_position_encoding_max_len=24,
54
+ temporal_attention_dim_div=1,
55
+ zero_initialize=True,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.temporal_transformer = TemporalTransformer3DModel(
60
+ in_channels=in_channels,
61
+ num_attention_heads=num_attention_heads,
62
+ attention_head_dim=in_channels
63
+ // num_attention_heads
64
+ // temporal_attention_dim_div,
65
+ num_layers=num_transformer_block,
66
+ attention_block_types=attention_block_types,
67
+ cross_frame_attention_mode=cross_frame_attention_mode,
68
+ temporal_position_encoding=temporal_position_encoding,
69
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
70
+ )
71
+
72
+ if zero_initialize:
73
+ self.temporal_transformer.proj_out = zero_module(
74
+ self.temporal_transformer.proj_out
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ input_tensor,
80
+ temb,
81
+ encoder_hidden_states,
82
+ attention_mask=None,
83
+ anchor_frame_idx=None,
84
+ ):
85
+ hidden_states = input_tensor
86
+ hidden_states = self.temporal_transformer(
87
+ hidden_states, encoder_hidden_states, attention_mask
88
+ )
89
+
90
+ output = hidden_states
91
+ return output
92
+
93
+
94
+ class TemporalTransformer3DModel(nn.Module):
95
+ def __init__(
96
+ self,
97
+ in_channels,
98
+ num_attention_heads,
99
+ attention_head_dim,
100
+ num_layers,
101
+ attention_block_types=(
102
+ "Temporal_Self",
103
+ "Temporal_Self",
104
+ ),
105
+ dropout=0.0,
106
+ norm_num_groups=32,
107
+ cross_attention_dim=768,
108
+ activation_fn="geglu",
109
+ attention_bias=False,
110
+ upcast_attention=False,
111
+ cross_frame_attention_mode=None,
112
+ temporal_position_encoding=False,
113
+ temporal_position_encoding_max_len=24,
114
+ ):
115
+ super().__init__()
116
+
117
+ inner_dim = num_attention_heads * attention_head_dim
118
+
119
+ self.norm = torch.nn.GroupNorm(
120
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
121
+ )
122
+ self.proj_in = nn.Linear(in_channels, inner_dim)
123
+
124
+ self.transformer_blocks = nn.ModuleList(
125
+ [
126
+ TemporalTransformerBlock(
127
+ dim=inner_dim,
128
+ num_attention_heads=num_attention_heads,
129
+ attention_head_dim=attention_head_dim,
130
+ attention_block_types=attention_block_types,
131
+ dropout=dropout,
132
+ norm_num_groups=norm_num_groups,
133
+ cross_attention_dim=cross_attention_dim,
134
+ activation_fn=activation_fn,
135
+ attention_bias=attention_bias,
136
+ upcast_attention=upcast_attention,
137
+ cross_frame_attention_mode=cross_frame_attention_mode,
138
+ temporal_position_encoding=temporal_position_encoding,
139
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
140
+ )
141
+ for d in range(num_layers)
142
+ ]
143
+ )
144
+ self.proj_out = nn.Linear(inner_dim, in_channels)
145
+
146
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
147
+ assert (
148
+ hidden_states.dim() == 5
149
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
150
+ video_length = hidden_states.shape[2]
151
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
152
+
153
+ batch, channel, height, weight = hidden_states.shape
154
+ residual = hidden_states
155
+
156
+ hidden_states = self.norm(hidden_states)
157
+ inner_dim = hidden_states.shape[1]
158
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
159
+ batch, height * weight, inner_dim
160
+ )
161
+ hidden_states = self.proj_in(hidden_states)
162
+
163
+ # Transformer Blocks
164
+ for block in self.transformer_blocks:
165
+ hidden_states = block(
166
+ hidden_states,
167
+ encoder_hidden_states=encoder_hidden_states,
168
+ video_length=video_length,
169
+ )
170
+
171
+ # output
172
+ hidden_states = self.proj_out(hidden_states)
173
+ hidden_states = (
174
+ hidden_states.reshape(batch, height, weight, inner_dim)
175
+ .permute(0, 3, 1, 2)
176
+ .contiguous()
177
+ )
178
+
179
+ output = hidden_states + residual
180
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
181
+
182
+ return output
183
+
184
+
185
+ class TemporalTransformerBlock(nn.Module):
186
+ def __init__(
187
+ self,
188
+ dim,
189
+ num_attention_heads,
190
+ attention_head_dim,
191
+ attention_block_types=(
192
+ "Temporal_Self",
193
+ "Temporal_Self",
194
+ ),
195
+ dropout=0.0,
196
+ norm_num_groups=32,
197
+ cross_attention_dim=768,
198
+ activation_fn="geglu",
199
+ attention_bias=False,
200
+ upcast_attention=False,
201
+ cross_frame_attention_mode=None,
202
+ temporal_position_encoding=False,
203
+ temporal_position_encoding_max_len=24,
204
+ ):
205
+ super().__init__()
206
+
207
+ attention_blocks = []
208
+ norms = []
209
+
210
+ for block_name in attention_block_types:
211
+ attention_blocks.append(
212
+ VersatileAttention(
213
+ attention_mode=block_name.split("_")[0],
214
+ cross_attention_dim=cross_attention_dim
215
+ if block_name.endswith("_Cross")
216
+ else None,
217
+ query_dim=dim,
218
+ heads=num_attention_heads,
219
+ dim_head=attention_head_dim,
220
+ dropout=dropout,
221
+ bias=attention_bias,
222
+ upcast_attention=upcast_attention,
223
+ cross_frame_attention_mode=cross_frame_attention_mode,
224
+ temporal_position_encoding=temporal_position_encoding,
225
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
226
+ )
227
+ )
228
+ norms.append(nn.LayerNorm(dim))
229
+
230
+ self.attention_blocks = nn.ModuleList(attention_blocks)
231
+ self.norms = nn.ModuleList(norms)
232
+
233
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
234
+ self.ff_norm = nn.LayerNorm(dim)
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states,
239
+ encoder_hidden_states=None,
240
+ attention_mask=None,
241
+ video_length=None,
242
+ ):
243
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
244
+ norm_hidden_states = norm(hidden_states)
245
+ hidden_states = (
246
+ attention_block(
247
+ norm_hidden_states,
248
+ encoder_hidden_states=encoder_hidden_states
249
+ if attention_block.is_cross_attention
250
+ else None,
251
+ video_length=video_length,
252
+ )
253
+ + hidden_states
254
+ )
255
+
256
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
257
+
258
+ output = hidden_states
259
+ return output
260
+
261
+
262
+ class PositionalEncoding(nn.Module):
263
+ def __init__(self, d_model, dropout=0.0, max_len=24):
264
+ super().__init__()
265
+ self.dropout = nn.Dropout(p=dropout)
266
+ position = torch.arange(max_len).unsqueeze(1)
267
+ div_term = torch.exp(
268
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
269
+ )
270
+ pe = torch.zeros(1, max_len, d_model)
271
+ pe[0, :, 0::2] = torch.sin(position * div_term)
272
+ pe[0, :, 1::2] = torch.cos(position * div_term)
273
+ self.register_buffer("pe", pe)
274
+
275
+ def forward(self, x):
276
+ x = x + self.pe[:, : x.size(1)]
277
+ return self.dropout(x)
278
+
279
+
280
+ class VersatileAttention(Attention):
281
+ def __init__(
282
+ self,
283
+ attention_mode=None,
284
+ cross_frame_attention_mode=None,
285
+ temporal_position_encoding=False,
286
+ temporal_position_encoding_max_len=24,
287
+ *args,
288
+ **kwargs,
289
+ ):
290
+ super().__init__(*args, **kwargs)
291
+ assert attention_mode == "Temporal"
292
+
293
+ self.attention_mode = attention_mode
294
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
295
+
296
+ self.pos_encoder = (
297
+ PositionalEncoding(
298
+ kwargs["query_dim"],
299
+ dropout=0.0,
300
+ max_len=temporal_position_encoding_max_len,
301
+ )
302
+ if (temporal_position_encoding and attention_mode == "Temporal")
303
+ else None
304
+ )
305
+
306
+ def extra_repr(self):
307
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
308
+
309
+ def set_use_memory_efficient_attention_xformers(
310
+ self,
311
+ use_memory_efficient_attention_xformers: bool,
312
+ attention_op: Optional[Callable] = None,
313
+ ):
314
+ if use_memory_efficient_attention_xformers:
315
+ if not is_xformers_available():
316
+ raise ModuleNotFoundError(
317
+ (
318
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
319
+ " xformers"
320
+ ),
321
+ name="xformers",
322
+ )
323
+ elif not torch.cuda.is_available():
324
+ raise ValueError(
325
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
326
+ " only available for GPU "
327
+ )
328
+ else:
329
+ try:
330
+ # Make sure we can run the memory efficient attention
331
+ _ = xformers.ops.memory_efficient_attention(
332
+ torch.randn((1, 2, 40), device="cuda"),
333
+ torch.randn((1, 2, 40), device="cuda"),
334
+ torch.randn((1, 2, 40), device="cuda"),
335
+ )
336
+ except Exception as e:
337
+ raise e
338
+
339
+ # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
340
+ # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
341
+ # You don't need XFormersAttnProcessor here.
342
+ # processor = XFormersAttnProcessor(
343
+ # attention_op=attention_op,
344
+ # )
345
+ processor = AttnProcessor()
346
+ else:
347
+ processor = AttnProcessor()
348
+
349
+ self.set_processor(processor)
350
+
351
+ def forward(
352
+ self,
353
+ hidden_states,
354
+ encoder_hidden_states=None,
355
+ attention_mask=None,
356
+ video_length=None,
357
+ **cross_attention_kwargs,
358
+ ):
359
+ if self.attention_mode == "Temporal":
360
+ d = hidden_states.shape[1] # d means HxW
361
+ hidden_states = rearrange(
362
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
363
+ )
364
+
365
+ if self.pos_encoder is not None:
366
+ hidden_states = self.pos_encoder(hidden_states)
367
+
368
+ encoder_hidden_states = (
369
+ repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
370
+ if encoder_hidden_states is not None
371
+ else encoder_hidden_states
372
+ )
373
+
374
+ else:
375
+ raise NotImplementedError
376
+
377
+ hidden_states = self.processor(
378
+ self,
379
+ hidden_states,
380
+ encoder_hidden_states=encoder_hidden_states,
381
+ attention_mask=attention_mask,
382
+ **cross_attention_kwargs,
383
+ )
384
+
385
+ if self.attention_mode == "Temporal":
386
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
387
+
388
+ return hidden_states
modules/mutual_self_attention.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2
+ from typing import Any, Dict, Optional
3
+
4
+ import torch
5
+ from einops import rearrange
6
+
7
+ from .attention import BasicTransformerBlock
8
+ from .attention import TemporalBasicTransformerBlock
9
+
10
+
11
+ def torch_dfs(model: torch.nn.Module):
12
+ result = [model]
13
+ for child in model.children():
14
+ result += torch_dfs(child)
15
+ return result
16
+
17
+
18
+ class ReferenceAttentionControl:
19
+ def __init__(
20
+ self,
21
+ unet,
22
+ mode="write",
23
+ do_classifier_free_guidance=False,
24
+ attention_auto_machine_weight=float("inf"),
25
+ gn_auto_machine_weight=1.0,
26
+ style_fidelity=1.0,
27
+ reference_attn=True,
28
+ reference_adain=False,
29
+ fusion_blocks="midup",
30
+ batch_size=1,
31
+ reference_attention_weight=1.,
32
+ audio_attention_weight=1.,
33
+ ) -> None:
34
+ # 10. Modify self attention and group norm
35
+ self.unet = unet
36
+ assert mode in ["read", "write"]
37
+ assert fusion_blocks in ["midup", "full"]
38
+ self.reference_attn = reference_attn
39
+ self.reference_adain = reference_adain
40
+ self.fusion_blocks = fusion_blocks
41
+ self.reference_attention_weight = reference_attention_weight
42
+ self.audio_attention_weight = audio_attention_weight
43
+ self.register_reference_hooks(
44
+ mode,
45
+ do_classifier_free_guidance,
46
+ attention_auto_machine_weight,
47
+ gn_auto_machine_weight,
48
+ style_fidelity,
49
+ reference_attn,
50
+ reference_adain,
51
+ fusion_blocks,
52
+ batch_size=batch_size,
53
+ )
54
+
55
+ def register_reference_hooks(
56
+ self,
57
+ mode,
58
+ do_classifier_free_guidance,
59
+ attention_auto_machine_weight,
60
+ gn_auto_machine_weight,
61
+ style_fidelity,
62
+ reference_attn,
63
+ reference_adain,
64
+ dtype=torch.float16,
65
+ batch_size=1,
66
+ num_images_per_prompt=1,
67
+ device=torch.device("cpu"),
68
+ fusion_blocks="midup",
69
+ ):
70
+ MODE = mode
71
+ do_classifier_free_guidance = do_classifier_free_guidance
72
+ attention_auto_machine_weight = attention_auto_machine_weight
73
+ gn_auto_machine_weight = gn_auto_machine_weight
74
+ style_fidelity = style_fidelity
75
+ reference_attn = reference_attn
76
+ reference_adain = reference_adain
77
+ fusion_blocks = fusion_blocks
78
+ num_images_per_prompt = num_images_per_prompt
79
+ reference_attention_weight = self.reference_attention_weight
80
+ audio_attention_weight = self.audio_attention_weight
81
+ dtype = dtype
82
+ if do_classifier_free_guidance:
83
+ uc_mask = (
84
+ torch.Tensor(
85
+ [1] * batch_size * num_images_per_prompt * 16
86
+ + [0] * batch_size * num_images_per_prompt * 16
87
+ )
88
+ .to(device)
89
+ .bool()
90
+ )
91
+ else:
92
+ uc_mask = (
93
+ torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
94
+ .to(device)
95
+ .bool()
96
+ )
97
+
98
+ def hacked_basic_transformer_inner_forward(
99
+ self,
100
+ hidden_states: torch.FloatTensor,
101
+ attention_mask: Optional[torch.FloatTensor] = None,
102
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
103
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
104
+ timestep: Optional[torch.LongTensor] = None,
105
+ cross_attention_kwargs: Dict[str, Any] = None,
106
+ class_labels: Optional[torch.LongTensor] = None,
107
+ video_length=None,
108
+ ):
109
+ if self.use_ada_layer_norm: # False
110
+ norm_hidden_states = self.norm1(hidden_states, timestep)
111
+ elif self.use_ada_layer_norm_zero:
112
+ (
113
+ norm_hidden_states,
114
+ gate_msa,
115
+ shift_mlp,
116
+ scale_mlp,
117
+ gate_mlp,
118
+ ) = self.norm1(
119
+ hidden_states,
120
+ timestep,
121
+ class_labels,
122
+ hidden_dtype=hidden_states.dtype,
123
+ )
124
+ else:
125
+ norm_hidden_states = self.norm1(hidden_states)
126
+
127
+ # 1. Self-Attention
128
+ # self.only_cross_attention = False
129
+ cross_attention_kwargs = (
130
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
131
+ )
132
+ if self.only_cross_attention:
133
+ attn_output = self.attn1(
134
+ norm_hidden_states,
135
+ encoder_hidden_states=encoder_hidden_states
136
+ if self.only_cross_attention
137
+ else None,
138
+ attention_mask=attention_mask,
139
+ **cross_attention_kwargs,
140
+ )
141
+ else:
142
+ if MODE == "write":
143
+ attn_output = self.attn1(
144
+ norm_hidden_states,
145
+ encoder_hidden_states=encoder_hidden_states
146
+ if self.only_cross_attention
147
+ else None,
148
+ attention_mask=attention_mask,
149
+ **cross_attention_kwargs,
150
+ )
151
+
152
+ if self.use_ada_layer_norm_zero:
153
+ attn_output = gate_msa.unsqueeze(1) * attn_output
154
+ hidden_states = attn_output + hidden_states
155
+
156
+ if self.attn2 is not None:
157
+ norm_hidden_states = (
158
+ self.norm2(hidden_states, timestep)
159
+ if self.use_ada_layer_norm
160
+ else self.norm2(hidden_states)
161
+ )
162
+ self.bank.append(norm_hidden_states.clone())
163
+
164
+ # 2. Cross-Attention
165
+ attn_output = self.attn2(
166
+ norm_hidden_states,
167
+ encoder_hidden_states=encoder_hidden_states,
168
+ attention_mask=encoder_attention_mask,
169
+ **cross_attention_kwargs,
170
+ )
171
+ hidden_states = attn_output + hidden_states
172
+
173
+ if MODE == "read":
174
+ hidden_states = (
175
+ self.attn1(
176
+ norm_hidden_states,
177
+ encoder_hidden_states=norm_hidden_states,
178
+ attention_mask=attention_mask,
179
+ )
180
+ + hidden_states
181
+ )
182
+
183
+ if self.use_ada_layer_norm: # False
184
+ norm_hidden_states = self.norm1_5(hidden_states, timestep)
185
+ elif self.use_ada_layer_norm_zero:
186
+ (
187
+ norm_hidden_states,
188
+ gate_msa,
189
+ shift_mlp,
190
+ scale_mlp,
191
+ gate_mlp,
192
+ ) = self.norm1_5(
193
+ hidden_states,
194
+ timestep,
195
+ class_labels,
196
+ hidden_dtype=hidden_states.dtype,
197
+ )
198
+ else:
199
+ norm_hidden_states = self.norm1_5(hidden_states)
200
+
201
+ bank_fea = []
202
+ for d in self.bank:
203
+ if len(d.shape) == 3:
204
+ d = d.unsqueeze(1).repeat(1, video_length, 1, 1)
205
+ bank_fea.append(rearrange(d, "b t l c -> (b t) l c"))
206
+
207
+ attn_hidden_states = self.attn1_5(
208
+ norm_hidden_states,
209
+ encoder_hidden_states=bank_fea[0],
210
+ attention_mask=attention_mask,
211
+ )
212
+
213
+ if reference_attention_weight != 1.:
214
+ attn_hidden_states *= reference_attention_weight
215
+
216
+ hidden_states = (attn_hidden_states + hidden_states)
217
+
218
+ # self.bank.clear()
219
+ if self.attn2 is not None:
220
+ # Cross-Attention
221
+ norm_hidden_states = (
222
+ self.norm2(hidden_states, timestep)
223
+ if self.use_ada_layer_norm
224
+ else self.norm2(hidden_states)
225
+ )
226
+
227
+ attn_hidden_states = self.attn2(
228
+ norm_hidden_states,
229
+ encoder_hidden_states=encoder_hidden_states,
230
+ attention_mask=attention_mask,
231
+ )
232
+
233
+ if audio_attention_weight != 1.:
234
+ attn_hidden_states *= audio_attention_weight
235
+
236
+ hidden_states = (attn_hidden_states + hidden_states)
237
+
238
+ # Feed-forward
239
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
240
+
241
+ # Temporal-Attention
242
+ if self.unet_use_temporal_attention:
243
+ d = hidden_states.shape[1]
244
+ hidden_states = rearrange(
245
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
246
+ )
247
+ norm_hidden_states = (
248
+ self.norm_temp(hidden_states, timestep)
249
+ if self.use_ada_layer_norm
250
+ else self.norm_temp(hidden_states)
251
+ )
252
+ hidden_states = (
253
+ self.attn_temp(norm_hidden_states) + hidden_states
254
+ )
255
+ hidden_states = rearrange(
256
+ hidden_states, "(b d) f c -> (b f) d c", d=d
257
+ )
258
+
259
+ return hidden_states
260
+
261
+ # 3. Feed-forward
262
+ norm_hidden_states = self.norm3(hidden_states)
263
+
264
+ if self.use_ada_layer_norm_zero:
265
+ norm_hidden_states = (
266
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
267
+ )
268
+
269
+ ff_output = self.ff(norm_hidden_states)
270
+
271
+ if self.use_ada_layer_norm_zero:
272
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
273
+
274
+ hidden_states = ff_output + hidden_states
275
+
276
+ return hidden_states
277
+
278
+ if self.reference_attn:
279
+ if self.fusion_blocks == "midup":
280
+ attn_modules = [
281
+ module
282
+ for module in (
283
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
284
+ )
285
+ if isinstance(module, BasicTransformerBlock)
286
+ or isinstance(module, TemporalBasicTransformerBlock)
287
+ ]
288
+ elif self.fusion_blocks == "full":
289
+ attn_modules = [
290
+ module
291
+ for module in torch_dfs(self.unet)
292
+ if isinstance(module, BasicTransformerBlock)
293
+ or isinstance(module, TemporalBasicTransformerBlock)
294
+ ]
295
+ attn_modules = sorted(
296
+ attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
297
+ )
298
+
299
+ for i, module in enumerate(attn_modules):
300
+ module._original_inner_forward = module.forward
301
+ if isinstance(module, BasicTransformerBlock):
302
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
303
+ module, BasicTransformerBlock
304
+ )
305
+ if isinstance(module, TemporalBasicTransformerBlock):
306
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
307
+ module, TemporalBasicTransformerBlock
308
+ )
309
+
310
+ module.bank = []
311
+ module.attn_weight = float(i) / float(len(attn_modules))
312
+
313
+ def update(
314
+ self,
315
+ writer,
316
+ do_classifier_free_guidance=True,
317
+ dtype=torch.float16,
318
+ ):
319
+ if self.reference_attn:
320
+ if self.fusion_blocks == "midup":
321
+ reader_attn_modules = [
322
+ module
323
+ for module in (torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks))
324
+ if isinstance(module, TemporalBasicTransformerBlock)
325
+ ]
326
+ writer_attn_modules = [
327
+ module
328
+ for module in (torch_dfs(writer.unet.mid_block) + torch_dfs(writer.unet.up_blocks))
329
+ if isinstance(module, BasicTransformerBlock)
330
+ ]
331
+ elif self.fusion_blocks == "full":
332
+ reader_attn_modules = [
333
+ module
334
+ for module in torch_dfs(self.unet)
335
+ if isinstance(module, TemporalBasicTransformerBlock)
336
+ ]
337
+ writer_attn_modules = [
338
+ module
339
+ for module in torch_dfs(writer.unet)
340
+ if isinstance(module, BasicTransformerBlock)
341
+ ]
342
+ reader_attn_modules = sorted(
343
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
344
+ )
345
+ writer_attn_modules = sorted(
346
+ writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
347
+ )
348
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
349
+ if do_classifier_free_guidance:
350
+ r.bank = [torch.cat([torch.zeros_like(v), v]).to(dtype) for v in w.bank]
351
+ else:
352
+ r.bank = [v.clone().to(dtype) for v in w.bank]
353
+
354
+ def clear(self):
355
+ if self.reference_attn:
356
+ if self.fusion_blocks == "midup":
357
+ reader_attn_modules = [
358
+ module
359
+ for module in (
360
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
361
+ )
362
+ if isinstance(module, BasicTransformerBlock)
363
+ or isinstance(module, TemporalBasicTransformerBlock)
364
+ ]
365
+ elif self.fusion_blocks == "full":
366
+ reader_attn_modules = [
367
+ module
368
+ for module in torch_dfs(self.unet)
369
+ if isinstance(module, BasicTransformerBlock)
370
+ or isinstance(module, TemporalBasicTransformerBlock)
371
+ ]
372
+ reader_attn_modules = sorted(
373
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
374
+ )
375
+ for r in reader_attn_modules:
376
+ r.bank.clear()
modules/resnet.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+
9
+ class InflatedConv3d(nn.Conv2d):
10
+ def forward(self, x):
11
+ video_length = x.shape[2]
12
+
13
+ x = rearrange(x, "b c f h w -> (b f) c h w")
14
+ x = super().forward(x)
15
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
16
+
17
+ return x
18
+
19
+
20
+ class InflatedGroupNorm(nn.GroupNorm):
21
+ def forward(self, x):
22
+ video_length = x.shape[2]
23
+
24
+ x = rearrange(x, "b c f h w -> (b f) c h w")
25
+ x = super().forward(x)
26
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
27
+
28
+ return x
29
+
30
+
31
+ class Upsample3D(nn.Module):
32
+ def __init__(
33
+ self,
34
+ channels,
35
+ use_conv=False,
36
+ use_conv_transpose=False,
37
+ out_channels=None,
38
+ name="conv",
39
+ ):
40
+ super().__init__()
41
+ self.channels = channels
42
+ self.out_channels = out_channels or channels
43
+ self.use_conv = use_conv
44
+ self.use_conv_transpose = use_conv_transpose
45
+ self.name = name
46
+
47
+ conv = None
48
+ if use_conv_transpose:
49
+ raise NotImplementedError
50
+ elif use_conv:
51
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
52
+
53
+ def forward(self, hidden_states, output_size=None):
54
+ assert hidden_states.shape[1] == self.channels
55
+
56
+ if self.use_conv_transpose:
57
+ raise NotImplementedError
58
+
59
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
60
+ dtype = hidden_states.dtype
61
+ if dtype == torch.bfloat16:
62
+ hidden_states = hidden_states.to(torch.float32)
63
+
64
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
65
+ if hidden_states.shape[0] >= 64:
66
+ hidden_states = hidden_states.contiguous()
67
+
68
+ # if `output_size` is passed we force the interpolation output
69
+ # size and do not make use of `scale_factor=2`
70
+ if output_size is None:
71
+ hidden_states = F.interpolate(
72
+ hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
73
+ )
74
+ else:
75
+ hidden_states = F.interpolate(
76
+ hidden_states, size=output_size, mode="nearest"
77
+ )
78
+
79
+ # If the input is bfloat16, we cast back to bfloat16
80
+ if dtype == torch.bfloat16:
81
+ hidden_states = hidden_states.to(dtype)
82
+
83
+ # if self.use_conv:
84
+ # if self.name == "conv":
85
+ # hidden_states = self.conv(hidden_states)
86
+ # else:
87
+ # hidden_states = self.Conv2d_0(hidden_states)
88
+ hidden_states = self.conv(hidden_states)
89
+
90
+ return hidden_states
91
+
92
+
93
+ class Downsample3D(nn.Module):
94
+ def __init__(
95
+ self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
96
+ ):
97
+ super().__init__()
98
+ self.channels = channels
99
+ self.out_channels = out_channels or channels
100
+ self.use_conv = use_conv
101
+ self.padding = padding
102
+ stride = 2
103
+ self.name = name
104
+
105
+ if use_conv:
106
+ self.conv = InflatedConv3d(
107
+ self.channels, self.out_channels, 3, stride=stride, padding=padding
108
+ )
109
+ else:
110
+ raise NotImplementedError
111
+
112
+ def forward(self, hidden_states):
113
+ assert hidden_states.shape[1] == self.channels
114
+ if self.use_conv and self.padding == 0:
115
+ raise NotImplementedError
116
+
117
+ assert hidden_states.shape[1] == self.channels
118
+ hidden_states = self.conv(hidden_states)
119
+
120
+ return hidden_states
121
+
122
+
123
+ class ResnetBlock3D(nn.Module):
124
+ def __init__(
125
+ self,
126
+ *,
127
+ in_channels,
128
+ out_channels=None,
129
+ conv_shortcut=False,
130
+ dropout=0.0,
131
+ temb_channels=512,
132
+ groups=32,
133
+ groups_out=None,
134
+ pre_norm=True,
135
+ eps=1e-6,
136
+ non_linearity="swish",
137
+ time_embedding_norm="default",
138
+ output_scale_factor=1.0,
139
+ use_in_shortcut=None,
140
+ use_inflated_groupnorm=None,
141
+ ):
142
+ super().__init__()
143
+ self.pre_norm = pre_norm
144
+ self.pre_norm = True
145
+ self.in_channels = in_channels
146
+ out_channels = in_channels if out_channels is None else out_channels
147
+ self.out_channels = out_channels
148
+ self.use_conv_shortcut = conv_shortcut
149
+ self.time_embedding_norm = time_embedding_norm
150
+ self.output_scale_factor = output_scale_factor
151
+
152
+ if groups_out is None:
153
+ groups_out = groups
154
+
155
+ assert use_inflated_groupnorm != None
156
+ if use_inflated_groupnorm:
157
+ self.norm1 = InflatedGroupNorm(
158
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
159
+ )
160
+ else:
161
+ self.norm1 = torch.nn.GroupNorm(
162
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
163
+ )
164
+
165
+ self.conv1 = InflatedConv3d(
166
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
167
+ )
168
+
169
+ if temb_channels is not None:
170
+ if self.time_embedding_norm == "default":
171
+ time_emb_proj_out_channels = out_channels
172
+ elif self.time_embedding_norm == "scale_shift":
173
+ time_emb_proj_out_channels = out_channels * 2
174
+ else:
175
+ raise ValueError(
176
+ f"unknown time_embedding_norm : {self.time_embedding_norm} "
177
+ )
178
+
179
+ self.time_emb_proj = torch.nn.Linear(
180
+ temb_channels, time_emb_proj_out_channels
181
+ )
182
+ else:
183
+ self.time_emb_proj = None
184
+
185
+ if use_inflated_groupnorm:
186
+ self.norm2 = InflatedGroupNorm(
187
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
188
+ )
189
+ else:
190
+ self.norm2 = torch.nn.GroupNorm(
191
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
192
+ )
193
+ self.dropout = torch.nn.Dropout(dropout)
194
+ self.conv2 = InflatedConv3d(
195
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
196
+ )
197
+
198
+ if non_linearity == "swish":
199
+ self.nonlinearity = lambda x: F.silu(x)
200
+ elif non_linearity == "mish":
201
+ self.nonlinearity = Mish()
202
+ elif non_linearity == "silu":
203
+ self.nonlinearity = nn.SiLU()
204
+
205
+ self.use_in_shortcut = (
206
+ self.in_channels != self.out_channels
207
+ if use_in_shortcut is None
208
+ else use_in_shortcut
209
+ )
210
+
211
+ self.conv_shortcut = None
212
+ if self.use_in_shortcut:
213
+ self.conv_shortcut = InflatedConv3d(
214
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
215
+ )
216
+
217
+ def forward(self, input_tensor, temb):
218
+ hidden_states = input_tensor
219
+
220
+ hidden_states = self.norm1(hidden_states)
221
+ hidden_states = self.nonlinearity(hidden_states)
222
+
223
+ hidden_states = self.conv1(hidden_states)
224
+
225
+ if temb is not None:
226
+ temb = self.time_emb_proj(self.nonlinearity(temb))
227
+ if len(temb.shape) == 2:
228
+ temb = temb[:, :, None, None, None]
229
+ elif len(temb.shape) == 3:
230
+ temb = temb[:, :, :, None, None].permute(0, 2, 1, 3, 4)
231
+
232
+ if temb is not None and self.time_embedding_norm == "default":
233
+ hidden_states = hidden_states + temb
234
+
235
+ hidden_states = self.norm2(hidden_states)
236
+
237
+ if temb is not None and self.time_embedding_norm == "scale_shift":
238
+ scale, shift = torch.chunk(temb, 2, dim=1)
239
+ hidden_states = hidden_states * (1 + scale) + shift
240
+
241
+ hidden_states = self.nonlinearity(hidden_states)
242
+
243
+ hidden_states = self.dropout(hidden_states)
244
+ hidden_states = self.conv2(hidden_states)
245
+
246
+ if self.conv_shortcut is not None:
247
+ input_tensor = self.conv_shortcut(input_tensor)
248
+
249
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
250
+
251
+ return output_tensor
252
+
253
+
254
+ class Mish(torch.nn.Module):
255
+ def forward(self, hidden_states):
256
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
modules/transformer_2d.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+
8
+ try:
9
+ from diffusers.models.embeddings import CaptionProjection
10
+ except:
11
+ from diffusers.models.embeddings import PixArtAlphaTextProjection as CaptionProjection
12
+
13
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+ from diffusers.models.normalization import AdaLayerNormSingle
16
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
17
+ from torch import nn
18
+
19
+ from .attention import BasicTransformerBlock
20
+
21
+
22
+ @dataclass
23
+ class Transformer2DModelOutput(BaseOutput):
24
+ """
25
+ The output of [`Transformer2DModel`].
26
+
27
+ Args:
28
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
29
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
30
+ distributions for the unnoised latent pixels.
31
+ """
32
+
33
+ sample: torch.FloatTensor
34
+ ref_feature: torch.FloatTensor
35
+
36
+
37
+ class Transformer2DModel(ModelMixin, ConfigMixin):
38
+ """
39
+ A 2D Transformer model for image-like data.
40
+
41
+ Parameters:
42
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
43
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
44
+ in_channels (`int`, *optional*):
45
+ The number of channels in the input and output (specify if the input is **continuous**).
46
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
47
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
48
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
49
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
50
+ This is fixed during training since it is used to learn a number of position embeddings.
51
+ num_vector_embeds (`int`, *optional*):
52
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
53
+ Includes the class for the masked latent pixel.
54
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
55
+ num_embeds_ada_norm ( `int`, *optional*):
56
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
57
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
58
+ added to the hidden states.
59
+
60
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
61
+ attention_bias (`bool`, *optional*):
62
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
63
+ """
64
+
65
+ _supports_gradient_checkpointing = True
66
+
67
+ @register_to_config
68
+ def __init__(
69
+ self,
70
+ num_attention_heads: int = 16,
71
+ attention_head_dim: int = 88,
72
+ in_channels: Optional[int] = None,
73
+ out_channels: Optional[int] = None,
74
+ num_layers: int = 1,
75
+ dropout: float = 0.0,
76
+ norm_num_groups: int = 32,
77
+ cross_attention_dim: Optional[int] = None,
78
+ attention_bias: bool = False,
79
+ sample_size: Optional[int] = None,
80
+ num_vector_embeds: Optional[int] = None,
81
+ patch_size: Optional[int] = None,
82
+ activation_fn: str = "geglu",
83
+ num_embeds_ada_norm: Optional[int] = None,
84
+ use_linear_projection: bool = False,
85
+ only_cross_attention: bool = False,
86
+ double_self_attention: bool = False,
87
+ upcast_attention: bool = False,
88
+ norm_type: str = "layer_norm",
89
+ norm_elementwise_affine: bool = True,
90
+ norm_eps: float = 1e-5,
91
+ attention_type: str = "default",
92
+ caption_channels: int = None,
93
+ ):
94
+ super().__init__()
95
+ self.use_linear_projection = use_linear_projection
96
+ self.num_attention_heads = num_attention_heads
97
+ self.attention_head_dim = attention_head_dim
98
+ inner_dim = num_attention_heads * attention_head_dim
99
+
100
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
101
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
102
+
103
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
104
+ # Define whether input is continuous or discrete depending on configuration
105
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
106
+ self.is_input_vectorized = num_vector_embeds is not None
107
+ self.is_input_patches = in_channels is not None and patch_size is not None
108
+
109
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
110
+ deprecation_message = (
111
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
112
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
113
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
114
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
115
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
116
+ )
117
+ deprecate(
118
+ "norm_type!=num_embeds_ada_norm",
119
+ "1.0.0",
120
+ deprecation_message,
121
+ standard_warn=False,
122
+ )
123
+ norm_type = "ada_norm"
124
+
125
+ if self.is_input_continuous and self.is_input_vectorized:
126
+ raise ValueError(
127
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
128
+ " sure that either `in_channels` or `num_vector_embeds` is None."
129
+ )
130
+ elif self.is_input_vectorized and self.is_input_patches:
131
+ raise ValueError(
132
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
133
+ " sure that either `num_vector_embeds` or `num_patches` is None."
134
+ )
135
+ elif (
136
+ not self.is_input_continuous
137
+ and not self.is_input_vectorized
138
+ and not self.is_input_patches
139
+ ):
140
+ raise ValueError(
141
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
142
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
143
+ )
144
+
145
+ # 2. Define input layers
146
+ self.in_channels = in_channels
147
+
148
+ self.norm = torch.nn.GroupNorm(
149
+ num_groups=norm_num_groups,
150
+ num_channels=in_channels,
151
+ eps=1e-6,
152
+ affine=True,
153
+ )
154
+ if use_linear_projection:
155
+ self.proj_in = linear_cls(in_channels, inner_dim)
156
+ else:
157
+ self.proj_in = conv_cls(
158
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
159
+ )
160
+
161
+ # 3. Define transformers blocks
162
+ self.transformer_blocks = nn.ModuleList(
163
+ [
164
+ BasicTransformerBlock(
165
+ inner_dim,
166
+ num_attention_heads,
167
+ attention_head_dim,
168
+ dropout=dropout,
169
+ cross_attention_dim=cross_attention_dim,
170
+ activation_fn=activation_fn,
171
+ num_embeds_ada_norm=num_embeds_ada_norm,
172
+ attention_bias=attention_bias,
173
+ only_cross_attention=only_cross_attention,
174
+ double_self_attention=double_self_attention,
175
+ upcast_attention=upcast_attention,
176
+ norm_type=norm_type,
177
+ norm_elementwise_affine=norm_elementwise_affine,
178
+ norm_eps=norm_eps,
179
+ attention_type=attention_type,
180
+ )
181
+ for d in range(num_layers)
182
+ ]
183
+ )
184
+
185
+ # 4. Define output layers
186
+ self.out_channels = in_channels if out_channels is None else out_channels
187
+ # TODO: should use out_channels for continuous projections
188
+ if use_linear_projection:
189
+ self.proj_out = linear_cls(inner_dim, in_channels)
190
+ else:
191
+ self.proj_out = conv_cls(
192
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
193
+ )
194
+
195
+ # 5. PixArt-Alpha blocks.
196
+ self.adaln_single = None
197
+ self.use_additional_conditions = False
198
+ if norm_type == "ada_norm_single":
199
+ self.use_additional_conditions = self.config.sample_size == 128
200
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
201
+ # additional conditions until we find better name
202
+ self.adaln_single = AdaLayerNormSingle(
203
+ inner_dim, use_additional_conditions=self.use_additional_conditions
204
+ )
205
+
206
+ self.caption_projection = None
207
+ if caption_channels is not None:
208
+ self.caption_projection = CaptionProjection(
209
+ in_features=caption_channels, hidden_size=inner_dim
210
+ )
211
+
212
+ self.gradient_checkpointing = False
213
+
214
+ def _set_gradient_checkpointing(self, module, value=False):
215
+ if hasattr(module, "gradient_checkpointing"):
216
+ module.gradient_checkpointing = value
217
+
218
+ def forward(
219
+ self,
220
+ hidden_states: torch.Tensor,
221
+ encoder_hidden_states: Optional[torch.Tensor] = None,
222
+ timestep: Optional[torch.LongTensor] = None,
223
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
224
+ class_labels: Optional[torch.LongTensor] = None,
225
+ cross_attention_kwargs: Dict[str, Any] = None,
226
+ attention_mask: Optional[torch.Tensor] = None,
227
+ encoder_attention_mask: Optional[torch.Tensor] = None,
228
+ return_dict: bool = True,
229
+ ):
230
+ """
231
+ The [`Transformer2DModel`] forward method.
232
+
233
+ Args:
234
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
235
+ Input `hidden_states`.
236
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
237
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
238
+ self-attention.
239
+ timestep ( `torch.LongTensor`, *optional*):
240
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
241
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
242
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
243
+ `AdaLayerZeroNorm`.
244
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
245
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
246
+ `self.processor` in
247
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
248
+ attention_mask ( `torch.Tensor`, *optional*):
249
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
250
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
251
+ negative values to the attention scores corresponding to "discard" tokens.
252
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
253
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
254
+
255
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
256
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
257
+
258
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
259
+ above. This bias will be added to the cross-attention scores.
260
+ return_dict (`bool`, *optional*, defaults to `True`):
261
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
262
+ tuple.
263
+
264
+ Returns:
265
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
266
+ `tuple` where the first element is the sample tensor.
267
+ """
268
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
269
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
270
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
271
+ # expects mask of shape:
272
+ # [batch, key_tokens]
273
+ # adds singleton query_tokens dimension:
274
+ # [batch, 1, key_tokens]
275
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
276
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
277
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
278
+ if attention_mask is not None and attention_mask.ndim == 2:
279
+ # assume that mask is expressed as:
280
+ # (1 = keep, 0 = discard)
281
+ # convert mask into a bias that can be added to attention scores:
282
+ # (keep = +0, discard = -10000.0)
283
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
284
+ attention_mask = attention_mask.unsqueeze(1)
285
+
286
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
287
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
288
+ encoder_attention_mask = (
289
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
290
+ ) * -10000.0
291
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
292
+
293
+ # Retrieve lora scale.
294
+ lora_scale = (
295
+ cross_attention_kwargs.get("scale", 1.0)
296
+ if cross_attention_kwargs is not None
297
+ else 1.0
298
+ )
299
+
300
+ # 1. Input
301
+ batch, _, height, width = hidden_states.shape
302
+ residual = hidden_states
303
+
304
+ hidden_states = self.norm(hidden_states)
305
+ if not self.use_linear_projection:
306
+ hidden_states = (
307
+ self.proj_in(hidden_states, scale=lora_scale)
308
+ if not USE_PEFT_BACKEND
309
+ else self.proj_in(hidden_states)
310
+ )
311
+ inner_dim = hidden_states.shape[1]
312
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
313
+ batch, height * width, inner_dim
314
+ )
315
+ else:
316
+ inner_dim = hidden_states.shape[1]
317
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
318
+ batch, height * width, inner_dim
319
+ )
320
+ hidden_states = (
321
+ self.proj_in(hidden_states, scale=lora_scale)
322
+ if not USE_PEFT_BACKEND
323
+ else self.proj_in(hidden_states)
324
+ )
325
+
326
+ # 2. Blocks
327
+ if self.caption_projection is not None:
328
+ batch_size = hidden_states.shape[0]
329
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
330
+ encoder_hidden_states = encoder_hidden_states.view(
331
+ batch_size, -1, hidden_states.shape[-1]
332
+ )
333
+
334
+ ref_feature = hidden_states.reshape(batch, height, width, inner_dim)
335
+ for block in self.transformer_blocks:
336
+ if self.training and self.gradient_checkpointing:
337
+
338
+ def create_custom_forward(module, return_dict=None):
339
+ def custom_forward(*inputs):
340
+ if return_dict is not None:
341
+ return module(*inputs, return_dict=return_dict)
342
+ else:
343
+ return module(*inputs)
344
+
345
+ return custom_forward
346
+
347
+ ckpt_kwargs: Dict[str, Any] = (
348
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
349
+ )
350
+ hidden_states = torch.utils.checkpoint.checkpoint(
351
+ create_custom_forward(block),
352
+ hidden_states,
353
+ attention_mask,
354
+ encoder_hidden_states,
355
+ encoder_attention_mask,
356
+ timestep,
357
+ cross_attention_kwargs,
358
+ class_labels,
359
+ **ckpt_kwargs,
360
+ )
361
+ else:
362
+ hidden_states = block(
363
+ hidden_states,
364
+ attention_mask=attention_mask,
365
+ encoder_hidden_states=encoder_hidden_states,
366
+ encoder_attention_mask=encoder_attention_mask,
367
+ timestep=timestep,
368
+ cross_attention_kwargs=cross_attention_kwargs,
369
+ class_labels=class_labels,
370
+ )
371
+
372
+ # 3. Output
373
+ if self.is_input_continuous:
374
+ if not self.use_linear_projection:
375
+ hidden_states = (
376
+ hidden_states.reshape(batch, height, width, inner_dim)
377
+ .permute(0, 3, 1, 2)
378
+ .contiguous()
379
+ )
380
+ hidden_states = (
381
+ self.proj_out(hidden_states, scale=lora_scale)
382
+ if not USE_PEFT_BACKEND
383
+ else self.proj_out(hidden_states)
384
+ )
385
+ else:
386
+ hidden_states = (
387
+ self.proj_out(hidden_states, scale=lora_scale)
388
+ if not USE_PEFT_BACKEND
389
+ else self.proj_out(hidden_states)
390
+ )
391
+ hidden_states = (
392
+ hidden_states.reshape(batch, height, width, inner_dim)
393
+ .permute(0, 3, 1, 2)
394
+ .contiguous()
395
+ )
396
+
397
+ output = hidden_states + residual
398
+ if not return_dict:
399
+ return (output, ref_feature)
400
+
401
+ return Transformer2DModelOutput(sample=output, ref_feature=ref_feature)
modules/transformer_3d.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.models import ModelMixin
7
+ from diffusers.utils import BaseOutput
8
+ from diffusers.utils.import_utils import is_xformers_available
9
+ from einops import rearrange, repeat
10
+ from torch import nn
11
+
12
+ from .attention import TemporalBasicTransformerBlock
13
+
14
+
15
+ @dataclass
16
+ class Transformer3DModelOutput(BaseOutput):
17
+ sample: torch.FloatTensor
18
+
19
+
20
+ if is_xformers_available():
21
+ import xformers
22
+ import xformers.ops
23
+ else:
24
+ xformers = None
25
+
26
+
27
+ class Transformer3DModel(ModelMixin, ConfigMixin):
28
+ _supports_gradient_checkpointing = True
29
+
30
+ @register_to_config
31
+ def __init__(
32
+ self,
33
+ num_attention_heads: int = 16,
34
+ attention_head_dim: int = 88,
35
+ in_channels: Optional[int] = None,
36
+ num_layers: int = 1,
37
+ dropout: float = 0.0,
38
+ norm_num_groups: int = 32,
39
+ cross_attention_dim: Optional[int] = None,
40
+ attention_bias: bool = False,
41
+ activation_fn: str = "geglu",
42
+ num_embeds_ada_norm: Optional[int] = None,
43
+ use_linear_projection: bool = False,
44
+ only_cross_attention: bool = False,
45
+ upcast_attention: bool = False,
46
+ unet_use_cross_frame_attention=None,
47
+ unet_use_temporal_attention=None,
48
+ ):
49
+ super().__init__()
50
+ self.use_linear_projection = use_linear_projection
51
+ self.num_attention_heads = num_attention_heads
52
+ self.attention_head_dim = attention_head_dim
53
+ inner_dim = num_attention_heads * attention_head_dim
54
+
55
+ # Define input layers
56
+ self.in_channels = in_channels
57
+
58
+ self.norm = torch.nn.GroupNorm(
59
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
60
+ )
61
+ if use_linear_projection:
62
+ self.proj_in = nn.Linear(in_channels, inner_dim)
63
+ else:
64
+ self.proj_in = nn.Conv2d(
65
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
66
+ )
67
+
68
+ # Define transformers blocks
69
+ self.transformer_blocks = nn.ModuleList(
70
+ [
71
+ TemporalBasicTransformerBlock(
72
+ inner_dim,
73
+ num_attention_heads,
74
+ attention_head_dim,
75
+ dropout=dropout,
76
+ cross_attention_dim=cross_attention_dim,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ attention_bias=attention_bias,
80
+ only_cross_attention=only_cross_attention,
81
+ upcast_attention=upcast_attention,
82
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
83
+ unet_use_temporal_attention=unet_use_temporal_attention,
84
+ )
85
+ for d in range(num_layers)
86
+ ]
87
+ )
88
+
89
+ # 4. Define output layers
90
+ if use_linear_projection:
91
+ self.proj_out = nn.Linear(in_channels, inner_dim)
92
+ else:
93
+ self.proj_out = nn.Conv2d(
94
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
95
+ )
96
+
97
+ self.gradient_checkpointing = False
98
+
99
+ def _set_gradient_checkpointing(self, module, value=False):
100
+ if hasattr(module, "gradient_checkpointing"):
101
+ module.gradient_checkpointing = value
102
+
103
+ def forward(
104
+ self,
105
+ hidden_states,
106
+ encoder_hidden_states=None,
107
+ timestep=None,
108
+ return_dict: bool = True,
109
+ ):
110
+ # Input
111
+ assert (
112
+ hidden_states.dim() == 5
113
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
114
+ video_length = hidden_states.shape[2]
115
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
116
+ if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
117
+ encoder_hidden_states = repeat(
118
+ encoder_hidden_states, "b n c -> (b f) n c", f=video_length
119
+ )
120
+
121
+ batch, channel, height, weight = hidden_states.shape
122
+ residual = hidden_states
123
+
124
+ hidden_states = self.norm(hidden_states)
125
+ if not self.use_linear_projection:
126
+ hidden_states = self.proj_in(hidden_states)
127
+ inner_dim = hidden_states.shape[1]
128
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
129
+ batch, height * weight, inner_dim
130
+ )
131
+ else:
132
+ inner_dim = hidden_states.shape[1]
133
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
134
+ batch, height * weight, inner_dim
135
+ )
136
+ hidden_states = self.proj_in(hidden_states)
137
+
138
+ # Blocks
139
+ for i, block in enumerate(self.transformer_blocks):
140
+ hidden_states = block(
141
+ hidden_states,
142
+ encoder_hidden_states=encoder_hidden_states,
143
+ timestep=timestep,
144
+ video_length=video_length,
145
+ )
146
+
147
+ # Output
148
+ if not self.use_linear_projection:
149
+ hidden_states = (
150
+ hidden_states.reshape(batch, height, weight, inner_dim)
151
+ .permute(0, 3, 1, 2)
152
+ .contiguous()
153
+ )
154
+ hidden_states = self.proj_out(hidden_states)
155
+ else:
156
+ hidden_states = self.proj_out(hidden_states)
157
+ hidden_states = (
158
+ hidden_states.reshape(batch, height, weight, inner_dim)
159
+ .permute(0, 3, 1, 2)
160
+ .contiguous()
161
+ )
162
+
163
+ output = hidden_states + residual
164
+
165
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
166
+ if not return_dict:
167
+ return (output,)
168
+
169
+ return Transformer3DModelOutput(sample=output)
modules/unet_2d_blocks.py ADDED
@@ -0,0 +1,1072 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from diffusers.models.activations import get_activation
6
+ from diffusers.models.attention_processor import Attention
7
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
8
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
9
+ from diffusers.utils import is_torch_version, logging
10
+ from diffusers.utils.torch_utils import apply_freeu
11
+ from torch import nn
12
+
13
+ from .transformer_2d import Transformer2DModel
14
+
15
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
16
+
17
+
18
+ def get_down_block(
19
+ down_block_type: str,
20
+ num_layers: int,
21
+ in_channels: int,
22
+ out_channels: int,
23
+ temb_channels: int,
24
+ add_downsample: bool,
25
+ resnet_eps: float,
26
+ resnet_act_fn: str,
27
+ transformer_layers_per_block: int = 1,
28
+ num_attention_heads: Optional[int] = None,
29
+ resnet_groups: Optional[int] = None,
30
+ cross_attention_dim: Optional[int] = None,
31
+ downsample_padding: Optional[int] = None,
32
+ dual_cross_attention: bool = False,
33
+ use_linear_projection: bool = False,
34
+ only_cross_attention: bool = False,
35
+ upcast_attention: bool = False,
36
+ resnet_time_scale_shift: str = "default",
37
+ attention_type: str = "default",
38
+ resnet_skip_time_act: bool = False,
39
+ resnet_out_scale_factor: float = 1.0,
40
+ cross_attention_norm: Optional[str] = None,
41
+ attention_head_dim: Optional[int] = None,
42
+ downsample_type: Optional[str] = None,
43
+ dropout: float = 0.0,
44
+ ):
45
+ # If attn head dim is not defined, we default it to the number of heads
46
+ if attention_head_dim is None:
47
+ logger.warn(
48
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
49
+ )
50
+ attention_head_dim = num_attention_heads
51
+
52
+ down_block_type = (
53
+ down_block_type[7:]
54
+ if down_block_type.startswith("UNetRes")
55
+ else down_block_type
56
+ )
57
+ if down_block_type == "DownBlock2D":
58
+ return DownBlock2D(
59
+ num_layers=num_layers,
60
+ in_channels=in_channels,
61
+ out_channels=out_channels,
62
+ temb_channels=temb_channels,
63
+ dropout=dropout,
64
+ add_downsample=add_downsample,
65
+ resnet_eps=resnet_eps,
66
+ resnet_act_fn=resnet_act_fn,
67
+ resnet_groups=resnet_groups,
68
+ downsample_padding=downsample_padding,
69
+ resnet_time_scale_shift=resnet_time_scale_shift,
70
+ )
71
+ elif down_block_type == "CrossAttnDownBlock2D":
72
+ if cross_attention_dim is None:
73
+ raise ValueError(
74
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D"
75
+ )
76
+ return CrossAttnDownBlock2D(
77
+ num_layers=num_layers,
78
+ transformer_layers_per_block=transformer_layers_per_block,
79
+ in_channels=in_channels,
80
+ out_channels=out_channels,
81
+ temb_channels=temb_channels,
82
+ dropout=dropout,
83
+ add_downsample=add_downsample,
84
+ resnet_eps=resnet_eps,
85
+ resnet_act_fn=resnet_act_fn,
86
+ resnet_groups=resnet_groups,
87
+ downsample_padding=downsample_padding,
88
+ cross_attention_dim=cross_attention_dim,
89
+ num_attention_heads=num_attention_heads,
90
+ dual_cross_attention=dual_cross_attention,
91
+ use_linear_projection=use_linear_projection,
92
+ only_cross_attention=only_cross_attention,
93
+ upcast_attention=upcast_attention,
94
+ resnet_time_scale_shift=resnet_time_scale_shift,
95
+ attention_type=attention_type,
96
+ )
97
+ raise ValueError(f"{down_block_type} does not exist.")
98
+
99
+
100
+ def get_up_block(
101
+ up_block_type: str,
102
+ num_layers: int,
103
+ in_channels: int,
104
+ out_channels: int,
105
+ prev_output_channel: int,
106
+ temb_channels: int,
107
+ add_upsample: bool,
108
+ resnet_eps: float,
109
+ resnet_act_fn: str,
110
+ resolution_idx: Optional[int] = None,
111
+ transformer_layers_per_block: int = 1,
112
+ num_attention_heads: Optional[int] = None,
113
+ resnet_groups: Optional[int] = None,
114
+ cross_attention_dim: Optional[int] = None,
115
+ dual_cross_attention: bool = False,
116
+ use_linear_projection: bool = False,
117
+ only_cross_attention: bool = False,
118
+ upcast_attention: bool = False,
119
+ resnet_time_scale_shift: str = "default",
120
+ attention_type: str = "default",
121
+ resnet_skip_time_act: bool = False,
122
+ resnet_out_scale_factor: float = 1.0,
123
+ cross_attention_norm: Optional[str] = None,
124
+ attention_head_dim: Optional[int] = None,
125
+ upsample_type: Optional[str] = None,
126
+ dropout: float = 0.0,
127
+ ) -> nn.Module:
128
+ # If attn head dim is not defined, we default it to the number of heads
129
+ if attention_head_dim is None:
130
+ logger.warn(
131
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
132
+ )
133
+ attention_head_dim = num_attention_heads
134
+
135
+ up_block_type = (
136
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
137
+ )
138
+ if up_block_type == "UpBlock2D":
139
+ return UpBlock2D(
140
+ num_layers=num_layers,
141
+ in_channels=in_channels,
142
+ out_channels=out_channels,
143
+ prev_output_channel=prev_output_channel,
144
+ temb_channels=temb_channels,
145
+ resolution_idx=resolution_idx,
146
+ dropout=dropout,
147
+ add_upsample=add_upsample,
148
+ resnet_eps=resnet_eps,
149
+ resnet_act_fn=resnet_act_fn,
150
+ resnet_groups=resnet_groups,
151
+ resnet_time_scale_shift=resnet_time_scale_shift,
152
+ )
153
+ elif up_block_type == "CrossAttnUpBlock2D":
154
+ if cross_attention_dim is None:
155
+ raise ValueError(
156
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D"
157
+ )
158
+ return CrossAttnUpBlock2D(
159
+ num_layers=num_layers,
160
+ transformer_layers_per_block=transformer_layers_per_block,
161
+ in_channels=in_channels,
162
+ out_channels=out_channels,
163
+ prev_output_channel=prev_output_channel,
164
+ temb_channels=temb_channels,
165
+ resolution_idx=resolution_idx,
166
+ dropout=dropout,
167
+ add_upsample=add_upsample,
168
+ resnet_eps=resnet_eps,
169
+ resnet_act_fn=resnet_act_fn,
170
+ resnet_groups=resnet_groups,
171
+ cross_attention_dim=cross_attention_dim,
172
+ num_attention_heads=num_attention_heads,
173
+ dual_cross_attention=dual_cross_attention,
174
+ use_linear_projection=use_linear_projection,
175
+ only_cross_attention=only_cross_attention,
176
+ upcast_attention=upcast_attention,
177
+ resnet_time_scale_shift=resnet_time_scale_shift,
178
+ attention_type=attention_type,
179
+ )
180
+
181
+ raise ValueError(f"{up_block_type} does not exist.")
182
+
183
+
184
+ class AutoencoderTinyBlock(nn.Module):
185
+ """
186
+ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
187
+ blocks.
188
+
189
+ Args:
190
+ in_channels (`int`): The number of input channels.
191
+ out_channels (`int`): The number of output channels.
192
+ act_fn (`str`):
193
+ ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
194
+
195
+ Returns:
196
+ `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
197
+ `out_channels`.
198
+ """
199
+
200
+ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
201
+ super().__init__()
202
+ act_fn = get_activation(act_fn)
203
+ self.conv = nn.Sequential(
204
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
205
+ act_fn,
206
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
207
+ act_fn,
208
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
209
+ )
210
+ self.skip = (
211
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
212
+ if in_channels != out_channels
213
+ else nn.Identity()
214
+ )
215
+ self.fuse = nn.ReLU()
216
+
217
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
218
+ return self.fuse(self.conv(x) + self.skip(x))
219
+
220
+
221
+ class UNetMidBlock2D(nn.Module):
222
+ """
223
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
224
+
225
+ Args:
226
+ in_channels (`int`): The number of input channels.
227
+ temb_channels (`int`): The number of temporal embedding channels.
228
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
229
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
230
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
231
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
232
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
233
+ model on tasks with long-range temporal dependencies.
234
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
235
+ resnet_groups (`int`, *optional*, defaults to 32):
236
+ The number of groups to use in the group normalization layers of the resnet blocks.
237
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
238
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
239
+ Whether to use pre-normalization for the resnet blocks.
240
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
241
+ attention_head_dim (`int`, *optional*, defaults to 1):
242
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
243
+ the number of input channels.
244
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
245
+
246
+ Returns:
247
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
248
+ in_channels, height, width)`.
249
+
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ in_channels: int,
255
+ temb_channels: int,
256
+ dropout: float = 0.0,
257
+ num_layers: int = 1,
258
+ resnet_eps: float = 1e-6,
259
+ resnet_time_scale_shift: str = "default", # default, spatial
260
+ resnet_act_fn: str = "swish",
261
+ resnet_groups: int = 32,
262
+ attn_groups: Optional[int] = None,
263
+ resnet_pre_norm: bool = True,
264
+ add_attention: bool = True,
265
+ attention_head_dim: int = 1,
266
+ output_scale_factor: float = 1.0,
267
+ ):
268
+ super().__init__()
269
+ resnet_groups = (
270
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
271
+ )
272
+ self.add_attention = add_attention
273
+
274
+ if attn_groups is None:
275
+ attn_groups = (
276
+ resnet_groups if resnet_time_scale_shift == "default" else None
277
+ )
278
+
279
+ # there is always at least one resnet
280
+ resnets = [
281
+ ResnetBlock2D(
282
+ in_channels=in_channels,
283
+ out_channels=in_channels,
284
+ temb_channels=temb_channels,
285
+ eps=resnet_eps,
286
+ groups=resnet_groups,
287
+ dropout=dropout,
288
+ time_embedding_norm=resnet_time_scale_shift,
289
+ non_linearity=resnet_act_fn,
290
+ output_scale_factor=output_scale_factor,
291
+ pre_norm=resnet_pre_norm,
292
+ )
293
+ ]
294
+ attentions = []
295
+
296
+ if attention_head_dim is None:
297
+ logger.warn(
298
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
299
+ )
300
+ attention_head_dim = in_channels
301
+
302
+ for _ in range(num_layers):
303
+ if self.add_attention:
304
+ attentions.append(
305
+ Attention(
306
+ in_channels,
307
+ heads=in_channels // attention_head_dim,
308
+ dim_head=attention_head_dim,
309
+ rescale_output_factor=output_scale_factor,
310
+ eps=resnet_eps,
311
+ norm_num_groups=attn_groups,
312
+ spatial_norm_dim=temb_channels
313
+ if resnet_time_scale_shift == "spatial"
314
+ else None,
315
+ residual_connection=True,
316
+ bias=True,
317
+ upcast_softmax=True,
318
+ _from_deprecated_attn_block=True,
319
+ )
320
+ )
321
+ else:
322
+ attentions.append(None)
323
+
324
+ resnets.append(
325
+ ResnetBlock2D(
326
+ in_channels=in_channels,
327
+ out_channels=in_channels,
328
+ temb_channels=temb_channels,
329
+ eps=resnet_eps,
330
+ groups=resnet_groups,
331
+ dropout=dropout,
332
+ time_embedding_norm=resnet_time_scale_shift,
333
+ non_linearity=resnet_act_fn,
334
+ output_scale_factor=output_scale_factor,
335
+ pre_norm=resnet_pre_norm,
336
+ )
337
+ )
338
+
339
+ self.attentions = nn.ModuleList(attentions)
340
+ self.resnets = nn.ModuleList(resnets)
341
+
342
+ def forward(
343
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
344
+ ) -> torch.FloatTensor:
345
+ hidden_states = self.resnets[0](hidden_states, temb)
346
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
347
+ if attn is not None:
348
+ hidden_states = attn(hidden_states, temb=temb)
349
+ hidden_states = resnet(hidden_states, temb)
350
+
351
+ return hidden_states
352
+
353
+
354
+ class UNetMidBlock2DCrossAttn(nn.Module):
355
+ def __init__(
356
+ self,
357
+ in_channels: int,
358
+ temb_channels: int,
359
+ dropout: float = 0.0,
360
+ num_layers: int = 1,
361
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
362
+ resnet_eps: float = 1e-6,
363
+ resnet_time_scale_shift: str = "default",
364
+ resnet_act_fn: str = "swish",
365
+ resnet_groups: int = 32,
366
+ resnet_pre_norm: bool = True,
367
+ num_attention_heads: int = 1,
368
+ output_scale_factor: float = 1.0,
369
+ cross_attention_dim: int = 1280,
370
+ dual_cross_attention: bool = False,
371
+ use_linear_projection: bool = False,
372
+ upcast_attention: bool = False,
373
+ attention_type: str = "default",
374
+ ):
375
+ super().__init__()
376
+
377
+ self.has_cross_attention = True
378
+ self.num_attention_heads = num_attention_heads
379
+ resnet_groups = (
380
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
381
+ )
382
+
383
+ # support for variable transformer layers per block
384
+ if isinstance(transformer_layers_per_block, int):
385
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
386
+
387
+ # there is always at least one resnet
388
+ resnets = [
389
+ ResnetBlock2D(
390
+ in_channels=in_channels,
391
+ out_channels=in_channels,
392
+ temb_channels=temb_channels,
393
+ eps=resnet_eps,
394
+ groups=resnet_groups,
395
+ dropout=dropout,
396
+ time_embedding_norm=resnet_time_scale_shift,
397
+ non_linearity=resnet_act_fn,
398
+ output_scale_factor=output_scale_factor,
399
+ pre_norm=resnet_pre_norm,
400
+ )
401
+ ]
402
+ attentions = []
403
+
404
+ for i in range(num_layers):
405
+ if not dual_cross_attention:
406
+ attentions.append(
407
+ Transformer2DModel(
408
+ num_attention_heads,
409
+ in_channels // num_attention_heads,
410
+ in_channels=in_channels,
411
+ num_layers=transformer_layers_per_block[i],
412
+ cross_attention_dim=cross_attention_dim,
413
+ norm_num_groups=resnet_groups,
414
+ use_linear_projection=use_linear_projection,
415
+ upcast_attention=upcast_attention,
416
+ attention_type=attention_type,
417
+ )
418
+ )
419
+ else:
420
+ attentions.append(
421
+ DualTransformer2DModel(
422
+ num_attention_heads,
423
+ in_channels // num_attention_heads,
424
+ in_channels=in_channels,
425
+ num_layers=1,
426
+ cross_attention_dim=cross_attention_dim,
427
+ norm_num_groups=resnet_groups,
428
+ )
429
+ )
430
+ resnets.append(
431
+ ResnetBlock2D(
432
+ in_channels=in_channels,
433
+ out_channels=in_channels,
434
+ temb_channels=temb_channels,
435
+ eps=resnet_eps,
436
+ groups=resnet_groups,
437
+ dropout=dropout,
438
+ time_embedding_norm=resnet_time_scale_shift,
439
+ non_linearity=resnet_act_fn,
440
+ output_scale_factor=output_scale_factor,
441
+ pre_norm=resnet_pre_norm,
442
+ )
443
+ )
444
+
445
+ self.attentions = nn.ModuleList(attentions)
446
+ self.resnets = nn.ModuleList(resnets)
447
+
448
+ self.gradient_checkpointing = False
449
+
450
+ def forward(
451
+ self,
452
+ hidden_states: torch.FloatTensor,
453
+ temb: Optional[torch.FloatTensor] = None,
454
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
455
+ attention_mask: Optional[torch.FloatTensor] = None,
456
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
457
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
458
+ ) -> torch.FloatTensor:
459
+ lora_scale = (
460
+ cross_attention_kwargs.get("scale", 1.0)
461
+ if cross_attention_kwargs is not None
462
+ else 1.0
463
+ )
464
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
465
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
466
+ if self.training and self.gradient_checkpointing:
467
+
468
+ def create_custom_forward(module, return_dict=None):
469
+ def custom_forward(*inputs):
470
+ if return_dict is not None:
471
+ return module(*inputs, return_dict=return_dict)
472
+ else:
473
+ return module(*inputs)
474
+
475
+ return custom_forward
476
+
477
+ ckpt_kwargs: Dict[str, Any] = (
478
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
479
+ )
480
+ hidden_states, ref_feature = attn(
481
+ hidden_states,
482
+ encoder_hidden_states=encoder_hidden_states,
483
+ cross_attention_kwargs=cross_attention_kwargs,
484
+ attention_mask=attention_mask,
485
+ encoder_attention_mask=encoder_attention_mask,
486
+ return_dict=False,
487
+ )
488
+ hidden_states = torch.utils.checkpoint.checkpoint(
489
+ create_custom_forward(resnet),
490
+ hidden_states,
491
+ temb,
492
+ **ckpt_kwargs,
493
+ )
494
+ else:
495
+ hidden_states, ref_feature = attn(
496
+ hidden_states,
497
+ encoder_hidden_states=encoder_hidden_states,
498
+ cross_attention_kwargs=cross_attention_kwargs,
499
+ attention_mask=attention_mask,
500
+ encoder_attention_mask=encoder_attention_mask,
501
+ return_dict=False,
502
+ )
503
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
504
+
505
+ return hidden_states
506
+
507
+
508
+ class CrossAttnDownBlock2D(nn.Module):
509
+ def __init__(
510
+ self,
511
+ in_channels: int,
512
+ out_channels: int,
513
+ temb_channels: int,
514
+ dropout: float = 0.0,
515
+ num_layers: int = 1,
516
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
517
+ resnet_eps: float = 1e-6,
518
+ resnet_time_scale_shift: str = "default",
519
+ resnet_act_fn: str = "swish",
520
+ resnet_groups: int = 32,
521
+ resnet_pre_norm: bool = True,
522
+ num_attention_heads: int = 1,
523
+ cross_attention_dim: int = 1280,
524
+ output_scale_factor: float = 1.0,
525
+ downsample_padding: int = 1,
526
+ add_downsample: bool = True,
527
+ dual_cross_attention: bool = False,
528
+ use_linear_projection: bool = False,
529
+ only_cross_attention: bool = False,
530
+ upcast_attention: bool = False,
531
+ attention_type: str = "default",
532
+ ):
533
+ super().__init__()
534
+ resnets = []
535
+ attentions = []
536
+
537
+ self.has_cross_attention = True
538
+ self.num_attention_heads = num_attention_heads
539
+ if isinstance(transformer_layers_per_block, int):
540
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
541
+
542
+ for i in range(num_layers):
543
+ in_channels = in_channels if i == 0 else out_channels
544
+ resnets.append(
545
+ ResnetBlock2D(
546
+ in_channels=in_channels,
547
+ out_channels=out_channels,
548
+ temb_channels=temb_channels,
549
+ eps=resnet_eps,
550
+ groups=resnet_groups,
551
+ dropout=dropout,
552
+ time_embedding_norm=resnet_time_scale_shift,
553
+ non_linearity=resnet_act_fn,
554
+ output_scale_factor=output_scale_factor,
555
+ pre_norm=resnet_pre_norm,
556
+ )
557
+ )
558
+ if not dual_cross_attention:
559
+ attentions.append(
560
+ Transformer2DModel(
561
+ num_attention_heads,
562
+ out_channels // num_attention_heads,
563
+ in_channels=out_channels,
564
+ num_layers=transformer_layers_per_block[i],
565
+ cross_attention_dim=cross_attention_dim,
566
+ norm_num_groups=resnet_groups,
567
+ use_linear_projection=use_linear_projection,
568
+ only_cross_attention=only_cross_attention,
569
+ upcast_attention=upcast_attention,
570
+ attention_type=attention_type,
571
+ )
572
+ )
573
+ else:
574
+ attentions.append(
575
+ DualTransformer2DModel(
576
+ num_attention_heads,
577
+ out_channels // num_attention_heads,
578
+ in_channels=out_channels,
579
+ num_layers=1,
580
+ cross_attention_dim=cross_attention_dim,
581
+ norm_num_groups=resnet_groups,
582
+ )
583
+ )
584
+ self.attentions = nn.ModuleList(attentions)
585
+ self.resnets = nn.ModuleList(resnets)
586
+
587
+ if add_downsample:
588
+ self.downsamplers = nn.ModuleList(
589
+ [
590
+ Downsample2D(
591
+ out_channels,
592
+ use_conv=True,
593
+ out_channels=out_channels,
594
+ padding=downsample_padding,
595
+ name="op",
596
+ )
597
+ ]
598
+ )
599
+ else:
600
+ self.downsamplers = None
601
+
602
+ self.gradient_checkpointing = False
603
+
604
+ def forward(
605
+ self,
606
+ hidden_states: torch.FloatTensor,
607
+ temb: Optional[torch.FloatTensor] = None,
608
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
609
+ attention_mask: Optional[torch.FloatTensor] = None,
610
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
611
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
612
+ additional_residuals: Optional[torch.FloatTensor] = None,
613
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
614
+ output_states = ()
615
+
616
+ lora_scale = (
617
+ cross_attention_kwargs.get("scale", 1.0)
618
+ if cross_attention_kwargs is not None
619
+ else 1.0
620
+ )
621
+
622
+ blocks = list(zip(self.resnets, self.attentions))
623
+
624
+ for i, (resnet, attn) in enumerate(blocks):
625
+ if self.training and self.gradient_checkpointing:
626
+
627
+ def create_custom_forward(module, return_dict=None):
628
+ def custom_forward(*inputs):
629
+ if return_dict is not None:
630
+ return module(*inputs, return_dict=return_dict)
631
+ else:
632
+ return module(*inputs)
633
+
634
+ return custom_forward
635
+
636
+ ckpt_kwargs: Dict[str, Any] = (
637
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
638
+ )
639
+ hidden_states = torch.utils.checkpoint.checkpoint(
640
+ create_custom_forward(resnet),
641
+ hidden_states,
642
+ temb,
643
+ **ckpt_kwargs,
644
+ )
645
+ hidden_states, ref_feature = attn(
646
+ hidden_states,
647
+ encoder_hidden_states=encoder_hidden_states,
648
+ cross_attention_kwargs=cross_attention_kwargs,
649
+ attention_mask=attention_mask,
650
+ encoder_attention_mask=encoder_attention_mask,
651
+ return_dict=False,
652
+ )
653
+ else:
654
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
655
+ hidden_states, ref_feature = attn(
656
+ hidden_states,
657
+ encoder_hidden_states=encoder_hidden_states,
658
+ cross_attention_kwargs=cross_attention_kwargs,
659
+ attention_mask=attention_mask,
660
+ encoder_attention_mask=encoder_attention_mask,
661
+ return_dict=False,
662
+ )
663
+
664
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
665
+ if i == len(blocks) - 1 and additional_residuals is not None:
666
+ hidden_states = hidden_states + additional_residuals
667
+
668
+ output_states = output_states + (hidden_states,)
669
+
670
+ if self.downsamplers is not None:
671
+ for downsampler in self.downsamplers:
672
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
673
+
674
+ output_states = output_states + (hidden_states,)
675
+
676
+ return hidden_states, output_states
677
+
678
+
679
+ class DownBlock2D(nn.Module):
680
+ def __init__(
681
+ self,
682
+ in_channels: int,
683
+ out_channels: int,
684
+ temb_channels: int,
685
+ dropout: float = 0.0,
686
+ num_layers: int = 1,
687
+ resnet_eps: float = 1e-6,
688
+ resnet_time_scale_shift: str = "default",
689
+ resnet_act_fn: str = "swish",
690
+ resnet_groups: int = 32,
691
+ resnet_pre_norm: bool = True,
692
+ output_scale_factor: float = 1.0,
693
+ add_downsample: bool = True,
694
+ downsample_padding: int = 1,
695
+ ):
696
+ super().__init__()
697
+ resnets = []
698
+
699
+ for i in range(num_layers):
700
+ in_channels = in_channels if i == 0 else out_channels
701
+ resnets.append(
702
+ ResnetBlock2D(
703
+ in_channels=in_channels,
704
+ out_channels=out_channels,
705
+ temb_channels=temb_channels,
706
+ eps=resnet_eps,
707
+ groups=resnet_groups,
708
+ dropout=dropout,
709
+ time_embedding_norm=resnet_time_scale_shift,
710
+ non_linearity=resnet_act_fn,
711
+ output_scale_factor=output_scale_factor,
712
+ pre_norm=resnet_pre_norm,
713
+ )
714
+ )
715
+
716
+ self.resnets = nn.ModuleList(resnets)
717
+
718
+ if add_downsample:
719
+ self.downsamplers = nn.ModuleList(
720
+ [
721
+ Downsample2D(
722
+ out_channels,
723
+ use_conv=True,
724
+ out_channels=out_channels,
725
+ padding=downsample_padding,
726
+ name="op",
727
+ )
728
+ ]
729
+ )
730
+ else:
731
+ self.downsamplers = None
732
+
733
+ self.gradient_checkpointing = False
734
+
735
+ def forward(
736
+ self,
737
+ hidden_states: torch.FloatTensor,
738
+ temb: Optional[torch.FloatTensor] = None,
739
+ scale: float = 1.0,
740
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
741
+ output_states = ()
742
+
743
+ for resnet in self.resnets:
744
+ if self.training and self.gradient_checkpointing:
745
+
746
+ def create_custom_forward(module):
747
+ def custom_forward(*inputs):
748
+ return module(*inputs)
749
+
750
+ return custom_forward
751
+
752
+ if is_torch_version(">=", "1.11.0"):
753
+ hidden_states = torch.utils.checkpoint.checkpoint(
754
+ create_custom_forward(resnet),
755
+ hidden_states,
756
+ temb,
757
+ use_reentrant=False,
758
+ )
759
+ else:
760
+ hidden_states = torch.utils.checkpoint.checkpoint(
761
+ create_custom_forward(resnet), hidden_states, temb
762
+ )
763
+ else:
764
+ hidden_states = resnet(hidden_states, temb, scale=scale)
765
+
766
+ output_states = output_states + (hidden_states,)
767
+
768
+ if self.downsamplers is not None:
769
+ for downsampler in self.downsamplers:
770
+ hidden_states = downsampler(hidden_states, scale=scale)
771
+
772
+ output_states = output_states + (hidden_states,)
773
+
774
+ return hidden_states, output_states
775
+
776
+
777
+ class CrossAttnUpBlock2D(nn.Module):
778
+ def __init__(
779
+ self,
780
+ in_channels: int,
781
+ out_channels: int,
782
+ prev_output_channel: int,
783
+ temb_channels: int,
784
+ resolution_idx: Optional[int] = None,
785
+ dropout: float = 0.0,
786
+ num_layers: int = 1,
787
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
788
+ resnet_eps: float = 1e-6,
789
+ resnet_time_scale_shift: str = "default",
790
+ resnet_act_fn: str = "swish",
791
+ resnet_groups: int = 32,
792
+ resnet_pre_norm: bool = True,
793
+ num_attention_heads: int = 1,
794
+ cross_attention_dim: int = 1280,
795
+ output_scale_factor: float = 1.0,
796
+ add_upsample: bool = True,
797
+ dual_cross_attention: bool = False,
798
+ use_linear_projection: bool = False,
799
+ only_cross_attention: bool = False,
800
+ upcast_attention: bool = False,
801
+ attention_type: str = "default",
802
+ ):
803
+ super().__init__()
804
+ resnets = []
805
+ attentions = []
806
+
807
+ self.has_cross_attention = True
808
+ self.num_attention_heads = num_attention_heads
809
+
810
+ if isinstance(transformer_layers_per_block, int):
811
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
812
+
813
+ for i in range(num_layers):
814
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
815
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
816
+
817
+ resnets.append(
818
+ ResnetBlock2D(
819
+ in_channels=resnet_in_channels + res_skip_channels,
820
+ out_channels=out_channels,
821
+ temb_channels=temb_channels,
822
+ eps=resnet_eps,
823
+ groups=resnet_groups,
824
+ dropout=dropout,
825
+ time_embedding_norm=resnet_time_scale_shift,
826
+ non_linearity=resnet_act_fn,
827
+ output_scale_factor=output_scale_factor,
828
+ pre_norm=resnet_pre_norm,
829
+ )
830
+ )
831
+ if not dual_cross_attention:
832
+ attentions.append(
833
+ Transformer2DModel(
834
+ num_attention_heads,
835
+ out_channels // num_attention_heads,
836
+ in_channels=out_channels,
837
+ num_layers=transformer_layers_per_block[i],
838
+ cross_attention_dim=cross_attention_dim,
839
+ norm_num_groups=resnet_groups,
840
+ use_linear_projection=use_linear_projection,
841
+ only_cross_attention=only_cross_attention,
842
+ upcast_attention=upcast_attention,
843
+ attention_type=attention_type,
844
+ )
845
+ )
846
+ else:
847
+ attentions.append(
848
+ DualTransformer2DModel(
849
+ num_attention_heads,
850
+ out_channels // num_attention_heads,
851
+ in_channels=out_channels,
852
+ num_layers=1,
853
+ cross_attention_dim=cross_attention_dim,
854
+ norm_num_groups=resnet_groups,
855
+ )
856
+ )
857
+ self.attentions = nn.ModuleList(attentions)
858
+ self.resnets = nn.ModuleList(resnets)
859
+
860
+ if add_upsample:
861
+ self.upsamplers = nn.ModuleList(
862
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
863
+ )
864
+ else:
865
+ self.upsamplers = None
866
+
867
+ self.gradient_checkpointing = False
868
+ self.resolution_idx = resolution_idx
869
+
870
+ def forward(
871
+ self,
872
+ hidden_states: torch.FloatTensor,
873
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
874
+ temb: Optional[torch.FloatTensor] = None,
875
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
876
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
877
+ upsample_size: Optional[int] = None,
878
+ attention_mask: Optional[torch.FloatTensor] = None,
879
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
880
+ ) -> torch.FloatTensor:
881
+ lora_scale = (
882
+ cross_attention_kwargs.get("scale", 1.0)
883
+ if cross_attention_kwargs is not None
884
+ else 1.0
885
+ )
886
+ is_freeu_enabled = (
887
+ getattr(self, "s1", None)
888
+ and getattr(self, "s2", None)
889
+ and getattr(self, "b1", None)
890
+ and getattr(self, "b2", None)
891
+ )
892
+
893
+ for resnet, attn in zip(self.resnets, self.attentions):
894
+ # pop res hidden states
895
+ res_hidden_states = res_hidden_states_tuple[-1]
896
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
897
+
898
+ # FreeU: Only operate on the first two stages
899
+ if is_freeu_enabled:
900
+ hidden_states, res_hidden_states = apply_freeu(
901
+ self.resolution_idx,
902
+ hidden_states,
903
+ res_hidden_states,
904
+ s1=self.s1,
905
+ s2=self.s2,
906
+ b1=self.b1,
907
+ b2=self.b2,
908
+ )
909
+
910
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
911
+
912
+ if self.training and self.gradient_checkpointing:
913
+
914
+ def create_custom_forward(module, return_dict=None):
915
+ def custom_forward(*inputs):
916
+ if return_dict is not None:
917
+ return module(*inputs, return_dict=return_dict)
918
+ else:
919
+ return module(*inputs)
920
+
921
+ return custom_forward
922
+
923
+ ckpt_kwargs: Dict[str, Any] = (
924
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
925
+ )
926
+ hidden_states = torch.utils.checkpoint.checkpoint(
927
+ create_custom_forward(resnet),
928
+ hidden_states,
929
+ temb,
930
+ **ckpt_kwargs,
931
+ )
932
+ hidden_states, ref_feature = attn(
933
+ hidden_states,
934
+ encoder_hidden_states=encoder_hidden_states,
935
+ cross_attention_kwargs=cross_attention_kwargs,
936
+ attention_mask=attention_mask,
937
+ encoder_attention_mask=encoder_attention_mask,
938
+ return_dict=False,
939
+ )
940
+ else:
941
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
942
+ hidden_states, ref_feature = attn(
943
+ hidden_states,
944
+ encoder_hidden_states=encoder_hidden_states,
945
+ cross_attention_kwargs=cross_attention_kwargs,
946
+ attention_mask=attention_mask,
947
+ encoder_attention_mask=encoder_attention_mask,
948
+ return_dict=False,
949
+ )
950
+
951
+ if self.upsamplers is not None:
952
+ for upsampler in self.upsamplers:
953
+ hidden_states = upsampler(
954
+ hidden_states, upsample_size, scale=lora_scale
955
+ )
956
+
957
+ return hidden_states
958
+
959
+
960
+ class UpBlock2D(nn.Module):
961
+ def __init__(
962
+ self,
963
+ in_channels: int,
964
+ prev_output_channel: int,
965
+ out_channels: int,
966
+ temb_channels: int,
967
+ resolution_idx: Optional[int] = None,
968
+ dropout: float = 0.0,
969
+ num_layers: int = 1,
970
+ resnet_eps: float = 1e-6,
971
+ resnet_time_scale_shift: str = "default",
972
+ resnet_act_fn: str = "swish",
973
+ resnet_groups: int = 32,
974
+ resnet_pre_norm: bool = True,
975
+ output_scale_factor: float = 1.0,
976
+ add_upsample: bool = True,
977
+ ):
978
+ super().__init__()
979
+ resnets = []
980
+
981
+ for i in range(num_layers):
982
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
983
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
984
+
985
+ resnets.append(
986
+ ResnetBlock2D(
987
+ in_channels=resnet_in_channels + res_skip_channels,
988
+ out_channels=out_channels,
989
+ temb_channels=temb_channels,
990
+ eps=resnet_eps,
991
+ groups=resnet_groups,
992
+ dropout=dropout,
993
+ time_embedding_norm=resnet_time_scale_shift,
994
+ non_linearity=resnet_act_fn,
995
+ output_scale_factor=output_scale_factor,
996
+ pre_norm=resnet_pre_norm,
997
+ )
998
+ )
999
+
1000
+ self.resnets = nn.ModuleList(resnets)
1001
+
1002
+ if add_upsample:
1003
+ self.upsamplers = nn.ModuleList(
1004
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1005
+ )
1006
+ else:
1007
+ self.upsamplers = None
1008
+
1009
+ self.gradient_checkpointing = False
1010
+ self.resolution_idx = resolution_idx
1011
+
1012
+ def forward(
1013
+ self,
1014
+ hidden_states: torch.FloatTensor,
1015
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1016
+ temb: Optional[torch.FloatTensor] = None,
1017
+ upsample_size: Optional[int] = None,
1018
+ scale: float = 1.0,
1019
+ ) -> torch.FloatTensor:
1020
+ is_freeu_enabled = (
1021
+ getattr(self, "s1", None)
1022
+ and getattr(self, "s2", None)
1023
+ and getattr(self, "b1", None)
1024
+ and getattr(self, "b2", None)
1025
+ )
1026
+
1027
+ for resnet in self.resnets:
1028
+ # pop res hidden states
1029
+ res_hidden_states = res_hidden_states_tuple[-1]
1030
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1031
+
1032
+ # FreeU: Only operate on the first two stages
1033
+ if is_freeu_enabled:
1034
+ hidden_states, res_hidden_states = apply_freeu(
1035
+ self.resolution_idx,
1036
+ hidden_states,
1037
+ res_hidden_states,
1038
+ s1=self.s1,
1039
+ s2=self.s2,
1040
+ b1=self.b1,
1041
+ b2=self.b2,
1042
+ )
1043
+
1044
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1045
+
1046
+ if self.training and self.gradient_checkpointing:
1047
+
1048
+ def create_custom_forward(module):
1049
+ def custom_forward(*inputs):
1050
+ return module(*inputs)
1051
+
1052
+ return custom_forward
1053
+
1054
+ if is_torch_version(">=", "1.11.0"):
1055
+ hidden_states = torch.utils.checkpoint.checkpoint(
1056
+ create_custom_forward(resnet),
1057
+ hidden_states,
1058
+ temb,
1059
+ use_reentrant=False,
1060
+ )
1061
+ else:
1062
+ hidden_states = torch.utils.checkpoint.checkpoint(
1063
+ create_custom_forward(resnet), hidden_states, temb
1064
+ )
1065
+ else:
1066
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1067
+
1068
+ if self.upsamplers is not None:
1069
+ for upsampler in self.upsamplers:
1070
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1071
+
1072
+ return hidden_states
modules/unet_2d_condition.py ADDED
@@ -0,0 +1,1312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.utils.checkpoint
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import UNet2DConditionLoadersMixin
10
+ from diffusers.models.activations import get_activation
11
+ from diffusers.models.attention_processor import (
12
+ ADDED_KV_ATTENTION_PROCESSORS,
13
+ CROSS_ATTENTION_PROCESSORS,
14
+ AttentionProcessor,
15
+ AttnAddedKVProcessor,
16
+ AttnProcessor,
17
+ )
18
+ from diffusers.models.embeddings import (
19
+ GaussianFourierProjection,
20
+ ImageHintTimeEmbedding,
21
+ ImageProjection,
22
+ ImageTimeEmbedding,
23
+ TextImageProjection,
24
+ TextImageTimeEmbedding,
25
+ TextTimeEmbedding,
26
+ TimestepEmbedding,
27
+ Timesteps,
28
+ )
29
+ try:
30
+ from diffusers.models.embeddings import PositionNet
31
+ except:
32
+ from diffusers.models.embeddings import GLIGENTextBoundingboxProjection as PositionNet
33
+
34
+ from diffusers.models.modeling_utils import ModelMixin
35
+ from diffusers.utils import (
36
+ USE_PEFT_BACKEND,
37
+ BaseOutput,
38
+ deprecate,
39
+ logging,
40
+ scale_lora_layers,
41
+ unscale_lora_layers,
42
+ )
43
+
44
+ from .unet_2d_blocks import (
45
+ UNetMidBlock2D,
46
+ UNetMidBlock2DCrossAttn,
47
+ get_down_block,
48
+ get_up_block,
49
+ )
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+
54
+ @dataclass
55
+ class UNet2DConditionOutput(BaseOutput):
56
+ """
57
+ The output of [`UNet2DConditionModel`].
58
+
59
+ Args:
60
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
61
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
62
+ """
63
+
64
+ sample: torch.FloatTensor = None
65
+ ref_features: Tuple[torch.FloatTensor] = None
66
+
67
+
68
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
69
+ r"""
70
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
71
+ shaped output.
72
+
73
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
74
+ for all models (such as downloading or saving).
75
+
76
+ Parameters:
77
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
78
+ Height and width of input/output sample.
79
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
80
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
81
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
82
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
83
+ Whether to flip the sin to cos in the time embedding.
84
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
85
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
86
+ The tuple of downsample blocks to use.
87
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
88
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
89
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
90
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
91
+ The tuple of upsample blocks to use.
92
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
93
+ Whether to include self-attention in the basic transformer blocks, see
94
+ [`~models.attention.BasicTransformerBlock`].
95
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
96
+ The tuple of output channels for each block.
97
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
98
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
99
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
100
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
101
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
102
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
103
+ If `None`, normalization and activation layers is skipped in post-processing.
104
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
105
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
106
+ The dimension of the cross attention features.
107
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
108
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
109
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
110
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
111
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
112
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
113
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
114
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
115
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
116
+ encoder_hid_dim (`int`, *optional*, defaults to None):
117
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
118
+ dimension to `cross_attention_dim`.
119
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
120
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
121
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
122
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
123
+ num_attention_heads (`int`, *optional*):
124
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
125
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
126
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
127
+ class_embed_type (`str`, *optional*, defaults to `None`):
128
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
129
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
130
+ addition_embed_type (`str`, *optional*, defaults to `None`):
131
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
132
+ "text". "text" will use the `TextTimeEmbedding` layer.
133
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
134
+ Dimension for the timestep embeddings.
135
+ num_class_embeds (`int`, *optional*, defaults to `None`):
136
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
137
+ class conditioning with `class_embed_type` equal to `None`.
138
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
139
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
140
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
141
+ An optional override for the dimension of the projected time embedding.
142
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
143
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
144
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
145
+ timestep_post_act (`str`, *optional*, defaults to `None`):
146
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
147
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
148
+ The dimension of `cond_proj` layer in the timestep embedding.
149
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
150
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
151
+ *optional*): The dimension of the `class_labels` input when
152
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
153
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
154
+ embeddings with the class embeddings.
155
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
156
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
157
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
158
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
159
+ otherwise.
160
+ """
161
+
162
+ _supports_gradient_checkpointing = True
163
+
164
+ @register_to_config
165
+ def __init__(
166
+ self,
167
+ sample_size: Optional[int] = None,
168
+ in_channels: int = 4,
169
+ out_channels: int = 4,
170
+ center_input_sample: bool = False,
171
+ flip_sin_to_cos: bool = True,
172
+ freq_shift: int = 0,
173
+ down_block_types: Tuple[str] = (
174
+ "CrossAttnDownBlock2D",
175
+ "CrossAttnDownBlock2D",
176
+ "CrossAttnDownBlock2D",
177
+ "DownBlock2D",
178
+ ),
179
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
180
+ up_block_types: Tuple[str] = (
181
+ "UpBlock2D",
182
+ "CrossAttnUpBlock2D",
183
+ "CrossAttnUpBlock2D",
184
+ "CrossAttnUpBlock2D",
185
+ ),
186
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
187
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
188
+ layers_per_block: Union[int, Tuple[int]] = 2,
189
+ downsample_padding: int = 1,
190
+ mid_block_scale_factor: float = 1,
191
+ dropout: float = 0.0,
192
+ act_fn: str = "silu",
193
+ norm_num_groups: Optional[int] = 32,
194
+ norm_eps: float = 1e-5,
195
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
196
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
197
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
198
+ encoder_hid_dim: Optional[int] = None,
199
+ encoder_hid_dim_type: Optional[str] = None,
200
+ attention_head_dim: Union[int, Tuple[int]] = 8,
201
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
202
+ dual_cross_attention: bool = False,
203
+ use_linear_projection: bool = False,
204
+ class_embed_type: Optional[str] = None,
205
+ addition_embed_type: Optional[str] = None,
206
+ addition_time_embed_dim: Optional[int] = None,
207
+ num_class_embeds: Optional[int] = None,
208
+ upcast_attention: bool = False,
209
+ resnet_time_scale_shift: str = "default",
210
+ resnet_skip_time_act: bool = False,
211
+ resnet_out_scale_factor: int = 1.0,
212
+ time_embedding_type: str = "positional",
213
+ time_embedding_dim: Optional[int] = None,
214
+ time_embedding_act_fn: Optional[str] = None,
215
+ timestep_post_act: Optional[str] = None,
216
+ time_cond_proj_dim: Optional[int] = None,
217
+ conv_in_kernel: int = 3,
218
+ conv_out_kernel: int = 3,
219
+ projection_class_embeddings_input_dim: Optional[int] = None,
220
+ attention_type: str = "default",
221
+ class_embeddings_concat: bool = False,
222
+ mid_block_only_cross_attention: Optional[bool] = None,
223
+ cross_attention_norm: Optional[str] = None,
224
+ addition_embed_type_num_heads=64,
225
+ ):
226
+ super().__init__()
227
+
228
+ self.sample_size = sample_size
229
+
230
+ if num_attention_heads is not None:
231
+ raise ValueError(
232
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
233
+ )
234
+
235
+ # If `num_attention_heads` is not defined (which is the case for most models)
236
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
237
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
238
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
239
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
240
+ # which is why we correct for the naming here.
241
+ num_attention_heads = num_attention_heads or attention_head_dim
242
+
243
+ # Check inputs
244
+ if len(down_block_types) != len(up_block_types):
245
+ raise ValueError(
246
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
247
+ )
248
+
249
+ if len(block_out_channels) != len(down_block_types):
250
+ raise ValueError(
251
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
252
+ )
253
+
254
+ if not isinstance(only_cross_attention, bool) and len(
255
+ only_cross_attention
256
+ ) != len(down_block_types):
257
+ raise ValueError(
258
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
259
+ )
260
+
261
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
262
+ down_block_types
263
+ ):
264
+ raise ValueError(
265
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
266
+ )
267
+
268
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
269
+ down_block_types
270
+ ):
271
+ raise ValueError(
272
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
273
+ )
274
+
275
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
276
+ down_block_types
277
+ ):
278
+ raise ValueError(
279
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
280
+ )
281
+
282
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
283
+ down_block_types
284
+ ):
285
+ raise ValueError(
286
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
287
+ )
288
+ if (
289
+ isinstance(transformer_layers_per_block, list)
290
+ and reverse_transformer_layers_per_block is None
291
+ ):
292
+ for layer_number_per_block in transformer_layers_per_block:
293
+ if isinstance(layer_number_per_block, list):
294
+ raise ValueError(
295
+ "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
296
+ )
297
+
298
+ # input
299
+ conv_in_padding = (conv_in_kernel - 1) // 2
300
+ self.conv_in = nn.Conv2d(
301
+ in_channels,
302
+ block_out_channels[0],
303
+ kernel_size=conv_in_kernel,
304
+ padding=conv_in_padding,
305
+ )
306
+
307
+ # time
308
+ if time_embedding_type == "fourier":
309
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
310
+ if time_embed_dim % 2 != 0:
311
+ raise ValueError(
312
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
313
+ )
314
+ self.time_proj = GaussianFourierProjection(
315
+ time_embed_dim // 2,
316
+ set_W_to_weight=False,
317
+ log=False,
318
+ flip_sin_to_cos=flip_sin_to_cos,
319
+ )
320
+ timestep_input_dim = time_embed_dim
321
+ elif time_embedding_type == "positional":
322
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
323
+
324
+ self.time_proj = Timesteps(
325
+ block_out_channels[0], flip_sin_to_cos, freq_shift
326
+ )
327
+ timestep_input_dim = block_out_channels[0]
328
+ else:
329
+ raise ValueError(
330
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
331
+ )
332
+
333
+ self.time_embedding = TimestepEmbedding(
334
+ timestep_input_dim,
335
+ time_embed_dim,
336
+ act_fn=act_fn,
337
+ post_act_fn=timestep_post_act,
338
+ cond_proj_dim=time_cond_proj_dim,
339
+ )
340
+
341
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
342
+ encoder_hid_dim_type = "text_proj"
343
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
344
+ logger.info(
345
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
346
+ )
347
+
348
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
349
+ raise ValueError(
350
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
351
+ )
352
+
353
+ if encoder_hid_dim_type == "text_proj":
354
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
355
+ elif encoder_hid_dim_type == "text_image_proj":
356
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
357
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
358
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
359
+ self.encoder_hid_proj = TextImageProjection(
360
+ text_embed_dim=encoder_hid_dim,
361
+ image_embed_dim=cross_attention_dim,
362
+ cross_attention_dim=cross_attention_dim,
363
+ )
364
+ elif encoder_hid_dim_type == "image_proj":
365
+ # Kandinsky 2.2
366
+ self.encoder_hid_proj = ImageProjection(
367
+ image_embed_dim=encoder_hid_dim,
368
+ cross_attention_dim=cross_attention_dim,
369
+ )
370
+ elif encoder_hid_dim_type is not None:
371
+ raise ValueError(
372
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
373
+ )
374
+ else:
375
+ self.encoder_hid_proj = None
376
+
377
+ # class embedding
378
+ if class_embed_type is None and num_class_embeds is not None:
379
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
380
+ elif class_embed_type == "timestep":
381
+ self.class_embedding = TimestepEmbedding(
382
+ timestep_input_dim, time_embed_dim, act_fn=act_fn
383
+ )
384
+ elif class_embed_type == "identity":
385
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
386
+ elif class_embed_type == "projection":
387
+ if projection_class_embeddings_input_dim is None:
388
+ raise ValueError(
389
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
390
+ )
391
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
392
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
393
+ # 2. it projects from an arbitrary input dimension.
394
+ #
395
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
396
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
397
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
398
+ self.class_embedding = TimestepEmbedding(
399
+ projection_class_embeddings_input_dim, time_embed_dim
400
+ )
401
+ elif class_embed_type == "simple_projection":
402
+ if projection_class_embeddings_input_dim is None:
403
+ raise ValueError(
404
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
405
+ )
406
+ self.class_embedding = nn.Linear(
407
+ projection_class_embeddings_input_dim, time_embed_dim
408
+ )
409
+ else:
410
+ self.class_embedding = None
411
+
412
+ if addition_embed_type == "text":
413
+ if encoder_hid_dim is not None:
414
+ text_time_embedding_from_dim = encoder_hid_dim
415
+ else:
416
+ text_time_embedding_from_dim = cross_attention_dim
417
+
418
+ self.add_embedding = TextTimeEmbedding(
419
+ text_time_embedding_from_dim,
420
+ time_embed_dim,
421
+ num_heads=addition_embed_type_num_heads,
422
+ )
423
+ elif addition_embed_type == "text_image":
424
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
425
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
426
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
427
+ self.add_embedding = TextImageTimeEmbedding(
428
+ text_embed_dim=cross_attention_dim,
429
+ image_embed_dim=cross_attention_dim,
430
+ time_embed_dim=time_embed_dim,
431
+ )
432
+ elif addition_embed_type == "text_time":
433
+ self.add_time_proj = Timesteps(
434
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
435
+ )
436
+ self.add_embedding = TimestepEmbedding(
437
+ projection_class_embeddings_input_dim, time_embed_dim
438
+ )
439
+ elif addition_embed_type == "image":
440
+ # Kandinsky 2.2
441
+ self.add_embedding = ImageTimeEmbedding(
442
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
443
+ )
444
+ elif addition_embed_type == "image_hint":
445
+ # Kandinsky 2.2 ControlNet
446
+ self.add_embedding = ImageHintTimeEmbedding(
447
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
448
+ )
449
+ elif addition_embed_type is not None:
450
+ raise ValueError(
451
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
452
+ )
453
+
454
+ if time_embedding_act_fn is None:
455
+ self.time_embed_act = None
456
+ else:
457
+ self.time_embed_act = get_activation(time_embedding_act_fn)
458
+
459
+ self.down_blocks = nn.ModuleList([])
460
+ self.up_blocks = nn.ModuleList([])
461
+
462
+ if isinstance(only_cross_attention, bool):
463
+ if mid_block_only_cross_attention is None:
464
+ mid_block_only_cross_attention = only_cross_attention
465
+
466
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
467
+
468
+ if mid_block_only_cross_attention is None:
469
+ mid_block_only_cross_attention = False
470
+
471
+ if isinstance(num_attention_heads, int):
472
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
473
+
474
+ if isinstance(attention_head_dim, int):
475
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
476
+
477
+ if isinstance(cross_attention_dim, int):
478
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
479
+
480
+ if isinstance(layers_per_block, int):
481
+ layers_per_block = [layers_per_block] * len(down_block_types)
482
+
483
+ if isinstance(transformer_layers_per_block, int):
484
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
485
+ down_block_types
486
+ )
487
+
488
+ if class_embeddings_concat:
489
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
490
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
491
+ # regular time embeddings
492
+ blocks_time_embed_dim = time_embed_dim * 2
493
+ else:
494
+ blocks_time_embed_dim = time_embed_dim
495
+
496
+ # down
497
+ output_channel = block_out_channels[0]
498
+ for i, down_block_type in enumerate(down_block_types):
499
+ input_channel = output_channel
500
+ output_channel = block_out_channels[i]
501
+ is_final_block = i == len(block_out_channels) - 1
502
+
503
+ down_block = get_down_block(
504
+ down_block_type,
505
+ num_layers=layers_per_block[i],
506
+ transformer_layers_per_block=transformer_layers_per_block[i],
507
+ in_channels=input_channel,
508
+ out_channels=output_channel,
509
+ temb_channels=blocks_time_embed_dim,
510
+ add_downsample=not is_final_block,
511
+ resnet_eps=norm_eps,
512
+ resnet_act_fn=act_fn,
513
+ resnet_groups=norm_num_groups,
514
+ cross_attention_dim=cross_attention_dim[i],
515
+ num_attention_heads=num_attention_heads[i],
516
+ downsample_padding=downsample_padding,
517
+ dual_cross_attention=dual_cross_attention,
518
+ use_linear_projection=use_linear_projection,
519
+ only_cross_attention=only_cross_attention[i],
520
+ upcast_attention=upcast_attention,
521
+ resnet_time_scale_shift=resnet_time_scale_shift,
522
+ attention_type=attention_type,
523
+ resnet_skip_time_act=resnet_skip_time_act,
524
+ resnet_out_scale_factor=resnet_out_scale_factor,
525
+ cross_attention_norm=cross_attention_norm,
526
+ attention_head_dim=attention_head_dim[i]
527
+ if attention_head_dim[i] is not None
528
+ else output_channel,
529
+ dropout=dropout,
530
+ )
531
+ self.down_blocks.append(down_block)
532
+
533
+ # mid
534
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
535
+ self.mid_block = UNetMidBlock2DCrossAttn(
536
+ transformer_layers_per_block=transformer_layers_per_block[-1],
537
+ in_channels=block_out_channels[-1],
538
+ temb_channels=blocks_time_embed_dim,
539
+ dropout=dropout,
540
+ resnet_eps=norm_eps,
541
+ resnet_act_fn=act_fn,
542
+ output_scale_factor=mid_block_scale_factor,
543
+ resnet_time_scale_shift=resnet_time_scale_shift,
544
+ cross_attention_dim=cross_attention_dim[-1],
545
+ num_attention_heads=num_attention_heads[-1],
546
+ resnet_groups=norm_num_groups,
547
+ dual_cross_attention=dual_cross_attention,
548
+ use_linear_projection=use_linear_projection,
549
+ upcast_attention=upcast_attention,
550
+ attention_type=attention_type,
551
+ )
552
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
553
+ raise NotImplementedError(f"Unsupport mid_block_type: {mid_block_type}")
554
+ elif mid_block_type == "UNetMidBlock2D":
555
+ self.mid_block = UNetMidBlock2D(
556
+ in_channels=block_out_channels[-1],
557
+ temb_channels=blocks_time_embed_dim,
558
+ dropout=dropout,
559
+ num_layers=0,
560
+ resnet_eps=norm_eps,
561
+ resnet_act_fn=act_fn,
562
+ output_scale_factor=mid_block_scale_factor,
563
+ resnet_groups=norm_num_groups,
564
+ resnet_time_scale_shift=resnet_time_scale_shift,
565
+ add_attention=False,
566
+ )
567
+ elif mid_block_type is None:
568
+ self.mid_block = None
569
+ else:
570
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
571
+
572
+ # count how many layers upsample the images
573
+ self.num_upsamplers = 0
574
+
575
+ # up
576
+ reversed_block_out_channels = list(reversed(block_out_channels))
577
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
578
+ reversed_layers_per_block = list(reversed(layers_per_block))
579
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
580
+ reversed_transformer_layers_per_block = (
581
+ list(reversed(transformer_layers_per_block))
582
+ if reverse_transformer_layers_per_block is None
583
+ else reverse_transformer_layers_per_block
584
+ )
585
+ only_cross_attention = list(reversed(only_cross_attention))
586
+
587
+ output_channel = reversed_block_out_channels[0]
588
+ for i, up_block_type in enumerate(up_block_types):
589
+ is_final_block = i == len(block_out_channels) - 1
590
+
591
+ prev_output_channel = output_channel
592
+ output_channel = reversed_block_out_channels[i]
593
+ input_channel = reversed_block_out_channels[
594
+ min(i + 1, len(block_out_channels) - 1)
595
+ ]
596
+
597
+ # add upsample block for all BUT final layer
598
+ if not is_final_block:
599
+ add_upsample = True
600
+ self.num_upsamplers += 1
601
+ else:
602
+ add_upsample = False
603
+
604
+ up_block = get_up_block(
605
+ up_block_type,
606
+ num_layers=reversed_layers_per_block[i] + 1,
607
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
608
+ in_channels=input_channel,
609
+ out_channels=output_channel,
610
+ prev_output_channel=prev_output_channel,
611
+ temb_channels=blocks_time_embed_dim,
612
+ add_upsample=add_upsample,
613
+ resnet_eps=norm_eps,
614
+ resnet_act_fn=act_fn,
615
+ resolution_idx=i,
616
+ resnet_groups=norm_num_groups,
617
+ cross_attention_dim=reversed_cross_attention_dim[i],
618
+ num_attention_heads=reversed_num_attention_heads[i],
619
+ dual_cross_attention=dual_cross_attention,
620
+ use_linear_projection=use_linear_projection,
621
+ only_cross_attention=only_cross_attention[i],
622
+ upcast_attention=upcast_attention,
623
+ resnet_time_scale_shift=resnet_time_scale_shift,
624
+ attention_type=attention_type,
625
+ resnet_skip_time_act=resnet_skip_time_act,
626
+ resnet_out_scale_factor=resnet_out_scale_factor,
627
+ cross_attention_norm=cross_attention_norm,
628
+ attention_head_dim=attention_head_dim[i]
629
+ if attention_head_dim[i] is not None
630
+ else output_channel,
631
+ dropout=dropout,
632
+ )
633
+ self.up_blocks.append(up_block)
634
+ prev_output_channel = output_channel
635
+
636
+ # out
637
+ if norm_num_groups is not None:
638
+ self.conv_norm_out = nn.GroupNorm(
639
+ num_channels=block_out_channels[0],
640
+ num_groups=norm_num_groups,
641
+ eps=norm_eps,
642
+ )
643
+
644
+ self.conv_act = get_activation(act_fn)
645
+
646
+ else:
647
+ self.conv_norm_out = None
648
+ self.conv_act = None
649
+ self.conv_norm_out = None
650
+
651
+ conv_out_padding = (conv_out_kernel - 1) // 2
652
+ self.conv_out = nn.Conv2d(
653
+ block_out_channels[0],
654
+ out_channels,
655
+ kernel_size=conv_out_kernel,
656
+ padding=conv_out_padding,
657
+ )
658
+
659
+ if attention_type in ["gated", "gated-text-image"]:
660
+ positive_len = 768
661
+ if isinstance(cross_attention_dim, int):
662
+ positive_len = cross_attention_dim
663
+ elif isinstance(cross_attention_dim, tuple) or isinstance(
664
+ cross_attention_dim, list
665
+ ):
666
+ positive_len = cross_attention_dim[0]
667
+
668
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
669
+ self.position_net = PositionNet(
670
+ positive_len=positive_len,
671
+ out_dim=cross_attention_dim,
672
+ feature_type=feature_type,
673
+ )
674
+
675
+ @property
676
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
677
+ r"""
678
+ Returns:
679
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
680
+ indexed by its weight name.
681
+ """
682
+ # set recursively
683
+ processors = {}
684
+
685
+ def fn_recursive_add_processors(
686
+ name: str,
687
+ module: torch.nn.Module,
688
+ processors: Dict[str, AttentionProcessor],
689
+ ):
690
+ if hasattr(module, "get_processor"):
691
+ processors[f"{name}.processor"] = module.get_processor(
692
+ return_deprecated_lora=True
693
+ )
694
+
695
+ for sub_name, child in module.named_children():
696
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
697
+
698
+ return processors
699
+
700
+ for name, module in self.named_children():
701
+ fn_recursive_add_processors(name, module, processors)
702
+
703
+ return processors
704
+
705
+ def set_attn_processor(
706
+ self,
707
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
708
+ _remove_lora=False,
709
+ ):
710
+ r"""
711
+ Sets the attention processor to use to compute attention.
712
+
713
+ Parameters:
714
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
715
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
716
+ for **all** `Attention` layers.
717
+
718
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
719
+ processor. This is strongly recommended when setting trainable attention processors.
720
+
721
+ """
722
+ count = len(self.attn_processors.keys())
723
+
724
+ if isinstance(processor, dict) and len(processor) != count:
725
+ raise ValueError(
726
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
727
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
728
+ )
729
+
730
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
731
+ if hasattr(module, "set_processor"):
732
+ if not isinstance(processor, dict):
733
+ module.set_processor(processor, _remove_lora=_remove_lora)
734
+ else:
735
+ module.set_processor(
736
+ processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
737
+ )
738
+
739
+ for sub_name, child in module.named_children():
740
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
741
+
742
+ for name, module in self.named_children():
743
+ fn_recursive_attn_processor(name, module, processor)
744
+
745
+ def set_default_attn_processor(self):
746
+ """
747
+ Disables custom attention processors and sets the default attention implementation.
748
+ """
749
+ if all(
750
+ proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
751
+ for proc in self.attn_processors.values()
752
+ ):
753
+ processor = AttnAddedKVProcessor()
754
+ elif all(
755
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
756
+ for proc in self.attn_processors.values()
757
+ ):
758
+ processor = AttnProcessor()
759
+ else:
760
+ raise ValueError(
761
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
762
+ )
763
+
764
+ self.set_attn_processor(processor, _remove_lora=True)
765
+
766
+ def set_attention_slice(self, slice_size):
767
+ r"""
768
+ Enable sliced attention computation.
769
+
770
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
771
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
772
+
773
+ Args:
774
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
775
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
776
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
777
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
778
+ must be a multiple of `slice_size`.
779
+ """
780
+ sliceable_head_dims = []
781
+
782
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
783
+ if hasattr(module, "set_attention_slice"):
784
+ sliceable_head_dims.append(module.sliceable_head_dim)
785
+
786
+ for child in module.children():
787
+ fn_recursive_retrieve_sliceable_dims(child)
788
+
789
+ # retrieve number of attention layers
790
+ for module in self.children():
791
+ fn_recursive_retrieve_sliceable_dims(module)
792
+
793
+ num_sliceable_layers = len(sliceable_head_dims)
794
+
795
+ if slice_size == "auto":
796
+ # half the attention head size is usually a good trade-off between
797
+ # speed and memory
798
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
799
+ elif slice_size == "max":
800
+ # make smallest slice possible
801
+ slice_size = num_sliceable_layers * [1]
802
+
803
+ slice_size = (
804
+ num_sliceable_layers * [slice_size]
805
+ if not isinstance(slice_size, list)
806
+ else slice_size
807
+ )
808
+
809
+ if len(slice_size) != len(sliceable_head_dims):
810
+ raise ValueError(
811
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
812
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
813
+ )
814
+
815
+ for i in range(len(slice_size)):
816
+ size = slice_size[i]
817
+ dim = sliceable_head_dims[i]
818
+ if size is not None and size > dim:
819
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
820
+
821
+ # Recursively walk through all the children.
822
+ # Any children which exposes the set_attention_slice method
823
+ # gets the message
824
+ def fn_recursive_set_attention_slice(
825
+ module: torch.nn.Module, slice_size: List[int]
826
+ ):
827
+ if hasattr(module, "set_attention_slice"):
828
+ module.set_attention_slice(slice_size.pop())
829
+
830
+ for child in module.children():
831
+ fn_recursive_set_attention_slice(child, slice_size)
832
+
833
+ reversed_slice_size = list(reversed(slice_size))
834
+ for module in self.children():
835
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
836
+
837
+ def _set_gradient_checkpointing(self, module, value=False):
838
+ if hasattr(module, "gradient_checkpointing"):
839
+ module.gradient_checkpointing = value
840
+
841
+ def enable_freeu(self, s1, s2, b1, b2):
842
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
843
+
844
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
845
+
846
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
847
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
848
+
849
+ Args:
850
+ s1 (`float`):
851
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
852
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
853
+ s2 (`float`):
854
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
855
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
856
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
857
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
858
+ """
859
+ for i, upsample_block in enumerate(self.up_blocks):
860
+ setattr(upsample_block, "s1", s1)
861
+ setattr(upsample_block, "s2", s2)
862
+ setattr(upsample_block, "b1", b1)
863
+ setattr(upsample_block, "b2", b2)
864
+
865
+ def disable_freeu(self):
866
+ """Disables the FreeU mechanism."""
867
+ freeu_keys = {"s1", "s2", "b1", "b2"}
868
+ for i, upsample_block in enumerate(self.up_blocks):
869
+ for k in freeu_keys:
870
+ if (
871
+ hasattr(upsample_block, k)
872
+ or getattr(upsample_block, k, None) is not None
873
+ ):
874
+ setattr(upsample_block, k, None)
875
+
876
+ def forward(
877
+ self,
878
+ sample: torch.FloatTensor,
879
+ timestep: Union[torch.Tensor, float, int],
880
+ encoder_hidden_states: torch.Tensor,
881
+ class_labels: Optional[torch.Tensor] = None,
882
+ timestep_cond: Optional[torch.Tensor] = None,
883
+ attention_mask: Optional[torch.Tensor] = None,
884
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
885
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
886
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
887
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
888
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
889
+ encoder_attention_mask: Optional[torch.Tensor] = None,
890
+ return_dict: bool = True,
891
+ ) -> Union[UNet2DConditionOutput, Tuple]:
892
+ r"""
893
+ The [`UNet2DConditionModel`] forward method.
894
+
895
+ Args:
896
+ sample (`torch.FloatTensor`):
897
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
898
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
899
+ encoder_hidden_states (`torch.FloatTensor`):
900
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
901
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
902
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
903
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
904
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
905
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
906
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
907
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
908
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
909
+ negative values to the attention scores corresponding to "discard" tokens.
910
+ cross_attention_kwargs (`dict`, *optional*):
911
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
912
+ `self.processor` in
913
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
914
+ added_cond_kwargs: (`dict`, *optional*):
915
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
916
+ are passed along to the UNet blocks.
917
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
918
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
919
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
920
+ A tensor that if specified is added to the residual of the middle unet block.
921
+ encoder_attention_mask (`torch.Tensor`):
922
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
923
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
924
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
925
+ return_dict (`bool`, *optional*, defaults to `True`):
926
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
927
+ tuple.
928
+ cross_attention_kwargs (`dict`, *optional*):
929
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
930
+ added_cond_kwargs: (`dict`, *optional*):
931
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
932
+ are passed along to the UNet blocks.
933
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
934
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
935
+ example from ControlNet side model(s)
936
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
937
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
938
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
939
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
940
+
941
+ Returns:
942
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
943
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
944
+ a `tuple` is returned where the first element is the sample tensor.
945
+ """
946
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
947
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
948
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
949
+ # on the fly if necessary.
950
+ default_overall_up_factor = 2 ** self.num_upsamplers
951
+
952
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
953
+ forward_upsample_size = False
954
+ upsample_size = None
955
+
956
+ for dim in sample.shape[-2:]:
957
+ if dim % default_overall_up_factor != 0:
958
+ # Forward upsample size to force interpolation output size.
959
+ forward_upsample_size = True
960
+ break
961
+
962
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
963
+ # expects mask of shape:
964
+ # [batch, key_tokens]
965
+ # adds singleton query_tokens dimension:
966
+ # [batch, 1, key_tokens]
967
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
968
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
969
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
970
+ if attention_mask is not None:
971
+ # assume that mask is expressed as:
972
+ # (1 = keep, 0 = discard)
973
+ # convert mask into a bias that can be added to attention scores:
974
+ # (keep = +0, discard = -10000.0)
975
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
976
+ attention_mask = attention_mask.unsqueeze(1)
977
+
978
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
979
+ if encoder_attention_mask is not None:
980
+ encoder_attention_mask = (
981
+ 1 - encoder_attention_mask.to(sample.dtype)
982
+ ) * -10000.0
983
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
984
+
985
+ # 0. center input if necessary
986
+ if self.config.center_input_sample:
987
+ sample = 2 * sample - 1.0
988
+
989
+ # 1. time
990
+ timesteps = timestep
991
+ if not torch.is_tensor(timesteps):
992
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
993
+ # This would be a good case for the `match` statement (Python 3.10+)
994
+ is_mps = sample.device.type == "mps"
995
+ if isinstance(timestep, float):
996
+ dtype = torch.float32 if is_mps else torch.float64
997
+ else:
998
+ dtype = torch.int32 if is_mps else torch.int64
999
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1000
+ elif len(timesteps.shape) == 0:
1001
+ timesteps = timesteps[None].to(sample.device)
1002
+
1003
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1004
+ timesteps = timesteps.expand(sample.shape[0])
1005
+
1006
+ t_emb = self.time_proj(timesteps)
1007
+
1008
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1009
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1010
+ # there might be better ways to encapsulate this.
1011
+ t_emb = t_emb.to(dtype=sample.dtype)
1012
+
1013
+ emb = self.time_embedding(t_emb, timestep_cond)
1014
+ aug_emb = None
1015
+
1016
+ if self.class_embedding is not None:
1017
+ if class_labels is None:
1018
+ raise ValueError(
1019
+ "class_labels should be provided when num_class_embeds > 0"
1020
+ )
1021
+
1022
+ if self.config.class_embed_type == "timestep":
1023
+ class_labels = self.time_proj(class_labels)
1024
+
1025
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1026
+ # there might be better ways to encapsulate this.
1027
+ class_labels = class_labels.to(dtype=sample.dtype)
1028
+
1029
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1030
+
1031
+ if self.config.class_embeddings_concat:
1032
+ emb = torch.cat([emb, class_emb], dim=-1)
1033
+ else:
1034
+ emb = emb + class_emb
1035
+
1036
+ if self.config.addition_embed_type == "text":
1037
+ aug_emb = self.add_embedding(encoder_hidden_states)
1038
+ elif self.config.addition_embed_type == "text_image":
1039
+ # Kandinsky 2.1 - style
1040
+ if "image_embeds" not in added_cond_kwargs:
1041
+ raise ValueError(
1042
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1043
+ )
1044
+
1045
+ image_embs = added_cond_kwargs.get("image_embeds")
1046
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1047
+ aug_emb = self.add_embedding(text_embs, image_embs)
1048
+ elif self.config.addition_embed_type == "text_time":
1049
+ # SDXL - style
1050
+ if "text_embeds" not in added_cond_kwargs:
1051
+ raise ValueError(
1052
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1053
+ )
1054
+ text_embeds = added_cond_kwargs.get("text_embeds")
1055
+ if "time_ids" not in added_cond_kwargs:
1056
+ raise ValueError(
1057
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1058
+ )
1059
+ time_ids = added_cond_kwargs.get("time_ids")
1060
+ time_embeds = self.add_time_proj(time_ids.flatten())
1061
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1062
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1063
+ add_embeds = add_embeds.to(emb.dtype)
1064
+ aug_emb = self.add_embedding(add_embeds)
1065
+ elif self.config.addition_embed_type == "image":
1066
+ # Kandinsky 2.2 - style
1067
+ if "image_embeds" not in added_cond_kwargs:
1068
+ raise ValueError(
1069
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1070
+ )
1071
+ image_embs = added_cond_kwargs.get("image_embeds")
1072
+ aug_emb = self.add_embedding(image_embs)
1073
+ elif self.config.addition_embed_type == "image_hint":
1074
+ # Kandinsky 2.2 - style
1075
+ if (
1076
+ "image_embeds" not in added_cond_kwargs
1077
+ or "hint" not in added_cond_kwargs
1078
+ ):
1079
+ raise ValueError(
1080
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1081
+ )
1082
+ image_embs = added_cond_kwargs.get("image_embeds")
1083
+ hint = added_cond_kwargs.get("hint")
1084
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1085
+ sample = torch.cat([sample, hint], dim=1)
1086
+
1087
+ emb = emb + aug_emb if aug_emb is not None else emb
1088
+
1089
+ if self.time_embed_act is not None:
1090
+ emb = self.time_embed_act(emb)
1091
+
1092
+ if (
1093
+ self.encoder_hid_proj is not None
1094
+ and self.config.encoder_hid_dim_type == "text_proj"
1095
+ ):
1096
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1097
+ elif (
1098
+ self.encoder_hid_proj is not None
1099
+ and self.config.encoder_hid_dim_type == "text_image_proj"
1100
+ ):
1101
+ # Kadinsky 2.1 - style
1102
+ if "image_embeds" not in added_cond_kwargs:
1103
+ raise ValueError(
1104
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1105
+ )
1106
+
1107
+ image_embeds = added_cond_kwargs.get("image_embeds")
1108
+ encoder_hidden_states = self.encoder_hid_proj(
1109
+ encoder_hidden_states, image_embeds
1110
+ )
1111
+ elif (
1112
+ self.encoder_hid_proj is not None
1113
+ and self.config.encoder_hid_dim_type == "image_proj"
1114
+ ):
1115
+ # Kandinsky 2.2 - style
1116
+ if "image_embeds" not in added_cond_kwargs:
1117
+ raise ValueError(
1118
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1119
+ )
1120
+ image_embeds = added_cond_kwargs.get("image_embeds")
1121
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1122
+ elif (
1123
+ self.encoder_hid_proj is not None
1124
+ and self.config.encoder_hid_dim_type == "ip_image_proj"
1125
+ ):
1126
+ if "image_embeds" not in added_cond_kwargs:
1127
+ raise ValueError(
1128
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1129
+ )
1130
+ image_embeds = added_cond_kwargs.get("image_embeds")
1131
+ image_embeds = self.encoder_hid_proj(image_embeds).to(
1132
+ encoder_hidden_states.dtype
1133
+ )
1134
+ encoder_hidden_states = torch.cat(
1135
+ [encoder_hidden_states, image_embeds], dim=1
1136
+ )
1137
+
1138
+ # 2. pre-process
1139
+ sample = self.conv_in(sample)
1140
+
1141
+ # 2.5 GLIGEN position net
1142
+ if (
1143
+ cross_attention_kwargs is not None
1144
+ and cross_attention_kwargs.get("gligen", None) is not None
1145
+ ):
1146
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1147
+ gligen_args = cross_attention_kwargs.pop("gligen")
1148
+ cross_attention_kwargs["gligen"] = {
1149
+ "objs": self.position_net(**gligen_args)
1150
+ }
1151
+
1152
+ # 3. down
1153
+ lora_scale = (
1154
+ cross_attention_kwargs.get("scale", 1.0)
1155
+ if cross_attention_kwargs is not None
1156
+ else 1.0
1157
+ )
1158
+ if USE_PEFT_BACKEND:
1159
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1160
+ scale_lora_layers(self, lora_scale)
1161
+
1162
+ is_controlnet = (
1163
+ mid_block_additional_residual is not None
1164
+ and down_block_additional_residuals is not None
1165
+ )
1166
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1167
+ is_adapter = down_intrablock_additional_residuals is not None
1168
+ # maintain backward compatibility for legacy usage, where
1169
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1170
+ # but can only use one or the other
1171
+ if (
1172
+ not is_adapter
1173
+ and mid_block_additional_residual is None
1174
+ and down_block_additional_residuals is not None
1175
+ ):
1176
+ deprecate(
1177
+ "T2I should not use down_block_additional_residuals",
1178
+ "1.3.0",
1179
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1180
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1181
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1182
+ standard_warn=False,
1183
+ )
1184
+ down_intrablock_additional_residuals = down_block_additional_residuals
1185
+ is_adapter = True
1186
+
1187
+ down_block_res_samples = (sample,)
1188
+ tot_referece_features = ()
1189
+ for downsample_block in self.down_blocks:
1190
+ if (
1191
+ hasattr(downsample_block, "has_cross_attention")
1192
+ and downsample_block.has_cross_attention
1193
+ ):
1194
+ # For t2i-adapter CrossAttnDownBlock2D
1195
+ additional_residuals = {}
1196
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1197
+ additional_residuals[
1198
+ "additional_residuals"
1199
+ ] = down_intrablock_additional_residuals.pop(0)
1200
+
1201
+ sample, res_samples = downsample_block(
1202
+ hidden_states=sample,
1203
+ temb=emb,
1204
+ encoder_hidden_states=encoder_hidden_states,
1205
+ attention_mask=attention_mask,
1206
+ cross_attention_kwargs=cross_attention_kwargs,
1207
+ encoder_attention_mask=encoder_attention_mask,
1208
+ **additional_residuals,
1209
+ )
1210
+ else:
1211
+ sample, res_samples = downsample_block(
1212
+ hidden_states=sample, temb=emb, scale=lora_scale
1213
+ )
1214
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1215
+ sample += down_intrablock_additional_residuals.pop(0)
1216
+
1217
+ down_block_res_samples += res_samples
1218
+
1219
+ if is_controlnet:
1220
+ new_down_block_res_samples = ()
1221
+
1222
+ for down_block_res_sample, down_block_additional_residual in zip(
1223
+ down_block_res_samples, down_block_additional_residuals
1224
+ ):
1225
+ down_block_res_sample = (
1226
+ down_block_res_sample + down_block_additional_residual
1227
+ )
1228
+ new_down_block_res_samples = new_down_block_res_samples + (
1229
+ down_block_res_sample,
1230
+ )
1231
+
1232
+ down_block_res_samples = new_down_block_res_samples
1233
+
1234
+ # 4. mid
1235
+ if self.mid_block is not None:
1236
+ if (
1237
+ hasattr(self.mid_block, "has_cross_attention")
1238
+ and self.mid_block.has_cross_attention
1239
+ ):
1240
+ sample = self.mid_block(
1241
+ sample,
1242
+ emb,
1243
+ encoder_hidden_states=encoder_hidden_states,
1244
+ attention_mask=attention_mask,
1245
+ cross_attention_kwargs=cross_attention_kwargs,
1246
+ encoder_attention_mask=encoder_attention_mask,
1247
+ )
1248
+ else:
1249
+ sample = self.mid_block(sample, emb)
1250
+
1251
+ # To support T2I-Adapter-XL
1252
+ if (
1253
+ is_adapter
1254
+ and len(down_intrablock_additional_residuals) > 0
1255
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1256
+ ):
1257
+ sample += down_intrablock_additional_residuals.pop(0)
1258
+
1259
+ if is_controlnet:
1260
+ sample = sample + mid_block_additional_residual
1261
+
1262
+ # 5. up
1263
+ for i, upsample_block in enumerate(self.up_blocks):
1264
+ is_final_block = i == len(self.up_blocks) - 1
1265
+
1266
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
1267
+ down_block_res_samples = down_block_res_samples[
1268
+ : -len(upsample_block.resnets)
1269
+ ]
1270
+
1271
+ # if we have not reached the final block and need to forward the
1272
+ # upsample size, we do it here
1273
+ if not is_final_block and forward_upsample_size:
1274
+ upsample_size = down_block_res_samples[-1].shape[2:]
1275
+
1276
+ if (
1277
+ hasattr(upsample_block, "has_cross_attention")
1278
+ and upsample_block.has_cross_attention
1279
+ ):
1280
+ sample = upsample_block(
1281
+ hidden_states=sample,
1282
+ temb=emb,
1283
+ res_hidden_states_tuple=res_samples,
1284
+ encoder_hidden_states=encoder_hidden_states,
1285
+ cross_attention_kwargs=cross_attention_kwargs,
1286
+ upsample_size=upsample_size,
1287
+ attention_mask=attention_mask,
1288
+ encoder_attention_mask=encoder_attention_mask,
1289
+ )
1290
+ else:
1291
+ sample = upsample_block(
1292
+ hidden_states=sample,
1293
+ temb=emb,
1294
+ res_hidden_states_tuple=res_samples,
1295
+ upsample_size=upsample_size,
1296
+ scale=lora_scale,
1297
+ )
1298
+
1299
+ # 6. post-process
1300
+ if self.conv_norm_out:
1301
+ sample = self.conv_norm_out(sample)
1302
+ sample = self.conv_act(sample)
1303
+ sample = self.conv_out(sample)
1304
+
1305
+ if USE_PEFT_BACKEND:
1306
+ # remove `lora_scale` from each PEFT layer
1307
+ unscale_lora_layers(self, lora_scale)
1308
+
1309
+ if not return_dict:
1310
+ return (sample,)
1311
+
1312
+ return UNet2DConditionOutput(sample=sample)
modules/unet_3d.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
2
+
3
+ from collections import OrderedDict
4
+ from dataclasses import dataclass
5
+ from os import PathLike
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.models.attention_processor import AttentionProcessor
14
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
17
+ from safetensors.torch import load_file
18
+
19
+ from .resnet import InflatedConv3d, InflatedGroupNorm
20
+ from .unet_3d_blocks import UNetMidBlock3DCrossAttn, get_down_block, get_up_block
21
+
22
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23
+
24
+
25
+ @dataclass
26
+ class UNet3DConditionOutput(BaseOutput):
27
+ sample: torch.FloatTensor
28
+
29
+
30
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
31
+ _supports_gradient_checkpointing = True
32
+
33
+ @register_to_config
34
+ def __init__(
35
+ self,
36
+ sample_size: Optional[int] = None,
37
+ in_channels: int = 4,
38
+ out_channels: int = 4,
39
+ center_input_sample: bool = False,
40
+ flip_sin_to_cos: bool = True,
41
+ freq_shift: int = 0,
42
+ down_block_types: Tuple[str] = (
43
+ "CrossAttnDownBlock3D",
44
+ "CrossAttnDownBlock3D",
45
+ "CrossAttnDownBlock3D",
46
+ "DownBlock3D",
47
+ ),
48
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
49
+ up_block_types: Tuple[str] = (
50
+ "UpBlock3D",
51
+ "CrossAttnUpBlock3D",
52
+ "CrossAttnUpBlock3D",
53
+ "CrossAttnUpBlock3D",
54
+ ),
55
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
56
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
57
+ layers_per_block: int = 2,
58
+ downsample_padding: int = 1,
59
+ mid_block_scale_factor: float = 1,
60
+ act_fn: str = "silu",
61
+ norm_num_groups: int = 32,
62
+ norm_eps: float = 1e-5,
63
+ cross_attention_dim: int = 1280,
64
+ attention_head_dim: Union[int, Tuple[int]] = 8,
65
+ dual_cross_attention: bool = False,
66
+ use_linear_projection: bool = False,
67
+ class_embed_type: Optional[str] = None,
68
+ num_class_embeds: Optional[int] = None,
69
+ upcast_attention: bool = False,
70
+ resnet_time_scale_shift: str = "default",
71
+ use_inflated_groupnorm=False,
72
+ # Additional
73
+ use_motion_module=False,
74
+ motion_module_resolutions=(1, 2, 4, 8),
75
+ motion_module_mid_block=False,
76
+ motion_module_decoder_only=False,
77
+ motion_module_type=None,
78
+ motion_module_kwargs={},
79
+ unet_use_cross_frame_attention=None,
80
+ unet_use_temporal_attention=None,
81
+ ):
82
+ super().__init__()
83
+
84
+ self.sample_size = sample_size
85
+ time_embed_dim = block_out_channels[0] * 4
86
+
87
+ # input
88
+ self.conv_in = InflatedConv3d(
89
+ in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
90
+ )
91
+
92
+ # time
93
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
94
+ timestep_input_dim = block_out_channels[0]
95
+
96
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
97
+
98
+ # class embedding
99
+ if class_embed_type is None and num_class_embeds is not None:
100
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
101
+ elif class_embed_type == "timestep":
102
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
103
+ elif class_embed_type == "identity":
104
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
105
+ else:
106
+ self.class_embedding = None
107
+
108
+ self.down_blocks = nn.ModuleList([])
109
+ self.mid_block = None
110
+ self.up_blocks = nn.ModuleList([])
111
+
112
+ if isinstance(only_cross_attention, bool):
113
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
114
+
115
+ if isinstance(attention_head_dim, int):
116
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
117
+
118
+ # down
119
+ output_channel = block_out_channels[0]
120
+ for i, down_block_type in enumerate(down_block_types):
121
+ res = 2 ** i
122
+ input_channel = output_channel
123
+ output_channel = block_out_channels[i]
124
+ is_final_block = i == len(block_out_channels) - 1
125
+
126
+ down_block = get_down_block(
127
+ down_block_type,
128
+ num_layers=layers_per_block,
129
+ in_channels=input_channel,
130
+ out_channels=output_channel,
131
+ temb_channels=time_embed_dim,
132
+ add_downsample=not is_final_block,
133
+ resnet_eps=norm_eps,
134
+ resnet_act_fn=act_fn,
135
+ resnet_groups=norm_num_groups,
136
+ cross_attention_dim=cross_attention_dim,
137
+ attn_num_head_channels=attention_head_dim[i],
138
+ downsample_padding=downsample_padding,
139
+ dual_cross_attention=dual_cross_attention,
140
+ use_linear_projection=use_linear_projection,
141
+ only_cross_attention=only_cross_attention[i],
142
+ upcast_attention=upcast_attention,
143
+ resnet_time_scale_shift=resnet_time_scale_shift,
144
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
145
+ unet_use_temporal_attention=unet_use_temporal_attention,
146
+ use_inflated_groupnorm=use_inflated_groupnorm,
147
+ use_motion_module=use_motion_module
148
+ and (res in motion_module_resolutions)
149
+ and (not motion_module_decoder_only),
150
+ motion_module_type=motion_module_type,
151
+ motion_module_kwargs=motion_module_kwargs,
152
+ )
153
+ self.down_blocks.append(down_block)
154
+
155
+ # mid
156
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
157
+ self.mid_block = UNetMidBlock3DCrossAttn(
158
+ in_channels=block_out_channels[-1],
159
+ temb_channels=time_embed_dim,
160
+ resnet_eps=norm_eps,
161
+ resnet_act_fn=act_fn,
162
+ output_scale_factor=mid_block_scale_factor,
163
+ resnet_time_scale_shift=resnet_time_scale_shift,
164
+ cross_attention_dim=cross_attention_dim,
165
+ attn_num_head_channels=attention_head_dim[-1],
166
+ resnet_groups=norm_num_groups,
167
+ dual_cross_attention=dual_cross_attention,
168
+ use_linear_projection=use_linear_projection,
169
+ upcast_attention=upcast_attention,
170
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
171
+ unet_use_temporal_attention=unet_use_temporal_attention,
172
+ use_inflated_groupnorm=use_inflated_groupnorm,
173
+ use_motion_module=use_motion_module and motion_module_mid_block,
174
+ motion_module_type=motion_module_type,
175
+ motion_module_kwargs=motion_module_kwargs,
176
+ )
177
+ else:
178
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
179
+
180
+ # count how many layers upsample the videos
181
+ self.num_upsamplers = 0
182
+
183
+ # up
184
+ reversed_block_out_channels = list(reversed(block_out_channels))
185
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
186
+ only_cross_attention = list(reversed(only_cross_attention))
187
+ output_channel = reversed_block_out_channels[0]
188
+ for i, up_block_type in enumerate(up_block_types):
189
+ res = 2 ** (3 - i)
190
+ is_final_block = i == len(block_out_channels) - 1
191
+
192
+ prev_output_channel = output_channel
193
+ output_channel = reversed_block_out_channels[i]
194
+ input_channel = reversed_block_out_channels[
195
+ min(i + 1, len(block_out_channels) - 1)
196
+ ]
197
+
198
+ # add upsample block for all BUT final layer
199
+ if not is_final_block:
200
+ add_upsample = True
201
+ self.num_upsamplers += 1
202
+ else:
203
+ add_upsample = False
204
+
205
+ up_block = get_up_block(
206
+ up_block_type,
207
+ num_layers=layers_per_block + 1,
208
+ in_channels=input_channel,
209
+ out_channels=output_channel,
210
+ prev_output_channel=prev_output_channel,
211
+ temb_channels=time_embed_dim,
212
+ add_upsample=add_upsample,
213
+ resnet_eps=norm_eps,
214
+ resnet_act_fn=act_fn,
215
+ resnet_groups=norm_num_groups,
216
+ cross_attention_dim=cross_attention_dim,
217
+ attn_num_head_channels=reversed_attention_head_dim[i],
218
+ dual_cross_attention=dual_cross_attention,
219
+ use_linear_projection=use_linear_projection,
220
+ only_cross_attention=only_cross_attention[i],
221
+ upcast_attention=upcast_attention,
222
+ resnet_time_scale_shift=resnet_time_scale_shift,
223
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
224
+ unet_use_temporal_attention=unet_use_temporal_attention,
225
+ use_inflated_groupnorm=use_inflated_groupnorm,
226
+ use_motion_module=use_motion_module
227
+ and (res in motion_module_resolutions),
228
+ motion_module_type=motion_module_type,
229
+ motion_module_kwargs=motion_module_kwargs,
230
+ )
231
+ self.up_blocks.append(up_block)
232
+ prev_output_channel = output_channel
233
+
234
+ # out
235
+ if use_inflated_groupnorm:
236
+ self.conv_norm_out = InflatedGroupNorm(
237
+ num_channels=block_out_channels[0],
238
+ num_groups=norm_num_groups,
239
+ eps=norm_eps,
240
+ )
241
+ else:
242
+ self.conv_norm_out = nn.GroupNorm(
243
+ num_channels=block_out_channels[0],
244
+ num_groups=norm_num_groups,
245
+ eps=norm_eps,
246
+ )
247
+ self.conv_act = nn.SiLU()
248
+ self.conv_out = InflatedConv3d(
249
+ block_out_channels[0], out_channels, kernel_size=3, padding=1
250
+ )
251
+
252
+ @property
253
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
254
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
255
+ r"""
256
+ Returns:
257
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
258
+ indexed by its weight name.
259
+ """
260
+ # set recursively
261
+ processors = {}
262
+
263
+ def fn_recursive_add_processors(
264
+ name: str,
265
+ module: torch.nn.Module,
266
+ processors: Dict[str, AttentionProcessor],
267
+ ):
268
+ # if hasattr(module, "set_processor"):
269
+ # processors[f"{name}.processor"] = module.processor
270
+
271
+ if hasattr(module, "get_processor") or hasattr(module, "set_processor"):
272
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
273
+
274
+ for sub_name, child in module.named_children():
275
+ if "temporal_transformer" not in sub_name:
276
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
277
+
278
+ return processors
279
+
280
+ for name, module in self.named_children():
281
+ if "temporal_transformer" not in name:
282
+ fn_recursive_add_processors(name, module, processors)
283
+
284
+ return processors
285
+
286
+ def set_attention_slice(self, slice_size):
287
+ r"""
288
+ Enable sliced attention computation.
289
+
290
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
291
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
292
+
293
+ Args:
294
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
295
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
296
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
297
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
298
+ must be a multiple of `slice_size`.
299
+ """
300
+ sliceable_head_dims = []
301
+
302
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
303
+ if hasattr(module, "set_attention_slice"):
304
+ sliceable_head_dims.append(module.sliceable_head_dim)
305
+
306
+ for child in module.children():
307
+ fn_recursive_retrieve_slicable_dims(child)
308
+
309
+ # retrieve number of attention layers
310
+ for module in self.children():
311
+ fn_recursive_retrieve_slicable_dims(module)
312
+
313
+ num_slicable_layers = len(sliceable_head_dims)
314
+
315
+ if slice_size == "auto":
316
+ # half the attention head size is usually a good trade-off between
317
+ # speed and memory
318
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
319
+ elif slice_size == "max":
320
+ # make smallest slice possible
321
+ slice_size = num_slicable_layers * [1]
322
+
323
+ slice_size = (
324
+ num_slicable_layers * [slice_size]
325
+ if not isinstance(slice_size, list)
326
+ else slice_size
327
+ )
328
+
329
+ if len(slice_size) != len(sliceable_head_dims):
330
+ raise ValueError(
331
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
332
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
333
+ )
334
+
335
+ for i in range(len(slice_size)):
336
+ size = slice_size[i]
337
+ dim = sliceable_head_dims[i]
338
+ if size is not None and size > dim:
339
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
340
+
341
+ # Recursively walk through all the children.
342
+ # Any children which exposes the set_attention_slice method
343
+ # gets the message
344
+ def fn_recursive_set_attention_slice(
345
+ module: torch.nn.Module, slice_size: List[int]
346
+ ):
347
+ if hasattr(module, "set_attention_slice"):
348
+ module.set_attention_slice(slice_size.pop())
349
+
350
+ for child in module.children():
351
+ fn_recursive_set_attention_slice(child, slice_size)
352
+
353
+ reversed_slice_size = list(reversed(slice_size))
354
+ for module in self.children():
355
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
356
+
357
+ def _set_gradient_checkpointing(self, module, value=False):
358
+ if hasattr(module, "gradient_checkpointing"):
359
+ module.gradient_checkpointing = value
360
+
361
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
362
+ def set_attn_processor(
363
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
364
+ ):
365
+ r"""
366
+ Sets the attention processor to use to compute attention.
367
+
368
+ Parameters:
369
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
370
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
371
+ for **all** `Attention` layers.
372
+
373
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
374
+ processor. This is strongly recommended when setting trainable attention processors.
375
+
376
+ """
377
+ count = len(self.attn_processors.keys())
378
+
379
+ if isinstance(processor, dict) and len(processor) != count:
380
+ raise ValueError(
381
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
382
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
383
+ )
384
+
385
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
386
+ if hasattr(module, "set_processor"):
387
+ if not isinstance(processor, dict):
388
+ module.set_processor(processor)
389
+ else:
390
+ module.set_processor(processor.pop(f"{name}.processor"))
391
+
392
+ for sub_name, child in module.named_children():
393
+ if "temporal_transformer" not in sub_name:
394
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
395
+
396
+ for name, module in self.named_children():
397
+ if "temporal_transformer" not in name:
398
+ fn_recursive_attn_processor(name, module, processor)
399
+
400
+ def forward(
401
+ self,
402
+ sample: torch.FloatTensor,
403
+ timestep: Union[torch.Tensor, float, int],
404
+ encoder_hidden_states: torch.Tensor,
405
+ class_labels: Optional[torch.Tensor] = None,
406
+ kps_features: Optional[torch.Tensor] = None,
407
+ attention_mask: Optional[torch.Tensor] = None,
408
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
409
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
410
+ return_dict: bool = True,
411
+ ) -> Union[UNet3DConditionOutput, Tuple]:
412
+ r"""
413
+ Args:
414
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
415
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
416
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
417
+ return_dict (`bool`, *optional*, defaults to `True`):
418
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
419
+
420
+ Returns:
421
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
422
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
423
+ returning a tuple, the first element is the sample tensor.
424
+ """
425
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
426
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
427
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
428
+ # on the fly if necessary.
429
+ default_overall_up_factor = 2 ** self.num_upsamplers
430
+
431
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
432
+ forward_upsample_size = False
433
+ upsample_size = None
434
+
435
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
436
+ logger.info("Forward upsample size to force interpolation output size.")
437
+ forward_upsample_size = True
438
+
439
+ # prepare attention_mask
440
+ if attention_mask is not None:
441
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
442
+ attention_mask = attention_mask.unsqueeze(1)
443
+
444
+ # center input if necessary
445
+ if self.config.center_input_sample:
446
+ sample = 2 * sample - 1.0
447
+
448
+ # time
449
+ timesteps = timestep
450
+ if not torch.is_tensor(timesteps):
451
+ # This would be a good case for the `match` statement (Python 3.10+)
452
+ is_mps = sample.device.type == "mps"
453
+ if isinstance(timestep, float):
454
+ dtype = torch.float32 if is_mps else torch.float64
455
+ else:
456
+ dtype = torch.int32 if is_mps else torch.int64
457
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
458
+ elif len(timesteps.shape) == 0:
459
+ timesteps = timesteps[None].to(sample.device)
460
+
461
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
462
+ timesteps = timesteps.expand(sample.shape[0])
463
+
464
+ t_emb = self.time_proj(timesteps)
465
+
466
+ # timesteps does not contain any weights and will always return f32 tensors
467
+ # but time_embedding might actually be running in fp16. so we need to cast here.
468
+ # there might be better ways to encapsulate this.
469
+ t_emb = t_emb.to(dtype=self.dtype)
470
+ emb = self.time_embedding(t_emb)
471
+
472
+ if self.class_embedding is not None:
473
+ if class_labels is None:
474
+ raise ValueError(
475
+ "class_labels should be provided when num_class_embeds > 0"
476
+ )
477
+
478
+ if self.config.class_embed_type == "timestep":
479
+ class_labels = self.time_proj(class_labels)
480
+
481
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
482
+ emb = emb + class_emb
483
+
484
+ # pre-process
485
+ sample = self.conv_in(sample)
486
+ if kps_features is not None:
487
+ sample = sample + kps_features
488
+
489
+ # down
490
+ down_block_res_samples = (sample,)
491
+ for downsample_block in self.down_blocks:
492
+ if (
493
+ hasattr(downsample_block, "has_cross_attention")
494
+ and downsample_block.has_cross_attention
495
+ ):
496
+ sample, res_samples = downsample_block(
497
+ hidden_states=sample,
498
+ temb=emb,
499
+ encoder_hidden_states=encoder_hidden_states,
500
+ attention_mask=attention_mask,
501
+ )
502
+ else:
503
+ sample, res_samples = downsample_block(
504
+ hidden_states=sample,
505
+ temb=emb,
506
+ encoder_hidden_states=encoder_hidden_states,
507
+ )
508
+
509
+ down_block_res_samples += res_samples
510
+
511
+ if down_block_additional_residuals is not None:
512
+ new_down_block_res_samples = ()
513
+
514
+ for down_block_res_sample, down_block_additional_residual in zip(
515
+ down_block_res_samples, down_block_additional_residuals
516
+ ):
517
+ down_block_res_sample = (
518
+ down_block_res_sample + down_block_additional_residual
519
+ )
520
+ new_down_block_res_samples += (down_block_res_sample,)
521
+
522
+ down_block_res_samples = new_down_block_res_samples
523
+
524
+ # mid
525
+ sample = self.mid_block(
526
+ sample,
527
+ emb,
528
+ encoder_hidden_states=encoder_hidden_states,
529
+ attention_mask=attention_mask,
530
+ )
531
+
532
+ if mid_block_additional_residual is not None:
533
+ sample = sample + mid_block_additional_residual
534
+
535
+ # up
536
+ for i, upsample_block in enumerate(self.up_blocks):
537
+ is_final_block = i == len(self.up_blocks) - 1
538
+
539
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
540
+ down_block_res_samples = down_block_res_samples[
541
+ : -len(upsample_block.resnets)
542
+ ]
543
+
544
+ # if we have not reached the final block and need to forward the
545
+ # upsample size, we do it here
546
+ if not is_final_block and forward_upsample_size:
547
+ upsample_size = down_block_res_samples[-1].shape[2:]
548
+
549
+ if (
550
+ hasattr(upsample_block, "has_cross_attention")
551
+ and upsample_block.has_cross_attention
552
+ ):
553
+ sample = upsample_block(
554
+ hidden_states=sample,
555
+ temb=emb,
556
+ res_hidden_states_tuple=res_samples,
557
+ encoder_hidden_states=encoder_hidden_states,
558
+ upsample_size=upsample_size,
559
+ attention_mask=attention_mask,
560
+ )
561
+ else:
562
+ sample = upsample_block(
563
+ hidden_states=sample,
564
+ temb=emb,
565
+ res_hidden_states_tuple=res_samples,
566
+ upsample_size=upsample_size,
567
+ encoder_hidden_states=encoder_hidden_states,
568
+ )
569
+
570
+ # post-process
571
+ sample = self.conv_norm_out(sample)
572
+ sample = self.conv_act(sample)
573
+ sample = self.conv_out(sample)
574
+
575
+ if not return_dict:
576
+ return (sample,)
577
+
578
+ return UNet3DConditionOutput(sample=sample)
579
+
580
+ @classmethod
581
+ def from_pretrained_2d(
582
+ cls,
583
+ pretrained_model_path: PathLike,
584
+ motion_module_path: PathLike,
585
+ subfolder=None,
586
+ unet_additional_kwargs=None,
587
+ mm_zero_proj_out=False,
588
+ ):
589
+ pretrained_model_path = Path(pretrained_model_path)
590
+ motion_module_path = Path(motion_module_path)
591
+ if subfolder is not None:
592
+ pretrained_model_path = pretrained_model_path.joinpath(subfolder)
593
+ logger.info(
594
+ f"loaded temporal unet's pretrained weights from {pretrained_model_path} ..."
595
+ )
596
+
597
+ config_file = pretrained_model_path / "config.json"
598
+ if not (config_file.exists() and config_file.is_file()):
599
+ raise RuntimeError(f"{config_file} does not exist or is not a file")
600
+
601
+ unet_config = cls.load_config(config_file)
602
+ unet_config["_class_name"] = cls.__name__
603
+ unet_config["down_block_types"] = [
604
+ "CrossAttnDownBlock3D",
605
+ "CrossAttnDownBlock3D",
606
+ "CrossAttnDownBlock3D",
607
+ "DownBlock3D",
608
+ ]
609
+ unet_config["up_block_types"] = [
610
+ "UpBlock3D",
611
+ "CrossAttnUpBlock3D",
612
+ "CrossAttnUpBlock3D",
613
+ "CrossAttnUpBlock3D",
614
+ ]
615
+ unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
616
+
617
+ model = cls.from_config(unet_config, **unet_additional_kwargs)
618
+ # load the vanilla weights
619
+ if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists():
620
+ logger.debug(
621
+ f"loading safeTensors weights from {pretrained_model_path} ..."
622
+ )
623
+ state_dict = load_file(
624
+ pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu"
625
+ )
626
+
627
+ elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists():
628
+ logger.debug(f"loading weights from {pretrained_model_path} ...")
629
+ state_dict = torch.load(
630
+ pretrained_model_path.joinpath(WEIGHTS_NAME),
631
+ map_location="cpu",
632
+ weights_only=True,
633
+ )
634
+ else:
635
+ raise FileNotFoundError(f"no weights file found in {pretrained_model_path}")
636
+
637
+ # load the motion module weights
638
+ if motion_module_path.exists() and motion_module_path.is_file():
639
+ if motion_module_path.suffix.lower() in [".pth", ".pt", ".ckpt", ".bin"]:
640
+ logger.info(f"Load motion module params from {motion_module_path}")
641
+ motion_state_dict = torch.load(
642
+ motion_module_path, map_location="cpu", weights_only=True
643
+ )
644
+ elif motion_module_path.suffix.lower() == ".safetensors":
645
+ motion_state_dict = load_file(motion_module_path, device="cpu")
646
+ else:
647
+ raise RuntimeError(
648
+ f"unknown file format for motion module weights: {motion_module_path.suffix}"
649
+ )
650
+ if mm_zero_proj_out:
651
+ logger.info(f"Zero initialize proj_out layers in motion module...")
652
+ new_motion_state_dict = OrderedDict()
653
+ for k in motion_state_dict:
654
+ if "proj_out" in k:
655
+ continue
656
+ new_motion_state_dict[k] = motion_state_dict[k]
657
+ motion_state_dict = new_motion_state_dict
658
+
659
+ # merge the state dicts
660
+ state_dict.update(motion_state_dict)
661
+
662
+ # load the weights into the model
663
+ m, u = model.load_state_dict(state_dict, strict=False)
664
+ logger.debug(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
665
+
666
+ params = [
667
+ p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()
668
+ ]
669
+ logger.info(f"Loaded {sum(params) / 1e6}M-parameter motion module")
670
+
671
+ return model
672
+
673
+ @classmethod
674
+ def from_config_2d(
675
+ cls,
676
+ unet_config_path: PathLike,
677
+ unet_additional_kwargs=None,
678
+ ):
679
+ config_file = unet_config_path
680
+
681
+ unet_config = cls.load_config(config_file)
682
+ unet_config["_class_name"] = cls.__name__
683
+ unet_config["down_block_types"] = [
684
+ "CrossAttnDownBlock3D",
685
+ "CrossAttnDownBlock3D",
686
+ "CrossAttnDownBlock3D",
687
+ "DownBlock3D",
688
+ ]
689
+ unet_config["up_block_types"] = [
690
+ "UpBlock3D",
691
+ "CrossAttnUpBlock3D",
692
+ "CrossAttnUpBlock3D",
693
+ "CrossAttnUpBlock3D",
694
+ ]
695
+ unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
696
+
697
+ model = cls.from_config(unet_config, **unet_additional_kwargs)
698
+ return model
modules/unet_3d_blocks.py ADDED
@@ -0,0 +1,862 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import pdb
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from .motion_module import get_motion_module
9
+
10
+ # from .motion_module import get_motion_module
11
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
12
+ from .transformer_3d import Transformer3DModel
13
+
14
+
15
+ def get_down_block(
16
+ down_block_type,
17
+ num_layers,
18
+ in_channels,
19
+ out_channels,
20
+ temb_channels,
21
+ add_downsample,
22
+ resnet_eps,
23
+ resnet_act_fn,
24
+ attn_num_head_channels,
25
+ resnet_groups=None,
26
+ cross_attention_dim=None,
27
+ downsample_padding=None,
28
+ dual_cross_attention=False,
29
+ use_linear_projection=False,
30
+ only_cross_attention=False,
31
+ upcast_attention=False,
32
+ resnet_time_scale_shift="default",
33
+ unet_use_cross_frame_attention=None,
34
+ unet_use_temporal_attention=None,
35
+ use_inflated_groupnorm=None,
36
+ use_motion_module=None,
37
+ motion_module_type=None,
38
+ motion_module_kwargs=None,
39
+ ):
40
+ down_block_type = (
41
+ down_block_type[7:]
42
+ if down_block_type.startswith("UNetRes")
43
+ else down_block_type
44
+ )
45
+ if down_block_type == "DownBlock3D":
46
+ return DownBlock3D(
47
+ num_layers=num_layers,
48
+ in_channels=in_channels,
49
+ out_channels=out_channels,
50
+ temb_channels=temb_channels,
51
+ add_downsample=add_downsample,
52
+ resnet_eps=resnet_eps,
53
+ resnet_act_fn=resnet_act_fn,
54
+ resnet_groups=resnet_groups,
55
+ downsample_padding=downsample_padding,
56
+ resnet_time_scale_shift=resnet_time_scale_shift,
57
+ use_inflated_groupnorm=use_inflated_groupnorm,
58
+ use_motion_module=use_motion_module,
59
+ motion_module_type=motion_module_type,
60
+ motion_module_kwargs=motion_module_kwargs,
61
+ )
62
+ elif down_block_type == "CrossAttnDownBlock3D":
63
+ if cross_attention_dim is None:
64
+ raise ValueError(
65
+ "cross_attention_dim must be specified for CrossAttnDownBlock3D"
66
+ )
67
+ return CrossAttnDownBlock3D(
68
+ num_layers=num_layers,
69
+ in_channels=in_channels,
70
+ out_channels=out_channels,
71
+ temb_channels=temb_channels,
72
+ add_downsample=add_downsample,
73
+ resnet_eps=resnet_eps,
74
+ resnet_act_fn=resnet_act_fn,
75
+ resnet_groups=resnet_groups,
76
+ downsample_padding=downsample_padding,
77
+ cross_attention_dim=cross_attention_dim,
78
+ attn_num_head_channels=attn_num_head_channels,
79
+ dual_cross_attention=dual_cross_attention,
80
+ use_linear_projection=use_linear_projection,
81
+ only_cross_attention=only_cross_attention,
82
+ upcast_attention=upcast_attention,
83
+ resnet_time_scale_shift=resnet_time_scale_shift,
84
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
85
+ unet_use_temporal_attention=unet_use_temporal_attention,
86
+ use_inflated_groupnorm=use_inflated_groupnorm,
87
+ use_motion_module=use_motion_module,
88
+ motion_module_type=motion_module_type,
89
+ motion_module_kwargs=motion_module_kwargs,
90
+ )
91
+ raise ValueError(f"{down_block_type} does not exist.")
92
+
93
+
94
+ def get_up_block(
95
+ up_block_type,
96
+ num_layers,
97
+ in_channels,
98
+ out_channels,
99
+ prev_output_channel,
100
+ temb_channels,
101
+ add_upsample,
102
+ resnet_eps,
103
+ resnet_act_fn,
104
+ attn_num_head_channels,
105
+ resnet_groups=None,
106
+ cross_attention_dim=None,
107
+ dual_cross_attention=False,
108
+ use_linear_projection=False,
109
+ only_cross_attention=False,
110
+ upcast_attention=False,
111
+ resnet_time_scale_shift="default",
112
+ unet_use_cross_frame_attention=None,
113
+ unet_use_temporal_attention=None,
114
+ use_inflated_groupnorm=None,
115
+ use_motion_module=None,
116
+ motion_module_type=None,
117
+ motion_module_kwargs=None,
118
+ ):
119
+ up_block_type = (
120
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
121
+ )
122
+ if up_block_type == "UpBlock3D":
123
+ return UpBlock3D(
124
+ num_layers=num_layers,
125
+ in_channels=in_channels,
126
+ out_channels=out_channels,
127
+ prev_output_channel=prev_output_channel,
128
+ temb_channels=temb_channels,
129
+ add_upsample=add_upsample,
130
+ resnet_eps=resnet_eps,
131
+ resnet_act_fn=resnet_act_fn,
132
+ resnet_groups=resnet_groups,
133
+ resnet_time_scale_shift=resnet_time_scale_shift,
134
+ use_inflated_groupnorm=use_inflated_groupnorm,
135
+ use_motion_module=use_motion_module,
136
+ motion_module_type=motion_module_type,
137
+ motion_module_kwargs=motion_module_kwargs,
138
+ )
139
+ elif up_block_type == "CrossAttnUpBlock3D":
140
+ if cross_attention_dim is None:
141
+ raise ValueError(
142
+ "cross_attention_dim must be specified for CrossAttnUpBlock3D"
143
+ )
144
+ return CrossAttnUpBlock3D(
145
+ num_layers=num_layers,
146
+ in_channels=in_channels,
147
+ out_channels=out_channels,
148
+ prev_output_channel=prev_output_channel,
149
+ temb_channels=temb_channels,
150
+ add_upsample=add_upsample,
151
+ resnet_eps=resnet_eps,
152
+ resnet_act_fn=resnet_act_fn,
153
+ resnet_groups=resnet_groups,
154
+ cross_attention_dim=cross_attention_dim,
155
+ attn_num_head_channels=attn_num_head_channels,
156
+ dual_cross_attention=dual_cross_attention,
157
+ use_linear_projection=use_linear_projection,
158
+ only_cross_attention=only_cross_attention,
159
+ upcast_attention=upcast_attention,
160
+ resnet_time_scale_shift=resnet_time_scale_shift,
161
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
162
+ unet_use_temporal_attention=unet_use_temporal_attention,
163
+ use_inflated_groupnorm=use_inflated_groupnorm,
164
+ use_motion_module=use_motion_module,
165
+ motion_module_type=motion_module_type,
166
+ motion_module_kwargs=motion_module_kwargs,
167
+ )
168
+ raise ValueError(f"{up_block_type} does not exist.")
169
+
170
+
171
+ class UNetMidBlock3DCrossAttn(nn.Module):
172
+ def __init__(
173
+ self,
174
+ in_channels: int,
175
+ temb_channels: int,
176
+ dropout: float = 0.0,
177
+ num_layers: int = 1,
178
+ resnet_eps: float = 1e-6,
179
+ resnet_time_scale_shift: str = "default",
180
+ resnet_act_fn: str = "swish",
181
+ resnet_groups: int = 32,
182
+ resnet_pre_norm: bool = True,
183
+ attn_num_head_channels=1,
184
+ output_scale_factor=1.0,
185
+ cross_attention_dim=1280,
186
+ dual_cross_attention=False,
187
+ use_linear_projection=False,
188
+ upcast_attention=False,
189
+ unet_use_cross_frame_attention=None,
190
+ unet_use_temporal_attention=None,
191
+ use_inflated_groupnorm=None,
192
+ use_motion_module=None,
193
+ motion_module_type=None,
194
+ motion_module_kwargs=None,
195
+ ):
196
+ super().__init__()
197
+
198
+ self.has_cross_attention = True
199
+ self.attn_num_head_channels = attn_num_head_channels
200
+ resnet_groups = (
201
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
202
+ )
203
+
204
+ # there is always at least one resnet
205
+ resnets = [
206
+ ResnetBlock3D(
207
+ in_channels=in_channels,
208
+ out_channels=in_channels,
209
+ temb_channels=temb_channels,
210
+ eps=resnet_eps,
211
+ groups=resnet_groups,
212
+ dropout=dropout,
213
+ time_embedding_norm=resnet_time_scale_shift,
214
+ non_linearity=resnet_act_fn,
215
+ output_scale_factor=output_scale_factor,
216
+ pre_norm=resnet_pre_norm,
217
+ use_inflated_groupnorm=use_inflated_groupnorm,
218
+ )
219
+ ]
220
+ attentions = []
221
+ motion_modules = []
222
+
223
+ for _ in range(num_layers):
224
+ if dual_cross_attention:
225
+ raise NotImplementedError
226
+ attentions.append(
227
+ Transformer3DModel(
228
+ attn_num_head_channels,
229
+ in_channels // attn_num_head_channels,
230
+ in_channels=in_channels,
231
+ num_layers=1,
232
+ cross_attention_dim=cross_attention_dim,
233
+ norm_num_groups=resnet_groups,
234
+ use_linear_projection=use_linear_projection,
235
+ upcast_attention=upcast_attention,
236
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
237
+ unet_use_temporal_attention=unet_use_temporal_attention,
238
+ )
239
+ )
240
+ motion_modules.append(
241
+ get_motion_module(
242
+ in_channels=in_channels,
243
+ motion_module_type=motion_module_type,
244
+ motion_module_kwargs=motion_module_kwargs,
245
+ )
246
+ if use_motion_module
247
+ else None
248
+ )
249
+ resnets.append(
250
+ ResnetBlock3D(
251
+ in_channels=in_channels,
252
+ out_channels=in_channels,
253
+ temb_channels=temb_channels,
254
+ eps=resnet_eps,
255
+ groups=resnet_groups,
256
+ dropout=dropout,
257
+ time_embedding_norm=resnet_time_scale_shift,
258
+ non_linearity=resnet_act_fn,
259
+ output_scale_factor=output_scale_factor,
260
+ pre_norm=resnet_pre_norm,
261
+ use_inflated_groupnorm=use_inflated_groupnorm,
262
+ )
263
+ )
264
+
265
+ self.attentions = nn.ModuleList(attentions)
266
+ self.resnets = nn.ModuleList(resnets)
267
+ self.motion_modules = nn.ModuleList(motion_modules)
268
+
269
+ def forward(
270
+ self,
271
+ hidden_states,
272
+ temb=None,
273
+ encoder_hidden_states=None,
274
+ attention_mask=None,
275
+ ):
276
+ hidden_states = self.resnets[0](hidden_states, temb)
277
+ for attn, resnet, motion_module in zip(
278
+ self.attentions, self.resnets[1:], self.motion_modules
279
+ ):
280
+ hidden_states = attn(
281
+ hidden_states,
282
+ encoder_hidden_states=encoder_hidden_states,
283
+ ).sample
284
+ hidden_states = (
285
+ motion_module(
286
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
287
+ )
288
+ if motion_module is not None
289
+ else hidden_states
290
+ )
291
+ hidden_states = resnet(hidden_states, temb)
292
+
293
+ return hidden_states
294
+
295
+
296
+ class CrossAttnDownBlock3D(nn.Module):
297
+ def __init__(
298
+ self,
299
+ in_channels: int,
300
+ out_channels: int,
301
+ temb_channels: int,
302
+ dropout: float = 0.0,
303
+ num_layers: int = 1,
304
+ resnet_eps: float = 1e-6,
305
+ resnet_time_scale_shift: str = "default",
306
+ resnet_act_fn: str = "swish",
307
+ resnet_groups: int = 32,
308
+ resnet_pre_norm: bool = True,
309
+ attn_num_head_channels=1,
310
+ cross_attention_dim=1280,
311
+ output_scale_factor=1.0,
312
+ downsample_padding=1,
313
+ add_downsample=True,
314
+ dual_cross_attention=False,
315
+ use_linear_projection=False,
316
+ only_cross_attention=False,
317
+ upcast_attention=False,
318
+ unet_use_cross_frame_attention=None,
319
+ unet_use_temporal_attention=None,
320
+ use_inflated_groupnorm=None,
321
+ use_motion_module=None,
322
+ motion_module_type=None,
323
+ motion_module_kwargs=None,
324
+ ):
325
+ super().__init__()
326
+ resnets = []
327
+ attentions = []
328
+ motion_modules = []
329
+
330
+ self.has_cross_attention = True
331
+ self.attn_num_head_channels = attn_num_head_channels
332
+
333
+ for i in range(num_layers):
334
+ in_channels = in_channels if i == 0 else out_channels
335
+ resnets.append(
336
+ ResnetBlock3D(
337
+ in_channels=in_channels,
338
+ out_channels=out_channels,
339
+ temb_channels=temb_channels,
340
+ eps=resnet_eps,
341
+ groups=resnet_groups,
342
+ dropout=dropout,
343
+ time_embedding_norm=resnet_time_scale_shift,
344
+ non_linearity=resnet_act_fn,
345
+ output_scale_factor=output_scale_factor,
346
+ pre_norm=resnet_pre_norm,
347
+ use_inflated_groupnorm=use_inflated_groupnorm,
348
+ )
349
+ )
350
+ if dual_cross_attention:
351
+ raise NotImplementedError
352
+ attentions.append(
353
+ Transformer3DModel(
354
+ attn_num_head_channels,
355
+ out_channels // attn_num_head_channels,
356
+ in_channels=out_channels,
357
+ num_layers=1,
358
+ cross_attention_dim=cross_attention_dim,
359
+ norm_num_groups=resnet_groups,
360
+ use_linear_projection=use_linear_projection,
361
+ only_cross_attention=only_cross_attention,
362
+ upcast_attention=upcast_attention,
363
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
364
+ unet_use_temporal_attention=unet_use_temporal_attention,
365
+ )
366
+ )
367
+ motion_modules.append(
368
+ get_motion_module(
369
+ in_channels=out_channels,
370
+ motion_module_type=motion_module_type,
371
+ motion_module_kwargs=motion_module_kwargs,
372
+ )
373
+ if use_motion_module
374
+ else None
375
+ )
376
+
377
+ self.attentions = nn.ModuleList(attentions)
378
+ self.resnets = nn.ModuleList(resnets)
379
+ self.motion_modules = nn.ModuleList(motion_modules)
380
+
381
+ if add_downsample:
382
+ self.downsamplers = nn.ModuleList(
383
+ [
384
+ Downsample3D(
385
+ out_channels,
386
+ use_conv=True,
387
+ out_channels=out_channels,
388
+ padding=downsample_padding,
389
+ name="op",
390
+ )
391
+ ]
392
+ )
393
+ else:
394
+ self.downsamplers = None
395
+
396
+ self.gradient_checkpointing = False
397
+
398
+ def forward(
399
+ self,
400
+ hidden_states,
401
+ temb=None,
402
+ encoder_hidden_states=None,
403
+ attention_mask=None,
404
+ ):
405
+ output_states = ()
406
+
407
+ for i, (resnet, attn, motion_module) in enumerate(
408
+ zip(self.resnets, self.attentions, self.motion_modules)
409
+ ):
410
+ # self.gradient_checkpointing = False
411
+ if self.training and self.gradient_checkpointing:
412
+
413
+ def create_custom_forward(module, return_dict=None):
414
+ def custom_forward(*inputs):
415
+ if return_dict is not None:
416
+ return module(*inputs, return_dict=return_dict)
417
+ else:
418
+ return module(*inputs)
419
+
420
+ return custom_forward
421
+
422
+ hidden_states = torch.utils.checkpoint.checkpoint(
423
+ create_custom_forward(resnet), hidden_states, temb
424
+ )
425
+ hidden_states = torch.utils.checkpoint.checkpoint(
426
+ create_custom_forward(attn, return_dict=False),
427
+ hidden_states,
428
+ encoder_hidden_states,
429
+ )[0]
430
+
431
+ # add motion module
432
+ hidden_states = (
433
+ motion_module(
434
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
435
+ )
436
+ if motion_module is not None
437
+ else hidden_states
438
+ )
439
+
440
+ else:
441
+ hidden_states = resnet(hidden_states, temb)
442
+ hidden_states = attn(
443
+ hidden_states,
444
+ encoder_hidden_states=encoder_hidden_states,
445
+ ).sample
446
+
447
+ # add motion module
448
+ hidden_states = (
449
+ motion_module(
450
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
451
+ )
452
+ if motion_module is not None
453
+ else hidden_states
454
+ )
455
+
456
+ output_states += (hidden_states,)
457
+
458
+ if self.downsamplers is not None:
459
+ for downsampler in self.downsamplers:
460
+ hidden_states = downsampler(hidden_states)
461
+
462
+ output_states += (hidden_states,)
463
+
464
+ return hidden_states, output_states
465
+
466
+
467
+ class DownBlock3D(nn.Module):
468
+ def __init__(
469
+ self,
470
+ in_channels: int,
471
+ out_channels: int,
472
+ temb_channels: int,
473
+ dropout: float = 0.0,
474
+ num_layers: int = 1,
475
+ resnet_eps: float = 1e-6,
476
+ resnet_time_scale_shift: str = "default",
477
+ resnet_act_fn: str = "swish",
478
+ resnet_groups: int = 32,
479
+ resnet_pre_norm: bool = True,
480
+ output_scale_factor=1.0,
481
+ add_downsample=True,
482
+ downsample_padding=1,
483
+ use_inflated_groupnorm=None,
484
+ use_motion_module=None,
485
+ motion_module_type=None,
486
+ motion_module_kwargs=None,
487
+ ):
488
+ super().__init__()
489
+ resnets = []
490
+ motion_modules = []
491
+
492
+ # use_motion_module = False
493
+ for i in range(num_layers):
494
+ in_channels = in_channels if i == 0 else out_channels
495
+ resnets.append(
496
+ ResnetBlock3D(
497
+ in_channels=in_channels,
498
+ out_channels=out_channels,
499
+ temb_channels=temb_channels,
500
+ eps=resnet_eps,
501
+ groups=resnet_groups,
502
+ dropout=dropout,
503
+ time_embedding_norm=resnet_time_scale_shift,
504
+ non_linearity=resnet_act_fn,
505
+ output_scale_factor=output_scale_factor,
506
+ pre_norm=resnet_pre_norm,
507
+ use_inflated_groupnorm=use_inflated_groupnorm,
508
+ )
509
+ )
510
+ motion_modules.append(
511
+ get_motion_module(
512
+ in_channels=out_channels,
513
+ motion_module_type=motion_module_type,
514
+ motion_module_kwargs=motion_module_kwargs,
515
+ )
516
+ if use_motion_module
517
+ else None
518
+ )
519
+
520
+ self.resnets = nn.ModuleList(resnets)
521
+ self.motion_modules = nn.ModuleList(motion_modules)
522
+
523
+ if add_downsample:
524
+ self.downsamplers = nn.ModuleList(
525
+ [
526
+ Downsample3D(
527
+ out_channels,
528
+ use_conv=True,
529
+ out_channels=out_channels,
530
+ padding=downsample_padding,
531
+ name="op",
532
+ )
533
+ ]
534
+ )
535
+ else:
536
+ self.downsamplers = None
537
+
538
+ self.gradient_checkpointing = False
539
+
540
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
541
+ output_states = ()
542
+
543
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
544
+ # print(f"DownBlock3D {self.gradient_checkpointing = }")
545
+ if self.training and self.gradient_checkpointing:
546
+
547
+ def create_custom_forward(module):
548
+ def custom_forward(*inputs):
549
+ return module(*inputs)
550
+
551
+ return custom_forward
552
+
553
+ hidden_states = torch.utils.checkpoint.checkpoint(
554
+ create_custom_forward(resnet), hidden_states, temb
555
+ )
556
+ if motion_module is not None:
557
+ hidden_states = torch.utils.checkpoint.checkpoint(
558
+ create_custom_forward(motion_module),
559
+ hidden_states.requires_grad_(),
560
+ temb,
561
+ encoder_hidden_states,
562
+ )
563
+ else:
564
+ hidden_states = resnet(hidden_states, temb)
565
+
566
+ # add motion module
567
+ hidden_states = (
568
+ motion_module(
569
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
570
+ )
571
+ if motion_module is not None
572
+ else hidden_states
573
+ )
574
+
575
+ output_states += (hidden_states,)
576
+
577
+ if self.downsamplers is not None:
578
+ for downsampler in self.downsamplers:
579
+ hidden_states = downsampler(hidden_states)
580
+
581
+ output_states += (hidden_states,)
582
+
583
+ return hidden_states, output_states
584
+
585
+
586
+ class CrossAttnUpBlock3D(nn.Module):
587
+ def __init__(
588
+ self,
589
+ in_channels: int,
590
+ out_channels: int,
591
+ prev_output_channel: int,
592
+ temb_channels: int,
593
+ dropout: float = 0.0,
594
+ num_layers: int = 1,
595
+ resnet_eps: float = 1e-6,
596
+ resnet_time_scale_shift: str = "default",
597
+ resnet_act_fn: str = "swish",
598
+ resnet_groups: int = 32,
599
+ resnet_pre_norm: bool = True,
600
+ attn_num_head_channels=1,
601
+ cross_attention_dim=1280,
602
+ output_scale_factor=1.0,
603
+ add_upsample=True,
604
+ dual_cross_attention=False,
605
+ use_linear_projection=False,
606
+ only_cross_attention=False,
607
+ upcast_attention=False,
608
+ unet_use_cross_frame_attention=None,
609
+ unet_use_temporal_attention=None,
610
+ use_motion_module=None,
611
+ use_inflated_groupnorm=None,
612
+ motion_module_type=None,
613
+ motion_module_kwargs=None,
614
+ ):
615
+ super().__init__()
616
+ resnets = []
617
+ attentions = []
618
+ motion_modules = []
619
+
620
+ self.has_cross_attention = True
621
+ self.attn_num_head_channels = attn_num_head_channels
622
+
623
+ for i in range(num_layers):
624
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
625
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
626
+
627
+ resnets.append(
628
+ ResnetBlock3D(
629
+ in_channels=resnet_in_channels + res_skip_channels,
630
+ out_channels=out_channels,
631
+ temb_channels=temb_channels,
632
+ eps=resnet_eps,
633
+ groups=resnet_groups,
634
+ dropout=dropout,
635
+ time_embedding_norm=resnet_time_scale_shift,
636
+ non_linearity=resnet_act_fn,
637
+ output_scale_factor=output_scale_factor,
638
+ pre_norm=resnet_pre_norm,
639
+ use_inflated_groupnorm=use_inflated_groupnorm,
640
+ )
641
+ )
642
+ if dual_cross_attention:
643
+ raise NotImplementedError
644
+ attentions.append(
645
+ Transformer3DModel(
646
+ attn_num_head_channels,
647
+ out_channels // attn_num_head_channels,
648
+ in_channels=out_channels,
649
+ num_layers=1,
650
+ cross_attention_dim=cross_attention_dim,
651
+ norm_num_groups=resnet_groups,
652
+ use_linear_projection=use_linear_projection,
653
+ only_cross_attention=only_cross_attention,
654
+ upcast_attention=upcast_attention,
655
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
656
+ unet_use_temporal_attention=unet_use_temporal_attention,
657
+ )
658
+ )
659
+ motion_modules.append(
660
+ get_motion_module(
661
+ in_channels=out_channels,
662
+ motion_module_type=motion_module_type,
663
+ motion_module_kwargs=motion_module_kwargs,
664
+ )
665
+ if use_motion_module
666
+ else None
667
+ )
668
+
669
+ self.attentions = nn.ModuleList(attentions)
670
+ self.resnets = nn.ModuleList(resnets)
671
+ self.motion_modules = nn.ModuleList(motion_modules)
672
+
673
+ if add_upsample:
674
+ self.upsamplers = nn.ModuleList(
675
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
676
+ )
677
+ else:
678
+ self.upsamplers = None
679
+
680
+ self.gradient_checkpointing = False
681
+
682
+ def forward(
683
+ self,
684
+ hidden_states,
685
+ res_hidden_states_tuple,
686
+ temb=None,
687
+ encoder_hidden_states=None,
688
+ upsample_size=None,
689
+ attention_mask=None,
690
+ ):
691
+ for i, (resnet, attn, motion_module) in enumerate(
692
+ zip(self.resnets, self.attentions, self.motion_modules)
693
+ ):
694
+ # pop res hidden states
695
+ res_hidden_states = res_hidden_states_tuple[-1]
696
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
697
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
698
+
699
+ if self.training and self.gradient_checkpointing:
700
+
701
+ def create_custom_forward(module, return_dict=None):
702
+ def custom_forward(*inputs):
703
+ if return_dict is not None:
704
+ return module(*inputs, return_dict=return_dict)
705
+ else:
706
+ return module(*inputs)
707
+
708
+ return custom_forward
709
+
710
+ hidden_states = torch.utils.checkpoint.checkpoint(
711
+ create_custom_forward(resnet), hidden_states, temb
712
+ )
713
+ hidden_states = attn(
714
+ hidden_states,
715
+ encoder_hidden_states=encoder_hidden_states,
716
+ ).sample
717
+ if motion_module is not None:
718
+ hidden_states = torch.utils.checkpoint.checkpoint(
719
+ create_custom_forward(motion_module),
720
+ hidden_states.requires_grad_(),
721
+ temb,
722
+ encoder_hidden_states,
723
+ )
724
+
725
+ else:
726
+ hidden_states = resnet(hidden_states, temb)
727
+ hidden_states = attn(
728
+ hidden_states,
729
+ encoder_hidden_states=encoder_hidden_states,
730
+ ).sample
731
+
732
+ # add motion module
733
+ hidden_states = (
734
+ motion_module(
735
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
736
+ )
737
+ if motion_module is not None
738
+ else hidden_states
739
+ )
740
+
741
+ if self.upsamplers is not None:
742
+ for upsampler in self.upsamplers:
743
+ hidden_states = upsampler(hidden_states, upsample_size)
744
+
745
+ return hidden_states
746
+
747
+
748
+ class UpBlock3D(nn.Module):
749
+ def __init__(
750
+ self,
751
+ in_channels: int,
752
+ prev_output_channel: int,
753
+ out_channels: int,
754
+ temb_channels: int,
755
+ dropout: float = 0.0,
756
+ num_layers: int = 1,
757
+ resnet_eps: float = 1e-6,
758
+ resnet_time_scale_shift: str = "default",
759
+ resnet_act_fn: str = "swish",
760
+ resnet_groups: int = 32,
761
+ resnet_pre_norm: bool = True,
762
+ output_scale_factor=1.0,
763
+ add_upsample=True,
764
+ use_inflated_groupnorm=None,
765
+ use_motion_module=None,
766
+ motion_module_type=None,
767
+ motion_module_kwargs=None,
768
+ ):
769
+ super().__init__()
770
+ resnets = []
771
+ motion_modules = []
772
+
773
+ # use_motion_module = False
774
+ for i in range(num_layers):
775
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
776
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
777
+
778
+ resnets.append(
779
+ ResnetBlock3D(
780
+ in_channels=resnet_in_channels + res_skip_channels,
781
+ out_channels=out_channels,
782
+ temb_channels=temb_channels,
783
+ eps=resnet_eps,
784
+ groups=resnet_groups,
785
+ dropout=dropout,
786
+ time_embedding_norm=resnet_time_scale_shift,
787
+ non_linearity=resnet_act_fn,
788
+ output_scale_factor=output_scale_factor,
789
+ pre_norm=resnet_pre_norm,
790
+ use_inflated_groupnorm=use_inflated_groupnorm,
791
+ )
792
+ )
793
+ motion_modules.append(
794
+ get_motion_module(
795
+ in_channels=out_channels,
796
+ motion_module_type=motion_module_type,
797
+ motion_module_kwargs=motion_module_kwargs,
798
+ )
799
+ if use_motion_module
800
+ else None
801
+ )
802
+
803
+ self.resnets = nn.ModuleList(resnets)
804
+ self.motion_modules = nn.ModuleList(motion_modules)
805
+
806
+ if add_upsample:
807
+ self.upsamplers = nn.ModuleList(
808
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
809
+ )
810
+ else:
811
+ self.upsamplers = None
812
+
813
+ self.gradient_checkpointing = False
814
+
815
+ def forward(
816
+ self,
817
+ hidden_states,
818
+ res_hidden_states_tuple,
819
+ temb=None,
820
+ upsample_size=None,
821
+ encoder_hidden_states=None,
822
+ ):
823
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
824
+ # pop res hidden states
825
+ res_hidden_states = res_hidden_states_tuple[-1]
826
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
827
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
828
+
829
+ # print(f"UpBlock3D {self.gradient_checkpointing = }")
830
+ if self.training and self.gradient_checkpointing:
831
+
832
+ def create_custom_forward(module):
833
+ def custom_forward(*inputs):
834
+ return module(*inputs)
835
+
836
+ return custom_forward
837
+
838
+ hidden_states = torch.utils.checkpoint.checkpoint(
839
+ create_custom_forward(resnet), hidden_states, temb
840
+ )
841
+ if motion_module is not None:
842
+ hidden_states = torch.utils.checkpoint.checkpoint(
843
+ create_custom_forward(motion_module),
844
+ hidden_states.requires_grad_(),
845
+ temb,
846
+ encoder_hidden_states,
847
+ )
848
+ else:
849
+ hidden_states = resnet(hidden_states, temb)
850
+ hidden_states = (
851
+ motion_module(
852
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
853
+ )
854
+ if motion_module is not None
855
+ else hidden_states
856
+ )
857
+
858
+ if self.upsamplers is not None:
859
+ for upsampler in self.upsamplers:
860
+ hidden_states = upsampler(hidden_states, upsample_size)
861
+
862
+ return hidden_states
modules/v_kps_guider.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from diffusers.models.modeling_utils import ModelMixin
6
+ from .motion_module import zero_module
7
+ from .resnet import InflatedConv3d
8
+
9
+
10
+ class VKpsGuider(ModelMixin):
11
+ def __init__(
12
+ self,
13
+ conditioning_embedding_channels: int,
14
+ conditioning_channels: int = 3,
15
+ block_out_channels: Tuple[int] = (16, 32, 64, 128),
16
+ ):
17
+ super().__init__()
18
+ self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
19
+
20
+ self.blocks = nn.ModuleList([])
21
+
22
+ for i in range(len(block_out_channels) - 1):
23
+ channel_in = block_out_channels[i]
24
+ channel_out = block_out_channels[i + 1]
25
+ self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1))
26
+ self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
27
+
28
+ self.conv_out = zero_module(InflatedConv3d(
29
+ block_out_channels[-1],
30
+ conditioning_embedding_channels,
31
+ kernel_size=3,
32
+ padding=1,
33
+ ))
34
+
35
+ def forward(self, conditioning):
36
+ embedding = self.conv_in(conditioning)
37
+ embedding = F.silu(embedding)
38
+
39
+ for block in self.blocks:
40
+ embedding = block(embedding)
41
+ embedding = F.silu(embedding)
42
+
43
+ embedding = self.conv_out(embedding)
44
+
45
+ return embedding
pipelines/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .v_express_pipeline import VExpressPipeline
pipelines/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (210 Bytes). View file
 
pipelines/__pycache__/context.cpython-310.pyc ADDED
Binary file (1.96 kB). View file
 
pipelines/__pycache__/utils.cpython-310.pyc ADDED
Binary file (6.6 kB). View file
 
pipelines/__pycache__/v_express_pipeline.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
pipelines/context.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: Adapted from cli
2
+ from typing import Callable, List, Optional
3
+
4
+ import numpy as np
5
+
6
+
7
+ def ordered_halving(val):
8
+ bin_str = f"{val:064b}"
9
+ bin_flip = bin_str[::-1]
10
+ as_int = int(bin_flip, 2)
11
+
12
+ return as_int / (1 << 64)
13
+
14
+
15
+ def uniform(
16
+ step: int = ...,
17
+ num_frames: int = ...,
18
+ context_size: Optional[int] = None,
19
+ context_stride: int = 3,
20
+ context_overlap: int = 4,
21
+ closed_loop: bool = True,
22
+ ):
23
+ if num_frames <= context_size:
24
+ yield list(range(num_frames))
25
+ return
26
+
27
+ context_stride = min(
28
+ context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
29
+ )
30
+
31
+ for context_step in 1 << np.arange(context_stride):
32
+ pad = int(round(num_frames * ordered_halving(step)))
33
+ for j in range(
34
+ int(ordered_halving(step) * context_step) + pad,
35
+ num_frames + pad + (0 if closed_loop else -context_overlap),
36
+ (context_size * context_step - context_overlap),
37
+ ):
38
+ next_itr = []
39
+ for e in range(j, j + context_size * context_step, context_step):
40
+ if e >= num_frames:
41
+ e = num_frames - 2 - e % num_frames
42
+ next_itr.append(e)
43
+
44
+ yield next_itr
45
+
46
+
47
+ def get_context_scheduler(name: str) -> Callable:
48
+ if name == "uniform":
49
+ return uniform
50
+ else:
51
+ raise ValueError(f"Unknown context_overlap policy {name}")
52
+
53
+
54
+ def get_total_steps(
55
+ scheduler,
56
+ timesteps: List[int],
57
+ num_steps: Optional[int] = None,
58
+ num_frames: int = ...,
59
+ context_size: Optional[int] = None,
60
+ context_stride: int = 3,
61
+ context_overlap: int = 4,
62
+ closed_loop: bool = True,
63
+ ):
64
+ return sum(
65
+ len(
66
+ list(
67
+ scheduler(
68
+ i,
69
+ num_steps,
70
+ num_frames,
71
+ context_size,
72
+ context_stride,
73
+ context_overlap,
74
+ )
75
+ )
76
+ )
77
+ for i in range(len(timesteps))
78
+ )
pipelines/utils.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import pathlib
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as func
9
+ import tqdm
10
+ from imageio_ffmpeg import get_ffmpeg_exe
11
+
12
+ tensor_interpolation = None
13
+
14
+
15
+ def get_tensor_interpolation_method():
16
+ return tensor_interpolation
17
+
18
+
19
+ def set_tensor_interpolation_method(is_slerp):
20
+ global tensor_interpolation
21
+ tensor_interpolation = slerp if is_slerp else linear
22
+
23
+
24
+ def linear(v1, v2, t):
25
+ return (1.0 - t) * v1 + t * v2
26
+
27
+
28
+ def slerp(v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995) -> torch.Tensor:
29
+ u0 = v0 / v0.norm()
30
+ u1 = v1 / v1.norm()
31
+ dot = (u0 * u1).sum()
32
+ if dot.abs() > DOT_THRESHOLD:
33
+ # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
34
+ return (1.0 - t) * v0 + t * v1
35
+ omega = dot.acos()
36
+ return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
37
+
38
+
39
+ def draw_kps_image(height, width, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255)]):
40
+ stick_width = 4
41
+ limb_seq = np.array([[0, 2], [1, 2]])
42
+ kps = np.array(kps)
43
+
44
+ canvas = np.zeros((height, width, 3), dtype=np.uint8)
45
+
46
+ for i in range(len(limb_seq)):
47
+ index = limb_seq[i]
48
+ color = color_list[index[0]]
49
+
50
+ x = kps[index][:, 0]
51
+ y = kps[index][:, 1]
52
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
53
+ angle = int(math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])))
54
+ polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stick_width), angle, 0, 360, 1)
55
+ cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color])
56
+
57
+ for idx_kp, kp in enumerate(kps):
58
+ color = color_list[idx_kp]
59
+ x, y = kp
60
+ cv2.circle(canvas, (int(x), int(y)), 4, color, -1)
61
+
62
+ return canvas
63
+
64
+
65
+
66
+ import os
67
+ import pathlib
68
+ import shutil
69
+ import cv2
70
+ import numpy as np
71
+ from scipy.ndimage.filters import median_filter
72
+
73
+ def get_ffmpeg_exe():
74
+ if os.name == 'nt': # Windows
75
+ return 'ffmpeg'
76
+ else: # Ubuntu and other Unix-based systems
77
+ return 'ffmpeg'
78
+
79
+
80
+ def median_filter_3d(video_tensor, kernel_size, device):
81
+ _, video_length, height, width = video_tensor.shape
82
+
83
+ pad_size = kernel_size // 2
84
+ video_tensor = func.pad(video_tensor, (pad_size, pad_size, pad_size, pad_size, pad_size, pad_size), mode='reflect')
85
+
86
+ filtered_video_tensor = []
87
+ for i in tqdm.tqdm(range(video_length), desc='Median Filtering'):
88
+ video_segment = video_tensor[:, i:i + kernel_size, ...].to(device)
89
+ video_segment = video_segment.unfold(dimension=2, size=kernel_size, step=1)
90
+ video_segment = video_segment.unfold(dimension=3, size=kernel_size, step=1)
91
+ video_segment = video_segment.permute(0, 2, 3, 1, 4, 5).reshape(3, height, width, -1)
92
+ filtered_video_frame = torch.median(video_segment, dim=-1)[0]
93
+ filtered_video_tensor.append(filtered_video_frame.cpu())
94
+ filtered_video_tensor = torch.stack(filtered_video_tensor, dim=1)
95
+ return filtered_video_tensor
96
+
97
+
98
+ def save_video(video_tensor, audio_path, output_path, device, fps=30.0):
99
+ pathlib.Path(output_path).parent.mkdir(exist_ok=True, parents=True)
100
+
101
+ video_tensor = video_tensor[0, ...]
102
+ _, num_frames, height, width = video_tensor.shape
103
+
104
+ video_tensor = median_filter_3d(video_tensor, kernel_size=3, device=device)
105
+ video_tensor = video_tensor.permute(1, 2, 3, 0)
106
+ video_frames = (video_tensor * 255).numpy().astype(np.uint8)
107
+
108
+ output_name = pathlib.Path(output_path).stem
109
+ temp_output_path = output_path.replace(output_name, output_name + '-temp')
110
+ video_writer = cv2.VideoWriter(temp_output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
111
+
112
+ for i in tqdm.tqdm(range(num_frames), 'Writing frames into file'):
113
+ frame_image = video_frames[i, ...]
114
+ frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR)
115
+ video_writer.write(frame_image)
116
+ video_writer.release()
117
+
118
+ cmd = (f'{get_ffmpeg_exe()} -i "{temp_output_path}" -i "{audio_path}" '
119
+ f'-map 0:v -map 1:a -c:v h264 -shortest -y "{output_path}" -loglevel quiet')
120
+ os.system(cmd)
121
+
122
+ os.remove(temp_output_path)
123
+
124
+
125
+
126
+ def compute_dist(x1, y1, x2, y2):
127
+ return math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
128
+
129
+
130
+ def compute_ratio(kps):
131
+ l_eye_x, l_eye_y = kps[0][0], kps[0][1]
132
+ r_eye_x, r_eye_y = kps[1][0], kps[1][1]
133
+ nose_x, nose_y = kps[2][0], kps[2][1]
134
+ d_left = compute_dist(l_eye_x, l_eye_y, nose_x, nose_y)
135
+ d_right = compute_dist(r_eye_x, r_eye_y, nose_x, nose_y)
136
+ ratio = d_left / (d_right + 1e-6)
137
+ return ratio
138
+
139
+
140
+ def point_to_line_dist(point, line_points):
141
+ point = np.array(point)
142
+ line_points = np.array(line_points)
143
+ line_vec = line_points[1] - line_points[0]
144
+ point_vec = point - line_points[0]
145
+ line_norm = line_vec / np.sqrt(np.sum(line_vec ** 2))
146
+ point_vec_scaled = point_vec * 1.0 / np.sqrt(np.sum(line_vec ** 2))
147
+ t = np.dot(line_norm, point_vec_scaled)
148
+ if t < 0.0:
149
+ t = 0.0
150
+ elif t > 1.0:
151
+ t = 1.0
152
+ nearest = line_points[0] + t * line_vec
153
+ dist = np.sqrt(np.sum((point - nearest) ** 2))
154
+ return dist
155
+
156
+
157
+ def get_face_size(kps):
158
+ # 0: left eye, 1: right eye, 2: nose
159
+ A = kps[0, :]
160
+ B = kps[1, :]
161
+ C = kps[2, :]
162
+
163
+ AB_dist = math.sqrt((A[0] - B[0]) ** 2 + (A[1] - B[1]) ** 2)
164
+ C_AB_dist = point_to_line_dist(C, [A, B])
165
+ return AB_dist, C_AB_dist
166
+
167
+
168
+ def get_rescale_params(kps_ref, kps_target):
169
+ kps_ref = np.array(kps_ref)
170
+ kps_target = np.array(kps_target)
171
+
172
+ ref_AB_dist, ref_C_AB_dist = get_face_size(kps_ref)
173
+ target_AB_dist, target_C_AB_dist = get_face_size(kps_target)
174
+
175
+ scale_width = ref_AB_dist / target_AB_dist
176
+ scale_height = ref_C_AB_dist / target_C_AB_dist
177
+
178
+ return scale_width, scale_height
179
+
180
+
181
+ def retarget_kps(ref_kps, tgt_kps_list, only_offset=True):
182
+ ref_kps = np.array(ref_kps)
183
+ tgt_kps_list = np.array(tgt_kps_list)
184
+
185
+ ref_ratio = compute_ratio(ref_kps)
186
+
187
+ ratio_delta = 10000
188
+ selected_tgt_kps_idx = None
189
+ for idx, tgt_kps in enumerate(tgt_kps_list):
190
+ tgt_ratio = compute_ratio(tgt_kps)
191
+ if math.fabs(tgt_ratio - ref_ratio) < ratio_delta:
192
+ selected_tgt_kps_idx = idx
193
+ ratio_delta = tgt_ratio
194
+
195
+ scale_width, scale_height = get_rescale_params(
196
+ kps_ref=ref_kps,
197
+ kps_target=tgt_kps_list[selected_tgt_kps_idx],
198
+ )
199
+
200
+ rescaled_tgt_kps_list = np.array(tgt_kps_list)
201
+ rescaled_tgt_kps_list[:, :, 0] *= scale_width
202
+ rescaled_tgt_kps_list[:, :, 1] *= scale_height
203
+
204
+ if only_offset:
205
+ nose_offset = rescaled_tgt_kps_list[:, 2, :] - rescaled_tgt_kps_list[0, 2, :]
206
+ nose_offset = nose_offset[:, np.newaxis, :]
207
+ ref_kps_repeat = np.tile(ref_kps, (tgt_kps_list.shape[0], 1, 1))
208
+
209
+ ref_kps_repeat[:, :, :] -= (nose_offset / 2.0)
210
+ rescaled_tgt_kps_list = ref_kps_repeat
211
+ else:
212
+ nose_offset_x = rescaled_tgt_kps_list[0, 2, 0] - ref_kps[2][0]
213
+ nose_offset_y = rescaled_tgt_kps_list[0, 2, 1] - ref_kps[2][1]
214
+
215
+ rescaled_tgt_kps_list[:, :, 0] -= nose_offset_x
216
+ rescaled_tgt_kps_list[:, :, 1] -= nose_offset_y
217
+
218
+ return rescaled_tgt_kps_list
pipelines/v_express_pipeline.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/pipelines/pipeline_animation.py
2
+ import inspect
3
+ import math
4
+ from typing import Callable, List, Optional, Union
5
+
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from diffusers.schedulers import (
10
+ DDIMScheduler,
11
+ DPMSolverMultistepScheduler,
12
+ EulerAncestralDiscreteScheduler,
13
+ EulerDiscreteScheduler,
14
+ LMSDiscreteScheduler,
15
+ PNDMScheduler,
16
+ )
17
+ from diffusers.utils import is_accelerate_available
18
+ from diffusers.utils.torch_utils import randn_tensor
19
+ from einops import rearrange
20
+ from tqdm import tqdm
21
+ from transformers import CLIPImageProcessor
22
+
23
+ from modules import ReferenceAttentionControl
24
+ from .context import get_context_scheduler
25
+
26
+
27
+ def retrieve_timesteps(
28
+ scheduler,
29
+ num_inference_steps: Optional[int] = None,
30
+ device: Optional[Union[str, torch.device]] = None,
31
+ timesteps: Optional[List[int]] = None,
32
+ **kwargs,
33
+ ):
34
+ """
35
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
36
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
37
+
38
+ Args:
39
+ scheduler (`SchedulerMixin`):
40
+ The scheduler to get timesteps from.
41
+ num_inference_steps (`int`):
42
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
43
+ `timesteps` must be `None`.
44
+ device (`str` or `torch.device`, *optional*):
45
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
46
+ timesteps (`List[int]`, *optional*):
47
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
48
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
49
+ must be `None`.
50
+
51
+ Returns:
52
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
53
+ second element is the number of inference steps.
54
+ """
55
+ if timesteps is not None:
56
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
57
+ if not accepts_timesteps:
58
+ raise ValueError(
59
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
60
+ f" timestep schedules. Please check whether you are using the correct scheduler."
61
+ )
62
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
63
+ timesteps = scheduler.timesteps
64
+ num_inference_steps = len(timesteps)
65
+ else:
66
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
67
+ timesteps = scheduler.timesteps
68
+ return timesteps, num_inference_steps
69
+
70
+
71
+ class VExpressPipeline(DiffusionPipeline):
72
+ _optional_components = []
73
+
74
+ def __init__(
75
+ self,
76
+ vae,
77
+ reference_net,
78
+ denoising_unet,
79
+ v_kps_guider,
80
+ audio_processor,
81
+ audio_encoder,
82
+ audio_projection,
83
+ scheduler: Union[
84
+ DDIMScheduler,
85
+ PNDMScheduler,
86
+ LMSDiscreteScheduler,
87
+ EulerDiscreteScheduler,
88
+ EulerAncestralDiscreteScheduler,
89
+ DPMSolverMultistepScheduler,
90
+ ],
91
+ image_proj_model=None,
92
+ tokenizer=None,
93
+ text_encoder=None,
94
+ ):
95
+ super().__init__()
96
+
97
+ self.register_modules(
98
+ vae=vae,
99
+ reference_net=reference_net,
100
+ denoising_unet=denoising_unet,
101
+ v_kps_guider=v_kps_guider,
102
+ audio_processor=audio_processor,
103
+ audio_encoder=audio_encoder,
104
+ audio_projection=audio_projection,
105
+ scheduler=scheduler,
106
+ image_proj_model=image_proj_model,
107
+ tokenizer=tokenizer,
108
+ text_encoder=text_encoder,
109
+ )
110
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
111
+ self.clip_image_processor = CLIPImageProcessor()
112
+ self.reference_image_processor = VaeImageProcessor(
113
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
114
+ )
115
+ self.condition_image_processor = VaeImageProcessor(
116
+ vae_scale_factor=self.vae_scale_factor,
117
+ do_convert_rgb=True,
118
+ do_normalize=False,
119
+ )
120
+
121
+ def enable_vae_slicing(self):
122
+ self.vae.enable_slicing()
123
+
124
+ def disable_vae_slicing(self):
125
+ self.vae.disable_slicing()
126
+
127
+ def enable_sequential_cpu_offload(self, gpu_id=0):
128
+ if is_accelerate_available():
129
+ from accelerate import cpu_offload
130
+ else:
131
+ raise ImportError("Please install accelerate via `pip install accelerate`")
132
+
133
+ device = torch.device(f"cuda:{gpu_id}")
134
+
135
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
136
+ if cpu_offloaded_model is not None:
137
+ cpu_offload(cpu_offloaded_model, device)
138
+
139
+ @property
140
+ def _execution_device(self):
141
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
142
+ return self.device
143
+ for module in self.unet.modules():
144
+ if (
145
+ hasattr(module, "_hf_hook")
146
+ and hasattr(module._hf_hook, "execution_device")
147
+ and module._hf_hook.execution_device is not None
148
+ ):
149
+ return torch.device(module._hf_hook.execution_device)
150
+ return self.device
151
+
152
+ @torch.no_grad()
153
+ def decode_latents(self, latents):
154
+ video_length = latents.shape[2]
155
+ latents = 1 / 0.18215 * latents
156
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
157
+ video = []
158
+ for frame_idx in tqdm(range(latents.shape[0]), desc='Decoding latents into frames'):
159
+ image = self.vae.decode(latents[frame_idx: frame_idx + 1].to(self.vae.device)).sample
160
+ image = (image / 2 + 0.5).clamp(0, 1)
161
+ image = image.cpu().float()
162
+ video.append(image)
163
+ video = torch.cat(video)
164
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
165
+
166
+ return video
167
+
168
+ def prepare_extra_step_kwargs(self, generator, eta):
169
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
170
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
171
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
172
+ # and should be between [0, 1]
173
+
174
+ accepts_eta = "eta" in set(
175
+ inspect.signature(self.scheduler.step).parameters.keys()
176
+ )
177
+ extra_step_kwargs = {}
178
+ if accepts_eta:
179
+ extra_step_kwargs["eta"] = eta
180
+
181
+ # check if the scheduler accepts generator
182
+ accepts_generator = "generator" in set(
183
+ inspect.signature(self.scheduler.step).parameters.keys()
184
+ )
185
+ if accepts_generator:
186
+ extra_step_kwargs["generator"] = generator
187
+ return extra_step_kwargs
188
+
189
+ def prepare_latents(
190
+ self,
191
+ batch_size,
192
+ num_channels_latents,
193
+ width,
194
+ height,
195
+ video_length,
196
+ dtype,
197
+ device,
198
+ generator,
199
+ latents=None
200
+ ):
201
+ shape = (
202
+ batch_size,
203
+ num_channels_latents,
204
+ video_length,
205
+ height // self.vae_scale_factor,
206
+ width // self.vae_scale_factor,
207
+ )
208
+ if isinstance(generator, list) and len(generator) != batch_size:
209
+ raise ValueError(
210
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
211
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
212
+ )
213
+
214
+ if latents is None:
215
+ latents = randn_tensor(
216
+ shape, generator=generator, device=device, dtype=dtype
217
+ )
218
+
219
+ else:
220
+ latents = latents.to(device)
221
+
222
+ # scale the initial noise by the standard deviation required by the scheduler
223
+ latents = latents * self.scheduler.init_noise_sigma
224
+ return latents
225
+
226
+ def _encode_prompt(
227
+ self,
228
+ prompt,
229
+ device,
230
+ num_videos_per_prompt,
231
+ do_classifier_free_guidance,
232
+ negative_prompt,
233
+ ):
234
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
235
+
236
+ text_inputs = self.tokenizer(
237
+ prompt,
238
+ padding="max_length",
239
+ max_length=self.tokenizer.model_max_length,
240
+ truncation=True,
241
+ return_tensors="pt",
242
+ )
243
+ text_input_ids = text_inputs.input_ids
244
+ untruncated_ids = self.tokenizer(
245
+ prompt, padding="longest", return_tensors="pt"
246
+ ).input_ids
247
+
248
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
249
+ text_input_ids, untruncated_ids
250
+ ):
251
+ removed_text = self.tokenizer.batch_decode(
252
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
253
+ )
254
+
255
+ if (
256
+ hasattr(self.text_encoder.config, "use_attention_mask")
257
+ and self.text_encoder.config.use_attention_mask
258
+ ):
259
+ attention_mask = text_inputs.attention_mask.to(device)
260
+ else:
261
+ attention_mask = None
262
+
263
+ text_embeddings = self.text_encoder(
264
+ text_input_ids.to(device),
265
+ attention_mask=attention_mask,
266
+ )
267
+ text_embeddings = text_embeddings[0]
268
+
269
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
270
+ bs_embed, seq_len, _ = text_embeddings.shape
271
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
272
+ text_embeddings = text_embeddings.view(
273
+ bs_embed * num_videos_per_prompt, seq_len, -1
274
+ )
275
+
276
+ # get unconditional embeddings for classifier free guidance
277
+ if do_classifier_free_guidance:
278
+ uncond_tokens: List[str]
279
+ if negative_prompt is None:
280
+ uncond_tokens = [""] * batch_size
281
+ elif type(prompt) is not type(negative_prompt):
282
+ raise TypeError(
283
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
284
+ f" {type(prompt)}."
285
+ )
286
+ elif isinstance(negative_prompt, str):
287
+ uncond_tokens = [negative_prompt]
288
+ elif batch_size != len(negative_prompt):
289
+ raise ValueError(
290
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
291
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
292
+ " the batch size of `prompt`."
293
+ )
294
+ else:
295
+ uncond_tokens = negative_prompt
296
+
297
+ max_length = text_input_ids.shape[-1]
298
+ uncond_input = self.tokenizer(
299
+ uncond_tokens,
300
+ padding="max_length",
301
+ max_length=max_length,
302
+ truncation=True,
303
+ return_tensors="pt",
304
+ )
305
+
306
+ if (
307
+ hasattr(self.text_encoder.config, "use_attention_mask")
308
+ and self.text_encoder.config.use_attention_mask
309
+ ):
310
+ attention_mask = uncond_input.attention_mask.to(device)
311
+ else:
312
+ attention_mask = None
313
+
314
+ uncond_embeddings = self.text_encoder(
315
+ uncond_input.input_ids.to(device),
316
+ attention_mask=attention_mask,
317
+ )
318
+ uncond_embeddings = uncond_embeddings[0]
319
+
320
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
321
+ seq_len = uncond_embeddings.shape[1]
322
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
323
+ uncond_embeddings = uncond_embeddings.view(
324
+ batch_size * num_videos_per_prompt, seq_len, -1
325
+ )
326
+
327
+ # For classifier free guidance, we need to do two forward passes.
328
+ # Here we concatenate the unconditional and text embeddings into a single batch
329
+ # to avoid doing two forward passes
330
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
331
+
332
+ return text_embeddings
333
+
334
+ def get_timesteps(self, num_inference_steps, strength, device):
335
+ # get the original timestep using init_timestep
336
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
337
+
338
+ t_start = max(num_inference_steps - init_timestep, 0)
339
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
340
+
341
+ return timesteps, num_inference_steps - t_start
342
+
343
+ def prepare_reference_latent(self, reference_image, height, width):
344
+ reference_image_tensor = self.reference_image_processor.preprocess(reference_image, height=height, width=width)
345
+ reference_image_tensor = reference_image_tensor.to(dtype=self.dtype, device=self.device)
346
+ reference_image_latents = self.vae.encode(reference_image_tensor).latent_dist.mean
347
+ reference_image_latents = reference_image_latents * 0.18215
348
+ return reference_image_latents
349
+
350
+ def prepare_kps_feature(self, kps_images, height, width, do_classifier_free_guidance):
351
+ kps_image_tensors = []
352
+ for idx, kps_image in enumerate(kps_images):
353
+ kps_image_tensor = self.condition_image_processor.preprocess(kps_image, height=height, width=width)
354
+ kps_image_tensor = kps_image_tensor.unsqueeze(2) # [bs, c, 1, h, w]
355
+ kps_image_tensors.append(kps_image_tensor)
356
+ kps_images_tensor = torch.cat(kps_image_tensors, dim=2) # [bs, c, t, h, w]
357
+
358
+ bs = 16
359
+ num_forward = math.ceil(kps_images_tensor.shape[2] / bs)
360
+ kps_feature = []
361
+ for i in range(num_forward):
362
+ tensor = kps_images_tensor[:, :, i * bs:(i + 1) * bs, ...].to(device=self.device, dtype=self.dtype)
363
+ feature = self.v_kps_guider(tensor).cpu()
364
+ kps_feature.append(feature)
365
+ torch.cuda.empty_cache()
366
+ kps_feature = torch.cat(kps_feature, dim=2)
367
+
368
+ if do_classifier_free_guidance:
369
+ uc_kps_feature = torch.zeros_like(kps_feature)
370
+ kps_feature = torch.cat([uc_kps_feature, kps_feature], dim=0)
371
+
372
+ return kps_feature
373
+
374
+ def prepare_audio_embeddings(self, audio_waveform, video_length, num_pad_audio_frames, do_classifier_free_guidance):
375
+ audio_waveform = self.audio_processor(audio_waveform, return_tensors="pt", sampling_rate=16000)['input_values']
376
+ audio_waveform = audio_waveform.to(self.device, self.dtype)
377
+ audio_embeddings = self.audio_encoder(audio_waveform).last_hidden_state # [1, num_embeds, d]
378
+
379
+ audio_embeddings = torch.nn.functional.interpolate(
380
+ audio_embeddings.permute(0, 2, 1),
381
+ size=2 * video_length,
382
+ mode='linear',
383
+ )[0, :, :].permute(1, 0) # [2*vid_len, dim]
384
+
385
+ audio_embeddings = torch.cat([
386
+ torch.zeros_like(audio_embeddings)[:2 * num_pad_audio_frames, :],
387
+ audio_embeddings,
388
+ torch.zeros_like(audio_embeddings)[:2 * num_pad_audio_frames, :],
389
+ ], dim=0) # [2*num_pad+2*vid_len+2*num_pad, dim]
390
+
391
+ frame_audio_embeddings = []
392
+ for frame_idx in range(video_length):
393
+ start_sample = frame_idx
394
+ end_sample = frame_idx + 2 * num_pad_audio_frames
395
+
396
+ frame_audio_embedding = audio_embeddings[2 * start_sample:2 * (end_sample + 1), :] # [2*num_pad+1, dim]
397
+ frame_audio_embeddings.append(frame_audio_embedding)
398
+ audio_embeddings = torch.stack(frame_audio_embeddings, dim=0) # [vid_len, 2*num_pad+1, dim]
399
+
400
+ audio_embeddings = self.audio_projection(audio_embeddings).unsqueeze(0)
401
+ if do_classifier_free_guidance:
402
+ uc_audio_embeddings = torch.zeros_like(audio_embeddings)
403
+ audio_embeddings = torch.cat([uc_audio_embeddings, audio_embeddings], dim=0)
404
+ return audio_embeddings
405
+
406
+ @torch.no_grad()
407
+ def __call__(
408
+ self,
409
+ reference_image,
410
+ kps_images,
411
+ audio_waveform,
412
+ width,
413
+ height,
414
+ video_length,
415
+ num_inference_steps,
416
+ guidance_scale,
417
+ strength=1.,
418
+ num_images_per_prompt=1,
419
+ eta: float = 0.0,
420
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
421
+ output_type: Optional[str] = "tensor",
422
+ return_dict: bool = True,
423
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
424
+ callback_steps: Optional[int] = 1,
425
+ context_schedule="uniform",
426
+ context_frames=24,
427
+ context_overlap=4,
428
+ reference_attention_weight=1.,
429
+ audio_attention_weight=1.,
430
+ num_pad_audio_frames=2,
431
+ do_multi_devices_inference=False,
432
+ save_gpu_memory=False,
433
+ **kwargs,
434
+ ):
435
+ # Default height and width to unet
436
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
437
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
438
+
439
+ device = self._execution_device
440
+
441
+ do_classifier_free_guidance = guidance_scale > 1.0
442
+ batch_size = 1
443
+
444
+ # Prepare timesteps
445
+ timesteps = None
446
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
447
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
448
+
449
+ reference_control_writer = ReferenceAttentionControl(
450
+ self.reference_net,
451
+ do_classifier_free_guidance=do_classifier_free_guidance,
452
+ mode="write",
453
+ batch_size=batch_size,
454
+ fusion_blocks="full",
455
+ )
456
+ reference_control_reader = ReferenceAttentionControl(
457
+ self.denoising_unet,
458
+ do_classifier_free_guidance=do_classifier_free_guidance,
459
+ mode="read",
460
+ batch_size=batch_size,
461
+ fusion_blocks="full",
462
+ reference_attention_weight=reference_attention_weight,
463
+ audio_attention_weight=audio_attention_weight,
464
+ )
465
+
466
+ num_channels_latents = self.denoising_unet.in_channels
467
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
468
+
469
+ reference_image_latents = self.prepare_reference_latent(reference_image, height, width)
470
+ kps_feature = self.prepare_kps_feature(kps_images, height, width, do_classifier_free_guidance)
471
+ if save_gpu_memory:
472
+ del self.v_kps_guider
473
+ torch.cuda.empty_cache()
474
+ audio_embeddings = self.prepare_audio_embeddings(
475
+ audio_waveform,
476
+ video_length,
477
+ num_pad_audio_frames,
478
+ do_classifier_free_guidance,
479
+ )
480
+ if save_gpu_memory:
481
+ del self.audio_processor, self.audio_encoder, self.audio_projection
482
+ torch.cuda.empty_cache()
483
+
484
+ context_scheduler = get_context_scheduler(context_schedule)
485
+ context_queue = list(
486
+ context_scheduler(
487
+ step=0,
488
+ num_frames=video_length,
489
+ context_size=context_frames,
490
+ context_stride=1,
491
+ context_overlap=context_overlap,
492
+ closed_loop=False,
493
+ )
494
+ )
495
+
496
+ num_frame_context = torch.zeros(video_length, device=device, dtype=torch.long)
497
+ for context in context_queue:
498
+ num_frame_context[context] += 1
499
+
500
+ encoder_hidden_states = torch.zeros((1, 1, 768), dtype=self.dtype, device=self.device)
501
+ self.reference_net(
502
+ reference_image_latents,
503
+ timestep=0,
504
+ encoder_hidden_states=encoder_hidden_states,
505
+ return_dict=False,
506
+ )
507
+ reference_control_reader.update(reference_control_writer, do_classifier_free_guidance)
508
+ if save_gpu_memory:
509
+ del self.reference_net
510
+ torch.cuda.empty_cache()
511
+
512
+ latents = self.prepare_latents(
513
+ batch_size * num_images_per_prompt,
514
+ num_channels_latents,
515
+ width,
516
+ height,
517
+ video_length,
518
+ self.dtype,
519
+ torch.device('cpu'),
520
+ generator,
521
+ )
522
+
523
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
524
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
525
+ for i, t in enumerate(timesteps):
526
+ context_counter = torch.zeros(video_length, device=device, dtype=torch.long)
527
+ noise_preds = [None] * video_length
528
+ for context_idx, context in enumerate(context_queue):
529
+ latent_kps_feature = kps_feature[:, :, context].to(device, self.dtype)
530
+
531
+ latent_audio_embeddings = audio_embeddings[:, context, ...]
532
+ _, _, num_tokens, dim = latent_audio_embeddings.shape
533
+ latent_audio_embeddings = latent_audio_embeddings.reshape(-1, num_tokens, dim)
534
+
535
+ input_latents = latents[:, :, context, ...].to(device)
536
+ input_latents = input_latents.repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
537
+ input_latents = self.scheduler.scale_model_input(input_latents, t)
538
+ noise_pred = self.denoising_unet(
539
+ input_latents,
540
+ t,
541
+ encoder_hidden_states=latent_audio_embeddings.reshape(-1, num_tokens, dim),
542
+ kps_features=latent_kps_feature,
543
+ return_dict=False,
544
+ )[0]
545
+ if do_classifier_free_guidance:
546
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
547
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
548
+
549
+ context_counter[context] += 1
550
+ noise_pred /= num_frame_context[context][None, None, :, None, None]
551
+ step_frame_ids = []
552
+ step_noise_preds = []
553
+ for latent_idx, frame_idx in enumerate(context):
554
+ if noise_preds[frame_idx] is None:
555
+ noise_preds[frame_idx] = noise_pred[:, :, latent_idx, ...]
556
+ else:
557
+ noise_preds[frame_idx] += noise_pred[:, :, latent_idx, ...]
558
+ if context_counter[frame_idx] == num_frame_context[frame_idx]:
559
+ step_frame_ids.append(frame_idx)
560
+ step_noise_preds.append(noise_preds[frame_idx])
561
+ noise_preds[frame_idx] = None
562
+ step_noise_preds = torch.stack(step_noise_preds, dim=2)
563
+ output_latents = self.scheduler.step(
564
+ step_noise_preds,
565
+ t,
566
+ latents[:, :, step_frame_ids, ...].to(device),
567
+ **extra_step_kwargs,
568
+ ).prev_sample
569
+ latents[:, :, step_frame_ids, ...] = output_latents.cpu()
570
+
571
+ progress_bar.set_description(
572
+ f'Denoising Step Index: {i + 1} / {len(timesteps)}, '
573
+ f'Context Index: {context_idx + 1} / {len(context_queue)}'
574
+ )
575
+
576
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
577
+ progress_bar.update()
578
+ if callback is not None and i % callback_steps == 0:
579
+ step_idx = i // getattr(self.scheduler, "order", 1)
580
+ callback(step_idx, t, latents)
581
+
582
+ reference_control_reader.clear()
583
+ reference_control_writer.clear()
584
+
585
+ video_tensor = self.decode_latents(latents)
586
+ return video_tensor
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.24.0
2
+ imageio-ffmpeg==0.4.9
3
+ insightface==0.7.3
4
+ omegaconf==2.2.3
5
+ onnx==1.14.0
6
+ onnxruntime-gpu==1.16.3
7
+ safetensors==0.4.2
8
+ torch==2.0.1
9
+ torchaudio==2.0.2
10
+ torchvision==0.15.2
11
+ transformers==4.41.1
12
+ einops==0.4.1
13
+ tqdm==4.66.4
14
+ xformers==0.0.22
15
+ av==11.0.0
16
+ gradio
17
+ retina-face
18
+ tf-keras
19
+ filetype
20
+ bitsandbytes==0.43.0
21
+ accelerate
scripts/crop.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from retinaface import RetinaFace
2
+ from PIL import Image
3
+ import torch
4
+
5
+
6
+
7
+ def auto_crop_image(image_path=r"F:\V_Express_V1\Material\Biden_Photo_Big.png", expand_percent=0.15, crop_size=(512, 512)):
8
+ # Check if CUDA is available
9
+ if torch.cuda.is_available():
10
+ device = 'cuda'
11
+ print("Using GPU for RetinaFace detection.")
12
+ else:
13
+ device = 'cpu'
14
+ print("Using CPU for RetinaFace detection.")
15
+
16
+ # Load image
17
+ img = Image.open(image_path)
18
+
19
+ # Perform face detection
20
+ faces = RetinaFace.detect_faces(image_path)
21
+
22
+ if not faces:
23
+ print("No faces detected.")
24
+ return None
25
+
26
+ # Assuming 'faces' is a dictionary of detected faces
27
+ # Pick the first face detected
28
+ face = list(faces.values())[0]
29
+ landmarks = face['landmarks']
30
+
31
+ # Extract the landmarks
32
+ right_eye = landmarks['right_eye']
33
+ left_eye = landmarks['left_eye']
34
+ right_mouth = landmarks['mouth_right']
35
+ left_mouth = landmarks['mouth_left']
36
+
37
+ # Calculate the distance between the eyes
38
+ eye_distance = abs(right_eye[0] - left_eye[0])
39
+
40
+ # Estimate the head width and height
41
+ head_width = eye_distance * 4.5 # Increase the width multiplier
42
+ head_height = eye_distance * 6.5 # Increase the height multiplier
43
+
44
+ # Calculate the center point between the eyes
45
+ eye_center_x = (right_eye[0] + left_eye[0]) // 2
46
+ eye_center_y = (right_eye[1] + left_eye[1]) // 2
47
+
48
+ # Calculate the top-left and bottom-right coordinates of the assumed head region
49
+ head_left = max(0, int(eye_center_x - head_width // 2))
50
+ head_top = max(0, int(eye_center_y - head_height // 2)) # Adjust the top coordinate
51
+ head_right = min(img.width, int(eye_center_x + head_width // 2))
52
+ head_bottom = min(img.height, int(eye_center_y + head_height // 2)) # Adjust the bottom coordinate
53
+
54
+ # Save the assumed head image
55
+ assumed_head_img = img.crop((head_left, head_top, head_right, head_bottom))
56
+ assumed_head_img.save("assumed_head.png", format='PNG')
57
+
58
+ # Calculate the expansion in pixels and the new dimensions
59
+ expanded_w = int(head_width * (1 + expand_percent))
60
+ expanded_h = int(head_height * (1 + expand_percent))
61
+
62
+ # Calculate the top-left and bottom-right points of the expanded box
63
+ center_x, center_y = head_left + head_width // 2, head_top + head_height // 2
64
+ left = max(0, center_x - expanded_w // 2)
65
+ right = min(img.width, center_x + expanded_w // 2)
66
+ top = max(0, center_y - expanded_h // 2)
67
+ bottom = min(img.height, center_y + expanded_h // 2)
68
+
69
+ # Crop the image with the expanded boundaries
70
+ cropped_img = img.crop((left, top, right, bottom))
71
+ cropped_img.save("expanded_face.png", format='PNG')
72
+
73
+ # Calculate the aspect ratio of the cropped image
74
+ cropped_width, cropped_height = cropped_img.size
75
+ aspect_ratio = cropped_width / cropped_height
76
+
77
+ # Calculate the target dimensions based on the desired crop size
78
+ target_width = crop_size[0]
79
+ target_height = crop_size[1]
80
+
81
+ # Adjust the crop to match the desired aspect ratio
82
+ if aspect_ratio > target_width / target_height:
83
+ # Crop from left and right
84
+ new_width = int(cropped_height * target_width / target_height)
85
+ left_crop = (cropped_width - new_width) // 2
86
+ right_crop = left_crop + new_width
87
+ top_crop = 0
88
+ bottom_crop = cropped_height
89
+ else:
90
+ # Crop from top and bottom
91
+ new_height = int(cropped_width * target_height / target_width)
92
+ top_crop = (cropped_height - new_height) // 2
93
+ bottom_crop = top_crop + new_height
94
+ left_crop = 0
95
+ right_crop = cropped_width
96
+
97
+ # Crop the image with the adjusted boundaries
98
+ final_cropped_img = cropped_img.crop((left_crop, top_crop, right_crop, bottom_crop))
99
+ final_cropped_img.save("final_cropped_img.png", format='PNG')
100
+
101
+ # Resize the cropped image to the desired size (512x512 by default) with best quality
102
+ resized_img = final_cropped_img.resize(crop_size, resample=Image.LANCZOS)
103
+
104
+ # Save the resized image as PNG
105
+ resized_img_path = image_path.rsplit('.', 1)[0] + '_cropped.png' # Change file name to avoid overwriting
106
+ resized_img.save("resized_img.png", format='PNG')
107
+
108
+ auto_crop_image()