Tello2020 commited on
Commit
fd5f698
1 Parent(s): b3ae437

Upload 14 files

Browse files
Files changed (14) hide show
  1. LICENSE +21 -0
  2. README.md +20 -12
  3. bucketing.py +32 -0
  4. cog.yaml +18 -0
  5. dataset.py +581 -0
  6. download-weights +48 -0
  7. inference.py +238 -0
  8. lama.py +350 -0
  9. lora.py +1312 -0
  10. predict.py +101 -0
  11. samples.py +57 -0
  12. train.py +998 -0
  13. unet_3d_blocks.py +836 -0
  14. unet_3d_condition.py +499 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 ExponentialML
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,20 @@
1
- ---
2
- title: Text2video2024
3
- emoji: 🦀
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: streamlit
7
- sdk_version: 1.25.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
1
+ # cog-text2video
2
+
3
+ A Cog implementation with txt2vid and vid2vid of:
4
+
5
+ - https://huggingface.co/cerspense/zeroscope_v2_XL
6
+ - https://huggingface.co/cerspense/zeroscope_v2_576w
7
+ - https://huggingface.co/camenduru/potat1
8
+
9
+ Deployed at https://replicate.com/anotherjesse/zeroscope-v2-xl
10
+
11
+ ## Shoutouts
12
+
13
+ - [Text-To-Video-Finetuning](https://github.com/camenduru/Text-To-Video-Finetuning) - Finetune ModelScope's Text To Video model using Diffusers
14
+ - [Showlab](https://github.com/showlab/Tune-A-Video) and bryandlee[https://github.com/bryandlee/Tune-A-Video] for their Tune-A-Video contribution that made this much easier.
15
+ - [lucidrains](https://github.com/lucidrains) for their implementations around video diffusion.
16
+ - [cloneofsimo](https://github.com/cloneofsimo) for their diffusers implementation of LoRA.
17
+ - [kabachuha](https://github.com/kabachuha) for their conversion scripts, training ideas, and webui works.
18
+ - [JCBrouwer](https://github.com/JCBrouwer) Inference implementations.
19
+ - [sergiobr](https://github.com/sergiobr) Helpful ideas and bug fixes.
20
+ - [cjwbw/damo-text-to-video](https://replicate.com/cjwbw/damo-text-to-video) for original [cog](https://github.com/replicate/cog) implementation
bucketing.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ def min_res(size, min_size): return 192 if size < 192 else size
4
+
5
+ def up_down_bucket(m_size, in_size, direction):
6
+ if direction == 'down': return abs(int(m_size - in_size))
7
+ if direction == 'up': return abs(int(m_size + in_size))
8
+
9
+ def get_bucket_sizes(size, direction: 'down', min_size):
10
+ multipliers = [64, 128]
11
+ for i, m in enumerate(multipliers):
12
+ res = up_down_bucket(m, size, direction)
13
+ multipliers[i] = min_res(res, min_size=min_size)
14
+ return multipliers
15
+
16
+ def closest_bucket(m_size, size, direction, min_size):
17
+ lst = get_bucket_sizes(m_size, direction, min_size)
18
+ return lst[min(range(len(lst)), key=lambda i: abs(lst[i]-size))]
19
+
20
+ def resolve_bucket(i,h,w): return (i / (h / w))
21
+
22
+ def sensible_buckets(m_width, m_height, w, h, min_size=192):
23
+ if h > w:
24
+ w = resolve_bucket(m_width, h, w)
25
+ w = closest_bucket(m_width, w, 'down', min_size=min_size)
26
+ return w, m_height
27
+ if h < w:
28
+ h = resolve_bucket(m_height, w, h)
29
+ h = closest_bucket(m_height, h, 'down', min_size=min_size)
30
+ return m_width, h
31
+
32
+ return m_width, m_height
cog.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ gpu: true
3
+ python_version: "3.10"
4
+ cuda: "11.7"
5
+ python_packages:
6
+ - "accelerate==0.20.3"
7
+ - "diffusers==0.17.1"
8
+ - "gradio==3.35.2"
9
+ - "imageio[ffmpeg]==2.31.1"
10
+ - "torch==2.0.1"
11
+ - "torchvision==0.15.2"
12
+ - "transformers==4.30.2"
13
+ - "einops==0.6.1"
14
+ - "omegaconf==2.3.0"
15
+ - "opencv-python-headless==4.7.0.72"
16
+ - "decord==0.6.0"
17
+
18
+ predict: "predict.py:Predictor"
dataset.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import decord
3
+ import numpy as np
4
+ import random
5
+ import json
6
+ import torchvision
7
+ import torchvision.transforms as T
8
+ import torch
9
+
10
+ from glob import glob
11
+ from PIL import Image
12
+ from itertools import islice
13
+ from pathlib import Path
14
+ from .bucketing import sensible_buckets
15
+
16
+ decord.bridge.set_bridge('torch')
17
+
18
+ from torch.utils.data import Dataset
19
+ from einops import rearrange, repeat
20
+
21
+ def get_prompt_ids(prompt, tokenizer):
22
+ prompt_ids = tokenizer(
23
+ prompt,
24
+ truncation=True,
25
+ padding="max_length",
26
+ max_length=tokenizer.model_max_length,
27
+ return_tensors="pt",
28
+ ).input_ids
29
+
30
+ return prompt_ids
31
+
32
+ def read_caption_file(caption_file):
33
+ with open(caption_file, 'r', encoding="utf8") as t:
34
+ return t.read()
35
+
36
+ def get_text_prompt(
37
+ text_prompt: str = '',
38
+ fallback_prompt: str= '',
39
+ file_path:str = '',
40
+ ext_types=['.mp4'],
41
+ use_caption=False
42
+ ):
43
+ try:
44
+ if use_caption:
45
+ if len(text_prompt) > 1: return text_prompt
46
+ caption_file = ''
47
+ # Use caption on per-video basis (One caption PER video)
48
+ for ext in ext_types:
49
+ maybe_file = file_path.replace(ext, '.txt')
50
+ if maybe_file.endswith(ext_types): continue
51
+ if os.path.exists(maybe_file):
52
+ caption_file = maybe_file
53
+ break
54
+
55
+ if os.path.exists(caption_file):
56
+ return read_caption_file(caption_file)
57
+
58
+ # Return fallback prompt if no conditions are met.
59
+ return fallback_prompt
60
+
61
+ return text_prompt
62
+ except:
63
+ print(f"Couldn't read prompt caption for {file_path}. Using fallback.")
64
+ return fallback_prompt
65
+
66
+
67
+ def get_video_frames(vr, start_idx, sample_rate=1, max_frames=24):
68
+ max_range = len(vr)
69
+ frame_number = sorted((0, start_idx, max_range))[1]
70
+
71
+ frame_range = range(frame_number, max_range, sample_rate)
72
+ frame_range_indices = list(frame_range)[:max_frames]
73
+
74
+ return frame_range_indices
75
+
76
+ def process_video(vid_path, use_bucketing, w, h, get_frame_buckets, get_frame_batch):
77
+ if use_bucketing:
78
+ vr = decord.VideoReader(vid_path)
79
+ resize = get_frame_buckets(vr)
80
+ video = get_frame_batch(vr, resize=resize)
81
+
82
+ else:
83
+ vr = decord.VideoReader(vid_path, width=w, height=h)
84
+ video = get_frame_batch(vr)
85
+
86
+ return video, vr
87
+
88
+ # https://github.com/ExponentialML/Video-BLIP2-Preprocessor
89
+ class VideoJsonDataset(Dataset):
90
+ def __init__(
91
+ self,
92
+ tokenizer = None,
93
+ width: int = 256,
94
+ height: int = 256,
95
+ n_sample_frames: int = 4,
96
+ sample_start_idx: int = 1,
97
+ frame_step: int = 1,
98
+ json_path: str ="",
99
+ json_data = None,
100
+ vid_data_key: str = "video_path",
101
+ preprocessed: bool = False,
102
+ use_bucketing: bool = False,
103
+ **kwargs
104
+ ):
105
+ self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg")
106
+ self.use_bucketing = use_bucketing
107
+ self.tokenizer = tokenizer
108
+ self.preprocessed = preprocessed
109
+
110
+ self.vid_data_key = vid_data_key
111
+ self.train_data = self.load_from_json(json_path, json_data)
112
+
113
+ self.width = width
114
+ self.height = height
115
+
116
+ self.n_sample_frames = n_sample_frames
117
+ self.sample_start_idx = sample_start_idx
118
+ self.frame_step = frame_step
119
+
120
+ def build_json(self, json_data):
121
+ extended_data = []
122
+ for data in json_data['data']:
123
+ for nested_data in data['data']:
124
+ self.build_json_dict(
125
+ data,
126
+ nested_data,
127
+ extended_data
128
+ )
129
+ json_data = extended_data
130
+ return json_data
131
+
132
+ def build_json_dict(self, data, nested_data, extended_data):
133
+ clip_path = nested_data['clip_path'] if 'clip_path' in nested_data else None
134
+
135
+ extended_data.append({
136
+ self.vid_data_key: data[self.vid_data_key],
137
+ 'frame_index': nested_data['frame_index'],
138
+ 'prompt': nested_data['prompt'],
139
+ 'clip_path': clip_path
140
+ })
141
+
142
+ def load_from_json(self, path, json_data):
143
+ try:
144
+ with open(path) as jpath:
145
+ print(f"Loading JSON from {path}")
146
+ json_data = json.load(jpath)
147
+
148
+ return self.build_json(json_data)
149
+
150
+ except:
151
+ self.train_data = []
152
+ print("Non-existant JSON path. Skipping.")
153
+
154
+ def validate_json(self, base_path, path):
155
+ return os.path.exists(f"{base_path}/{path}")
156
+
157
+ def get_frame_range(self, vr):
158
+ return get_video_frames(
159
+ vr,
160
+ self.sample_start_idx,
161
+ self.frame_step,
162
+ self.n_sample_frames
163
+ )
164
+
165
+ def get_vid_idx(self, vr, vid_data=None):
166
+ frames = self.n_sample_frames
167
+
168
+ if vid_data is not None:
169
+ idx = vid_data['frame_index']
170
+ else:
171
+ idx = self.sample_start_idx
172
+
173
+ return idx
174
+
175
+ def get_frame_buckets(self, vr):
176
+ _, h, w = vr[0].shape
177
+ width, height = sensible_buckets(self.width, self.height, h, w)
178
+ resize = T.transforms.Resize((height, width), antialias=True)
179
+
180
+ return resize
181
+
182
+ def get_frame_batch(self, vr, resize=None):
183
+ frame_range = self.get_frame_range(vr)
184
+ frames = vr.get_batch(frame_range)
185
+ video = rearrange(frames, "f h w c -> f c h w")
186
+
187
+ if resize is not None: video = resize(video)
188
+ return video
189
+
190
+ def process_video_wrapper(self, vid_path):
191
+ video, vr = process_video(
192
+ vid_path,
193
+ self.use_bucketing,
194
+ self.width,
195
+ self.height,
196
+ self.get_frame_buckets,
197
+ self.get_frame_batch
198
+ )
199
+
200
+ return video, vr
201
+
202
+ def train_data_batch(self, index):
203
+
204
+ # If we are training on individual clips.
205
+ if 'clip_path' in self.train_data[index] and \
206
+ self.train_data[index]['clip_path'] is not None:
207
+
208
+ vid_data = self.train_data[index]
209
+
210
+ clip_path = vid_data['clip_path']
211
+
212
+ # Get video prompt
213
+ prompt = vid_data['prompt']
214
+
215
+ video, _ = self.process_video_wrapper(clip_path)
216
+
217
+ prompt_ids = prompt_ids = get_prompt_ids(prompt, self.tokenizer)
218
+
219
+ return video, prompt, prompt_ids
220
+
221
+ # Assign train data
222
+ train_data = self.train_data[index]
223
+
224
+ # Get the frame of the current index.
225
+ self.sample_start_idx = train_data['frame_index']
226
+
227
+ # Initialize resize
228
+ resize = None
229
+
230
+ video, vr = self.process_video_wrapper(train_data[self.vid_data_key])
231
+
232
+ # Get video prompt
233
+ prompt = train_data['prompt']
234
+ vr.seek(0)
235
+
236
+ prompt_ids = get_prompt_ids(prompt, self.tokenizer)
237
+
238
+ return video, prompt, prompt_ids
239
+
240
+ @staticmethod
241
+ def __getname__(): return 'json'
242
+
243
+ def __len__(self):
244
+ if self.train_data is not None:
245
+ return len(self.train_data)
246
+ else:
247
+ return 0
248
+
249
+ def __getitem__(self, index):
250
+
251
+ # Initialize variables
252
+ video = None
253
+ prompt = None
254
+ prompt_ids = None
255
+
256
+ # Use default JSON training
257
+ if self.train_data is not None:
258
+ video, prompt, prompt_ids = self.train_data_batch(index)
259
+
260
+ example = {
261
+ "pixel_values": (video / 127.5 - 1.0),
262
+ "prompt_ids": prompt_ids[0],
263
+ "text_prompt": prompt,
264
+ 'dataset': self.__getname__()
265
+ }
266
+
267
+ return example
268
+
269
+
270
+ class SingleVideoDataset(Dataset):
271
+ def __init__(
272
+ self,
273
+ tokenizer = None,
274
+ width: int = 256,
275
+ height: int = 256,
276
+ n_sample_frames: int = 4,
277
+ frame_step: int = 1,
278
+ single_video_path: str = "",
279
+ single_video_prompt: str = "",
280
+ use_caption: bool = False,
281
+ use_bucketing: bool = False,
282
+ **kwargs
283
+ ):
284
+ self.tokenizer = tokenizer
285
+ self.use_bucketing = use_bucketing
286
+ self.frames = []
287
+ self.index = 1
288
+
289
+ self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg")
290
+ self.n_sample_frames = n_sample_frames
291
+ self.frame_step = frame_step
292
+
293
+ self.single_video_path = single_video_path
294
+ self.single_video_prompt = single_video_prompt
295
+
296
+ self.width = width
297
+ self.height = height
298
+ def create_video_chunks(self):
299
+ # Create a list of frames separated by sample frames
300
+ # [(1,2,3), (4,5,6), ...]
301
+ vr = decord.VideoReader(self.single_video_path)
302
+ vr_range = range(1, len(vr), self.frame_step)
303
+
304
+ self.frames = list(self.chunk(vr_range, self.n_sample_frames))
305
+
306
+ # Delete any list that contains an out of range index.
307
+ for i, inner_frame_nums in enumerate(self.frames):
308
+ for frame_num in inner_frame_nums:
309
+ if frame_num > len(vr):
310
+ print(f"Removing out of range index list at position: {i}...")
311
+ del self.frames[i]
312
+
313
+ return self.frames
314
+
315
+ def chunk(self, it, size):
316
+ it = iter(it)
317
+ return iter(lambda: tuple(islice(it, size)), ())
318
+
319
+ def get_frame_batch(self, vr, resize=None):
320
+ index = self.index
321
+ frames = vr.get_batch(self.frames[self.index])
322
+ video = rearrange(frames, "f h w c -> f c h w")
323
+
324
+ if resize is not None: video = resize(video)
325
+ return video
326
+
327
+ def get_frame_buckets(self, vr):
328
+ _, h, w = vr[0].shape
329
+ width, height = sensible_buckets(self.width, self.height, h, w)
330
+ resize = T.transforms.Resize((height, width), antialias=True)
331
+
332
+ return resize
333
+
334
+ def process_video_wrapper(self, vid_path):
335
+ video, vr = process_video(
336
+ vid_path,
337
+ self.use_bucketing,
338
+ self.width,
339
+ self.height,
340
+ self.get_frame_buckets,
341
+ self.get_frame_batch
342
+ )
343
+
344
+ return video, vr
345
+
346
+ def single_video_batch(self, index):
347
+ train_data = self.single_video_path
348
+ self.index = index
349
+
350
+ if train_data.endswith(self.vid_types):
351
+ video, _ = self.process_video_wrapper(train_data)
352
+
353
+ prompt = self.single_video_prompt
354
+ prompt_ids = get_prompt_ids(prompt, self.tokenizer)
355
+
356
+ return video, prompt, prompt_ids
357
+ else:
358
+ raise ValueError(f"Single video is not a video type. Types: {self.vid_types}")
359
+
360
+ @staticmethod
361
+ def __getname__(): return 'single_video'
362
+
363
+ def __len__(self):
364
+
365
+ return len(self.create_video_chunks())
366
+
367
+ def __getitem__(self, index):
368
+
369
+ video, prompt, prompt_ids = self.single_video_batch(index)
370
+
371
+ example = {
372
+ "pixel_values": (video / 127.5 - 1.0),
373
+ "prompt_ids": prompt_ids[0],
374
+ "text_prompt": prompt,
375
+ 'dataset': self.__getname__()
376
+ }
377
+
378
+ return example
379
+
380
+ class ImageDataset(Dataset):
381
+
382
+ def __init__(
383
+ self,
384
+ tokenizer = None,
385
+ width: int = 256,
386
+ height: int = 256,
387
+ base_width: int = 256,
388
+ base_height: int = 256,
389
+ use_caption: bool = False,
390
+ image_dir: str = '',
391
+ single_img_prompt: str = '',
392
+ use_bucketing: bool = False,
393
+ fallback_prompt: str = '',
394
+ **kwargs
395
+ ):
396
+ self.tokenizer = tokenizer
397
+ self.img_types = (".png", ".jpg", ".jpeg", '.bmp')
398
+ self.use_bucketing = use_bucketing
399
+
400
+ self.image_dir = self.get_images_list(image_dir)
401
+ self.fallback_prompt = fallback_prompt
402
+
403
+ self.use_caption = use_caption
404
+ self.single_img_prompt = single_img_prompt
405
+
406
+ self.width = width
407
+ self.height = height
408
+
409
+ def get_images_list(self, image_dir):
410
+ if os.path.exists(image_dir):
411
+ imgs = [x for x in os.listdir(image_dir) if x.endswith(self.img_types)]
412
+ full_img_dir = []
413
+
414
+ for img in imgs:
415
+ full_img_dir.append(f"{image_dir}/{img}")
416
+
417
+ return sorted(full_img_dir)
418
+
419
+ return ['']
420
+
421
+ def image_batch(self, index):
422
+ train_data = self.image_dir[index]
423
+ img = train_data
424
+
425
+ try:
426
+ img = torchvision.io.read_image(img, mode=torchvision.io.ImageReadMode.RGB)
427
+ except:
428
+ img = T.transforms.PILToTensor()(Image.open(img).convert("RGB"))
429
+
430
+ width = self.width
431
+ height = self.height
432
+
433
+ if self.use_bucketing:
434
+ _, h, w = img.shape
435
+ width, height = sensible_buckets(width, height, w, h)
436
+
437
+ resize = T.transforms.Resize((height, width), antialias=True)
438
+
439
+ img = resize(img)
440
+ img = repeat(img, 'c h w -> f c h w', f=1)
441
+
442
+ prompt = get_text_prompt(
443
+ file_path=train_data,
444
+ text_prompt=self.single_img_prompt,
445
+ fallback_prompt=self.fallback_prompt,
446
+ ext_types=self.img_types,
447
+ use_caption=True
448
+ )
449
+ prompt_ids = get_prompt_ids(prompt, self.tokenizer)
450
+
451
+ return img, prompt, prompt_ids
452
+
453
+ @staticmethod
454
+ def __getname__(): return 'image'
455
+
456
+ def __len__(self):
457
+ # Image directory
458
+ if os.path.exists(self.image_dir[0]):
459
+ return len(self.image_dir)
460
+ else:
461
+ return 0
462
+
463
+ def __getitem__(self, index):
464
+ img, prompt, prompt_ids = self.image_batch(index)
465
+ example = {
466
+ "pixel_values": (img / 127.5 - 1.0),
467
+ "prompt_ids": prompt_ids[0],
468
+ "text_prompt": prompt,
469
+ 'dataset': self.__getname__()
470
+ }
471
+
472
+ return example
473
+
474
+ class VideoFolderDataset(Dataset):
475
+ def __init__(
476
+ self,
477
+ tokenizer=None,
478
+ width: int = 256,
479
+ height: int = 256,
480
+ n_sample_frames: int = 16,
481
+ fps: int = 8,
482
+ path: str = "./data",
483
+ fallback_prompt: str = "",
484
+ use_bucketing: bool = False,
485
+ **kwargs
486
+ ):
487
+ self.tokenizer = tokenizer
488
+ self.use_bucketing = use_bucketing
489
+
490
+ self.fallback_prompt = fallback_prompt
491
+
492
+ self.video_files = glob(f"{path}/*.mp4")
493
+
494
+ self.width = width
495
+ self.height = height
496
+
497
+ self.n_sample_frames = n_sample_frames
498
+ self.fps = fps
499
+
500
+ def get_frame_buckets(self, vr):
501
+ _, h, w = vr[0].shape
502
+ width, height = sensible_buckets(self.width, self.height, h, w)
503
+ resize = T.transforms.Resize((height, width), antialias=True)
504
+
505
+ return resize
506
+
507
+ def get_frame_batch(self, vr, resize=None):
508
+ n_sample_frames = self.n_sample_frames
509
+ native_fps = vr.get_avg_fps()
510
+
511
+ every_nth_frame = max(1, round(native_fps / self.fps))
512
+ every_nth_frame = min(len(vr), every_nth_frame)
513
+
514
+ effective_length = len(vr) // every_nth_frame
515
+ if effective_length < n_sample_frames:
516
+ n_sample_frames = effective_length
517
+
518
+ effective_idx = random.randint(0, (effective_length - n_sample_frames))
519
+ idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames)
520
+
521
+ video = vr.get_batch(idxs)
522
+ video = rearrange(video, "f h w c -> f c h w")
523
+
524
+ if resize is not None: video = resize(video)
525
+ return video, vr
526
+
527
+ def process_video_wrapper(self, vid_path):
528
+ video, vr = process_video(
529
+ vid_path,
530
+ self.use_bucketing,
531
+ self.width,
532
+ self.height,
533
+ self.get_frame_buckets,
534
+ self.get_frame_batch
535
+ )
536
+ return video, vr
537
+
538
+ def get_prompt_ids(self, prompt):
539
+ return self.tokenizer(
540
+ prompt,
541
+ truncation=True,
542
+ padding="max_length",
543
+ max_length=self.tokenizer.model_max_length,
544
+ return_tensors="pt",
545
+ ).input_ids
546
+
547
+ @staticmethod
548
+ def __getname__(): return 'folder'
549
+
550
+ def __len__(self):
551
+ return len(self.video_files)
552
+
553
+ def __getitem__(self, index):
554
+
555
+ video, _ = self.process_video_wrapper(self.video_files[index])
556
+
557
+ if os.path.exists(self.video_files[index].replace(".mp4", ".txt")):
558
+ with open(self.video_files[index].replace(".mp4", ".txt"), "r") as f:
559
+ prompt = f.read()
560
+ else:
561
+ prompt = self.fallback_prompt
562
+
563
+ prompt_ids = self.get_prompt_ids(prompt)
564
+
565
+ return {"pixel_values": (video[0] / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, 'dataset': self.__getname__()}
566
+
567
+ class CachedDataset(Dataset):
568
+ def __init__(self,cache_dir: str = ''):
569
+ self.cache_dir = cache_dir
570
+ self.cached_data_list = self.get_files_list()
571
+
572
+ def get_files_list(self):
573
+ tensors_list = [f"{self.cache_dir}/{x}" for x in os.listdir(self.cache_dir) if x.endswith('.pt')]
574
+ return sorted(tensors_list)
575
+
576
+ def __len__(self):
577
+ return len(self.cached_data_list)
578
+
579
+ def __getitem__(self, index):
580
+ cached_latent = torch.load(self.cached_data_list[index], map_location='cuda:0')
581
+ return cached_latent
download-weights ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+
4
+ import os
5
+ import shutil
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+
9
+ MODEL_CACHE = "model-cache"
10
+ TMP_CACHE = "tmp-cache"
11
+
12
+ if os.path.exists(MODEL_CACHE):
13
+ shutil.rmtree(MODEL_CACHE)
14
+ os.makedirs(MODEL_CACHE, exist_ok=True)
15
+
16
+ pipe = DiffusionPipeline.from_pretrained(
17
+ "cerspense/zeroscope_v2_XL",
18
+ torch_dtype=torch.float16,
19
+ cache_dir=TMP_CACHE,
20
+ )
21
+
22
+ pipe.save_pretrained(MODEL_CACHE + "/xl")
23
+
24
+ pipe = DiffusionPipeline.from_pretrained(
25
+ "cerspense/zeroscope_v2_576w",
26
+ torch_dtype=torch.float16,
27
+ cache_dir=TMP_CACHE,
28
+ )
29
+
30
+ pipe.save_pretrained(MODEL_CACHE + "/576w")
31
+
32
+ pipe = DiffusionPipeline.from_pretrained(
33
+ "camenduru/potat1",
34
+ torch_dtype=torch.float16,
35
+ cache_dir=TMP_CACHE,
36
+ )
37
+
38
+ pipe.save_pretrained(MODEL_CACHE + "/potat1")
39
+
40
+ pipe = DiffusionPipeline.from_pretrained(
41
+ "strangeman3107/animov-512x",
42
+ torch_dtype=torch.float16,
43
+ cache_dir=TMP_CACHE,
44
+ )
45
+
46
+ pipe.save_pretrained(MODEL_CACHE + "/animov-512x")
47
+
48
+ shutil.rmtree(TMP_CACHE)
inference.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import warnings
4
+ from pathlib import Path
5
+ from uuid import uuid4
6
+ from utils.lora import inject_inferable_lora
7
+ import torch
8
+ from diffusers import DPMSolverMultistepScheduler, TextToVideoSDPipeline
9
+ from models.unet_3d_condition import UNet3DConditionModel
10
+ from einops import rearrange
11
+ from torch.nn.functional import interpolate
12
+ import imageio
13
+ import decord
14
+
15
+ from train import handle_memory_attention, load_primary_models
16
+ from utils.lama import inpaint_watermark
17
+
18
+
19
+ def initialize_pipeline(model, device="cuda", xformers=False, sdp=False):
20
+ with warnings.catch_warnings():
21
+ warnings.simplefilter("ignore")
22
+
23
+ scheduler, tokenizer, text_encoder, vae, _unet = load_primary_models(model)
24
+ del _unet #This is a no op
25
+ unet = UNet3DConditionModel.from_pretrained(model, subfolder='unet')
26
+ # unet.disable_gradient_checkpointing()
27
+
28
+ pipeline = TextToVideoSDPipeline.from_pretrained(
29
+ pretrained_model_name_or_path=model,
30
+ scheduler=scheduler,
31
+ tokenizer=tokenizer,
32
+ text_encoder=text_encoder.to(device=device, dtype=torch.half),
33
+ vae=vae.to(device=device, dtype=torch.half),
34
+ unet=unet.to(device=device, dtype=torch.half),
35
+ )
36
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
37
+ unet._set_gradient_checkpointing(value=False)
38
+ handle_memory_attention(xformers, sdp, unet)
39
+ vae.enable_slicing()
40
+ return pipeline
41
+
42
+
43
+ def vid2vid(
44
+ pipeline, init_video, init_weight, prompt, negative_prompt, height, width, num_inference_steps, generator, guidance_scale
45
+ ):
46
+ num_frames = init_video.shape[2]
47
+ init_video = rearrange(init_video, "b c f h w -> (b f) c h w")
48
+ pipeline.generator=generator
49
+ latents = pipeline.vae.encode(init_video).latent_dist.sample()
50
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", f=num_frames)
51
+ latents = pipeline.scheduler.add_noise(
52
+ original_samples=latents * 0.18215,
53
+ noise=torch.randn_like(latents),
54
+ timesteps=(torch.ones(latents.shape[0]) * pipeline.scheduler.num_train_timesteps * (1 - init_weight)).long(),
55
+ )
56
+ if latents.shape[0] != len(prompt):
57
+ latents = latents.repeat(len(prompt), 1, 1, 1, 1)
58
+
59
+ do_classifier_free_guidance = guidance_scale > 1.0
60
+
61
+ prompt_embeds = pipeline._encode_prompt(
62
+ prompt=prompt,
63
+ negative_prompt=negative_prompt,
64
+ device=latents.device,
65
+ num_images_per_prompt=1,
66
+ do_classifier_free_guidance=do_classifier_free_guidance,
67
+ )
68
+
69
+ pipeline.scheduler.set_timesteps(num_inference_steps, device=latents.device)
70
+ timesteps = pipeline.scheduler.timesteps
71
+ timesteps = timesteps[round(init_weight * len(timesteps)) :]
72
+
73
+ with pipeline.progress_bar(total=len(timesteps)) as progress_bar:
74
+ for t in timesteps:
75
+ # expand the latents if we are doing classifier free guidance
76
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
77
+ latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
78
+
79
+ # predict the noise residual
80
+ noise_pred = pipeline.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
81
+
82
+ # perform guidance
83
+ if do_classifier_free_guidance:
84
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
85
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
86
+
87
+ # reshape latents
88
+ bsz, channel, frames, width, height = latents.shape
89
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
90
+ noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
91
+
92
+ # compute the previous noisy sample x_t -> x_t-1
93
+ latents = pipeline.scheduler.step(noise_pred, t, latents).prev_sample
94
+
95
+ # reshape latents back
96
+ latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4)
97
+
98
+ progress_bar.update()
99
+
100
+ video_tensor = pipeline.decode_latents(latents)
101
+
102
+ return video_tensor
103
+
104
+
105
+ @torch.inference_mode()
106
+ def inference(
107
+ model,
108
+ prompt,
109
+ negative_prompt=None,
110
+ batch_size=1,
111
+ num_frames=16,
112
+ width=256,
113
+ height=256,
114
+ num_steps=50,
115
+ guidance_scale=9,
116
+ init_video=None,
117
+ init_weight=0.5,
118
+ device="cuda",
119
+ xformers=False,
120
+ sdp=False,
121
+ lora_path='',
122
+ lora_rank=64,
123
+ seed=0,
124
+ ):
125
+ with torch.autocast(device, dtype=torch.half):
126
+ pipeline = initialize_pipeline(model, device, xformers, sdp)
127
+ inject_inferable_lora(pipeline, lora_path, r=lora_rank)
128
+ prompt = [prompt] * batch_size
129
+ negative_prompt = ([negative_prompt] * batch_size) if negative_prompt is not None else None
130
+
131
+ if init_video is not None:
132
+ g_cuda = torch.Generator(device='cuda')
133
+ g_cuda.manual_seed(seed)
134
+ g_cpu = torch.Generator()
135
+ g_cpu.manual_seed(seed)
136
+ videos = vid2vid(
137
+ pipeline=pipeline,
138
+ init_video=init_video.to(device=device, dtype=torch.half),
139
+ init_weight=init_weight,
140
+ prompt=prompt,
141
+ negative_prompt=negative_prompt,
142
+ height=height,
143
+ width=width,
144
+ num_inference_steps=num_steps,
145
+ generator=g_cuda,
146
+ guidance_scale=guidance_scale,
147
+ )
148
+
149
+ else:
150
+ g_cuda = torch.Generator(device='cuda')
151
+ g_cuda.manual_seed(seed)
152
+ g_cpu = torch.Generator()
153
+ g_cpu.manual_seed(seed)
154
+
155
+ videos = pipeline(
156
+ prompt=prompt,
157
+ negative_prompt=negative_prompt,
158
+ num_frames=num_frames,
159
+ height=height,
160
+ width=width,
161
+ num_inference_steps=num_steps,
162
+ generator=g_cuda,
163
+ guidance_scale=guidance_scale,
164
+ output_type="pt",
165
+ ).frames
166
+
167
+ return videos
168
+
169
+ def export_to_video(video_frames, output_video_path, fps):
170
+ writer = imageio.get_writer(output_video_path, format="FFMPEG", fps=fps)
171
+ for frame in video_frames:
172
+ writer.append_data(frame)
173
+ writer.close()
174
+
175
+
176
+ def run(**args):
177
+ decord.bridge.set_bridge("torch")
178
+
179
+ output_dir = args.pop("output_dir")
180
+ fps = args.pop("fps")
181
+ remove_watermark = args.pop("remove_watermark")
182
+
183
+ init_video = args.get("init_video", None)
184
+ if init_video is not None:
185
+ vr = decord.VideoReader(init_video)
186
+ init = rearrange(vr[:], "f h w c -> c f h w").div(127.5).sub(1).unsqueeze(0)
187
+ init = interpolate(init, size=(args['num_frames'], args['height'], args['width']), mode="trilinear")
188
+ args["init_video"] = init
189
+
190
+ videos = inference(**args)
191
+
192
+ os.makedirs(output_dir, exist_ok=True)
193
+
194
+ for idx, video in enumerate(videos):
195
+ if remove_watermark:
196
+ video = rearrange(video, "c f h w -> f c h w").add(1).div(2)
197
+ video = inpaint_watermark(video)
198
+ video = rearrange(video, "f c h w -> f h w c").clamp(0, 1).mul(255)
199
+ else:
200
+ video = rearrange(video, "c f h w -> f h w c").clamp(-1, 1).add(1).mul(127.5)
201
+
202
+ video = video.byte().cpu().numpy()
203
+
204
+ filename = os.path.join(output_dir, f"output-{idx}.mp4")
205
+ export_to_video(video, filename, fps)
206
+ yield filename
207
+
208
+
209
+ if __name__ == "__main__":
210
+ parser = argparse.ArgumentParser()
211
+ parser.add_argument("-m", "--model", type=str, required=True)
212
+ parser.add_argument("-p", "--prompt", type=str, required=True)
213
+ parser.add_argument("-n", "--negative_prompt", type=str, default=None)
214
+ parser.add_argument("-o", "--output_dir", type=str, default="./output")
215
+ parser.add_argument("-B", "--batch_size", type=int, default=1)
216
+ parser.add_argument("-T", "--num_frames", type=int, default=16)
217
+ parser.add_argument("-W", "--width", type=int, default=256)
218
+ parser.add_argument("-H", "--height", type=int, default=256)
219
+ parser.add_argument("-s", "--num_steps", type=int, default=50)
220
+ parser.add_argument("-g", "--guidance-scale", type=float, default=9)
221
+ parser.add_argument("-i", "--init-video", type=str, default=None)
222
+ parser.add_argument("-iw", "--init-weight", type=float, default=0.5)
223
+ parser.add_argument("-f", "--fps", type=int, default=8)
224
+ parser.add_argument("-d", "--device", type=str, default="cuda")
225
+ parser.add_argument("-x", "--xformers", action="store_true")
226
+ parser.add_argument("-S", "--sdp", action="store_true")
227
+ parser.add_argument("-lP", "--lora_path", type=str, default="")
228
+ parser.add_argument("-lR", "--lora_rank", type=int, default=64)
229
+ parser.add_argument("-rw", "--remove-watermark", action="store_true")
230
+ parser.add_argument("-seed", "--seed", type=int, default =0)
231
+ args = vars(parser.parse_args())
232
+
233
+ for filename in run(**args):
234
+ print(filename)
235
+
236
+
237
+
238
+
lama.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on the implementation from:
3
+ https://huggingface.co/spaces/fffiloni/lama-video-watermark-remover/tree/main
4
+
5
+ Modules were adapted by Hans Brouwer to only support the final configuration of the model uploaded here:
6
+ https://huggingface.co/akhaliq/lama
7
+
8
+ Apache License 2.0: https://github.com/advimman/lama/blob/main/LICENSE
9
+
10
+ @article{suvorov2021resolution,
11
+ title={Resolution-robust Large Mask Inpainting with Fourier Convolutions},
12
+ author={Suvorov, Roman and Logacheva, Elizaveta and Mashikhin, Anton and Remizova, Anastasia and Ashukha, Arsenii and Silvestrov, Aleksei and Kong, Naejin and Goka, Harshith and Park, Kiwoong and Lempitsky, Victor},
13
+ journal={arXiv preprint arXiv:2109.07161},
14
+ year={2021}
15
+ }
16
+ """
17
+
18
+ import os
19
+ import sys
20
+ from urllib.request import urlretrieve
21
+
22
+ import torch
23
+ from einops import rearrange
24
+ from PIL import Image
25
+ from torch import nn
26
+ from torch.nn import functional as F
27
+ from torchvision.transforms.functional import to_tensor
28
+ from tqdm import tqdm
29
+
30
+ from train import export_to_video
31
+
32
+
33
+ LAMA_URL = "https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt"
34
+ LAMA_PATH = "models/lama.ckpt"
35
+
36
+
37
+ def download_progress(t):
38
+ last_b = [0]
39
+
40
+ def update_to(b=1, bsize=1, tsize=None):
41
+ if tsize is not None:
42
+ t.total = tsize
43
+ t.update((b - last_b[0]) * bsize)
44
+ last_b[0] = b
45
+
46
+ return update_to
47
+
48
+
49
+ def download(url, path):
50
+ with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=path) as t:
51
+ urlretrieve(url, filename=path, reporthook=download_progress(t), data=None)
52
+
53
+
54
+ class FourierUnit(nn.Module):
55
+ def __init__(self, in_channels, out_channels, groups=1):
56
+ super(FourierUnit, self).__init__()
57
+ self.groups = groups
58
+ self.conv_layer = torch.nn.Conv2d(
59
+ in_channels=in_channels * 2,
60
+ out_channels=out_channels * 2,
61
+ kernel_size=1,
62
+ stride=1,
63
+ padding=0,
64
+ groups=self.groups,
65
+ bias=False,
66
+ )
67
+ self.bn = torch.nn.BatchNorm2d(out_channels * 2)
68
+ self.relu = torch.nn.ReLU(inplace=True)
69
+
70
+ def forward(self, x):
71
+ batch = x.shape[0]
72
+
73
+ # (batch, c, h, w/2+1, 2)
74
+ fft_dim = (-2, -1)
75
+ ffted = torch.fft.rfftn(x, dim=fft_dim, norm="ortho")
76
+ ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
77
+ ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
78
+ ffted = ffted.view((batch, -1) + ffted.size()[3:])
79
+
80
+ ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
81
+ ffted = self.relu(self.bn(ffted))
82
+
83
+ # (batch,c, t, h, w/2+1, 2)
84
+ ffted = ffted.view((batch, -1, 2) + ffted.size()[2:]).permute(0, 1, 3, 4, 2).contiguous()
85
+ ffted = torch.complex(ffted[..., 0], ffted[..., 1])
86
+
87
+ ifft_shape_slice = x.shape[-2:]
88
+ output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm="ortho")
89
+
90
+ return output
91
+
92
+
93
+ class SpectralTransform(nn.Module):
94
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
95
+ super(SpectralTransform, self).__init__()
96
+ self.stride = stride
97
+ if stride == 2:
98
+ self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
99
+ else:
100
+ self.downsample = nn.Identity()
101
+
102
+ self.conv1 = nn.Sequential(
103
+ nn.Conv2d(in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False),
104
+ nn.BatchNorm2d(out_channels // 2),
105
+ nn.ReLU(inplace=True),
106
+ )
107
+ self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups)
108
+ self.conv2 = torch.nn.Conv2d(out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
109
+
110
+ def forward(self, x):
111
+ x = self.downsample(x)
112
+ x = self.conv1(x)
113
+ output = self.fu(x)
114
+ output = self.conv2(x + output)
115
+ return output
116
+
117
+
118
+ class FFC(nn.Module):
119
+ def __init__(
120
+ self,
121
+ in_channels,
122
+ out_channels,
123
+ kernel_size,
124
+ ratio_gin,
125
+ ratio_gout,
126
+ stride=1,
127
+ padding=0,
128
+ dilation=1,
129
+ groups=1,
130
+ bias=False,
131
+ padding_type="reflect",
132
+ gated=False,
133
+ ):
134
+ super(FFC, self).__init__()
135
+
136
+ assert stride == 1 or stride == 2, "Stride should be 1 or 2."
137
+ self.stride = stride
138
+
139
+ in_cg = int(in_channels * ratio_gin)
140
+ in_cl = in_channels - in_cg
141
+ out_cg = int(out_channels * ratio_gout)
142
+ out_cl = out_channels - out_cg
143
+
144
+ self.ratio_gin = ratio_gin
145
+ self.ratio_gout = ratio_gout
146
+ self.global_in_num = in_cg
147
+
148
+ module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
149
+ self.convl2l = module(
150
+ in_cl, out_cl, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type
151
+ )
152
+ module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
153
+ self.convl2g = module(
154
+ in_cl, out_cg, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type
155
+ )
156
+ module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
157
+ self.convg2l = module(
158
+ in_cg, out_cl, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type
159
+ )
160
+ module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
161
+ self.convg2g = module(in_cg, out_cg, stride, 1 if groups == 1 else groups // 2)
162
+
163
+ self.gated = gated
164
+ module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
165
+ self.gate = module(in_channels, 2, 1)
166
+
167
+ def forward(self, x):
168
+ x_l, x_g = x if type(x) is tuple else (x, 0)
169
+ out_xl, out_xg = 0, 0
170
+
171
+ if self.gated:
172
+ total_input_parts = [x_l]
173
+ if torch.is_tensor(x_g):
174
+ total_input_parts.append(x_g)
175
+ total_input = torch.cat(total_input_parts, dim=1)
176
+
177
+ gates = torch.sigmoid(self.gate(total_input))
178
+ g2l_gate, l2g_gate = gates.chunk(2, dim=1)
179
+ else:
180
+ g2l_gate, l2g_gate = 1, 1
181
+
182
+ if self.ratio_gout != 1:
183
+ out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
184
+ if self.ratio_gout != 0:
185
+ out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
186
+
187
+ return out_xl, out_xg
188
+
189
+
190
+ class FFC_BN_ACT(nn.Module):
191
+ def __init__(
192
+ self,
193
+ in_channels,
194
+ out_channels,
195
+ kernel_size,
196
+ ratio_gin=0,
197
+ ratio_gout=0,
198
+ stride=1,
199
+ padding=0,
200
+ dilation=1,
201
+ groups=1,
202
+ bias=False,
203
+ norm_layer=nn.BatchNorm2d,
204
+ activation_layer=nn.ReLU,
205
+ ):
206
+ super(FFC_BN_ACT, self).__init__()
207
+ self.ffc = FFC(
208
+ in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride, padding, dilation, groups, bias
209
+ )
210
+ lnorm = nn.Identity if ratio_gout == 1 else norm_layer
211
+ gnorm = nn.Identity if ratio_gout == 0 else norm_layer
212
+ global_channels = int(out_channels * ratio_gout)
213
+ self.bn_l = lnorm(out_channels - global_channels)
214
+ self.bn_g = gnorm(global_channels)
215
+
216
+ lact = nn.Identity if ratio_gout == 1 else activation_layer
217
+ gact = nn.Identity if ratio_gout == 0 else activation_layer
218
+ self.act_l = lact(inplace=True)
219
+ self.act_g = gact(inplace=True)
220
+
221
+ def forward(self, x):
222
+ x_l, x_g = self.ffc(x)
223
+ x_l = self.act_l(self.bn_l(x_l))
224
+ x_g = self.act_g(self.bn_g(x_g))
225
+ return x_l, x_g
226
+
227
+
228
+ class FFCResnetBlock(nn.Module):
229
+ def __init__(self, dim, ratio_gin, ratio_gout):
230
+ super().__init__()
231
+ self.conv1 = FFC_BN_ACT(
232
+ dim, dim, kernel_size=3, padding=1, dilation=1, ratio_gin=ratio_gin, ratio_gout=ratio_gout
233
+ )
234
+ self.conv2 = FFC_BN_ACT(
235
+ dim, dim, kernel_size=3, padding=1, dilation=1, ratio_gin=ratio_gin, ratio_gout=ratio_gout
236
+ )
237
+
238
+ def forward(self, x):
239
+ x_l, x_g = x if type(x) is tuple else (x, 0)
240
+ id_l, id_g = x_l, x_g
241
+ x_l, x_g = self.conv1((x_l, x_g))
242
+ x_l, x_g = self.conv2((x_l, x_g))
243
+ x_l, x_g = id_l + x_l, id_g + x_g
244
+ out = x_l, x_g
245
+ return out
246
+
247
+
248
+ class ConcatTupleLayer(nn.Module):
249
+ def forward(self, x):
250
+ assert isinstance(x, tuple)
251
+ x_l, x_g = x
252
+ assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
253
+ if not torch.is_tensor(x_g):
254
+ return x_l
255
+ return torch.cat(x, dim=1)
256
+
257
+
258
+ class LargeMaskInpainting(nn.Module):
259
+ def __init__(self, input_nc=4, output_nc=3, ngf=64, n_downsampling=3, n_blocks=18, max_features=1024):
260
+ super().__init__()
261
+
262
+ model = [nn.ReflectionPad2d(3), FFC_BN_ACT(input_nc, ngf, kernel_size=7)]
263
+
264
+ ### downsample
265
+ for i in range(n_downsampling):
266
+ mult = 2**i
267
+ model += [
268
+ FFC_BN_ACT(
269
+ min(max_features, ngf * mult),
270
+ min(max_features, ngf * mult * 2),
271
+ kernel_size=3,
272
+ stride=2,
273
+ padding=1,
274
+ ratio_gout=0.75 if i == n_downsampling - 1 else 0,
275
+ )
276
+ ]
277
+
278
+ ### resnet blocks
279
+ for i in range(n_blocks):
280
+ cur_resblock = FFCResnetBlock(min(max_features, ngf * 2**n_downsampling), ratio_gin=0.75, ratio_gout=0.75)
281
+ model += [cur_resblock]
282
+
283
+ model += [ConcatTupleLayer()]
284
+
285
+ ### upsample
286
+ for i in range(n_downsampling):
287
+ mult = 2 ** (n_downsampling - i)
288
+ model += [
289
+ nn.ConvTranspose2d(
290
+ min(max_features, ngf * mult),
291
+ min(max_features, int(ngf * mult / 2)),
292
+ kernel_size=3,
293
+ stride=2,
294
+ padding=1,
295
+ output_padding=1,
296
+ ),
297
+ nn.BatchNorm2d(min(max_features, int(ngf * mult / 2))),
298
+ nn.ReLU(True),
299
+ ]
300
+
301
+ model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7), nn.Sigmoid()]
302
+ self.model = nn.Sequential(*model)
303
+
304
+ def forward(self, img, mask):
305
+ masked_img = img * (1 - mask)
306
+ masked_img = torch.cat([masked_img, mask], dim=1)
307
+ pred = self.model(masked_img)
308
+ inpainted = mask * pred + (1 - mask) * img
309
+ return inpainted
310
+
311
+
312
+ @torch.inference_mode()
313
+ def inpaint_watermark(imgs):
314
+ if not os.path.exists(LAMA_PATH):
315
+ download(LAMA_URL, LAMA_PATH)
316
+
317
+ mask = to_tensor(Image.open("./utils/mask.png").convert("L")).unsqueeze(0).to(imgs.device)
318
+ if mask.shape[-1] != imgs.shape[-1]:
319
+ mask = F.interpolate(mask, size=(imgs.shape[2], imgs.shape[3]), mode="nearest")
320
+ mask = mask.expand(imgs.shape[0], 1, mask.shape[2], mask.shape[3])
321
+
322
+ model = LargeMaskInpainting().to(imgs.device)
323
+ state_dict = torch.load(LAMA_PATH, map_location=imgs.device)["state_dict"]
324
+ g_dict = {k.replace("generator.", ""): v for k, v in state_dict.items() if k.startswith("generator")}
325
+ model.load_state_dict(g_dict)
326
+
327
+ inpainted = model.forward(imgs, mask)
328
+
329
+ return inpainted
330
+
331
+
332
+ if __name__ == "__main__":
333
+ import decord
334
+
335
+ decord.bridge.set_bridge("torch")
336
+
337
+ if len(sys.argv) < 2:
338
+ print("Usage: python -m utils.lama <path/to/video>")
339
+ sys.exit(1)
340
+
341
+ video_path = sys.argv[1]
342
+ out_path = video_path.replace(".mp4", " inpainted.mp4")
343
+
344
+ vr = decord.VideoReader(video_path)
345
+ fps = vr.get_avg_fps()
346
+ video = rearrange(vr[:], "f h w c -> f c h w").div(255)
347
+
348
+ inpainted = inpaint_watermark(video)
349
+ inpainted = rearrange(inpainted, "f c h w -> f h w c").clamp(0, 1).mul(255).byte().cpu().numpy()
350
+ export_to_video(inpainted, out_path, fps)
lora.py ADDED
@@ -0,0 +1,1312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ from itertools import groupby
4
+ import os
5
+ from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
6
+
7
+ import numpy as np
8
+ import PIL
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ try:
14
+ from safetensors.torch import safe_open
15
+ from safetensors.torch import save_file as safe_save
16
+
17
+ safetensors_available = True
18
+ except ImportError:
19
+ from .safe_open import safe_open
20
+
21
+ def safe_save(
22
+ tensors: Dict[str, torch.Tensor],
23
+ filename: str,
24
+ metadata: Optional[Dict[str, str]] = None,
25
+ ) -> None:
26
+ raise EnvironmentError(
27
+ "Saving safetensors requires the safetensors library. Please install with pip or similar."
28
+ )
29
+
30
+ safetensors_available = False
31
+
32
+
33
+ class LoraInjectedLinear(nn.Module):
34
+ def __init__(
35
+ self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
36
+ ):
37
+ super().__init__()
38
+
39
+ if r > min(in_features, out_features):
40
+ #raise ValueError(
41
+ # f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
42
+ #)
43
+ print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}")
44
+ r = min(in_features, out_features)
45
+
46
+ self.r = r
47
+ self.linear = nn.Linear(in_features, out_features, bias)
48
+ self.lora_down = nn.Linear(in_features, r, bias=False)
49
+ self.dropout = nn.Dropout(dropout_p)
50
+ self.lora_up = nn.Linear(r, out_features, bias=False)
51
+ self.scale = scale
52
+ self.selector = nn.Identity()
53
+
54
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
55
+ nn.init.zeros_(self.lora_up.weight)
56
+
57
+ def forward(self, input):
58
+ return (
59
+ self.linear(input)
60
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
61
+ * self.scale
62
+ )
63
+
64
+ def realize_as_lora(self):
65
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
66
+
67
+ def set_selector_from_diag(self, diag: torch.Tensor):
68
+ # diag is a 1D tensor of size (r,)
69
+ assert diag.shape == (self.r,)
70
+ self.selector = nn.Linear(self.r, self.r, bias=False)
71
+ self.selector.weight.data = torch.diag(diag)
72
+ self.selector.weight.data = self.selector.weight.data.to(
73
+ self.lora_up.weight.device
74
+ ).to(self.lora_up.weight.dtype)
75
+
76
+
77
+ class LoraInjectedConv2d(nn.Module):
78
+ def __init__(
79
+ self,
80
+ in_channels: int,
81
+ out_channels: int,
82
+ kernel_size,
83
+ stride=1,
84
+ padding=0,
85
+ dilation=1,
86
+ groups: int = 1,
87
+ bias: bool = True,
88
+ r: int = 4,
89
+ dropout_p: float = 0.1,
90
+ scale: float = 1.0,
91
+ ):
92
+ super().__init__()
93
+ if r > min(in_channels, out_channels):
94
+ print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}")
95
+ r = min(in_channels, out_channels)
96
+
97
+ self.r = r
98
+ self.conv = nn.Conv2d(
99
+ in_channels=in_channels,
100
+ out_channels=out_channels,
101
+ kernel_size=kernel_size,
102
+ stride=stride,
103
+ padding=padding,
104
+ dilation=dilation,
105
+ groups=groups,
106
+ bias=bias,
107
+ )
108
+
109
+ self.lora_down = nn.Conv2d(
110
+ in_channels=in_channels,
111
+ out_channels=r,
112
+ kernel_size=kernel_size,
113
+ stride=stride,
114
+ padding=padding,
115
+ dilation=dilation,
116
+ groups=groups,
117
+ bias=False,
118
+ )
119
+ self.dropout = nn.Dropout(dropout_p)
120
+ self.lora_up = nn.Conv2d(
121
+ in_channels=r,
122
+ out_channels=out_channels,
123
+ kernel_size=1,
124
+ stride=1,
125
+ padding=0,
126
+ bias=False,
127
+ )
128
+ self.selector = nn.Identity()
129
+ self.scale = scale
130
+
131
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
132
+ nn.init.zeros_(self.lora_up.weight)
133
+
134
+ def forward(self, input):
135
+ return (
136
+ self.conv(input)
137
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
138
+ * self.scale
139
+ )
140
+
141
+ def realize_as_lora(self):
142
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
143
+
144
+ def set_selector_from_diag(self, diag: torch.Tensor):
145
+ # diag is a 1D tensor of size (r,)
146
+ assert diag.shape == (self.r,)
147
+ self.selector = nn.Conv2d(
148
+ in_channels=self.r,
149
+ out_channels=self.r,
150
+ kernel_size=1,
151
+ stride=1,
152
+ padding=0,
153
+ bias=False,
154
+ )
155
+ self.selector.weight.data = torch.diag(diag)
156
+
157
+ # same device + dtype as lora_up
158
+ self.selector.weight.data = self.selector.weight.data.to(
159
+ self.lora_up.weight.device
160
+ ).to(self.lora_up.weight.dtype)
161
+
162
+ class LoraInjectedConv3d(nn.Module):
163
+ def __init__(
164
+ self,
165
+ in_channels: int,
166
+ out_channels: int,
167
+ kernel_size: (3, 1, 1),
168
+ padding: (1, 0, 0),
169
+ bias: bool = False,
170
+ r: int = 4,
171
+ dropout_p: float = 0,
172
+ scale: float = 1.0,
173
+ ):
174
+ super().__init__()
175
+ if r > min(in_channels, out_channels):
176
+ print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}")
177
+ r = min(in_channels, out_channels)
178
+
179
+ self.r = r
180
+ self.kernel_size = kernel_size
181
+ self.padding = padding
182
+ self.conv = nn.Conv3d(
183
+ in_channels=in_channels,
184
+ out_channels=out_channels,
185
+ kernel_size=kernel_size,
186
+ padding=padding,
187
+ )
188
+
189
+ self.lora_down = nn.Conv3d(
190
+ in_channels=in_channels,
191
+ out_channels=r,
192
+ kernel_size=kernel_size,
193
+ bias=False,
194
+ padding=padding
195
+ )
196
+ self.dropout = nn.Dropout(dropout_p)
197
+ self.lora_up = nn.Conv3d(
198
+ in_channels=r,
199
+ out_channels=out_channels,
200
+ kernel_size=1,
201
+ stride=1,
202
+ padding=0,
203
+ bias=False,
204
+ )
205
+ self.selector = nn.Identity()
206
+ self.scale = scale
207
+
208
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
209
+ nn.init.zeros_(self.lora_up.weight)
210
+
211
+ def forward(self, input):
212
+ return (
213
+ self.conv(input)
214
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
215
+ * self.scale
216
+ )
217
+
218
+ def realize_as_lora(self):
219
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
220
+
221
+ def set_selector_from_diag(self, diag: torch.Tensor):
222
+ # diag is a 1D tensor of size (r,)
223
+ assert diag.shape == (self.r,)
224
+ self.selector = nn.Conv3d(
225
+ in_channels=self.r,
226
+ out_channels=self.r,
227
+ kernel_size=1,
228
+ stride=1,
229
+ padding=0,
230
+ bias=False,
231
+ )
232
+ self.selector.weight.data = torch.diag(diag)
233
+
234
+ # same device + dtype as lora_up
235
+ self.selector.weight.data = self.selector.weight.data.to(
236
+ self.lora_up.weight.device
237
+ ).to(self.lora_up.weight.dtype)
238
+
239
+ UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
240
+
241
+ UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"}
242
+
243
+ TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
244
+
245
+ TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
246
+
247
+ DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
248
+
249
+ EMBED_FLAG = "<embed>"
250
+
251
+
252
+ def _find_children(
253
+ model,
254
+ search_class: List[Type[nn.Module]] = [nn.Linear],
255
+ ):
256
+ """
257
+ Find all modules of a certain class (or union of classes).
258
+
259
+ Returns all matching modules, along with the parent of those moduless and the
260
+ names they are referenced by.
261
+ """
262
+ # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
263
+ for parent in model.modules():
264
+ for name, module in parent.named_children():
265
+ if any([isinstance(module, _class) for _class in search_class]):
266
+ yield parent, name, module
267
+
268
+
269
+ def _find_modules_v2(
270
+ model,
271
+ ancestor_class: Optional[Set[str]] = None,
272
+ search_class: List[Type[nn.Module]] = [nn.Linear],
273
+ exclude_children_of: Optional[List[Type[nn.Module]]] = [
274
+ LoraInjectedLinear,
275
+ LoraInjectedConv2d,
276
+ LoraInjectedConv3d
277
+ ],
278
+ ):
279
+ """
280
+ Find all modules of a certain class (or union of classes) that are direct or
281
+ indirect descendants of other modules of a certain class (or union of classes).
282
+
283
+ Returns all matching modules, along with the parent of those moduless and the
284
+ names they are referenced by.
285
+ """
286
+
287
+ # Get the targets we should replace all linears under
288
+ if ancestor_class is not None:
289
+ ancestors = (
290
+ module
291
+ for module in model.modules()
292
+ if module.__class__.__name__ in ancestor_class
293
+ )
294
+ else:
295
+ # this, incase you want to naively iterate over all modules.
296
+ ancestors = [module for module in model.modules()]
297
+
298
+ # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
299
+ for ancestor in ancestors:
300
+ for fullname, module in ancestor.named_modules():
301
+ if any([isinstance(module, _class) for _class in search_class]):
302
+ # Find the direct parent if this is a descendant, not a child, of target
303
+ *path, name = fullname.split(".")
304
+ parent = ancestor
305
+ while path:
306
+ parent = parent.get_submodule(path.pop(0))
307
+ # Skip this linear if it's a child of a LoraInjectedLinear
308
+ if exclude_children_of and any(
309
+ [isinstance(parent, _class) for _class in exclude_children_of]
310
+ ):
311
+ continue
312
+ # Otherwise, yield it
313
+ yield parent, name, module
314
+
315
+
316
+ def _find_modules_old(
317
+ model,
318
+ ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
319
+ search_class: List[Type[nn.Module]] = [nn.Linear],
320
+ exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
321
+ ):
322
+ ret = []
323
+ for _module in model.modules():
324
+ if _module.__class__.__name__ in ancestor_class:
325
+
326
+ for name, _child_module in _module.named_modules():
327
+ if _child_module.__class__ in search_class:
328
+ ret.append((_module, name, _child_module))
329
+ print(ret)
330
+ return ret
331
+
332
+
333
+ _find_modules = _find_modules_v2
334
+
335
+
336
+ def inject_trainable_lora(
337
+ model: nn.Module,
338
+ target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
339
+ r: int = 4,
340
+ loras=None, # path to lora .pt
341
+ verbose: bool = False,
342
+ dropout_p: float = 0.0,
343
+ scale: float = 1.0,
344
+ ):
345
+ """
346
+ inject lora into model, and returns lora parameter groups.
347
+ """
348
+
349
+ require_grad_params = []
350
+ names = []
351
+
352
+ if loras != None:
353
+ loras = torch.load(loras)
354
+
355
+ for _module, name, _child_module in _find_modules(
356
+ model, target_replace_module, search_class=[nn.Linear]
357
+ ):
358
+ weight = _child_module.weight
359
+ bias = _child_module.bias
360
+ if verbose:
361
+ print("LoRA Injection : injecting lora into ", name)
362
+ print("LoRA Injection : weight shape", weight.shape)
363
+ _tmp = LoraInjectedLinear(
364
+ _child_module.in_features,
365
+ _child_module.out_features,
366
+ _child_module.bias is not None,
367
+ r=r,
368
+ dropout_p=dropout_p,
369
+ scale=scale,
370
+ )
371
+ _tmp.linear.weight = weight
372
+ if bias is not None:
373
+ _tmp.linear.bias = bias
374
+
375
+ # switch the module
376
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
377
+ _module._modules[name] = _tmp
378
+
379
+ require_grad_params.append(_module._modules[name].lora_up.parameters())
380
+ require_grad_params.append(_module._modules[name].lora_down.parameters())
381
+
382
+ if loras != None:
383
+ _module._modules[name].lora_up.weight = loras.pop(0)
384
+ _module._modules[name].lora_down.weight = loras.pop(0)
385
+
386
+ _module._modules[name].lora_up.weight.requires_grad = True
387
+ _module._modules[name].lora_down.weight.requires_grad = True
388
+ names.append(name)
389
+
390
+ return require_grad_params, names
391
+
392
+
393
+ def inject_trainable_lora_extended(
394
+ model: nn.Module,
395
+ target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE,
396
+ r: int = 4,
397
+ loras=None, # path to lora .pt
398
+ ):
399
+ """
400
+ inject lora into model, and returns lora parameter groups.
401
+ """
402
+
403
+ require_grad_params = []
404
+ names = []
405
+
406
+ if loras != None:
407
+ loras = torch.load(loras)
408
+
409
+ for _module, name, _child_module in _find_modules(
410
+ model, target_replace_module, search_class=[nn.Linear, nn.Conv2d, nn.Conv3d]
411
+ ):
412
+ if _child_module.__class__ == nn.Linear:
413
+ weight = _child_module.weight
414
+ bias = _child_module.bias
415
+ _tmp = LoraInjectedLinear(
416
+ _child_module.in_features,
417
+ _child_module.out_features,
418
+ _child_module.bias is not None,
419
+ r=r,
420
+ )
421
+ _tmp.linear.weight = weight
422
+ if bias is not None:
423
+ _tmp.linear.bias = bias
424
+ elif _child_module.__class__ == nn.Conv2d:
425
+ weight = _child_module.weight
426
+ bias = _child_module.bias
427
+ _tmp = LoraInjectedConv2d(
428
+ _child_module.in_channels,
429
+ _child_module.out_channels,
430
+ _child_module.kernel_size,
431
+ _child_module.stride,
432
+ _child_module.padding,
433
+ _child_module.dilation,
434
+ _child_module.groups,
435
+ _child_module.bias is not None,
436
+ r=r,
437
+ )
438
+
439
+ _tmp.conv.weight = weight
440
+ if bias is not None:
441
+ _tmp.conv.bias = bias
442
+
443
+ elif _child_module.__class__ == nn.Conv3d:
444
+ weight = _child_module.weight
445
+ bias = _child_module.bias
446
+ _tmp = LoraInjectedConv3d(
447
+ _child_module.in_channels,
448
+ _child_module.out_channels,
449
+ bias=_child_module.bias is not None,
450
+ kernel_size=_child_module.kernel_size,
451
+ padding=_child_module.padding,
452
+ r=r,
453
+ )
454
+
455
+ _tmp.conv.weight = weight
456
+ if bias is not None:
457
+ _tmp.conv.bias = bias
458
+ # switch the module
459
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
460
+ if bias is not None:
461
+ _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
462
+
463
+ _module._modules[name] = _tmp
464
+ require_grad_params.append(_module._modules[name].lora_up.parameters())
465
+ require_grad_params.append(_module._modules[name].lora_down.parameters())
466
+
467
+ if loras != None:
468
+ _module._modules[name].lora_up.weight = loras.pop(0)
469
+ _module._modules[name].lora_down.weight = loras.pop(0)
470
+
471
+ _module._modules[name].lora_up.weight.requires_grad = True
472
+ _module._modules[name].lora_down.weight.requires_grad = True
473
+ names.append(name)
474
+
475
+ return require_grad_params, names
476
+
477
+
478
+ def inject_inferable_lora(
479
+ model,
480
+ lora_path='',
481
+ unet_replace_modules=["UNet3DConditionModel"],
482
+ text_encoder_replace_modules=["CLIPEncoderLayer"],
483
+ is_extended=False,
484
+ r=16
485
+ ):
486
+ from transformers.models.clip import CLIPTextModel
487
+ from diffusers import UNet3DConditionModel
488
+
489
+ def is_text_model(f): return 'text_encoder' in f and isinstance(model.text_encoder, CLIPTextModel)
490
+ def is_unet(f): return 'unet' in f and model.unet.__class__.__name__ == "UNet3DConditionModel"
491
+
492
+ if os.path.exists(lora_path):
493
+ try:
494
+ for f in os.listdir(lora_path):
495
+ if f.endswith('.pt'):
496
+ lora_file = os.path.join(lora_path, f)
497
+
498
+ if is_text_model(f):
499
+ monkeypatch_or_replace_lora(
500
+ model.text_encoder,
501
+ torch.load(lora_file),
502
+ target_replace_module=text_encoder_replace_modules,
503
+ r=r
504
+ )
505
+ print("Successfully loaded Text Encoder LoRa.")
506
+ continue
507
+
508
+ if is_unet(f):
509
+ monkeypatch_or_replace_lora_extended(
510
+ model.unet,
511
+ torch.load(lora_file),
512
+ target_replace_module=unet_replace_modules,
513
+ r=r
514
+ )
515
+ print("Successfully loaded UNET LoRa.")
516
+ continue
517
+
518
+ print("Found a .pt file, but doesn't have the correct name format. (unet.pt, text_encoder.pt)")
519
+
520
+ except Exception as e:
521
+ print(e)
522
+ print("Couldn't inject LoRA's due to an error.")
523
+
524
+ def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
525
+
526
+ loras = []
527
+
528
+ for _m, _n, _child_module in _find_modules(
529
+ model,
530
+ target_replace_module,
531
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
532
+ ):
533
+ loras.append((_child_module.lora_up, _child_module.lora_down))
534
+
535
+ if len(loras) == 0:
536
+ raise ValueError("No lora injected.")
537
+
538
+ return loras
539
+
540
+
541
+ def extract_lora_as_tensor(
542
+ model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True
543
+ ):
544
+
545
+ loras = []
546
+
547
+ for _m, _n, _child_module in _find_modules(
548
+ model,
549
+ target_replace_module,
550
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
551
+ ):
552
+ up, down = _child_module.realize_as_lora()
553
+ if as_fp16:
554
+ up = up.to(torch.float16)
555
+ down = down.to(torch.float16)
556
+
557
+ loras.append((up, down))
558
+
559
+ if len(loras) == 0:
560
+ raise ValueError("No lora injected.")
561
+
562
+ return loras
563
+
564
+
565
+ def save_lora_weight(
566
+ model,
567
+ path="./lora.pt",
568
+ target_replace_module=DEFAULT_TARGET_REPLACE,
569
+ ):
570
+ weights = []
571
+ for _up, _down in extract_lora_ups_down(
572
+ model, target_replace_module=target_replace_module
573
+ ):
574
+ weights.append(_up.weight.to("cpu").to(torch.float32))
575
+ weights.append(_down.weight.to("cpu").to(torch.float32))
576
+
577
+ torch.save(weights, path)
578
+
579
+
580
+ def save_lora_as_json(model, path="./lora.json"):
581
+ weights = []
582
+ for _up, _down in extract_lora_ups_down(model):
583
+ weights.append(_up.weight.detach().cpu().numpy().tolist())
584
+ weights.append(_down.weight.detach().cpu().numpy().tolist())
585
+
586
+ import json
587
+
588
+ with open(path, "w") as f:
589
+ json.dump(weights, f)
590
+
591
+
592
+ def save_safeloras_with_embeds(
593
+ modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
594
+ embeds: Dict[str, torch.Tensor] = {},
595
+ outpath="./lora.safetensors",
596
+ ):
597
+ """
598
+ Saves the Lora from multiple modules in a single safetensor file.
599
+
600
+ modelmap is a dictionary of {
601
+ "module name": (module, target_replace_module)
602
+ }
603
+ """
604
+ weights = {}
605
+ metadata = {}
606
+
607
+ for name, (model, target_replace_module) in modelmap.items():
608
+ metadata[name] = json.dumps(list(target_replace_module))
609
+
610
+ for i, (_up, _down) in enumerate(
611
+ extract_lora_as_tensor(model, target_replace_module)
612
+ ):
613
+ rank = _down.shape[0]
614
+
615
+ metadata[f"{name}:{i}:rank"] = str(rank)
616
+ weights[f"{name}:{i}:up"] = _up
617
+ weights[f"{name}:{i}:down"] = _down
618
+
619
+ for token, tensor in embeds.items():
620
+ metadata[token] = EMBED_FLAG
621
+ weights[token] = tensor
622
+
623
+ print(f"Saving weights to {outpath}")
624
+ safe_save(weights, outpath, metadata)
625
+
626
+
627
+ def save_safeloras(
628
+ modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
629
+ outpath="./lora.safetensors",
630
+ ):
631
+ return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
632
+
633
+
634
+ def convert_loras_to_safeloras_with_embeds(
635
+ modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
636
+ embeds: Dict[str, torch.Tensor] = {},
637
+ outpath="./lora.safetensors",
638
+ ):
639
+ """
640
+ Converts the Lora from multiple pytorch .pt files into a single safetensor file.
641
+
642
+ modelmap is a dictionary of {
643
+ "module name": (pytorch_model_path, target_replace_module, rank)
644
+ }
645
+ """
646
+
647
+ weights = {}
648
+ metadata = {}
649
+
650
+ for name, (path, target_replace_module, r) in modelmap.items():
651
+ metadata[name] = json.dumps(list(target_replace_module))
652
+
653
+ lora = torch.load(path)
654
+ for i, weight in enumerate(lora):
655
+ is_up = i % 2 == 0
656
+ i = i // 2
657
+
658
+ if is_up:
659
+ metadata[f"{name}:{i}:rank"] = str(r)
660
+ weights[f"{name}:{i}:up"] = weight
661
+ else:
662
+ weights[f"{name}:{i}:down"] = weight
663
+
664
+ for token, tensor in embeds.items():
665
+ metadata[token] = EMBED_FLAG
666
+ weights[token] = tensor
667
+
668
+ print(f"Saving weights to {outpath}")
669
+ safe_save(weights, outpath, metadata)
670
+
671
+
672
+ def convert_loras_to_safeloras(
673
+ modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
674
+ outpath="./lora.safetensors",
675
+ ):
676
+ convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
677
+
678
+
679
+ def parse_safeloras(
680
+ safeloras,
681
+ ) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
682
+ """
683
+ Converts a loaded safetensor file that contains a set of module Loras
684
+ into Parameters and other information
685
+
686
+ Output is a dictionary of {
687
+ "module name": (
688
+ [list of weights],
689
+ [list of ranks],
690
+ target_replacement_modules
691
+ )
692
+ }
693
+ """
694
+ loras = {}
695
+ metadata = safeloras.metadata()
696
+
697
+ get_name = lambda k: k.split(":")[0]
698
+
699
+ keys = list(safeloras.keys())
700
+ keys.sort(key=get_name)
701
+
702
+ for name, module_keys in groupby(keys, get_name):
703
+ info = metadata.get(name)
704
+
705
+ if not info:
706
+ raise ValueError(
707
+ f"Tensor {name} has no metadata - is this a Lora safetensor?"
708
+ )
709
+
710
+ # Skip Textual Inversion embeds
711
+ if info == EMBED_FLAG:
712
+ continue
713
+
714
+ # Handle Loras
715
+ # Extract the targets
716
+ target = json.loads(info)
717
+
718
+ # Build the result lists - Python needs us to preallocate lists to insert into them
719
+ module_keys = list(module_keys)
720
+ ranks = [4] * (len(module_keys) // 2)
721
+ weights = [None] * len(module_keys)
722
+
723
+ for key in module_keys:
724
+ # Split the model name and index out of the key
725
+ _, idx, direction = key.split(":")
726
+ idx = int(idx)
727
+
728
+ # Add the rank
729
+ ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
730
+
731
+ # Insert the weight into the list
732
+ idx = idx * 2 + (1 if direction == "down" else 0)
733
+ weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
734
+
735
+ loras[name] = (weights, ranks, target)
736
+
737
+ return loras
738
+
739
+
740
+ def parse_safeloras_embeds(
741
+ safeloras,
742
+ ) -> Dict[str, torch.Tensor]:
743
+ """
744
+ Converts a loaded safetensor file that contains Textual Inversion embeds into
745
+ a dictionary of embed_token: Tensor
746
+ """
747
+ embeds = {}
748
+ metadata = safeloras.metadata()
749
+
750
+ for key in safeloras.keys():
751
+ # Only handle Textual Inversion embeds
752
+ meta = metadata.get(key)
753
+ if not meta or meta != EMBED_FLAG:
754
+ continue
755
+
756
+ embeds[key] = safeloras.get_tensor(key)
757
+
758
+ return embeds
759
+
760
+
761
+ def load_safeloras(path, device="cpu"):
762
+ safeloras = safe_open(path, framework="pt", device=device)
763
+ return parse_safeloras(safeloras)
764
+
765
+
766
+ def load_safeloras_embeds(path, device="cpu"):
767
+ safeloras = safe_open(path, framework="pt", device=device)
768
+ return parse_safeloras_embeds(safeloras)
769
+
770
+
771
+ def load_safeloras_both(path, device="cpu"):
772
+ safeloras = safe_open(path, framework="pt", device=device)
773
+ return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)
774
+
775
+
776
+ def collapse_lora(model, alpha=1.0):
777
+
778
+ for _module, name, _child_module in _find_modules(
779
+ model,
780
+ UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE,
781
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d],
782
+ ):
783
+
784
+ if isinstance(_child_module, LoraInjectedLinear):
785
+ print("Collapsing Lin Lora in", name)
786
+
787
+ _child_module.linear.weight = nn.Parameter(
788
+ _child_module.linear.weight.data
789
+ + alpha
790
+ * (
791
+ _child_module.lora_up.weight.data
792
+ @ _child_module.lora_down.weight.data
793
+ )
794
+ .type(_child_module.linear.weight.dtype)
795
+ .to(_child_module.linear.weight.device)
796
+ )
797
+
798
+ else:
799
+ print("Collapsing Conv Lora in", name)
800
+ _child_module.conv.weight = nn.Parameter(
801
+ _child_module.conv.weight.data
802
+ + alpha
803
+ * (
804
+ _child_module.lora_up.weight.data.flatten(start_dim=1)
805
+ @ _child_module.lora_down.weight.data.flatten(start_dim=1)
806
+ )
807
+ .reshape(_child_module.conv.weight.data.shape)
808
+ .type(_child_module.conv.weight.dtype)
809
+ .to(_child_module.conv.weight.device)
810
+ )
811
+
812
+
813
+ def monkeypatch_or_replace_lora(
814
+ model,
815
+ loras,
816
+ target_replace_module=DEFAULT_TARGET_REPLACE,
817
+ r: Union[int, List[int]] = 4,
818
+ ):
819
+ for _module, name, _child_module in _find_modules(
820
+ model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
821
+ ):
822
+ _source = (
823
+ _child_module.linear
824
+ if isinstance(_child_module, LoraInjectedLinear)
825
+ else _child_module
826
+ )
827
+
828
+ weight = _source.weight
829
+ bias = _source.bias
830
+ _tmp = LoraInjectedLinear(
831
+ _source.in_features,
832
+ _source.out_features,
833
+ _source.bias is not None,
834
+ r=r.pop(0) if isinstance(r, list) else r,
835
+ )
836
+ _tmp.linear.weight = weight
837
+
838
+ if bias is not None:
839
+ _tmp.linear.bias = bias
840
+
841
+ # switch the module
842
+ _module._modules[name] = _tmp
843
+
844
+ up_weight = loras.pop(0)
845
+ down_weight = loras.pop(0)
846
+
847
+ _module._modules[name].lora_up.weight = nn.Parameter(
848
+ up_weight.type(weight.dtype)
849
+ )
850
+ _module._modules[name].lora_down.weight = nn.Parameter(
851
+ down_weight.type(weight.dtype)
852
+ )
853
+
854
+ _module._modules[name].to(weight.device)
855
+
856
+
857
+ def monkeypatch_or_replace_lora_extended(
858
+ model,
859
+ loras,
860
+ target_replace_module=DEFAULT_TARGET_REPLACE,
861
+ r: Union[int, List[int]] = 4,
862
+ ):
863
+ for _module, name, _child_module in _find_modules(
864
+ model,
865
+ target_replace_module,
866
+ search_class=[
867
+ nn.Linear,
868
+ nn.Conv2d,
869
+ nn.Conv3d,
870
+ LoraInjectedLinear,
871
+ LoraInjectedConv2d,
872
+ LoraInjectedConv3d,
873
+ ],
874
+ ):
875
+
876
+ if (_child_module.__class__ == nn.Linear) or (
877
+ _child_module.__class__ == LoraInjectedLinear
878
+ ):
879
+ if len(loras[0].shape) != 2:
880
+ continue
881
+
882
+ _source = (
883
+ _child_module.linear
884
+ if isinstance(_child_module, LoraInjectedLinear)
885
+ else _child_module
886
+ )
887
+
888
+ weight = _source.weight
889
+ bias = _source.bias
890
+ _tmp = LoraInjectedLinear(
891
+ _source.in_features,
892
+ _source.out_features,
893
+ _source.bias is not None,
894
+ r=r.pop(0) if isinstance(r, list) else r,
895
+ )
896
+ _tmp.linear.weight = weight
897
+
898
+ if bias is not None:
899
+ _tmp.linear.bias = bias
900
+
901
+ elif (_child_module.__class__ == nn.Conv2d) or (
902
+ _child_module.__class__ == LoraInjectedConv2d
903
+ ):
904
+ if len(loras[0].shape) != 4:
905
+ continue
906
+ _source = (
907
+ _child_module.conv
908
+ if isinstance(_child_module, LoraInjectedConv2d)
909
+ else _child_module
910
+ )
911
+
912
+ weight = _source.weight
913
+ bias = _source.bias
914
+ _tmp = LoraInjectedConv2d(
915
+ _source.in_channels,
916
+ _source.out_channels,
917
+ _source.kernel_size,
918
+ _source.stride,
919
+ _source.padding,
920
+ _source.dilation,
921
+ _source.groups,
922
+ _source.bias is not None,
923
+ r=r.pop(0) if isinstance(r, list) else r,
924
+ )
925
+
926
+ _tmp.conv.weight = weight
927
+
928
+ if bias is not None:
929
+ _tmp.conv.bias = bias
930
+
931
+ elif _child_module.__class__ == nn.Conv3d or(
932
+ _child_module.__class__ == LoraInjectedConv3d
933
+ ):
934
+
935
+ if len(loras[0].shape) != 5:
936
+ continue
937
+
938
+ _source = (
939
+ _child_module.conv
940
+ if isinstance(_child_module, LoraInjectedConv3d)
941
+ else _child_module
942
+ )
943
+
944
+ weight = _source.weight
945
+ bias = _source.bias
946
+ _tmp = LoraInjectedConv3d(
947
+ _source.in_channels,
948
+ _source.out_channels,
949
+ bias=_source.bias is not None,
950
+ kernel_size=_source.kernel_size,
951
+ padding=_source.padding,
952
+ r=r.pop(0) if isinstance(r, list) else r,
953
+ )
954
+
955
+ _tmp.conv.weight = weight
956
+
957
+ if bias is not None:
958
+ _tmp.conv.bias = bias
959
+
960
+ # switch the module
961
+ _module._modules[name] = _tmp
962
+
963
+ up_weight = loras.pop(0)
964
+ down_weight = loras.pop(0)
965
+
966
+ _module._modules[name].lora_up.weight = nn.Parameter(
967
+ up_weight.type(weight.dtype)
968
+ )
969
+ _module._modules[name].lora_down.weight = nn.Parameter(
970
+ down_weight.type(weight.dtype)
971
+ )
972
+
973
+ _module._modules[name].to(weight.device)
974
+
975
+
976
+ def monkeypatch_or_replace_safeloras(models, safeloras):
977
+ loras = parse_safeloras(safeloras)
978
+
979
+ for name, (lora, ranks, target) in loras.items():
980
+ model = getattr(models, name, None)
981
+
982
+ if not model:
983
+ print(f"No model provided for {name}, contained in Lora")
984
+ continue
985
+
986
+ monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
987
+
988
+
989
+ def monkeypatch_remove_lora(model):
990
+ for _module, name, _child_module in _find_modules(
991
+ model, search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d]
992
+ ):
993
+ if isinstance(_child_module, LoraInjectedLinear):
994
+ _source = _child_module.linear
995
+ weight, bias = _source.weight, _source.bias
996
+
997
+ _tmp = nn.Linear(
998
+ _source.in_features, _source.out_features, bias is not None
999
+ )
1000
+
1001
+ _tmp.weight = weight
1002
+ if bias is not None:
1003
+ _tmp.bias = bias
1004
+
1005
+ else:
1006
+ _source = _child_module.conv
1007
+ weight, bias = _source.weight, _source.bias
1008
+
1009
+ if isinstance(_source, nn.Conv2d):
1010
+ _tmp = nn.Conv2d(
1011
+ in_channels=_source.in_channels,
1012
+ out_channels=_source.out_channels,
1013
+ kernel_size=_source.kernel_size,
1014
+ stride=_source.stride,
1015
+ padding=_source.padding,
1016
+ dilation=_source.dilation,
1017
+ groups=_source.groups,
1018
+ bias=bias is not None,
1019
+ )
1020
+
1021
+ _tmp.weight = weight
1022
+ if bias is not None:
1023
+ _tmp.bias = bias
1024
+
1025
+ if isinstance(_source, nn.Conv3d):
1026
+ _tmp = nn.Conv3d(
1027
+ _source.in_channels,
1028
+ _source.out_channels,
1029
+ bias=_source.bias is not None,
1030
+ kernel_size=_source.kernel_size,
1031
+ padding=_source.padding,
1032
+ )
1033
+
1034
+ _tmp.weight = weight
1035
+ if bias is not None:
1036
+ _tmp.bias = bias
1037
+
1038
+ _module._modules[name] = _tmp
1039
+
1040
+
1041
+ def monkeypatch_add_lora(
1042
+ model,
1043
+ loras,
1044
+ target_replace_module=DEFAULT_TARGET_REPLACE,
1045
+ alpha: float = 1.0,
1046
+ beta: float = 1.0,
1047
+ ):
1048
+ for _module, name, _child_module in _find_modules(
1049
+ model, target_replace_module, search_class=[LoraInjectedLinear]
1050
+ ):
1051
+ weight = _child_module.linear.weight
1052
+
1053
+ up_weight = loras.pop(0)
1054
+ down_weight = loras.pop(0)
1055
+
1056
+ _module._modules[name].lora_up.weight = nn.Parameter(
1057
+ up_weight.type(weight.dtype).to(weight.device) * alpha
1058
+ + _module._modules[name].lora_up.weight.to(weight.device) * beta
1059
+ )
1060
+ _module._modules[name].lora_down.weight = nn.Parameter(
1061
+ down_weight.type(weight.dtype).to(weight.device) * alpha
1062
+ + _module._modules[name].lora_down.weight.to(weight.device) * beta
1063
+ )
1064
+
1065
+ _module._modules[name].to(weight.device)
1066
+
1067
+
1068
+ def tune_lora_scale(model, alpha: float = 1.0):
1069
+ for _module in model.modules():
1070
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]:
1071
+ _module.scale = alpha
1072
+
1073
+
1074
+ def set_lora_diag(model, diag: torch.Tensor):
1075
+ for _module in model.modules():
1076
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]:
1077
+ _module.set_selector_from_diag(diag)
1078
+
1079
+
1080
+ def _text_lora_path(path: str) -> str:
1081
+ assert path.endswith(".pt"), "Only .pt files are supported"
1082
+ return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
1083
+
1084
+
1085
+ def _ti_lora_path(path: str) -> str:
1086
+ assert path.endswith(".pt"), "Only .pt files are supported"
1087
+ return ".".join(path.split(".")[:-1] + ["ti", "pt"])
1088
+
1089
+
1090
+ def apply_learned_embed_in_clip(
1091
+ learned_embeds,
1092
+ text_encoder,
1093
+ tokenizer,
1094
+ token: Optional[Union[str, List[str]]] = None,
1095
+ idempotent=False,
1096
+ ):
1097
+ if isinstance(token, str):
1098
+ trained_tokens = [token]
1099
+ elif isinstance(token, list):
1100
+ assert len(learned_embeds.keys()) == len(
1101
+ token
1102
+ ), "The number of tokens and the number of embeds should be the same"
1103
+ trained_tokens = token
1104
+ else:
1105
+ trained_tokens = list(learned_embeds.keys())
1106
+
1107
+ for token in trained_tokens:
1108
+ print(token)
1109
+ embeds = learned_embeds[token]
1110
+
1111
+ # cast to dtype of text_encoder
1112
+ dtype = text_encoder.get_input_embeddings().weight.dtype
1113
+ num_added_tokens = tokenizer.add_tokens(token)
1114
+
1115
+ i = 1
1116
+ if not idempotent:
1117
+ while num_added_tokens == 0:
1118
+ print(f"The tokenizer already contains the token {token}.")
1119
+ token = f"{token[:-1]}-{i}>"
1120
+ print(f"Attempting to add the token {token}.")
1121
+ num_added_tokens = tokenizer.add_tokens(token)
1122
+ i += 1
1123
+ elif num_added_tokens == 0 and idempotent:
1124
+ print(f"The tokenizer already contains the token {token}.")
1125
+ print(f"Replacing {token} embedding.")
1126
+
1127
+ # resize the token embeddings
1128
+ text_encoder.resize_token_embeddings(len(tokenizer))
1129
+
1130
+ # get the id for the token and assign the embeds
1131
+ token_id = tokenizer.convert_tokens_to_ids(token)
1132
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
1133
+ return token
1134
+
1135
+
1136
+ def load_learned_embed_in_clip(
1137
+ learned_embeds_path,
1138
+ text_encoder,
1139
+ tokenizer,
1140
+ token: Optional[Union[str, List[str]]] = None,
1141
+ idempotent=False,
1142
+ ):
1143
+ learned_embeds = torch.load(learned_embeds_path)
1144
+ apply_learned_embed_in_clip(
1145
+ learned_embeds, text_encoder, tokenizer, token, idempotent
1146
+ )
1147
+
1148
+
1149
+ def patch_pipe(
1150
+ pipe,
1151
+ maybe_unet_path,
1152
+ token: Optional[str] = None,
1153
+ r: int = 4,
1154
+ patch_unet=True,
1155
+ patch_text=True,
1156
+ patch_ti=True,
1157
+ idempotent_token=True,
1158
+ unet_target_replace_module=DEFAULT_TARGET_REPLACE,
1159
+ text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
1160
+ ):
1161
+ if maybe_unet_path.endswith(".pt"):
1162
+ # torch format
1163
+
1164
+ if maybe_unet_path.endswith(".ti.pt"):
1165
+ unet_path = maybe_unet_path[:-6] + ".pt"
1166
+ elif maybe_unet_path.endswith(".text_encoder.pt"):
1167
+ unet_path = maybe_unet_path[:-16] + ".pt"
1168
+ else:
1169
+ unet_path = maybe_unet_path
1170
+
1171
+ ti_path = _ti_lora_path(unet_path)
1172
+ text_path = _text_lora_path(unet_path)
1173
+
1174
+ if patch_unet:
1175
+ print("LoRA : Patching Unet")
1176
+ monkeypatch_or_replace_lora(
1177
+ pipe.unet,
1178
+ torch.load(unet_path),
1179
+ r=r,
1180
+ target_replace_module=unet_target_replace_module,
1181
+ )
1182
+
1183
+ if patch_text:
1184
+ print("LoRA : Patching text encoder")
1185
+ monkeypatch_or_replace_lora(
1186
+ pipe.text_encoder,
1187
+ torch.load(text_path),
1188
+ target_replace_module=text_target_replace_module,
1189
+ r=r,
1190
+ )
1191
+ if patch_ti:
1192
+ print("LoRA : Patching token input")
1193
+ token = load_learned_embed_in_clip(
1194
+ ti_path,
1195
+ pipe.text_encoder,
1196
+ pipe.tokenizer,
1197
+ token=token,
1198
+ idempotent=idempotent_token,
1199
+ )
1200
+
1201
+ elif maybe_unet_path.endswith(".safetensors"):
1202
+ safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu")
1203
+ monkeypatch_or_replace_safeloras(pipe, safeloras)
1204
+ tok_dict = parse_safeloras_embeds(safeloras)
1205
+ if patch_ti:
1206
+ apply_learned_embed_in_clip(
1207
+ tok_dict,
1208
+ pipe.text_encoder,
1209
+ pipe.tokenizer,
1210
+ token=token,
1211
+ idempotent=idempotent_token,
1212
+ )
1213
+ return tok_dict
1214
+
1215
+
1216
+ def train_patch_pipe(pipe, patch_unet, patch_text):
1217
+ if patch_unet:
1218
+ print("LoRA : Patching Unet")
1219
+ collapse_lora(pipe.unet)
1220
+ monkeypatch_remove_lora(pipe.unet)
1221
+
1222
+ if patch_text:
1223
+ print("LoRA : Patching text encoder")
1224
+
1225
+ collapse_lora(pipe.text_encoder)
1226
+ monkeypatch_remove_lora(pipe.text_encoder)
1227
+
1228
+ @torch.no_grad()
1229
+ def inspect_lora(model):
1230
+ moved = {}
1231
+
1232
+ for name, _module in model.named_modules():
1233
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]:
1234
+ ups = _module.lora_up.weight.data.clone()
1235
+ downs = _module.lora_down.weight.data.clone()
1236
+
1237
+ wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1)
1238
+
1239
+ dist = wght.flatten().abs().mean().item()
1240
+ if name in moved:
1241
+ moved[name].append(dist)
1242
+ else:
1243
+ moved[name] = [dist]
1244
+
1245
+ return moved
1246
+
1247
+
1248
+ def save_all(
1249
+ unet,
1250
+ text_encoder,
1251
+ save_path,
1252
+ placeholder_token_ids=None,
1253
+ placeholder_tokens=None,
1254
+ save_lora=True,
1255
+ save_ti=True,
1256
+ target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
1257
+ target_replace_module_unet=DEFAULT_TARGET_REPLACE,
1258
+ safe_form=True,
1259
+ ):
1260
+ if not safe_form:
1261
+ # save ti
1262
+ if save_ti:
1263
+ ti_path = _ti_lora_path(save_path)
1264
+ learned_embeds_dict = {}
1265
+ for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
1266
+ learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
1267
+ print(
1268
+ f"Current Learned Embeddings for {tok}:, id {tok_id} ",
1269
+ learned_embeds[:4],
1270
+ )
1271
+ learned_embeds_dict[tok] = learned_embeds.detach().cpu()
1272
+
1273
+ torch.save(learned_embeds_dict, ti_path)
1274
+ print("Ti saved to ", ti_path)
1275
+
1276
+ # save text encoder
1277
+ if save_lora:
1278
+ save_lora_weight(
1279
+ unet, save_path, target_replace_module=target_replace_module_unet
1280
+ )
1281
+ print("Unet saved to ", save_path)
1282
+
1283
+ save_lora_weight(
1284
+ text_encoder,
1285
+ _text_lora_path(save_path),
1286
+ target_replace_module=target_replace_module_text,
1287
+ )
1288
+ print("Text Encoder saved to ", _text_lora_path(save_path))
1289
+
1290
+ else:
1291
+ assert save_path.endswith(
1292
+ ".safetensors"
1293
+ ), f"Save path : {save_path} should end with .safetensors"
1294
+
1295
+ loras = {}
1296
+ embeds = {}
1297
+
1298
+ if save_lora:
1299
+
1300
+ loras["unet"] = (unet, target_replace_module_unet)
1301
+ loras["text_encoder"] = (text_encoder, target_replace_module_text)
1302
+
1303
+ if save_ti:
1304
+ for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
1305
+ learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
1306
+ print(
1307
+ f"Current Learned Embeddings for {tok}:, id {tok_id} ",
1308
+ learned_embeds[:4],
1309
+ )
1310
+ embeds[tok] = learned_embeds.detach().cpu()
1311
+
1312
+ save_safeloras_with_embeds(loras, embeds, save_path)
predict.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from cog import BasePredictor, Input, Path
4
+ import subprocess
5
+ import shutil
6
+
7
+ MODEL_CACHE = "model-cache"
8
+
9
+ class Predictor(BasePredictor):
10
+ def setup(self):
11
+ pass
12
+
13
+ def predict(
14
+ self,
15
+ prompt: str = Input(
16
+ description="Input prompt", default="An astronaut riding a horse"
17
+ ),
18
+ negative_prompt: str = Input(
19
+ description="Negative prompt", default=None
20
+ ),
21
+ init_video: Path = Input(
22
+ description="URL of the initial video (optional)", default=None
23
+ ),
24
+ init_weight: float = Input(
25
+ description="Strength of init_video", default=0.5
26
+ ),
27
+ num_frames: int = Input(
28
+ description="Number of frames for the output video", default=24
29
+ ),
30
+ num_inference_steps: int = Input(
31
+ description="Number of denoising steps", ge=1, le=500, default=50
32
+ ),
33
+ width: int = Input(
34
+ description="Width of the output video", ge=256, default=576
35
+ ),
36
+ height: int = Input(
37
+ description="Height of the output video", ge=256, default=320
38
+ ),
39
+ guidance_scale: float = Input(
40
+ description="Guidance scale", ge=1.0, le=100.0, default=7.5
41
+ ),
42
+ fps: int = Input(description="fps for the output video", default=8),
43
+ model: str = Input(
44
+ description="Model to use", default="xl", choices=["xl", "576w", "potat1", "animov-512x"]
45
+ ),
46
+ batch_size: int = Input(description="Batch size", default=1, ge=1),
47
+ remove_watermark: bool = Input(
48
+ description="Remove watermark", default=False
49
+ ),
50
+ seed: int = Input(
51
+ description="Random seed. Leave blank to randomize the seed", default=None
52
+ ),
53
+ ) -> List[Path]:
54
+ if seed is None:
55
+ seed = int.from_bytes(os.urandom(2), "big")
56
+ print(f"Using seed: {seed}")
57
+
58
+ shutil.rmtree("output", ignore_errors=True)
59
+ os.makedirs("output", exist_ok=True)
60
+
61
+ args = {
62
+ "prompt": prompt,
63
+ "negative_prompt": negative_prompt,
64
+ "batch_size": batch_size,
65
+ "num_frames": num_frames,
66
+ "num_steps": num_inference_steps,
67
+ "seed": seed,
68
+ "guidance-scale": guidance_scale,
69
+ "width": width,
70
+ "height": height,
71
+ "fps": fps,
72
+ "device": "cuda",
73
+ "output_dir": "output",
74
+ "remove-watermark": remove_watermark,
75
+ }
76
+
77
+ args['model'] = MODEL_CACHE + "/" + model
78
+
79
+ if init_video is not None:
80
+ # for some reason I need to copy the file to make it work
81
+ if os.path.exists("input.mp4"):
82
+ os.unlink("input.mp4")
83
+ shutil.copy(init_video, "input.mp4")
84
+
85
+ args["init-video"] = "input.mp4"
86
+ args["init-weight"] = init_weight
87
+ print("init video", os.stat("input.mp4").st_size)
88
+
89
+ cmd = ["python", "inference.py"]
90
+ for k, v in args.items():
91
+ if not v is None:
92
+ cmd.append(f"--{k}")
93
+ cmd.append(str(v))
94
+ subprocess.check_call(cmd)
95
+ # outputs = inference.run(**args)
96
+
97
+ outputs = []
98
+ for f in os.listdir("output"):
99
+ if f.endswith(".mp4"):
100
+ outputs.append(Path(os.path.join("output", f)))
101
+ return outputs
samples.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import requests
3
+ import sys
4
+ import os
5
+
6
+
7
+ def gen(output_fn, **kwargs):
8
+ if os.path.exists(output_fn):
9
+ print("Skipping", output_fn)
10
+ return
11
+
12
+ print("Generating", output_fn)
13
+ url = "http://localhost:5000/predictions"
14
+ response = requests.post(url, json={"input": kwargs})
15
+ data = response.json()
16
+
17
+ try:
18
+ datauri = data["output"][0]
19
+ base64_encoded_data = datauri.split(",")[1]
20
+ data = base64.b64decode(base64_encoded_data)
21
+ except:
22
+ print("Error!")
23
+ print("input:", kwargs)
24
+ print(data["logs"])
25
+ # sys.exit(1)
26
+
27
+ with open(output_fn, "wb") as f:
28
+ f.write(data)
29
+
30
+
31
+ def main():
32
+ gen(
33
+ "sample.mp4",
34
+ prompt="A deep sea video of a bioluminescent siphonophore, 8k, beautiful, award winning, close up",
35
+ seed=42,
36
+ num_frames=24,
37
+ model="potat1",
38
+ num_inference_steps=30,
39
+ guidance_scale=17.5,
40
+ fps=12,
41
+ )
42
+ gen(
43
+ "vid-sample.mp4",
44
+ prompt="A deep sea video of a bioluminescent siphonophore, 8k, beautiful, award winning, close up",
45
+ seed=42,
46
+ num_frames=24,
47
+ model="zeroscope_v2_XL",
48
+ num_inference_steps=30,
49
+ guidance_scale=17.5,
50
+ init_video="https://replicate.delivery/pbxt/qxacIWhXu0rFAZu6GMElrXrTL5Wx6ZqnjPqIoS7DgIftowkIA/out.mp4",
51
+ fps=12,
52
+ )
53
+
54
+
55
+
56
+ if __name__ == "__main__":
57
+ main()
train.py ADDED
@@ -0,0 +1,998 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import logging
4
+ import inspect
5
+ import math
6
+ import os
7
+ import random
8
+ import gc
9
+ import copy
10
+
11
+ from typing import Dict, Optional, Tuple
12
+ from omegaconf import OmegaConf
13
+
14
+ import cv2
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torch.utils.checkpoint
18
+ import torchvision.transforms as T
19
+ import diffusers
20
+ import transformers
21
+
22
+ from torchvision import transforms
23
+ from tqdm.auto import tqdm
24
+
25
+ from accelerate import Accelerator
26
+ from accelerate.logging import get_logger
27
+ from accelerate.utils import set_seed
28
+
29
+ from models.unet_3d_condition import UNet3DConditionModel
30
+ from diffusers.models import AutoencoderKL
31
+ from diffusers import DPMSolverMultistepScheduler, DDPMScheduler, TextToVideoSDPipeline
32
+ from diffusers.optimization import get_scheduler
33
+ from diffusers.utils import check_min_version, export_to_video
34
+ from diffusers.utils.import_utils import is_xformers_available
35
+ from diffusers.models.attention_processor import AttnProcessor2_0, Attention
36
+ from diffusers.models.attention import BasicTransformerBlock
37
+
38
+ from transformers import CLIPTextModel, CLIPTokenizer
39
+ from transformers.models.clip.modeling_clip import CLIPEncoder
40
+ from utils.dataset import VideoJsonDataset, SingleVideoDataset, \
41
+ ImageDataset, VideoFolderDataset, CachedDataset
42
+ from einops import rearrange, repeat
43
+
44
+ from utils.lora import (
45
+ extract_lora_ups_down,
46
+ inject_trainable_lora,
47
+ inject_trainable_lora_extended,
48
+ save_lora_weight,
49
+ train_patch_pipe,
50
+ monkeypatch_or_replace_lora,
51
+ monkeypatch_or_replace_lora_extended
52
+ )
53
+
54
+
55
+ already_printed_trainables = False
56
+
57
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
58
+ check_min_version("0.10.0.dev0")
59
+
60
+ logger = get_logger(__name__, log_level="INFO")
61
+
62
+ def create_logging(logging, logger, accelerator):
63
+ logging.basicConfig(
64
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
65
+ datefmt="%m/%d/%Y %H:%M:%S",
66
+ level=logging.INFO,
67
+ )
68
+ logger.info(accelerator.state, main_process_only=False)
69
+
70
+ def accelerate_set_verbose(accelerator):
71
+ if accelerator.is_local_main_process:
72
+ transformers.utils.logging.set_verbosity_warning()
73
+ diffusers.utils.logging.set_verbosity_info()
74
+ else:
75
+ transformers.utils.logging.set_verbosity_error()
76
+ diffusers.utils.logging.set_verbosity_error()
77
+
78
+ def get_train_dataset(dataset_types, train_data, tokenizer):
79
+ train_datasets = []
80
+
81
+ # Loop through all available datasets, get the name, then add to list of data to process.
82
+ for DataSet in [VideoJsonDataset, SingleVideoDataset, ImageDataset, VideoFolderDataset]:
83
+ for dataset in dataset_types:
84
+ if dataset == DataSet.__getname__():
85
+ train_datasets.append(DataSet(**train_data, tokenizer=tokenizer))
86
+
87
+ if len(train_datasets) > 0:
88
+ return train_datasets
89
+ else:
90
+ raise ValueError("Dataset type not found: 'json', 'single_video', 'folder', 'image'")
91
+
92
+ def extend_datasets(datasets, dataset_items, extend=False):
93
+ biggest_data_len = max(x.__len__() for x in datasets)
94
+ extended = []
95
+ for dataset in datasets:
96
+ if dataset.__len__() == 0:
97
+ del dataset
98
+ continue
99
+ if dataset.__len__() < biggest_data_len:
100
+ for item in dataset_items:
101
+ if extend and item not in extended and hasattr(dataset, item):
102
+ print(f"Extending {item}")
103
+
104
+ value = getattr(dataset, item)
105
+ value *= biggest_data_len
106
+ value = value[:biggest_data_len]
107
+
108
+ setattr(dataset, item, value)
109
+
110
+ print(f"New {item} dataset length: {dataset.__len__()}")
111
+ extended.append(item)
112
+
113
+ def export_to_video(video_frames, output_video_path, fps):
114
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
115
+ h, w, _ = video_frames[0].shape
116
+ video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h))
117
+ for i in range(len(video_frames)):
118
+ img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
119
+ video_writer.write(img)
120
+
121
+ def create_output_folders(output_dir, config):
122
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
123
+ out_dir = os.path.join(output_dir, f"train_{now}")
124
+
125
+ os.makedirs(out_dir, exist_ok=True)
126
+ os.makedirs(f"{out_dir}/samples", exist_ok=True)
127
+ OmegaConf.save(config, os.path.join(out_dir, 'config.yaml'))
128
+
129
+ return out_dir
130
+
131
+ def load_primary_models(pretrained_model_path):
132
+ noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
133
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
134
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
135
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
136
+ unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
137
+
138
+ return noise_scheduler, tokenizer, text_encoder, vae, unet
139
+
140
+ def unet_and_text_g_c(unet, text_encoder, unet_enable, text_enable):
141
+ unet._set_gradient_checkpointing(value=unet_enable)
142
+ text_encoder._set_gradient_checkpointing(CLIPEncoder, value=text_enable)
143
+
144
+ def freeze_models(models_to_freeze):
145
+ for model in models_to_freeze:
146
+ if model is not None: model.requires_grad_(False)
147
+
148
+ def is_attn(name):
149
+ return ('attn1' or 'attn2' == name.split('.')[-1])
150
+
151
+ def set_processors(attentions):
152
+ for attn in attentions: attn.set_processor(AttnProcessor2_0())
153
+
154
+ def set_torch_2_attn(unet):
155
+ optim_count = 0
156
+
157
+ for name, module in unet.named_modules():
158
+ if is_attn(name):
159
+ if isinstance(module, torch.nn.ModuleList):
160
+ for m in module:
161
+ if isinstance(m, BasicTransformerBlock):
162
+ set_processors([m.attn1, m.attn2])
163
+ optim_count += 1
164
+ if optim_count > 0:
165
+ print(f"{optim_count} Attention layers using Scaled Dot Product Attention.")
166
+
167
+ def handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet):
168
+ try:
169
+ is_torch_2 = hasattr(F, 'scaled_dot_product_attention')
170
+
171
+ if enable_xformers_memory_efficient_attention and not is_torch_2:
172
+ if is_xformers_available():
173
+ from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
174
+ unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
175
+ else:
176
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
177
+
178
+ if enable_torch_2_attn and is_torch_2:
179
+ set_torch_2_attn(unet)
180
+ except:
181
+ print("Could not enable memory efficient attention for xformers or Torch 2.0.")
182
+
183
+ def inject_lora(use_lora, model, replace_modules, is_extended=False, dropout=0.0, lora_path='', r=16):
184
+ injector = (
185
+ inject_trainable_lora if not is_extended
186
+ else
187
+ inject_trainable_lora_extended
188
+ )
189
+
190
+ params = None
191
+ negation = None
192
+
193
+ if os.path.exists(lora_path):
194
+ try:
195
+ for f in os.listdir(lora_path):
196
+ if f.endswith('.pt'):
197
+ lora_file = os.path.join(lora_path, f)
198
+
199
+ if 'text_encoder' in f and isinstance(model, CLIPTextModel):
200
+ monkeypatch_or_replace_lora(
201
+ model,
202
+ torch.load(lora_file),
203
+ target_replace_module=replace_modules,
204
+ r=r
205
+ )
206
+ print("Successfully loaded Text Encoder LoRa.")
207
+
208
+ if 'unet' in f and isinstance(model, UNet3DConditionModel):
209
+ monkeypatch_or_replace_lora_extended(
210
+ model,
211
+ torch.load(lora_file),
212
+ target_replace_module=replace_modules,
213
+ r=r
214
+ )
215
+ print("Successfully loaded UNET LoRa.")
216
+
217
+ except Exception as e:
218
+ print(e)
219
+ print("Could not load LoRAs. Injecting new ones instead...")
220
+
221
+ if use_lora:
222
+ REPLACE_MODULES = replace_modules
223
+ injector_args = {
224
+ "model": model,
225
+ "target_replace_module": REPLACE_MODULES,
226
+ "r": r
227
+ }
228
+ if not is_extended: injector_args['dropout_p'] = dropout
229
+
230
+ params, negation = injector(**injector_args)
231
+ for _up, _down in extract_lora_ups_down(
232
+ model,
233
+ target_replace_module=REPLACE_MODULES):
234
+
235
+ if all(x is not None for x in [_up, _down]):
236
+ print(f"Lora successfully injected into {model.__class__.__name__}.")
237
+
238
+ break
239
+
240
+ return params, negation
241
+
242
+ def save_lora(model, name, condition, replace_modules, step, save_path):
243
+ if condition and replace_modules is not None:
244
+ save_path = f"{save_path}/{step}_{name}.pt"
245
+ save_lora_weight(model, save_path, replace_modules)
246
+
247
+ def handle_lora_save(
248
+ use_unet_lora,
249
+ use_text_lora,
250
+ model,
251
+ save_path,
252
+ checkpoint_step,
253
+ unet_target_modules,
254
+ text_encoder_target_modules
255
+ ):
256
+
257
+ save_path = f"{save_path}/lora"
258
+ os.makedirs(save_path, exist_ok=True)
259
+
260
+ save_lora(
261
+ model.unet,
262
+ 'unet',
263
+ use_unet_lora,
264
+ unet_target_modules,
265
+ checkpoint_step,
266
+ save_path,
267
+ )
268
+ save_lora(
269
+ model.text_encoder,
270
+ 'text_encoder',
271
+ use_text_lora,
272
+ text_encoder_target_modules,
273
+ checkpoint_step,
274
+ save_path
275
+ )
276
+
277
+ train_patch_pipe(model, use_unet_lora, use_text_lora)
278
+
279
+ def param_optim(model, condition, extra_params=None, is_lora=False, negation=None):
280
+ return {
281
+ "model": model,
282
+ "condition": condition,
283
+ 'extra_params': extra_params,
284
+ 'is_lora': is_lora,
285
+ "negation": negation
286
+ }
287
+
288
+
289
+ def create_optim_params(name='param', params=None, lr=5e-6, extra_params=None):
290
+ params = {
291
+ "name": name,
292
+ "params": params,
293
+ "lr": lr
294
+ }
295
+
296
+ if extra_params is not None:
297
+ for k, v in extra_params.items():
298
+ params[k] = v
299
+
300
+ return params
301
+
302
+ def negate_params(name, negation):
303
+ # We have to do this if we are co-training with LoRA.
304
+ # This ensures that parameter groups aren't duplicated.
305
+ if negation is None: return False
306
+ for n in negation:
307
+ if n in name and 'temp' not in name:
308
+ return True
309
+ return False
310
+
311
+
312
+ def create_optimizer_params(model_list, lr):
313
+ import itertools
314
+ optimizer_params = []
315
+
316
+ for optim in model_list:
317
+ model, condition, extra_params, is_lora, negation = optim.values()
318
+ # Check if we are doing LoRA training.
319
+ if is_lora and condition:
320
+ params = create_optim_params(
321
+ params=itertools.chain(*model),
322
+ extra_params=extra_params
323
+ )
324
+ optimizer_params.append(params)
325
+ continue
326
+
327
+ # If this is true, we can train it.
328
+ if condition:
329
+ for n, p in model.named_parameters():
330
+ should_negate = 'lora' in n
331
+ if should_negate: continue
332
+
333
+ params = create_optim_params(n, p, lr, extra_params)
334
+ optimizer_params.append(params)
335
+
336
+ return optimizer_params
337
+
338
+ def get_optimizer(use_8bit_adam):
339
+ if use_8bit_adam:
340
+ try:
341
+ import bitsandbytes as bnb
342
+ except ImportError:
343
+ raise ImportError(
344
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
345
+ )
346
+
347
+ return bnb.optim.AdamW8bit
348
+ else:
349
+ return torch.optim.AdamW
350
+
351
+ def is_mixed_precision(accelerator):
352
+ weight_dtype = torch.float32
353
+
354
+ if accelerator.mixed_precision == "fp16":
355
+ weight_dtype = torch.float16
356
+
357
+ elif accelerator.mixed_precision == "bf16":
358
+ weight_dtype = torch.bfloat16
359
+
360
+ return weight_dtype
361
+
362
+ def cast_to_gpu_and_type(model_list, accelerator, weight_dtype):
363
+ for model in model_list:
364
+ if model is not None: model.to(accelerator.device, dtype=weight_dtype)
365
+
366
+ def handle_cache_latents(
367
+ should_cache,
368
+ output_dir,
369
+ train_dataloader,
370
+ train_batch_size,
371
+ vae,
372
+ cached_latent_dir=None
373
+ ):
374
+
375
+ # Cache latents by storing them in VRAM.
376
+ # Speeds up training and saves memory by not encoding during the train loop.
377
+ if not should_cache: return None
378
+ vae.to('cuda', dtype=torch.float16)
379
+ vae.enable_slicing()
380
+
381
+ cached_latent_dir = (
382
+ os.path.abspath(cached_latent_dir) if cached_latent_dir is not None else None
383
+ )
384
+
385
+ if cached_latent_dir is None:
386
+ cache_save_dir = f"{output_dir}/cached_latents"
387
+ os.makedirs(cache_save_dir, exist_ok=True)
388
+
389
+ for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")):
390
+
391
+ save_name = f"cached_{i}"
392
+ full_out_path = f"{cache_save_dir}/{save_name}.pt"
393
+
394
+ pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float16)
395
+ batch['pixel_values'] = tensor_to_vae_latent(pixel_values, vae)
396
+ for k, v in batch.items(): batch[k] = v[0]
397
+
398
+ torch.save(batch, full_out_path)
399
+ del pixel_values
400
+ del batch
401
+
402
+ # We do this to avoid fragmentation from casting latents between devices.
403
+ torch.cuda.empty_cache()
404
+ else:
405
+ cache_save_dir = cached_latent_dir
406
+
407
+
408
+ return torch.utils.data.DataLoader(
409
+ CachedDataset(cache_dir=cache_save_dir),
410
+ batch_size=train_batch_size,
411
+ shuffle=True,
412
+ num_workers=0
413
+ )
414
+
415
+ def handle_trainable_modules(model, trainable_modules=None, is_enabled=True, negation=None):
416
+ global already_printed_trainables
417
+
418
+ # This can most definitely be refactored :-)
419
+ unfrozen_params = 0
420
+ if trainable_modules is not None:
421
+ for name, module in model.named_modules():
422
+ for tm in tuple(trainable_modules):
423
+ if tm == 'all':
424
+ model.requires_grad_(is_enabled)
425
+ unfrozen_params =len(list(model.parameters()))
426
+ break
427
+
428
+ if tm in name and 'lora' not in name:
429
+ for m in module.parameters():
430
+ m.requires_grad_(is_enabled)
431
+ if is_enabled: unfrozen_params +=1
432
+
433
+ if unfrozen_params > 0 and not already_printed_trainables:
434
+ already_printed_trainables = True
435
+ print(f"{unfrozen_params} params have been unfrozen for training.")
436
+
437
+ def tensor_to_vae_latent(t, vae):
438
+ video_length = t.shape[1]
439
+
440
+ t = rearrange(t, "b f c h w -> (b f) c h w")
441
+ latents = vae.encode(t).latent_dist.sample()
442
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
443
+ latents = latents * 0.18215
444
+
445
+ return latents
446
+
447
+ def sample_noise(latents, noise_strength, use_offset_noise):
448
+ b ,c, f, *_ = latents.shape
449
+ noise_latents = torch.randn_like(latents, device=latents.device)
450
+ offset_noise = None
451
+
452
+ if use_offset_noise:
453
+ offset_noise = torch.randn(b, c, f, 1, 1, device=latents.device)
454
+ noise_latents = noise_latents + noise_strength * offset_noise
455
+
456
+ return noise_latents
457
+
458
+ def should_sample(global_step, validation_steps, validation_data):
459
+ return (global_step % validation_steps == 0 or global_step == 1) \
460
+ and validation_data.sample_preview
461
+
462
+ def save_pipe(
463
+ path,
464
+ global_step,
465
+ accelerator,
466
+ unet,
467
+ text_encoder,
468
+ vae,
469
+ output_dir,
470
+ use_unet_lora,
471
+ use_text_lora,
472
+ unet_target_replace_module=None,
473
+ text_target_replace_module=None,
474
+ is_checkpoint=False,
475
+ ):
476
+
477
+ if is_checkpoint:
478
+ save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
479
+ os.makedirs(save_path, exist_ok=True)
480
+ else:
481
+ save_path = output_dir
482
+
483
+ # Save the dtypes so we can continue training at the same precision.
484
+ u_dtype, t_dtype, v_dtype = unet.dtype, text_encoder.dtype, vae.dtype
485
+
486
+ # Copy the model without creating a reference to it. This allows keeping the state of our lora training if enabled.
487
+ unet_out = copy.deepcopy(accelerator.unwrap_model(unet, keep_fp32_wrapper=False))
488
+ text_encoder_out = copy.deepcopy(accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False))
489
+
490
+ pipeline = TextToVideoSDPipeline.from_pretrained(
491
+ path,
492
+ unet=unet_out,
493
+ text_encoder=text_encoder_out,
494
+ vae=vae,
495
+ ).to(torch_dtype=torch.float16)
496
+
497
+ handle_lora_save(
498
+ use_unet_lora,
499
+ use_text_lora,
500
+ pipeline,
501
+ output_dir,
502
+ global_step,
503
+ unet_target_replace_module,
504
+ text_target_replace_module
505
+ )
506
+
507
+ pipeline.save_pretrained(save_path)
508
+
509
+ if is_checkpoint:
510
+ unet, text_encoder = accelerator.prepare(unet, text_encoder)
511
+ models_to_cast_back = [(unet, u_dtype), (text_encoder, t_dtype), (vae, v_dtype)]
512
+ [x[0].to(accelerator.device, dtype=x[1]) for x in models_to_cast_back]
513
+
514
+ logger.info(f"Saved model at {save_path} on step {global_step}")
515
+
516
+ del pipeline
517
+ del unet_out
518
+ del text_encoder_out
519
+ torch.cuda.empty_cache()
520
+ gc.collect()
521
+
522
+
523
+ def replace_prompt(prompt, token, wlist):
524
+ for w in wlist:
525
+ if w in prompt: return prompt.replace(w, token)
526
+ return prompt
527
+
528
+ def main(
529
+ pretrained_model_path: str,
530
+ output_dir: str,
531
+ train_data: Dict,
532
+ validation_data: Dict,
533
+ dataset_types: Tuple[str] = ('json'),
534
+ validation_steps: int = 100,
535
+ trainable_modules: Tuple[str] = ("attn1", "attn2"),
536
+ trainable_text_modules: Tuple[str] = ("all"),
537
+ extra_unet_params = None,
538
+ extra_text_encoder_params = None,
539
+ train_batch_size: int = 1,
540
+ max_train_steps: int = 500,
541
+ learning_rate: float = 5e-5,
542
+ scale_lr: bool = False,
543
+ lr_scheduler: str = "constant",
544
+ lr_warmup_steps: int = 0,
545
+ adam_beta1: float = 0.9,
546
+ adam_beta2: float = 0.999,
547
+ adam_weight_decay: float = 1e-2,
548
+ adam_epsilon: float = 1e-08,
549
+ max_grad_norm: float = 1.0,
550
+ gradient_accumulation_steps: int = 1,
551
+ gradient_checkpointing: bool = False,
552
+ text_encoder_gradient_checkpointing: bool = False,
553
+ checkpointing_steps: int = 500,
554
+ resume_from_checkpoint: Optional[str] = None,
555
+ mixed_precision: Optional[str] = "fp16",
556
+ use_8bit_adam: bool = False,
557
+ enable_xformers_memory_efficient_attention: bool = True,
558
+ enable_torch_2_attn: bool = False,
559
+ seed: Optional[int] = None,
560
+ train_text_encoder: bool = False,
561
+ use_offset_noise: bool = False,
562
+ offset_noise_strength: float = 0.1,
563
+ extend_dataset: bool = False,
564
+ cache_latents: bool = False,
565
+ cached_latent_dir = None,
566
+ use_unet_lora: bool = False,
567
+ use_text_lora: bool = False,
568
+ unet_lora_modules: Tuple[str] = ["ResnetBlock2D"],
569
+ text_encoder_lora_modules: Tuple[str] = ["CLIPEncoderLayer"],
570
+ lora_rank: int = 16,
571
+ lora_path: str = '',
572
+ **kwargs
573
+ ):
574
+
575
+ *_, config = inspect.getargvalues(inspect.currentframe())
576
+
577
+ accelerator = Accelerator(
578
+ gradient_accumulation_steps=gradient_accumulation_steps,
579
+ mixed_precision=mixed_precision,
580
+ log_with="tensorboard",
581
+ logging_dir=output_dir
582
+ )
583
+
584
+ # Make one log on every process with the configuration for debugging.
585
+ create_logging(logging, logger, accelerator)
586
+
587
+ # Initialize accelerate, transformers, and diffusers warnings
588
+ accelerate_set_verbose(accelerator)
589
+
590
+ # If passed along, set the training seed now.
591
+ if seed is not None:
592
+ set_seed(seed)
593
+
594
+ # Handle the output folder creation
595
+ if accelerator.is_main_process:
596
+ output_dir = create_output_folders(output_dir, config)
597
+
598
+ # Load scheduler, tokenizer and models.
599
+ noise_scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(pretrained_model_path)
600
+
601
+ # Freeze any necessary models
602
+ freeze_models([vae, text_encoder, unet])
603
+
604
+ # Enable xformers if available
605
+ handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet)
606
+
607
+ if scale_lr:
608
+ learning_rate = (
609
+ learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
610
+ )
611
+
612
+ # Initialize the optimizer
613
+ optimizer_cls = get_optimizer(use_8bit_adam)
614
+
615
+ # Use LoRA if enabled.
616
+ unet_lora_params, unet_negation = inject_lora(
617
+ use_unet_lora, unet, unet_lora_modules, is_extended=True,
618
+ r=lora_rank, lora_path=lora_path
619
+ )
620
+
621
+ text_encoder_lora_params, text_encoder_negation = inject_lora(
622
+ use_text_lora, text_encoder, text_encoder_lora_modules,
623
+ r=lora_rank, lora_path=lora_path
624
+ )
625
+
626
+ # Create parameters to optimize over with a condition (if "condition" is true, optimize it)
627
+ optim_params = [
628
+ param_optim(unet, trainable_modules is not None, extra_params=extra_unet_params, negation=unet_negation),
629
+ param_optim(text_encoder, train_text_encoder and not use_text_lora, extra_params=extra_text_encoder_params,
630
+ negation=text_encoder_negation
631
+ ),
632
+ param_optim(text_encoder_lora_params, use_text_lora, is_lora=True, extra_params={"lr": 1e-5}),
633
+ param_optim(unet_lora_params, use_unet_lora, is_lora=True, extra_params={"lr": 1e-5})
634
+ ]
635
+
636
+ params = create_optimizer_params(optim_params, learning_rate)
637
+
638
+ # Create Optimizer
639
+ optimizer = optimizer_cls(
640
+ params,
641
+ lr=learning_rate,
642
+ betas=(adam_beta1, adam_beta2),
643
+ weight_decay=adam_weight_decay,
644
+ eps=adam_epsilon,
645
+ )
646
+
647
+ # Scheduler
648
+ lr_scheduler = get_scheduler(
649
+ lr_scheduler,
650
+ optimizer=optimizer,
651
+ num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
652
+ num_training_steps=max_train_steps * gradient_accumulation_steps,
653
+ )
654
+
655
+ # Get the training dataset based on types (json, single_video, image)
656
+ train_datasets = get_train_dataset(dataset_types, train_data, tokenizer)
657
+
658
+ # Extend datasets that are less than the greatest one. This allows for more balanced training.
659
+ attrs = ['train_data', 'frames', 'image_dir', 'video_files']
660
+ extend_datasets(train_datasets, attrs, extend=extend_dataset)
661
+
662
+ # Process one dataset
663
+ if len(train_datasets) == 1:
664
+ train_dataset = train_datasets[0]
665
+
666
+ # Process many datasets
667
+ else:
668
+ train_dataset = torch.utils.data.ConcatDataset(train_datasets)
669
+
670
+ # DataLoaders creation:
671
+ train_dataloader = torch.utils.data.DataLoader(
672
+ train_dataset,
673
+ batch_size=train_batch_size,
674
+ shuffle=True
675
+ )
676
+
677
+ # Latents caching
678
+ cached_data_loader = handle_cache_latents(
679
+ cache_latents,
680
+ output_dir,
681
+ train_dataloader,
682
+ train_batch_size,
683
+ vae,
684
+ cached_latent_dir
685
+ )
686
+
687
+ if cached_data_loader is not None:
688
+ train_dataloader = cached_data_loader
689
+
690
+ # Prepare everything with our `accelerator`.
691
+ unet, optimizer,train_dataloader, lr_scheduler, text_encoder = accelerator.prepare(
692
+ unet,
693
+ optimizer,
694
+ train_dataloader,
695
+ lr_scheduler,
696
+ text_encoder
697
+ )
698
+
699
+ # Use Gradient Checkpointing if enabled.
700
+ unet_and_text_g_c(
701
+ unet,
702
+ text_encoder,
703
+ gradient_checkpointing,
704
+ text_encoder_gradient_checkpointing
705
+ )
706
+
707
+ # Enable VAE slicing to save memory.
708
+ vae.enable_slicing()
709
+
710
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
711
+ # as these models are only used for inference, keeping weights in full precision is not required.
712
+ weight_dtype = is_mixed_precision(accelerator)
713
+
714
+ # Move text encoders, and VAE to GPU
715
+ models_to_cast = [text_encoder, vae]
716
+ cast_to_gpu_and_type(models_to_cast, accelerator, weight_dtype)
717
+
718
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
719
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
720
+
721
+ # Afterwards we recalculate our number of training epochs
722
+ num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
723
+
724
+ # We need to initialize the trackers we use, and also store our configuration.
725
+ # The trackers initializes automatically on the main process.
726
+ if accelerator.is_main_process:
727
+ accelerator.init_trackers("text2video-fine-tune")
728
+
729
+ # Train!
730
+ total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
731
+
732
+ logger.info("***** Running training *****")
733
+ logger.info(f" Num examples = {len(train_dataset)}")
734
+ logger.info(f" Num Epochs = {num_train_epochs}")
735
+ logger.info(f" Instantaneous batch size per device = {train_batch_size}")
736
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
737
+ logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
738
+ logger.info(f" Total optimization steps = {max_train_steps}")
739
+ global_step = 0
740
+ first_epoch = 0
741
+
742
+ # Only show the progress bar once on each machine.
743
+ progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
744
+ progress_bar.set_description("Steps")
745
+
746
+ def finetune_unet(batch, train_encoder=False):
747
+
748
+ # Check if we are training the text encoder
749
+ text_trainable = (train_text_encoder or use_text_lora)
750
+
751
+ # Unfreeze UNET Layers
752
+ if global_step == 0:
753
+ already_printed_trainables = False
754
+ unet.train()
755
+ handle_trainable_modules(
756
+ unet,
757
+ trainable_modules,
758
+ is_enabled=True,
759
+ negation=unet_negation
760
+ )
761
+
762
+ # Convert videos to latent space
763
+ pixel_values = batch["pixel_values"]
764
+
765
+ if not cache_latents:
766
+ latents = tensor_to_vae_latent(pixel_values, vae)
767
+ else:
768
+ latents = pixel_values
769
+
770
+ # Get video length
771
+ video_length = latents.shape[2]
772
+
773
+ # Sample noise that we'll add to the latents
774
+ noise = sample_noise(latents, offset_noise_strength, use_offset_noise)
775
+ bsz = latents.shape[0]
776
+
777
+ # Sample a random timestep for each video
778
+ timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
779
+ timesteps = timesteps.long()
780
+
781
+ # Add noise to the latents according to the noise magnitude at each timestep
782
+ # (this is the forward diffusion process)
783
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
784
+
785
+ # Enable text encoder training
786
+ if text_trainable:
787
+ text_encoder.train()
788
+
789
+ if use_text_lora:
790
+ text_encoder.text_model.embeddings.requires_grad_(True)
791
+
792
+ if global_step == 0 and train_text_encoder:
793
+ handle_trainable_modules(
794
+ text_encoder,
795
+ trainable_modules=trainable_text_modules,
796
+ negation=text_encoder_negation
797
+ )
798
+ cast_to_gpu_and_type([text_encoder], accelerator, torch.float32)
799
+
800
+ # Fixes gradient checkpointing training.
801
+ # See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb
802
+ if gradient_checkpointing or text_encoder_gradient_checkpointing:
803
+ unet.eval()
804
+ text_encoder.eval()
805
+
806
+ # Encode text embeddings
807
+ token_ids = batch['prompt_ids']
808
+ encoder_hidden_states = text_encoder(token_ids)[0]
809
+
810
+ # Get the target for loss depending on the prediction type
811
+ if noise_scheduler.prediction_type == "epsilon":
812
+ target = noise
813
+
814
+ elif noise_scheduler.prediction_type == "v_prediction":
815
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
816
+
817
+ else:
818
+ raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}")
819
+
820
+
821
+ # Here we do two passes for video and text training.
822
+ # If we are on the second iteration of the loop, get one frame.
823
+ # This allows us to train text information only on the spatial layers.
824
+ losses = []
825
+ should_truncate_video = (video_length > 1 and text_trainable)
826
+
827
+ # We detach the encoder hidden states for the first pass (video frames > 1)
828
+ # Then we make a clone of the initial state to ensure we can train it in the loop.
829
+ detached_encoder_state = encoder_hidden_states.clone().detach()
830
+ trainable_encoder_state = encoder_hidden_states.clone()
831
+
832
+ for i in range(2):
833
+
834
+ should_detach = noisy_latents.shape[2] > 1 and i == 0
835
+
836
+ if should_truncate_video and i == 1:
837
+ noisy_latents = noisy_latents[:,:,1,:,:].unsqueeze(2)
838
+ target = target[:,:,1,:,:].unsqueeze(2)
839
+
840
+ encoder_hidden_states = (
841
+ detached_encoder_state if should_detach else trainable_encoder_state
842
+ )
843
+
844
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
845
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
846
+
847
+ losses.append(loss)
848
+
849
+ # This was most likely single frame training or a single image.
850
+ if video_length == 1 and i == 0: break
851
+
852
+ loss = losses[0] if len(losses) == 1 else losses[0] + losses[1]
853
+
854
+ return loss, latents
855
+
856
+ for epoch in range(first_epoch, num_train_epochs):
857
+ train_loss = 0.0
858
+
859
+ for step, batch in enumerate(train_dataloader):
860
+ # Skip steps until we reach the resumed step
861
+ if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
862
+ if step % gradient_accumulation_steps == 0:
863
+ progress_bar.update(1)
864
+ continue
865
+
866
+ with accelerator.accumulate(unet) ,accelerator.accumulate(text_encoder):
867
+
868
+ text_prompt = batch['text_prompt'][0]
869
+
870
+ with accelerator.autocast():
871
+ loss, latents = finetune_unet(batch, train_encoder=train_text_encoder)
872
+
873
+ # Gather the losses across all processes for logging (if we use distributed training).
874
+ avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
875
+ train_loss += avg_loss.item() / gradient_accumulation_steps
876
+
877
+ # Backpropagate
878
+ try:
879
+ accelerator.backward(loss)
880
+ params_to_clip = (
881
+ unet.parameters() if not train_text_encoder
882
+ else
883
+ list(unet.parameters()) + list(text_encoder.parameters())
884
+ )
885
+ accelerator.clip_grad_norm_(params_to_clip, max_grad_norm)
886
+
887
+ optimizer.step()
888
+ lr_scheduler.step()
889
+ optimizer.zero_grad(set_to_none=True)
890
+
891
+ except Exception as e:
892
+ print(f"An error has occured during backpropogation! {e}")
893
+ continue
894
+
895
+ # Checks if the accelerator has performed an optimization step behind the scenes
896
+ if accelerator.sync_gradients:
897
+ progress_bar.update(1)
898
+ global_step += 1
899
+ accelerator.log({"train_loss": train_loss}, step=global_step)
900
+ train_loss = 0.0
901
+
902
+ if global_step % checkpointing_steps == 0:
903
+ save_pipe(
904
+ pretrained_model_path,
905
+ global_step,
906
+ accelerator,
907
+ unet,
908
+ text_encoder,
909
+ vae,
910
+ output_dir,
911
+ use_unet_lora,
912
+ use_text_lora,
913
+ unet_lora_modules,
914
+ text_encoder_lora_modules,
915
+ is_checkpoint=True
916
+ )
917
+
918
+ if should_sample(global_step, validation_steps, validation_data):
919
+ if global_step == 1: print("Performing validation prompt.")
920
+ if accelerator.is_main_process:
921
+
922
+ with accelerator.autocast():
923
+ unet.eval()
924
+ text_encoder.eval()
925
+ unet_and_text_g_c(unet, text_encoder, False, False)
926
+
927
+ pipeline = TextToVideoSDPipeline.from_pretrained(
928
+ pretrained_model_path,
929
+ text_encoder=text_encoder,
930
+ vae=vae,
931
+ unet=unet
932
+ )
933
+
934
+ diffusion_scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
935
+ pipeline.scheduler = diffusion_scheduler
936
+
937
+ prompt = text_prompt if len(validation_data.prompt) <= 0 else validation_data.prompt
938
+
939
+ curr_dataset_name = batch['dataset']
940
+ save_filename = f"{global_step}_dataset-{curr_dataset_name}_{prompt}"
941
+
942
+ out_file = f"{output_dir}/samples/{save_filename}.mp4"
943
+
944
+ with torch.no_grad():
945
+ video_frames = pipeline(
946
+ prompt,
947
+ width=validation_data.width,
948
+ height=validation_data.height,
949
+ num_frames=validation_data.num_frames,
950
+ num_inference_steps=validation_data.num_inference_steps,
951
+ guidance_scale=validation_data.guidance_scale
952
+ ).frames
953
+ export_to_video(video_frames, out_file, train_data.get('fps', 8))
954
+
955
+ del pipeline
956
+ torch.cuda.empty_cache()
957
+
958
+ logger.info(f"Saved a new sample to {out_file}")
959
+
960
+ unet_and_text_g_c(
961
+ unet,
962
+ text_encoder,
963
+ gradient_checkpointing,
964
+ text_encoder_gradient_checkpointing
965
+ )
966
+
967
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
968
+ accelerator.log({"training_loss": loss.detach().item()}, step=step)
969
+ progress_bar.set_postfix(**logs)
970
+
971
+ if global_step >= max_train_steps:
972
+ break
973
+
974
+ # Create the pipeline using the trained modules and save it.
975
+ accelerator.wait_for_everyone()
976
+ if accelerator.is_main_process:
977
+ save_pipe(
978
+ pretrained_model_path,
979
+ global_step,
980
+ accelerator,
981
+ unet,
982
+ text_encoder,
983
+ vae,
984
+ output_dir,
985
+ use_unet_lora,
986
+ use_text_lora,
987
+ unet_lora_modules,
988
+ text_encoder_lora_modules,
989
+ is_checkpoint=False
990
+ )
991
+ accelerator.end_training()
992
+
993
+ if __name__ == "__main__":
994
+ parser = argparse.ArgumentParser()
995
+ parser.add_argument("--config", type=str, default="./configs/my_config.yaml")
996
+ args = parser.parse_args()
997
+
998
+ main(**OmegaConf.load(args.config))
unet_3d_blocks.py ADDED
@@ -0,0 +1,836 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.utils.checkpoint as checkpoint
17
+ from torch import nn
18
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
19
+ from diffusers.models.transformer_2d import Transformer2DModel
20
+ from diffusers.models.transformer_temporal import TransformerTemporalModel
21
+
22
+ # Assign gradient checkpoint function to simple variable for readability.
23
+ g_c = checkpoint.checkpoint
24
+
25
+ def use_temporal(module, num_frames, x):
26
+ if num_frames == 1:
27
+ if isinstance(module, TransformerTemporalModel):
28
+ return {"sample": x}
29
+ else:
30
+ return x
31
+
32
+ def custom_checkpoint(module, mode=None):
33
+ if mode == None: raise ValueError('Mode for gradient checkpointing cannot be none.')
34
+ custom_forward = None
35
+
36
+ if mode == 'resnet':
37
+ def custom_forward(hidden_states, temb):
38
+ inputs = module(hidden_states, temb)
39
+ return inputs
40
+
41
+ if mode == 'attn':
42
+ def custom_forward(
43
+ hidden_states,
44
+ encoder_hidden_states=None,
45
+ cross_attention_kwargs=None
46
+ ):
47
+ inputs = module(
48
+ hidden_states,
49
+ encoder_hidden_states,
50
+ cross_attention_kwargs
51
+ )
52
+ return inputs
53
+
54
+ if mode == 'temp':
55
+ def custom_forward(hidden_states, num_frames=None):
56
+ inputs = use_temporal(module, num_frames, hidden_states)
57
+ if inputs is None: inputs = module(
58
+ hidden_states,
59
+ num_frames=num_frames
60
+ )
61
+ return inputs
62
+
63
+ return custom_forward
64
+
65
+ def transformer_g_c(transformer, sample, num_frames):
66
+ sample = g_c(custom_checkpoint(transformer, mode='temp'),
67
+ sample, num_frames, use_reentrant=False
68
+ )['sample']
69
+
70
+ return sample
71
+
72
+ def cross_attn_g_c(
73
+ attn,
74
+ temp_attn,
75
+ resnet,
76
+ temp_conv,
77
+ hidden_states,
78
+ encoder_hidden_states,
79
+ cross_attention_kwargs,
80
+ temb,
81
+ num_frames,
82
+ inverse_temp=False
83
+ ):
84
+
85
+ def ordered_g_c(idx):
86
+
87
+ # Self and CrossAttention
88
+ if idx == 0: return g_c(custom_checkpoint(attn, mode='attn'),
89
+ hidden_states, encoder_hidden_states,cross_attention_kwargs, use_reentrant=False
90
+ )['sample']
91
+
92
+ # Temporal Self and CrossAttention
93
+ if idx == 1: return g_c(custom_checkpoint(temp_attn, mode='temp'),
94
+ hidden_states, num_frames, use_reentrant=False)['sample']
95
+
96
+ # Resnets
97
+ if idx == 2: return g_c(custom_checkpoint(resnet, mode='resnet'),
98
+ hidden_states, temb, use_reentrant=False)
99
+
100
+ # Temporal Convolutions
101
+ if idx == 3: return g_c(custom_checkpoint(temp_conv, mode='temp'),
102
+ hidden_states, num_frames, use_reentrant=False
103
+ )
104
+
105
+ # Here we call the function depending on the order in which they are called.
106
+ # For some layers, the orders are different, so we access the appropriate one by index.
107
+
108
+ if not inverse_temp:
109
+ for idx in [0,1,2,3]: hidden_states = ordered_g_c(idx)
110
+ else:
111
+ for idx in [2,3,0,1]: hidden_states = ordered_g_c(idx)
112
+
113
+ return hidden_states
114
+
115
+ def up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames):
116
+ hidden_states = g_c(custom_checkpoint(resnet, mode='resnet'), hidden_states, temb, use_reentrant=False)
117
+ hidden_states = g_c(custom_checkpoint(temp_conv, mode='temp'),
118
+ hidden_states, num_frames, use_reentrant=False
119
+ )
120
+ return hidden_states
121
+
122
+ def get_down_block(
123
+ down_block_type,
124
+ num_layers,
125
+ in_channels,
126
+ out_channels,
127
+ temb_channels,
128
+ add_downsample,
129
+ resnet_eps,
130
+ resnet_act_fn,
131
+ attn_num_head_channels,
132
+ resnet_groups=None,
133
+ cross_attention_dim=None,
134
+ downsample_padding=None,
135
+ dual_cross_attention=False,
136
+ use_linear_projection=True,
137
+ only_cross_attention=False,
138
+ upcast_attention=False,
139
+ resnet_time_scale_shift="default",
140
+ ):
141
+ if down_block_type == "DownBlock3D":
142
+ return DownBlock3D(
143
+ num_layers=num_layers,
144
+ in_channels=in_channels,
145
+ out_channels=out_channels,
146
+ temb_channels=temb_channels,
147
+ add_downsample=add_downsample,
148
+ resnet_eps=resnet_eps,
149
+ resnet_act_fn=resnet_act_fn,
150
+ resnet_groups=resnet_groups,
151
+ downsample_padding=downsample_padding,
152
+ resnet_time_scale_shift=resnet_time_scale_shift,
153
+ )
154
+ elif down_block_type == "CrossAttnDownBlock3D":
155
+ if cross_attention_dim is None:
156
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
157
+ return CrossAttnDownBlock3D(
158
+ num_layers=num_layers,
159
+ in_channels=in_channels,
160
+ out_channels=out_channels,
161
+ temb_channels=temb_channels,
162
+ add_downsample=add_downsample,
163
+ resnet_eps=resnet_eps,
164
+ resnet_act_fn=resnet_act_fn,
165
+ resnet_groups=resnet_groups,
166
+ downsample_padding=downsample_padding,
167
+ cross_attention_dim=cross_attention_dim,
168
+ attn_num_head_channels=attn_num_head_channels,
169
+ dual_cross_attention=dual_cross_attention,
170
+ use_linear_projection=use_linear_projection,
171
+ only_cross_attention=only_cross_attention,
172
+ upcast_attention=upcast_attention,
173
+ resnet_time_scale_shift=resnet_time_scale_shift,
174
+ )
175
+ raise ValueError(f"{down_block_type} does not exist.")
176
+
177
+
178
+ def get_up_block(
179
+ up_block_type,
180
+ num_layers,
181
+ in_channels,
182
+ out_channels,
183
+ prev_output_channel,
184
+ temb_channels,
185
+ add_upsample,
186
+ resnet_eps,
187
+ resnet_act_fn,
188
+ attn_num_head_channels,
189
+ resnet_groups=None,
190
+ cross_attention_dim=None,
191
+ dual_cross_attention=False,
192
+ use_linear_projection=True,
193
+ only_cross_attention=False,
194
+ upcast_attention=False,
195
+ resnet_time_scale_shift="default",
196
+ ):
197
+ if up_block_type == "UpBlock3D":
198
+ return UpBlock3D(
199
+ num_layers=num_layers,
200
+ in_channels=in_channels,
201
+ out_channels=out_channels,
202
+ prev_output_channel=prev_output_channel,
203
+ temb_channels=temb_channels,
204
+ add_upsample=add_upsample,
205
+ resnet_eps=resnet_eps,
206
+ resnet_act_fn=resnet_act_fn,
207
+ resnet_groups=resnet_groups,
208
+ resnet_time_scale_shift=resnet_time_scale_shift,
209
+ )
210
+ elif up_block_type == "CrossAttnUpBlock3D":
211
+ if cross_attention_dim is None:
212
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
213
+ return CrossAttnUpBlock3D(
214
+ num_layers=num_layers,
215
+ in_channels=in_channels,
216
+ out_channels=out_channels,
217
+ prev_output_channel=prev_output_channel,
218
+ temb_channels=temb_channels,
219
+ add_upsample=add_upsample,
220
+ resnet_eps=resnet_eps,
221
+ resnet_act_fn=resnet_act_fn,
222
+ resnet_groups=resnet_groups,
223
+ cross_attention_dim=cross_attention_dim,
224
+ attn_num_head_channels=attn_num_head_channels,
225
+ dual_cross_attention=dual_cross_attention,
226
+ use_linear_projection=use_linear_projection,
227
+ only_cross_attention=only_cross_attention,
228
+ upcast_attention=upcast_attention,
229
+ resnet_time_scale_shift=resnet_time_scale_shift,
230
+ )
231
+ raise ValueError(f"{up_block_type} does not exist.")
232
+
233
+
234
+ class UNetMidBlock3DCrossAttn(nn.Module):
235
+ def __init__(
236
+ self,
237
+ in_channels: int,
238
+ temb_channels: int,
239
+ dropout: float = 0.0,
240
+ num_layers: int = 1,
241
+ resnet_eps: float = 1e-6,
242
+ resnet_time_scale_shift: str = "default",
243
+ resnet_act_fn: str = "swish",
244
+ resnet_groups: int = 32,
245
+ resnet_pre_norm: bool = True,
246
+ attn_num_head_channels=1,
247
+ output_scale_factor=1.0,
248
+ cross_attention_dim=1280,
249
+ dual_cross_attention=False,
250
+ use_linear_projection=True,
251
+ upcast_attention=False,
252
+ ):
253
+ super().__init__()
254
+
255
+ self.gradient_checkpointing = False
256
+ self.has_cross_attention = True
257
+ self.attn_num_head_channels = attn_num_head_channels
258
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
259
+
260
+ # there is always at least one resnet
261
+ resnets = [
262
+ ResnetBlock2D(
263
+ in_channels=in_channels,
264
+ out_channels=in_channels,
265
+ temb_channels=temb_channels,
266
+ eps=resnet_eps,
267
+ groups=resnet_groups,
268
+ dropout=dropout,
269
+ time_embedding_norm=resnet_time_scale_shift,
270
+ non_linearity=resnet_act_fn,
271
+ output_scale_factor=output_scale_factor,
272
+ pre_norm=resnet_pre_norm,
273
+ )
274
+ ]
275
+ temp_convs = [
276
+ TemporalConvLayer(
277
+ in_channels,
278
+ in_channels,
279
+ )
280
+ ]
281
+ attentions = []
282
+ temp_attentions = []
283
+
284
+ for _ in range(num_layers):
285
+ attentions.append(
286
+ Transformer2DModel(
287
+ in_channels // attn_num_head_channels,
288
+ attn_num_head_channels,
289
+ in_channels=in_channels,
290
+ num_layers=1,
291
+ cross_attention_dim=cross_attention_dim,
292
+ norm_num_groups=resnet_groups,
293
+ use_linear_projection=use_linear_projection,
294
+ upcast_attention=upcast_attention,
295
+ )
296
+ )
297
+ temp_attentions.append(
298
+ TransformerTemporalModel(
299
+ in_channels // attn_num_head_channels,
300
+ attn_num_head_channels,
301
+ in_channels=in_channels,
302
+ num_layers=1,
303
+ cross_attention_dim=cross_attention_dim,
304
+ norm_num_groups=resnet_groups,
305
+ )
306
+ )
307
+ resnets.append(
308
+ ResnetBlock2D(
309
+ in_channels=in_channels,
310
+ out_channels=in_channels,
311
+ temb_channels=temb_channels,
312
+ eps=resnet_eps,
313
+ groups=resnet_groups,
314
+ dropout=dropout,
315
+ time_embedding_norm=resnet_time_scale_shift,
316
+ non_linearity=resnet_act_fn,
317
+ output_scale_factor=output_scale_factor,
318
+ pre_norm=resnet_pre_norm,
319
+ )
320
+ )
321
+ temp_convs.append(
322
+ TemporalConvLayer(
323
+ in_channels,
324
+ in_channels,
325
+ )
326
+ )
327
+
328
+ self.resnets = nn.ModuleList(resnets)
329
+ self.temp_convs = nn.ModuleList(temp_convs)
330
+ self.attentions = nn.ModuleList(attentions)
331
+ self.temp_attentions = nn.ModuleList(temp_attentions)
332
+
333
+ def forward(
334
+ self,
335
+ hidden_states,
336
+ temb=None,
337
+ encoder_hidden_states=None,
338
+ attention_mask=None,
339
+ num_frames=1,
340
+ cross_attention_kwargs=None,
341
+ ):
342
+ if self.gradient_checkpointing:
343
+ hidden_states = up_down_g_c(
344
+ self.resnets[0],
345
+ self.temp_convs[0],
346
+ hidden_states,
347
+ temb,
348
+ num_frames
349
+ )
350
+ else:
351
+ hidden_states = self.resnets[0](hidden_states, temb)
352
+ hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
353
+
354
+ for attn, temp_attn, resnet, temp_conv in zip(
355
+ self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
356
+ ):
357
+ if self.gradient_checkpointing:
358
+ hidden_states = cross_attn_g_c(
359
+ attn,
360
+ temp_attn,
361
+ resnet,
362
+ temp_conv,
363
+ hidden_states,
364
+ encoder_hidden_states,
365
+ cross_attention_kwargs,
366
+ temb,
367
+ num_frames
368
+ )
369
+ else:
370
+ hidden_states = attn(
371
+ hidden_states,
372
+ encoder_hidden_states=encoder_hidden_states,
373
+ cross_attention_kwargs=cross_attention_kwargs,
374
+ ).sample
375
+
376
+ if num_frames > 1:
377
+ hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
378
+
379
+ hidden_states = resnet(hidden_states, temb)
380
+
381
+ if num_frames > 1:
382
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
383
+
384
+ return hidden_states
385
+
386
+
387
+ class CrossAttnDownBlock3D(nn.Module):
388
+ def __init__(
389
+ self,
390
+ in_channels: int,
391
+ out_channels: int,
392
+ temb_channels: int,
393
+ dropout: float = 0.0,
394
+ num_layers: int = 1,
395
+ resnet_eps: float = 1e-6,
396
+ resnet_time_scale_shift: str = "default",
397
+ resnet_act_fn: str = "swish",
398
+ resnet_groups: int = 32,
399
+ resnet_pre_norm: bool = True,
400
+ attn_num_head_channels=1,
401
+ cross_attention_dim=1280,
402
+ output_scale_factor=1.0,
403
+ downsample_padding=1,
404
+ add_downsample=True,
405
+ dual_cross_attention=False,
406
+ use_linear_projection=False,
407
+ only_cross_attention=False,
408
+ upcast_attention=False,
409
+ ):
410
+ super().__init__()
411
+ resnets = []
412
+ attentions = []
413
+ temp_attentions = []
414
+ temp_convs = []
415
+
416
+ self.gradient_checkpointing = False
417
+ self.has_cross_attention = True
418
+ self.attn_num_head_channels = attn_num_head_channels
419
+
420
+ for i in range(num_layers):
421
+ in_channels = in_channels if i == 0 else out_channels
422
+ resnets.append(
423
+ ResnetBlock2D(
424
+ in_channels=in_channels,
425
+ out_channels=out_channels,
426
+ temb_channels=temb_channels,
427
+ eps=resnet_eps,
428
+ groups=resnet_groups,
429
+ dropout=dropout,
430
+ time_embedding_norm=resnet_time_scale_shift,
431
+ non_linearity=resnet_act_fn,
432
+ output_scale_factor=output_scale_factor,
433
+ pre_norm=resnet_pre_norm,
434
+ )
435
+ )
436
+ temp_convs.append(
437
+ TemporalConvLayer(
438
+ out_channels,
439
+ out_channels,
440
+ )
441
+ )
442
+ attentions.append(
443
+ Transformer2DModel(
444
+ out_channels // attn_num_head_channels,
445
+ attn_num_head_channels,
446
+ in_channels=out_channels,
447
+ num_layers=1,
448
+ cross_attention_dim=cross_attention_dim,
449
+ norm_num_groups=resnet_groups,
450
+ use_linear_projection=use_linear_projection,
451
+ only_cross_attention=only_cross_attention,
452
+ upcast_attention=upcast_attention,
453
+ )
454
+ )
455
+ temp_attentions.append(
456
+ TransformerTemporalModel(
457
+ out_channels // attn_num_head_channels,
458
+ attn_num_head_channels,
459
+ in_channels=out_channels,
460
+ num_layers=1,
461
+ cross_attention_dim=cross_attention_dim,
462
+ norm_num_groups=resnet_groups,
463
+ )
464
+ )
465
+ self.resnets = nn.ModuleList(resnets)
466
+ self.temp_convs = nn.ModuleList(temp_convs)
467
+ self.attentions = nn.ModuleList(attentions)
468
+ self.temp_attentions = nn.ModuleList(temp_attentions)
469
+
470
+ if add_downsample:
471
+ self.downsamplers = nn.ModuleList(
472
+ [
473
+ Downsample2D(
474
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
475
+ )
476
+ ]
477
+ )
478
+ else:
479
+ self.downsamplers = None
480
+
481
+ def forward(
482
+ self,
483
+ hidden_states,
484
+ temb=None,
485
+ encoder_hidden_states=None,
486
+ attention_mask=None,
487
+ num_frames=1,
488
+ cross_attention_kwargs=None,
489
+ ):
490
+ # TODO(Patrick, William) - attention mask is not used
491
+ output_states = ()
492
+
493
+ for resnet, temp_conv, attn, temp_attn in zip(
494
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
495
+ ):
496
+
497
+ if self.gradient_checkpointing:
498
+ hidden_states = cross_attn_g_c(
499
+ attn,
500
+ temp_attn,
501
+ resnet,
502
+ temp_conv,
503
+ hidden_states,
504
+ encoder_hidden_states,
505
+ cross_attention_kwargs,
506
+ temb,
507
+ num_frames,
508
+ inverse_temp=True
509
+ )
510
+ else:
511
+ hidden_states = resnet(hidden_states, temb)
512
+
513
+ if num_frames > 1:
514
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
515
+
516
+ hidden_states = attn(
517
+ hidden_states,
518
+ encoder_hidden_states=encoder_hidden_states,
519
+ cross_attention_kwargs=cross_attention_kwargs,
520
+ ).sample
521
+
522
+ if num_frames > 1:
523
+ hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
524
+
525
+ output_states += (hidden_states,)
526
+
527
+ if self.downsamplers is not None:
528
+ for downsampler in self.downsamplers:
529
+ hidden_states = downsampler(hidden_states)
530
+
531
+ output_states += (hidden_states,)
532
+
533
+ return hidden_states, output_states
534
+
535
+
536
+ class DownBlock3D(nn.Module):
537
+ def __init__(
538
+ self,
539
+ in_channels: int,
540
+ out_channels: int,
541
+ temb_channels: int,
542
+ dropout: float = 0.0,
543
+ num_layers: int = 1,
544
+ resnet_eps: float = 1e-6,
545
+ resnet_time_scale_shift: str = "default",
546
+ resnet_act_fn: str = "swish",
547
+ resnet_groups: int = 32,
548
+ resnet_pre_norm: bool = True,
549
+ output_scale_factor=1.0,
550
+ add_downsample=True,
551
+ downsample_padding=1,
552
+ ):
553
+ super().__init__()
554
+ resnets = []
555
+ temp_convs = []
556
+
557
+ self.gradient_checkpointing = False
558
+ for i in range(num_layers):
559
+ in_channels = in_channels if i == 0 else out_channels
560
+ resnets.append(
561
+ ResnetBlock2D(
562
+ in_channels=in_channels,
563
+ out_channels=out_channels,
564
+ temb_channels=temb_channels,
565
+ eps=resnet_eps,
566
+ groups=resnet_groups,
567
+ dropout=dropout,
568
+ time_embedding_norm=resnet_time_scale_shift,
569
+ non_linearity=resnet_act_fn,
570
+ output_scale_factor=output_scale_factor,
571
+ pre_norm=resnet_pre_norm,
572
+ )
573
+ )
574
+ temp_convs.append(
575
+ TemporalConvLayer(
576
+ out_channels,
577
+ out_channels,
578
+ )
579
+ )
580
+
581
+ self.resnets = nn.ModuleList(resnets)
582
+ self.temp_convs = nn.ModuleList(temp_convs)
583
+
584
+ if add_downsample:
585
+ self.downsamplers = nn.ModuleList(
586
+ [
587
+ Downsample2D(
588
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
589
+ )
590
+ ]
591
+ )
592
+ else:
593
+ self.downsamplers = None
594
+
595
+ def forward(self, hidden_states, temb=None, num_frames=1):
596
+ output_states = ()
597
+
598
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
599
+ if self.gradient_checkpointing:
600
+ hidden_states = up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames)
601
+ else:
602
+ hidden_states = resnet(hidden_states, temb)
603
+
604
+ if num_frames > 1:
605
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
606
+
607
+ output_states += (hidden_states,)
608
+
609
+ if self.downsamplers is not None:
610
+ for downsampler in self.downsamplers:
611
+ hidden_states = downsampler(hidden_states)
612
+
613
+ output_states += (hidden_states,)
614
+
615
+ return hidden_states, output_states
616
+
617
+
618
+ class CrossAttnUpBlock3D(nn.Module):
619
+ def __init__(
620
+ self,
621
+ in_channels: int,
622
+ out_channels: int,
623
+ prev_output_channel: int,
624
+ temb_channels: int,
625
+ dropout: float = 0.0,
626
+ num_layers: int = 1,
627
+ resnet_eps: float = 1e-6,
628
+ resnet_time_scale_shift: str = "default",
629
+ resnet_act_fn: str = "swish",
630
+ resnet_groups: int = 32,
631
+ resnet_pre_norm: bool = True,
632
+ attn_num_head_channels=1,
633
+ cross_attention_dim=1280,
634
+ output_scale_factor=1.0,
635
+ add_upsample=True,
636
+ dual_cross_attention=False,
637
+ use_linear_projection=False,
638
+ only_cross_attention=False,
639
+ upcast_attention=False,
640
+ ):
641
+ super().__init__()
642
+ resnets = []
643
+ temp_convs = []
644
+ attentions = []
645
+ temp_attentions = []
646
+
647
+ self.gradient_checkpointing = False
648
+ self.has_cross_attention = True
649
+ self.attn_num_head_channels = attn_num_head_channels
650
+
651
+ for i in range(num_layers):
652
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
653
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
654
+
655
+ resnets.append(
656
+ ResnetBlock2D(
657
+ in_channels=resnet_in_channels + res_skip_channels,
658
+ out_channels=out_channels,
659
+ temb_channels=temb_channels,
660
+ eps=resnet_eps,
661
+ groups=resnet_groups,
662
+ dropout=dropout,
663
+ time_embedding_norm=resnet_time_scale_shift,
664
+ non_linearity=resnet_act_fn,
665
+ output_scale_factor=output_scale_factor,
666
+ pre_norm=resnet_pre_norm,
667
+ )
668
+ )
669
+ temp_convs.append(
670
+ TemporalConvLayer(
671
+ out_channels,
672
+ out_channels,
673
+ )
674
+ )
675
+ attentions.append(
676
+ Transformer2DModel(
677
+ out_channels // attn_num_head_channels,
678
+ attn_num_head_channels,
679
+ in_channels=out_channels,
680
+ num_layers=1,
681
+ cross_attention_dim=cross_attention_dim,
682
+ norm_num_groups=resnet_groups,
683
+ use_linear_projection=use_linear_projection,
684
+ only_cross_attention=only_cross_attention,
685
+ upcast_attention=upcast_attention,
686
+ )
687
+ )
688
+ temp_attentions.append(
689
+ TransformerTemporalModel(
690
+ out_channels // attn_num_head_channels,
691
+ attn_num_head_channels,
692
+ in_channels=out_channels,
693
+ num_layers=1,
694
+ cross_attention_dim=cross_attention_dim,
695
+ norm_num_groups=resnet_groups,
696
+ )
697
+ )
698
+ self.resnets = nn.ModuleList(resnets)
699
+ self.temp_convs = nn.ModuleList(temp_convs)
700
+ self.attentions = nn.ModuleList(attentions)
701
+ self.temp_attentions = nn.ModuleList(temp_attentions)
702
+
703
+ if add_upsample:
704
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
705
+ else:
706
+ self.upsamplers = None
707
+
708
+ def forward(
709
+ self,
710
+ hidden_states,
711
+ res_hidden_states_tuple,
712
+ temb=None,
713
+ encoder_hidden_states=None,
714
+ upsample_size=None,
715
+ attention_mask=None,
716
+ num_frames=1,
717
+ cross_attention_kwargs=None,
718
+ ):
719
+ # TODO(Patrick, William) - attention mask is not used
720
+ for resnet, temp_conv, attn, temp_attn in zip(
721
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
722
+ ):
723
+ # pop res hidden states
724
+ res_hidden_states = res_hidden_states_tuple[-1]
725
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
726
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
727
+
728
+ if self.gradient_checkpointing:
729
+ hidden_states = cross_attn_g_c(
730
+ attn,
731
+ temp_attn,
732
+ resnet,
733
+ temp_conv,
734
+ hidden_states,
735
+ encoder_hidden_states,
736
+ cross_attention_kwargs,
737
+ temb,
738
+ num_frames,
739
+ inverse_temp=True
740
+ )
741
+ else:
742
+ hidden_states = resnet(hidden_states, temb)
743
+
744
+ if num_frames > 1:
745
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
746
+
747
+ hidden_states = attn(
748
+ hidden_states,
749
+ encoder_hidden_states=encoder_hidden_states,
750
+ cross_attention_kwargs=cross_attention_kwargs,
751
+ ).sample
752
+
753
+ if num_frames > 1:
754
+ hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample
755
+
756
+ if self.upsamplers is not None:
757
+ for upsampler in self.upsamplers:
758
+ hidden_states = upsampler(hidden_states, upsample_size)
759
+
760
+ return hidden_states
761
+
762
+
763
+ class UpBlock3D(nn.Module):
764
+ def __init__(
765
+ self,
766
+ in_channels: int,
767
+ prev_output_channel: int,
768
+ out_channels: int,
769
+ temb_channels: int,
770
+ dropout: float = 0.0,
771
+ num_layers: int = 1,
772
+ resnet_eps: float = 1e-6,
773
+ resnet_time_scale_shift: str = "default",
774
+ resnet_act_fn: str = "swish",
775
+ resnet_groups: int = 32,
776
+ resnet_pre_norm: bool = True,
777
+ output_scale_factor=1.0,
778
+ add_upsample=True,
779
+ ):
780
+ super().__init__()
781
+ resnets = []
782
+ temp_convs = []
783
+ self.gradient_checkpointing = False
784
+ for i in range(num_layers):
785
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
786
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
787
+
788
+ resnets.append(
789
+ ResnetBlock2D(
790
+ in_channels=resnet_in_channels + res_skip_channels,
791
+ out_channels=out_channels,
792
+ temb_channels=temb_channels,
793
+ eps=resnet_eps,
794
+ groups=resnet_groups,
795
+ dropout=dropout,
796
+ time_embedding_norm=resnet_time_scale_shift,
797
+ non_linearity=resnet_act_fn,
798
+ output_scale_factor=output_scale_factor,
799
+ pre_norm=resnet_pre_norm,
800
+ )
801
+ )
802
+ temp_convs.append(
803
+ TemporalConvLayer(
804
+ out_channels,
805
+ out_channels,
806
+ )
807
+ )
808
+
809
+ self.resnets = nn.ModuleList(resnets)
810
+ self.temp_convs = nn.ModuleList(temp_convs)
811
+
812
+ if add_upsample:
813
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
814
+ else:
815
+ self.upsamplers = None
816
+
817
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
818
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
819
+ # pop res hidden states
820
+ res_hidden_states = res_hidden_states_tuple[-1]
821
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
822
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
823
+
824
+ if self.gradient_checkpointing:
825
+ hidden_states = up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames)
826
+ else:
827
+ hidden_states = resnet(hidden_states, temb)
828
+
829
+ if num_frames > 1:
830
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
831
+
832
+ if self.upsamplers is not None:
833
+ for upsampler in self.upsamplers:
834
+ hidden_states = upsampler(hidden_states, upsample_size)
835
+
836
+ return hidden_states
unet_3d_condition.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2023 The ModelScope Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.transformer_temporal import TransformerTemporalModel
27
+ from .unet_3d_blocks import (
28
+ CrossAttnDownBlock3D,
29
+ CrossAttnUpBlock3D,
30
+ DownBlock3D,
31
+ UNetMidBlock3DCrossAttn,
32
+ UpBlock3D,
33
+ get_down_block,
34
+ get_up_block,
35
+ transformer_g_c
36
+ )
37
+
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ @dataclass
43
+ class UNet3DConditionOutput(BaseOutput):
44
+ """
45
+ Args:
46
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
47
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
48
+ """
49
+
50
+ sample: torch.FloatTensor
51
+
52
+
53
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
54
+ r"""
55
+ UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
56
+ and returns sample shaped output.
57
+
58
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
59
+ implements for all the models (such as downloading or saving, etc.)
60
+
61
+ Parameters:
62
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
63
+ Height and width of input/output sample.
64
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
65
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
66
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
67
+ The tuple of downsample blocks to use.
68
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
69
+ The tuple of upsample blocks to use.
70
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
71
+ The tuple of output channels for each block.
72
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
73
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
74
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
75
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
76
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
77
+ If `None`, it will skip the normalization and activation layers in post-processing
78
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
79
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
80
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
81
+ """
82
+
83
+ _supports_gradient_checkpointing = True
84
+
85
+ @register_to_config
86
+ def __init__(
87
+ self,
88
+ sample_size: Optional[int] = None,
89
+ in_channels: int = 4,
90
+ out_channels: int = 4,
91
+ down_block_types: Tuple[str] = (
92
+ "CrossAttnDownBlock3D",
93
+ "CrossAttnDownBlock3D",
94
+ "CrossAttnDownBlock3D",
95
+ "DownBlock3D",
96
+ ),
97
+ up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
98
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
99
+ layers_per_block: int = 2,
100
+ downsample_padding: int = 1,
101
+ mid_block_scale_factor: float = 1,
102
+ act_fn: str = "silu",
103
+ norm_num_groups: Optional[int] = 32,
104
+ norm_eps: float = 1e-5,
105
+ cross_attention_dim: int = 1024,
106
+ attention_head_dim: Union[int, Tuple[int]] = 64,
107
+ ):
108
+ super().__init__()
109
+
110
+ self.sample_size = sample_size
111
+ self.gradient_checkpointing = False
112
+ # Check inputs
113
+ if len(down_block_types) != len(up_block_types):
114
+ raise ValueError(
115
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
116
+ )
117
+
118
+ if len(block_out_channels) != len(down_block_types):
119
+ raise ValueError(
120
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
121
+ )
122
+
123
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
124
+ raise ValueError(
125
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
126
+ )
127
+
128
+ # input
129
+ conv_in_kernel = 3
130
+ conv_out_kernel = 3
131
+ conv_in_padding = (conv_in_kernel - 1) // 2
132
+ self.conv_in = nn.Conv2d(
133
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
134
+ )
135
+
136
+ # time
137
+ time_embed_dim = block_out_channels[0] * 4
138
+ self.time_proj = Timesteps(block_out_channels[0], True, 0)
139
+ timestep_input_dim = block_out_channels[0]
140
+
141
+ self.time_embedding = TimestepEmbedding(
142
+ timestep_input_dim,
143
+ time_embed_dim,
144
+ act_fn=act_fn,
145
+ )
146
+
147
+ self.transformer_in = TransformerTemporalModel(
148
+ num_attention_heads=8,
149
+ attention_head_dim=attention_head_dim,
150
+ in_channels=block_out_channels[0],
151
+ num_layers=1,
152
+ )
153
+
154
+ # class embedding
155
+ self.down_blocks = nn.ModuleList([])
156
+ self.up_blocks = nn.ModuleList([])
157
+
158
+ if isinstance(attention_head_dim, int):
159
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
160
+
161
+ # down
162
+ output_channel = block_out_channels[0]
163
+ for i, down_block_type in enumerate(down_block_types):
164
+ input_channel = output_channel
165
+ output_channel = block_out_channels[i]
166
+ is_final_block = i == len(block_out_channels) - 1
167
+
168
+ down_block = get_down_block(
169
+ down_block_type,
170
+ num_layers=layers_per_block,
171
+ in_channels=input_channel,
172
+ out_channels=output_channel,
173
+ temb_channels=time_embed_dim,
174
+ add_downsample=not is_final_block,
175
+ resnet_eps=norm_eps,
176
+ resnet_act_fn=act_fn,
177
+ resnet_groups=norm_num_groups,
178
+ cross_attention_dim=cross_attention_dim,
179
+ attn_num_head_channels=attention_head_dim[i],
180
+ downsample_padding=downsample_padding,
181
+ dual_cross_attention=False,
182
+ )
183
+ self.down_blocks.append(down_block)
184
+
185
+ # mid
186
+ self.mid_block = UNetMidBlock3DCrossAttn(
187
+ in_channels=block_out_channels[-1],
188
+ temb_channels=time_embed_dim,
189
+ resnet_eps=norm_eps,
190
+ resnet_act_fn=act_fn,
191
+ output_scale_factor=mid_block_scale_factor,
192
+ cross_attention_dim=cross_attention_dim,
193
+ attn_num_head_channels=attention_head_dim[-1],
194
+ resnet_groups=norm_num_groups,
195
+ dual_cross_attention=False,
196
+ )
197
+
198
+ # count how many layers upsample the images
199
+ self.num_upsamplers = 0
200
+
201
+ # up
202
+ reversed_block_out_channels = list(reversed(block_out_channels))
203
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
204
+
205
+ output_channel = reversed_block_out_channels[0]
206
+ for i, up_block_type in enumerate(up_block_types):
207
+ is_final_block = i == len(block_out_channels) - 1
208
+
209
+ prev_output_channel = output_channel
210
+ output_channel = reversed_block_out_channels[i]
211
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
212
+
213
+ # add upsample block for all BUT final layer
214
+ if not is_final_block:
215
+ add_upsample = True
216
+ self.num_upsamplers += 1
217
+ else:
218
+ add_upsample = False
219
+
220
+ up_block = get_up_block(
221
+ up_block_type,
222
+ num_layers=layers_per_block + 1,
223
+ in_channels=input_channel,
224
+ out_channels=output_channel,
225
+ prev_output_channel=prev_output_channel,
226
+ temb_channels=time_embed_dim,
227
+ add_upsample=add_upsample,
228
+ resnet_eps=norm_eps,
229
+ resnet_act_fn=act_fn,
230
+ resnet_groups=norm_num_groups,
231
+ cross_attention_dim=cross_attention_dim,
232
+ attn_num_head_channels=reversed_attention_head_dim[i],
233
+ dual_cross_attention=False,
234
+ )
235
+ self.up_blocks.append(up_block)
236
+ prev_output_channel = output_channel
237
+
238
+ # out
239
+ if norm_num_groups is not None:
240
+ self.conv_norm_out = nn.GroupNorm(
241
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
242
+ )
243
+ self.conv_act = nn.SiLU()
244
+ else:
245
+ self.conv_norm_out = None
246
+ self.conv_act = None
247
+
248
+ conv_out_padding = (conv_out_kernel - 1) // 2
249
+ self.conv_out = nn.Conv2d(
250
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
251
+ )
252
+
253
+ def set_attention_slice(self, slice_size):
254
+ r"""
255
+ Enable sliced attention computation.
256
+
257
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
258
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
259
+
260
+ Args:
261
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
262
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
263
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
264
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
265
+ must be a multiple of `slice_size`.
266
+ """
267
+ sliceable_head_dims = []
268
+
269
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
270
+ if hasattr(module, "set_attention_slice"):
271
+ sliceable_head_dims.append(module.sliceable_head_dim)
272
+
273
+ for child in module.children():
274
+ fn_recursive_retrieve_slicable_dims(child)
275
+
276
+ # retrieve number of attention layers
277
+ for module in self.children():
278
+ fn_recursive_retrieve_slicable_dims(module)
279
+
280
+ num_slicable_layers = len(sliceable_head_dims)
281
+
282
+ if slice_size == "auto":
283
+ # half the attention head size is usually a good trade-off between
284
+ # speed and memory
285
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
286
+ elif slice_size == "max":
287
+ # make smallest slice possible
288
+ slice_size = num_slicable_layers * [1]
289
+
290
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
291
+
292
+ if len(slice_size) != len(sliceable_head_dims):
293
+ raise ValueError(
294
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
295
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
296
+ )
297
+
298
+ for i in range(len(slice_size)):
299
+ size = slice_size[i]
300
+ dim = sliceable_head_dims[i]
301
+ if size is not None and size > dim:
302
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
303
+
304
+ # Recursively walk through all the children.
305
+ # Any children which exposes the set_attention_slice method
306
+ # gets the message
307
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
308
+ if hasattr(module, "set_attention_slice"):
309
+ module.set_attention_slice(slice_size.pop())
310
+
311
+ for child in module.children():
312
+ fn_recursive_set_attention_slice(child, slice_size)
313
+
314
+ reversed_slice_size = list(reversed(slice_size))
315
+ for module in self.children():
316
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
317
+
318
+ def _set_gradient_checkpointing(self, value=False):
319
+ self.gradient_checkpointing = value
320
+ self.mid_block.gradient_checkpointing = value
321
+ for module in self.down_blocks + self.up_blocks:
322
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
323
+ module.gradient_checkpointing = value
324
+
325
+ def forward(
326
+ self,
327
+ sample: torch.FloatTensor,
328
+ timestep: Union[torch.Tensor, float, int],
329
+ encoder_hidden_states: torch.Tensor,
330
+ class_labels: Optional[torch.Tensor] = None,
331
+ timestep_cond: Optional[torch.Tensor] = None,
332
+ attention_mask: Optional[torch.Tensor] = None,
333
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
334
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
335
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
336
+ return_dict: bool = True,
337
+ ) -> Union[UNet3DConditionOutput, Tuple]:
338
+ r"""
339
+ Args:
340
+ sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor
341
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
342
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
343
+ return_dict (`bool`, *optional*, defaults to `True`):
344
+ Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple.
345
+ cross_attention_kwargs (`dict`, *optional*):
346
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
347
+ `self.processor` in
348
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
349
+
350
+ Returns:
351
+ [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`:
352
+ [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
353
+ returning a tuple, the first element is the sample tensor.
354
+ """
355
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
356
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
357
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
358
+ # on the fly if necessary.
359
+ default_overall_up_factor = 2**self.num_upsamplers
360
+
361
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
362
+ forward_upsample_size = False
363
+ upsample_size = None
364
+
365
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
366
+ logger.info("Forward upsample size to force interpolation output size.")
367
+ forward_upsample_size = True
368
+
369
+ # prepare attention_mask
370
+ if attention_mask is not None:
371
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
372
+ attention_mask = attention_mask.unsqueeze(1)
373
+
374
+ # 1. time
375
+ timesteps = timestep
376
+ if not torch.is_tensor(timesteps):
377
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
378
+ # This would be a good case for the `match` statement (Python 3.10+)
379
+ is_mps = sample.device.type == "mps"
380
+ if isinstance(timestep, float):
381
+ dtype = torch.float32 if is_mps else torch.float64
382
+ else:
383
+ dtype = torch.int32 if is_mps else torch.int64
384
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
385
+ elif len(timesteps.shape) == 0:
386
+ timesteps = timesteps[None].to(sample.device)
387
+
388
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
389
+ num_frames = sample.shape[2]
390
+ timesteps = timesteps.expand(sample.shape[0])
391
+
392
+ t_emb = self.time_proj(timesteps)
393
+
394
+ # timesteps does not contain any weights and will always return f32 tensors
395
+ # but time_embedding might actually be running in fp16. so we need to cast here.
396
+ # there might be better ways to encapsulate this.
397
+ t_emb = t_emb.to(dtype=self.dtype)
398
+
399
+ emb = self.time_embedding(t_emb, timestep_cond)
400
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
401
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
402
+
403
+ # 2. pre-process
404
+ sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
405
+ sample = self.conv_in(sample)
406
+
407
+ if self.gradient_checkpointing:
408
+ sample = transformer_g_c(self.transformer_in, sample, num_frames)
409
+ else:
410
+ sample = self.transformer_in(sample, num_frames=num_frames).sample
411
+
412
+ # 3. down
413
+ down_block_res_samples = (sample,)
414
+ for downsample_block in self.down_blocks:
415
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
416
+ sample, res_samples = downsample_block(
417
+ hidden_states=sample,
418
+ temb=emb,
419
+ encoder_hidden_states=encoder_hidden_states,
420
+ attention_mask=attention_mask,
421
+ num_frames=num_frames,
422
+ cross_attention_kwargs=cross_attention_kwargs,
423
+ )
424
+ else:
425
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
426
+
427
+ down_block_res_samples += res_samples
428
+
429
+ if down_block_additional_residuals is not None:
430
+ new_down_block_res_samples = ()
431
+
432
+ for down_block_res_sample, down_block_additional_residual in zip(
433
+ down_block_res_samples, down_block_additional_residuals
434
+ ):
435
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
436
+ new_down_block_res_samples += (down_block_res_sample,)
437
+
438
+ down_block_res_samples = new_down_block_res_samples
439
+
440
+ # 4. mid
441
+ if self.mid_block is not None:
442
+ sample = self.mid_block(
443
+ sample,
444
+ emb,
445
+ encoder_hidden_states=encoder_hidden_states,
446
+ attention_mask=attention_mask,
447
+ num_frames=num_frames,
448
+ cross_attention_kwargs=cross_attention_kwargs,
449
+ )
450
+
451
+ if mid_block_additional_residual is not None:
452
+ sample = sample + mid_block_additional_residual
453
+
454
+ # 5. up
455
+ for i, upsample_block in enumerate(self.up_blocks):
456
+ is_final_block = i == len(self.up_blocks) - 1
457
+
458
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
459
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
460
+
461
+ # if we have not reached the final block and need to forward the
462
+ # upsample size, we do it here
463
+ if not is_final_block and forward_upsample_size:
464
+ upsample_size = down_block_res_samples[-1].shape[2:]
465
+
466
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
467
+ sample = upsample_block(
468
+ hidden_states=sample,
469
+ temb=emb,
470
+ res_hidden_states_tuple=res_samples,
471
+ encoder_hidden_states=encoder_hidden_states,
472
+ upsample_size=upsample_size,
473
+ attention_mask=attention_mask,
474
+ num_frames=num_frames,
475
+ cross_attention_kwargs=cross_attention_kwargs,
476
+ )
477
+ else:
478
+ sample = upsample_block(
479
+ hidden_states=sample,
480
+ temb=emb,
481
+ res_hidden_states_tuple=res_samples,
482
+ upsample_size=upsample_size,
483
+ num_frames=num_frames,
484
+ )
485
+
486
+ # 6. post-process
487
+ if self.conv_norm_out:
488
+ sample = self.conv_norm_out(sample)
489
+ sample = self.conv_act(sample)
490
+
491
+ sample = self.conv_out(sample)
492
+
493
+ # reshape to (batch, channel, framerate, width, height)
494
+ sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
495
+
496
+ if not return_dict:
497
+ return (sample,)
498
+
499
+ return UNet3DConditionOutput(sample=sample)