Yw22 commited on
Commit
d711508
1 Parent(s): 03ab2b7
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. ImageConductor_app.py +586 -0
  3. app.py +577 -0
  4. configs/.DS_Store +0 -0
  5. configs/inference/flow_condition.yaml +18 -0
  6. configs/inference/image_condition.yaml +18 -0
  7. configs/inference/inference.yaml +22 -0
  8. models/.DS_Store +0 -0
  9. modules/__pycache__/attention.cpython-310.pyc +0 -0
  10. modules/__pycache__/flow_controlnet.cpython-310.pyc +0 -0
  11. modules/__pycache__/image_controlnet.cpython-310.pyc +0 -0
  12. modules/__pycache__/motion_module.cpython-310.pyc +0 -0
  13. modules/__pycache__/resnet.cpython-310.pyc +0 -0
  14. modules/__pycache__/unet.cpython-310.pyc +0 -0
  15. modules/__pycache__/unet_blocks.cpython-310.pyc +0 -0
  16. modules/attention.py +396 -0
  17. modules/flow_controlnet.py +591 -0
  18. modules/image_controlnet.py +721 -0
  19. modules/motion_module.py +355 -0
  20. modules/resnet.py +261 -0
  21. modules/unet.py +591 -0
  22. modules/unet_blocks.py +866 -0
  23. peft/__init__.py +98 -0
  24. peft/__pycache__/__init__.cpython-310.pyc +0 -0
  25. peft/__pycache__/auto.cpython-310.pyc +0 -0
  26. peft/__pycache__/config.cpython-310.pyc +0 -0
  27. peft/__pycache__/import_utils.cpython-310.pyc +0 -0
  28. peft/__pycache__/mapping.cpython-310.pyc +0 -0
  29. peft/__pycache__/mixed_model.cpython-310.pyc +0 -0
  30. peft/__pycache__/peft_model.cpython-310.pyc +0 -0
  31. peft/auto.py +170 -0
  32. peft/config.py +270 -0
  33. peft/helpers.py +148 -0
  34. peft/import_utils.py +89 -0
  35. peft/mapping.py +181 -0
  36. peft/mixed_model.py +415 -0
  37. peft/peft_model.py +0 -0
  38. peft/py.typed +0 -0
  39. peft/tuners/__init__.py +35 -0
  40. peft/tuners/__pycache__/__init__.cpython-310.pyc +0 -0
  41. peft/tuners/__pycache__/lycoris_utils.cpython-310.pyc +0 -0
  42. peft/tuners/__pycache__/tuners_utils.cpython-310.pyc +0 -0
  43. peft/tuners/adalora/__init__.py +37 -0
  44. peft/tuners/adalora/__pycache__/__init__.cpython-310.pyc +0 -0
  45. peft/tuners/adalora/__pycache__/bnb.cpython-310.pyc +0 -0
  46. peft/tuners/adalora/__pycache__/config.cpython-310.pyc +0 -0
  47. peft/tuners/adalora/__pycache__/gptq.cpython-310.pyc +0 -0
  48. peft/tuners/adalora/__pycache__/layer.cpython-310.pyc +0 -0
  49. peft/tuners/adalora/__pycache__/model.cpython-310.pyc +0 -0
  50. peft/tuners/adalora/bnb.py +145 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
ImageConductor_app.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ import cv2
5
+ import uuid
6
+ import torch
7
+ import torchvision
8
+ import json
9
+
10
+ from PIL import Image
11
+ from omegaconf import OmegaConf
12
+ from einops import rearrange, repeat
13
+ from torchvision import transforms
14
+ from transformers import CLIPTextModel, CLIPTokenizer
15
+ from diffusers import AutoencoderKL, DDIMScheduler
16
+
17
+ from pipelines.pipeline_imagecoductor import ImageConductorPipeline
18
+ from modules.unet import UNet3DConditionFlowModel
19
+ from utils.gradio_utils import ensure_dirname, split_filename, visualize_drag, image2pil, image2arr
20
+ from utils.utils import create_image_controlnet, create_flow_controlnet, interpolate_trajectory, load_weights, load_model, bivariate_Gaussian
21
+ from utils.lora_utils import add_LoRA_to_controlnet
22
+ from utils.visualizer import Visualizer, vis_flow_to_video
23
+ #### Description ####
24
+ title = r"""<h1 align="center">CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models</h1>"""
25
+
26
+ head = r"""
27
+ <div style="text-align: center;">
28
+ <h1>Image Conductor: Precision Control for Interactive Video Synthesis</h1>
29
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
30
+ <a href=""></a>
31
+ <a href='https://liyaowei-stu.github.io/project/ImageConductor/'><img src='https://img.shields.io/badge/Project_Page-ImgaeConductor-green' alt='Project Page'></a>
32
+ <a href='https://arxiv.org/pdf/2406.15339'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
33
+ <a href='https://github.com/liyaowei-stu/ImageConductor'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
34
+
35
+
36
+ </div>
37
+ </br>
38
+ </div>
39
+ """
40
+
41
+
42
+
43
+ descriptions = r"""
44
+ Official Gradio Demo for <a href='https://github.com/liyaowei-stu/ImageConductor'><b>Image Conductor: Precision Control for Interactive Video Synthesis</b></a>.<br>
45
+ 🧙Image Conductor enables precise, fine-grained control for generating motion-controllable videos from images, advancing the practical application of interactive video synthesis.<br>
46
+ """
47
+
48
+
49
+ instructions = r"""
50
+ - ⭐️ <b>step1: </b>Upload or select one image from Example.
51
+ - ⭐️ <b>step2: </b>Click 'Add Drag' to draw some drags.
52
+ - ⭐️ <b>step3: </b>Input text prompt that complements the image (Necessary).
53
+ - ⭐️ <b>step4: </b>Select 'Drag Mode' to specify the control of camera transition or object movement.
54
+ - ⭐️ <b>step5: </b>Click 'Run' button to generate video assets.
55
+ - ⭐️ <b>others: </b>Click 'Delete last drag' to delete the whole lastest path. Click 'Delete last step' to delete the lastest clicked control point.
56
+ """
57
+
58
+ citation = r"""
59
+ If Image Conductor is helpful, please help to ⭐ the <a href='https://github.com/liyaowei-stu/ImageConductor' target='_blank'>Github Repo</a>. Thanks!
60
+ [![GitHub Stars](https://img.shields.io/github/stars/liyaowei-stu%2FImageConductor)](https://github.com/liyaowei-stu/ImageConductor)
61
+ ---
62
+
63
+ 📝 **Citation**
64
+ <br>
65
+ If our work is useful for your research, please consider citing:
66
+ ```bibtex
67
+ @misc{li2024imageconductor,
68
+ title={Image Conductor: Precision Control for Interactive Video Synthesis},
69
+ author={Li, Yaowei and Wang, Xintao and Zhang, Zhaoyang and Wang, Zhouxia and Yuan, Ziyang and Xie, Liangbin and Zou, Yuexian and Shan, Ying},
70
+ year={2024},
71
+ eprint={2406.15339},
72
+ archivePrefix={arXiv},
73
+ primaryClass={cs.CV}
74
+ }
75
+ ```
76
+
77
+ 📧 **Contact**
78
+ <br>
79
+ If you have any questions, please feel free to reach me out at <b>ywl@stu.pku.edu.cn</b>.
80
+
81
+ # """
82
+
83
+ os.makedirs("models/personalized")
84
+ os.system(f'wget https://huggingface.co/TencentARC/ImageConductor/blob/main/flow_controlnet.ckpt -P models/')
85
+ os.system(f'wget https://huggingface.co/TencentARC/ImageConductor/blob/main/image_controlnet.ckpt -P models/')
86
+ os.system(f'wget https://huggingface.co/TencentARC/ImageConductor/blob/main/unet.ckpt -P models/')
87
+ os.system(f'wget https://huggingface.co/TencentARC/ImageConductor/blob/main/helloobjects_V12c.safetensors -P models/personalized')
88
+ os.system(f'wget https://huggingface.co/TencentARC/ImageConductor/blob/main/TUSUN.safetensors -P models/personalized')
89
+
90
+
91
+
92
+
93
+ # - - - - - examples - - - - - #
94
+ image_examples = [
95
+ ["__asset__/images/object/turtle-1.jpg",
96
+ "a sea turtle gracefully swimming over a coral reef in the clear blue ocean.",
97
+ "object",
98
+ 11318446767408804497,
99
+ "",
100
+ json.load(open("__asset__/trajs/object/turtle-1.json")),
101
+ "__asset__/images/object/turtle-1.jpg",
102
+ ],
103
+
104
+ ["__asset__/images/object/rose-1.jpg",
105
+ "a red rose engulfed in flames.",
106
+ "object",
107
+ 6854275249656120509,
108
+ "",
109
+ json.load(open("__asset__/trajs/object/rose-1.json")),
110
+ "__asset__/images/object/rose-1.jpg",
111
+ ],
112
+
113
+ ["__asset__/images/object/jellyfish-1.jpg",
114
+ "intricate detailing,photorealism,hyperrealistic, glowing jellyfish mushroom, flying, starry sky, bokeh, golden ratio composition.",
115
+ "object",
116
+ 17966188172968903484,
117
+ "HelloObject",
118
+ json.load(open("__asset__/trajs/object/jellyfish-1.json")),
119
+ "__asset__/images/object/jellyfish-1.jpg",
120
+ ],
121
+
122
+
123
+ ["__asset__/images/camera/lush-1.jpg",
124
+ "detailed craftsmanship, photorealism, hyperrealistic, roaring waterfall, misty spray, lush greenery, vibrant rainbow, golden ratio composition.",
125
+ "camera",
126
+ 7970487946960948963,
127
+ "HelloObject",
128
+ json.load(open("__asset__/trajs/camera/lush-1.json")),
129
+ "__asset__/images/camera/lush-1.jpg",
130
+ ],
131
+
132
+ ["__asset__/images/camera/tusun-1.jpg",
133
+ "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing.",
134
+ "camera",
135
+ 996953226890228361,
136
+ "TUSUN",
137
+ json.load(open("__asset__/trajs/camera/tusun-1.json")),
138
+ "__asset__/images/camera/tusun-1.jpg",
139
+ ],
140
+
141
+ ["__asset__/images/camera/painting-1.jpg",
142
+ "A oil painting.",
143
+ "camera",
144
+ 16867854766769816385,
145
+ "",
146
+ json.load(open("__asset__/trajs/camera/painting-1.json")),
147
+ "__asset__/images/camera/painting-1.jpg",
148
+ ],
149
+
150
+ ]
151
+
152
+
153
+ DREAM_BOOTH = {
154
+ 'HelloObject': 'models/personalized/helloobjects_V12c.safetensors',
155
+ }
156
+
157
+ LORA = {
158
+ 'TUSUN': 'models/personalized/TUSUN.safetensors',
159
+ }
160
+
161
+ LORA_ALPHA = {
162
+ 'TUSUN': 0.6,
163
+ }
164
+
165
+ NPROMPT = {
166
+ "HelloObject": 'FastNegativeV2,(bad-artist:1),(worst quality, low quality:1.4),(bad_prompt_version2:0.8),bad-hands-5,lowres,bad anatomy,bad hands,((text)),(watermark),error,missing fingers,extra digit,fewer digits,cropped,worst quality,low quality,normal quality,((username)),blurry,(extra limbs),bad-artist-anime,badhandv4,EasyNegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3,BadDream,(three hands:1.6),(three legs:1.2),(more than two hands:1.4),(more than two legs,:1.2)'
167
+ }
168
+
169
+ output_dir = "outputs"
170
+ ensure_dirname(output_dir)
171
+
172
+ def points_to_flows(track_points, model_length, height, width):
173
+ input_drag = np.zeros((model_length - 1, height, width, 2))
174
+ for splited_track in track_points:
175
+ if len(splited_track) == 1: # stationary point
176
+ displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
177
+ splited_track = tuple([splited_track[0], displacement_point])
178
+ # interpolate the track
179
+ splited_track = interpolate_trajectory(splited_track, model_length)
180
+ splited_track = splited_track[:model_length]
181
+ if len(splited_track) < model_length:
182
+ splited_track = splited_track + [splited_track[-1]] * (model_length -len(splited_track))
183
+ for i in range(model_length - 1):
184
+ start_point = splited_track[i]
185
+ end_point = splited_track[i+1]
186
+ input_drag[i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0]
187
+ input_drag[i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1]
188
+ return input_drag
189
+
190
+ class ImageConductor:
191
+ def __init__(self, device, unet_path, image_controlnet_path, flow_controlnet_path, height, width, model_length, lora_rank=64):
192
+ self.device = device
193
+ tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
194
+ text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").cuda()
195
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").cuda()
196
+ inference_config = OmegaConf.load("configs/inference/inference.yaml")
197
+ unet = UNet3DConditionFlowModel.from_pretrained_2d("runwayml/stable-diffusion-v1-5", subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
198
+
199
+ self.vae = vae
200
+
201
+ ### >>> Initialize UNet module >>> ###
202
+ load_model(unet, unet_path)
203
+
204
+ ### >>> Initialize image controlnet module >>> ###
205
+ image_controlnet = create_image_controlnet("configs/inference/image_condition.yaml", unet)
206
+ load_model(image_controlnet, image_controlnet_path)
207
+ ### >>> Initialize flow controlnet module >>> ###
208
+ flow_controlnet = create_flow_controlnet("configs/inference/flow_condition.yaml", unet)
209
+ add_LoRA_to_controlnet(lora_rank, flow_controlnet)
210
+ load_model(flow_controlnet, flow_controlnet_path)
211
+
212
+ unet.eval().to(device)
213
+ image_controlnet.eval().to(device)
214
+ flow_controlnet.eval().to(device)
215
+
216
+ self.pipeline = ImageConductorPipeline(
217
+ unet=unet,
218
+ vae=vae,
219
+ tokenizer=tokenizer,
220
+ text_encoder=text_encoder,
221
+ scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
222
+ image_controlnet=image_controlnet,
223
+ flow_controlnet=flow_controlnet,
224
+ ).to(device)
225
+
226
+
227
+ self.height = height
228
+ self.width = width
229
+ # _, model_step, _ = split_filename(model_path)
230
+ # self.ouput_prefix = f'{model_step}_{width}X{height}'
231
+ self.model_length = model_length
232
+
233
+ blur_kernel = bivariate_Gaussian(kernel_size=99, sig_x=10, sig_y=10, theta=0, grid=None, isotropic=True)
234
+
235
+ self.blur_kernel = blur_kernel
236
+
237
+ @torch.no_grad()
238
+ def run(self, first_frame_path, tracking_points, prompt, drag_mode, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized):
239
+
240
+
241
+ original_width, original_height=384, 256
242
+ if isinstance(tracking_points, list):
243
+ input_all_points = tracking_points
244
+ else:
245
+ input_all_points = tracking_points.constructor_args['value']
246
+
247
+
248
+ resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
249
+
250
+ dir, base, ext = split_filename(first_frame_path)
251
+ id = base.split('_')[-1]
252
+
253
+
254
+ with open(f'{output_dir}/points-{id}.json', 'w') as f:
255
+ json.dump(input_all_points, f)
256
+
257
+
258
+ visualized_drag, _ = visualize_drag(first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length)
259
+
260
+ ## image condition
261
+ image_transforms = transforms.Compose([
262
+ transforms.RandomResizedCrop(
263
+ (self.height, self.width), (1.0, 1.0),
264
+ ratio=(self.width/self.height, self.width/self.height)
265
+ ),
266
+ transforms.ToTensor(),
267
+ ])
268
+
269
+ image_norm = lambda x: x
270
+ image_paths = [first_frame_path]
271
+ controlnet_images = [image_norm(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths]
272
+ controlnet_images = torch.stack(controlnet_images).unsqueeze(0).cuda()
273
+ controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w")
274
+ num_controlnet_images = controlnet_images.shape[2]
275
+ controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w")
276
+ controlnet_images = self.vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215
277
+ controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)
278
+
279
+ # flow condition
280
+ controlnet_flows = points_to_flows(resized_all_points, self.model_length, self.height, self.width)
281
+ for i in range(0, self.model_length-1):
282
+ controlnet_flows[i] = cv2.filter2D(controlnet_flows[i], -1, self.blur_kernel)
283
+ controlnet_flows = np.concatenate([np.zeros_like(controlnet_flows[0])[np.newaxis, ...], controlnet_flows], axis=0) # pad the first frame with zero flow
284
+ os.makedirs(os.path.join(output_dir, "control_flows"), exist_ok=True)
285
+ trajs_video = vis_flow_to_video(controlnet_flows, num_frames=self.model_length) # T-1 x H x W x 3
286
+ torchvision.io.write_video(f'{output_dir}/control_flows/sample-{id}-train_flow.mp4', trajs_video, fps=8, video_codec='h264', options={'crf': '10'})
287
+ controlnet_flows = torch.from_numpy(controlnet_flows)[None].to(controlnet_images)[:, :self.model_length, ...]
288
+ controlnet_flows = rearrange(controlnet_flows, "b f h w c-> b c f h w")
289
+
290
+ dreambooth_model_path = DREAM_BOOTH.get(personalized, '')
291
+ lora_model_path = LORA.get(personalized, '')
292
+ lora_alpha = LORA_ALPHA.get(personalized, 0.6)
293
+ self.pipeline = load_weights(
294
+ self.pipeline,
295
+ dreambooth_model_path = dreambooth_model_path,
296
+ lora_model_path = lora_model_path,
297
+ lora_alpha = lora_alpha,
298
+ ).to(device)
299
+
300
+ if NPROMPT.get(personalized, '') != '':
301
+ negative_prompt = NPROMPT.get(personalized)
302
+
303
+ if randomize_seed:
304
+ random_seed = torch.seed()
305
+ else:
306
+ seed = int(seed)
307
+ random_seed = seed
308
+ torch.manual_seed(random_seed)
309
+ torch.cuda.manual_seed_all(random_seed)
310
+ print(f"current seed: {torch.initial_seed()}")
311
+ sample = self.pipeline(
312
+ prompt,
313
+ negative_prompt = negative_prompt,
314
+ num_inference_steps = num_inference_steps,
315
+ guidance_scale = guidance_scale,
316
+ width = self.width,
317
+ height = self.height,
318
+ video_length = self.model_length,
319
+ controlnet_images = controlnet_images, # 1 4 1 32 48
320
+ controlnet_image_index = [0],
321
+ controlnet_flows = controlnet_flows,# [1, 2, 16, 256, 384]
322
+ control_mode = drag_mode,
323
+ eval_mode = True,
324
+ ).videos
325
+
326
+ outputs_path = os.path.join(output_dir, f'output_{i}_{id}.mp4')
327
+ vis_video = (rearrange(sample[0], 'c t h w -> t h w c') * 255.).clip(0, 255)
328
+ torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec='h264', options={'crf': '10'})
329
+
330
+ return visualized_drag, outputs_path
331
+
332
+
333
+ def reset_states(first_frame_path, tracking_points):
334
+ first_frame_path = gr.State()
335
+ tracking_points = gr.State([])
336
+ return None, first_frame_path, tracking_points
337
+
338
+
339
+ def preprocess_image(image):
340
+ image_pil = image2pil(image.name)
341
+ raw_w, raw_h = image_pil.size
342
+ resize_ratio = max(384/raw_w, 256/raw_h)
343
+ image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
344
+ image_pil = transforms.CenterCrop((256, 384))(image_pil.convert('RGB'))
345
+ id = str(uuid.uuid4())[:4]
346
+ first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
347
+ image_pil.save(first_frame_path, quality=95)
348
+ return first_frame_path, first_frame_path, gr.State([])
349
+
350
+
351
+ def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData
352
+ if drag_mode=='object':
353
+ color = (255, 0, 0, 255)
354
+ elif drag_mode=='camera':
355
+ color = (0, 0, 255, 255)
356
+
357
+
358
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
359
+ tracking_points.constructor_args['value'][-1].append(evt.index)
360
+ print(tracking_points.constructor_args)
361
+
362
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
363
+ w, h = transparent_background.size
364
+ transparent_layer = np.zeros((h, w, 4))
365
+ for track in tracking_points.constructor_args['value']:
366
+ if len(track) > 1:
367
+ for i in range(len(track)-1):
368
+ start_point = track[i]
369
+ end_point = track[i+1]
370
+ vx = end_point[0] - start_point[0]
371
+ vy = end_point[1] - start_point[1]
372
+ arrow_length = np.sqrt(vx**2 + vy**2)
373
+ if i == len(track)-2:
374
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
375
+ else:
376
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
377
+ else:
378
+ cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
379
+
380
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
381
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
382
+ return tracking_points, trajectory_map
383
+
384
+
385
+ def add_drag(tracking_points):
386
+ tracking_points.constructor_args['value'].append([])
387
+ print(tracking_points.constructor_args)
388
+ return tracking_points
389
+
390
+
391
+ def delete_last_drag(tracking_points, first_frame_path, drag_mode):
392
+ if drag_mode=='object':
393
+ color = (255, 0, 0, 255)
394
+ elif drag_mode=='camera':
395
+ color = (0, 0, 255, 255)
396
+ tracking_points.constructor_args['value'].pop()
397
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
398
+ w, h = transparent_background.size
399
+ transparent_layer = np.zeros((h, w, 4))
400
+ for track in tracking_points.constructor_args['value']:
401
+ if len(track) > 1:
402
+ for i in range(len(track)-1):
403
+ start_point = track[i]
404
+ end_point = track[i+1]
405
+ vx = end_point[0] - start_point[0]
406
+ vy = end_point[1] - start_point[1]
407
+ arrow_length = np.sqrt(vx**2 + vy**2)
408
+ if i == len(track)-2:
409
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
410
+ else:
411
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
412
+ else:
413
+ cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
414
+
415
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
416
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
417
+ return tracking_points, trajectory_map
418
+
419
+
420
+ def delete_last_step(tracking_points, first_frame_path, drag_mode):
421
+ if drag_mode=='object':
422
+ color = (255, 0, 0, 255)
423
+ elif drag_mode=='camera':
424
+ color = (0, 0, 255, 255)
425
+ tracking_points.constructor_args['value'][-1].pop()
426
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
427
+ w, h = transparent_background.size
428
+ transparent_layer = np.zeros((h, w, 4))
429
+ for track in tracking_points.constructor_args['value']:
430
+ if len(track) > 1:
431
+ for i in range(len(track)-1):
432
+ start_point = track[i]
433
+ end_point = track[i+1]
434
+ vx = end_point[0] - start_point[0]
435
+ vy = end_point[1] - start_point[1]
436
+ arrow_length = np.sqrt(vx**2 + vy**2)
437
+ if i == len(track)-2:
438
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
439
+ else:
440
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
441
+ else:
442
+ cv2.circle(transparent_layer, tuple(track[0]), 5,color, -1)
443
+
444
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
445
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
446
+ return tracking_points, trajectory_map
447
+
448
+
449
+ block = gr.Blocks(
450
+ theme=gr.themes.Soft(
451
+ radius_size=gr.themes.sizes.radius_none,
452
+ text_size=gr.themes.sizes.text_md
453
+ )
454
+ ).queue()
455
+ with block as demo:
456
+ with gr.Row():
457
+ with gr.Column():
458
+ gr.HTML(head)
459
+
460
+ gr.Markdown(descriptions)
461
+
462
+ with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"):
463
+ with gr.Row(equal_height=True):
464
+ gr.Markdown(instructions)
465
+
466
+
467
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
468
+ unet_path = 'models/unet.ckpt'
469
+ image_controlnet_path = 'models/image_controlnet.ckpt'
470
+ flow_controlnet_path = 'models/flow_controlnet.ckpt'
471
+ ImageConductor_net = ImageConductor(device=device,
472
+ unet_path=unet_path,
473
+ image_controlnet_path=image_controlnet_path,
474
+ flow_controlnet_path=flow_controlnet_path,
475
+ height=256,
476
+ width=384,
477
+ model_length=16
478
+ )
479
+ first_frame_path = gr.State()
480
+ tracking_points = gr.State([])
481
+
482
+
483
+ with gr.Row():
484
+ with gr.Column(scale=1):
485
+ image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
486
+ add_drag_button = gr.Button(value="Add Drag")
487
+ reset_button = gr.Button(value="Reset")
488
+ delete_last_drag_button = gr.Button(value="Delete last drag")
489
+ delete_last_step_button = gr.Button(value="Delete last step")
490
+
491
+
492
+
493
+ with gr.Column(scale=7):
494
+ with gr.Row():
495
+ with gr.Column(scale=6):
496
+ input_image = gr.Image(label=None,
497
+ interactive=True,
498
+ height=256,
499
+ width=384,)
500
+ with gr.Column(scale=6):
501
+ output_image = gr.Image(label="Motion Path",
502
+ interactive=False,
503
+ height=256,
504
+ width=384,)
505
+ with gr.Row():
506
+ with gr.Column(scale=1):
507
+ prompt = gr.Textbox(value="a wonderful elf.", label="Prompt (highly-recommended)", interactive=True, visible=True)
508
+ negative_prompt = gr.Text(
509
+ label="Negative Prompt",
510
+ max_lines=5,
511
+ placeholder="Please input your negative prompt",
512
+ value='worst quality, low quality, letterboxed',lines=1
513
+ )
514
+ drag_mode = gr.Radio(['camera', 'object'], label='Drag mode: ', value='object', scale=2)
515
+ run_button = gr.Button(value="Run")
516
+
517
+ with gr.Accordion("More input params", open=False, elem_id="accordion1"):
518
+ with gr.Group():
519
+ seed = gr.Textbox(
520
+ label="Seed: ", value=561793204,
521
+ )
522
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
523
+
524
+ with gr.Group():
525
+ with gr.Row():
526
+ guidance_scale = gr.Slider(
527
+ label="Guidance scale",
528
+ minimum=1,
529
+ maximum=12,
530
+ step=0.1,
531
+ value=8.5,
532
+ )
533
+ num_inference_steps = gr.Slider(
534
+ label="Number of inference steps",
535
+ minimum=1,
536
+ maximum=50,
537
+ step=1,
538
+ value=25,
539
+ )
540
+
541
+ with gr.Group():
542
+ personalized = gr.Dropdown(label="Personalized template", choices=['HelloObject', 'TUSUN'], value="")
543
+
544
+ with gr.Column(scale=7):
545
+ output_video = gr.Video(value=None,
546
+ label="Output Video",
547
+ width=384,
548
+ height=256)
549
+
550
+
551
+ with gr.Row():
552
+ def process_example(input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path):
553
+
554
+ return input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path
555
+
556
+ example = gr.Examples(
557
+ label="Input Example",
558
+ examples=image_examples,
559
+ inputs=[input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path],
560
+ outputs=[input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path],
561
+ fn=process_example,
562
+ run_on_click=True,
563
+ examples_per_page=10
564
+ )
565
+
566
+ with gr.Row():
567
+ gr.Markdown(citation)
568
+
569
+
570
+ image_upload_button.upload(preprocess_image, image_upload_button, [input_image, first_frame_path, tracking_points])
571
+
572
+ add_drag_button.click(add_drag, tracking_points, tracking_points)
573
+
574
+ delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image])
575
+
576
+ delete_last_step_button.click(delete_last_step, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image])
577
+
578
+ reset_button.click(reset_states, [first_frame_path, tracking_points], [input_image, first_frame_path, tracking_points])
579
+
580
+ input_image.select(add_tracking_points, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image])
581
+
582
+ run_button.click(ImageConductor_net.run, [first_frame_path, tracking_points, prompt, drag_mode,
583
+ negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized],
584
+ [output_image, output_video])
585
+
586
+ demo.launch(server_name="0.0.0.0", debug=True, server_port=12345)
app.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ import cv2
5
+ import uuid
6
+ import torch
7
+ import torchvision
8
+ import json
9
+
10
+ from PIL import Image
11
+ from omegaconf import OmegaConf
12
+ from einops import rearrange, repeat
13
+ from torchvision import transforms
14
+ from transformers import CLIPTextModel, CLIPTokenizer
15
+ from diffusers import AutoencoderKL, DDIMScheduler
16
+
17
+ from pipelines.pipeline_imagecoductor import ImageConductorPipeline
18
+ from modules.unet import UNet3DConditionFlowModel
19
+ from utils.gradio_utils import ensure_dirname, split_filename, visualize_drag, image2pil, image2arr
20
+ from utils.utils import create_image_controlnet, create_flow_controlnet, interpolate_trajectory, load_weights, load_model, bivariate_Gaussian
21
+ from utils.lora_utils import add_LoRA_to_controlnet
22
+ from utils.visualizer import Visualizer, vis_flow_to_video
23
+ #### Description ####
24
+ title = r"""<h1 align="center">CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models</h1>"""
25
+
26
+ head = r"""
27
+ <div style="text-align: center;">
28
+ <h1>Image Conductor: Precision Control for Interactive Video Synthesis</h1>
29
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
30
+ <a href=""></a>
31
+ <a href='https://liyaowei-stu.github.io/project/ImageConductor/'><img src='https://img.shields.io/badge/Project_Page-ImgaeConductor-green' alt='Project Page'></a>
32
+ <a href='https://arxiv.org/pdf/2406.15339'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
33
+ <a href='https://github.com/liyaowei-stu/ImageConductor'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
34
+
35
+
36
+ </div>
37
+ </br>
38
+ </div>
39
+ """
40
+
41
+
42
+
43
+ descriptions = r"""
44
+ Official Gradio Demo for <a href='https://github.com/liyaowei-stu/ImageConductor'><b>Image Conductor: Precision Control for Interactive Video Synthesis</b></a>.<br>
45
+ 🧙Image Conductor enables precise, fine-grained control for generating motion-controllable videos from images, advancing the practical application of interactive video synthesis.<br>
46
+ """
47
+
48
+
49
+ instructions = r"""
50
+ - ⭐️ <b>step1: </b>Upload or select one image from Example.
51
+ - ⭐️ <b>step2: </b>Click 'Add Drag' to draw some drags.
52
+ - ⭐️ <b>step3: </b>Input text prompt that complements the image (Necessary).
53
+ - ⭐️ <b>step4: </b>Select 'Drag Mode' to specify the control of camera transition or object movement.
54
+ - ⭐️ <b>step5: </b>Click 'Run' button to generate video assets.
55
+ - ⭐️ <b>others: </b>Click 'Delete last drag' to delete the whole lastest path. Click 'Delete last step' to delete the lastest clicked control point.
56
+ """
57
+
58
+ citation = r"""
59
+ If Image Conductor is helpful, please help to ⭐ the <a href='https://github.com/liyaowei-stu/ImageConductor' target='_blank'>Github Repo</a>. Thanks!
60
+ [![GitHub Stars](https://img.shields.io/github/stars/liyaowei-stu%2FImageConductor)](https://github.com/liyaowei-stu/ImageConductor)
61
+ ---
62
+
63
+ 📝 **Citation**
64
+ <br>
65
+ If our work is useful for your research, please consider citing:
66
+ ```bibtex
67
+ @misc{li2024imageconductor,
68
+ title={Image Conductor: Precision Control for Interactive Video Synthesis},
69
+ author={Li, Yaowei and Wang, Xintao and Zhang, Zhaoyang and Wang, Zhouxia and Yuan, Ziyang and Xie, Liangbin and Zou, Yuexian and Shan, Ying},
70
+ year={2024},
71
+ eprint={2406.15339},
72
+ archivePrefix={arXiv},
73
+ primaryClass={cs.CV}
74
+ }
75
+ ```
76
+
77
+ 📧 **Contact**
78
+ <br>
79
+ If you have any questions, please feel free to reach me out at <b>ywl@stu.pku.edu.cn</b>.
80
+
81
+ # """
82
+
83
+
84
+ # - - - - - examples - - - - - #
85
+ image_examples = [
86
+ ["__asset__/images/object/turtle-1.jpg",
87
+ "a sea turtle gracefully swimming over a coral reef in the clear blue ocean.",
88
+ "object",
89
+ 11318446767408804497,
90
+ "",
91
+ json.load(open("__asset__/trajs/object/turtle-1.json")),
92
+ "__asset__/images/object/turtle-1.jpg",
93
+ ],
94
+
95
+ ["__asset__/images/object/rose-1.jpg",
96
+ "a red rose engulfed in flames.",
97
+ "object",
98
+ 6854275249656120509,
99
+ "",
100
+ json.load(open("__asset__/trajs/object/rose-1.json")),
101
+ "__asset__/images/object/rose-1.jpg",
102
+ ],
103
+
104
+ ["__asset__/images/object/jellyfish-1.jpg",
105
+ "intricate detailing,photorealism,hyperrealistic, glowing jellyfish mushroom, flying, starry sky, bokeh, golden ratio composition.",
106
+ "object",
107
+ 17966188172968903484,
108
+ "HelloObject",
109
+ json.load(open("__asset__/trajs/object/jellyfish-1.json")),
110
+ "__asset__/images/object/jellyfish-1.jpg",
111
+ ],
112
+
113
+
114
+ ["__asset__/images/camera/lush-1.jpg",
115
+ "detailed craftsmanship, photorealism, hyperrealistic, roaring waterfall, misty spray, lush greenery, vibrant rainbow, golden ratio composition.",
116
+ "camera",
117
+ 7970487946960948963,
118
+ "HelloObject",
119
+ json.load(open("__asset__/trajs/camera/lush-1.json")),
120
+ "__asset__/images/camera/lush-1.jpg",
121
+ ],
122
+
123
+ ["__asset__/images/camera/tusun-1.jpg",
124
+ "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing.",
125
+ "camera",
126
+ 996953226890228361,
127
+ "TUSUN",
128
+ json.load(open("__asset__/trajs/camera/tusun-1.json")),
129
+ "__asset__/images/camera/tusun-1.jpg",
130
+ ],
131
+
132
+ ["__asset__/images/camera/painting-1.jpg",
133
+ "A oil painting.",
134
+ "camera",
135
+ 16867854766769816385,
136
+ "",
137
+ json.load(open("__asset__/trajs/camera/painting-1.json")),
138
+ "__asset__/images/camera/painting-1.jpg",
139
+ ],
140
+
141
+ ]
142
+
143
+
144
+ DREAM_BOOTH = {
145
+ 'HelloObject': 'models/personalized/helloobjects_V12c.safetensors',
146
+ }
147
+
148
+ LORA = {
149
+ 'TUSUN': 'models/personalized/TUSUN.safetensors',
150
+ }
151
+
152
+ LORA_ALPHA = {
153
+ 'TUSUN': 0.6,
154
+ }
155
+
156
+ NPROMPT = {
157
+ "HelloObject": 'FastNegativeV2,(bad-artist:1),(worst quality, low quality:1.4),(bad_prompt_version2:0.8),bad-hands-5,lowres,bad anatomy,bad hands,((text)),(watermark),error,missing fingers,extra digit,fewer digits,cropped,worst quality,low quality,normal quality,((username)),blurry,(extra limbs),bad-artist-anime,badhandv4,EasyNegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3,BadDream,(three hands:1.6),(three legs:1.2),(more than two hands:1.4),(more than two legs,:1.2)'
158
+ }
159
+
160
+ output_dir = "outputs"
161
+ ensure_dirname(output_dir)
162
+
163
+ def points_to_flows(track_points, model_length, height, width):
164
+ input_drag = np.zeros((model_length - 1, height, width, 2))
165
+ for splited_track in track_points:
166
+ if len(splited_track) == 1: # stationary point
167
+ displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
168
+ splited_track = tuple([splited_track[0], displacement_point])
169
+ # interpolate the track
170
+ splited_track = interpolate_trajectory(splited_track, model_length)
171
+ splited_track = splited_track[:model_length]
172
+ if len(splited_track) < model_length:
173
+ splited_track = splited_track + [splited_track[-1]] * (model_length -len(splited_track))
174
+ for i in range(model_length - 1):
175
+ start_point = splited_track[i]
176
+ end_point = splited_track[i+1]
177
+ input_drag[i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0]
178
+ input_drag[i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1]
179
+ return input_drag
180
+
181
+ class ImageConductor:
182
+ def __init__(self, device, unet_path, image_controlnet_path, flow_controlnet_path, height, width, model_length, lora_rank=64):
183
+ self.device = device
184
+ tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
185
+ text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").cuda()
186
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").cuda()
187
+ inference_config = OmegaConf.load("configs/inference/inference.yaml")
188
+ unet = UNet3DConditionFlowModel.from_pretrained_2d("runwayml/stable-diffusion-v1-5", subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
189
+
190
+ self.vae = vae
191
+
192
+ ### >>> Initialize UNet module >>> ###
193
+ load_model(unet, unet_path)
194
+
195
+ ### >>> Initialize image controlnet module >>> ###
196
+ image_controlnet = create_image_controlnet("configs/inference/image_condition.yaml", unet)
197
+ load_model(image_controlnet, image_controlnet_path)
198
+ ### >>> Initialize flow controlnet module >>> ###
199
+ flow_controlnet = create_flow_controlnet("configs/inference/flow_condition.yaml", unet)
200
+ add_LoRA_to_controlnet(lora_rank, flow_controlnet)
201
+ load_model(flow_controlnet, flow_controlnet_path)
202
+
203
+ unet.eval().to(device)
204
+ image_controlnet.eval().to(device)
205
+ flow_controlnet.eval().to(device)
206
+
207
+ self.pipeline = ImageConductorPipeline(
208
+ unet=unet,
209
+ vae=vae,
210
+ tokenizer=tokenizer,
211
+ text_encoder=text_encoder,
212
+ scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
213
+ image_controlnet=image_controlnet,
214
+ flow_controlnet=flow_controlnet,
215
+ ).to(device)
216
+
217
+
218
+ self.height = height
219
+ self.width = width
220
+ # _, model_step, _ = split_filename(model_path)
221
+ # self.ouput_prefix = f'{model_step}_{width}X{height}'
222
+ self.model_length = model_length
223
+
224
+ blur_kernel = bivariate_Gaussian(kernel_size=99, sig_x=10, sig_y=10, theta=0, grid=None, isotropic=True)
225
+
226
+ self.blur_kernel = blur_kernel
227
+
228
+ @torch.no_grad()
229
+ def run(self, first_frame_path, tracking_points, prompt, drag_mode, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized):
230
+
231
+
232
+ original_width, original_height=384, 256
233
+ if isinstance(tracking_points, list):
234
+ input_all_points = tracking_points
235
+ else:
236
+ input_all_points = tracking_points.constructor_args['value']
237
+
238
+
239
+ resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
240
+
241
+ dir, base, ext = split_filename(first_frame_path)
242
+ id = base.split('_')[-1]
243
+
244
+
245
+ with open(f'{output_dir}/points-{id}.json', 'w') as f:
246
+ json.dump(input_all_points, f)
247
+
248
+
249
+ visualized_drag, _ = visualize_drag(first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length)
250
+
251
+ ## image condition
252
+ image_transforms = transforms.Compose([
253
+ transforms.RandomResizedCrop(
254
+ (self.height, self.width), (1.0, 1.0),
255
+ ratio=(self.width/self.height, self.width/self.height)
256
+ ),
257
+ transforms.ToTensor(),
258
+ ])
259
+
260
+ image_norm = lambda x: x
261
+ image_paths = [first_frame_path]
262
+ controlnet_images = [image_norm(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths]
263
+ controlnet_images = torch.stack(controlnet_images).unsqueeze(0).cuda()
264
+ controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w")
265
+ num_controlnet_images = controlnet_images.shape[2]
266
+ controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w")
267
+ controlnet_images = self.vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215
268
+ controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)
269
+
270
+ # flow condition
271
+ controlnet_flows = points_to_flows(resized_all_points, self.model_length, self.height, self.width)
272
+ for i in range(0, self.model_length-1):
273
+ controlnet_flows[i] = cv2.filter2D(controlnet_flows[i], -1, self.blur_kernel)
274
+ controlnet_flows = np.concatenate([np.zeros_like(controlnet_flows[0])[np.newaxis, ...], controlnet_flows], axis=0) # pad the first frame with zero flow
275
+ os.makedirs(os.path.join(output_dir, "control_flows"), exist_ok=True)
276
+ trajs_video = vis_flow_to_video(controlnet_flows, num_frames=self.model_length) # T-1 x H x W x 3
277
+ torchvision.io.write_video(f'{output_dir}/control_flows/sample-{id}-train_flow.mp4', trajs_video, fps=8, video_codec='h264', options={'crf': '10'})
278
+ controlnet_flows = torch.from_numpy(controlnet_flows)[None].to(controlnet_images)[:, :self.model_length, ...]
279
+ controlnet_flows = rearrange(controlnet_flows, "b f h w c-> b c f h w")
280
+
281
+ dreambooth_model_path = DREAM_BOOTH.get(personalized, '')
282
+ lora_model_path = LORA.get(personalized, '')
283
+ lora_alpha = LORA_ALPHA.get(personalized, 0.6)
284
+ self.pipeline = load_weights(
285
+ self.pipeline,
286
+ dreambooth_model_path = dreambooth_model_path,
287
+ lora_model_path = lora_model_path,
288
+ lora_alpha = lora_alpha,
289
+ ).to(device)
290
+
291
+ if NPROMPT.get(personalized, '') != '':
292
+ negative_prompt = NPROMPT.get(personalized)
293
+
294
+ if randomize_seed:
295
+ random_seed = torch.seed()
296
+ else:
297
+ seed = int(seed)
298
+ random_seed = seed
299
+ torch.manual_seed(random_seed)
300
+ torch.cuda.manual_seed_all(random_seed)
301
+ print(f"current seed: {torch.initial_seed()}")
302
+ sample = self.pipeline(
303
+ prompt,
304
+ negative_prompt = negative_prompt,
305
+ num_inference_steps = num_inference_steps,
306
+ guidance_scale = guidance_scale,
307
+ width = self.width,
308
+ height = self.height,
309
+ video_length = self.model_length,
310
+ controlnet_images = controlnet_images, # 1 4 1 32 48
311
+ controlnet_image_index = [0],
312
+ controlnet_flows = controlnet_flows,# [1, 2, 16, 256, 384]
313
+ control_mode = drag_mode,
314
+ eval_mode = True,
315
+ ).videos
316
+
317
+ outputs_path = os.path.join(output_dir, f'output_{i}_{id}.mp4')
318
+ vis_video = (rearrange(sample[0], 'c t h w -> t h w c') * 255.).clip(0, 255)
319
+ torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec='h264', options={'crf': '10'})
320
+
321
+ return visualized_drag, outputs_path
322
+
323
+
324
+ def reset_states(first_frame_path, tracking_points):
325
+ first_frame_path = gr.State()
326
+ tracking_points = gr.State([])
327
+ return None, first_frame_path, tracking_points
328
+
329
+
330
+ def preprocess_image(image):
331
+ image_pil = image2pil(image.name)
332
+ raw_w, raw_h = image_pil.size
333
+ resize_ratio = max(384/raw_w, 256/raw_h)
334
+ image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
335
+ image_pil = transforms.CenterCrop((256, 384))(image_pil.convert('RGB'))
336
+ id = str(uuid.uuid4())[:4]
337
+ first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
338
+ image_pil.save(first_frame_path, quality=95)
339
+ return first_frame_path, first_frame_path, gr.State([])
340
+
341
+
342
+ def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData
343
+ if drag_mode=='object':
344
+ color = (255, 0, 0, 255)
345
+ elif drag_mode=='camera':
346
+ color = (0, 0, 255, 255)
347
+
348
+
349
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
350
+ tracking_points.constructor_args['value'][-1].append(evt.index)
351
+ print(tracking_points.constructor_args)
352
+
353
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
354
+ w, h = transparent_background.size
355
+ transparent_layer = np.zeros((h, w, 4))
356
+ for track in tracking_points.constructor_args['value']:
357
+ if len(track) > 1:
358
+ for i in range(len(track)-1):
359
+ start_point = track[i]
360
+ end_point = track[i+1]
361
+ vx = end_point[0] - start_point[0]
362
+ vy = end_point[1] - start_point[1]
363
+ arrow_length = np.sqrt(vx**2 + vy**2)
364
+ if i == len(track)-2:
365
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
366
+ else:
367
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
368
+ else:
369
+ cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
370
+
371
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
372
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
373
+ return tracking_points, trajectory_map
374
+
375
+
376
+ def add_drag(tracking_points):
377
+ tracking_points.constructor_args['value'].append([])
378
+ print(tracking_points.constructor_args)
379
+ return tracking_points
380
+
381
+
382
+ def delete_last_drag(tracking_points, first_frame_path, drag_mode):
383
+ if drag_mode=='object':
384
+ color = (255, 0, 0, 255)
385
+ elif drag_mode=='camera':
386
+ color = (0, 0, 255, 255)
387
+ tracking_points.constructor_args['value'].pop()
388
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
389
+ w, h = transparent_background.size
390
+ transparent_layer = np.zeros((h, w, 4))
391
+ for track in tracking_points.constructor_args['value']:
392
+ if len(track) > 1:
393
+ for i in range(len(track)-1):
394
+ start_point = track[i]
395
+ end_point = track[i+1]
396
+ vx = end_point[0] - start_point[0]
397
+ vy = end_point[1] - start_point[1]
398
+ arrow_length = np.sqrt(vx**2 + vy**2)
399
+ if i == len(track)-2:
400
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
401
+ else:
402
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
403
+ else:
404
+ cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
405
+
406
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
407
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
408
+ return tracking_points, trajectory_map
409
+
410
+
411
+ def delete_last_step(tracking_points, first_frame_path, drag_mode):
412
+ if drag_mode=='object':
413
+ color = (255, 0, 0, 255)
414
+ elif drag_mode=='camera':
415
+ color = (0, 0, 255, 255)
416
+ tracking_points.constructor_args['value'][-1].pop()
417
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
418
+ w, h = transparent_background.size
419
+ transparent_layer = np.zeros((h, w, 4))
420
+ for track in tracking_points.constructor_args['value']:
421
+ if len(track) > 1:
422
+ for i in range(len(track)-1):
423
+ start_point = track[i]
424
+ end_point = track[i+1]
425
+ vx = end_point[0] - start_point[0]
426
+ vy = end_point[1] - start_point[1]
427
+ arrow_length = np.sqrt(vx**2 + vy**2)
428
+ if i == len(track)-2:
429
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
430
+ else:
431
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
432
+ else:
433
+ cv2.circle(transparent_layer, tuple(track[0]), 5,color, -1)
434
+
435
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
436
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
437
+ return tracking_points, trajectory_map
438
+
439
+
440
+ block = gr.Blocks(
441
+ theme=gr.themes.Soft(
442
+ radius_size=gr.themes.sizes.radius_none,
443
+ text_size=gr.themes.sizes.text_md
444
+ )
445
+ ).queue()
446
+ with block as demo:
447
+ with gr.Row():
448
+ with gr.Column():
449
+ gr.HTML(head)
450
+
451
+ gr.Markdown(descriptions)
452
+
453
+ with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"):
454
+ with gr.Row(equal_height=True):
455
+ gr.Markdown(instructions)
456
+
457
+
458
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
459
+ unet_path = 'models/unet.ckpt'
460
+ image_controlnet_path = 'models/image_controlnet.ckpt'
461
+ flow_controlnet_path = 'models/flow_controlnet.ckpt'
462
+ ImageConductor_net = ImageConductor(device=device,
463
+ unet_path=unet_path,
464
+ image_controlnet_path=image_controlnet_path,
465
+ flow_controlnet_path=flow_controlnet_path,
466
+ height=256,
467
+ width=384,
468
+ model_length=16
469
+ )
470
+ first_frame_path = gr.State()
471
+ tracking_points = gr.State([])
472
+
473
+
474
+ with gr.Row():
475
+ with gr.Column(scale=1):
476
+ image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
477
+ add_drag_button = gr.Button(value="Add Drag")
478
+ reset_button = gr.Button(value="Reset")
479
+ delete_last_drag_button = gr.Button(value="Delete last drag")
480
+ delete_last_step_button = gr.Button(value="Delete last step")
481
+
482
+
483
+
484
+ with gr.Column(scale=7):
485
+ with gr.Row():
486
+ with gr.Column(scale=6):
487
+ input_image = gr.Image(label=None,
488
+ interactive=True,
489
+ height=256,
490
+ width=384,)
491
+ with gr.Column(scale=6):
492
+ output_image = gr.Image(label="Motion Path",
493
+ interactive=False,
494
+ height=256,
495
+ width=384,)
496
+ with gr.Row():
497
+ with gr.Column(scale=1):
498
+ prompt = gr.Textbox(value="a wonderful elf.", label="Prompt (highly-recommended)", interactive=True, visible=True)
499
+ negative_prompt = gr.Text(
500
+ label="Negative Prompt",
501
+ max_lines=5,
502
+ placeholder="Please input your negative prompt",
503
+ value='worst quality, low quality, letterboxed',lines=1
504
+ )
505
+ drag_mode = gr.Radio(['camera', 'object'], label='Drag mode: ', value='object', scale=2)
506
+ run_button = gr.Button(value="Run")
507
+
508
+ with gr.Accordion("More input params", open=False, elem_id="accordion1"):
509
+ with gr.Group():
510
+ seed = gr.Textbox(
511
+ label="Seed: ", value=561793204,
512
+ )
513
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
514
+
515
+ with gr.Group():
516
+ with gr.Row():
517
+ guidance_scale = gr.Slider(
518
+ label="Guidance scale",
519
+ minimum=1,
520
+ maximum=12,
521
+ step=0.1,
522
+ value=8.5,
523
+ )
524
+ num_inference_steps = gr.Slider(
525
+ label="Number of inference steps",
526
+ minimum=1,
527
+ maximum=50,
528
+ step=1,
529
+ value=25,
530
+ )
531
+
532
+ with gr.Group():
533
+ personalized = gr.Dropdown(label="Personalized template", choices=['HelloObject', 'TUSUN'], value="")
534
+
535
+ with gr.Column(scale=7):
536
+ output_video = gr.Video(value=None,
537
+ label="Output Video",
538
+ width=384,
539
+ height=256)
540
+
541
+
542
+ with gr.Row():
543
+ def process_example(input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path):
544
+
545
+ return input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path
546
+
547
+ example = gr.Examples(
548
+ label="Input Example",
549
+ examples=image_examples,
550
+ inputs=[input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path],
551
+ outputs=[input_image, prompt, drag_mode, seed, personalized, tracking_points, first_frame_path],
552
+ fn=process_example,
553
+ run_on_click=True,
554
+ examples_per_page=10
555
+ )
556
+
557
+ with gr.Row():
558
+ gr.Markdown(citation)
559
+
560
+
561
+ image_upload_button.upload(preprocess_image, image_upload_button, [input_image, first_frame_path, tracking_points])
562
+
563
+ add_drag_button.click(add_drag, tracking_points, tracking_points)
564
+
565
+ delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image])
566
+
567
+ delete_last_step_button.click(delete_last_step, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image])
568
+
569
+ reset_button.click(reset_states, [first_frame_path, tracking_points], [input_image, first_frame_path, tracking_points])
570
+
571
+ input_image.select(add_tracking_points, [tracking_points, first_frame_path, drag_mode], [tracking_points, input_image])
572
+
573
+ run_button.click(ImageConductor_net.run, [first_frame_path, tracking_points, prompt, drag_mode,
574
+ negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized],
575
+ [output_image, output_video])
576
+
577
+ demo.launch(server_name="0.0.0.0", debug=True, server_port=12345)
configs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/inference/flow_condition.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ controlnet_additional_kwargs:
2
+ set_noisy_sample_input_to_zero: true
3
+ use_simplified_condition_embedding: true
4
+ conditioning_channels: 2
5
+ concate_conditioning_mask: false
6
+
7
+ use_motion_module: true
8
+ motion_module_resolutions: [1,2,4,8]
9
+ motion_module_mid_block: false
10
+ motion_module_type: "Vanilla"
11
+
12
+ motion_module_kwargs:
13
+ num_attention_heads: 8
14
+ num_transformer_block: 1
15
+ attention_block_types: [ "Temporal_Self" ]
16
+ temporal_position_encoding: true
17
+ temporal_position_encoding_max_len: 32
18
+ temporal_attention_dim_div: 1
configs/inference/image_condition.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ controlnet_additional_kwargs:
2
+ set_noisy_sample_input_to_zero: true
3
+ use_simplified_condition_embedding: true
4
+ conditioning_channels: 4
5
+ concate_conditioning_mask: true
6
+
7
+ use_motion_module: true
8
+ motion_module_resolutions: [1,2,4,8]
9
+ motion_module_mid_block: false
10
+ motion_module_type: "Vanilla"
11
+
12
+ motion_module_kwargs:
13
+ num_attention_heads: 8
14
+ num_transformer_block: 1
15
+ attention_block_types: [ "Temporal_Self" ]
16
+ temporal_position_encoding: true
17
+ temporal_position_encoding_max_len: 32
18
+ temporal_attention_dim_div: 1
configs/inference/inference.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true
3
+ use_motion_module: true
4
+ motion_module_resolutions: [1,2,4,8]
5
+ motion_module_mid_block: false
6
+ motion_module_type: Vanilla
7
+
8
+ motion_module_kwargs:
9
+ num_attention_heads: 8
10
+ num_transformer_block: 1
11
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
12
+ temporal_position_encoding: true
13
+ temporal_position_encoding_max_len: 32
14
+ temporal_attention_dim_div: 1
15
+ zero_initialize: true
16
+
17
+ noise_scheduler_kwargs:
18
+ beta_start: 0.00085
19
+ beta_end: 0.012
20
+ beta_schedule: "linear"
21
+ steps_offset: 1
22
+ clip_sample: False
models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
modules/__pycache__/attention.cpython-310.pyc ADDED
Binary file (6.61 kB). View file
 
modules/__pycache__/flow_controlnet.cpython-310.pyc ADDED
Binary file (14.5 kB). View file
 
modules/__pycache__/image_controlnet.cpython-310.pyc ADDED
Binary file (16.9 kB). View file
 
modules/__pycache__/motion_module.cpython-310.pyc ADDED
Binary file (8.54 kB). View file
 
modules/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (5.89 kB). View file
 
modules/__pycache__/unet.cpython-310.pyc ADDED
Binary file (14.3 kB). View file
 
modules/__pycache__/unet_blocks.cpython-310.pyc ADDED
Binary file (13.9 kB). View file
 
modules/attention.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, Optional
6
+
7
+ import torch
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.models import ModelMixin
10
+ from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward
11
+ from diffusers.utils import BaseOutput
12
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
13
+ from einops import rearrange, repeat
14
+ from torch import Tensor, nn
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class Transformer3DModelOutput(BaseOutput):
21
+ sample: torch.FloatTensor
22
+
23
+
24
+ @maybe_allow_in_graph
25
+ class Transformer3DModel(ModelMixin, ConfigMixin):
26
+ @register_to_config
27
+ def __init__(
28
+ self,
29
+ num_attention_heads: int = 16,
30
+ attention_head_dim: int = 88,
31
+ in_channels: Optional[int] = None,
32
+ num_layers: int = 1,
33
+ dropout: float = 0.0,
34
+ norm_num_groups: int = 32,
35
+ cross_attention_dim: Optional[int] = None,
36
+ attention_bias: bool = False,
37
+ activation_fn: str = "geglu",
38
+ num_embeds_ada_norm: Optional[int] = None,
39
+ use_linear_projection: bool = False,
40
+ only_cross_attention: bool = False,
41
+ upcast_attention: bool = False,
42
+ unet_use_cross_frame_attention=None,
43
+ unet_use_temporal_attention=None,
44
+ ):
45
+ super().__init__()
46
+ self.use_linear_projection = use_linear_projection
47
+ self.num_attention_heads = num_attention_heads
48
+ self.attention_head_dim = attention_head_dim
49
+ inner_dim = num_attention_heads * attention_head_dim
50
+
51
+ # Define input layers
52
+ self.in_channels = in_channels
53
+
54
+ self.norm = torch.nn.GroupNorm(
55
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
56
+ )
57
+ if use_linear_projection:
58
+ self.proj_in = nn.Linear(in_channels, inner_dim)
59
+ else:
60
+ self.proj_in = nn.Conv2d(
61
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
62
+ )
63
+
64
+ # Define transformers blocks
65
+ self.transformer_blocks = nn.ModuleList(
66
+ [
67
+ BasicTransformerBlock(
68
+ inner_dim,
69
+ num_attention_heads,
70
+ attention_head_dim,
71
+ dropout=dropout,
72
+ cross_attention_dim=cross_attention_dim,
73
+ activation_fn=activation_fn,
74
+ num_embeds_ada_norm=num_embeds_ada_norm,
75
+ attention_bias=attention_bias,
76
+ only_cross_attention=only_cross_attention,
77
+ upcast_attention=upcast_attention,
78
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
79
+ unet_use_temporal_attention=unet_use_temporal_attention,
80
+ )
81
+ for d in range(num_layers)
82
+ ]
83
+ )
84
+
85
+ # 4. Define output layers
86
+ if use_linear_projection:
87
+ self.proj_out = nn.Linear(in_channels, inner_dim)
88
+ else:
89
+ self.proj_out = nn.Conv2d(
90
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
91
+ )
92
+
93
+ def forward(
94
+ self,
95
+ hidden_states: torch.Tensor,
96
+ encoder_hidden_states: Optional[torch.Tensor] = None,
97
+ timestep: Optional[torch.LongTensor] = None,
98
+ cross_attention_kwargs: Dict[str, Any] = None,
99
+ attention_mask: Optional[torch.Tensor] = None,
100
+ encoder_attention_mask: Optional[torch.Tensor] = None,
101
+ return_dict: bool = True,
102
+ ):
103
+ # validate input dim
104
+ if hidden_states.dim() != 5:
105
+ raise ValueError(
106
+ f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
107
+ )
108
+
109
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
110
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
111
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
112
+ # expects mask of shape:
113
+ # [batch, key_tokens]
114
+ # adds singleton query_tokens dimension:
115
+ # [batch, 1, key_tokens]
116
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
117
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
118
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
119
+ if attention_mask is not None and attention_mask.ndim == 2:
120
+ # assume that mask is expressed as:
121
+ # (1 = keep, 0 = discard)
122
+ # convert mask into a bias that can be added to attention scores:
123
+ # (keep = +0, discard = -10000.0)
124
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
125
+ attention_mask = attention_mask.unsqueeze(1)
126
+
127
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
128
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
129
+ encoder_attention_mask = (
130
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
131
+ ) * -10000.0
132
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
133
+
134
+ # shenanigans for motion module
135
+ video_length = hidden_states.shape[2]
136
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
137
+ encoder_hidden_states = repeat(
138
+ encoder_hidden_states, "b n c -> (b f) n c", f=video_length
139
+ )
140
+
141
+ # 1. Input
142
+ batch, _, height, width = hidden_states.shape
143
+ residual = hidden_states
144
+
145
+ hidden_states = self.norm(hidden_states)
146
+ if not self.use_linear_projection:
147
+ hidden_states = self.proj_in(hidden_states)
148
+ inner_dim = hidden_states.shape[1]
149
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
150
+ batch, height * width, inner_dim
151
+ )
152
+ else:
153
+ inner_dim = hidden_states.shape[1]
154
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
155
+ batch, height * width, inner_dim
156
+ )
157
+ hidden_states = self.proj_in(hidden_states)
158
+
159
+ # 2. Blocks
160
+ for block in self.transformer_blocks:
161
+ hidden_states = block(
162
+ hidden_states,
163
+ attention_mask=attention_mask,
164
+ encoder_hidden_states=encoder_hidden_states,
165
+ timestep=timestep,
166
+ video_length=video_length,
167
+ encoder_attention_mask=encoder_attention_mask,
168
+ cross_attention_kwargs=cross_attention_kwargs,
169
+ )
170
+
171
+ # 3. Output
172
+ if not self.use_linear_projection:
173
+ hidden_states = (
174
+ hidden_states.reshape(batch, height, width, inner_dim)
175
+ .permute(0, 3, 1, 2)
176
+ .contiguous()
177
+ )
178
+ hidden_states = self.proj_out(hidden_states)
179
+ else:
180
+ hidden_states = self.proj_out(hidden_states)
181
+ hidden_states = (
182
+ hidden_states.reshape(batch, height, width, inner_dim)
183
+ .permute(0, 3, 1, 2)
184
+ .contiguous()
185
+ )
186
+
187
+ output = hidden_states + residual
188
+
189
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
190
+ if not return_dict:
191
+ return (output,)
192
+
193
+ return Transformer3DModelOutput(sample=output)
194
+
195
+
196
+ @maybe_allow_in_graph
197
+ class BasicTransformerBlock(nn.Module):
198
+ def __init__(
199
+ self,
200
+ dim: int,
201
+ num_attention_heads: int,
202
+ attention_head_dim: int,
203
+ dropout: float = 0.0,
204
+ cross_attention_dim: Optional[int] = None,
205
+ activation_fn: str = "geglu",
206
+ num_embeds_ada_norm: Optional[int] = None,
207
+ attention_bias: bool = False,
208
+ only_cross_attention: bool = False,
209
+ upcast_attention: bool = False,
210
+ norm_elementwise_affine: bool = True,
211
+ unet_use_cross_frame_attention: bool = False,
212
+ unet_use_temporal_attention: bool = False,
213
+ final_dropout: bool = False,
214
+ ):
215
+ super().__init__()
216
+ self.only_cross_attention = only_cross_attention
217
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
218
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
219
+ self.unet_use_temporal_attention = unet_use_temporal_attention
220
+
221
+ # Define 3 blocks. Each block has its own normalization layer.
222
+ # Self-Attn / SC-Attn
223
+ if self.use_ada_layer_norm:
224
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
225
+ else:
226
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
227
+
228
+ if unet_use_cross_frame_attention:
229
+ # this isn't actually implemented anywhere in the AnimateDiff codebase or in Diffusers...
230
+ raise NotImplementedError("SC-Attn is not implemented yet.")
231
+ else:
232
+ self.attn1 = Attention(
233
+ query_dim=dim,
234
+ cross_attention_dim=(
235
+ cross_attention_dim if only_cross_attention else None
236
+ ),
237
+ heads=num_attention_heads,
238
+ dim_head=attention_head_dim,
239
+ dropout=dropout,
240
+ bias=attention_bias,
241
+ upcast_attention=upcast_attention,
242
+ )
243
+
244
+ # 2. Cross-Attn
245
+ if cross_attention_dim is not None:
246
+ self.norm2 = (
247
+ AdaLayerNorm(dim, num_embeds_ada_norm)
248
+ if self.use_ada_layer_norm
249
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
250
+ )
251
+ self.attn2 = Attention(
252
+ query_dim=dim,
253
+ cross_attention_dim=cross_attention_dim,
254
+ heads=num_attention_heads,
255
+ dim_head=attention_head_dim,
256
+ dropout=dropout,
257
+ bias=attention_bias,
258
+ upcast_attention=upcast_attention,
259
+ ) # is self-attn if encoder_hidden_states is none
260
+ else:
261
+ self.norm2 = None
262
+ self.attn2 = None
263
+
264
+ # 3. Feed-forward
265
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
266
+ self.ff = FeedForward(
267
+ dim,
268
+ dropout=dropout,
269
+ activation_fn=activation_fn,
270
+ final_dropout=final_dropout,
271
+ )
272
+
273
+ # 4. Temporal Attn
274
+ assert unet_use_temporal_attention is not None
275
+ if unet_use_temporal_attention:
276
+ self.attn_temp = Attention(
277
+ query_dim=dim,
278
+ heads=num_attention_heads,
279
+ dim_head=attention_head_dim,
280
+ dropout=dropout,
281
+ bias=attention_bias,
282
+ upcast_attention=upcast_attention,
283
+ )
284
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
285
+ if self.use_ada_layer_norm:
286
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
287
+ else:
288
+ self.norm1 = nn.LayerNorm(
289
+ dim, elementwise_affine=norm_elementwise_affine
290
+ )
291
+
292
+ def forward(
293
+ self,
294
+ hidden_states: torch.FloatTensor,
295
+ attention_mask: Optional[torch.FloatTensor] = None,
296
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
297
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
298
+ timestep: Optional[torch.LongTensor] = None,
299
+ cross_attention_kwargs: Dict[str, Any] = None,
300
+ video_length=None,
301
+ ):
302
+ # SparseCausal-Attention
303
+ # Notice that normalization is always applied before the real computation in the following blocks.
304
+ # 1. Self-Attention
305
+ if self.use_ada_layer_norm:
306
+ norm_hidden_states = self.norm1(hidden_states, timestep)
307
+ else:
308
+ norm_hidden_states = self.norm1(hidden_states)
309
+
310
+ cross_attention_kwargs = (
311
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
312
+ )
313
+ if self.unet_use_cross_frame_attention:
314
+ cross_attention_kwargs["video_length"] = video_length
315
+
316
+ attn_output = self.attn1(
317
+ norm_hidden_states,
318
+ encoder_hidden_states=(
319
+ encoder_hidden_states if self.only_cross_attention else None
320
+ ),
321
+ attention_mask=attention_mask,
322
+ **cross_attention_kwargs,
323
+ )
324
+
325
+ hidden_states = attn_output + hidden_states
326
+
327
+ # 2. Cross-Attention
328
+ if self.attn2 is not None:
329
+ norm_hidden_states = (
330
+ self.norm2(hidden_states, timestep)
331
+ if self.use_ada_layer_norm
332
+ else self.norm2(hidden_states)
333
+ )
334
+
335
+ attn_output = self.attn2(
336
+ norm_hidden_states,
337
+ encoder_hidden_states=encoder_hidden_states,
338
+ attention_mask=encoder_attention_mask,
339
+ **cross_attention_kwargs,
340
+ )
341
+ hidden_states = attn_output + hidden_states
342
+
343
+ # 3. Feed-forward
344
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
345
+
346
+ # 4. Temporal-Attention
347
+ if self.unet_use_temporal_attention:
348
+ d = hidden_states.shape[1]
349
+ hidden_states = rearrange(
350
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
351
+ )
352
+ norm_hidden_states = (
353
+ self.norm_temp(hidden_states, timestep)
354
+ if self.use_ada_layer_norm
355
+ else self.norm_temp(hidden_states)
356
+ )
357
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
358
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
359
+
360
+ return hidden_states
361
+ hidden_states = attn_output + hidden_states
362
+
363
+ # 2. Cross-Attention
364
+ if self.attn2 is not None:
365
+ norm_hidden_states = (
366
+ self.norm2(hidden_states, timestep)
367
+ if self.use_ada_layer_norm
368
+ else self.norm2(hidden_states)
369
+ )
370
+
371
+ attn_output = self.attn2(
372
+ norm_hidden_states,
373
+ encoder_hidden_states=encoder_hidden_states,
374
+ attention_mask=encoder_attention_mask,
375
+ **cross_attention_kwargs,
376
+ )
377
+ hidden_states = attn_output + hidden_states
378
+
379
+ # 3. Feed-forward
380
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
381
+
382
+ # 4. Temporal-Attention
383
+ if self.unet_use_temporal_attention:
384
+ d = hidden_states.shape[1]
385
+ hidden_states = rearrange(
386
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
387
+ )
388
+ norm_hidden_states = (
389
+ self.norm_temp(hidden_states, timestep)
390
+ if self.use_ada_layer_norm
391
+ else self.norm_temp(hidden_states)
392
+ )
393
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
394
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
395
+
396
+ return hidden_states
modules/flow_controlnet.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Changes were made to this source code by Yuwei Guo.
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ from diffusers import ModelMixin
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.attention_processor import AttentionProcessor
23
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
24
+ from diffusers.models.unet_2d_condition import UNet2DConditionModel
25
+ from diffusers.loaders import UNet2DConditionLoadersMixin, PeftAdapterMixin
26
+ from diffusers.utils import BaseOutput, logging
27
+ from einops import rearrange, repeat
28
+ from torch import nn
29
+ from torch.nn import functional as F
30
+
31
+ from .resnet import InflatedConv3d
32
+ from .unet_blocks import (
33
+ CrossAttnDownBlock3D,
34
+ DownBlock3D,
35
+ UNetMidBlock3DCrossAttn,
36
+ get_down_block,
37
+ )
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ @dataclass
44
+ class FlowControlNetOutput(BaseOutput):
45
+ down_block_res_samples: Tuple[torch.Tensor]
46
+ mid_block_res_sample: torch.Tensor
47
+
48
+
49
+ class FlowControlNetConditioningEmbedding(nn.Module):
50
+ def __init__(
51
+ self,
52
+ conditioning_embedding_channels: int,
53
+ conditioning_channels: int = 3,
54
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
55
+ ):
56
+ super().__init__()
57
+
58
+ self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
59
+
60
+ self.blocks = nn.ModuleList([])
61
+
62
+ for i in range(len(block_out_channels) - 1):
63
+ channel_in = block_out_channels[i]
64
+ channel_out = block_out_channels[i + 1]
65
+ self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1))
66
+ self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
67
+
68
+ self.conv_out = zero_module(
69
+ InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
70
+ )
71
+
72
+ def forward(self, conditioning):
73
+ embedding = self.conv_in(conditioning)
74
+ embedding = F.silu(embedding)
75
+
76
+ for block in self.blocks:
77
+ embedding = block(embedding)
78
+ embedding = F.silu(embedding)
79
+
80
+ embedding = self.conv_out(embedding)
81
+
82
+ return embedding
83
+
84
+
85
+ class FlowControlNetModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
86
+ _supports_gradient_checkpointing = True
87
+
88
+ @register_to_config
89
+ def __init__(
90
+ self,
91
+ in_channels: int = 4,
92
+ conditioning_channels: int = 3,
93
+ flip_sin_to_cos: bool = True,
94
+ freq_shift: int = 0,
95
+ down_block_types: Tuple[str] = (
96
+ "CrossAttnDownBlock2D",
97
+ "CrossAttnDownBlock2D",
98
+ "CrossAttnDownBlock2D",
99
+ "DownBlock2D",
100
+ ),
101
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
102
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
103
+ layers_per_block: int = 2,
104
+ downsample_padding: int = 1,
105
+ mid_block_scale_factor: float = 1,
106
+ act_fn: str = "silu",
107
+ norm_num_groups: Optional[int] = 32,
108
+ norm_eps: float = 1e-5,
109
+ cross_attention_dim: int = 1280,
110
+ attention_head_dim: Union[int, Tuple[int]] = 8,
111
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
112
+ use_linear_projection: bool = False,
113
+ class_embed_type: Optional[str] = None,
114
+ num_class_embeds: Optional[int] = None,
115
+ upcast_attention: bool = False,
116
+ resnet_time_scale_shift: str = "default",
117
+ projection_class_embeddings_input_dim: Optional[int] = None,
118
+ controlnet_conditioning_channel_order: str = "rgb",
119
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
120
+ global_pool_conditions: bool = False,
121
+
122
+ use_motion_module = True,
123
+ motion_module_resolutions = ( 1,2,4,8 ),
124
+ motion_module_mid_block = False,
125
+ motion_module_type = "Vanilla",
126
+ motion_module_kwargs = {
127
+ "num_attention_heads": 8,
128
+ "num_transformer_block": 1,
129
+ "attention_block_types": ["Temporal_Self"],
130
+ "temporal_position_encoding": True,
131
+ "temporal_position_encoding_max_len": 32,
132
+ "temporal_attention_dim_div": 1,
133
+ "causal_temporal_attention": False,
134
+ },
135
+
136
+ concate_conditioning_mask: bool = True,
137
+ use_simplified_condition_embedding: bool = False,
138
+
139
+ set_noisy_sample_input_to_zero: bool = False,
140
+ ):
141
+ super().__init__()
142
+
143
+ # If `num_attention_heads` is not defined (which is the case for most models)
144
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
145
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
146
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
147
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
148
+ # which is why we correct for the naming here.
149
+ num_attention_heads = num_attention_heads or attention_head_dim
150
+
151
+ # Check inputs
152
+ if len(block_out_channels) != len(down_block_types):
153
+ raise ValueError(
154
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
155
+ )
156
+
157
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
158
+ raise ValueError(
159
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
160
+ )
161
+
162
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
163
+ raise ValueError(
164
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
165
+ )
166
+
167
+ # input
168
+ self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero
169
+
170
+ conv_in_kernel = 3
171
+ conv_in_padding = (conv_in_kernel - 1) // 2
172
+ self.conv_in = InflatedConv3d(
173
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
174
+ )
175
+ conditioning_channels = conditioning_channels * 8 * 8
176
+ if concate_conditioning_mask:
177
+ conditioning_channels = conditioning_channels + 1
178
+ self.concate_conditioning_mask = concate_conditioning_mask
179
+
180
+ # control net conditioning embedding
181
+ if use_simplified_condition_embedding:
182
+ self.controlnet_cond_embedding = zero_module(
183
+ InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
184
+ )
185
+ else:
186
+ self.controlnet_cond_embedding = FlowControlNetConditioningEmbedding(
187
+ conditioning_embedding_channels=block_out_channels[0],
188
+ block_out_channels=conditioning_embedding_out_channels,
189
+ conditioning_channels=conditioning_channels,
190
+ )
191
+ self.use_simplified_condition_embedding = use_simplified_condition_embedding
192
+
193
+ # time
194
+ time_embed_dim = block_out_channels[0] * 4
195
+
196
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
197
+ timestep_input_dim = block_out_channels[0]
198
+
199
+ self.time_embedding = TimestepEmbedding(
200
+ timestep_input_dim,
201
+ time_embed_dim,
202
+ act_fn=act_fn,
203
+ )
204
+
205
+ # class embedding
206
+ if class_embed_type is None and num_class_embeds is not None:
207
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
208
+ elif class_embed_type == "timestep":
209
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
210
+ elif class_embed_type == "identity":
211
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
212
+ elif class_embed_type == "projection":
213
+ if projection_class_embeddings_input_dim is None:
214
+ raise ValueError(
215
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
216
+ )
217
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
218
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
219
+ # 2. it projects from an arbitrary input dimension.
220
+ #
221
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
222
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
223
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
224
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
225
+ else:
226
+ self.class_embedding = None
227
+
228
+
229
+ self.down_blocks = nn.ModuleList([])
230
+ self.controlnet_down_blocks = nn.ModuleList([])
231
+
232
+ if isinstance(only_cross_attention, bool):
233
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
234
+
235
+ if isinstance(attention_head_dim, int):
236
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
237
+
238
+ if isinstance(num_attention_heads, int):
239
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
240
+
241
+ # down
242
+ output_channel = block_out_channels[0]
243
+
244
+ controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
245
+ controlnet_block = zero_module(controlnet_block)
246
+ self.controlnet_down_blocks.append(controlnet_block)
247
+
248
+ for i, down_block_type in enumerate(down_block_types):
249
+ res = 2 ** i
250
+ input_channel = output_channel
251
+ output_channel = block_out_channels[i]
252
+ is_final_block = i == len(block_out_channels) - 1
253
+
254
+ down_block = get_down_block(
255
+ down_block_type,
256
+ num_layers=layers_per_block,
257
+ in_channels=input_channel,
258
+ out_channels=output_channel,
259
+ temb_channels=time_embed_dim,
260
+ add_downsample=not is_final_block,
261
+ resnet_eps=norm_eps,
262
+ resnet_act_fn=act_fn,
263
+ resnet_groups=norm_num_groups,
264
+ cross_attention_dim=cross_attention_dim,
265
+ attn_num_head_channels=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
266
+ downsample_padding=downsample_padding,
267
+ use_linear_projection=use_linear_projection,
268
+ only_cross_attention=only_cross_attention[i],
269
+ upcast_attention=upcast_attention,
270
+ resnet_time_scale_shift=resnet_time_scale_shift,
271
+
272
+ use_inflated_groupnorm=True,
273
+
274
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
275
+ motion_module_type=motion_module_type,
276
+ motion_module_kwargs=motion_module_kwargs,
277
+ )
278
+ self.down_blocks.append(down_block)
279
+
280
+ for _ in range(layers_per_block):
281
+ controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
282
+ controlnet_block = zero_module(controlnet_block)
283
+ self.controlnet_down_blocks.append(controlnet_block)
284
+
285
+ if not is_final_block:
286
+ controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
287
+ controlnet_block = zero_module(controlnet_block)
288
+ self.controlnet_down_blocks.append(controlnet_block)
289
+
290
+ # mid
291
+ mid_block_channel = block_out_channels[-1]
292
+
293
+ controlnet_block = InflatedConv3d(mid_block_channel, mid_block_channel, kernel_size=1)
294
+ controlnet_block = zero_module(controlnet_block)
295
+ self.controlnet_mid_block = controlnet_block
296
+
297
+ self.mid_block = UNetMidBlock3DCrossAttn(
298
+ in_channels=mid_block_channel,
299
+ temb_channels=time_embed_dim,
300
+ resnet_eps=norm_eps,
301
+ resnet_act_fn=act_fn,
302
+ output_scale_factor=mid_block_scale_factor,
303
+ resnet_time_scale_shift=resnet_time_scale_shift,
304
+ cross_attention_dim=cross_attention_dim,
305
+ attn_num_head_channels=num_attention_heads[-1],
306
+ resnet_groups=norm_num_groups,
307
+ use_linear_projection=use_linear_projection,
308
+ upcast_attention=upcast_attention,
309
+
310
+ use_inflated_groupnorm=True,
311
+ use_motion_module=use_motion_module and motion_module_mid_block,
312
+ motion_module_type=motion_module_type,
313
+ motion_module_kwargs=motion_module_kwargs,
314
+ )
315
+
316
+ @classmethod
317
+ def from_unet(
318
+ cls,
319
+ unet: UNet2DConditionModel,
320
+ controlnet_conditioning_channel_order: str = "rgb",
321
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
322
+ load_weights_from_unet: bool = True,
323
+
324
+ controlnet_additional_kwargs: dict = {},
325
+ ):
326
+ controlnet = cls(
327
+ in_channels=unet.config.in_channels,
328
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
329
+ freq_shift=unet.config.freq_shift,
330
+ down_block_types=unet.config.down_block_types,
331
+ only_cross_attention=unet.config.only_cross_attention,
332
+ block_out_channels=unet.config.block_out_channels,
333
+ layers_per_block=unet.config.layers_per_block,
334
+ downsample_padding=unet.config.downsample_padding,
335
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
336
+ act_fn=unet.config.act_fn,
337
+ norm_num_groups=unet.config.norm_num_groups,
338
+ norm_eps=unet.config.norm_eps,
339
+ cross_attention_dim=unet.config.cross_attention_dim,
340
+ attention_head_dim=unet.config.attention_head_dim,
341
+ num_attention_heads=unet.config.num_attention_heads,
342
+ use_linear_projection=unet.config.use_linear_projection,
343
+ class_embed_type=unet.config.class_embed_type,
344
+ num_class_embeds=unet.config.num_class_embeds,
345
+ upcast_attention=unet.config.upcast_attention,
346
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
347
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
348
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
349
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
350
+
351
+ **controlnet_additional_kwargs,
352
+ )
353
+ controlnet.unshuffle = nn.PixelUnshuffle(8)
354
+
355
+ if load_weights_from_unet:
356
+ m, u = controlnet.conv_in.load_state_dict(cls.image_layer_filter(unet.conv_in.state_dict()), strict=False)
357
+ assert len(u) == 0
358
+ m, u = controlnet.time_proj.load_state_dict(cls.image_layer_filter(unet.time_proj.state_dict()), strict=False)
359
+ assert len(u) == 0
360
+ m, u = controlnet.time_embedding.load_state_dict(cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False)
361
+ assert len(u) == 0
362
+
363
+ if controlnet.class_embedding:
364
+ m, u = controlnet.class_embedding.load_state_dict(cls.image_layer_filter(unet.class_embedding.state_dict()), strict=False)
365
+ assert len(u) == 0
366
+ m, u = controlnet.down_blocks.load_state_dict(cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False)
367
+ assert len(u) == 0
368
+ m, u = controlnet.mid_block.load_state_dict(cls.image_layer_filter(unet.mid_block.state_dict()), strict=False)
369
+ assert len(u) == 0
370
+
371
+
372
+ return controlnet
373
+
374
+ @staticmethod
375
+ def image_layer_filter(state_dict):
376
+ new_state_dict = {}
377
+ for name, param in state_dict.items():
378
+ if "motion_modules." in name or "lora" in name: continue
379
+ new_state_dict[name] = param
380
+ return new_state_dict
381
+
382
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
383
+ def set_attention_slice(self, slice_size):
384
+ r"""
385
+ Enable sliced attention computation.
386
+
387
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
388
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
389
+
390
+ Args:
391
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
392
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
393
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
394
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
395
+ must be a multiple of `slice_size`.
396
+ """
397
+ sliceable_head_dims = []
398
+
399
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
400
+ if hasattr(module, "set_attention_slice"):
401
+ sliceable_head_dims.append(module.sliceable_head_dim)
402
+
403
+ for child in module.children():
404
+ fn_recursive_retrieve_sliceable_dims(child)
405
+
406
+ # retrieve number of attention layers
407
+ for module in self.children():
408
+ fn_recursive_retrieve_sliceable_dims(module)
409
+
410
+ num_sliceable_layers = len(sliceable_head_dims)
411
+
412
+ if slice_size == "auto":
413
+ # half the attention head size is usually a good trade-off between
414
+ # speed and memory
415
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
416
+ elif slice_size == "max":
417
+ # make smallest slice possible
418
+ slice_size = num_sliceable_layers * [1]
419
+
420
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
421
+
422
+ if len(slice_size) != len(sliceable_head_dims):
423
+ raise ValueError(
424
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
425
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
426
+ )
427
+
428
+ for i in range(len(slice_size)):
429
+ size = slice_size[i]
430
+ dim = sliceable_head_dims[i]
431
+ if size is not None and size > dim:
432
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
433
+
434
+ # Recursively walk through all the children.
435
+ # Any children which exposes the set_attention_slice method
436
+ # gets the message
437
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
438
+ if hasattr(module, "set_attention_slice"):
439
+ module.set_attention_slice(slice_size.pop())
440
+
441
+ for child in module.children():
442
+ fn_recursive_set_attention_slice(child, slice_size)
443
+
444
+ reversed_slice_size = list(reversed(slice_size))
445
+ for module in self.children():
446
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
447
+
448
+ def _set_gradient_checkpointing(self, module, value=False):
449
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
450
+ module.gradient_checkpointing = value
451
+
452
+ def forward(
453
+ self,
454
+ sample: torch.FloatTensor,
455
+ timestep: Union[torch.Tensor, float, int],
456
+ encoder_hidden_states: torch.Tensor,
457
+
458
+ controlnet_cond: torch.FloatTensor,
459
+ conditioning_mask: Optional[torch.FloatTensor] = None,
460
+
461
+ conditioning_scale: float = 1.0,
462
+ class_labels: Optional[torch.Tensor] = None,
463
+ attention_mask: Optional[torch.Tensor] = None,
464
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
465
+ guess_mode: bool = False,
466
+ return_dict: bool = True,
467
+ ) -> Union[FlowControlNetOutput, Tuple]:
468
+ # set input noise to zero
469
+ if self.set_noisy_sample_input_to_zero:
470
+ sample = torch.zeros_like(sample).to(sample.device)
471
+
472
+ # prepare attention_mask
473
+ if attention_mask is not None:
474
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
475
+ attention_mask = attention_mask.unsqueeze(1)
476
+
477
+ # 1. time
478
+ timesteps = timestep
479
+ if not torch.is_tensor(timesteps):
480
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
481
+ # This would be a good case for the `match` statement (Python 3.10+)
482
+ is_mps = sample.device.type == "mps"
483
+ if isinstance(timestep, float):
484
+ dtype = torch.float32 if is_mps else torch.float64
485
+ else:
486
+ dtype = torch.int32 if is_mps else torch.int64
487
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
488
+ elif len(timesteps.shape) == 0:
489
+ timesteps = timesteps[None].to(sample.device)
490
+
491
+ timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0])
492
+ encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0] // encoder_hidden_states.shape[0], 1, 1)
493
+
494
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
495
+ timesteps = timesteps.expand(sample.shape[0])
496
+
497
+ t_emb = self.time_proj(timesteps)
498
+
499
+ # timesteps does not contain any weights and will always return f32 tensors
500
+ # but time_embedding might actually be running in fp16. so we need to cast here.
501
+ # there might be better ways to encapsulate this.
502
+ t_emb = t_emb.to(dtype=self.dtype)
503
+ emb = self.time_embedding(t_emb)
504
+
505
+ if self.class_embedding is not None:
506
+ if class_labels is None:
507
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
508
+
509
+ if self.config.class_embed_type == "timestep":
510
+ class_labels = self.time_proj(class_labels)
511
+
512
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
513
+ emb = emb + class_emb
514
+
515
+ # 2. pre-process
516
+ sample = self.conv_in(sample)
517
+
518
+
519
+ if self.concate_conditioning_mask:
520
+ controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1)
521
+ controlnet_cond = self.unshuffle(controlnet_cond.permute(0,2,1,3,4))
522
+ controlnet_cond = controlnet_cond.contiguous().permute(0,2,1,3,4)
523
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
524
+
525
+ sample = sample + controlnet_cond
526
+
527
+ # 3. down
528
+ down_block_res_samples = (sample,)
529
+ for downsample_block in self.down_blocks:
530
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
531
+ sample, res_samples = downsample_block(
532
+ hidden_states=sample,
533
+ temb=emb,
534
+ encoder_hidden_states=encoder_hidden_states,
535
+ attention_mask=attention_mask,
536
+ # cross_attention_kwargs=cross_attention_kwargs,
537
+ )
538
+ else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
539
+
540
+ down_block_res_samples += res_samples
541
+
542
+ # 4. mid
543
+ if self.mid_block is not None:
544
+ sample = self.mid_block(
545
+ sample,
546
+ emb,
547
+ encoder_hidden_states=encoder_hidden_states,
548
+ attention_mask=attention_mask,
549
+ # cross_attention_kwargs=cross_attention_kwargs,
550
+ )
551
+
552
+ # 5. controlnet blocks
553
+ controlnet_down_block_res_samples = ()
554
+
555
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
556
+ down_block_res_sample = controlnet_block(down_block_res_sample)
557
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
558
+
559
+ down_block_res_samples = controlnet_down_block_res_samples
560
+
561
+ mid_block_res_sample = self.controlnet_mid_block(sample)
562
+
563
+ # 6. scaling
564
+ if guess_mode and not self.config.global_pool_conditions:
565
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
566
+
567
+ scales = scales * conditioning_scale
568
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
569
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
570
+ else:
571
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
572
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
573
+
574
+ if self.config.global_pool_conditions:
575
+ down_block_res_samples = [
576
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
577
+ ]
578
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
579
+
580
+ if not return_dict:
581
+ return (down_block_res_samples, mid_block_res_sample)
582
+
583
+ return FlowControlNetOutput(
584
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
585
+ )
586
+
587
+
588
+ def zero_module(module):
589
+ for p in module.parameters():
590
+ nn.init.zeros_(p)
591
+ return module
modules/image_controlnet.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Changes were made to this source code by Yuwei Guo.
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ from diffusers import ModelMixin
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.attention_processor import AttentionProcessor
23
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
24
+ from diffusers.models.unet_2d_condition import UNet2DConditionModel
25
+ from diffusers.utils import BaseOutput, logging
26
+ from einops import rearrange, repeat
27
+ from torch import nn
28
+ from torch.nn import functional as F
29
+
30
+ from .resnet import InflatedConv3d
31
+ from .unet_blocks import (
32
+ CrossAttnDownBlock3D,
33
+ DownBlock3D,
34
+ UNetMidBlock3DCrossAttn,
35
+ get_down_block,
36
+ )
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ @dataclass
42
+ class ImageControlNetOutput(BaseOutput):
43
+ down_block_res_samples: Tuple[torch.Tensor]
44
+ mid_block_res_sample: torch.Tensor
45
+
46
+
47
+ class ImageControlNetConditioningEmbedding(nn.Module):
48
+ def __init__(
49
+ self,
50
+ conditioning_embedding_channels: int,
51
+ conditioning_channels: int = 3,
52
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
53
+ ):
54
+ super().__init__()
55
+
56
+ self.conv_in = InflatedConv3d(
57
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
58
+ )
59
+
60
+ self.blocks = nn.ModuleList([])
61
+
62
+ for i in range(len(block_out_channels) - 1):
63
+ channel_in = block_out_channels[i]
64
+ channel_out = block_out_channels[i + 1]
65
+ self.blocks.append(
66
+ InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
67
+ )
68
+ self.blocks.append(
69
+ InflatedConv3d(
70
+ channel_in, channel_out, kernel_size=3, padding=1, stride=2
71
+ )
72
+ )
73
+
74
+ self.conv_out = zero_module(
75
+ InflatedConv3d(
76
+ block_out_channels[-1],
77
+ conditioning_embedding_channels,
78
+ kernel_size=3,
79
+ padding=1,
80
+ )
81
+ )
82
+
83
+ def forward(self, conditioning):
84
+ embedding = self.conv_in(conditioning)
85
+ embedding = F.silu(embedding)
86
+
87
+ for block in self.blocks:
88
+ embedding = block(embedding)
89
+ embedding = F.silu(embedding)
90
+
91
+ embedding = self.conv_out(embedding)
92
+
93
+ return embedding
94
+
95
+
96
+ class ImageControlNetModel(ModelMixin, ConfigMixin):
97
+ _supports_gradient_checkpointing = True
98
+
99
+ @register_to_config
100
+ def __init__(
101
+ self,
102
+ in_channels: int = 4,
103
+ conditioning_channels: int = 3,
104
+ flip_sin_to_cos: bool = True,
105
+ freq_shift: int = 0,
106
+ down_block_types: Tuple[str] = (
107
+ "CrossAttnDownBlock2D",
108
+ "CrossAttnDownBlock2D",
109
+ "CrossAttnDownBlock2D",
110
+ "DownBlock2D",
111
+ ),
112
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
113
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
114
+ layers_per_block: int = 2,
115
+ downsample_padding: int = 1,
116
+ mid_block_scale_factor: float = 1,
117
+ act_fn: str = "silu",
118
+ norm_num_groups: Optional[int] = 32,
119
+ norm_eps: float = 1e-5,
120
+ cross_attention_dim: int = 1280,
121
+ attention_head_dim: Union[int, Tuple[int]] = 8,
122
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
123
+ use_linear_projection: bool = False,
124
+ class_embed_type: Optional[str] = None,
125
+ num_class_embeds: Optional[int] = None,
126
+ upcast_attention: bool = False,
127
+ resnet_time_scale_shift: str = "default",
128
+ projection_class_embeddings_input_dim: Optional[int] = None,
129
+ controlnet_conditioning_channel_order: str = "rgb",
130
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
131
+ global_pool_conditions: bool = False,
132
+ use_motion_module=True,
133
+ motion_module_resolutions=(1, 2, 4, 8),
134
+ motion_module_mid_block=False,
135
+ motion_module_type="Vanilla",
136
+ motion_module_kwargs={
137
+ "num_attention_heads": 8,
138
+ "num_transformer_block": 1,
139
+ "attention_block_types": ["Temporal_Self"],
140
+ "temporal_position_encoding": True,
141
+ "temporal_position_encoding_max_len": 32,
142
+ "temporal_attention_dim_div": 1,
143
+ "causal_temporal_attention": False,
144
+ },
145
+ concate_conditioning_mask: bool = True,
146
+ use_simplified_condition_embedding: bool = False,
147
+ set_noisy_sample_input_to_zero: bool = False,
148
+ ):
149
+ super().__init__()
150
+
151
+ # If `num_attention_heads` is not defined (which is the case for most models)
152
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
153
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
154
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
155
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
156
+ # which is why we correct for the naming here.
157
+ num_attention_heads = num_attention_heads or attention_head_dim
158
+
159
+ # Check inputs
160
+ if len(block_out_channels) != len(down_block_types):
161
+ raise ValueError(
162
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
163
+ )
164
+
165
+ if not isinstance(only_cross_attention, bool) and len(
166
+ only_cross_attention
167
+ ) != len(down_block_types):
168
+ raise ValueError(
169
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
170
+ )
171
+
172
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
173
+ down_block_types
174
+ ):
175
+ raise ValueError(
176
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
177
+ )
178
+
179
+ # input
180
+ self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero
181
+
182
+ conv_in_kernel = 3
183
+ conv_in_padding = (conv_in_kernel - 1) // 2
184
+ self.conv_in = InflatedConv3d(
185
+ in_channels,
186
+ block_out_channels[0],
187
+ kernel_size=conv_in_kernel,
188
+ padding=conv_in_padding,
189
+ )
190
+
191
+ if concate_conditioning_mask:
192
+ conditioning_channels = conditioning_channels + 1
193
+ self.concate_conditioning_mask = concate_conditioning_mask
194
+
195
+ # control net conditioning embedding
196
+ if use_simplified_condition_embedding:
197
+ self.controlnet_cond_embedding = zero_module(
198
+ InflatedConv3d(
199
+ conditioning_channels,
200
+ block_out_channels[0],
201
+ kernel_size=conv_in_kernel,
202
+ padding=conv_in_padding,
203
+ )
204
+ )
205
+ else:
206
+ self.controlnet_cond_embedding = ImageControlNetConditioningEmbedding(
207
+ conditioning_embedding_channels=block_out_channels[0],
208
+ block_out_channels=conditioning_embedding_out_channels,
209
+ conditioning_channels=conditioning_channels,
210
+ )
211
+ self.use_simplified_condition_embedding = use_simplified_condition_embedding
212
+
213
+ # time
214
+ time_embed_dim = block_out_channels[0] * 4
215
+
216
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
217
+ timestep_input_dim = block_out_channels[0]
218
+
219
+ self.time_embedding = TimestepEmbedding(
220
+ timestep_input_dim,
221
+ time_embed_dim,
222
+ act_fn=act_fn,
223
+ )
224
+
225
+ # class embedding
226
+ if class_embed_type is None and num_class_embeds is not None:
227
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
228
+ elif class_embed_type == "timestep":
229
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
230
+ elif class_embed_type == "identity":
231
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
232
+ elif class_embed_type == "projection":
233
+ if projection_class_embeddings_input_dim is None:
234
+ raise ValueError(
235
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
236
+ )
237
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
238
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
239
+ # 2. it projects from an arbitrary input dimension.
240
+ #
241
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
242
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
243
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
244
+ self.class_embedding = TimestepEmbedding(
245
+ projection_class_embeddings_input_dim, time_embed_dim
246
+ )
247
+ else:
248
+ self.class_embedding = None
249
+
250
+ self.down_blocks = nn.ModuleList([])
251
+ self.controlnet_down_blocks = nn.ModuleList([])
252
+
253
+ if isinstance(only_cross_attention, bool):
254
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
255
+
256
+ if isinstance(attention_head_dim, int):
257
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
258
+
259
+ if isinstance(num_attention_heads, int):
260
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
261
+
262
+ # down
263
+ output_channel = block_out_channels[0]
264
+
265
+ controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
266
+ controlnet_block = zero_module(controlnet_block)
267
+ self.controlnet_down_blocks.append(controlnet_block)
268
+
269
+ for i, down_block_type in enumerate(down_block_types):
270
+ res = 2**i
271
+ input_channel = output_channel
272
+ output_channel = block_out_channels[i]
273
+ is_final_block = i == len(block_out_channels) - 1
274
+
275
+ down_block = get_down_block(
276
+ down_block_type,
277
+ num_layers=layers_per_block,
278
+ in_channels=input_channel,
279
+ out_channels=output_channel,
280
+ temb_channels=time_embed_dim,
281
+ add_downsample=not is_final_block,
282
+ resnet_eps=norm_eps,
283
+ resnet_act_fn=act_fn,
284
+ resnet_groups=norm_num_groups,
285
+ cross_attention_dim=cross_attention_dim,
286
+ attn_num_head_channels=(
287
+ attention_head_dim[i]
288
+ if attention_head_dim[i] is not None
289
+ else output_channel
290
+ ),
291
+ downsample_padding=downsample_padding,
292
+ use_linear_projection=use_linear_projection,
293
+ only_cross_attention=only_cross_attention[i],
294
+ upcast_attention=upcast_attention,
295
+ resnet_time_scale_shift=resnet_time_scale_shift,
296
+ use_inflated_groupnorm=True,
297
+ use_motion_module=use_motion_module
298
+ and (res in motion_module_resolutions),
299
+ motion_module_type=motion_module_type,
300
+ motion_module_kwargs=motion_module_kwargs,
301
+ )
302
+ self.down_blocks.append(down_block)
303
+
304
+ for _ in range(layers_per_block):
305
+ controlnet_block = InflatedConv3d(
306
+ output_channel, output_channel, kernel_size=1
307
+ )
308
+ controlnet_block = zero_module(controlnet_block)
309
+ self.controlnet_down_blocks.append(controlnet_block)
310
+
311
+ if not is_final_block:
312
+ controlnet_block = InflatedConv3d(
313
+ output_channel, output_channel, kernel_size=1
314
+ )
315
+ controlnet_block = zero_module(controlnet_block)
316
+ self.controlnet_down_blocks.append(controlnet_block)
317
+
318
+ # mid
319
+ mid_block_channel = block_out_channels[-1]
320
+
321
+ controlnet_block = InflatedConv3d(
322
+ mid_block_channel, mid_block_channel, kernel_size=1
323
+ )
324
+ controlnet_block = zero_module(controlnet_block)
325
+ self.controlnet_mid_block = controlnet_block
326
+
327
+ self.mid_block = UNetMidBlock3DCrossAttn(
328
+ in_channels=mid_block_channel,
329
+ temb_channels=time_embed_dim,
330
+ resnet_eps=norm_eps,
331
+ resnet_act_fn=act_fn,
332
+ output_scale_factor=mid_block_scale_factor,
333
+ resnet_time_scale_shift=resnet_time_scale_shift,
334
+ cross_attention_dim=cross_attention_dim,
335
+ attn_num_head_channels=num_attention_heads[-1],
336
+ resnet_groups=norm_num_groups,
337
+ use_linear_projection=use_linear_projection,
338
+ upcast_attention=upcast_attention,
339
+ use_inflated_groupnorm=True,
340
+ use_motion_module=use_motion_module and motion_module_mid_block,
341
+ motion_module_type=motion_module_type,
342
+ motion_module_kwargs=motion_module_kwargs,
343
+ )
344
+
345
+ @classmethod
346
+ def from_unet(
347
+ cls,
348
+ unet: UNet2DConditionModel,
349
+ controlnet_conditioning_channel_order: str = "rgb",
350
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
351
+ load_weights_from_unet: bool = True,
352
+ controlnet_additional_kwargs: dict = {},
353
+ ):
354
+ controlnet = cls(
355
+ in_channels=unet.config.in_channels,
356
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
357
+ freq_shift=unet.config.freq_shift,
358
+ down_block_types=unet.config.down_block_types,
359
+ only_cross_attention=unet.config.only_cross_attention,
360
+ block_out_channels=unet.config.block_out_channels,
361
+ layers_per_block=unet.config.layers_per_block,
362
+ downsample_padding=unet.config.downsample_padding,
363
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
364
+ act_fn=unet.config.act_fn,
365
+ norm_num_groups=unet.config.norm_num_groups,
366
+ norm_eps=unet.config.norm_eps,
367
+ cross_attention_dim=unet.config.cross_attention_dim,
368
+ attention_head_dim=unet.config.attention_head_dim,
369
+ num_attention_heads=unet.config.num_attention_heads,
370
+ use_linear_projection=unet.config.use_linear_projection,
371
+ class_embed_type=unet.config.class_embed_type,
372
+ num_class_embeds=unet.config.num_class_embeds,
373
+ upcast_attention=unet.config.upcast_attention,
374
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
375
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
376
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
377
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
378
+ **controlnet_additional_kwargs,
379
+ )
380
+
381
+ if load_weights_from_unet:
382
+ m, u = controlnet.conv_in.load_state_dict(
383
+ cls.image_layer_filter(unet.conv_in.state_dict()), strict=False
384
+ )
385
+ assert len(u) == 0
386
+ m, u = controlnet.time_proj.load_state_dict(
387
+ cls.image_layer_filter(unet.time_proj.state_dict()), strict=False
388
+ )
389
+ assert len(u) == 0
390
+ m, u = controlnet.time_embedding.load_state_dict(
391
+ cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False
392
+ )
393
+ assert len(u) == 0
394
+
395
+ if controlnet.class_embedding:
396
+ m, u = controlnet.class_embedding.load_state_dict(
397
+ cls.image_layer_filter(unet.class_embedding.state_dict()),
398
+ strict=False,
399
+ )
400
+ assert len(u) == 0
401
+ m, u = controlnet.down_blocks.load_state_dict(
402
+ cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False
403
+ )
404
+ assert len(u) == 0
405
+ m, u = controlnet.mid_block.load_state_dict(
406
+ cls.image_layer_filter(unet.mid_block.state_dict()), strict=False
407
+ )
408
+ assert len(u) == 0
409
+
410
+ return controlnet
411
+
412
+ @staticmethod
413
+ def image_layer_filter(state_dict):
414
+ new_state_dict = {}
415
+ for name, param in state_dict.items():
416
+ if "motion_modules." in name or "lora" in name:
417
+ continue
418
+ new_state_dict[name] = param
419
+ return new_state_dict
420
+
421
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
422
+ def set_attention_slice(self, slice_size):
423
+ r"""
424
+ Enable sliced attention computation.
425
+
426
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
427
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
428
+
429
+ Args:
430
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
431
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
432
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
433
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
434
+ must be a multiple of `slice_size`.
435
+ """
436
+ sliceable_head_dims = []
437
+
438
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
439
+ if hasattr(module, "set_attention_slice"):
440
+ sliceable_head_dims.append(module.sliceable_head_dim)
441
+
442
+ for child in module.children():
443
+ fn_recursive_retrieve_sliceable_dims(child)
444
+
445
+ # retrieve number of attention layers
446
+ for module in self.children():
447
+ fn_recursive_retrieve_sliceable_dims(module)
448
+
449
+ num_sliceable_layers = len(sliceable_head_dims)
450
+
451
+ if slice_size == "auto":
452
+ # half the attention head size is usually a good trade-off between
453
+ # speed and memory
454
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
455
+ elif slice_size == "max":
456
+ # make smallest slice possible
457
+ slice_size = num_sliceable_layers * [1]
458
+
459
+ slice_size = (
460
+ num_sliceable_layers * [slice_size]
461
+ if not isinstance(slice_size, list)
462
+ else slice_size
463
+ )
464
+
465
+ if len(slice_size) != len(sliceable_head_dims):
466
+ raise ValueError(
467
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
468
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
469
+ )
470
+
471
+ for i in range(len(slice_size)):
472
+ size = slice_size[i]
473
+ dim = sliceable_head_dims[i]
474
+ if size is not None and size > dim:
475
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
476
+
477
+ # Recursively walk through all the children.
478
+ # Any children which exposes the set_attention_slice method
479
+ # gets the message
480
+ def fn_recursive_set_attention_slice(
481
+ module: torch.nn.Module, slice_size: List[int]
482
+ ):
483
+ if hasattr(module, "set_attention_slice"):
484
+ module.set_attention_slice(slice_size.pop())
485
+
486
+ for child in module.children():
487
+ fn_recursive_set_attention_slice(child, slice_size)
488
+
489
+ reversed_slice_size = list(reversed(slice_size))
490
+ for module in self.children():
491
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
492
+
493
+ def _set_gradient_checkpointing(self, module, value=False):
494
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
495
+ module.gradient_checkpointing = value
496
+
497
+ def forward(
498
+ self,
499
+ sample: torch.FloatTensor,
500
+ timestep: Union[torch.Tensor, float, int],
501
+ encoder_hidden_states: torch.Tensor,
502
+ controlnet_cond: torch.FloatTensor,
503
+ conditioning_mask: Optional[torch.FloatTensor] = None,
504
+ conditioning_scale: float = 1.0,
505
+ class_labels: Optional[torch.Tensor] = None,
506
+ attention_mask: Optional[torch.Tensor] = None,
507
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
508
+ guess_mode: bool = False,
509
+ return_dict: bool = True,
510
+ ) -> Union[ImageControlNetOutput, Tuple]:
511
+
512
+ # set input noise to zero
513
+ if self.set_noisy_sample_input_to_zero:
514
+ sample = torch.zeros_like(sample).to(sample.device)
515
+
516
+ # prepare attention_mask
517
+ if attention_mask is not None:
518
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
519
+ attention_mask = attention_mask.unsqueeze(1)
520
+
521
+ # 1. time
522
+ timesteps = timestep
523
+ if not torch.is_tensor(timesteps):
524
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
525
+ # This would be a good case for the `match` statement (Python 3.10+)
526
+ is_mps = sample.device.type == "mps"
527
+ if isinstance(timestep, float):
528
+ dtype = torch.float32 if is_mps else torch.float64
529
+ else:
530
+ dtype = torch.int32 if is_mps else torch.int64
531
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
532
+ elif len(timesteps.shape) == 0:
533
+ timesteps = timesteps[None].to(sample.device)
534
+
535
+ timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0])
536
+ encoder_hidden_states = encoder_hidden_states.repeat(
537
+ sample.shape[0] // encoder_hidden_states.shape[0], 1, 1
538
+ )
539
+
540
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
541
+ timesteps = timesteps.expand(sample.shape[0])
542
+
543
+ t_emb = self.time_proj(timesteps)
544
+
545
+ # timesteps does not contain any weights and will always return f32 tensors
546
+ # but time_embedding might actually be running in fp16. so we need to cast here.
547
+ # there might be better ways to encapsulate this.
548
+ t_emb = t_emb.to(dtype=self.dtype)
549
+ emb = self.time_embedding(t_emb)
550
+
551
+ if self.class_embedding is not None:
552
+ if class_labels is None:
553
+ raise ValueError(
554
+ "class_labels should be provided when num_class_embeds > 0"
555
+ )
556
+
557
+ if self.config.class_embed_type == "timestep":
558
+ class_labels = self.time_proj(class_labels)
559
+
560
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
561
+ emb = emb + class_emb
562
+
563
+ # 2. pre-process
564
+ sample = self.conv_in(sample)
565
+
566
+ if self.concate_conditioning_mask:
567
+ controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1)
568
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
569
+
570
+ sample = sample + controlnet_cond
571
+
572
+ # 3. down
573
+ down_block_res_samples = (sample,)
574
+ for downsample_block in self.down_blocks:
575
+ if (
576
+ hasattr(downsample_block, "has_cross_attention")
577
+ and downsample_block.has_cross_attention
578
+ ):
579
+ sample, res_samples = downsample_block(
580
+ hidden_states=sample,
581
+ temb=emb,
582
+ encoder_hidden_states=encoder_hidden_states,
583
+ attention_mask=attention_mask,
584
+ # cross_attention_kwargs=cross_attention_kwargs,
585
+ )
586
+ else:
587
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
588
+
589
+ down_block_res_samples += res_samples
590
+
591
+ # 4. mid
592
+ if self.mid_block is not None:
593
+ sample = self.mid_block(
594
+ sample,
595
+ emb,
596
+ encoder_hidden_states=encoder_hidden_states,
597
+ attention_mask=attention_mask,
598
+ # cross_attention_kwargs=cross_attention_kwargs,
599
+ )
600
+
601
+ # 5. controlnet blocks
602
+ controlnet_down_block_res_samples = ()
603
+
604
+ for down_block_res_sample, controlnet_block in zip(
605
+ down_block_res_samples, self.controlnet_down_blocks
606
+ ):
607
+ down_block_res_sample = controlnet_block(down_block_res_sample)
608
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (
609
+ down_block_res_sample,
610
+ )
611
+
612
+ down_block_res_samples = controlnet_down_block_res_samples
613
+
614
+ mid_block_res_sample = self.controlnet_mid_block(sample)
615
+
616
+ # 6. scaling
617
+ if guess_mode and not self.config.global_pool_conditions:
618
+ scales = torch.logspace(
619
+ -1, 0, len(down_block_res_samples) + 1, device=sample.device
620
+ ) # 0.1 to 1.0
621
+
622
+ scales = scales * conditioning_scale
623
+ down_block_res_samples = [
624
+ sample * scale for sample, scale in zip(down_block_res_samples, scales)
625
+ ]
626
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
627
+ else:
628
+ down_block_res_samples = [
629
+ sample * conditioning_scale for sample in down_block_res_samples
630
+ ]
631
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
632
+
633
+ if self.config.global_pool_conditions:
634
+ down_block_res_samples = [
635
+ torch.mean(sample, dim=(2, 3), keepdim=True)
636
+ for sample in down_block_res_samples
637
+ ]
638
+ mid_block_res_sample = torch.mean(
639
+ mid_block_res_sample, dim=(2, 3), keepdim=True
640
+ )
641
+
642
+ if not return_dict:
643
+ return (down_block_res_samples, mid_block_res_sample)
644
+
645
+ return ImageControlNetOutput(
646
+ down_block_res_samples=down_block_res_samples,
647
+ mid_block_res_sample=mid_block_res_sample,
648
+ )
649
+
650
+ @property
651
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
652
+ r"""
653
+ Returns:
654
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
655
+ indexed by its weight name.
656
+ """
657
+ # set recursively
658
+ processors = {}
659
+
660
+ def fn_recursive_add_processors(
661
+ name: str,
662
+ module: torch.nn.Module,
663
+ processors: Dict[str, AttentionProcessor],
664
+ ):
665
+ if hasattr(module, "set_processor"):
666
+ processors[f"{name}.processor"] = module.processor
667
+
668
+ for sub_name, child in module.named_children():
669
+ if "temporal_transformer" not in sub_name:
670
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
671
+
672
+ return processors
673
+
674
+ for name, module in self.named_children():
675
+ if "temporal_transformer" not in name:
676
+ fn_recursive_add_processors(name, module, processors)
677
+
678
+ return processors
679
+
680
+ def set_attn_processor(
681
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
682
+ ):
683
+ r"""
684
+ Sets the attention processor to use to compute attention.
685
+ Parameters:
686
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
687
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
688
+ for **all** `Attention` layers.
689
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
690
+ processor. This is strongly recommended when setting trainable attention processors.
691
+ """
692
+ count = len(self.attn_processors.keys())
693
+
694
+ if isinstance(processor, dict) and len(processor) != count:
695
+ raise ValueError(
696
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
697
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
698
+ )
699
+
700
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
701
+ if hasattr(module, "set_processor"):
702
+ if not isinstance(processor, dict):
703
+ module.set_processor(processor)
704
+ else:
705
+ module.set_processor(processor.pop(f"{name}.processor"))
706
+
707
+ for sub_name, child in module.named_children():
708
+ if "temporal_transformer" not in sub_name:
709
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
710
+
711
+ for name, module in self.named_children():
712
+ if "temporal_transformer" not in name:
713
+ fn_recursive_attn_processor(name, module, processor)
714
+
715
+
716
+ def zero_module(module):
717
+ for p in module.parameters():
718
+ nn.init.zeros_(p)
719
+ return module
720
+
721
+
modules/motion_module.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.models.attention import Attention, FeedForward
8
+ from diffusers.utils import BaseOutput
9
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
10
+ from einops import rearrange, repeat
11
+ from torch import Tensor, nn
12
+
13
+
14
+ def zero_module(module):
15
+ # Zero out the parameters of a module and return it.
16
+ for p in module.parameters():
17
+ p.detach().zero_()
18
+ return module
19
+
20
+
21
+ @dataclass
22
+ class TemporalTransformer3DModelOutput(BaseOutput):
23
+ sample: torch.FloatTensor
24
+
25
+
26
+ def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
27
+ if motion_module_type == "Vanilla":
28
+ return VanillaTemporalModule(
29
+ in_channels=in_channels,
30
+ **motion_module_kwargs,
31
+ )
32
+ else:
33
+ raise ValueError
34
+
35
+
36
+ class VanillaTemporalModule(nn.Module):
37
+ def __init__(
38
+ self,
39
+ in_channels,
40
+ num_attention_heads=8,
41
+ num_transformer_block=2,
42
+ attention_block_types=("Temporal_Self", "Temporal_Self"),
43
+ cross_frame_attention_mode=None,
44
+ temporal_position_encoding=False,
45
+ temporal_position_encoding_max_len=24,
46
+ temporal_attention_dim_div=1,
47
+ zero_initialize=True,
48
+ ):
49
+ super().__init__()
50
+
51
+ self.temporal_transformer = TemporalTransformer3DModel(
52
+ in_channels=in_channels,
53
+ num_attention_heads=num_attention_heads,
54
+ attention_head_dim=in_channels
55
+ // num_attention_heads
56
+ // temporal_attention_dim_div,
57
+ num_layers=num_transformer_block,
58
+ attention_block_types=attention_block_types,
59
+ cross_frame_attention_mode=cross_frame_attention_mode,
60
+ temporal_position_encoding=temporal_position_encoding,
61
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
62
+ )
63
+
64
+ if zero_initialize:
65
+ self.temporal_transformer.proj_out = zero_module(
66
+ self.temporal_transformer.proj_out
67
+ )
68
+ self.skip_temporal_layers = False # Whether to skip temporal layer
69
+
70
+ def forward(
71
+ self,
72
+ input_tensor,
73
+ temb,
74
+ encoder_hidden_states,
75
+ attention_mask=None,
76
+ anchor_frame_idx=None,
77
+ ):
78
+ if self.skip_temporal_layers is True:
79
+ return input_tensor
80
+
81
+ hidden_states = input_tensor
82
+ hidden_states = self.temporal_transformer(
83
+ hidden_states, encoder_hidden_states, attention_mask
84
+ )
85
+
86
+ output = hidden_states
87
+ return output
88
+
89
+
90
+ @maybe_allow_in_graph
91
+ class TemporalTransformer3DModel(nn.Module):
92
+ def __init__(
93
+ self,
94
+ in_channels,
95
+ num_attention_heads,
96
+ attention_head_dim,
97
+ num_layers,
98
+ attention_block_types=(
99
+ "Temporal_Self",
100
+ "Temporal_Self",
101
+ ),
102
+ dropout=0.0,
103
+ norm_num_groups=32,
104
+ cross_attention_dim=768,
105
+ activation_fn="geglu",
106
+ attention_bias=False,
107
+ upcast_attention=False,
108
+ cross_frame_attention_mode=None,
109
+ temporal_position_encoding=False,
110
+ temporal_position_encoding_max_len=24,
111
+ ):
112
+ super().__init__()
113
+
114
+ inner_dim = num_attention_heads * attention_head_dim
115
+
116
+ self.norm = torch.nn.GroupNorm(
117
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
118
+ )
119
+ self.proj_in = nn.Linear(in_channels, inner_dim)
120
+
121
+ self.transformer_blocks = nn.ModuleList(
122
+ [
123
+ TemporalTransformerBlock(
124
+ dim=inner_dim,
125
+ num_attention_heads=num_attention_heads,
126
+ attention_head_dim=attention_head_dim,
127
+ attention_block_types=attention_block_types,
128
+ dropout=dropout,
129
+ norm_num_groups=norm_num_groups,
130
+ cross_attention_dim=cross_attention_dim,
131
+ activation_fn=activation_fn,
132
+ attention_bias=attention_bias,
133
+ upcast_attention=upcast_attention,
134
+ cross_frame_attention_mode=cross_frame_attention_mode,
135
+ temporal_position_encoding=temporal_position_encoding,
136
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
137
+ )
138
+ for d in range(num_layers)
139
+ ]
140
+ )
141
+ self.proj_out = nn.Linear(inner_dim, in_channels)
142
+
143
+ def forward(
144
+ self,
145
+ hidden_states: Tensor,
146
+ encoder_hidden_states: Optional[Tensor] = None,
147
+ attention_mask: Optional[Tensor] = None,
148
+ ):
149
+ assert (
150
+ hidden_states.dim() == 5
151
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
152
+ video_length = hidden_states.shape[2]
153
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
154
+
155
+ batch, channel, height, weight = hidden_states.shape
156
+ residual = hidden_states
157
+
158
+ hidden_states = self.norm(hidden_states)
159
+ inner_dim = hidden_states.shape[1]
160
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
161
+ batch, height * weight, inner_dim
162
+ )
163
+ hidden_states = self.proj_in(hidden_states)
164
+
165
+ # Transformer Blocks
166
+ for block in self.transformer_blocks:
167
+ hidden_states = block(
168
+ hidden_states,
169
+ encoder_hidden_states=encoder_hidden_states,
170
+ video_length=video_length,
171
+ )
172
+
173
+ # output
174
+ hidden_states = self.proj_out(hidden_states)
175
+ hidden_states = (
176
+ hidden_states.reshape(batch, height, weight, inner_dim)
177
+ .permute(0, 3, 1, 2)
178
+ .contiguous()
179
+ )
180
+
181
+ output = hidden_states + residual
182
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
183
+
184
+ return output
185
+
186
+
187
+ @maybe_allow_in_graph
188
+ class TemporalTransformerBlock(nn.Module):
189
+ def __init__(
190
+ self,
191
+ dim: int,
192
+ num_attention_heads: int,
193
+ attention_head_dim: int,
194
+ attention_block_types=(
195
+ "Temporal_Self",
196
+ "Temporal_Self",
197
+ ),
198
+ dropout=0.0,
199
+ norm_num_groups: int = 32,
200
+ cross_attention_dim: int = 768,
201
+ activation_fn: str = "geglu",
202
+ attention_bias: bool = False,
203
+ upcast_attention: bool = False,
204
+ cross_frame_attention_mode=None,
205
+ temporal_position_encoding: bool = False,
206
+ temporal_position_encoding_max_len: int = 24,
207
+ ):
208
+ super().__init__()
209
+
210
+ attention_blocks = []
211
+ norms = []
212
+
213
+ for block_name in attention_block_types:
214
+ attention_blocks.append(
215
+ VersatileAttention(
216
+ attention_mode=block_name.split("_")[0],
217
+ cross_attention_dim=(
218
+ cross_attention_dim if block_name.endswith("_Cross") else None
219
+ ),
220
+ query_dim=dim,
221
+ heads=num_attention_heads,
222
+ dim_head=attention_head_dim,
223
+ dropout=dropout,
224
+ bias=attention_bias,
225
+ upcast_attention=upcast_attention,
226
+ cross_frame_attention_mode=cross_frame_attention_mode,
227
+ temporal_position_encoding=temporal_position_encoding,
228
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
229
+ )
230
+ )
231
+ norms.append(nn.LayerNorm(dim))
232
+
233
+ self.attention_blocks = nn.ModuleList(attention_blocks)
234
+ self.norms = nn.ModuleList(norms)
235
+
236
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
237
+ self.ff_norm = nn.LayerNorm(dim)
238
+
239
+ def forward(
240
+ self,
241
+ hidden_states,
242
+ encoder_hidden_states=None,
243
+ attention_mask=None,
244
+ video_length=None,
245
+ ):
246
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
247
+ norm_hidden_states = norm(hidden_states)
248
+ hidden_states = (
249
+ attention_block(
250
+ norm_hidden_states,
251
+ encoder_hidden_states=(
252
+ encoder_hidden_states
253
+ if attention_block.is_cross_attention
254
+ else None
255
+ ),
256
+ video_length=video_length,
257
+ )
258
+ + hidden_states
259
+ )
260
+
261
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
262
+
263
+ output = hidden_states
264
+ return output
265
+
266
+
267
+ class PositionalEncoding(nn.Module):
268
+ def __init__(self, d_model, dropout: float = 0.0, max_len: int = 24):
269
+ super().__init__()
270
+ self.dropout: nn.Module = nn.Dropout(p=dropout)
271
+ position = torch.arange(max_len).unsqueeze(1)
272
+ div_term = torch.exp(
273
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
274
+ )
275
+ pe: Tensor = torch.zeros(1, max_len, d_model)
276
+ pe[0, :, 0::2] = torch.sin(position * div_term)
277
+ pe[0, :, 1::2] = torch.cos(position * div_term)
278
+ self.register_buffer("pe", pe)
279
+
280
+ def forward(self, x: Tensor):
281
+ x = x + self.pe[:, : x.size(1)]
282
+ return self.dropout(x)
283
+
284
+
285
+ @maybe_allow_in_graph
286
+ class VersatileAttention(Attention):
287
+ def __init__(
288
+ self,
289
+ attention_mode: str = None,
290
+ cross_frame_attention_mode: Optional[str] = None,
291
+ temporal_position_encoding: bool = False,
292
+ temporal_position_encoding_max_len: int = 24,
293
+ *args,
294
+ **kwargs,
295
+ ):
296
+ super().__init__(*args, **kwargs)
297
+ if attention_mode.lower() != "temporal":
298
+ raise ValueError(f"Attention mode {attention_mode} is not supported.")
299
+
300
+ self.attention_mode = attention_mode
301
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
302
+
303
+ self.pos_encoder = (
304
+ PositionalEncoding(
305
+ kwargs["query_dim"],
306
+ dropout=0.0,
307
+ max_len=temporal_position_encoding_max_len,
308
+ )
309
+ if (temporal_position_encoding and attention_mode == "Temporal")
310
+ else None
311
+ )
312
+
313
+ def extra_repr(self):
314
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
315
+
316
+ def forward(
317
+ self,
318
+ hidden_states: Tensor,
319
+ encoder_hidden_states=None,
320
+ attention_mask=None,
321
+ video_length=None,
322
+ ):
323
+ if self.attention_mode == "Temporal":
324
+ d = hidden_states.shape[1]
325
+ hidden_states = rearrange(
326
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
327
+ )
328
+
329
+ if self.pos_encoder is not None:
330
+ hidden_states = self.pos_encoder(hidden_states)
331
+
332
+ encoder_hidden_states = (
333
+ repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
334
+ if encoder_hidden_states is not None
335
+ else encoder_hidden_states
336
+ )
337
+ else:
338
+ raise NotImplementedError
339
+
340
+ # attention processor makes this easy so that's nice
341
+ hidden_states = self.processor(
342
+ self, hidden_states, encoder_hidden_states, attention_mask
343
+ )
344
+
345
+ if self.attention_mode == "Temporal":
346
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
347
+
348
+ return hidden_states
349
+
350
+ def set_use_memory_efficient_attention_xformers(
351
+ self,
352
+ use_memory_efficient_attention_xformers: bool,
353
+ attention_op: Optional[Callable] = None,
354
+ ):
355
+ return None
modules/resnet.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch import Tensor, nn
9
+
10
+
11
+ class InflatedConv3d(nn.Conv2d):
12
+ def forward(self, x: Tensor) -> Tensor:
13
+ ori_dim = x.ndim
14
+ if ori_dim == 5:
15
+ frames = x.shape[2]
16
+ x = rearrange(x, "b c f h w -> (b f) c h w")
17
+ x = F.conv2d(
18
+ x,
19
+ self.weight,
20
+ self.bias,
21
+ self.stride,
22
+ self.padding,
23
+ self.dilation,
24
+ self.groups,
25
+ )
26
+ if ori_dim == 5:
27
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=frames)
28
+ return x
29
+
30
+
31
+ class InflatedGroupNorm(nn.GroupNorm):
32
+ def forward(self, x):
33
+ video_length = x.shape[2]
34
+
35
+ x = rearrange(x, "b c f h w -> (b f) c h w")
36
+ x = super().forward(x)
37
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
38
+
39
+ return x
40
+
41
+
42
+ class Upsample3D(nn.Module):
43
+ def __init__(
44
+ self,
45
+ channels: int,
46
+ use_conv: bool = False,
47
+ use_conv_transpose: bool = False,
48
+ out_channels: Optional[int] = None,
49
+ name="conv",
50
+ ):
51
+ super().__init__()
52
+ self.channels = channels
53
+ self.out_channels = out_channels or channels
54
+ self.use_conv = use_conv
55
+ self.use_conv_transpose = use_conv_transpose
56
+ self.name = name
57
+
58
+ if use_conv_transpose:
59
+ raise NotImplementedError
60
+ elif use_conv:
61
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
62
+
63
+ def forward(self, hidden_states: Tensor, output_size=None):
64
+ assert hidden_states.shape[1] == self.channels
65
+
66
+ if self.use_conv_transpose:
67
+ raise NotImplementedError
68
+
69
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
70
+ dtype = hidden_states.dtype
71
+ if dtype == torch.bfloat16:
72
+ hidden_states = hidden_states.to(torch.float32)
73
+
74
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
75
+ if hidden_states.shape[0] >= 64:
76
+ hidden_states = hidden_states.contiguous()
77
+
78
+ # if `output_size` is passed we force the interpolation output
79
+ # size and do not make use of `scale_factor=2`
80
+ if output_size is None:
81
+ hidden_states = F.interpolate(
82
+ hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
83
+ )
84
+ else:
85
+ hidden_states = F.interpolate(
86
+ hidden_states, size=output_size, mode="nearest"
87
+ )
88
+
89
+ # If the input is bfloat16, we cast back to bfloat16
90
+ if dtype == torch.bfloat16:
91
+ hidden_states = hidden_states.to(dtype)
92
+
93
+ hidden_states = self.conv(hidden_states)
94
+
95
+ return hidden_states
96
+
97
+
98
+ class Downsample3D(nn.Module):
99
+ def __init__(
100
+ self,
101
+ channels: int,
102
+ use_conv: bool = False,
103
+ out_channels: Optional[int] = None,
104
+ padding: int = 1,
105
+ name="conv",
106
+ ):
107
+ super().__init__()
108
+ self.channels = channels
109
+ self.out_channels = out_channels or channels
110
+ self.use_conv = use_conv
111
+ self.padding = padding
112
+ stride = 2
113
+ self.name = name
114
+
115
+ if use_conv:
116
+ self.conv = InflatedConv3d(
117
+ self.channels, self.out_channels, 3, stride=stride, padding=padding
118
+ )
119
+ else:
120
+ raise NotImplementedError
121
+
122
+ def forward(self, hidden_states):
123
+ assert hidden_states.shape[1] == self.channels
124
+ if self.use_conv and self.padding == 0:
125
+ raise NotImplementedError
126
+
127
+ assert hidden_states.shape[1] == self.channels
128
+ hidden_states = self.conv(hidden_states)
129
+
130
+ return hidden_states
131
+
132
+
133
+ class ResnetBlock3D(nn.Module):
134
+ def __init__(
135
+ self,
136
+ *,
137
+ in_channels,
138
+ out_channels=None,
139
+ conv_shortcut=False,
140
+ dropout=0.0,
141
+ temb_channels=512,
142
+ groups=32,
143
+ groups_out=None,
144
+ pre_norm=True,
145
+ eps=1e-6,
146
+ non_linearity="swish",
147
+ time_embedding_norm="default",
148
+ output_scale_factor=1.0,
149
+ use_in_shortcut=None,
150
+ use_inflated_groupnorm=None,
151
+ ):
152
+ super().__init__()
153
+ self.pre_norm = pre_norm
154
+ self.pre_norm = True
155
+ self.in_channels = in_channels
156
+ out_channels = in_channels if out_channels is None else out_channels
157
+ self.out_channels = out_channels
158
+ self.use_conv_shortcut = conv_shortcut
159
+ self.time_embedding_norm = time_embedding_norm
160
+ self.output_scale_factor = output_scale_factor
161
+
162
+ if groups_out is None:
163
+ groups_out = groups
164
+
165
+ assert use_inflated_groupnorm != None
166
+ if use_inflated_groupnorm:
167
+ self.norm1 = InflatedGroupNorm(
168
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
169
+ )
170
+ else:
171
+ self.norm1 = nn.GroupNorm(
172
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
173
+ )
174
+
175
+ self.conv1 = InflatedConv3d(
176
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
177
+ )
178
+
179
+ if temb_channels is not None:
180
+ if self.time_embedding_norm == "default":
181
+ time_emb_proj_out_channels = out_channels
182
+ elif self.time_embedding_norm == "scale_shift":
183
+ time_emb_proj_out_channels = out_channels * 2
184
+ else:
185
+ raise ValueError(
186
+ f"unknown time_embedding_norm : {self.time_embedding_norm} "
187
+ )
188
+
189
+ self.time_emb_proj = nn.Linear(temb_channels, time_emb_proj_out_channels)
190
+ else:
191
+ self.time_emb_proj = None
192
+
193
+ if use_inflated_groupnorm:
194
+ self.norm2 = InflatedGroupNorm(
195
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
196
+ )
197
+ else:
198
+ self.norm2 = nn.GroupNorm(
199
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
200
+ )
201
+
202
+ self.dropout = nn.Dropout(dropout)
203
+ self.conv2 = InflatedConv3d(
204
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
205
+ )
206
+
207
+ if non_linearity == "swish":
208
+ self.nonlinearity = lambda x: F.silu(x)
209
+ elif non_linearity == "mish":
210
+ self.nonlinearity = Mish()
211
+ elif non_linearity == "silu":
212
+ self.nonlinearity = nn.SiLU()
213
+
214
+ self.use_in_shortcut = (
215
+ self.in_channels != self.out_channels
216
+ if use_in_shortcut is None
217
+ else use_in_shortcut
218
+ )
219
+
220
+ self.conv_shortcut = None
221
+ if self.use_in_shortcut:
222
+ self.conv_shortcut = InflatedConv3d(
223
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
224
+ )
225
+
226
+ def forward(self, input_tensor, temb):
227
+ hidden_states = input_tensor
228
+
229
+ hidden_states = self.norm1(hidden_states)
230
+ hidden_states = self.nonlinearity(hidden_states)
231
+
232
+ hidden_states = self.conv1(hidden_states)
233
+
234
+ if temb is not None:
235
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
236
+
237
+ if temb is not None and self.time_embedding_norm == "default":
238
+ hidden_states = hidden_states + temb
239
+
240
+ hidden_states = self.norm2(hidden_states)
241
+
242
+ if temb is not None and self.time_embedding_norm == "scale_shift":
243
+ scale, shift = torch.chunk(temb, 2, dim=1)
244
+ hidden_states = hidden_states * (1 + scale) + shift
245
+
246
+ hidden_states = self.nonlinearity(hidden_states)
247
+
248
+ hidden_states = self.dropout(hidden_states)
249
+ hidden_states = self.conv2(hidden_states)
250
+
251
+ if self.conv_shortcut is not None:
252
+ input_tensor = self.conv_shortcut(input_tensor)
253
+
254
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
255
+
256
+ return output_tensor
257
+
258
+
259
+ class Mish(nn.Module):
260
+ def forward(self, hidden_states):
261
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
modules/unet.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ import json
4
+ import os
5
+ from dataclasses import dataclass
6
+ from os import PathLike
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.utils.checkpoint
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.loaders import UNet2DConditionLoadersMixin, PeftAdapterMixin
14
+ from diffusers.models import ModelMixin
15
+ from diffusers.models.attention_processor import AttentionProcessor
16
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
17
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
18
+ from safetensors.torch import load_file
19
+ from torch import Tensor, nn
20
+
21
+ from .resnet import InflatedConv3d, InflatedGroupNorm
22
+ from .unet_blocks import (
23
+ CrossAttnDownBlock3D,
24
+ CrossAttnUpBlock3D,
25
+ DownBlock3D,
26
+ UNetMidBlock3DCrossAttn,
27
+ UpBlock3D,
28
+ get_down_block,
29
+ get_up_block,
30
+ )
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ @dataclass
36
+ class UNet3DConditionFlowModelOutput(BaseOutput):
37
+ sample: torch.FloatTensor
38
+
39
+
40
+
41
+ class UNet3DConditionFlowModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
42
+ _supports_gradient_checkpointing = True
43
+
44
+ @register_to_config
45
+ def __init__(
46
+ self,
47
+ sample_size: Optional[int] = None,
48
+ in_channels: int = 4,
49
+ out_channels: int = 4,
50
+ center_input_sample: bool = False,
51
+ flip_sin_to_cos: bool = True,
52
+ freq_shift: int = 0,
53
+ down_block_types: Tuple[str] = (
54
+ "CrossAttnDownBlock3D",
55
+ "CrossAttnDownBlock3D",
56
+ "CrossAttnDownBlock3D",
57
+ "DownBlock3D",
58
+ ),
59
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
60
+ up_block_types: Tuple[str] = (
61
+ "UpBlock3D",
62
+ "CrossAttnUpBlock3D",
63
+ "CrossAttnUpBlock3D",
64
+ "CrossAttnUpBlock3D"
65
+ ),
66
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
67
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
68
+ layers_per_block: int = 2,
69
+ downsample_padding: int = 1,
70
+ mid_block_scale_factor: float = 1,
71
+ act_fn: str = "silu",
72
+ norm_num_groups: int = 32,
73
+ norm_eps: float = 1e-5,
74
+ cross_attention_dim: int = 1280,
75
+ attention_head_dim: Union[int, Tuple[int]] = 8,
76
+ dual_cross_attention: bool = False,
77
+ use_linear_projection: bool = False,
78
+ class_embed_type: Optional[str] = None,
79
+ num_class_embeds: Optional[int] = None,
80
+ upcast_attention: bool = False,
81
+ resnet_time_scale_shift: str = "default",
82
+
83
+ use_inflated_groupnorm=False,
84
+
85
+ # Additional
86
+ use_motion_module = False,
87
+ motion_module_resolutions = ( 1,2,4,8 ),
88
+ motion_module_mid_block = False,
89
+ motion_module_decoder_only = False,
90
+ motion_module_type = None,
91
+ motion_module_kwargs = {},
92
+ unet_use_cross_frame_attention = False,
93
+ unet_use_temporal_attention = False,
94
+ ):
95
+ super().__init__()
96
+
97
+ self.sample_size = sample_size
98
+ time_embed_dim = block_out_channels[0] * 4
99
+
100
+ # input
101
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
102
+
103
+ # time
104
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
105
+ timestep_input_dim = block_out_channels[0]
106
+
107
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
108
+
109
+ # class embedding
110
+ if class_embed_type is None and num_class_embeds is not None:
111
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
112
+ elif class_embed_type == "timestep":
113
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
114
+ elif class_embed_type == "identity":
115
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
116
+ else:
117
+ self.class_embedding = None
118
+
119
+ self.down_blocks = nn.ModuleList([])
120
+ self.mid_block = None
121
+ self.up_blocks = nn.ModuleList([])
122
+
123
+ if isinstance(only_cross_attention, bool):
124
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
125
+
126
+ if isinstance(attention_head_dim, int):
127
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
128
+
129
+ # down
130
+ output_channel = block_out_channels[0]
131
+ for i, down_block_type in enumerate(down_block_types):
132
+ res = 2 ** i
133
+ input_channel = output_channel
134
+ output_channel = block_out_channels[i]
135
+ is_final_block = i == len(block_out_channels) - 1
136
+
137
+ down_block = get_down_block(
138
+ down_block_type,
139
+ num_layers=layers_per_block,
140
+ in_channels=input_channel,
141
+ out_channels=output_channel,
142
+ temb_channels=time_embed_dim,
143
+ add_downsample=not is_final_block,
144
+ resnet_eps=norm_eps,
145
+ resnet_act_fn=act_fn,
146
+ resnet_groups=norm_num_groups,
147
+ cross_attention_dim=cross_attention_dim,
148
+ attn_num_head_channels=attention_head_dim[i],
149
+ downsample_padding=downsample_padding,
150
+ dual_cross_attention=dual_cross_attention,
151
+ use_linear_projection=use_linear_projection,
152
+ only_cross_attention=only_cross_attention[i],
153
+ upcast_attention=upcast_attention,
154
+ resnet_time_scale_shift=resnet_time_scale_shift,
155
+
156
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
157
+ unet_use_temporal_attention=unet_use_temporal_attention,
158
+ use_inflated_groupnorm=use_inflated_groupnorm,
159
+
160
+ use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
161
+ motion_module_type=motion_module_type,
162
+ motion_module_kwargs=motion_module_kwargs,
163
+ )
164
+ self.down_blocks.append(down_block)
165
+
166
+ # mid
167
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
168
+ self.mid_block = UNetMidBlock3DCrossAttn(
169
+ in_channels=block_out_channels[-1],
170
+ temb_channels=time_embed_dim,
171
+ resnet_eps=norm_eps,
172
+ resnet_act_fn=act_fn,
173
+ output_scale_factor=mid_block_scale_factor,
174
+ resnet_time_scale_shift=resnet_time_scale_shift,
175
+ cross_attention_dim=cross_attention_dim,
176
+ attn_num_head_channels=attention_head_dim[-1],
177
+ resnet_groups=norm_num_groups,
178
+ dual_cross_attention=dual_cross_attention,
179
+ use_linear_projection=use_linear_projection,
180
+ upcast_attention=upcast_attention,
181
+
182
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
183
+ unet_use_temporal_attention=unet_use_temporal_attention,
184
+ use_inflated_groupnorm=use_inflated_groupnorm,
185
+
186
+ use_motion_module=use_motion_module and motion_module_mid_block,
187
+ motion_module_type=motion_module_type,
188
+ motion_module_kwargs=motion_module_kwargs,
189
+ )
190
+ else:
191
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
192
+
193
+ # count how many layers upsample the videos
194
+ self.num_upsamplers = 0
195
+
196
+ # up
197
+ reversed_block_out_channels = list(reversed(block_out_channels))
198
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
199
+ only_cross_attention = list(reversed(only_cross_attention))
200
+ output_channel = reversed_block_out_channels[0]
201
+ for i, up_block_type in enumerate(up_block_types):
202
+ res = 2 ** (3 - i)
203
+ is_final_block = i == len(block_out_channels) - 1
204
+
205
+ prev_output_channel = output_channel
206
+ output_channel = reversed_block_out_channels[i]
207
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
208
+
209
+ # add upsample block for all BUT final layer
210
+ if not is_final_block:
211
+ add_upsample = True
212
+ self.num_upsamplers += 1
213
+ else:
214
+ add_upsample = False
215
+
216
+ up_block = get_up_block(
217
+ up_block_type,
218
+ num_layers=layers_per_block + 1,
219
+ in_channels=input_channel,
220
+ out_channels=output_channel,
221
+ prev_output_channel=prev_output_channel,
222
+ temb_channels=time_embed_dim,
223
+ add_upsample=add_upsample,
224
+ resnet_eps=norm_eps,
225
+ resnet_act_fn=act_fn,
226
+ resnet_groups=norm_num_groups,
227
+ cross_attention_dim=cross_attention_dim,
228
+ attn_num_head_channels=reversed_attention_head_dim[i],
229
+ dual_cross_attention=dual_cross_attention,
230
+ use_linear_projection=use_linear_projection,
231
+ only_cross_attention=only_cross_attention[i],
232
+ upcast_attention=upcast_attention,
233
+ resnet_time_scale_shift=resnet_time_scale_shift,
234
+
235
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
236
+ unet_use_temporal_attention=unet_use_temporal_attention,
237
+ use_inflated_groupnorm=use_inflated_groupnorm,
238
+
239
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
240
+ motion_module_type=motion_module_type,
241
+ motion_module_kwargs=motion_module_kwargs,
242
+ )
243
+ self.up_blocks.append(up_block)
244
+ prev_output_channel = output_channel
245
+
246
+ # out
247
+ if use_inflated_groupnorm:
248
+ self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
249
+ else:
250
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
251
+ self.conv_act = nn.SiLU()
252
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
253
+
254
+
255
+ def set_attention_slice(self, slice_size):
256
+ r"""
257
+ Enable sliced attention computation.
258
+
259
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
260
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
261
+
262
+ Args:
263
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
264
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
265
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
266
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
267
+ must be a multiple of `slice_size`.
268
+ """
269
+ sliceable_head_dims = []
270
+
271
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
272
+ if hasattr(module, "set_attention_slice"):
273
+ sliceable_head_dims.append(module.sliceable_head_dim)
274
+
275
+ for child in module.children():
276
+ fn_recursive_retrieve_slicable_dims(child)
277
+
278
+ # retrieve number of attention layers
279
+ for module in self.children():
280
+ fn_recursive_retrieve_slicable_dims(module)
281
+
282
+ num_slicable_layers = len(sliceable_head_dims)
283
+
284
+ if slice_size == "auto":
285
+ # half the attention head size is usually a good trade-off between
286
+ # speed and memory
287
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
288
+ elif slice_size == "max":
289
+ # make smallest slice possible
290
+ slice_size = num_slicable_layers * [1]
291
+
292
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
293
+
294
+ if len(slice_size) != len(sliceable_head_dims):
295
+ raise ValueError(
296
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
297
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
298
+ )
299
+
300
+ for i in range(len(slice_size)):
301
+ size = slice_size[i]
302
+ dim = sliceable_head_dims[i]
303
+ if size is not None and size > dim:
304
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
305
+
306
+ # Recursively walk through all the children.
307
+ # Any children which exposes the set_attention_slice method
308
+ # gets the message
309
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
310
+ if hasattr(module, "set_attention_slice"):
311
+ module.set_attention_slice(slice_size.pop())
312
+
313
+ for child in module.children():
314
+ fn_recursive_set_attention_slice(child, slice_size)
315
+
316
+ reversed_slice_size = list(reversed(slice_size))
317
+ for module in self.children():
318
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
319
+
320
+
321
+ def _set_gradient_checkpointing(self, module, value=False):
322
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
323
+ module.gradient_checkpointing = value
324
+
325
+
326
+ def get_image_controlnet(self, controlnet_noisy_latents, timesteps,
327
+ encoder_hidden_states=None,
328
+ controlnet_cond=None,
329
+ conditioning_mask=None,
330
+ conditioning_scale=None,
331
+ guess_mode=False,
332
+ return_dict=False,):
333
+ down_block_additional_residuals, mid_block_additional_residual = self.image_controlnet(
334
+ controlnet_noisy_latents, timesteps,
335
+ encoder_hidden_states=encoder_hidden_states,
336
+ controlnet_cond=controlnet_cond,
337
+ conditioning_mask=conditioning_mask,
338
+ conditioning_scale=conditioning_scale,
339
+ guess_mode=guess_mode,
340
+ return_dict=return_dict,
341
+ )
342
+ return down_block_additional_residuals, mid_block_additional_residual
343
+
344
+
345
+ def get_flow_controlnet(self, controlnet_noisy_latents, timesteps,
346
+ encoder_hidden_states=None,
347
+ controlnet_cond=None,
348
+ conditioning_mask=None,
349
+ conditioning_scale=None,
350
+ guess_mode=False,
351
+ return_dict=False,):
352
+ down_block_additional_residuals, mid_block_additional_residual = self.omcm_controlnet(
353
+ controlnet_noisy_latents, timesteps,
354
+ encoder_hidden_states=encoder_hidden_states,
355
+ controlnet_cond=controlnet_cond,
356
+ conditioning_mask=conditioning_mask,
357
+ conditioning_scale=conditioning_scale,
358
+ guess_mode=guess_mode,
359
+ return_dict=return_dict,
360
+ )
361
+ return down_block_additional_residuals, mid_block_additional_residual
362
+
363
+
364
+ def forward(
365
+ self,
366
+ sample: torch.FloatTensor,
367
+ timestep: Union[torch.Tensor, float, int],
368
+ encoder_hidden_states: torch.Tensor,
369
+ class_labels: Optional[torch.Tensor] = None,
370
+ attention_mask: Optional[torch.Tensor] = None,
371
+
372
+ # support image controlnet
373
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
374
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
375
+
376
+ # support flow controlnet
377
+ flow_down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
378
+ flow_mid_block_additional_residual: Optional[torch.Tensor] = None,
379
+
380
+ return_dict: bool = True,
381
+ ) -> Union[UNet3DConditionFlowModelOutput, Tuple]:
382
+ r"""
383
+ Args:
384
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
385
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
386
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
387
+ return_dict (`bool`, *optional*, defaults to `True`):
388
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
389
+
390
+ Returns:
391
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
392
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
393
+ returning a tuple, the first element is the sample tensor.
394
+ """
395
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
396
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
397
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
398
+ # on the fly if necessary.
399
+
400
+ default_overall_up_factor = 2**self.num_upsamplers
401
+
402
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
403
+ forward_upsample_size = False
404
+ upsample_size = None
405
+
406
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
407
+ logger.info("Forward upsample size to force interpolation output size.")
408
+ forward_upsample_size = True
409
+
410
+ # prepare attention_mask
411
+ if attention_mask is not None:
412
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
413
+ attention_mask = attention_mask.unsqueeze(1)
414
+
415
+ # center input if necessary
416
+ if self.config.center_input_sample:
417
+ sample = 2 * sample - 1.0
418
+
419
+ # time
420
+ timesteps = timestep
421
+ if not torch.is_tensor(timesteps):
422
+ # This would be a good case for the `match` statement (Python 3.10+)
423
+ is_mps = sample.device.type == "mps"
424
+ if isinstance(timestep, float):
425
+ dtype = torch.float32 if is_mps else torch.float64
426
+ else:
427
+ dtype = torch.int32 if is_mps else torch.int64
428
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
429
+ elif len(timesteps.shape) == 0:
430
+ timesteps = timesteps[None].to(sample.device)
431
+
432
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
433
+ timesteps = timesteps.expand(sample.shape[0])
434
+
435
+ t_emb = self.time_proj(timesteps)
436
+
437
+ # timesteps does not contain any weights and will always return f32 tensors
438
+ # but time_embedding might actually be running in fp16. so we need to cast here.
439
+ # there might be better ways to encapsulate this.
440
+ t_emb = t_emb.to(dtype=self.dtype)
441
+ emb = self.time_embedding(t_emb)
442
+
443
+ if self.class_embedding is not None:
444
+ if class_labels is None:
445
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
446
+
447
+ if self.config.class_embed_type == "timestep":
448
+ class_labels = self.time_proj(class_labels)
449
+
450
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
451
+ emb = emb + class_emb
452
+
453
+ # pre-process
454
+ sample = self.conv_in(sample)
455
+
456
+ # down
457
+ down_block_res_samples = (sample,)
458
+ for downsample_block in self.down_blocks:
459
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
460
+ sample, res_samples = downsample_block(
461
+ hidden_states=sample,
462
+ temb=emb,
463
+ encoder_hidden_states=encoder_hidden_states,
464
+ attention_mask=attention_mask,
465
+ )
466
+ else:
467
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
468
+
469
+ down_block_res_samples += res_samples
470
+
471
+ # support controlnet
472
+ # image controlnet
473
+ down_block_res_samples = list(down_block_res_samples)
474
+ if down_block_additional_residuals is not None:
475
+ for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
476
+ if down_block_additional_residual.dim() == 4: # boardcast
477
+ down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
478
+ down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
479
+
480
+ # flow controlnet
481
+ if flow_down_block_additional_residuals is not None:
482
+ for i, down_block_additional_residual in enumerate(flow_down_block_additional_residuals):
483
+ if down_block_additional_residual.dim() == 4: # boardcast
484
+ down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
485
+ down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
486
+
487
+
488
+ # mid
489
+ sample = self.mid_block(
490
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
491
+ )
492
+
493
+ # support controlnet
494
+ # image controlnet
495
+ if mid_block_additional_residual is not None:
496
+ if mid_block_additional_residual.dim() == 4: # boardcast
497
+ mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
498
+ sample = sample + mid_block_additional_residual
499
+
500
+ # flow controlnet
501
+ if flow_mid_block_additional_residual is not None:
502
+ if flow_mid_block_additional_residual.dim() == 4: # boardcast
503
+ flow_mid_block_additional_residual = flow_mid_block_additional_residual.unsqueeze(2)
504
+ sample = sample + flow_mid_block_additional_residual
505
+
506
+
507
+ # up
508
+ for i, upsample_block in enumerate(self.up_blocks):
509
+ is_final_block = i == len(self.up_blocks) - 1
510
+
511
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
512
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
513
+
514
+ # if we have not reached the final block and need to forward the
515
+ # upsample size, we do it here
516
+ if not is_final_block and forward_upsample_size:
517
+ upsample_size = down_block_res_samples[-1].shape[2:]
518
+
519
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
520
+ sample = upsample_block(
521
+ hidden_states=sample,
522
+ temb=emb,
523
+ res_hidden_states_tuple=res_samples,
524
+ encoder_hidden_states=encoder_hidden_states,
525
+ upsample_size=upsample_size,
526
+ attention_mask=attention_mask,
527
+ )
528
+ else:
529
+ sample = upsample_block(
530
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
531
+ )
532
+
533
+ # post-process
534
+ sample = self.conv_norm_out(sample)
535
+ sample = self.conv_act(sample)
536
+ sample = self.conv_out(sample)
537
+
538
+ if not return_dict:
539
+ return (sample,)
540
+
541
+ return UNet3DConditionFlowModelOutput(sample=sample)
542
+
543
+ @classmethod
544
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
545
+ if subfolder is not None:
546
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
547
+ print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...")
548
+
549
+ config_file = os.path.join(pretrained_model_path, 'config.json')
550
+ if not os.path.isfile(config_file):
551
+ raise RuntimeError(f"{config_file} does not exist")
552
+ with open(config_file, "r") as f:
553
+ config = json.load(f)
554
+ config["_class_name"] = cls.__name__
555
+ config["down_block_types"] = [
556
+ "CrossAttnDownBlock3D",
557
+ "CrossAttnDownBlock3D",
558
+ "CrossAttnDownBlock3D",
559
+ "DownBlock3D"
560
+ ]
561
+ config["up_block_types"] = [
562
+ "UpBlock3D",
563
+ "CrossAttnUpBlock3D",
564
+ "CrossAttnUpBlock3D",
565
+ "CrossAttnUpBlock3D"
566
+ ]
567
+
568
+ from diffusers.utils import WEIGHTS_NAME
569
+ model = cls.from_config(config, **unet_additional_kwargs)
570
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
571
+ if not os.path.isfile(model_file):
572
+ raise RuntimeError(f"{model_file} does not exist")
573
+ state_dict = torch.load(model_file, map_location="cpu")
574
+
575
+ m, u = model.load_state_dict(state_dict, strict=False)
576
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
577
+
578
+ motion_params = [p.numel() if "motion_modules." in n else 0 for n,p in model.named_parameters()]
579
+ motion_name = [n for n in model.state_dict().keys() if "motion_modules." in n]
580
+
581
+ print(f"### Motion Module Parameters: {sum(motion_params) / 1e6} M")
582
+ print(f"### Motion Module keys: {len(motion_name)}")
583
+
584
+ unnorlmal = []
585
+ for n in m:
586
+ if n not in motion_name:
587
+ unnorlmal.append(n)
588
+
589
+ return model
590
+
591
+ 'motion_modules.' in 'up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe'
modules/unet_blocks.py ADDED
@@ -0,0 +1,866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ from typing import Any, Dict, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from .attention import Transformer3DModel
9
+ from .motion_module import get_motion_module
10
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
11
+
12
+
13
+ def get_down_block(
14
+ down_block_type,
15
+ num_layers,
16
+ in_channels,
17
+ out_channels,
18
+ temb_channels,
19
+ add_downsample,
20
+ resnet_eps,
21
+ resnet_act_fn,
22
+ attn_num_head_channels,
23
+ resnet_groups=None,
24
+ cross_attention_dim=None,
25
+ downsample_padding=None,
26
+ dual_cross_attention=False,
27
+ use_linear_projection=False,
28
+ only_cross_attention=False,
29
+ upcast_attention=False,
30
+ resnet_time_scale_shift="default",
31
+ unet_use_cross_frame_attention=False,
32
+ unet_use_temporal_attention=False,
33
+ use_inflated_groupnorm=False,
34
+ use_motion_module=None,
35
+ motion_module_type=None,
36
+ motion_module_kwargs=None,
37
+ ):
38
+ down_block_type = (
39
+ down_block_type[7:]
40
+ if down_block_type.startswith("UNetRes")
41
+ else down_block_type
42
+ )
43
+ if down_block_type == "DownBlock3D":
44
+ return DownBlock3D(
45
+ num_layers=num_layers,
46
+ in_channels=in_channels,
47
+ out_channels=out_channels,
48
+ temb_channels=temb_channels,
49
+ add_downsample=add_downsample,
50
+ resnet_eps=resnet_eps,
51
+ resnet_act_fn=resnet_act_fn,
52
+ resnet_groups=resnet_groups,
53
+ downsample_padding=downsample_padding,
54
+ resnet_time_scale_shift=resnet_time_scale_shift,
55
+ use_inflated_groupnorm=use_inflated_groupnorm,
56
+ use_motion_module=use_motion_module,
57
+ motion_module_type=motion_module_type,
58
+ motion_module_kwargs=motion_module_kwargs,
59
+ )
60
+ elif down_block_type == "CrossAttnDownBlock3D":
61
+ if cross_attention_dim is None:
62
+ raise ValueError(
63
+ "cross_attention_dim must be specified for CrossAttnDownBlock3D"
64
+ )
65
+ return CrossAttnDownBlock3D(
66
+ num_layers=num_layers,
67
+ in_channels=in_channels,
68
+ out_channels=out_channels,
69
+ temb_channels=temb_channels,
70
+ add_downsample=add_downsample,
71
+ resnet_eps=resnet_eps,
72
+ resnet_act_fn=resnet_act_fn,
73
+ resnet_groups=resnet_groups,
74
+ downsample_padding=downsample_padding,
75
+ cross_attention_dim=cross_attention_dim,
76
+ attn_num_head_channels=attn_num_head_channels,
77
+ dual_cross_attention=dual_cross_attention,
78
+ use_linear_projection=use_linear_projection,
79
+ only_cross_attention=only_cross_attention,
80
+ upcast_attention=upcast_attention,
81
+ resnet_time_scale_shift=resnet_time_scale_shift,
82
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
83
+ unet_use_temporal_attention=unet_use_temporal_attention,
84
+ use_inflated_groupnorm=use_inflated_groupnorm,
85
+ use_motion_module=use_motion_module,
86
+ motion_module_type=motion_module_type,
87
+ motion_module_kwargs=motion_module_kwargs,
88
+ )
89
+ raise ValueError(f"{down_block_type} does not exist.")
90
+
91
+
92
+ def get_up_block(
93
+ up_block_type,
94
+ num_layers,
95
+ in_channels,
96
+ out_channels,
97
+ prev_output_channel,
98
+ temb_channels,
99
+ add_upsample,
100
+ resnet_eps,
101
+ resnet_act_fn,
102
+ attn_num_head_channels,
103
+ resnet_groups=None,
104
+ cross_attention_dim=None,
105
+ dual_cross_attention=False,
106
+ use_linear_projection=False,
107
+ only_cross_attention=False,
108
+ upcast_attention=False,
109
+ resnet_time_scale_shift="default",
110
+ unet_use_cross_frame_attention=False,
111
+ unet_use_temporal_attention=False,
112
+ use_inflated_groupnorm=False,
113
+ use_motion_module=None,
114
+ motion_module_type=None,
115
+ motion_module_kwargs=None,
116
+ ):
117
+ up_block_type = (
118
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
119
+ )
120
+ if up_block_type == "UpBlock3D":
121
+ return UpBlock3D(
122
+ num_layers=num_layers,
123
+ in_channels=in_channels,
124
+ out_channels=out_channels,
125
+ prev_output_channel=prev_output_channel,
126
+ temb_channels=temb_channels,
127
+ add_upsample=add_upsample,
128
+ resnet_eps=resnet_eps,
129
+ resnet_act_fn=resnet_act_fn,
130
+ resnet_groups=resnet_groups,
131
+ resnet_time_scale_shift=resnet_time_scale_shift,
132
+ use_inflated_groupnorm=use_inflated_groupnorm,
133
+ use_motion_module=use_motion_module,
134
+ motion_module_type=motion_module_type,
135
+ motion_module_kwargs=motion_module_kwargs,
136
+ )
137
+ elif up_block_type == "CrossAttnUpBlock3D":
138
+ if cross_attention_dim is None:
139
+ raise ValueError(
140
+ "cross_attention_dim must be specified for CrossAttnUpBlock3D"
141
+ )
142
+ return CrossAttnUpBlock3D(
143
+ num_layers=num_layers,
144
+ in_channels=in_channels,
145
+ out_channels=out_channels,
146
+ prev_output_channel=prev_output_channel,
147
+ temb_channels=temb_channels,
148
+ add_upsample=add_upsample,
149
+ resnet_eps=resnet_eps,
150
+ resnet_act_fn=resnet_act_fn,
151
+ resnet_groups=resnet_groups,
152
+ cross_attention_dim=cross_attention_dim,
153
+ attn_num_head_channels=attn_num_head_channels,
154
+ dual_cross_attention=dual_cross_attention,
155
+ use_linear_projection=use_linear_projection,
156
+ only_cross_attention=only_cross_attention,
157
+ upcast_attention=upcast_attention,
158
+ resnet_time_scale_shift=resnet_time_scale_shift,
159
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
160
+ unet_use_temporal_attention=unet_use_temporal_attention,
161
+ use_inflated_groupnorm=use_inflated_groupnorm,
162
+ use_motion_module=use_motion_module,
163
+ motion_module_type=motion_module_type,
164
+ motion_module_kwargs=motion_module_kwargs,
165
+ )
166
+ raise ValueError(f"{up_block_type} does not exist.")
167
+
168
+
169
+ class UNetMidBlock3DCrossAttn(nn.Module):
170
+
171
+ def __init__(
172
+ self,
173
+ in_channels: int,
174
+ temb_channels: int,
175
+ dropout: float = 0.0,
176
+ num_layers: int = 1,
177
+ resnet_eps: float = 1e-6,
178
+ resnet_time_scale_shift: str = "default",
179
+ resnet_act_fn: str = "swish",
180
+ resnet_groups: int = 32,
181
+ resnet_pre_norm: bool = True,
182
+ attn_num_head_channels=1,
183
+ output_scale_factor=1.0,
184
+ cross_attention_dim=1280,
185
+ dual_cross_attention=False,
186
+ use_linear_projection=False,
187
+ upcast_attention=False,
188
+ unet_use_cross_frame_attention=False,
189
+ unet_use_temporal_attention=False,
190
+ use_inflated_groupnorm=False,
191
+ use_motion_module=None,
192
+ motion_module_type=None,
193
+ motion_module_kwargs=None,
194
+ ):
195
+ super().__init__()
196
+
197
+ self.has_cross_attention = True
198
+ self.attn_num_head_channels = attn_num_head_channels
199
+ resnet_groups = (
200
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
201
+ )
202
+
203
+ # there is always at least one resnet
204
+ resnets = [
205
+ ResnetBlock3D(
206
+ in_channels=in_channels,
207
+ out_channels=in_channels,
208
+ temb_channels=temb_channels,
209
+ eps=resnet_eps,
210
+ groups=resnet_groups,
211
+ dropout=dropout,
212
+ time_embedding_norm=resnet_time_scale_shift,
213
+ non_linearity=resnet_act_fn,
214
+ output_scale_factor=output_scale_factor,
215
+ pre_norm=resnet_pre_norm,
216
+ use_inflated_groupnorm=use_inflated_groupnorm,
217
+ )
218
+ ]
219
+ attentions = []
220
+ motion_modules = []
221
+
222
+ for _ in range(num_layers):
223
+ if dual_cross_attention:
224
+ raise NotImplementedError
225
+ attentions.append(
226
+ Transformer3DModel(
227
+ attn_num_head_channels,
228
+ in_channels // attn_num_head_channels,
229
+ in_channels=in_channels,
230
+ num_layers=1,
231
+ cross_attention_dim=cross_attention_dim,
232
+ norm_num_groups=resnet_groups,
233
+ use_linear_projection=use_linear_projection,
234
+ upcast_attention=upcast_attention,
235
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
236
+ unet_use_temporal_attention=unet_use_temporal_attention,
237
+ )
238
+ )
239
+ motion_modules.append(
240
+ get_motion_module(
241
+ in_channels=in_channels,
242
+ motion_module_type=motion_module_type,
243
+ motion_module_kwargs=motion_module_kwargs,
244
+ )
245
+ if use_motion_module
246
+ else None
247
+ )
248
+ resnets.append(
249
+ ResnetBlock3D(
250
+ in_channels=in_channels,
251
+ out_channels=in_channels,
252
+ temb_channels=temb_channels,
253
+ eps=resnet_eps,
254
+ groups=resnet_groups,
255
+ dropout=dropout,
256
+ time_embedding_norm=resnet_time_scale_shift,
257
+ non_linearity=resnet_act_fn,
258
+ output_scale_factor=output_scale_factor,
259
+ pre_norm=resnet_pre_norm,
260
+ use_inflated_groupnorm=use_inflated_groupnorm,
261
+ )
262
+ )
263
+
264
+ self.attentions = nn.ModuleList(attentions)
265
+ self.resnets = nn.ModuleList(resnets)
266
+ self.motion_modules = nn.ModuleList(motion_modules)
267
+
268
+ def forward(
269
+ self,
270
+ hidden_states: torch.FloatTensor,
271
+ temb: Optional[torch.FloatTensor] = None,
272
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
273
+ attention_mask: Optional[torch.FloatTensor] = None,
274
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
275
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
276
+ ) -> torch.FloatTensor:
277
+ hidden_states = self.resnets[0](hidden_states, temb)
278
+ for attn, resnet, motion_module in zip(
279
+ self.attentions, self.resnets[1:], self.motion_modules
280
+ ):
281
+ hidden_states = attn(
282
+ hidden_states,
283
+ encoder_hidden_states=encoder_hidden_states,
284
+ cross_attention_kwargs=cross_attention_kwargs,
285
+ attention_mask=attention_mask,
286
+ encoder_attention_mask=encoder_attention_mask,
287
+ return_dict=False,
288
+ )[0]
289
+ if motion_module is not None:
290
+ hidden_states = motion_module(
291
+ hidden_states,
292
+ temb,
293
+ encoder_hidden_states=encoder_hidden_states,
294
+ )
295
+ hidden_states = resnet(hidden_states, temb)
296
+
297
+ return hidden_states
298
+
299
+
300
+ class CrossAttnDownBlock3D(nn.Module):
301
+
302
+ def __init__(
303
+ self,
304
+ in_channels: int,
305
+ out_channels: int,
306
+ temb_channels: int,
307
+ dropout: float = 0.0,
308
+ num_layers: int = 1,
309
+ transformer_layers_per_block: int = 1,
310
+ resnet_eps: float = 1e-6,
311
+ resnet_time_scale_shift: str = "default",
312
+ resnet_act_fn: str = "swish",
313
+ resnet_groups: int = 32,
314
+ resnet_pre_norm: bool = True,
315
+ attn_num_head_channels=1,
316
+ cross_attention_dim=1280,
317
+ output_scale_factor=1.0,
318
+ downsample_padding=1,
319
+ add_downsample=True,
320
+ dual_cross_attention=False,
321
+ use_linear_projection=False,
322
+ only_cross_attention=False,
323
+ upcast_attention=False,
324
+ unet_use_cross_frame_attention=False,
325
+ unet_use_temporal_attention=False,
326
+ use_inflated_groupnorm=False,
327
+ use_motion_module=None,
328
+ motion_module_type=None,
329
+ motion_module_kwargs=None,
330
+ ):
331
+ super().__init__()
332
+ resnets = []
333
+ attentions = []
334
+ motion_modules = []
335
+
336
+ self.has_cross_attention = True
337
+ self.attn_num_head_channels = attn_num_head_channels
338
+
339
+ for i in range(num_layers):
340
+ in_channels = in_channels if i == 0 else out_channels
341
+ resnets.append(
342
+ ResnetBlock3D(
343
+ in_channels=in_channels,
344
+ out_channels=out_channels,
345
+ temb_channels=temb_channels,
346
+ eps=resnet_eps,
347
+ groups=resnet_groups,
348
+ dropout=dropout,
349
+ time_embedding_norm=resnet_time_scale_shift,
350
+ non_linearity=resnet_act_fn,
351
+ output_scale_factor=output_scale_factor,
352
+ pre_norm=resnet_pre_norm,
353
+ use_inflated_groupnorm=use_inflated_groupnorm,
354
+ )
355
+ )
356
+ if dual_cross_attention:
357
+ raise NotImplementedError
358
+ attentions.append(
359
+ Transformer3DModel(
360
+ num_attention_heads=attn_num_head_channels,
361
+ attention_head_dim=out_channels // attn_num_head_channels,
362
+ in_channels=out_channels,
363
+ num_layers=transformer_layers_per_block,
364
+ cross_attention_dim=cross_attention_dim,
365
+ norm_num_groups=resnet_groups,
366
+ use_linear_projection=use_linear_projection,
367
+ only_cross_attention=only_cross_attention,
368
+ upcast_attention=upcast_attention,
369
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
370
+ unet_use_temporal_attention=unet_use_temporal_attention,
371
+ )
372
+ )
373
+ motion_modules.append(
374
+ get_motion_module(
375
+ in_channels=out_channels,
376
+ motion_module_type=motion_module_type,
377
+ motion_module_kwargs=motion_module_kwargs,
378
+ )
379
+ if use_motion_module
380
+ else None
381
+ )
382
+
383
+ self.attentions = nn.ModuleList(attentions)
384
+ self.resnets = nn.ModuleList(resnets)
385
+ self.motion_modules = nn.ModuleList(motion_modules)
386
+
387
+ if add_downsample:
388
+ self.downsamplers = nn.ModuleList(
389
+ [
390
+ Downsample3D(
391
+ out_channels,
392
+ use_conv=True,
393
+ out_channels=out_channels,
394
+ padding=downsample_padding,
395
+ name="op",
396
+ )
397
+ ]
398
+ )
399
+ else:
400
+ self.downsamplers = None
401
+
402
+ self.gradient_checkpointing = False
403
+
404
+ def forward(
405
+ self,
406
+ hidden_states: torch.FloatTensor,
407
+ temb: Optional[torch.FloatTensor] = None,
408
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
409
+ attention_mask: Optional[torch.FloatTensor] = None,
410
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
411
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
412
+ ) -> torch.FloatTensor:
413
+ output_states = ()
414
+
415
+ for resnet, attn, motion_module in zip(
416
+ self.resnets, self.attentions, self.motion_modules
417
+ ):
418
+ if self.training and self.gradient_checkpointing:
419
+
420
+ def create_custom_forward(module, return_dict=None):
421
+ def custom_forward(*inputs):
422
+ if return_dict is not None:
423
+ return module(*inputs, return_dict=return_dict)
424
+ else:
425
+ return module(*inputs)
426
+
427
+ return custom_forward
428
+
429
+ hidden_states = torch.utils.checkpoint.checkpoint(
430
+ create_custom_forward(resnet), hidden_states, temb
431
+ )
432
+ hidden_states = torch.utils.checkpoint.checkpoint(
433
+ create_custom_forward(attn, return_dict=False),
434
+ hidden_states,
435
+ encoder_hidden_states,
436
+ )[0]
437
+ if motion_module is not None:
438
+ hidden_states = torch.utils.checkpoint.checkpoint(
439
+ create_custom_forward(motion_module),
440
+ hidden_states.requires_grad_(),
441
+ temb,
442
+ encoder_hidden_states,
443
+ )
444
+
445
+ else:
446
+ hidden_states = resnet(hidden_states, temb)
447
+ hidden_states = attn(
448
+ hidden_states,
449
+ encoder_hidden_states=encoder_hidden_states,
450
+ cross_attention_kwargs=cross_attention_kwargs,
451
+ attention_mask=attention_mask,
452
+ encoder_attention_mask=encoder_attention_mask,
453
+ return_dict=False,
454
+ )[0]
455
+ # add motion module
456
+ hidden_states = (
457
+ motion_module(
458
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
459
+ )
460
+ if motion_module is not None
461
+ else hidden_states
462
+ )
463
+
464
+ output_states = output_states + (hidden_states,)
465
+
466
+ if self.downsamplers is not None:
467
+ for downsampler in self.downsamplers:
468
+ hidden_states = downsampler(hidden_states)
469
+
470
+ output_states = output_states + (hidden_states,)
471
+
472
+ return hidden_states, output_states
473
+
474
+
475
+ class DownBlock3D(nn.Module):
476
+ def __init__(
477
+ self,
478
+ in_channels: int,
479
+ out_channels: int,
480
+ temb_channels: int,
481
+ dropout: float = 0.0,
482
+ num_layers: int = 1,
483
+ resnet_eps: float = 1e-6,
484
+ resnet_time_scale_shift: str = "default",
485
+ resnet_act_fn: str = "swish",
486
+ resnet_groups: int = 32,
487
+ resnet_pre_norm: bool = True,
488
+ output_scale_factor=1.0,
489
+ add_downsample=True,
490
+ downsample_padding=1,
491
+ use_inflated_groupnorm=None,
492
+ use_motion_module=None,
493
+ motion_module_type=None,
494
+ motion_module_kwargs=None,
495
+ ):
496
+ super().__init__()
497
+ resnets = []
498
+ motion_modules = []
499
+
500
+ for i in range(num_layers):
501
+ in_channels = in_channels if i == 0 else out_channels
502
+ resnets.append(
503
+ ResnetBlock3D(
504
+ in_channels=in_channels,
505
+ out_channels=out_channels,
506
+ temb_channels=temb_channels,
507
+ eps=resnet_eps,
508
+ groups=resnet_groups,
509
+ dropout=dropout,
510
+ time_embedding_norm=resnet_time_scale_shift,
511
+ non_linearity=resnet_act_fn,
512
+ output_scale_factor=output_scale_factor,
513
+ pre_norm=resnet_pre_norm,
514
+ use_inflated_groupnorm=use_inflated_groupnorm,
515
+ )
516
+ )
517
+ motion_modules.append(
518
+ get_motion_module(
519
+ in_channels=out_channels,
520
+ motion_module_type=motion_module_type,
521
+ motion_module_kwargs=motion_module_kwargs,
522
+ )
523
+ if use_motion_module
524
+ else None
525
+ )
526
+
527
+ self.resnets = nn.ModuleList(resnets)
528
+ self.motion_modules = nn.ModuleList(motion_modules)
529
+
530
+ if add_downsample:
531
+ self.downsamplers = nn.ModuleList(
532
+ [
533
+ Downsample3D(
534
+ out_channels,
535
+ use_conv=True,
536
+ out_channels=out_channels,
537
+ padding=downsample_padding,
538
+ name="op",
539
+ )
540
+ ]
541
+ )
542
+ else:
543
+ self.downsamplers = None
544
+
545
+ self.gradient_checkpointing = False
546
+
547
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
548
+ output_states = ()
549
+
550
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
551
+ if self.training and self.gradient_checkpointing:
552
+
553
+ def create_custom_forward(module):
554
+ def custom_forward(*inputs):
555
+ return module(*inputs)
556
+
557
+ return custom_forward
558
+
559
+ hidden_states = torch.utils.checkpoint.checkpoint(
560
+ create_custom_forward(resnet), hidden_states, temb
561
+ )
562
+ if motion_module is not None:
563
+ hidden_states = torch.utils.checkpoint.checkpoint(
564
+ create_custom_forward(motion_module),
565
+ hidden_states.requires_grad_(),
566
+ temb,
567
+ encoder_hidden_states,
568
+ )
569
+ else:
570
+ hidden_states = resnet(hidden_states, temb)
571
+
572
+ # add motion module
573
+ if motion_module:
574
+ hidden_states = motion_module(
575
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
576
+ )
577
+
578
+ output_states = output_states + (hidden_states,)
579
+
580
+ if self.downsamplers is not None:
581
+ for downsampler in self.downsamplers:
582
+ hidden_states = downsampler(hidden_states)
583
+
584
+ output_states = output_states + (hidden_states,)
585
+
586
+ return hidden_states, output_states
587
+
588
+
589
+ class CrossAttnUpBlock3D(nn.Module):
590
+
591
+ def __init__(
592
+ self,
593
+ in_channels: int,
594
+ out_channels: int,
595
+ prev_output_channel: int,
596
+ temb_channels: int,
597
+ dropout: float = 0.0,
598
+ num_layers: int = 1,
599
+ transformer_layers_per_block: int = 1,
600
+ resnet_eps: float = 1e-6,
601
+ resnet_time_scale_shift: str = "default",
602
+ resnet_act_fn: str = "swish",
603
+ resnet_groups: int = 32,
604
+ resnet_pre_norm: bool = True,
605
+ attn_num_head_channels=1,
606
+ cross_attention_dim=1280,
607
+ output_scale_factor=1.0,
608
+ add_upsample=True,
609
+ dual_cross_attention=False,
610
+ use_linear_projection=False,
611
+ only_cross_attention=False,
612
+ upcast_attention=False,
613
+ unet_use_cross_frame_attention=False,
614
+ unet_use_temporal_attention=False,
615
+ use_inflated_groupnorm=False,
616
+ use_motion_module=None,
617
+ motion_module_type=None,
618
+ motion_module_kwargs=None,
619
+ ):
620
+ super().__init__()
621
+ resnets = []
622
+ attentions = []
623
+ motion_modules = []
624
+
625
+ self.has_cross_attention = True
626
+ self.attn_num_head_channels = attn_num_head_channels
627
+
628
+ for i in range(num_layers):
629
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
630
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
631
+
632
+ resnets.append(
633
+ ResnetBlock3D(
634
+ in_channels=resnet_in_channels + res_skip_channels,
635
+ out_channels=out_channels,
636
+ temb_channels=temb_channels,
637
+ eps=resnet_eps,
638
+ groups=resnet_groups,
639
+ dropout=dropout,
640
+ time_embedding_norm=resnet_time_scale_shift,
641
+ non_linearity=resnet_act_fn,
642
+ output_scale_factor=output_scale_factor,
643
+ pre_norm=resnet_pre_norm,
644
+ use_inflated_groupnorm=use_inflated_groupnorm,
645
+ )
646
+ )
647
+ if dual_cross_attention:
648
+ raise NotImplementedError
649
+ attentions.append(
650
+ Transformer3DModel(
651
+ attn_num_head_channels,
652
+ out_channels // attn_num_head_channels,
653
+ in_channels=out_channels,
654
+ num_layers=transformer_layers_per_block,
655
+ cross_attention_dim=cross_attention_dim,
656
+ norm_num_groups=resnet_groups,
657
+ use_linear_projection=use_linear_projection,
658
+ only_cross_attention=only_cross_attention,
659
+ upcast_attention=upcast_attention,
660
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
661
+ unet_use_temporal_attention=unet_use_temporal_attention,
662
+ )
663
+ )
664
+ motion_modules.append(
665
+ get_motion_module(
666
+ in_channels=out_channels,
667
+ motion_module_type=motion_module_type,
668
+ motion_module_kwargs=motion_module_kwargs,
669
+ )
670
+ if use_motion_module
671
+ else None
672
+ )
673
+
674
+ self.attentions = nn.ModuleList(attentions)
675
+ self.resnets = nn.ModuleList(resnets)
676
+ self.motion_modules = nn.ModuleList(motion_modules)
677
+
678
+ if add_upsample:
679
+ self.upsamplers = nn.ModuleList(
680
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
681
+ )
682
+ else:
683
+ self.upsamplers = None
684
+
685
+ self.gradient_checkpointing = False
686
+
687
+ def forward(
688
+ self,
689
+ hidden_states: torch.FloatTensor,
690
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
691
+ temb: Optional[torch.FloatTensor] = None,
692
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
693
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
694
+ upsample_size: Optional[int] = None,
695
+ attention_mask: Optional[torch.FloatTensor] = None,
696
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
697
+ ):
698
+ for resnet, attn, motion_module in zip(
699
+ self.resnets, self.attentions, self.motion_modules
700
+ ):
701
+ # pop res hidden states
702
+ res_hidden_states = res_hidden_states_tuple[-1]
703
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
704
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
705
+
706
+ if self.training and self.gradient_checkpointing:
707
+
708
+ def create_custom_forward(module, return_dict=None):
709
+ def custom_forward(*inputs):
710
+ if return_dict is not None:
711
+ return module(*inputs, return_dict=return_dict)
712
+ else:
713
+ return module(*inputs)
714
+
715
+ return custom_forward
716
+
717
+ hidden_states = torch.utils.checkpoint.checkpoint(
718
+ create_custom_forward(resnet), hidden_states, temb
719
+ )
720
+ hidden_states = torch.utils.checkpoint.checkpoint(
721
+ create_custom_forward(attn, return_dict=False),
722
+ hidden_states,
723
+ encoder_hidden_states,
724
+ )[0]
725
+ if motion_module is not None:
726
+ hidden_states = torch.utils.checkpoint.checkpoint(
727
+ create_custom_forward(motion_module),
728
+ hidden_states.requires_grad_(),
729
+ temb,
730
+ encoder_hidden_states,
731
+ )
732
+
733
+ else:
734
+ hidden_states = resnet(hidden_states, temb)
735
+ hidden_states = attn(
736
+ hidden_states,
737
+ encoder_hidden_states=encoder_hidden_states,
738
+ cross_attention_kwargs=cross_attention_kwargs,
739
+ attention_mask=attention_mask,
740
+ encoder_attention_mask=encoder_attention_mask,
741
+ return_dict=False,
742
+ )[0]
743
+
744
+ # add motion module
745
+ if motion_module:
746
+ hidden_states = motion_module(
747
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
748
+ )
749
+
750
+ if self.upsamplers is not None:
751
+ for upsampler in self.upsamplers:
752
+ hidden_states = upsampler(hidden_states, upsample_size)
753
+
754
+ return hidden_states
755
+
756
+
757
+ class UpBlock3D(nn.Module):
758
+ def __init__(
759
+ self,
760
+ in_channels: int,
761
+ prev_output_channel: int,
762
+ out_channels: int,
763
+ temb_channels: int,
764
+ dropout: float = 0.0,
765
+ num_layers: int = 1,
766
+ resnet_eps: float = 1e-6,
767
+ resnet_time_scale_shift: str = "default",
768
+ resnet_act_fn: str = "swish",
769
+ resnet_groups: int = 32,
770
+ resnet_pre_norm: bool = True,
771
+ output_scale_factor=1.0,
772
+ add_upsample=True,
773
+ use_inflated_groupnorm=None,
774
+ use_motion_module=None,
775
+ motion_module_type=None,
776
+ motion_module_kwargs=None,
777
+ ):
778
+ super().__init__()
779
+ resnets = []
780
+ motion_modules = []
781
+
782
+ for i in range(num_layers):
783
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
784
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
785
+
786
+ resnets.append(
787
+ ResnetBlock3D(
788
+ in_channels=resnet_in_channels + res_skip_channels,
789
+ out_channels=out_channels,
790
+ temb_channels=temb_channels,
791
+ eps=resnet_eps,
792
+ groups=resnet_groups,
793
+ dropout=dropout,
794
+ time_embedding_norm=resnet_time_scale_shift,
795
+ non_linearity=resnet_act_fn,
796
+ output_scale_factor=output_scale_factor,
797
+ pre_norm=resnet_pre_norm,
798
+ use_inflated_groupnorm=use_inflated_groupnorm,
799
+ )
800
+ )
801
+ motion_modules.append(
802
+ get_motion_module(
803
+ in_channels=out_channels,
804
+ motion_module_type=motion_module_type,
805
+ motion_module_kwargs=motion_module_kwargs,
806
+ )
807
+ if use_motion_module
808
+ else None
809
+ )
810
+
811
+ self.resnets = nn.ModuleList(resnets)
812
+ self.motion_modules = nn.ModuleList(motion_modules)
813
+
814
+ if add_upsample:
815
+ self.upsamplers = nn.ModuleList(
816
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
817
+ )
818
+ else:
819
+ self.upsamplers = None
820
+
821
+ self.gradient_checkpointing = False
822
+
823
+ def forward(
824
+ self,
825
+ hidden_states,
826
+ res_hidden_states_tuple,
827
+ temb=None,
828
+ upsample_size=None,
829
+ encoder_hidden_states=None,
830
+ ):
831
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
832
+ # pop res hidden states
833
+ res_hidden_states = res_hidden_states_tuple[-1]
834
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
835
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
836
+
837
+ if self.training and self.gradient_checkpointing:
838
+
839
+ def create_custom_forward(module):
840
+ def custom_forward(*inputs):
841
+ return module(*inputs)
842
+
843
+ return custom_forward
844
+
845
+ hidden_states = torch.utils.checkpoint.checkpoint(
846
+ create_custom_forward(resnet), hidden_states, temb
847
+ )
848
+ if motion_module is not None:
849
+ hidden_states = torch.utils.checkpoint.checkpoint(
850
+ create_custom_forward(motion_module),
851
+ hidden_states.requires_grad_(),
852
+ temb,
853
+ encoder_hidden_states,
854
+ )
855
+ else:
856
+ hidden_states = resnet(hidden_states, temb)
857
+ if motion_module:
858
+ hidden_states = motion_module(
859
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
860
+ )
861
+
862
+ if self.upsamplers is not None:
863
+ for upsampler in self.upsamplers:
864
+ hidden_states = upsampler(hidden_states, upsample_size)
865
+
866
+ return hidden_states
peft/__init__.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ # There's no way to ignore "F401 '...' imported but unused" warnings in this
3
+ # module, but to preserve other warnings. So, don't check this module at all.
4
+
5
+ # coding=utf-8
6
+ # Copyright 2023-present the HuggingFace Inc. team.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ __version__ = "0.11.1"
21
+
22
+ from .auto import (
23
+ AutoPeftModel,
24
+ AutoPeftModelForCausalLM,
25
+ AutoPeftModelForSequenceClassification,
26
+ AutoPeftModelForSeq2SeqLM,
27
+ AutoPeftModelForTokenClassification,
28
+ AutoPeftModelForQuestionAnswering,
29
+ AutoPeftModelForFeatureExtraction,
30
+ )
31
+ from .mapping import (
32
+ MODEL_TYPE_TO_PEFT_MODEL_MAPPING,
33
+ PEFT_TYPE_TO_CONFIG_MAPPING,
34
+ get_peft_config,
35
+ get_peft_model,
36
+ inject_adapter_in_model,
37
+ )
38
+ from .mixed_model import PeftMixedModel
39
+ from .peft_model import (
40
+ PeftModel,
41
+ PeftModelForCausalLM,
42
+ PeftModelForSeq2SeqLM,
43
+ PeftModelForSequenceClassification,
44
+ PeftModelForTokenClassification,
45
+ PeftModelForQuestionAnswering,
46
+ PeftModelForFeatureExtraction,
47
+ get_layer_status,
48
+ get_model_status,
49
+ )
50
+ from .tuners import (
51
+ AdaptionPromptConfig,
52
+ AdaptionPromptModel,
53
+ LoraConfig,
54
+ LoftQConfig,
55
+ LoraModel,
56
+ LoHaConfig,
57
+ LoHaModel,
58
+ LoKrConfig,
59
+ LoKrModel,
60
+ IA3Config,
61
+ IA3Model,
62
+ AdaLoraConfig,
63
+ AdaLoraModel,
64
+ BOFTConfig,
65
+ BOFTModel,
66
+ PrefixEncoder,
67
+ PrefixTuningConfig,
68
+ PromptEmbedding,
69
+ PromptEncoder,
70
+ PromptEncoderConfig,
71
+ PromptEncoderReparameterizationType,
72
+ PromptTuningConfig,
73
+ PromptTuningInit,
74
+ MultitaskPromptTuningConfig,
75
+ MultitaskPromptTuningInit,
76
+ OFTConfig,
77
+ OFTModel,
78
+ PolyConfig,
79
+ PolyModel,
80
+ LNTuningConfig,
81
+ LNTuningModel,
82
+ VeraConfig,
83
+ VeraModel,
84
+ )
85
+ from .utils import (
86
+ TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
87
+ PeftType,
88
+ TaskType,
89
+ bloom_model_postprocess_past_key_value,
90
+ get_peft_model_state_dict,
91
+ prepare_model_for_kbit_training,
92
+ replace_lora_weights_loftq,
93
+ set_peft_model_state_dict,
94
+ shift_tokens_right,
95
+ load_peft_weights,
96
+ cast_mixed_precision_params,
97
+ )
98
+ from .config import PeftConfig, PromptLearningConfig
peft/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.37 kB). View file
 
peft/__pycache__/auto.cpython-310.pyc ADDED
Binary file (4.84 kB). View file
 
peft/__pycache__/config.cpython-310.pyc ADDED
Binary file (8.79 kB). View file
 
peft/__pycache__/import_utils.cpython-310.pyc ADDED
Binary file (2.18 kB). View file
 
peft/__pycache__/mapping.cpython-310.pyc ADDED
Binary file (4.98 kB). View file
 
peft/__pycache__/mixed_model.cpython-310.pyc ADDED
Binary file (14.8 kB). View file
 
peft/__pycache__/peft_model.cpython-310.pyc ADDED
Binary file (71.9 kB). View file
 
peft/auto.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import importlib
18
+ import os
19
+ from typing import Optional
20
+
21
+ from transformers import (
22
+ AutoModel,
23
+ AutoModelForCausalLM,
24
+ AutoModelForQuestionAnswering,
25
+ AutoModelForSeq2SeqLM,
26
+ AutoModelForSequenceClassification,
27
+ AutoModelForTokenClassification,
28
+ AutoTokenizer,
29
+ )
30
+
31
+ from .config import PeftConfig
32
+ from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING
33
+ from .peft_model import (
34
+ PeftModel,
35
+ PeftModelForCausalLM,
36
+ PeftModelForFeatureExtraction,
37
+ PeftModelForQuestionAnswering,
38
+ PeftModelForSeq2SeqLM,
39
+ PeftModelForSequenceClassification,
40
+ PeftModelForTokenClassification,
41
+ )
42
+ from .utils.constants import TOKENIZER_CONFIG_NAME
43
+ from .utils.other import check_file_exists_on_hf_hub
44
+
45
+
46
+ class _BaseAutoPeftModel:
47
+ _target_class = None
48
+ _target_peft_class = None
49
+
50
+ def __init__(self, *args, **kwargs):
51
+ # For consistency with transformers: https://github.com/huggingface/transformers/blob/91d7df58b6537d385e90578dac40204cb550f706/src/transformers/models/auto/auto_factory.py#L400
52
+ raise EnvironmentError( # noqa: UP024
53
+ f"{self.__class__.__name__} is designed to be instantiated "
54
+ f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
55
+ f"`{self.__class__.__name__}.from_config(config)` methods."
56
+ )
57
+
58
+ @classmethod
59
+ def from_pretrained(
60
+ cls,
61
+ pretrained_model_name_or_path,
62
+ adapter_name: str = "default",
63
+ is_trainable: bool = False,
64
+ config: Optional[PeftConfig] = None,
65
+ **kwargs,
66
+ ):
67
+ r"""
68
+ A wrapper around all the preprocessing steps a user needs to perform in order to load a PEFT model. The kwargs
69
+ are passed along to `PeftConfig` that automatically takes care of filtering the kwargs of the Hub methods and
70
+ the config object init.
71
+ """
72
+ peft_config = PeftConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
73
+ base_model_path = peft_config.base_model_name_or_path
74
+
75
+ task_type = getattr(peft_config, "task_type", None)
76
+
77
+ if cls._target_class is not None:
78
+ target_class = cls._target_class
79
+ elif cls._target_class is None and task_type is not None:
80
+ # this is only in the case where we use `AutoPeftModel`
81
+ raise ValueError(
82
+ "Cannot use `AutoPeftModel` with a task type, please use a specific class for your task type. (e.g. `AutoPeftModelForCausalLM` for `task_type='CAUSAL_LM'`)"
83
+ )
84
+
85
+ if task_type is not None:
86
+ expected_target_class = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[task_type]
87
+ if cls._target_peft_class.__name__ != expected_target_class.__name__:
88
+ raise ValueError(
89
+ f"Expected target PEFT class: {expected_target_class.__name__}, but you have asked for: {cls._target_peft_class.__name__ }"
90
+ " make sure that you are loading the correct model for your task type."
91
+ )
92
+ elif task_type is None and getattr(peft_config, "auto_mapping", None) is not None:
93
+ auto_mapping = getattr(peft_config, "auto_mapping", None)
94
+ base_model_class = auto_mapping["base_model_class"]
95
+ parent_library_name = auto_mapping["parent_library"]
96
+
97
+ parent_library = importlib.import_module(parent_library_name)
98
+ target_class = getattr(parent_library, base_model_class)
99
+ else:
100
+ raise ValueError(
101
+ "Cannot infer the auto class from the config, please make sure that you are loading the correct model for your task type."
102
+ )
103
+
104
+ base_model = target_class.from_pretrained(base_model_path, **kwargs)
105
+
106
+ tokenizer_exists = False
107
+ if os.path.exists(os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_NAME)):
108
+ tokenizer_exists = True
109
+ else:
110
+ token = kwargs.get("token", None)
111
+ if token is None:
112
+ token = kwargs.get("use_auth_token", None)
113
+
114
+ tokenizer_exists = check_file_exists_on_hf_hub(
115
+ repo_id=pretrained_model_name_or_path,
116
+ filename=TOKENIZER_CONFIG_NAME,
117
+ revision=kwargs.get("revision", None),
118
+ repo_type=kwargs.get("repo_type", None),
119
+ token=token,
120
+ )
121
+
122
+ if tokenizer_exists:
123
+ tokenizer = AutoTokenizer.from_pretrained(
124
+ pretrained_model_name_or_path, trust_remote_code=kwargs.get("trust_remote_code", False)
125
+ )
126
+ base_model.resize_token_embeddings(len(tokenizer))
127
+
128
+ return cls._target_peft_class.from_pretrained(
129
+ base_model,
130
+ pretrained_model_name_or_path,
131
+ adapter_name=adapter_name,
132
+ is_trainable=is_trainable,
133
+ config=config,
134
+ **kwargs,
135
+ )
136
+
137
+
138
+ class AutoPeftModel(_BaseAutoPeftModel):
139
+ _target_class = None
140
+ _target_peft_class = PeftModel
141
+
142
+
143
+ class AutoPeftModelForCausalLM(_BaseAutoPeftModel):
144
+ _target_class = AutoModelForCausalLM
145
+ _target_peft_class = PeftModelForCausalLM
146
+
147
+
148
+ class AutoPeftModelForSeq2SeqLM(_BaseAutoPeftModel):
149
+ _target_class = AutoModelForSeq2SeqLM
150
+ _target_peft_class = PeftModelForSeq2SeqLM
151
+
152
+
153
+ class AutoPeftModelForSequenceClassification(_BaseAutoPeftModel):
154
+ _target_class = AutoModelForSequenceClassification
155
+ _target_peft_class = PeftModelForSequenceClassification
156
+
157
+
158
+ class AutoPeftModelForTokenClassification(_BaseAutoPeftModel):
159
+ _target_class = AutoModelForTokenClassification
160
+ _target_peft_class = PeftModelForTokenClassification
161
+
162
+
163
+ class AutoPeftModelForQuestionAnswering(_BaseAutoPeftModel):
164
+ _target_class = AutoModelForQuestionAnswering
165
+ _target_peft_class = PeftModelForQuestionAnswering
166
+
167
+
168
+ class AutoPeftModelForFeatureExtraction(_BaseAutoPeftModel):
169
+ _target_class = AutoModel
170
+ _target_peft_class = PeftModelForFeatureExtraction
peft/config.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import inspect
15
+ import json
16
+ import os
17
+ from dataclasses import asdict, dataclass, field
18
+ from typing import Dict, Optional, Union
19
+
20
+ from huggingface_hub import hf_hub_download
21
+ from transformers.utils import PushToHubMixin
22
+
23
+ from .utils import CONFIG_NAME, PeftType, TaskType
24
+
25
+
26
+ @dataclass
27
+ class PeftConfigMixin(PushToHubMixin):
28
+ r"""
29
+ This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all
30
+ PEFT adapter models. This class inherits from [`~transformers.utils.PushToHubMixin`] which contains the methods to
31
+ push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a
32
+ directory. The method `from_pretrained` will load the configuration of your adapter model from a directory.
33
+
34
+ Args:
35
+ peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
36
+ """
37
+
38
+ peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."})
39
+ auto_mapping: Optional[dict] = field(
40
+ default=None, metadata={"help": "An auto mapping dict to help retrieve the base model class if needed."}
41
+ )
42
+
43
+ def to_dict(self) -> Dict:
44
+ r"""
45
+ Returns the configuration for your adapter model as a dictionary.
46
+ """
47
+ return asdict(self)
48
+
49
+ def save_pretrained(self, save_directory: str, **kwargs) -> None:
50
+ r"""
51
+ This method saves the configuration of your adapter model in a directory.
52
+
53
+ Args:
54
+ save_directory (`str`):
55
+ The directory where the configuration will be saved.
56
+ kwargs (additional keyword arguments, *optional*):
57
+ Additional keyword arguments passed along to the [`~transformers.utils.PushToHubMixin.push_to_hub`]
58
+ method.
59
+ """
60
+ if os.path.isfile(save_directory):
61
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
62
+
63
+ os.makedirs(save_directory, exist_ok=True)
64
+ auto_mapping_dict = kwargs.pop("auto_mapping_dict", None)
65
+
66
+ output_dict = asdict(self)
67
+ # converting set type to list
68
+ for key, value in output_dict.items():
69
+ if isinstance(value, set):
70
+ output_dict[key] = list(value)
71
+
72
+ output_path = os.path.join(save_directory, CONFIG_NAME)
73
+
74
+ # Add auto mapping details for custom models.
75
+ if auto_mapping_dict is not None:
76
+ output_dict["auto_mapping"] = auto_mapping_dict
77
+
78
+ # save it
79
+ with open(output_path, "w") as writer:
80
+ writer.write(json.dumps(output_dict, indent=2, sort_keys=True))
81
+
82
+ @classmethod
83
+ def from_peft_type(cls, **kwargs):
84
+ r"""
85
+ This method loads the configuration of your adapter model from a set of kwargs.
86
+
87
+ The appropriate configuration type is determined by the `peft_type` argument. If `peft_type` is not provided,
88
+ the calling class type is instantiated.
89
+
90
+ Args:
91
+ kwargs (configuration keyword arguments):
92
+ Keyword arguments passed along to the configuration initialization.
93
+ """
94
+ # Avoid circular dependency .. TODO: fix this with a larger refactor
95
+ from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
96
+
97
+ # TODO: this hack is needed to fix the following issue (on commit 702f937):
98
+ # if someone saves a default config and loads it back with `PeftConfig` class it yields to
99
+ # not loading the correct config class.
100
+
101
+ # from peft import AdaLoraConfig, PeftConfig
102
+ # peft_config = AdaLoraConfig()
103
+ # print(peft_config)
104
+ # >>> AdaLoraConfig(peft_type=<PeftType.ADALORA: 'ADALORA'>, auto_mapping=None, base_model_name_or_path=None,
105
+ # revision=None, task_type=None, inference_mode=False, r=8, target_modules=None, lora_alpha=8, lora_dropout=0.0, ...
106
+ #
107
+ # peft_config.save_pretrained("./test_config")
108
+ # peft_config = PeftConfig.from_pretrained("./test_config")
109
+ # print(peft_config)
110
+ # >>> PeftConfig(peft_type='ADALORA', auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=None, inference_mode=False)
111
+
112
+ if "peft_type" in kwargs:
113
+ peft_type = kwargs["peft_type"]
114
+ config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type]
115
+ else:
116
+ config_cls = cls
117
+
118
+ return config_cls(**kwargs)
119
+
120
+ @classmethod
121
+ def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional[str] = None, **kwargs):
122
+ r"""
123
+ This method loads the configuration of your adapter model from a directory.
124
+
125
+ Args:
126
+ pretrained_model_name_or_path (`str`):
127
+ The directory or the Hub repository id where the configuration is saved.
128
+ kwargs (additional keyword arguments, *optional*):
129
+ Additional keyword arguments passed along to the child class initialization.
130
+ """
131
+ path = (
132
+ os.path.join(pretrained_model_name_or_path, subfolder)
133
+ if subfolder is not None
134
+ else pretrained_model_name_or_path
135
+ )
136
+
137
+ hf_hub_download_kwargs, class_kwargs, _ = cls._split_kwargs(kwargs)
138
+
139
+ if os.path.isfile(os.path.join(path, CONFIG_NAME)):
140
+ config_file = os.path.join(path, CONFIG_NAME)
141
+ else:
142
+ try:
143
+ config_file = hf_hub_download(
144
+ pretrained_model_name_or_path, CONFIG_NAME, subfolder=subfolder, **hf_hub_download_kwargs
145
+ )
146
+ except Exception as exc:
147
+ raise ValueError(f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'") from exc
148
+
149
+ loaded_attributes = cls.from_json_file(config_file)
150
+ kwargs = {**class_kwargs, **loaded_attributes}
151
+ return cls.from_peft_type(**kwargs)
152
+
153
+ @classmethod
154
+ def from_json_file(cls, path_json_file: str, **kwargs):
155
+ r"""
156
+ Loads a configuration file from a json file.
157
+
158
+ Args:
159
+ path_json_file (`str`):
160
+ The path to the json file.
161
+ """
162
+ with open(path_json_file) as file:
163
+ json_object = json.load(file)
164
+
165
+ return json_object
166
+
167
+ @classmethod
168
+ def _split_kwargs(cls, kwargs):
169
+ hf_hub_download_kwargs = {}
170
+ class_kwargs = {}
171
+ other_kwargs = {}
172
+
173
+ for key, value in kwargs.items():
174
+ if key in inspect.signature(hf_hub_download).parameters:
175
+ hf_hub_download_kwargs[key] = value
176
+ elif key in list(cls.__annotations__):
177
+ class_kwargs[key] = value
178
+ else:
179
+ other_kwargs[key] = value
180
+
181
+ return hf_hub_download_kwargs, class_kwargs, other_kwargs
182
+
183
+ @classmethod
184
+ def _get_peft_type(
185
+ cls,
186
+ model_id: str,
187
+ **hf_hub_download_kwargs,
188
+ ):
189
+ subfolder = hf_hub_download_kwargs.get("subfolder", None)
190
+
191
+ path = os.path.join(model_id, subfolder) if subfolder is not None else model_id
192
+
193
+ if os.path.isfile(os.path.join(path, CONFIG_NAME)):
194
+ config_file = os.path.join(path, CONFIG_NAME)
195
+ else:
196
+ try:
197
+ config_file = hf_hub_download(
198
+ model_id,
199
+ CONFIG_NAME,
200
+ **hf_hub_download_kwargs,
201
+ )
202
+ except Exception:
203
+ raise ValueError(f"Can't find '{CONFIG_NAME}' at '{model_id}'")
204
+
205
+ loaded_attributes = cls.from_json_file(config_file)
206
+ return loaded_attributes["peft_type"]
207
+
208
+ @property
209
+ def is_prompt_learning(self) -> bool:
210
+ r"""
211
+ Utility method to check if the configuration is for prompt learning.
212
+ """
213
+ return False
214
+
215
+ @property
216
+ def is_adaption_prompt(self) -> bool:
217
+ """Return True if this is an adaption prompt config."""
218
+ return False
219
+
220
+
221
+ @dataclass
222
+ class PeftConfig(PeftConfigMixin):
223
+ """
224
+ This is the base configuration class to store the configuration of a [`PeftModel`].
225
+
226
+ Args:
227
+ peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
228
+ task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform.
229
+ inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode.
230
+ """
231
+
232
+ base_model_name_or_path: Optional[str] = field(
233
+ default=None, metadata={"help": "The name of the base model to use."}
234
+ )
235
+ revision: Optional[str] = field(default=None, metadata={"help": "The specific model version to use."})
236
+ peft_type: Optional[Union[str, PeftType]] = field(default=None, metadata={"help": "Peft type"})
237
+ task_type: Optional[Union[str, TaskType]] = field(default=None, metadata={"help": "Task type"})
238
+ inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"})
239
+
240
+
241
+ @dataclass
242
+ class PromptLearningConfig(PeftConfig):
243
+ """
244
+ This is the base configuration class to store the configuration of [`PrefixTuning`], [`PromptEncoder`], or
245
+ [`PromptTuning`].
246
+
247
+ Args:
248
+ num_virtual_tokens (`int`): The number of virtual tokens to use.
249
+ token_dim (`int`): The hidden embedding dimension of the base transformer model.
250
+ num_transformer_submodules (`int`): The number of transformer submodules in the base transformer model.
251
+ num_attention_heads (`int`): The number of attention heads in the base transformer model.
252
+ num_layers (`int`): The number of layers in the base transformer model.
253
+ """
254
+
255
+ num_virtual_tokens: int = field(default=None, metadata={"help": "Number of virtual tokens"})
256
+ token_dim: int = field(
257
+ default=None, metadata={"help": "The hidden embedding dimension of the base transformer model"}
258
+ )
259
+ num_transformer_submodules: Optional[int] = field(
260
+ default=None, metadata={"help": "Number of transformer submodules"}
261
+ )
262
+ num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"})
263
+ num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"})
264
+
265
+ @property
266
+ def is_prompt_learning(self) -> bool:
267
+ r"""
268
+ Utility method to check if the configuration is for prompt learning.
269
+ """
270
+ return True
peft/helpers.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from copy import deepcopy
17
+ from functools import update_wrapper
18
+ from types import MethodType
19
+
20
+ from .peft_model import PeftConfig, PeftModel
21
+
22
+
23
+ def update_forward_signature(model: PeftModel) -> None:
24
+ """
25
+ Updates the forward signature of the PeftModel to include parents class signature
26
+ model (`PeftModel`): Peft model to update the forward signature
27
+
28
+ Example:
29
+
30
+ ```python
31
+ >>> from transformers import WhisperForConditionalGeneration
32
+ >>> from peft import get_peft_model, LoraConfig, update_forward_signature
33
+
34
+ >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
35
+ >>> peft_config = LoraConfig(r=8, lora_alpha=32, lora_dropout=0.1, target_modules=["q_proj", "v_proj"])
36
+
37
+ >>> peft_model = get_peft_model(model, peft_config)
38
+ >>> update_forward_signature(peft_model)
39
+ ```
40
+ """
41
+
42
+ # Only update signature when the current forward signature only has *args and **kwargs
43
+ current_signature = inspect.signature(model.forward)
44
+ if (
45
+ len(current_signature.parameters) == 2
46
+ and "args" in current_signature.parameters
47
+ and "kwargs" in current_signature.parameters
48
+ ):
49
+ forward = deepcopy(model.forward.__func__)
50
+ update_wrapper(
51
+ forward, type(model.get_base_model()).forward, assigned=("__doc__", "__name__", "__annotations__")
52
+ )
53
+ model.forward = MethodType(forward, model)
54
+
55
+
56
+ def update_generate_signature(model: PeftModel) -> None:
57
+ """
58
+ Updates the generate signature of a PeftModel with overriding generate to include parents class signature
59
+ model (`PeftModel`): Peft model to update the generate signature
60
+
61
+ Example:
62
+
63
+ ```python
64
+ >>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
65
+ >>> from peft import get_peft_model, LoraConfig, TaskType, update_generate_signature
66
+
67
+ >>> model_name_or_path = "bigscience/mt0-large"
68
+ >>> tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
69
+ >>> model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
70
+
71
+ >>> peft_config = LoraConfig(
72
+ ... task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
73
+ ... )
74
+ >>> peft_model = get_peft_model(model, peft_config)
75
+ >>> update_generate_signature(peft_model)
76
+ >>> help(peft_model.generate)
77
+ ```
78
+ """
79
+ if not hasattr(model, "generate"):
80
+ return
81
+ current_signature = inspect.signature(model.generate)
82
+ if (
83
+ len(current_signature.parameters) == 2
84
+ and "args" in current_signature.parameters
85
+ and "kwargs" in current_signature.parameters
86
+ ) or (len(current_signature.parameters) == 1 and "kwargs" in current_signature.parameters):
87
+ generate = deepcopy(model.generate.__func__)
88
+ update_wrapper(
89
+ generate,
90
+ type(model.get_base_model()).generate,
91
+ assigned=("__doc__", "__name__", "__annotations__"),
92
+ )
93
+ model.generate = MethodType(generate, model)
94
+
95
+
96
+ def update_signature(model: PeftModel, method: str = "all") -> None:
97
+ """
98
+ Updates the signature of a PeftModel include parents class signature for forward or generate method
99
+ model (`PeftModel`): Peft model to update generate or forward signature method (`str`): method to update
100
+ signature choose one of "forward", "generate", "all"
101
+
102
+ Example:
103
+ ```python
104
+ >>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
105
+ >>> from peft import get_peft_model, LoraConfig, TaskType, update_signature
106
+
107
+ >>> model_name_or_path = "bigscience/mt0-large"
108
+ >>> tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
109
+ >>> model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
110
+
111
+ >>> peft_config = LoraConfig(
112
+ ... task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
113
+ ... )
114
+ >>> peft_model = get_peft_model(model, peft_config)
115
+ >>> update_signature(peft_model)
116
+ >>> help(peft_model.generate)
117
+ ```
118
+ """
119
+ if method == "forward":
120
+ update_forward_signature(model)
121
+ elif method == "generate":
122
+ update_generate_signature(model)
123
+ elif method == "all":
124
+ update_forward_signature(model)
125
+ update_generate_signature(model)
126
+ else:
127
+ raise ValueError(f"method {method} is not supported please choose one of ['forward', 'generate', 'all']")
128
+
129
+
130
+ def check_if_peft_model(model_name_or_path: str) -> bool:
131
+ """
132
+ Check if the model is a PEFT model.
133
+
134
+ Args:
135
+ model_name_or_path (`str`):
136
+ Model id to check, can be local or on the Hugging Face Hub.
137
+
138
+ Returns:
139
+ `bool`: True if the model is a PEFT model, False otherwise.
140
+ """
141
+ is_peft_model = True
142
+ try:
143
+ PeftConfig.from_pretrained(model_name_or_path)
144
+ except Exception:
145
+ # allow broad exceptions so that this works even if new exceptions are added on HF Hub side
146
+ is_peft_model = False
147
+
148
+ return is_peft_model
peft/import_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import importlib
15
+ import importlib.metadata as importlib_metadata
16
+ from functools import lru_cache
17
+
18
+ import packaging.version
19
+
20
+
21
+ @lru_cache
22
+ def is_bnb_available() -> bool:
23
+ return importlib.util.find_spec("bitsandbytes") is not None
24
+
25
+
26
+ @lru_cache
27
+ def is_bnb_4bit_available() -> bool:
28
+ if not is_bnb_available():
29
+ return False
30
+
31
+ import bitsandbytes as bnb
32
+
33
+ return hasattr(bnb.nn, "Linear4bit")
34
+
35
+
36
+ @lru_cache
37
+ def is_auto_gptq_available():
38
+ if importlib.util.find_spec("auto_gptq") is not None:
39
+ AUTOGPTQ_MINIMUM_VERSION = packaging.version.parse("0.5.0")
40
+ version_autogptq = packaging.version.parse(importlib_metadata.version("auto_gptq"))
41
+ if AUTOGPTQ_MINIMUM_VERSION <= version_autogptq:
42
+ return True
43
+ else:
44
+ raise ImportError(
45
+ f"Found an incompatible version of auto-gptq. Found version {version_autogptq}, "
46
+ f"but only versions above {AUTOGPTQ_MINIMUM_VERSION} are supported"
47
+ )
48
+
49
+
50
+ @lru_cache
51
+ def is_optimum_available() -> bool:
52
+ return importlib.util.find_spec("optimum") is not None
53
+
54
+
55
+ @lru_cache
56
+ def is_torch_tpu_available(check_device=True):
57
+ "Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
58
+ if importlib.util.find_spec("torch_xla") is not None:
59
+ if check_device:
60
+ # We need to check if `xla_device` can be found, will raise a RuntimeError if not
61
+ try:
62
+ import torch_xla.core.xla_model as xm
63
+
64
+ _ = xm.xla_device()
65
+ return True
66
+ except RuntimeError:
67
+ return False
68
+ return True
69
+ return False
70
+
71
+
72
+ @lru_cache
73
+ def is_aqlm_available():
74
+ return importlib.util.find_spec("aqlm") is not None
75
+
76
+
77
+ @lru_cache
78
+ def is_auto_awq_available():
79
+ return importlib.util.find_spec("awq") is not None
80
+
81
+
82
+ @lru_cache
83
+ def is_eetq_available():
84
+ return importlib.util.find_spec("eetq") is not None
85
+
86
+
87
+ @lru_cache
88
+ def is_hqq_available():
89
+ return importlib.util.find_spec("hqq") is not None
peft/mapping.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import TYPE_CHECKING, Any
18
+
19
+ import torch
20
+
21
+ from .config import PeftConfig
22
+ from .mixed_model import PeftMixedModel
23
+ from .peft_model import (
24
+ PeftModel,
25
+ PeftModelForCausalLM,
26
+ PeftModelForFeatureExtraction,
27
+ PeftModelForQuestionAnswering,
28
+ PeftModelForSeq2SeqLM,
29
+ PeftModelForSequenceClassification,
30
+ PeftModelForTokenClassification,
31
+ )
32
+ from .tuners import (
33
+ AdaLoraConfig,
34
+ AdaLoraModel,
35
+ AdaptionPromptConfig,
36
+ BOFTConfig,
37
+ BOFTModel,
38
+ IA3Config,
39
+ IA3Model,
40
+ LNTuningConfig,
41
+ LNTuningModel,
42
+ LoHaConfig,
43
+ LoHaModel,
44
+ LoKrConfig,
45
+ LoKrModel,
46
+ LoraConfig,
47
+ LoraModel,
48
+ MultitaskPromptTuningConfig,
49
+ OFTConfig,
50
+ OFTModel,
51
+ PolyConfig,
52
+ PolyModel,
53
+ PrefixTuningConfig,
54
+ PromptEncoderConfig,
55
+ PromptTuningConfig,
56
+ VeraConfig,
57
+ VeraModel,
58
+ )
59
+ from .tuners.tuners_utils import BaseTuner as _BaseTuner
60
+ from .utils import _prepare_prompt_learning_config
61
+
62
+
63
+ if TYPE_CHECKING:
64
+ from transformers import PreTrainedModel
65
+
66
+
67
+ MODEL_TYPE_TO_PEFT_MODEL_MAPPING: dict[str, type[PeftModel]] = {
68
+ "SEQ_CLS": PeftModelForSequenceClassification,
69
+ "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
70
+ "CAUSAL_LM": PeftModelForCausalLM,
71
+ "TOKEN_CLS": PeftModelForTokenClassification,
72
+ "QUESTION_ANS": PeftModelForQuestionAnswering,
73
+ "FEATURE_EXTRACTION": PeftModelForFeatureExtraction,
74
+ }
75
+
76
+ PEFT_TYPE_TO_CONFIG_MAPPING: dict[str, type[PeftConfig]] = {
77
+ "ADAPTION_PROMPT": AdaptionPromptConfig,
78
+ "PROMPT_TUNING": PromptTuningConfig,
79
+ "PREFIX_TUNING": PrefixTuningConfig,
80
+ "P_TUNING": PromptEncoderConfig,
81
+ "LORA": LoraConfig,
82
+ "LOHA": LoHaConfig,
83
+ "LOKR": LoKrConfig,
84
+ "ADALORA": AdaLoraConfig,
85
+ "BOFT": BOFTConfig,
86
+ "IA3": IA3Config,
87
+ "MULTITASK_PROMPT_TUNING": MultitaskPromptTuningConfig,
88
+ "OFT": OFTConfig,
89
+ "POLY": PolyConfig,
90
+ "LN_TUNING": LNTuningConfig,
91
+ "VERA": VeraConfig,
92
+ }
93
+
94
+ PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[_BaseTuner]] = {
95
+ "LORA": LoraModel,
96
+ "LOHA": LoHaModel,
97
+ "LOKR": LoKrModel,
98
+ "ADALORA": AdaLoraModel,
99
+ "BOFT": BOFTModel,
100
+ "IA3": IA3Model,
101
+ "OFT": OFTModel,
102
+ "POLY": PolyModel,
103
+ "LN_TUNING": LNTuningModel,
104
+ "VERA": VeraModel,
105
+ }
106
+
107
+
108
+ def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig:
109
+ """
110
+ Returns a Peft config object from a dictionary.
111
+
112
+ Args:
113
+ config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters.
114
+ """
115
+
116
+ return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict)
117
+
118
+
119
+ def get_peft_model(
120
+ model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default", mixed: bool = False
121
+ ) -> PeftModel | PeftMixedModel:
122
+ """
123
+ Returns a Peft model object from a model and a config.
124
+
125
+ Args:
126
+ model ([`transformers.PreTrainedModel`]):
127
+ Model to be wrapped.
128
+ peft_config ([`PeftConfig`]):
129
+ Configuration object containing the parameters of the Peft model.
130
+ adapter_name (`str`, `optional`, defaults to `"default"`):
131
+ The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
132
+ mixed (`bool`, `optional`, defaults to `False`):
133
+ Whether to allow mixing different (compatible) adapter types.
134
+ """
135
+ model_config = getattr(model, "config", {"model_type": "custom"})
136
+ if hasattr(model_config, "to_dict"):
137
+ model_config = model_config.to_dict()
138
+
139
+ peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)
140
+
141
+ if mixed:
142
+ return PeftMixedModel(model, peft_config, adapter_name=adapter_name)
143
+
144
+ if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning:
145
+ return PeftModel(model, peft_config, adapter_name=adapter_name)
146
+
147
+ if peft_config.is_prompt_learning:
148
+ peft_config = _prepare_prompt_learning_config(peft_config, model_config)
149
+ return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)
150
+
151
+
152
+ def inject_adapter_in_model(
153
+ peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default"
154
+ ) -> torch.nn.Module:
155
+ r"""
156
+ A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning
157
+ methods and adaption prompt. Make sure to have the correct `target_names` set in the `peft_config` object. The API
158
+ calls `get_peft_model` under the hood but would be restricted only to non-prompt learning methods.
159
+
160
+ Args:
161
+ peft_config (`PeftConfig`):
162
+ Configuration object containing the parameters of the Peft model.
163
+ model (`torch.nn.Module`):
164
+ The input model where the adapter will be injected.
165
+ adapter_name (`str`, `optional`, defaults to `"default"`):
166
+ The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
167
+ """
168
+ if peft_config.is_prompt_learning or peft_config.is_adaption_prompt:
169
+ raise ValueError("`create_and_replace` does not support prompt learning and adaption prompt yet.")
170
+
171
+ if peft_config.peft_type not in PEFT_TYPE_TO_TUNER_MAPPING.keys():
172
+ raise ValueError(
173
+ f"`inject_adapter_in_model` does not support {peft_config.peft_type} yet. Please use `get_peft_model`."
174
+ )
175
+
176
+ tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type]
177
+
178
+ # By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules.
179
+ peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name)
180
+
181
+ return peft_model.model
peft/mixed_model.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ from contextlib import contextmanager
19
+ from typing import Any, Optional, Union
20
+
21
+ import torch
22
+ from accelerate.hooks import remove_hook_from_submodules
23
+ from torch import nn
24
+ from transformers.utils import PushToHubMixin
25
+
26
+ from peft.tuners.mixed import COMPATIBLE_TUNER_TYPES
27
+
28
+ from .config import PeftConfig
29
+ from .peft_model import PeftModel
30
+ from .tuners import (
31
+ AdaLoraModel,
32
+ IA3Model,
33
+ LoHaModel,
34
+ LoKrModel,
35
+ LoraModel,
36
+ MixedModel,
37
+ OFTModel,
38
+ )
39
+ from .utils import PeftType, _set_adapter, _set_trainable
40
+
41
+
42
+ PEFT_TYPE_TO_MODEL_MAPPING = {
43
+ PeftType.LORA: LoraModel,
44
+ PeftType.LOHA: LoHaModel,
45
+ PeftType.LOKR: LoKrModel,
46
+ PeftType.ADALORA: AdaLoraModel,
47
+ PeftType.IA3: IA3Model,
48
+ PeftType.OFT: OFTModel,
49
+ }
50
+
51
+
52
+ def _prepare_model_for_gradient_checkpointing(model: nn.Module) -> None:
53
+ r"""
54
+ Prepares the model for gradient checkpointing if necessary
55
+ """
56
+ # Note: same as PeftModel._prepare_model_for_gradient_checkpointing
57
+ if not getattr(model, "is_gradient_checkpointing", True):
58
+ return model
59
+
60
+ if not (
61
+ getattr(model, "is_loaded_in_8bit", False)
62
+ or getattr(model, "is_loaded_in_4bit", False)
63
+ or getattr(model, "is_quantized", False)
64
+ ):
65
+ if hasattr(model, "enable_input_require_grads"):
66
+ model.enable_input_require_grads()
67
+ elif hasattr(model, "get_input_embeddings"):
68
+
69
+ def make_inputs_require_grad(module, input, output):
70
+ output.requires_grad_(True)
71
+
72
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
73
+
74
+
75
+ def _check_config_compatible(peft_config: PeftConfig) -> None:
76
+ if peft_config.peft_type not in COMPATIBLE_TUNER_TYPES:
77
+ raise ValueError(
78
+ f"The provided `peft_type` '{peft_config.peft_type.value}' is not compatible with the `PeftMixedModel`. "
79
+ f"Compatible types are: {COMPATIBLE_TUNER_TYPES}"
80
+ )
81
+
82
+
83
+ class PeftMixedModel(PushToHubMixin, torch.nn.Module):
84
+ """
85
+ PeftMixedModel for loading mixing different types of adapters for inference.
86
+
87
+ This class does not support loading/saving, and it shouldn't usually be initialized directly. Instead, use
88
+ `get_peft_model` with the argument `mixed=True`.
89
+
90
+ <Tip>
91
+
92
+ Read the [Mixed adapter types](https://huggingface.co/docs/peft/en/developer_guides/mixed_models) guide to learn
93
+ more about using different adapter types.
94
+
95
+ </Tip>
96
+
97
+ Example:
98
+
99
+ ```py
100
+ >>> from peft import get_peft_model
101
+
102
+ >>> base_model = ... # load the base model, e.g. from transformers
103
+ >>> peft_model = PeftMixedModel.from_pretrained(base_model, path_to_adapter1, "adapter1").eval()
104
+ >>> peft_model.load_adapter(path_to_adapter2, "adapter2")
105
+ >>> peft_model.set_adapter(["adapter1", "adapter2"]) # activate both adapters
106
+ >>> peft_model(data) # forward pass using both adapters
107
+ ```
108
+
109
+ Args:
110
+ model (`torch.nn.Module`):
111
+ The model to be tuned.
112
+ config (`PeftConfig`):
113
+ The config of the model to be tuned. The adapter type must be compatible.
114
+ adapter_name (`str`, `optional`, defaults to `"default"`):
115
+ The name of the first adapter.
116
+ """
117
+
118
+ def __init__(self, model: nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
119
+ super().__init__()
120
+ _check_config_compatible(peft_config)
121
+ _prepare_model_for_gradient_checkpointing(model)
122
+ self.modules_to_save = None
123
+ self.base_model = MixedModel(model, {adapter_name: peft_config}, adapter_name)
124
+ self.set_modules_to_save(peft_config, adapter_name)
125
+
126
+ self.config = getattr(model, "config", {"model_type": "custom"})
127
+
128
+ # the `pretraining_tp` is set for some models to simulate Tensor Parallelism during inference to avoid
129
+ # numerical differences, https://github.com/pytorch/pytorch/issues/76232 - to avoid any unexpected
130
+ # behavior we disable that in this line.
131
+ if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"):
132
+ self.base_model.config.pretraining_tp = 1
133
+
134
+ @property
135
+ def peft_config(self) -> dict[str, PeftConfig]:
136
+ return self.base_model.peft_config
137
+
138
+ @property
139
+ def active_adapter(self) -> str:
140
+ return self.base_model.active_adapter
141
+
142
+ @property
143
+ def active_adapters(self) -> list[str]:
144
+ return self.base_model.active_adapters
145
+
146
+ def get_nb_trainable_parameters(self):
147
+ r"""
148
+ Returns the number of trainable parameters and number of all parameters in the model.
149
+ """
150
+ # note: same as PeftModel.get_nb_trainable_parameters
151
+ trainable_params = 0
152
+ all_param = 0
153
+ for _, param in self.named_parameters():
154
+ num_params = param.numel()
155
+ # if using DS Zero 3 and the weights are initialized empty
156
+ if num_params == 0 and hasattr(param, "ds_numel"):
157
+ num_params = param.ds_numel
158
+
159
+ # Due to the design of 4bit linear layers from bitsandbytes
160
+ # one needs to multiply the number of parameters by 2 to get
161
+ # the correct number of parameters
162
+ if param.__class__.__name__ == "Params4bit":
163
+ num_params = num_params * 2
164
+
165
+ all_param += num_params
166
+ if param.requires_grad:
167
+ trainable_params += num_params
168
+
169
+ return trainable_params, all_param
170
+
171
+ def print_trainable_parameters(self):
172
+ """
173
+ Prints the number of trainable parameters in the model.
174
+
175
+ Note: print_trainable_parameters() uses get_nb_trainable_parameters() which is different from
176
+ num_parameters(only_trainable=True) from huggingface/transformers. get_nb_trainable_parameters() returns
177
+ (trainable parameters, all parameters) of the Peft Model which includes modified backbone transformer model.
178
+ For techniques like LoRA, the backbone transformer model is modified in place with LoRA modules. However, for
179
+ prompt tuning, the backbone transformer model is unmodified. num_parameters(only_trainable=True) returns number
180
+ of trainable parameters of the backbone transformer model which can be different.
181
+ """
182
+ # note: same as PeftModel.print_trainable_parameters
183
+ trainable_params, all_param = self.get_nb_trainable_parameters()
184
+
185
+ print(
186
+ f"trainable params: {trainable_params:,d} || "
187
+ f"all params: {all_param:,d} || "
188
+ f"trainable%: {100 * trainable_params / all_param:.4f}"
189
+ )
190
+
191
+ def __getattr__(self, name: str):
192
+ """Forward missing attributes to the wrapped module."""
193
+ try:
194
+ return super().__getattr__(name) # defer to nn.Module's logic
195
+ except AttributeError:
196
+ return getattr(self.base_model, name)
197
+
198
+ def forward(self, *args: Any, **kwargs: Any):
199
+ """
200
+ Forward pass of the model.
201
+ """
202
+ return self.base_model(*args, **kwargs)
203
+
204
+ def generate(self, *args: Any, **kwargs: Any):
205
+ """
206
+ Generate output.
207
+ """
208
+ return self.base_model.generate(*args, **kwargs)
209
+
210
+ @contextmanager
211
+ def disable_adapter(self):
212
+ """
213
+ Disables the adapter module.
214
+ """
215
+ try:
216
+ self.base_model.disable_adapter_layers()
217
+ yield
218
+ finally:
219
+ self.base_model.enable_adapter_layers()
220
+
221
+ def add_adapter(self, adapter_name: str, peft_config: PeftConfig):
222
+ _check_config_compatible(peft_config)
223
+
224
+ try:
225
+ self.peft_config[adapter_name] = peft_config
226
+ self.base_model.inject_adapter(self, adapter_name)
227
+ except Exception: # something went wrong, roll back
228
+ if adapter_name in self.peft_config:
229
+ del self.peft_config[adapter_name]
230
+ raise
231
+
232
+ self.set_modules_to_save(peft_config, adapter_name)
233
+
234
+ def set_modules_to_save(self, peft_config: PeftConfig, adapter_name: str) -> None:
235
+ if (modules_to_save := getattr(peft_config, "modules_to_save", None)) is None:
236
+ return
237
+
238
+ if self.modules_to_save is None:
239
+ self.modules_to_save = set(modules_to_save)
240
+ else:
241
+ self.modules_to_save.update(modules_to_save)
242
+ _set_trainable(self, adapter_name)
243
+
244
+ def set_adapter(self, adapter_name: Union[str, list[str]]) -> None:
245
+ """
246
+ Sets the active adapter(s) for the model.
247
+
248
+ Note that the order in which the adapters are applied during the forward pass may not be the same as the order
249
+ in which they are passed to this function. Instead, the order during the forward pass is determined by the
250
+ order in which the adapters were loaded into the model. The active adapters only determine which adapters are
251
+ active during the forward pass, but not the order in which they are applied.
252
+
253
+ Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
254
+ not desired, use the following code.
255
+
256
+ ```py
257
+ >>> for name, param in model_peft.named_parameters():
258
+ ... if ...: # some check on name (ex. if 'lora' in name)
259
+ ... param.requires_grad = False
260
+ ```
261
+
262
+ Args:
263
+ adapter_name (`str` or `List[str]`):
264
+ The name of the adapter(s) to be activated.
265
+ """
266
+ if isinstance(adapter_name, str):
267
+ adapter_name = [adapter_name]
268
+
269
+ mismatched = set(adapter_name) - set(self.peft_config.keys())
270
+ if mismatched:
271
+ raise ValueError(
272
+ f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}"
273
+ )
274
+
275
+ self.base_model.set_adapter(adapter_name)
276
+ _set_adapter(self, adapter_name)
277
+
278
+ def delete_adapter(self, adapter_name: Union[str, list[str]]) -> None:
279
+ if isinstance(adapter_name, str):
280
+ adapter_name = [adapter_name]
281
+
282
+ mismatched = set(adapter_name) - set(self.peft_config.keys())
283
+ if mismatched:
284
+ raise ValueError(
285
+ f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}"
286
+ )
287
+
288
+ self.base_model.delete_adapter(adapter_name)
289
+
290
+ def merge_and_unload(self, *args: Any, **kwargs: Any):
291
+ r"""
292
+ This method merges the adapter layers into the base model. This is needed if someone wants to use the base
293
+ model as a standalone model.
294
+
295
+ Args:
296
+ progressbar (`bool`):
297
+ whether to show a progressbar indicating the unload and merge process
298
+ safe_merge (`bool`):
299
+ whether to activate the safe merging check to check if there is any potential Nan in the adapter
300
+ weights
301
+ adapter_names (`List[str]`, *optional*):
302
+ The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
303
+ to `None`.
304
+ """
305
+ return self.base_model.merge_and_unload(*args, **kwargs)
306
+
307
+ def unload(self, *args: Any, **kwargs: Any):
308
+ """
309
+ Gets back the base model by removing all the adapter modules without merging. This gives back the original base
310
+ model.
311
+ """
312
+ return self.base_model.unload(*args, **kwargs)
313
+
314
+ def get_layer_status(self):
315
+ raise TypeError(f"get_layer_status is not supported for {self.__class__.__name__}.")
316
+
317
+ def get_model_status(self):
318
+ raise TypeError(f"get_model_status is not supported for {self.__class__.__name__}.")
319
+
320
+ @classmethod
321
+ def _split_kwargs(cls, kwargs: dict[str, Any]):
322
+ return PeftModel._split_kwargs(kwargs)
323
+
324
+ def load_adapter(self, model_id: str, adapter_name: str, *args: Any, **kwargs: Any):
325
+ output = PeftModel.load_adapter(self, model_id, adapter_name, *args, **kwargs)
326
+ # TODO: not quite clear why this is necessary but tests fail without it
327
+ self.set_adapter(self.active_adapters)
328
+ return output
329
+
330
+ def create_or_update_model_card(self, output_dir: str):
331
+ raise NotImplementedError(f"Model card creation is not supported for {self.__class__.__name__} (yet).")
332
+
333
+ def save_pretrained(
334
+ self,
335
+ save_directory: str,
336
+ safe_serialization: bool = False,
337
+ selected_adapters: Optional[list[str]] = None,
338
+ **kwargs: Any,
339
+ ):
340
+ raise NotImplementedError(f"Saving is not supported for {self.__class__.__name__} (yet).")
341
+
342
+ @classmethod
343
+ def from_pretrained(
344
+ cls,
345
+ model: nn.Module,
346
+ model_id: str | os.PathLike,
347
+ adapter_name: str = "default",
348
+ is_trainable: bool = False,
349
+ config: Optional[PeftConfig] = None,
350
+ **kwargs: Any,
351
+ ):
352
+ r"""
353
+ Instantiate a PEFT mixed model from a pretrained model and loaded PEFT weights.
354
+
355
+ Note that the passed `model` may be modified inplace.
356
+
357
+ Args:
358
+ model (`nn.Module`):
359
+ The model to be adapted.
360
+ model_id (`str` or `os.PathLike`):
361
+ The name of the PEFT configuration to use. Can be either:
362
+ - A string, the `model id` of a PEFT configuration hosted inside a model repo on the Hugging Face
363
+ Hub.
364
+ - A path to a directory containing a PEFT configuration file saved using the `save_pretrained`
365
+ method (`./my_peft_config_directory/`).
366
+ adapter_name (`str`, *optional*, defaults to `"default"`):
367
+ The name of the adapter to be loaded. This is useful for loading multiple adapters.
368
+ is_trainable (`bool`, *optional*, defaults to `False`):
369
+ Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and use for
370
+ inference
371
+ config ([`~peft.PeftConfig`], *optional*):
372
+ The configuration object to use instead of an automatically loaded configuration. This configuration
373
+ object is mutually exclusive with `model_id` and `kwargs`. This is useful when configuration is already
374
+ loaded before calling `from_pretrained`.
375
+ kwargs: (`optional`):
376
+ Additional keyword arguments passed along to the specific PEFT configuration class.
377
+ """
378
+ # note: adapted from PeftModel.from_pretrained
379
+ from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING
380
+
381
+ # load the config
382
+ if config is None:
383
+ config = PEFT_TYPE_TO_CONFIG_MAPPING[
384
+ PeftConfig._get_peft_type(
385
+ model_id,
386
+ subfolder=kwargs.get("subfolder", None),
387
+ revision=kwargs.get("revision", None),
388
+ cache_dir=kwargs.get("cache_dir", None),
389
+ use_auth_token=kwargs.get("use_auth_token", None),
390
+ )
391
+ ].from_pretrained(model_id, **kwargs)
392
+ elif isinstance(config, PeftConfig):
393
+ config.inference_mode = not is_trainable
394
+ else:
395
+ raise ValueError(f"The input config must be a PeftConfig, got {config.__class__}")
396
+
397
+ # note: this is different from PeftModel.from_pretrained
398
+ if config.peft_type not in PEFT_TYPE_TO_MODEL_MAPPING:
399
+ raise ValueError(f"Adapter of type {config.peft_type} is not supported for mixed models.")
400
+
401
+ if (getattr(model, "hf_device_map", None) is not None) and len(
402
+ set(model.hf_device_map.values()).intersection({"cpu", "disk"})
403
+ ) > 0:
404
+ remove_hook_from_submodules(model)
405
+
406
+ if config.is_prompt_learning and is_trainable:
407
+ # note: should not be possible to reach, but just in case
408
+ raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
409
+ else:
410
+ config.inference_mode = not is_trainable
411
+
412
+ # note: this is different from PeftModel.from_pretrained, we always return a PeftMixedModel
413
+ model = cls(model, config, adapter_name)
414
+ model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs)
415
+ return model
peft/peft_model.py ADDED
The diff for this file is too large to render. See raw diff
 
peft/py.typed ADDED
File without changes
peft/tuners/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ # There's no way to ignore "F401 '...' imported but unused" warnings in this
3
+ # module, but to preserve other warnings. So, don't check this module at all
4
+
5
+ # coding=utf-8
6
+ # Copyright 2023-present the HuggingFace Inc. team.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel
21
+ from .lora import LoraConfig, LoraModel, LoftQConfig
22
+ from .loha import LoHaConfig, LoHaModel
23
+ from .lokr import LoKrConfig, LoKrModel
24
+ from .ia3 import IA3Config, IA3Model
25
+ from .adalora import AdaLoraConfig, AdaLoraModel
26
+ from .boft import BOFTConfig, BOFTModel
27
+ from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType
28
+ from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
29
+ from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
30
+ from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit
31
+ from .oft import OFTConfig, OFTModel
32
+ from .mixed import MixedModel
33
+ from .poly import PolyConfig, PolyModel
34
+ from .ln_tuning import LNTuningConfig, LNTuningModel
35
+ from .vera import VeraConfig, VeraModel
peft/tuners/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.38 kB). View file
 
peft/tuners/__pycache__/lycoris_utils.cpython-310.pyc ADDED
Binary file (14.3 kB). View file
 
peft/tuners/__pycache__/tuners_utils.cpython-310.pyc ADDED
Binary file (27.1 kB). View file
 
peft/tuners/adalora/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from peft.import_utils import is_bnb_4bit_available, is_bnb_available
16
+
17
+ from .config import AdaLoraConfig
18
+ from .gptq import SVDQuantLinear
19
+ from .layer import AdaLoraLayer, RankAllocator, SVDLinear
20
+ from .model import AdaLoraModel
21
+
22
+
23
+ __all__ = ["AdaLoraConfig", "AdaLoraLayer", "AdaLoraModel", "SVDLinear", "RankAllocator", "SVDQuantLinear"]
24
+
25
+
26
+ def __getattr__(name):
27
+ if (name == "SVDLinear8bitLt") and is_bnb_available():
28
+ from .bnb import SVDLinear8bitLt
29
+
30
+ return SVDLinear8bitLt
31
+
32
+ if (name == "SVDLinear4bit") and is_bnb_4bit_available():
33
+ from .bnb import SVDLinear4bit
34
+
35
+ return SVDLinear4bit
36
+
37
+ raise AttributeError(f"module {__name__} has no attribute {name}")
peft/tuners/adalora/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (867 Bytes). View file
 
peft/tuners/adalora/__pycache__/bnb.cpython-310.pyc ADDED
Binary file (3.18 kB). View file
 
peft/tuners/adalora/__pycache__/config.cpython-310.pyc ADDED
Binary file (2.85 kB). View file
 
peft/tuners/adalora/__pycache__/gptq.cpython-310.pyc ADDED
Binary file (1.6 kB). View file
 
peft/tuners/adalora/__pycache__/layer.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
peft/tuners/adalora/__pycache__/model.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
peft/tuners/adalora/bnb.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any
16
+
17
+ import torch
18
+
19
+ from peft.import_utils import is_bnb_4bit_available, is_bnb_available
20
+
21
+ from .layer import AdaLoraLayer
22
+
23
+
24
+ if is_bnb_available():
25
+
26
+ class SVDLinear8bitLt(torch.nn.Module, AdaLoraLayer):
27
+ # Low-rank matrix for SVD-based adaptation
28
+ def __init__(
29
+ self,
30
+ base_layer: torch.nn.Module,
31
+ adapter_name: str,
32
+ r: int = 0,
33
+ lora_alpha: int = 1,
34
+ lora_dropout: float = 0.0,
35
+ init_lora_weights: bool = True,
36
+ **kwargs,
37
+ ) -> None:
38
+ super().__init__()
39
+ AdaLoraLayer.__init__(self, base_layer)
40
+ # Freezing the pre-trained weight matrix
41
+ self.get_base_layer().weight.requires_grad = False
42
+
43
+ self._active_adapter = adapter_name
44
+ self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ # note: no check for self.merged because merging is not supported (yet)
48
+ result = self.base_layer(x)
49
+
50
+ if self.disable_adapters:
51
+ return result
52
+
53
+ for active_adapter in self.active_adapters:
54
+ if active_adapter not in self.lora_A.keys():
55
+ continue
56
+ requires_conversion = not torch.is_autocast_enabled()
57
+ if requires_conversion:
58
+ expected_dtype = result.dtype
59
+ if x.dtype != torch.float32:
60
+ x = x.float()
61
+
62
+ lora_A = self.lora_A[active_adapter]
63
+ lora_B = self.lora_B[active_adapter]
64
+ lora_E = self.lora_E[active_adapter]
65
+ dropout = self.lora_dropout[active_adapter]
66
+ scaling = self.scaling[active_adapter]
67
+ ranknum = self.ranknum[active_adapter] + 1e-5
68
+
69
+ output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T
70
+ if requires_conversion:
71
+ output = output.to(expected_dtype)
72
+ output = output * scaling / ranknum
73
+ # inplace operation on view is forbidden for MatMul8bitLtBackward, so avoid it
74
+ result = result + output
75
+ return result
76
+
77
+ def __repr__(self) -> str:
78
+ rep = super().__repr__()
79
+ return "adalora." + rep
80
+
81
+
82
+ if is_bnb_4bit_available():
83
+
84
+ class SVDLinear4bit(torch.nn.Module, AdaLoraLayer):
85
+ # Low-rank matrix for SVD-based adaptation
86
+ def __init__(
87
+ self,
88
+ base_layer: torch.nn.Module,
89
+ adapter_name: str,
90
+ r: int = 0,
91
+ lora_alpha: int = 1,
92
+ lora_dropout: float = 0.0,
93
+ init_lora_weights: bool = True,
94
+ **kwargs,
95
+ ) -> None:
96
+ super().__init__()
97
+ AdaLoraLayer.__init__(self, base_layer)
98
+ # Freezing the pre-trained weight matrix
99
+ self.get_base_layer().weight.requires_grad = False
100
+
101
+ self._active_adapter = adapter_name
102
+ self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
103
+
104
+ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
105
+ # note: no check for self.merged because merging is not supported (yet)
106
+ result = self.base_layer(x, *args, **kwargs)
107
+
108
+ if self.disable_adapters:
109
+ return result
110
+
111
+ # As per Tim Dettmers, for 4bit, we need to defensively clone here.
112
+ # The reason is that in some cases, an error can occur that backprop
113
+ # does not work on a manipulated view. This issue may be solved with
114
+ # newer PyTorch versions but this would need extensive testing to be
115
+ # sure.
116
+ result = result.clone()
117
+
118
+ for active_adapter in self.active_adapters:
119
+ if active_adapter not in self.lora_A.keys():
120
+ continue
121
+
122
+ lora_A = self.lora_A[active_adapter]
123
+ lora_B = self.lora_B[active_adapter]
124
+ lora_E = self.lora_E[active_adapter]
125
+ dropout = self.lora_dropout[active_adapter]
126
+ scaling = self.scaling[active_adapter]
127
+ ranknum = self.ranknum[active_adapter] + 1e-5
128
+
129
+ requires_conversion = not torch.is_autocast_enabled()
130
+ if requires_conversion:
131
+ expected_dtype = result.dtype
132
+ compute_dtype = lora_A.dtype
133
+ if x.dtype != compute_dtype:
134
+ x = x.to(compute_dtype)
135
+
136
+ output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T
137
+ if requires_conversion:
138
+ output = output.to(expected_dtype)
139
+ output = output * scaling / ranknum
140
+ result += output
141
+ return result
142
+
143
+ def __repr__(self) -> str:
144
+ rep = super().__repr__()
145
+ return "adalora." + rep