jayw commited on
Commit
ff0b056
1 Parent(s): 75e41fe

initial commit

Browse files
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Show 1
3
- emoji: 🐢
4
- colorFrom: yellow
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.50.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: Show-1
3
+ emoji: 🎬
4
+ colorFrom: red
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.39.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from share_btn import community_icon_html, loading_icon_html, share_js
3
+ import torch
4
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
5
+ from diffusers.utils import export_to_video
6
+
7
+ import os
8
+ import imageio
9
+ from PIL import Image
10
+
11
+ import torch.nn.functional as F
12
+
13
+ from diffusers import IFSuperResolutionPipeline, VideoToVideoSDPipeline
14
+ from diffusers.utils import export_to_video
15
+ from diffusers.utils.torch_utils import randn_tensor
16
+
17
+ from showone.pipelines import TextToVideoIFPipeline, TextToVideoIFInterpPipeline, TextToVideoIFSuperResolutionPipeline
18
+ from showone.pipelines.pipeline_t2v_base_pixel import tensor2vid
19
+ from showone.pipelines.pipeline_t2v_sr_pixel_cond import TextToVideoIFSuperResolutionPipeline_Cond
20
+
21
+
22
+ # Base Model
23
+ pretrained_model_path = "showlab/show-1-base"
24
+ pipe_base = TextToVideoIFPipeline.from_pretrained(
25
+ pretrained_model_path,
26
+ torch_dtype=torch.float16,
27
+ variant="fp16"
28
+ )
29
+ pipe_base.enable_model_cpu_offload()
30
+
31
+ # Interpolation Model
32
+ pretrained_model_path = "showlab/show-1-interpolation"
33
+ pipe_interp_1 = TextToVideoIFInterpPipeline.from_pretrained(
34
+ pretrained_model_path,
35
+ torch_dtype=torch.float16,
36
+ variant="fp16"
37
+ )
38
+ pipe_interp_1.enable_model_cpu_offload()
39
+
40
+ # Super-Resolution Model 1
41
+ # Image super-resolution model from DeepFloyd https://huggingface.co/DeepFloyd/IF-II-L-v1.0
42
+ pretrained_model_path = "DeepFloyd/IF-II-L-v1.0"
43
+ pipe_sr_1_image = IFSuperResolutionPipeline.from_pretrained(
44
+ pretrained_model_path,
45
+ text_encoder=None,
46
+ torch_dtype=torch.float16,
47
+ variant="fp16"
48
+ )
49
+ pipe_sr_1_image.enable_model_cpu_offload()
50
+
51
+ pretrained_model_path = "showlab/show-1-sr1"
52
+ pipe_sr_1_cond = TextToVideoIFSuperResolutionPipeline_Cond.from_pretrained(
53
+ pretrained_model_path,
54
+ torch_dtype=torch.float16
55
+ )
56
+ pipe_sr_1_cond.enable_model_cpu_offload()
57
+
58
+ # Super-Resolution Model 2
59
+ pretrained_model_path = "showlab/show-1-sr2"
60
+ pipe_sr_2 = VideoToVideoSDPipeline.from_pretrained(
61
+ pretrained_model_path,
62
+ torch_dtype=torch.float16
63
+ )
64
+ pipe_sr_2.enable_model_cpu_offload()
65
+ pipe_sr_2.enable_vae_slicing()
66
+
67
+ def infer(prompt):
68
+ print(prompt)
69
+ negative_prompt = "low resolution, blur"
70
+
71
+ # Text embeds
72
+ prompt_embeds, negative_embeds = pipe_base.encode_prompt(prompt)
73
+
74
+ # Keyframes generation (8x64x40, 2fps)
75
+ video_frames = pipe_base(
76
+ prompt_embeds=prompt_embeds,
77
+ negative_prompt_embeds=negative_embeds,
78
+ num_frames=8,
79
+ height=40,
80
+ width=64,
81
+ num_inference_steps=75,
82
+ guidance_scale=9.0,
83
+ output_type="pt"
84
+ ).frames
85
+
86
+ # Frame interpolation (8x64x40, 2fps -> 29x64x40, 7.5fps)
87
+ bsz, channel, num_frames, height, width = video_frames.shape
88
+ new_num_frames = 3 * (num_frames - 1) + num_frames
89
+ new_video_frames = torch.zeros((bsz, channel, new_num_frames, height, width),
90
+ dtype=video_frames.dtype, device=video_frames.device)
91
+ new_video_frames[:, :, torch.arange(0, new_num_frames, 4), ...] = video_frames
92
+ init_noise = randn_tensor((bsz, channel, 5, height, width), dtype=video_frames.dtype,
93
+ device=video_frames.device)
94
+
95
+ for i in range(num_frames - 1):
96
+ batch_i = torch.zeros((bsz, channel, 5, height, width), dtype=video_frames.dtype, device=video_frames.device)
97
+ batch_i[:, :, 0, ...] = video_frames[:, :, i, ...]
98
+ batch_i[:, :, -1, ...] = video_frames[:, :, i + 1, ...]
99
+ batch_i = pipe_interp_1(
100
+ pixel_values=batch_i,
101
+ prompt_embeds=prompt_embeds,
102
+ negative_prompt_embeds=negative_embeds,
103
+ num_frames=batch_i.shape[2],
104
+ height=40,
105
+ width=64,
106
+ num_inference_steps=50,
107
+ guidance_scale=4.0,
108
+ output_type="pt",
109
+ init_noise=init_noise,
110
+ cond_interpolation=True,
111
+ ).frames
112
+
113
+ new_video_frames[:, :, i * 4:i * 4 + 5, ...] = batch_i
114
+
115
+ video_frames = new_video_frames
116
+
117
+ # Super-resolution 1 (29x64x40 -> 29x256x160)
118
+ bsz, channel, num_frames, height, width = video_frames.shape
119
+ window_size, stride = 8, 7
120
+ new_video_frames = torch.zeros(
121
+ (bsz, channel, num_frames, height * 4, width * 4),
122
+ dtype=video_frames.dtype,
123
+ device=video_frames.device)
124
+ for i in range(0, num_frames - window_size + 1, stride):
125
+ batch_i = video_frames[:, :, i:i + window_size, ...]
126
+
127
+ if i == 0:
128
+ first_frame_cond = pipe_sr_1_image(
129
+ image=video_frames[:, :, 0, ...],
130
+ prompt_embeds=prompt_embeds,
131
+ negative_prompt_embeds=negative_embeds,
132
+ height=height * 4,
133
+ width=width * 4,
134
+ num_inference_steps=50,
135
+ guidance_scale=4.0,
136
+ noise_level=150,
137
+ output_type="pt"
138
+ ).images
139
+ first_frame_cond = first_frame_cond.unsqueeze(2)
140
+ else:
141
+ first_frame_cond = new_video_frames[:, :, i:i + 1, ...]
142
+
143
+ batch_i = pipe_sr_1_cond(
144
+ image=batch_i,
145
+ prompt_embeds=prompt_embeds,
146
+ negative_prompt_embeds=negative_embeds,
147
+ first_frame_cond=first_frame_cond,
148
+ height=height * 4,
149
+ width=width * 4,
150
+ num_inference_steps=50,
151
+ guidance_scale=7.0,
152
+ noise_level=250,
153
+ output_type="pt"
154
+ ).frames
155
+ new_video_frames[:, :, i:i + window_size, ...] = batch_i
156
+
157
+ video_frames = new_video_frames
158
+
159
+ # Super-resolution 2 (29x256x160 -> 29x576x320)
160
+ video_frames = [Image.fromarray(frame).resize((576, 320)) for frame in tensor2vid(video_frames.clone())]
161
+ video_frames = pipe_sr_2(
162
+ prompt,
163
+ negative_prompt=negative_prompt,
164
+ video=video_frames,
165
+ strength=0.8,
166
+ num_inference_steps=50,
167
+ ).frames
168
+
169
+ video_path = export_to_video(video_frames)
170
+ print(video_path)
171
+ return video_path, gr.Group.update(visible=True)
172
+
173
+ css = """
174
+ #col-container {max-width: 510px; margin-left: auto; margin-right: auto;}
175
+ a {text-decoration-line: underline; font-weight: 600;}
176
+ .animate-spin {
177
+ animation: spin 1s linear infinite;
178
+ }
179
+
180
+ @keyframes spin {
181
+ from {
182
+ transform: rotate(0deg);
183
+ }
184
+ to {
185
+ transform: rotate(360deg);
186
+ }
187
+ }
188
+
189
+ #share-btn-container {
190
+ display: flex;
191
+ padding-left: 0.5rem !important;
192
+ padding-right: 0.5rem !important;
193
+ background-color: #000000;
194
+ justify-content: center;
195
+ align-items: center;
196
+ border-radius: 9999px !important;
197
+ max-width: 15rem;
198
+ height: 36px;
199
+ }
200
+
201
+ div#share-btn-container > div {
202
+ flex-direction: row;
203
+ background: black;
204
+ align-items: center;
205
+ }
206
+
207
+ #share-btn-container:hover {
208
+ background-color: #060606;
209
+ }
210
+
211
+ #share-btn {
212
+ all: initial;
213
+ color: #ffffff;
214
+ font-weight: 600;
215
+ cursor:pointer;
216
+ font-family: 'IBM Plex Sans', sans-serif;
217
+ margin-left: 0.5rem !important;
218
+ padding-top: 0.5rem !important;
219
+ padding-bottom: 0.5rem !important;
220
+ right:0;
221
+ }
222
+
223
+ #share-btn * {
224
+ all: unset;
225
+ }
226
+
227
+ #share-btn-container div:nth-child(-n+2){
228
+ width: auto !important;
229
+ min-height: 0px !important;
230
+ }
231
+
232
+ #share-btn-container .wrap {
233
+ display: none !important;
234
+ }
235
+
236
+ #share-btn-container.hidden {
237
+ display: none!important;
238
+ }
239
+ img[src*='#center'] {
240
+ display: inline-block;
241
+ margin: unset;
242
+ }
243
+
244
+ .footer {
245
+ margin-bottom: 45px;
246
+ margin-top: 10px;
247
+ text-align: center;
248
+ border-bottom: 1px solid #e5e5e5;
249
+ }
250
+ .footer>p {
251
+ font-size: .8rem;
252
+ display: inline-block;
253
+ padding: 0 10px;
254
+ transform: translateY(10px);
255
+ background: white;
256
+ }
257
+ .dark .footer {
258
+ border-color: #303030;
259
+ }
260
+ .dark .footer>p {
261
+ background: #0b0f19;
262
+ }
263
+ """
264
+
265
+ with gr.Blocks(css=css) as demo:
266
+ with gr.Column(elem_id="col-container"):
267
+ gr.Markdown(
268
+ """
269
+ <h1 style="text-align: center;">Show-1 Text-to-Video</h1>
270
+ <p style="text-align: center;">
271
+ A text-to-video generation model that marries the strength and alleviates the weakness of pixel-based and latent-based VDMs. <br />
272
+ </p>
273
+
274
+ <p style="text-align: center;">
275
+ <a href="https://arxiv.org/abs/2309.15818" target="_blank">Paper</a> |
276
+ <a href="https://showlab.github.io/Show-1" target="_blank">Project Page</a> |
277
+ <a href="https://github.com/showlab/Show-1" target="_blank">Github</a>
278
+ </p>
279
+
280
+ """
281
+ )
282
+
283
+ prompt_in = gr.Textbox(label="Prompt", placeholder="A panda taking a selfie", elem_id="prompt-in")
284
+ #neg_prompt = gr.Textbox(label="Negative prompt", value="text, watermark, copyright, blurry, nsfw", elem_id="neg-prompt-in")
285
+ #inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=40, interactive=False)
286
+ submit_btn = gr.Button("Submit")
287
+ video_result = gr.Video(label="Video Output", elem_id="video-output")
288
+
289
+ with gr.Row():
290
+ with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
291
+ community_icon = gr.HTML(community_icon_html)
292
+ loading_icon = gr.HTML(loading_icon_html)
293
+ share_button = gr.Button("Share with Community", elem_id="share-btn")
294
+
295
+ gr.Markdown("""
296
+ [![Duplicate this Space](https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-lg.svg#center)](https://huggingface.co/spaces/showlab/Show-1?duplicate=true)
297
+ """)
298
+
299
+ gr.HTML("""
300
+ <div class="footer">
301
+ <p>
302
+ Demo adapted from <a href="https://huggingface.co/spaces/fffiloni/zeroscope" target="_blank">zeroscope</a>
303
+ by 🤗 <a href="https://twitter.com/fffiloni" target="_blank">Sylvain Filoni</a>
304
+ </p>
305
+ </div>
306
+ """)
307
+
308
+ submit_btn.click(fn=infer,
309
+ inputs=[prompt_in],
310
+ outputs=[video_result, share_group],
311
+ api_name="show-1")
312
+
313
+ share_button.click(None, [], [], _js=share_js)
314
+
315
+ demo.queue(max_size=12).launch(show_api=True, share=True)
316
+
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.19.3
2
+ bitsandbytes==0.35.4
3
+ decord==0.6.0
4
+ transformers==4.29.1
5
+ accelerate==0.18.0
6
+ imageio==2.14.1
7
+ torch==2.0.0
8
+ torchvision==0.15.0
9
+ beautifulsoup4
10
+ tensorboard
11
+ sentencepiece
12
+ safetensors
13
+ modelcards
14
+ omegaconf
15
+ pandas
16
+ einops
17
+ ftfy
18
+ opencv-python
19
+
share_btn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
2
+ <path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
3
+ <path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
4
+ </svg>"""
5
+
6
+ loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin"
7
+ style="color: #ffffff;
8
+ "
9
+ xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
10
+
11
+ share_js = """async () => {
12
+ async function uploadFile(file){
13
+ const UPLOAD_URL = 'https://huggingface.co/uploads';
14
+ const response = await fetch(UPLOAD_URL, {
15
+ method: 'POST',
16
+ headers: {
17
+ 'Content-Type': file.type,
18
+ 'X-Requested-With': 'XMLHttpRequest',
19
+ },
20
+ body: file, /// <- File inherits from Blob
21
+ });
22
+ const url = await response.text();
23
+ return url;
24
+ }
25
+
26
+ async function getVideoBlobFile(videoEL){
27
+ const res = await fetch(videoEL.src);
28
+ const blob = await res.blob();
29
+ const videoId = Date.now() % 200;
30
+ const fileName = `vid-show1-${{videoId}}.mp4`;
31
+ const videoBlob = new File([blob], fileName, { type: 'video/mp4' });
32
+ console.log(videoBlob);
33
+ return videoBlob;
34
+ }
35
+
36
+ const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
37
+ const captionTxt = gradioEl.querySelector('#prompt-in textarea').value;
38
+ const outputVideo = gradioEl.querySelector('#video-output video');
39
+
40
+
41
+ const shareBtnEl = gradioEl.querySelector('#share-btn');
42
+ const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
43
+ const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
44
+ if(!outputVideo){
45
+ return;
46
+ };
47
+ shareBtnEl.style.pointerEvents = 'none';
48
+ shareIconEl.style.display = 'none';
49
+ loadingIconEl.style.removeProperty('display');
50
+
51
+
52
+ const videoOutFile = await getVideoBlobFile(outputVideo);
53
+ const dataOutputVid = await uploadFile(videoOutFile);
54
+
55
+ const descriptionMd = `
56
+ #### Prompt:
57
+ ${captionTxt}
58
+
59
+ #### Show-1 video result:
60
+ ${dataOutputVid}
61
+
62
+ `;
63
+ const params = new URLSearchParams({
64
+ title: captionTxt,
65
+ description: descriptionMd,
66
+ });
67
+ const paramsStr = params.toString();
68
+ window.open(`https://huggingface.co/spaces/showlab/Show-1/discussions/new?${paramsStr}`, '_blank');
69
+ shareBtnEl.style.removeProperty('pointer-events');
70
+ shareIconEl.style.removeProperty('display');
71
+ loadingIconEl.style.display = 'none';
72
+ }"""
showone/models/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .unet_3d_condition import UNet3DConditionModel
showone/models/transformer_temporal.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.utils import BaseOutput
22
+ from diffusers.models.attention import BasicTransformerBlock
23
+ from diffusers.models.modeling_utils import ModelMixin
24
+
25
+
26
+ @dataclass
27
+ class TransformerTemporalModelOutput(BaseOutput):
28
+ """
29
+ The output of [`TransformerTemporalModel`].
30
+
31
+ Args:
32
+ sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
33
+ The hidden states output conditioned on `encoder_hidden_states` input.
34
+ """
35
+
36
+ sample: torch.FloatTensor
37
+
38
+
39
+ class TransformerTemporalModel(ModelMixin, ConfigMixin):
40
+ """
41
+ A Transformer model for video-like data.
42
+
43
+ Parameters:
44
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
45
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
46
+ in_channels (`int`, *optional*):
47
+ The number of channels in the input and output (specify if the input is **continuous**).
48
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
49
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
50
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
51
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
52
+ This is fixed during training since it is used to learn a number of position embeddings.
53
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
54
+ attention_bias (`bool`, *optional*):
55
+ Configure if the `TransformerBlock` attention should contain a bias parameter.
56
+ double_self_attention (`bool`, *optional*):
57
+ Configure if each `TransformerBlock` should contain two self-attention layers.
58
+ """
59
+
60
+ @register_to_config
61
+ def __init__(
62
+ self,
63
+ num_attention_heads: int = 16,
64
+ attention_head_dim: int = 88,
65
+ in_channels: Optional[int] = None,
66
+ out_channels: Optional[int] = None,
67
+ num_layers: int = 1,
68
+ dropout: float = 0.0,
69
+ norm_num_groups: int = 32,
70
+ cross_attention_dim: Optional[int] = None,
71
+ attention_bias: bool = False,
72
+ sample_size: Optional[int] = None,
73
+ activation_fn: str = "geglu",
74
+ norm_elementwise_affine: bool = True,
75
+ double_self_attention: bool = True,
76
+ ):
77
+ super().__init__()
78
+ self.num_attention_heads = num_attention_heads
79
+ self.attention_head_dim = attention_head_dim
80
+ inner_dim = num_attention_heads * attention_head_dim
81
+
82
+ self.in_channels = in_channels
83
+
84
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
85
+ self.proj_in = nn.Linear(in_channels, inner_dim)
86
+
87
+ # 3. Define transformers blocks
88
+ self.transformer_blocks = nn.ModuleList(
89
+ [
90
+ BasicTransformerBlock(
91
+ inner_dim,
92
+ num_attention_heads,
93
+ attention_head_dim,
94
+ dropout=dropout,
95
+ cross_attention_dim=cross_attention_dim,
96
+ activation_fn=activation_fn,
97
+ attention_bias=attention_bias,
98
+ double_self_attention=double_self_attention,
99
+ norm_elementwise_affine=norm_elementwise_affine,
100
+ )
101
+ for d in range(num_layers)
102
+ ]
103
+ )
104
+
105
+ self.proj_out = nn.Linear(inner_dim, in_channels)
106
+
107
+ def forward(
108
+ self,
109
+ hidden_states,
110
+ encoder_hidden_states=None,
111
+ timestep=None,
112
+ class_labels=None,
113
+ num_frames=1,
114
+ cross_attention_kwargs=None,
115
+ return_dict: bool = True,
116
+ ):
117
+ """
118
+ The [`TransformerTemporal`] forward method.
119
+
120
+ Args:
121
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
122
+ Input hidden_states.
123
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
124
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
125
+ self-attention.
126
+ timestep ( `torch.long`, *optional*):
127
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
128
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
129
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
130
+ `AdaLayerZeroNorm`.
131
+ return_dict (`bool`, *optional*, defaults to `True`):
132
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
133
+ tuple.
134
+
135
+ Returns:
136
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
137
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
138
+ returned, otherwise a `tuple` where the first element is the sample tensor.
139
+ """
140
+ # 1. Input
141
+ batch_frames, channel, height, width = hidden_states.shape
142
+ batch_size = batch_frames // num_frames
143
+
144
+ residual = hidden_states
145
+
146
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
147
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
148
+
149
+ hidden_states = self.norm(hidden_states)
150
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
151
+
152
+ hidden_states = self.proj_in(hidden_states)
153
+
154
+ # 2. Blocks
155
+ for block in self.transformer_blocks:
156
+ hidden_states = block(
157
+ hidden_states,
158
+ encoder_hidden_states=encoder_hidden_states,
159
+ timestep=timestep,
160
+ cross_attention_kwargs=cross_attention_kwargs,
161
+ class_labels=class_labels,
162
+ )
163
+
164
+ # 3. Output
165
+ hidden_states = self.proj_out(hidden_states)
166
+ hidden_states = (
167
+ hidden_states[None, None, :]
168
+ .reshape(batch_size, height, width, channel, num_frames)
169
+ .permute(0, 3, 4, 1, 2)
170
+ .contiguous()
171
+ )
172
+ hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
173
+
174
+ output = hidden_states + residual
175
+
176
+ if not return_dict:
177
+ return (output,)
178
+
179
+ return TransformerTemporalModelOutput(sample=output)
showone/models/unet_3d_blocks.py ADDED
@@ -0,0 +1,1619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import is_torch_version, logging
22
+ from diffusers.models.attention import AdaGroupNorm
23
+ from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
24
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
25
+ from diffusers.models.transformer_2d import Transformer2DModel
26
+ from diffusers.models.transformer_temporal import TransformerTemporalModel
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ def get_down_block(
32
+ down_block_type,
33
+ num_layers,
34
+ in_channels,
35
+ out_channels,
36
+ temb_channels,
37
+ add_downsample,
38
+ resnet_eps,
39
+ resnet_act_fn,
40
+ transformer_layers_per_block=1,
41
+ num_attention_heads=None,
42
+ resnet_groups=None,
43
+ cross_attention_dim=None,
44
+ downsample_padding=None,
45
+ dual_cross_attention=False,
46
+ use_linear_projection=False,
47
+ only_cross_attention=False,
48
+ upcast_attention=False,
49
+ resnet_time_scale_shift="default",
50
+ resnet_skip_time_act=False,
51
+ resnet_out_scale_factor=1.0,
52
+ cross_attention_norm=None,
53
+ attention_head_dim=None,
54
+ downsample_type=None,
55
+ ):
56
+ # If attn head dim is not defined, we default it to the number of heads
57
+ if attention_head_dim is None:
58
+ logger.warn(
59
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
60
+ )
61
+ attention_head_dim = num_attention_heads
62
+
63
+ if down_block_type == "DownBlock3D":
64
+ return DownBlock3D(
65
+ num_layers=num_layers,
66
+ in_channels=in_channels,
67
+ out_channels=out_channels,
68
+ temb_channels=temb_channels,
69
+ add_downsample=add_downsample,
70
+ resnet_eps=resnet_eps,
71
+ resnet_act_fn=resnet_act_fn,
72
+ resnet_groups=resnet_groups,
73
+ downsample_padding=downsample_padding,
74
+ resnet_time_scale_shift=resnet_time_scale_shift,
75
+ )
76
+ elif down_block_type == "CrossAttnDownBlock3D":
77
+ if cross_attention_dim is None:
78
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
79
+ return CrossAttnDownBlock3D(
80
+ num_layers=num_layers,
81
+ transformer_layers_per_block=transformer_layers_per_block,
82
+ in_channels=in_channels,
83
+ out_channels=out_channels,
84
+ temb_channels=temb_channels,
85
+ add_downsample=add_downsample,
86
+ resnet_eps=resnet_eps,
87
+ resnet_act_fn=resnet_act_fn,
88
+ resnet_groups=resnet_groups,
89
+ downsample_padding=downsample_padding,
90
+ cross_attention_dim=cross_attention_dim,
91
+ num_attention_heads=num_attention_heads,
92
+ dual_cross_attention=dual_cross_attention,
93
+ use_linear_projection=use_linear_projection,
94
+ only_cross_attention=only_cross_attention,
95
+ upcast_attention=upcast_attention,
96
+ resnet_time_scale_shift=resnet_time_scale_shift,
97
+ )
98
+ elif down_block_type == "SimpleCrossAttnDownBlock3D":
99
+ if cross_attention_dim is None:
100
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock3D")
101
+ return SimpleCrossAttnDownBlock3D(
102
+ num_layers=num_layers,
103
+ in_channels=in_channels,
104
+ out_channels=out_channels,
105
+ temb_channels=temb_channels,
106
+ add_downsample=add_downsample,
107
+ resnet_eps=resnet_eps,
108
+ resnet_act_fn=resnet_act_fn,
109
+ resnet_groups=resnet_groups,
110
+ cross_attention_dim=cross_attention_dim,
111
+ attention_head_dim=attention_head_dim,
112
+ resnet_time_scale_shift=resnet_time_scale_shift,
113
+ skip_time_act=resnet_skip_time_act,
114
+ output_scale_factor=resnet_out_scale_factor,
115
+ only_cross_attention=only_cross_attention,
116
+ cross_attention_norm=cross_attention_norm,
117
+ )
118
+ elif down_block_type == "ResnetDownsampleBlock3D":
119
+ return ResnetDownsampleBlock3D(
120
+ num_layers=num_layers,
121
+ in_channels=in_channels,
122
+ out_channels=out_channels,
123
+ temb_channels=temb_channels,
124
+ add_downsample=add_downsample,
125
+ resnet_eps=resnet_eps,
126
+ resnet_act_fn=resnet_act_fn,
127
+ resnet_groups=resnet_groups,
128
+ resnet_time_scale_shift=resnet_time_scale_shift,
129
+ skip_time_act=resnet_skip_time_act,
130
+ output_scale_factor=resnet_out_scale_factor,
131
+ )
132
+ raise ValueError(f"{down_block_type} does not exist.")
133
+
134
+
135
+ def get_up_block(
136
+ up_block_type,
137
+ num_layers,
138
+ in_channels,
139
+ out_channels,
140
+ prev_output_channel,
141
+ temb_channels,
142
+ add_upsample,
143
+ resnet_eps,
144
+ resnet_act_fn,
145
+ transformer_layers_per_block=1,
146
+ num_attention_heads=None,
147
+ resnet_groups=None,
148
+ cross_attention_dim=None,
149
+ dual_cross_attention=False,
150
+ use_linear_projection=False,
151
+ only_cross_attention=False,
152
+ upcast_attention=False,
153
+ resnet_time_scale_shift="default",
154
+ resnet_skip_time_act=False,
155
+ resnet_out_scale_factor=1.0,
156
+ cross_attention_norm=None,
157
+ attention_head_dim=None,
158
+ upsample_type=None,
159
+ ):
160
+ # If attn head dim is not defined, we default it to the number of heads
161
+ if attention_head_dim is None:
162
+ logger.warn(
163
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
164
+ )
165
+ attention_head_dim = num_attention_heads
166
+
167
+ if up_block_type == "UpBlock3D":
168
+ return UpBlock3D(
169
+ num_layers=num_layers,
170
+ in_channels=in_channels,
171
+ out_channels=out_channels,
172
+ prev_output_channel=prev_output_channel,
173
+ temb_channels=temb_channels,
174
+ add_upsample=add_upsample,
175
+ resnet_eps=resnet_eps,
176
+ resnet_act_fn=resnet_act_fn,
177
+ resnet_groups=resnet_groups,
178
+ resnet_time_scale_shift=resnet_time_scale_shift,
179
+ )
180
+ elif up_block_type == "CrossAttnUpBlock3D":
181
+ if cross_attention_dim is None:
182
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
183
+ return CrossAttnUpBlock3D(
184
+ num_layers=num_layers,
185
+ transformer_layers_per_block=transformer_layers_per_block,
186
+ in_channels=in_channels,
187
+ out_channels=out_channels,
188
+ prev_output_channel=prev_output_channel,
189
+ temb_channels=temb_channels,
190
+ add_upsample=add_upsample,
191
+ resnet_eps=resnet_eps,
192
+ resnet_act_fn=resnet_act_fn,
193
+ resnet_groups=resnet_groups,
194
+ cross_attention_dim=cross_attention_dim,
195
+ num_attention_heads=num_attention_heads,
196
+ dual_cross_attention=dual_cross_attention,
197
+ use_linear_projection=use_linear_projection,
198
+ only_cross_attention=only_cross_attention,
199
+ upcast_attention=upcast_attention,
200
+ resnet_time_scale_shift=resnet_time_scale_shift,
201
+ )
202
+ elif up_block_type == "SimpleCrossAttnUpBlock3D":
203
+ if cross_attention_dim is None:
204
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock3D")
205
+ return SimpleCrossAttnUpBlock3D(
206
+ num_layers=num_layers,
207
+ in_channels=in_channels,
208
+ out_channels=out_channels,
209
+ prev_output_channel=prev_output_channel,
210
+ temb_channels=temb_channels,
211
+ add_upsample=add_upsample,
212
+ resnet_eps=resnet_eps,
213
+ resnet_act_fn=resnet_act_fn,
214
+ resnet_groups=resnet_groups,
215
+ cross_attention_dim=cross_attention_dim,
216
+ attention_head_dim=attention_head_dim,
217
+ resnet_time_scale_shift=resnet_time_scale_shift,
218
+ skip_time_act=resnet_skip_time_act,
219
+ output_scale_factor=resnet_out_scale_factor,
220
+ only_cross_attention=only_cross_attention,
221
+ cross_attention_norm=cross_attention_norm,
222
+ )
223
+ elif up_block_type == "ResnetUpsampleBlock3D":
224
+ return ResnetUpsampleBlock3D(
225
+ num_layers=num_layers,
226
+ in_channels=in_channels,
227
+ out_channels=out_channels,
228
+ prev_output_channel=prev_output_channel,
229
+ temb_channels=temb_channels,
230
+ add_upsample=add_upsample,
231
+ resnet_eps=resnet_eps,
232
+ resnet_act_fn=resnet_act_fn,
233
+ resnet_groups=resnet_groups,
234
+ resnet_time_scale_shift=resnet_time_scale_shift,
235
+ skip_time_act=resnet_skip_time_act,
236
+ output_scale_factor=resnet_out_scale_factor,
237
+ )
238
+ raise ValueError(f"{up_block_type} does not exist.")
239
+
240
+
241
+ class UNetMidBlock3DCrossAttn(nn.Module):
242
+ def __init__(
243
+ self,
244
+ in_channels: int,
245
+ temb_channels: int,
246
+ dropout: float = 0.0,
247
+ num_layers: int = 1,
248
+ transformer_layers_per_block: int = 1,
249
+ resnet_eps: float = 1e-6,
250
+ resnet_time_scale_shift: str = "default",
251
+ resnet_act_fn: str = "swish",
252
+ resnet_groups: int = 32,
253
+ resnet_pre_norm: bool = True,
254
+ num_attention_heads=1,
255
+ output_scale_factor=1.0,
256
+ cross_attention_dim=1280,
257
+ dual_cross_attention=False,
258
+ use_linear_projection=False,
259
+ upcast_attention=False,
260
+ ):
261
+ super().__init__()
262
+
263
+ self.has_cross_attention = True
264
+ self.num_attention_heads = num_attention_heads
265
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
266
+
267
+ # there is always at least one resnet
268
+ resnets = [
269
+ ResnetBlock2D(
270
+ in_channels=in_channels,
271
+ out_channels=in_channels,
272
+ temb_channels=temb_channels,
273
+ eps=resnet_eps,
274
+ groups=resnet_groups,
275
+ dropout=dropout,
276
+ time_embedding_norm=resnet_time_scale_shift,
277
+ non_linearity=resnet_act_fn,
278
+ output_scale_factor=output_scale_factor,
279
+ pre_norm=resnet_pre_norm,
280
+ )
281
+ ]
282
+ temp_convs = [
283
+ TemporalConvLayer(
284
+ in_channels,
285
+ in_channels,
286
+ dropout=0.1,
287
+ )
288
+ ]
289
+ attentions = []
290
+ temp_attentions = []
291
+
292
+ for _ in range(num_layers):
293
+ attentions.append(
294
+ Transformer2DModel(
295
+ num_attention_heads,
296
+ in_channels // num_attention_heads,
297
+ in_channels=in_channels,
298
+ num_layers=transformer_layers_per_block,
299
+ cross_attention_dim=cross_attention_dim,
300
+ norm_num_groups=resnet_groups,
301
+ use_linear_projection=use_linear_projection,
302
+ upcast_attention=upcast_attention,
303
+ )
304
+ )
305
+ temp_attentions.append(
306
+ TransformerTemporalModel(
307
+ num_attention_heads,
308
+ in_channels // num_attention_heads,
309
+ in_channels=in_channels,
310
+ num_layers=1, #todo: transformer_layers_per_block?
311
+ cross_attention_dim=cross_attention_dim,
312
+ norm_num_groups=resnet_groups,
313
+ )
314
+ )
315
+ resnets.append(
316
+ ResnetBlock2D(
317
+ in_channels=in_channels,
318
+ out_channels=in_channels,
319
+ temb_channels=temb_channels,
320
+ eps=resnet_eps,
321
+ groups=resnet_groups,
322
+ dropout=dropout,
323
+ time_embedding_norm=resnet_time_scale_shift,
324
+ non_linearity=resnet_act_fn,
325
+ output_scale_factor=output_scale_factor,
326
+ pre_norm=resnet_pre_norm,
327
+ )
328
+ )
329
+ temp_convs.append(
330
+ TemporalConvLayer(
331
+ in_channels,
332
+ in_channels,
333
+ dropout=0.1,
334
+ )
335
+ )
336
+
337
+ self.resnets = nn.ModuleList(resnets)
338
+ self.temp_convs = nn.ModuleList(temp_convs)
339
+ self.attentions = nn.ModuleList(attentions)
340
+ self.temp_attentions = nn.ModuleList(temp_attentions)
341
+
342
+ def forward(
343
+ self,
344
+ hidden_states: torch.FloatTensor,
345
+ temb: Optional[torch.FloatTensor] = None,
346
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
347
+ attention_mask: Optional[torch.FloatTensor] = None,
348
+ num_frames: int = 1,
349
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
350
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
351
+ ) -> torch.FloatTensor:
352
+ hidden_states = self.resnets[0](hidden_states, temb)
353
+ hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
354
+ for attn, temp_attn, resnet, temp_conv in zip(
355
+ self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
356
+ ):
357
+ hidden_states = attn(
358
+ hidden_states,
359
+ encoder_hidden_states=encoder_hidden_states,
360
+ cross_attention_kwargs=cross_attention_kwargs,
361
+ attention_mask=attention_mask,
362
+ encoder_attention_mask=encoder_attention_mask,
363
+ return_dict=False,
364
+ )[0]
365
+ hidden_states = temp_attn(
366
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
367
+ ).sample
368
+ hidden_states = resnet(hidden_states, temb)
369
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
370
+
371
+ return hidden_states
372
+
373
+
374
+ class UNetMidBlock3DSimpleCrossAttn(nn.Module):
375
+ def __init__(
376
+ self,
377
+ in_channels: int,
378
+ temb_channels: int,
379
+ dropout: float = 0.0,
380
+ num_layers: int = 1,
381
+ resnet_eps: float = 1e-6,
382
+ resnet_time_scale_shift: str = "default",
383
+ resnet_act_fn: str = "swish",
384
+ resnet_groups: int = 32,
385
+ resnet_pre_norm: bool = True,
386
+ attention_head_dim=1,
387
+ output_scale_factor=1.0,
388
+ cross_attention_dim=1280,
389
+ skip_time_act=False,
390
+ only_cross_attention=False,
391
+ cross_attention_norm=None,
392
+ ):
393
+ super().__init__()
394
+
395
+ self.has_cross_attention = True
396
+
397
+ self.attention_head_dim = attention_head_dim
398
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
399
+
400
+ self.num_heads = in_channels // self.attention_head_dim
401
+
402
+ # there is always at least one resnet
403
+ resnets = [
404
+ ResnetBlock2D(
405
+ in_channels=in_channels,
406
+ out_channels=in_channels,
407
+ temb_channels=temb_channels,
408
+ eps=resnet_eps,
409
+ groups=resnet_groups,
410
+ dropout=dropout,
411
+ time_embedding_norm=resnet_time_scale_shift,
412
+ non_linearity=resnet_act_fn,
413
+ output_scale_factor=output_scale_factor,
414
+ pre_norm=resnet_pre_norm,
415
+ skip_time_act=skip_time_act,
416
+ )
417
+ ]
418
+ temp_convs = [
419
+ TemporalConvLayer(
420
+ in_channels,
421
+ in_channels,
422
+ dropout=0.1,
423
+ )
424
+ ]
425
+ attentions = []
426
+ temp_attentions = []
427
+
428
+ for _ in range(num_layers):
429
+ processor = (
430
+ AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
431
+ )
432
+
433
+ attentions.append(
434
+ Attention(
435
+ query_dim=in_channels,
436
+ cross_attention_dim=in_channels,
437
+ heads=self.num_heads,
438
+ dim_head=self.attention_head_dim,
439
+ added_kv_proj_dim=cross_attention_dim,
440
+ norm_num_groups=resnet_groups,
441
+ bias=True,
442
+ upcast_softmax=True,
443
+ only_cross_attention=only_cross_attention,
444
+ cross_attention_norm=cross_attention_norm,
445
+ processor=processor,
446
+ )
447
+ )
448
+ temp_attentions.append(
449
+ TransformerTemporalModel(
450
+ self.attention_head_dim,
451
+ in_channels // self.attention_head_dim,
452
+ in_channels=in_channels,
453
+ num_layers=1,
454
+ cross_attention_dim=cross_attention_dim,
455
+ norm_num_groups=resnet_groups,
456
+ )
457
+ )
458
+ resnets.append(
459
+ ResnetBlock2D(
460
+ in_channels=in_channels,
461
+ out_channels=in_channels,
462
+ temb_channels=temb_channels,
463
+ eps=resnet_eps,
464
+ groups=resnet_groups,
465
+ dropout=dropout,
466
+ time_embedding_norm=resnet_time_scale_shift,
467
+ non_linearity=resnet_act_fn,
468
+ output_scale_factor=output_scale_factor,
469
+ pre_norm=resnet_pre_norm,
470
+ skip_time_act=skip_time_act,
471
+ )
472
+ )
473
+ temp_convs.append(
474
+ TemporalConvLayer(
475
+ in_channels,
476
+ in_channels,
477
+ dropout=0.1,
478
+ )
479
+ )
480
+
481
+ self.resnets = nn.ModuleList(resnets)
482
+ self.temp_convs = nn.ModuleList(temp_convs)
483
+ self.attentions = nn.ModuleList(attentions)
484
+ self.temp_attentions = nn.ModuleList(temp_attentions)
485
+
486
+ def forward(
487
+ self,
488
+ hidden_states: torch.FloatTensor,
489
+ temb: Optional[torch.FloatTensor] = None,
490
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
491
+ attention_mask: Optional[torch.FloatTensor] = None,
492
+ num_frames: int = 1,
493
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
494
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
495
+ ):
496
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
497
+
498
+ if attention_mask is None:
499
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
500
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
501
+ else:
502
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
503
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
504
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
505
+ # then we can simplify this whole if/else block to:
506
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
507
+ mask = attention_mask
508
+
509
+ hidden_states = self.resnets[0](hidden_states, temb)
510
+ hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
511
+ for attn, temp_attn, resnet, temp_conv in zip(
512
+ self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
513
+ ):
514
+ hidden_states = attn(
515
+ hidden_states,
516
+ encoder_hidden_states=encoder_hidden_states,
517
+ attention_mask=mask,
518
+ **cross_attention_kwargs,
519
+ )
520
+ hidden_states = temp_attn(
521
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
522
+ ).sample
523
+ hidden_states = resnet(hidden_states, temb)
524
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
525
+
526
+ return hidden_states
527
+
528
+
529
+ class CrossAttnDownBlock3D(nn.Module):
530
+ def __init__(
531
+ self,
532
+ in_channels: int,
533
+ out_channels: int,
534
+ temb_channels: int,
535
+ dropout: float = 0.0,
536
+ num_layers: int = 1,
537
+ transformer_layers_per_block: int = 1,
538
+ resnet_eps: float = 1e-6,
539
+ resnet_time_scale_shift: str = "default",
540
+ resnet_act_fn: str = "swish",
541
+ resnet_groups: int = 32,
542
+ resnet_pre_norm: bool = True,
543
+ num_attention_heads=1,
544
+ cross_attention_dim=1280,
545
+ output_scale_factor=1.0,
546
+ downsample_padding=1,
547
+ add_downsample=True,
548
+ dual_cross_attention=False,
549
+ use_linear_projection=False,
550
+ only_cross_attention=False,
551
+ upcast_attention=False,
552
+ ):
553
+ super().__init__()
554
+ resnets = []
555
+ attentions = []
556
+ temp_attentions = []
557
+ temp_convs = []
558
+
559
+ self.has_cross_attention = True
560
+ self.num_attention_heads = num_attention_heads
561
+
562
+ for i in range(num_layers):
563
+ in_channels = in_channels if i == 0 else out_channels
564
+ resnets.append(
565
+ ResnetBlock2D(
566
+ in_channels=in_channels,
567
+ out_channels=out_channels,
568
+ temb_channels=temb_channels,
569
+ eps=resnet_eps,
570
+ groups=resnet_groups,
571
+ dropout=dropout,
572
+ time_embedding_norm=resnet_time_scale_shift,
573
+ non_linearity=resnet_act_fn,
574
+ output_scale_factor=output_scale_factor,
575
+ pre_norm=resnet_pre_norm,
576
+ )
577
+ )
578
+ temp_convs.append(
579
+ TemporalConvLayer(
580
+ out_channels,
581
+ out_channels,
582
+ dropout=0.1,
583
+ )
584
+ )
585
+ attentions.append(
586
+ Transformer2DModel(
587
+ num_attention_heads,
588
+ out_channels // num_attention_heads,
589
+ in_channels=out_channels,
590
+ num_layers=transformer_layers_per_block,
591
+ cross_attention_dim=cross_attention_dim,
592
+ norm_num_groups=resnet_groups,
593
+ use_linear_projection=use_linear_projection,
594
+ only_cross_attention=only_cross_attention,
595
+ upcast_attention=upcast_attention,
596
+ )
597
+ )
598
+ temp_attentions.append(
599
+ TransformerTemporalModel(
600
+ num_attention_heads,
601
+ out_channels // num_attention_heads,
602
+ in_channels=out_channels,
603
+ num_layers=1,
604
+ cross_attention_dim=cross_attention_dim,
605
+ norm_num_groups=resnet_groups,
606
+ )
607
+ )
608
+ self.resnets = nn.ModuleList(resnets)
609
+ self.temp_convs = nn.ModuleList(temp_convs)
610
+ self.attentions = nn.ModuleList(attentions)
611
+ self.temp_attentions = nn.ModuleList(temp_attentions)
612
+
613
+ if add_downsample:
614
+ self.downsamplers = nn.ModuleList(
615
+ [
616
+ Downsample2D(
617
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
618
+ )
619
+ ]
620
+ )
621
+ else:
622
+ self.downsamplers = None
623
+
624
+ self.gradient_checkpointing = False
625
+
626
+ def forward(
627
+ self,
628
+ hidden_states: torch.FloatTensor,
629
+ temb: Optional[torch.FloatTensor] = None,
630
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
631
+ attention_mask: Optional[torch.FloatTensor] = None,
632
+ num_frames: int = 1,
633
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
634
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
635
+ ):
636
+ output_states = ()
637
+
638
+ for resnet, temp_conv, attn, temp_attn in zip(
639
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
640
+ ):
641
+ if self.training and self.gradient_checkpointing:
642
+
643
+ def create_custom_forward(module, return_dict=None):
644
+ def custom_forward(*inputs):
645
+ if return_dict is not None:
646
+ return module(*inputs, return_dict=return_dict)
647
+ else:
648
+ return module(*inputs)
649
+
650
+ return custom_forward
651
+
652
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
653
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs,)
654
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, **ckpt_kwargs,)
655
+ hidden_states = torch.utils.checkpoint.checkpoint(
656
+ create_custom_forward(attn, return_dict=False),
657
+ hidden_states,
658
+ encoder_hidden_states,
659
+ None, # timestep
660
+ None, # class_labels
661
+ cross_attention_kwargs,
662
+ attention_mask,
663
+ encoder_attention_mask,
664
+ **ckpt_kwargs,
665
+ )[0]
666
+ hidden_states = temp_attn(
667
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, **ckpt_kwargs,
668
+ ).sample
669
+ else:
670
+ hidden_states = resnet(hidden_states, temb)
671
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
672
+ hidden_states = attn(
673
+ hidden_states,
674
+ encoder_hidden_states=encoder_hidden_states,
675
+ cross_attention_kwargs=cross_attention_kwargs,
676
+ attention_mask=attention_mask,
677
+ encoder_attention_mask=encoder_attention_mask,
678
+ return_dict=False,
679
+ )[0]
680
+ hidden_states = temp_attn(
681
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
682
+ ).sample
683
+
684
+ output_states = output_states + (hidden_states,)
685
+
686
+ if self.downsamplers is not None:
687
+ for downsampler in self.downsamplers:
688
+ hidden_states = downsampler(hidden_states)
689
+
690
+ output_states = output_states + (hidden_states,)
691
+
692
+ return hidden_states, output_states
693
+
694
+
695
+ class DownBlock3D(nn.Module):
696
+ def __init__(
697
+ self,
698
+ in_channels: int,
699
+ out_channels: int,
700
+ temb_channels: int,
701
+ dropout: float = 0.0,
702
+ num_layers: int = 1,
703
+ resnet_eps: float = 1e-6,
704
+ resnet_time_scale_shift: str = "default",
705
+ resnet_act_fn: str = "swish",
706
+ resnet_groups: int = 32,
707
+ resnet_pre_norm: bool = True,
708
+ output_scale_factor=1.0,
709
+ add_downsample=True,
710
+ downsample_padding=1,
711
+ ):
712
+ super().__init__()
713
+ resnets = []
714
+ temp_convs = []
715
+
716
+ for i in range(num_layers):
717
+ in_channels = in_channels if i == 0 else out_channels
718
+ resnets.append(
719
+ ResnetBlock2D(
720
+ in_channels=in_channels,
721
+ out_channels=out_channels,
722
+ temb_channels=temb_channels,
723
+ eps=resnet_eps,
724
+ groups=resnet_groups,
725
+ dropout=dropout,
726
+ time_embedding_norm=resnet_time_scale_shift,
727
+ non_linearity=resnet_act_fn,
728
+ output_scale_factor=output_scale_factor,
729
+ pre_norm=resnet_pre_norm,
730
+ )
731
+ )
732
+ temp_convs.append(
733
+ TemporalConvLayer(
734
+ out_channels,
735
+ out_channels,
736
+ dropout=0.1,
737
+ )
738
+ )
739
+
740
+ self.resnets = nn.ModuleList(resnets)
741
+ self.temp_convs = nn.ModuleList(temp_convs)
742
+
743
+ if add_downsample:
744
+ self.downsamplers = nn.ModuleList(
745
+ [
746
+ Downsample2D(
747
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
748
+ )
749
+ ]
750
+ )
751
+ else:
752
+ self.downsamplers = None
753
+
754
+ self.gradient_checkpointing = False
755
+
756
+ def forward(self, hidden_states, temb=None, num_frames=1):
757
+ output_states = ()
758
+
759
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
760
+ if self.training and self.gradient_checkpointing:
761
+
762
+ def create_custom_forward(module):
763
+ def custom_forward(*inputs):
764
+ return module(*inputs)
765
+
766
+ return custom_forward
767
+
768
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False)
769
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, use_reentrant=False)
770
+ else:
771
+ hidden_states = resnet(hidden_states, temb)
772
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
773
+
774
+ output_states = output_states + (hidden_states,)
775
+
776
+ if self.downsamplers is not None:
777
+ for downsampler in self.downsamplers:
778
+ hidden_states = downsampler(hidden_states)
779
+
780
+ output_states = output_states + (hidden_states,)
781
+
782
+ return hidden_states, output_states
783
+
784
+
785
+ class ResnetDownsampleBlock3D(nn.Module):
786
+ def __init__(
787
+ self,
788
+ in_channels: int,
789
+ out_channels: int,
790
+ temb_channels: int,
791
+ dropout: float = 0.0,
792
+ num_layers: int = 1,
793
+ resnet_eps: float = 1e-6,
794
+ resnet_time_scale_shift: str = "default",
795
+ resnet_act_fn: str = "swish",
796
+ resnet_groups: int = 32,
797
+ resnet_pre_norm: bool = True,
798
+ output_scale_factor=1.0,
799
+ add_downsample=True,
800
+ skip_time_act=False,
801
+ ):
802
+ super().__init__()
803
+ resnets = []
804
+ temp_convs = []
805
+
806
+ for i in range(num_layers):
807
+ in_channels = in_channels if i == 0 else out_channels
808
+ resnets.append(
809
+ ResnetBlock2D(
810
+ in_channels=in_channels,
811
+ out_channels=out_channels,
812
+ temb_channels=temb_channels,
813
+ eps=resnet_eps,
814
+ groups=resnet_groups,
815
+ dropout=dropout,
816
+ time_embedding_norm=resnet_time_scale_shift,
817
+ non_linearity=resnet_act_fn,
818
+ output_scale_factor=output_scale_factor,
819
+ pre_norm=resnet_pre_norm,
820
+ skip_time_act=skip_time_act,
821
+ )
822
+ )
823
+ temp_convs.append(
824
+ TemporalConvLayer(
825
+ out_channels,
826
+ out_channels,
827
+ dropout=0.1,
828
+ )
829
+ )
830
+
831
+ self.resnets = nn.ModuleList(resnets)
832
+ self.temp_convs = nn.ModuleList(temp_convs)
833
+
834
+ if add_downsample:
835
+ self.downsamplers = nn.ModuleList(
836
+ [
837
+ ResnetBlock2D(
838
+ in_channels=out_channels,
839
+ out_channels=out_channels,
840
+ temb_channels=temb_channels,
841
+ eps=resnet_eps,
842
+ groups=resnet_groups,
843
+ dropout=dropout,
844
+ time_embedding_norm=resnet_time_scale_shift,
845
+ non_linearity=resnet_act_fn,
846
+ output_scale_factor=output_scale_factor,
847
+ pre_norm=resnet_pre_norm,
848
+ skip_time_act=skip_time_act,
849
+ down=True,
850
+ )
851
+ ]
852
+ )
853
+ else:
854
+ self.downsamplers = None
855
+
856
+ self.gradient_checkpointing = False
857
+
858
+ def forward(self, hidden_states, temb=None, num_frames=1):
859
+ output_states = ()
860
+
861
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
862
+ if self.training and self.gradient_checkpointing:
863
+
864
+ def create_custom_forward(module):
865
+ def custom_forward(*inputs):
866
+ return module(*inputs)
867
+
868
+ return custom_forward
869
+
870
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False)
871
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, use_reentrant=False)
872
+ else:
873
+ hidden_states = resnet(hidden_states, temb)
874
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
875
+
876
+ output_states = output_states + (hidden_states,)
877
+
878
+ if self.downsamplers is not None:
879
+ for downsampler in self.downsamplers:
880
+ hidden_states = downsampler(hidden_states, temb)
881
+
882
+ output_states = output_states + (hidden_states,)
883
+
884
+ return hidden_states, output_states
885
+
886
+
887
+ class SimpleCrossAttnDownBlock3D(nn.Module):
888
+ def __init__(
889
+ self,
890
+ in_channels: int,
891
+ out_channels: int,
892
+ temb_channels: int,
893
+ dropout: float = 0.0,
894
+ num_layers: int = 1,
895
+ resnet_eps: float = 1e-6,
896
+ resnet_time_scale_shift: str = "default",
897
+ resnet_act_fn: str = "swish",
898
+ resnet_groups: int = 32,
899
+ resnet_pre_norm: bool = True,
900
+ attention_head_dim=1,
901
+ cross_attention_dim=1280,
902
+ output_scale_factor=1.0,
903
+ add_downsample=True,
904
+ skip_time_act=False,
905
+ only_cross_attention=False,
906
+ cross_attention_norm=None,
907
+ ):
908
+ super().__init__()
909
+
910
+ self.has_cross_attention = True
911
+
912
+ resnets = []
913
+ attentions = []
914
+ temp_attentions = []
915
+ temp_convs = []
916
+
917
+ self.attention_head_dim = attention_head_dim
918
+ self.num_heads = out_channels // self.attention_head_dim
919
+
920
+ for i in range(num_layers):
921
+ in_channels = in_channels if i == 0 else out_channels
922
+ resnets.append(
923
+ ResnetBlock2D(
924
+ in_channels=in_channels,
925
+ out_channels=out_channels,
926
+ temb_channels=temb_channels,
927
+ eps=resnet_eps,
928
+ groups=resnet_groups,
929
+ dropout=dropout,
930
+ time_embedding_norm=resnet_time_scale_shift,
931
+ non_linearity=resnet_act_fn,
932
+ output_scale_factor=output_scale_factor,
933
+ pre_norm=resnet_pre_norm,
934
+ skip_time_act=skip_time_act,
935
+ )
936
+ )
937
+ temp_convs.append(
938
+ TemporalConvLayer(
939
+ out_channels,
940
+ out_channels,
941
+ dropout=0.1,
942
+ )
943
+ )
944
+ processor = (
945
+ AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
946
+ )
947
+
948
+ attentions.append(
949
+ Attention(
950
+ query_dim=out_channels,
951
+ cross_attention_dim=out_channels,
952
+ heads=self.num_heads,
953
+ dim_head=attention_head_dim,
954
+ added_kv_proj_dim=cross_attention_dim,
955
+ norm_num_groups=resnet_groups,
956
+ bias=True,
957
+ upcast_softmax=True,
958
+ only_cross_attention=only_cross_attention,
959
+ cross_attention_norm=cross_attention_norm,
960
+ processor=processor,
961
+ )
962
+ )
963
+ temp_attentions.append(
964
+ TransformerTemporalModel(
965
+ attention_head_dim,
966
+ out_channels // attention_head_dim,
967
+ in_channels=out_channels,
968
+ num_layers=1,
969
+ cross_attention_dim=cross_attention_dim,
970
+ norm_num_groups=resnet_groups,
971
+ )
972
+ )
973
+ self.resnets = nn.ModuleList(resnets)
974
+ self.temp_convs = nn.ModuleList(temp_convs)
975
+ self.attentions = nn.ModuleList(attentions)
976
+ self.temp_attentions = nn.ModuleList(temp_attentions)
977
+
978
+ if add_downsample:
979
+ self.downsamplers = nn.ModuleList(
980
+ [
981
+ ResnetBlock2D(
982
+ in_channels=out_channels,
983
+ out_channels=out_channels,
984
+ temb_channels=temb_channels,
985
+ eps=resnet_eps,
986
+ groups=resnet_groups,
987
+ dropout=dropout,
988
+ time_embedding_norm=resnet_time_scale_shift,
989
+ non_linearity=resnet_act_fn,
990
+ output_scale_factor=output_scale_factor,
991
+ pre_norm=resnet_pre_norm,
992
+ skip_time_act=skip_time_act,
993
+ down=True,
994
+ )
995
+ ]
996
+ )
997
+ else:
998
+ self.downsamplers = None
999
+
1000
+ self.gradient_checkpointing = False
1001
+
1002
+ def forward(
1003
+ self,
1004
+ hidden_states: torch.FloatTensor,
1005
+ temb: Optional[torch.FloatTensor] = None,
1006
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1007
+ attention_mask: Optional[torch.FloatTensor] = None,
1008
+ num_frames: int = 1,
1009
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1010
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1011
+ ):
1012
+ output_states = ()
1013
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
1014
+
1015
+ if attention_mask is None:
1016
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
1017
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
1018
+ else:
1019
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
1020
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
1021
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
1022
+ # then we can simplify this whole if/else block to:
1023
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
1024
+ mask = attention_mask
1025
+
1026
+ for resnet, temp_conv, attn, temp_attn in zip(
1027
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
1028
+ ):
1029
+ if self.training and self.gradient_checkpointing:
1030
+
1031
+ def create_custom_forward(module, return_dict=None):
1032
+ def custom_forward(*inputs):
1033
+ if return_dict is not None:
1034
+ return module(*inputs, return_dict=return_dict)
1035
+ else:
1036
+ return module(*inputs)
1037
+
1038
+ return custom_forward
1039
+
1040
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1041
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames)
1042
+ hidden_states = torch.utils.checkpoint.checkpoint(
1043
+ create_custom_forward(attn, return_dict=False),
1044
+ hidden_states,
1045
+ encoder_hidden_states,
1046
+ mask,
1047
+ cross_attention_kwargs,
1048
+ )[0]
1049
+ hidden_states = temp_attn(
1050
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
1051
+ ).sample
1052
+ else:
1053
+ hidden_states = resnet(hidden_states, temb)
1054
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
1055
+ hidden_states = attn(
1056
+ hidden_states,
1057
+ encoder_hidden_states=encoder_hidden_states,
1058
+ attention_mask=mask,
1059
+ **cross_attention_kwargs,
1060
+ )
1061
+ hidden_states = temp_attn(
1062
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
1063
+ ).sample
1064
+
1065
+ output_states = output_states + (hidden_states,)
1066
+
1067
+ if self.downsamplers is not None:
1068
+ for downsampler in self.downsamplers:
1069
+ hidden_states = downsampler(hidden_states, temb)
1070
+
1071
+ output_states = output_states + (hidden_states,)
1072
+
1073
+ return hidden_states, output_states
1074
+
1075
+
1076
+ class CrossAttnUpBlock3D(nn.Module):
1077
+ def __init__(
1078
+ self,
1079
+ in_channels: int,
1080
+ out_channels: int,
1081
+ prev_output_channel: int,
1082
+ temb_channels: int,
1083
+ dropout: float = 0.0,
1084
+ num_layers: int = 1,
1085
+ transformer_layers_per_block: int = 1,
1086
+ resnet_eps: float = 1e-6,
1087
+ resnet_time_scale_shift: str = "default",
1088
+ resnet_act_fn: str = "swish",
1089
+ resnet_groups: int = 32,
1090
+ resnet_pre_norm: bool = True,
1091
+ num_attention_heads=1,
1092
+ cross_attention_dim=1280,
1093
+ output_scale_factor=1.0,
1094
+ add_upsample=True,
1095
+ dual_cross_attention=False,
1096
+ use_linear_projection=False,
1097
+ only_cross_attention=False,
1098
+ upcast_attention=False,
1099
+ ):
1100
+ super().__init__()
1101
+ resnets = []
1102
+ temp_convs = []
1103
+ attentions = []
1104
+ temp_attentions = []
1105
+
1106
+ self.has_cross_attention = True
1107
+ self.num_attention_heads = num_attention_heads
1108
+
1109
+ for i in range(num_layers):
1110
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1111
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1112
+
1113
+ resnets.append(
1114
+ ResnetBlock2D(
1115
+ in_channels=resnet_in_channels + res_skip_channels,
1116
+ out_channels=out_channels,
1117
+ temb_channels=temb_channels,
1118
+ eps=resnet_eps,
1119
+ groups=resnet_groups,
1120
+ dropout=dropout,
1121
+ time_embedding_norm=resnet_time_scale_shift,
1122
+ non_linearity=resnet_act_fn,
1123
+ output_scale_factor=output_scale_factor,
1124
+ pre_norm=resnet_pre_norm,
1125
+ )
1126
+ )
1127
+ temp_convs.append(
1128
+ TemporalConvLayer(
1129
+ out_channels,
1130
+ out_channels,
1131
+ dropout=0.1,
1132
+ )
1133
+ )
1134
+ attentions.append(
1135
+ Transformer2DModel(
1136
+ num_attention_heads,
1137
+ out_channels // num_attention_heads,
1138
+ in_channels=out_channels,
1139
+ num_layers=transformer_layers_per_block,
1140
+ cross_attention_dim=cross_attention_dim,
1141
+ norm_num_groups=resnet_groups,
1142
+ use_linear_projection=use_linear_projection,
1143
+ only_cross_attention=only_cross_attention,
1144
+ upcast_attention=upcast_attention,
1145
+ )
1146
+ )
1147
+ temp_attentions.append(
1148
+ TransformerTemporalModel(
1149
+ num_attention_heads,
1150
+ out_channels // num_attention_heads,
1151
+ in_channels=out_channels,
1152
+ num_layers=1,
1153
+ cross_attention_dim=cross_attention_dim,
1154
+ norm_num_groups=resnet_groups,
1155
+ )
1156
+ )
1157
+ self.resnets = nn.ModuleList(resnets)
1158
+ self.temp_convs = nn.ModuleList(temp_convs)
1159
+ self.attentions = nn.ModuleList(attentions)
1160
+ self.temp_attentions = nn.ModuleList(temp_attentions)
1161
+
1162
+ if add_upsample:
1163
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1164
+ else:
1165
+ self.upsamplers = None
1166
+
1167
+ self.gradient_checkpointing = False
1168
+
1169
+ def forward(
1170
+ self,
1171
+ hidden_states: torch.FloatTensor,
1172
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1173
+ temb: Optional[torch.FloatTensor] = None,
1174
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1175
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1176
+ upsample_size: Optional[int] = None,
1177
+ num_frames: int = 1,
1178
+ attention_mask: Optional[torch.FloatTensor] = None,
1179
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1180
+ ):
1181
+ for resnet, temp_conv, attn, temp_attn in zip(
1182
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
1183
+ ):
1184
+ # pop res hidden states
1185
+ res_hidden_states = res_hidden_states_tuple[-1]
1186
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1187
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1188
+
1189
+ if self.training and self.gradient_checkpointing:
1190
+
1191
+ def create_custom_forward(module, return_dict=None):
1192
+ def custom_forward(*inputs):
1193
+ if return_dict is not None:
1194
+ return module(*inputs, return_dict=return_dict)
1195
+ else:
1196
+ return module(*inputs)
1197
+
1198
+ return custom_forward
1199
+
1200
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1201
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs,)
1202
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, **ckpt_kwargs,)
1203
+ hidden_states = torch.utils.checkpoint.checkpoint(
1204
+ create_custom_forward(attn, return_dict=False),
1205
+ hidden_states,
1206
+ encoder_hidden_states,
1207
+ None, # timestep
1208
+ None, # class_labels
1209
+ cross_attention_kwargs,
1210
+ attention_mask,
1211
+ encoder_attention_mask,
1212
+ **ckpt_kwargs,
1213
+ )[0]
1214
+ hidden_states = temp_attn(
1215
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
1216
+ ).sample
1217
+ else:
1218
+ hidden_states = resnet(hidden_states, temb)
1219
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
1220
+ hidden_states = attn(
1221
+ hidden_states,
1222
+ encoder_hidden_states=encoder_hidden_states,
1223
+ cross_attention_kwargs=cross_attention_kwargs,
1224
+ attention_mask=attention_mask,
1225
+ encoder_attention_mask=encoder_attention_mask,
1226
+ return_dict=False,
1227
+ )[0]
1228
+ hidden_states = temp_attn(
1229
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
1230
+ ).sample
1231
+
1232
+ if self.upsamplers is not None:
1233
+ for upsampler in self.upsamplers:
1234
+ hidden_states = upsampler(hidden_states, upsample_size)
1235
+
1236
+ return hidden_states
1237
+
1238
+
1239
+ class UpBlock3D(nn.Module):
1240
+ def __init__(
1241
+ self,
1242
+ in_channels: int,
1243
+ prev_output_channel: int,
1244
+ out_channels: int,
1245
+ temb_channels: int,
1246
+ dropout: float = 0.0,
1247
+ num_layers: int = 1,
1248
+ resnet_eps: float = 1e-6,
1249
+ resnet_time_scale_shift: str = "default",
1250
+ resnet_act_fn: str = "swish",
1251
+ resnet_groups: int = 32,
1252
+ resnet_pre_norm: bool = True,
1253
+ output_scale_factor=1.0,
1254
+ add_upsample=True,
1255
+ ):
1256
+ super().__init__()
1257
+ resnets = []
1258
+ temp_convs = []
1259
+
1260
+ for i in range(num_layers):
1261
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1262
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1263
+
1264
+ resnets.append(
1265
+ ResnetBlock2D(
1266
+ in_channels=resnet_in_channels + res_skip_channels,
1267
+ out_channels=out_channels,
1268
+ temb_channels=temb_channels,
1269
+ eps=resnet_eps,
1270
+ groups=resnet_groups,
1271
+ dropout=dropout,
1272
+ time_embedding_norm=resnet_time_scale_shift,
1273
+ non_linearity=resnet_act_fn,
1274
+ output_scale_factor=output_scale_factor,
1275
+ pre_norm=resnet_pre_norm,
1276
+ )
1277
+ )
1278
+ temp_convs.append(
1279
+ TemporalConvLayer(
1280
+ out_channels,
1281
+ out_channels,
1282
+ dropout=0.1,
1283
+ )
1284
+ )
1285
+
1286
+ self.resnets = nn.ModuleList(resnets)
1287
+ self.temp_convs = nn.ModuleList(temp_convs)
1288
+
1289
+ if add_upsample:
1290
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1291
+ else:
1292
+ self.upsamplers = None
1293
+
1294
+ self.gradient_checkpointing = False
1295
+
1296
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
1297
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
1298
+ # pop res hidden states
1299
+ res_hidden_states = res_hidden_states_tuple[-1]
1300
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1301
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1302
+
1303
+ if self.training and self.gradient_checkpointing:
1304
+
1305
+ def create_custom_forward(module):
1306
+ def custom_forward(*inputs):
1307
+ return module(*inputs)
1308
+
1309
+ return custom_forward
1310
+
1311
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False)
1312
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, use_reentrant=False)
1313
+ else:
1314
+ hidden_states = resnet(hidden_states, temb)
1315
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
1316
+
1317
+ if self.upsamplers is not None:
1318
+ for upsampler in self.upsamplers:
1319
+ hidden_states = upsampler(hidden_states, upsample_size)
1320
+
1321
+ return hidden_states
1322
+
1323
+
1324
+ class ResnetUpsampleBlock3D(nn.Module):
1325
+ def __init__(
1326
+ self,
1327
+ in_channels: int,
1328
+ prev_output_channel: int,
1329
+ out_channels: int,
1330
+ temb_channels: int,
1331
+ dropout: float = 0.0,
1332
+ num_layers: int = 1,
1333
+ resnet_eps: float = 1e-6,
1334
+ resnet_time_scale_shift: str = "default",
1335
+ resnet_act_fn: str = "swish",
1336
+ resnet_groups: int = 32,
1337
+ resnet_pre_norm: bool = True,
1338
+ output_scale_factor=1.0,
1339
+ add_upsample=True,
1340
+ skip_time_act=False,
1341
+ ):
1342
+ super().__init__()
1343
+ resnets = []
1344
+ temp_convs = []
1345
+
1346
+ for i in range(num_layers):
1347
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1348
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1349
+
1350
+ resnets.append(
1351
+ ResnetBlock2D(
1352
+ in_channels=resnet_in_channels + res_skip_channels,
1353
+ out_channels=out_channels,
1354
+ temb_channels=temb_channels,
1355
+ eps=resnet_eps,
1356
+ groups=resnet_groups,
1357
+ dropout=dropout,
1358
+ time_embedding_norm=resnet_time_scale_shift,
1359
+ non_linearity=resnet_act_fn,
1360
+ output_scale_factor=output_scale_factor,
1361
+ pre_norm=resnet_pre_norm,
1362
+ skip_time_act=skip_time_act,
1363
+ )
1364
+ )
1365
+ temp_convs.append(
1366
+ TemporalConvLayer(
1367
+ out_channels,
1368
+ out_channels,
1369
+ dropout=0.1,
1370
+ )
1371
+ )
1372
+
1373
+ self.resnets = nn.ModuleList(resnets)
1374
+ self.temp_convs = nn.ModuleList(temp_convs)
1375
+
1376
+ if add_upsample:
1377
+ self.upsamplers = nn.ModuleList(
1378
+ [
1379
+ ResnetBlock2D(
1380
+ in_channels=out_channels,
1381
+ out_channels=out_channels,
1382
+ temb_channels=temb_channels,
1383
+ eps=resnet_eps,
1384
+ groups=resnet_groups,
1385
+ dropout=dropout,
1386
+ time_embedding_norm=resnet_time_scale_shift,
1387
+ non_linearity=resnet_act_fn,
1388
+ output_scale_factor=output_scale_factor,
1389
+ pre_norm=resnet_pre_norm,
1390
+ skip_time_act=skip_time_act,
1391
+ up=True,
1392
+ )
1393
+ ]
1394
+ )
1395
+ else:
1396
+ self.upsamplers = None
1397
+
1398
+ self.gradient_checkpointing = False
1399
+
1400
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
1401
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
1402
+ # pop res hidden states
1403
+ res_hidden_states = res_hidden_states_tuple[-1]
1404
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1405
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1406
+
1407
+ if self.training and self.gradient_checkpointing:
1408
+
1409
+ def create_custom_forward(module):
1410
+ def custom_forward(*inputs):
1411
+ return module(*inputs)
1412
+
1413
+ return custom_forward
1414
+
1415
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False)
1416
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames, use_reentrant=False)
1417
+ else:
1418
+ hidden_states = resnet(hidden_states, temb)
1419
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
1420
+
1421
+ if self.upsamplers is not None:
1422
+ for upsampler in self.upsamplers:
1423
+ hidden_states = upsampler(hidden_states, temb)
1424
+
1425
+ return hidden_states
1426
+
1427
+
1428
+ class SimpleCrossAttnUpBlock3D(nn.Module):
1429
+ def __init__(
1430
+ self,
1431
+ in_channels: int,
1432
+ out_channels: int,
1433
+ prev_output_channel: int,
1434
+ temb_channels: int,
1435
+ dropout: float = 0.0,
1436
+ num_layers: int = 1,
1437
+ resnet_eps: float = 1e-6,
1438
+ resnet_time_scale_shift: str = "default",
1439
+ resnet_act_fn: str = "swish",
1440
+ resnet_groups: int = 32,
1441
+ resnet_pre_norm: bool = True,
1442
+ attention_head_dim=1,
1443
+ cross_attention_dim=1280,
1444
+ output_scale_factor=1.0,
1445
+ add_upsample=True,
1446
+ skip_time_act=False,
1447
+ only_cross_attention=False,
1448
+ cross_attention_norm=None,
1449
+ ):
1450
+ super().__init__()
1451
+ resnets = []
1452
+ temp_convs = []
1453
+ attentions = []
1454
+ temp_attentions = []
1455
+
1456
+ self.has_cross_attention = True
1457
+ self.attention_head_dim = attention_head_dim
1458
+
1459
+ self.num_heads = out_channels // self.attention_head_dim
1460
+
1461
+ for i in range(num_layers):
1462
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1463
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1464
+
1465
+ resnets.append(
1466
+ ResnetBlock2D(
1467
+ in_channels=resnet_in_channels + res_skip_channels,
1468
+ out_channels=out_channels,
1469
+ temb_channels=temb_channels,
1470
+ eps=resnet_eps,
1471
+ groups=resnet_groups,
1472
+ dropout=dropout,
1473
+ time_embedding_norm=resnet_time_scale_shift,
1474
+ non_linearity=resnet_act_fn,
1475
+ output_scale_factor=output_scale_factor,
1476
+ pre_norm=resnet_pre_norm,
1477
+ skip_time_act=skip_time_act,
1478
+ )
1479
+ )
1480
+ temp_convs.append(
1481
+ TemporalConvLayer(
1482
+ out_channels,
1483
+ out_channels,
1484
+ dropout=0.1,
1485
+ )
1486
+ )
1487
+
1488
+ processor = (
1489
+ AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
1490
+ )
1491
+
1492
+ attentions.append(
1493
+ Attention(
1494
+ query_dim=out_channels,
1495
+ cross_attention_dim=out_channels,
1496
+ heads=self.num_heads,
1497
+ dim_head=self.attention_head_dim,
1498
+ added_kv_proj_dim=cross_attention_dim,
1499
+ norm_num_groups=resnet_groups,
1500
+ bias=True,
1501
+ upcast_softmax=True,
1502
+ only_cross_attention=only_cross_attention,
1503
+ cross_attention_norm=cross_attention_norm,
1504
+ processor=processor,
1505
+ )
1506
+ )
1507
+ temp_attentions.append(
1508
+ TransformerTemporalModel(
1509
+ attention_head_dim,
1510
+ out_channels // attention_head_dim,
1511
+ in_channels=out_channels,
1512
+ num_layers=1,
1513
+ cross_attention_dim=cross_attention_dim,
1514
+ norm_num_groups=resnet_groups,
1515
+ )
1516
+ )
1517
+ self.resnets = nn.ModuleList(resnets)
1518
+ self.temp_convs = nn.ModuleList(temp_convs)
1519
+ self.attentions = nn.ModuleList(attentions)
1520
+ self.temp_attentions = nn.ModuleList(temp_attentions)
1521
+
1522
+ if add_upsample:
1523
+ self.upsamplers = nn.ModuleList(
1524
+ [
1525
+ ResnetBlock2D(
1526
+ in_channels=out_channels,
1527
+ out_channels=out_channels,
1528
+ temb_channels=temb_channels,
1529
+ eps=resnet_eps,
1530
+ groups=resnet_groups,
1531
+ dropout=dropout,
1532
+ time_embedding_norm=resnet_time_scale_shift,
1533
+ non_linearity=resnet_act_fn,
1534
+ output_scale_factor=output_scale_factor,
1535
+ pre_norm=resnet_pre_norm,
1536
+ skip_time_act=skip_time_act,
1537
+ up=True,
1538
+ )
1539
+ ]
1540
+ )
1541
+ else:
1542
+ self.upsamplers = None
1543
+
1544
+ self.gradient_checkpointing = False
1545
+
1546
+ def forward(
1547
+ self,
1548
+ hidden_states: torch.FloatTensor,
1549
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1550
+ temb: Optional[torch.FloatTensor] = None,
1551
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1552
+ upsample_size: Optional[int] = None,
1553
+ num_frames: int = 1,
1554
+ attention_mask: Optional[torch.FloatTensor] = None,
1555
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1556
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1557
+ ):
1558
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
1559
+
1560
+ if attention_mask is None:
1561
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
1562
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
1563
+ else:
1564
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
1565
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
1566
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
1567
+ # then we can simplify this whole if/else block to:
1568
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
1569
+ mask = attention_mask
1570
+
1571
+ for resnet, temp_conv, attn, temp_attn in zip(
1572
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
1573
+ ):
1574
+ # pop res hidden states
1575
+ res_hidden_states = res_hidden_states_tuple[-1]
1576
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1577
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1578
+
1579
+ if self.training and self.gradient_checkpointing:
1580
+
1581
+ def create_custom_forward(module, return_dict=None):
1582
+ def custom_forward(*inputs):
1583
+ if return_dict is not None:
1584
+ return module(*inputs, return_dict=return_dict)
1585
+ else:
1586
+ return module(*inputs)
1587
+
1588
+ return custom_forward
1589
+
1590
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1591
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(temp_conv), hidden_states, num_frames)
1592
+ hidden_states = torch.utils.checkpoint.checkpoint(
1593
+ create_custom_forward(attn, return_dict=False),
1594
+ hidden_states,
1595
+ encoder_hidden_states,
1596
+ mask,
1597
+ cross_attention_kwargs,
1598
+ )[0]
1599
+ hidden_states = temp_attn(
1600
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
1601
+ ).sample
1602
+ else:
1603
+ hidden_states = resnet(hidden_states, temb)
1604
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
1605
+ hidden_states = attn(
1606
+ hidden_states,
1607
+ encoder_hidden_states=encoder_hidden_states,
1608
+ attention_mask=mask,
1609
+ **cross_attention_kwargs,
1610
+ )
1611
+ hidden_states = temp_attn(
1612
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
1613
+ ).sample
1614
+
1615
+ if self.upsamplers is not None:
1616
+ for upsampler in self.upsamplers:
1617
+ hidden_states = upsampler(hidden_states, temb)
1618
+
1619
+ return hidden_states
showone/models/unet_3d_condition.py ADDED
@@ -0,0 +1,985 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2023 The ModelScope Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.utils.checkpoint
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.loaders import UNet2DConditionLoadersMixin
25
+ from diffusers.utils import BaseOutput, logging
26
+ from diffusers.models.activations import get_activation
27
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
28
+ from diffusers.models.embeddings import (
29
+ GaussianFourierProjection,
30
+ ImageHintTimeEmbedding,
31
+ ImageProjection,
32
+ ImageTimeEmbedding,
33
+ TextImageProjection,
34
+ TextImageTimeEmbedding,
35
+ TextTimeEmbedding,
36
+ TimestepEmbedding,
37
+ Timesteps,
38
+ )
39
+ from diffusers.models.modeling_utils import ModelMixin
40
+ # from diffusers.models.transformer_temporal import TransformerTemporalModel
41
+ from .transformer_temporal import TransformerTemporalModel
42
+ from .unet_3d_blocks import (
43
+ CrossAttnDownBlock3D,
44
+ CrossAttnUpBlock3D,
45
+ DownBlock3D,
46
+ UNetMidBlock3DCrossAttn,
47
+ UNetMidBlock3DSimpleCrossAttn,
48
+ UpBlock3D,
49
+ get_down_block,
50
+ get_up_block,
51
+ )
52
+
53
+
54
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
+
56
+
57
+ @dataclass
58
+ class UNet3DConditionOutput(BaseOutput):
59
+ """
60
+ Args:
61
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
62
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
63
+ """
64
+
65
+ sample: torch.FloatTensor
66
+
67
+
68
+ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
69
+ r"""
70
+ UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
71
+ and returns sample shaped output.
72
+
73
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
74
+ implements for all the models (such as downloading or saving, etc.)
75
+
76
+ Parameters:
77
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
78
+ Height and width of input/output sample.
79
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
80
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
81
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
82
+ The tuple of downsample blocks to use.
83
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
84
+ The tuple of upsample blocks to use.
85
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
86
+ The tuple of output channels for each block.
87
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
88
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
89
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
90
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
91
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
92
+ If `None`, it will skip the normalization and activation layers in post-processing
93
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
94
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
95
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
96
+ """
97
+
98
+ _supports_gradient_checkpointing = True
99
+
100
+ @register_to_config
101
+ def __init__(
102
+ self,
103
+ sample_size: Optional[int] = None,
104
+ in_channels: int = 4,
105
+ out_channels: int = 4,
106
+ center_input_sample: bool = False,
107
+ flip_sin_to_cos: bool = True,
108
+ freq_shift: int = 0,
109
+ down_block_types: Tuple[str] = (
110
+ "CrossAttnDownBlock3D",
111
+ "CrossAttnDownBlock3D",
112
+ "CrossAttnDownBlock3D",
113
+ "DownBlock3D",
114
+ ),
115
+ mid_block_type: Optional[str] = "UNetMidBlock3DCrossAttn",
116
+ up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
117
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
118
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
119
+ layers_per_block: Union[int, Tuple[int]] = 2,
120
+ downsample_padding: int = 1,
121
+ mid_block_scale_factor: float = 1,
122
+ act_fn: str = "silu",
123
+ norm_num_groups: Optional[int] = 32,
124
+ norm_eps: float = 1e-5,
125
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
126
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
127
+ encoder_hid_dim: Optional[int] = None,
128
+ encoder_hid_dim_type: Optional[str] = None,
129
+ attention_head_dim: Union[int, Tuple[int]] = 8,
130
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
131
+ dual_cross_attention: bool = False,
132
+ use_linear_projection: bool = False,
133
+ class_embed_type: Optional[str] = None,
134
+ addition_embed_type: Optional[str] = None,
135
+ addition_time_embed_dim: Optional[int] = None,
136
+ num_class_embeds: Optional[int] = None,
137
+ upcast_attention: bool = False,
138
+ resnet_time_scale_shift: str = "default",
139
+ resnet_skip_time_act: bool = False,
140
+ resnet_out_scale_factor: int = 1.0,
141
+ time_embedding_type: str = "positional",
142
+ time_embedding_dim: Optional[int] = None,
143
+ time_embedding_act_fn: Optional[str] = None,
144
+ timestep_post_act: Optional[str] = None,
145
+ time_cond_proj_dim: Optional[int] = None,
146
+ conv_in_kernel: int = 3,
147
+ conv_out_kernel: int = 3,
148
+ projection_class_embeddings_input_dim: Optional[int] = None,
149
+ class_embeddings_concat: bool = False,
150
+ mid_block_only_cross_attention: Optional[bool] = None,
151
+ cross_attention_norm: Optional[str] = None,
152
+ addition_embed_type_num_heads=64,
153
+ transfromer_in_opt: bool =False,
154
+ ):
155
+ super().__init__()
156
+
157
+ self.sample_size = sample_size
158
+ self.transformer_in_opt = transfromer_in_opt
159
+
160
+ if num_attention_heads is not None:
161
+ raise ValueError(
162
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
163
+ )
164
+
165
+ # If `num_attention_heads` is not defined (which is the case for most models)
166
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
167
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
168
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
169
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
170
+ # which is why we correct for the naming here.
171
+ num_attention_heads = num_attention_heads or attention_head_dim
172
+
173
+ # Check inputs
174
+ if len(down_block_types) != len(up_block_types):
175
+ raise ValueError(
176
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
177
+ )
178
+
179
+ if len(block_out_channels) != len(down_block_types):
180
+ raise ValueError(
181
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
182
+ )
183
+
184
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
185
+ raise ValueError(
186
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
187
+ )
188
+
189
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
190
+ raise ValueError(
191
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
192
+ )
193
+
194
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
195
+ raise ValueError(
196
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
197
+ )
198
+
199
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
200
+ raise ValueError(
201
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
202
+ )
203
+
204
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
205
+ raise ValueError(
206
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
207
+ )
208
+
209
+ # input
210
+ conv_in_padding = (conv_in_kernel - 1) // 2
211
+ self.conv_in = nn.Conv2d(
212
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
213
+ )
214
+
215
+ if self.transformer_in_opt:
216
+ self.transformer_in = TransformerTemporalModel(
217
+ num_attention_heads=8,
218
+ attention_head_dim=64,
219
+ in_channels=block_out_channels[0],
220
+ num_layers=1,
221
+ )
222
+
223
+
224
+ # time
225
+ if time_embedding_type == "fourier":
226
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
227
+ if time_embed_dim % 2 != 0:
228
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
229
+ self.time_proj = GaussianFourierProjection(
230
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
231
+ )
232
+ timestep_input_dim = time_embed_dim
233
+ elif time_embedding_type == "positional":
234
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
235
+
236
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
237
+ timestep_input_dim = block_out_channels[0]
238
+ else:
239
+ raise ValueError(
240
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
241
+ )
242
+
243
+ self.time_embedding = TimestepEmbedding(
244
+ timestep_input_dim,
245
+ time_embed_dim,
246
+ act_fn=act_fn,
247
+ post_act_fn=timestep_post_act,
248
+ cond_proj_dim=time_cond_proj_dim,
249
+ )
250
+
251
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
252
+ encoder_hid_dim_type = "text_proj"
253
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
254
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
255
+
256
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
257
+ raise ValueError(
258
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
259
+ )
260
+
261
+ if encoder_hid_dim_type == "text_proj":
262
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
263
+ elif encoder_hid_dim_type == "text_image_proj":
264
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
265
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
266
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
267
+ self.encoder_hid_proj = TextImageProjection(
268
+ text_embed_dim=encoder_hid_dim,
269
+ image_embed_dim=cross_attention_dim,
270
+ cross_attention_dim=cross_attention_dim,
271
+ )
272
+ elif encoder_hid_dim_type == "image_proj":
273
+ # Kandinsky 2.2
274
+ self.encoder_hid_proj = ImageProjection(
275
+ image_embed_dim=encoder_hid_dim,
276
+ cross_attention_dim=cross_attention_dim,
277
+ )
278
+ elif encoder_hid_dim_type is not None:
279
+ raise ValueError(
280
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
281
+ )
282
+ else:
283
+ self.encoder_hid_proj = None
284
+
285
+ # class embedding
286
+ if class_embed_type is None and num_class_embeds is not None:
287
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
288
+ elif class_embed_type == "timestep":
289
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
290
+ elif class_embed_type == "identity":
291
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
292
+ elif class_embed_type == "projection":
293
+ if projection_class_embeddings_input_dim is None:
294
+ raise ValueError(
295
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
296
+ )
297
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
298
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
299
+ # 2. it projects from an arbitrary input dimension.
300
+ #
301
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
302
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
303
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
304
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
305
+ elif class_embed_type == "simple_projection":
306
+ if projection_class_embeddings_input_dim is None:
307
+ raise ValueError(
308
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
309
+ )
310
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
311
+ else:
312
+ self.class_embedding = None
313
+
314
+ if addition_embed_type == "text":
315
+ if encoder_hid_dim is not None:
316
+ text_time_embedding_from_dim = encoder_hid_dim
317
+ else:
318
+ text_time_embedding_from_dim = cross_attention_dim
319
+
320
+ self.add_embedding = TextTimeEmbedding(
321
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
322
+ )
323
+ elif addition_embed_type == "text_image":
324
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
325
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
326
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
327
+ self.add_embedding = TextImageTimeEmbedding(
328
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
329
+ )
330
+ elif addition_embed_type == "text_time":
331
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
332
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
333
+ elif addition_embed_type == "image":
334
+ # Kandinsky 2.2
335
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
336
+ elif addition_embed_type == "image_hint":
337
+ # Kandinsky 2.2 ControlNet
338
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
339
+ elif addition_embed_type is not None:
340
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
341
+
342
+ if time_embedding_act_fn is None:
343
+ self.time_embed_act = None
344
+ else:
345
+ self.time_embed_act = get_activation(time_embedding_act_fn)
346
+
347
+ self.down_blocks = nn.ModuleList([])
348
+ self.up_blocks = nn.ModuleList([])
349
+
350
+ if isinstance(only_cross_attention, bool):
351
+ if mid_block_only_cross_attention is None:
352
+ mid_block_only_cross_attention = only_cross_attention
353
+
354
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
355
+
356
+ if mid_block_only_cross_attention is None:
357
+ mid_block_only_cross_attention = False
358
+
359
+ if isinstance(num_attention_heads, int):
360
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
361
+
362
+ if isinstance(attention_head_dim, int):
363
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
364
+
365
+ if isinstance(cross_attention_dim, int):
366
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
367
+
368
+ if isinstance(layers_per_block, int):
369
+ layers_per_block = [layers_per_block] * len(down_block_types)
370
+
371
+ if isinstance(transformer_layers_per_block, int):
372
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
373
+
374
+ if class_embeddings_concat:
375
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
376
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
377
+ # regular time embeddings
378
+ blocks_time_embed_dim = time_embed_dim * 2
379
+ else:
380
+ blocks_time_embed_dim = time_embed_dim
381
+
382
+ # down
383
+ output_channel = block_out_channels[0]
384
+ for i, down_block_type in enumerate(down_block_types):
385
+ input_channel = output_channel
386
+ output_channel = block_out_channels[i]
387
+ is_final_block = i == len(block_out_channels) - 1
388
+
389
+ down_block = get_down_block(
390
+ down_block_type,
391
+ num_layers=layers_per_block[i],
392
+ transformer_layers_per_block=transformer_layers_per_block[i],
393
+ in_channels=input_channel,
394
+ out_channels=output_channel,
395
+ temb_channels=blocks_time_embed_dim,
396
+ add_downsample=not is_final_block,
397
+ resnet_eps=norm_eps,
398
+ resnet_act_fn=act_fn,
399
+ resnet_groups=norm_num_groups,
400
+ cross_attention_dim=cross_attention_dim[i],
401
+ num_attention_heads=num_attention_heads[i],
402
+ downsample_padding=downsample_padding,
403
+ dual_cross_attention=dual_cross_attention,
404
+ use_linear_projection=use_linear_projection,
405
+ only_cross_attention=only_cross_attention[i],
406
+ upcast_attention=upcast_attention,
407
+ resnet_time_scale_shift=resnet_time_scale_shift,
408
+ resnet_skip_time_act=resnet_skip_time_act,
409
+ resnet_out_scale_factor=resnet_out_scale_factor,
410
+ cross_attention_norm=cross_attention_norm,
411
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
412
+ )
413
+ self.down_blocks.append(down_block)
414
+
415
+ # mid
416
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
417
+ self.mid_block = UNetMidBlock3DCrossAttn(
418
+ transformer_layers_per_block=transformer_layers_per_block[-1],
419
+ in_channels=block_out_channels[-1],
420
+ temb_channels=blocks_time_embed_dim,
421
+ resnet_eps=norm_eps,
422
+ resnet_act_fn=act_fn,
423
+ output_scale_factor=mid_block_scale_factor,
424
+ resnet_time_scale_shift=resnet_time_scale_shift,
425
+ cross_attention_dim=cross_attention_dim[-1],
426
+ num_attention_heads=num_attention_heads[-1],
427
+ resnet_groups=norm_num_groups,
428
+ dual_cross_attention=dual_cross_attention,
429
+ use_linear_projection=use_linear_projection,
430
+ upcast_attention=upcast_attention,
431
+ )
432
+ elif mid_block_type == "UNetMidBlock3DSimpleCrossAttn":
433
+ self.mid_block = UNetMidBlock3DSimpleCrossAttn(
434
+ in_channels=block_out_channels[-1],
435
+ temb_channels=blocks_time_embed_dim,
436
+ resnet_eps=norm_eps,
437
+ resnet_act_fn=act_fn,
438
+ output_scale_factor=mid_block_scale_factor,
439
+ cross_attention_dim=cross_attention_dim[-1],
440
+ attention_head_dim=attention_head_dim[-1],
441
+ resnet_groups=norm_num_groups,
442
+ resnet_time_scale_shift=resnet_time_scale_shift,
443
+ skip_time_act=resnet_skip_time_act,
444
+ only_cross_attention=mid_block_only_cross_attention,
445
+ cross_attention_norm=cross_attention_norm,
446
+ )
447
+ elif mid_block_type is None:
448
+ self.mid_block = None
449
+ else:
450
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
451
+
452
+ # count how many layers upsample the images
453
+ self.num_upsamplers = 0
454
+
455
+ # up
456
+ reversed_block_out_channels = list(reversed(block_out_channels))
457
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
458
+ reversed_layers_per_block = list(reversed(layers_per_block))
459
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
460
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
461
+ only_cross_attention = list(reversed(only_cross_attention))
462
+
463
+ output_channel = reversed_block_out_channels[0]
464
+ for i, up_block_type in enumerate(up_block_types):
465
+ is_final_block = i == len(block_out_channels) - 1
466
+
467
+ prev_output_channel = output_channel
468
+ output_channel = reversed_block_out_channels[i]
469
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
470
+
471
+ # add upsample block for all BUT final layer
472
+ if not is_final_block:
473
+ add_upsample = True
474
+ self.num_upsamplers += 1
475
+ else:
476
+ add_upsample = False
477
+
478
+ up_block = get_up_block(
479
+ up_block_type,
480
+ num_layers=reversed_layers_per_block[i] + 1,
481
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
482
+ in_channels=input_channel,
483
+ out_channels=output_channel,
484
+ prev_output_channel=prev_output_channel,
485
+ temb_channels=blocks_time_embed_dim,
486
+ add_upsample=add_upsample,
487
+ resnet_eps=norm_eps,
488
+ resnet_act_fn=act_fn,
489
+ resnet_groups=norm_num_groups,
490
+ cross_attention_dim=reversed_cross_attention_dim[i],
491
+ num_attention_heads=reversed_num_attention_heads[i],
492
+ dual_cross_attention=dual_cross_attention,
493
+ use_linear_projection=use_linear_projection,
494
+ only_cross_attention=only_cross_attention[i],
495
+ upcast_attention=upcast_attention,
496
+ resnet_time_scale_shift=resnet_time_scale_shift,
497
+ resnet_skip_time_act=resnet_skip_time_act,
498
+ resnet_out_scale_factor=resnet_out_scale_factor,
499
+ cross_attention_norm=cross_attention_norm,
500
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
501
+ )
502
+ self.up_blocks.append(up_block)
503
+ prev_output_channel = output_channel
504
+
505
+ # out
506
+ if norm_num_groups is not None:
507
+ self.conv_norm_out = nn.GroupNorm(
508
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
509
+ )
510
+
511
+ self.conv_act = get_activation(act_fn)
512
+
513
+ else:
514
+ self.conv_norm_out = None
515
+ self.conv_act = None
516
+
517
+ conv_out_padding = (conv_out_kernel - 1) // 2
518
+ self.conv_out = nn.Conv2d(
519
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
520
+ )
521
+
522
+ @property
523
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
524
+ r"""
525
+ Returns:
526
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
527
+ indexed by its weight name.
528
+ """
529
+ # set recursively
530
+ processors = {}
531
+
532
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
533
+ if hasattr(module, "set_processor"):
534
+ processors[f"{name}.processor"] = module.processor
535
+
536
+ for sub_name, child in module.named_children():
537
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
538
+
539
+ return processors
540
+
541
+ for name, module in self.named_children():
542
+ fn_recursive_add_processors(name, module, processors)
543
+
544
+ return processors
545
+
546
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
547
+ r"""
548
+ Sets the attention processor to use to compute attention.
549
+
550
+ Parameters:
551
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
552
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
553
+ for **all** `Attention` layers.
554
+
555
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
556
+ processor. This is strongly recommended when setting trainable attention processors.
557
+
558
+ """
559
+ # count = len(self.attn_processors.keys())
560
+ # ignore temporal attention
561
+ count = len({k: v for k, v in self.attn_processors.items() if "temp_" not in k}.keys())
562
+
563
+ if isinstance(processor, dict) and len(processor) != count:
564
+ raise ValueError(
565
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
566
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
567
+ )
568
+
569
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
570
+ if hasattr(module, "set_processor") and "temp_" not in name:
571
+ if not isinstance(processor, dict):
572
+ module.set_processor(processor)
573
+ else:
574
+ module.set_processor(processor.pop(f"{name}.processor"))
575
+
576
+ for sub_name, child in module.named_children():
577
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
578
+
579
+ for name, module in self.named_children():
580
+ fn_recursive_attn_processor(name, module, processor)
581
+
582
+ def set_default_attn_processor(self):
583
+ """
584
+ Disables custom attention processors and sets the default attention implementation.
585
+ """
586
+ self.set_attn_processor(AttnProcessor())
587
+
588
+ def set_attention_slice(self, slice_size):
589
+ r"""
590
+ Enable sliced attention computation.
591
+
592
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
593
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
594
+
595
+ Args:
596
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
597
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
598
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
599
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
600
+ must be a multiple of `slice_size`.
601
+ """
602
+ sliceable_head_dims = []
603
+
604
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
605
+ if hasattr(module, "set_attention_slice"):
606
+ sliceable_head_dims.append(module.sliceable_head_dim)
607
+
608
+ for child in module.children():
609
+ fn_recursive_retrieve_sliceable_dims(child)
610
+
611
+ # retrieve number of attention layers
612
+ for module in self.children():
613
+ fn_recursive_retrieve_sliceable_dims(module)
614
+
615
+ num_sliceable_layers = len(sliceable_head_dims)
616
+
617
+ if slice_size == "auto":
618
+ # half the attention head size is usually a good trade-off between
619
+ # speed and memory
620
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
621
+ elif slice_size == "max":
622
+ # make smallest slice possible
623
+ slice_size = num_sliceable_layers * [1]
624
+
625
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
626
+
627
+ if len(slice_size) != len(sliceable_head_dims):
628
+ raise ValueError(
629
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
630
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
631
+ )
632
+
633
+ for i in range(len(slice_size)):
634
+ size = slice_size[i]
635
+ dim = sliceable_head_dims[i]
636
+ if size is not None and size > dim:
637
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
638
+
639
+ # Recursively walk through all the children.
640
+ # Any children which exposes the set_attention_slice method
641
+ # gets the message
642
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
643
+ if hasattr(module, "set_attention_slice"):
644
+ module.set_attention_slice(slice_size.pop())
645
+
646
+ for child in module.children():
647
+ fn_recursive_set_attention_slice(child, slice_size)
648
+
649
+ reversed_slice_size = list(reversed(slice_size))
650
+ for module in self.children():
651
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
652
+
653
+ def _set_gradient_checkpointing(self, module, value=False):
654
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
655
+ module.gradient_checkpointing = value
656
+
657
+ def forward(
658
+ self,
659
+ sample: torch.FloatTensor,
660
+ timestep: Union[torch.Tensor, float, int],
661
+ encoder_hidden_states: torch.Tensor,
662
+ class_labels: Optional[torch.Tensor] = None,
663
+ timestep_cond: Optional[torch.Tensor] = None,
664
+ attention_mask: Optional[torch.Tensor] = None,
665
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
666
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
667
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
668
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
669
+ encoder_attention_mask: Optional[torch.Tensor] = None,
670
+ return_dict: bool = True,
671
+ ) -> Union[UNet3DConditionOutput, Tuple]:
672
+ r"""
673
+ Args:
674
+ sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor
675
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
676
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
677
+ return_dict (`bool`, *optional*, defaults to `True`):
678
+ Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple.
679
+ cross_attention_kwargs (`dict`, *optional*):
680
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
681
+ `self.processor` in
682
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
683
+
684
+ Returns:
685
+ [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`:
686
+ [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
687
+ returning a tuple, the first element is the sample tensor.
688
+ """
689
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
690
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
691
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
692
+ # on the fly if necessary.
693
+ default_overall_up_factor = 2**self.num_upsamplers
694
+
695
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
696
+ forward_upsample_size = False
697
+ upsample_size = None
698
+
699
+
700
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
701
+ logger.info("Forward upsample size to force interpolation output size.")
702
+ forward_upsample_size = True
703
+
704
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
705
+ # expects mask of shape:
706
+ # [batch, key_tokens]
707
+ # adds singleton query_tokens dimension:
708
+ # [batch, 1, key_tokens]
709
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
710
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
711
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
712
+ if attention_mask is not None:
713
+ # assume that mask is expressed as:
714
+ # (1 = keep, 0 = discard)
715
+ # convert mask into a bias that can be added to attention scores:
716
+ # (keep = +0, discard = -10000.0)
717
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
718
+ attention_mask = attention_mask.unsqueeze(1)
719
+
720
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
721
+ if encoder_attention_mask is not None:
722
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
723
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
724
+
725
+ # 0. center input if necessary
726
+ if self.config.center_input_sample:
727
+ sample = 2 * sample - 1.0
728
+
729
+ # 1. time
730
+ timesteps = timestep
731
+ if not torch.is_tensor(timesteps):
732
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
733
+ # This would be a good case for the `match` statement (Python 3.10+)
734
+ is_mps = sample.device.type == "mps"
735
+ if isinstance(timestep, float):
736
+ dtype = torch.float32 if is_mps else torch.float64
737
+ else:
738
+ dtype = torch.int32 if is_mps else torch.int64
739
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
740
+ elif len(timesteps.shape) == 0:
741
+ timesteps = timesteps[None].to(sample.device)
742
+
743
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
744
+ num_frames = sample.shape[2]
745
+ timesteps = timesteps.expand(sample.shape[0])
746
+
747
+ t_emb = self.time_proj(timesteps)
748
+
749
+ # `Timesteps` does not contain any weights and will always return f32 tensors
750
+ # but time_embedding might actually be running in fp16. so we need to cast here.
751
+ # there might be better ways to encapsulate this.
752
+ t_emb = t_emb.to(dtype=sample.dtype)
753
+
754
+ emb = self.time_embedding(t_emb, timestep_cond)
755
+ aug_emb = None
756
+
757
+ if self.class_embedding is not None:
758
+ if class_labels is None:
759
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
760
+
761
+ if self.config.class_embed_type == "timestep":
762
+ class_labels = self.time_proj(class_labels)
763
+
764
+ # `Timesteps` does not contain any weights and will always return f32 tensors
765
+ # there might be better ways to encapsulate this.
766
+ class_labels = class_labels.to(dtype=sample.dtype)
767
+
768
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
769
+
770
+ if self.config.class_embeddings_concat:
771
+ emb = torch.cat([emb, class_emb], dim=-1)
772
+ else:
773
+ emb = emb + class_emb
774
+
775
+ if self.config.addition_embed_type == "text":
776
+ aug_emb = self.add_embedding(encoder_hidden_states)
777
+ elif self.config.addition_embed_type == "text_image":
778
+ # Kandinsky 2.1 - style
779
+ if "image_embeds" not in added_cond_kwargs:
780
+ raise ValueError(
781
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
782
+ )
783
+
784
+ image_embs = added_cond_kwargs.get("image_embeds")
785
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
786
+ aug_emb = self.add_embedding(text_embs, image_embs)
787
+ elif self.config.addition_embed_type == "text_time":
788
+ if "text_embeds" not in added_cond_kwargs:
789
+ raise ValueError(
790
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
791
+ )
792
+ text_embeds = added_cond_kwargs.get("text_embeds")
793
+ if "time_ids" not in added_cond_kwargs:
794
+ raise ValueError(
795
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
796
+ )
797
+ time_ids = added_cond_kwargs.get("time_ids")
798
+ time_embeds = self.add_time_proj(time_ids.flatten())
799
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
800
+
801
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
802
+ add_embeds = add_embeds.to(emb.dtype)
803
+ aug_emb = self.add_embedding(add_embeds)
804
+ elif self.config.addition_embed_type == "image":
805
+ # Kandinsky 2.2 - style
806
+ if "image_embeds" not in added_cond_kwargs:
807
+ raise ValueError(
808
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
809
+ )
810
+ image_embs = added_cond_kwargs.get("image_embeds")
811
+ aug_emb = self.add_embedding(image_embs)
812
+ elif self.config.addition_embed_type == "image_hint":
813
+ # Kandinsky 2.2 - style
814
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
815
+ raise ValueError(
816
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
817
+ )
818
+ image_embs = added_cond_kwargs.get("image_embeds")
819
+ hint = added_cond_kwargs.get("hint")
820
+ aug_emb, hint = self.add_embedding(image_embs, hint)
821
+ sample = torch.cat([sample, hint], dim=1)
822
+
823
+ emb = emb + aug_emb if aug_emb is not None else emb
824
+
825
+ if self.time_embed_act is not None:
826
+ emb = self.time_embed_act(emb)
827
+
828
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
829
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
830
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
831
+ # Kadinsky 2.1 - style
832
+ if "image_embeds" not in added_cond_kwargs:
833
+ raise ValueError(
834
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
835
+ )
836
+
837
+ image_embeds = added_cond_kwargs.get("image_embeds")
838
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
839
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
840
+ # Kandinsky 2.2 - style
841
+ if "image_embeds" not in added_cond_kwargs:
842
+ raise ValueError(
843
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
844
+ )
845
+ image_embeds = added_cond_kwargs.get("image_embeds")
846
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
847
+
848
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
849
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
850
+
851
+ # 2. pre-process
852
+ sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
853
+ sample = self.conv_in(sample)
854
+
855
+ if self.transformer_in_opt:
856
+
857
+ sample = self.transformer_in(
858
+ sample,
859
+ num_frames=num_frames,
860
+ cross_attention_kwargs=cross_attention_kwargs,
861
+ return_dict=False,
862
+ )[0]
863
+
864
+ # 3. down
865
+ down_block_res_samples = (sample,)
866
+ for downsample_block in self.down_blocks:
867
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
868
+ sample, res_samples = downsample_block(
869
+ hidden_states=sample,
870
+ temb=emb,
871
+ encoder_hidden_states=encoder_hidden_states,
872
+ attention_mask=attention_mask,
873
+ num_frames=num_frames,
874
+ cross_attention_kwargs=cross_attention_kwargs,
875
+ encoder_attention_mask=encoder_attention_mask,
876
+ )
877
+ else:
878
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
879
+
880
+ down_block_res_samples += res_samples
881
+
882
+ if down_block_additional_residuals is not None:
883
+ new_down_block_res_samples = ()
884
+
885
+ for down_block_res_sample, down_block_additional_residual in zip(
886
+ down_block_res_samples, down_block_additional_residuals
887
+ ):
888
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
889
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
890
+
891
+ down_block_res_samples = new_down_block_res_samples
892
+
893
+ # 4. mid
894
+ if self.mid_block is not None:
895
+ sample = self.mid_block(
896
+ sample,
897
+ emb,
898
+ encoder_hidden_states=encoder_hidden_states,
899
+ attention_mask=attention_mask,
900
+ num_frames=num_frames,
901
+ cross_attention_kwargs=cross_attention_kwargs,
902
+ encoder_attention_mask=encoder_attention_mask,
903
+ )
904
+
905
+ if mid_block_additional_residual is not None:
906
+ sample = sample + mid_block_additional_residual
907
+
908
+ # 5. up
909
+ for i, upsample_block in enumerate(self.up_blocks):
910
+ is_final_block = i == len(self.up_blocks) - 1
911
+
912
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
913
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
914
+
915
+ # if we have not reached the final block and need to forward the
916
+ # upsample size, we do it here
917
+ if not is_final_block and forward_upsample_size:
918
+ upsample_size = down_block_res_samples[-1].shape[2:]
919
+
920
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
921
+ sample = upsample_block(
922
+ hidden_states=sample,
923
+ temb=emb,
924
+ res_hidden_states_tuple=res_samples,
925
+ encoder_hidden_states=encoder_hidden_states,
926
+ cross_attention_kwargs=cross_attention_kwargs,
927
+ upsample_size=upsample_size,
928
+ attention_mask=attention_mask,
929
+ num_frames=num_frames,
930
+ encoder_attention_mask=encoder_attention_mask,
931
+ )
932
+ else:
933
+ sample = upsample_block(
934
+ hidden_states=sample,
935
+ temb=emb,
936
+ res_hidden_states_tuple=res_samples,
937
+ upsample_size=upsample_size,
938
+ num_frames=num_frames,
939
+ )
940
+
941
+ # 6. post-process
942
+ if self.conv_norm_out:
943
+ sample = self.conv_norm_out(sample)
944
+ sample = self.conv_act(sample)
945
+ sample = self.conv_out(sample)
946
+
947
+ # reshape to (batch, channel, framerate, width, height)
948
+ sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
949
+
950
+ if not return_dict:
951
+ return (sample,)
952
+
953
+ return UNet3DConditionOutput(sample=sample)
954
+
955
+ @classmethod
956
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None):
957
+ import os, json
958
+ if subfolder is not None:
959
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
960
+
961
+ config_file = os.path.join(pretrained_model_path, 'config.json')
962
+ if not os.path.isfile(config_file):
963
+ raise RuntimeError(f"{config_file} does not exist")
964
+ with open(config_file, "r") as f:
965
+ config = json.load(f)
966
+ config["_class_name"] = cls.__name__
967
+
968
+ config["down_block_types"] = [x.replace("2D", "3D") for x in config["down_block_types"]]
969
+ if "mid_block_type" in config.keys():
970
+ config["mid_block_type"] = config["mid_block_type"].replace("2D", "3D")
971
+ config["up_block_types"] = [x.replace("2D", "3D") for x in config["up_block_types"]]
972
+
973
+ from diffusers.utils import WEIGHTS_NAME
974
+ model = cls.from_config(config)
975
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
976
+ if not os.path.isfile(model_file):
977
+ raise RuntimeError(f"{model_file} does not exist")
978
+ state_dict = torch.load(model_file, map_location="cpu")
979
+ for k, v in model.state_dict().items():
980
+ if k not in state_dict:
981
+
982
+ state_dict.update({k: v})
983
+ model.load_state_dict(state_dict)
984
+
985
+ return model
showone/pipelines/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from diffusers.utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
8
+
9
+
10
+ @dataclass
11
+ class TextToVideoPipelineOutput(BaseOutput):
12
+ """
13
+ Output class for text to video pipelines.
14
+
15
+ Args:
16
+ frames (`List[np.ndarray]` or `torch.FloatTensor`)
17
+ List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
18
+ a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list
19
+ denotes the video length i.e., the number of frames.
20
+ """
21
+
22
+ frames: Union[List[np.ndarray], torch.FloatTensor]
23
+
24
+
25
+ try:
26
+ if not (is_transformers_available() and is_torch_available()):
27
+ raise OptionalDependencyNotAvailable()
28
+ except OptionalDependencyNotAvailable:
29
+ from diffusers.utils.dummy_torch_and_transformers_objects import * # noqa F403
30
+ else:
31
+ # from .pipeline_t2v_base_latent import TextToVideoSDPipeline # noqa: F401
32
+ # from .pipeline_t2v_base_latent_sdxl import TextToVideoSDXLPipeline
33
+ from .pipeline_t2v_base_pixel import TextToVideoIFPipeline
34
+ from .pipeline_t2v_interp_pixel import TextToVideoIFInterpPipeline
35
+ # from .pipeline_t2v_sr_latent import TextToVideoSDSuperResolutionPipeline
36
+ from .pipeline_t2v_sr_pixel import TextToVideoIFSuperResolutionPipeline
37
+ # from .pipeline_t2v_base_latent_controlnet import TextToVideoSDControlNetPipeline
showone/pipelines/pipeline_t2v_base_pixel.py ADDED
@@ -0,0 +1,775 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import inspect
3
+ import re
4
+ import urllib.parse as ul
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
10
+
11
+ from diffusers.loaders import LoraLoaderMixin
12
+ from diffusers.schedulers import DDPMScheduler
13
+ from diffusers.utils import (
14
+ BACKENDS_MAPPING,
15
+ is_accelerate_available,
16
+ is_accelerate_version,
17
+ is_bs4_available,
18
+ is_ftfy_available,
19
+ logging,
20
+ randn_tensor,
21
+ )
22
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
23
+
24
+ from ..models import UNet3DConditionModel
25
+ from . import TextToVideoPipelineOutput
26
+
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+ if is_bs4_available():
31
+ from bs4 import BeautifulSoup
32
+
33
+ if is_ftfy_available():
34
+ import ftfy
35
+
36
+
37
+ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
38
+ # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
39
+ # reshape to ncfhw
40
+ mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
41
+ std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
42
+ # unnormalize back to [0,1]
43
+ video = video.mul_(std).add_(mean)
44
+ video.clamp_(0, 1)
45
+ # prepare the final outputs
46
+ i, c, f, h, w = video.shape
47
+ images = video.permute(2, 3, 0, 4, 1).reshape(
48
+ f, h, i * w, c
49
+ ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
50
+ images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
51
+ images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
52
+ return images
53
+
54
+
55
+ class TextToVideoIFPipeline(DiffusionPipeline, LoraLoaderMixin):
56
+ tokenizer: T5Tokenizer
57
+ text_encoder: T5EncoderModel
58
+
59
+ unet: UNet3DConditionModel
60
+ scheduler: DDPMScheduler
61
+
62
+ feature_extractor: Optional[CLIPImageProcessor]
63
+ # safety_checker: Optional[IFSafetyChecker]
64
+
65
+ # watermarker: Optional[IFWatermarker]
66
+
67
+ bad_punct_regex = re.compile(
68
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
69
+ ) # noqa
70
+
71
+ _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
72
+
73
+ def __init__(
74
+ self,
75
+ tokenizer: T5Tokenizer,
76
+ text_encoder: T5EncoderModel,
77
+ unet: UNet3DConditionModel,
78
+ scheduler: DDPMScheduler,
79
+ feature_extractor: Optional[CLIPImageProcessor],
80
+ ):
81
+ super().__init__()
82
+
83
+ self.register_modules(
84
+ tokenizer=tokenizer,
85
+ text_encoder=text_encoder,
86
+ unet=unet,
87
+ scheduler=scheduler,
88
+ feature_extractor=feature_extractor,
89
+ )
90
+ self.safety_checker = None
91
+
92
+ def enable_sequential_cpu_offload(self, gpu_id=0):
93
+ r"""
94
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
95
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
96
+ when their specific submodule has its `forward` method called.
97
+ """
98
+ if is_accelerate_available():
99
+ from accelerate import cpu_offload
100
+ else:
101
+ raise ImportError("Please install accelerate via `pip install accelerate`")
102
+
103
+ device = torch.device(f"cuda:{gpu_id}")
104
+
105
+ models = [
106
+ self.text_encoder,
107
+ self.unet,
108
+ ]
109
+ for cpu_offloaded_model in models:
110
+ if cpu_offloaded_model is not None:
111
+ cpu_offload(cpu_offloaded_model, device)
112
+
113
+ if self.safety_checker is not None:
114
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
115
+
116
+ def enable_model_cpu_offload(self, gpu_id=0):
117
+ r"""
118
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
119
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
120
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
121
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
122
+ """
123
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
124
+ from accelerate import cpu_offload_with_hook
125
+ else:
126
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
127
+
128
+ device = torch.device(f"cuda:{gpu_id}")
129
+
130
+ self.unet.train()
131
+
132
+ if self.device.type != "cpu":
133
+ self.to("cpu", silence_dtype_warnings=True)
134
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
135
+
136
+ hook = None
137
+
138
+ if self.text_encoder is not None:
139
+ _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
140
+
141
+ # Accelerate will move the next model to the device _before_ calling the offload hook of the
142
+ # previous model. This will cause both models to be present on the device at the same time.
143
+ # IF uses T5 for its text encoder which is really large. We can manually call the offload
144
+ # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
145
+ # the GPU.
146
+ self.text_encoder_offload_hook = hook
147
+
148
+ _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
149
+
150
+ # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
151
+ self.unet_offload_hook = hook
152
+
153
+ if self.safety_checker is not None:
154
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
155
+
156
+ # We'll offload the last model manually.
157
+ self.final_offload_hook = hook
158
+
159
+ def remove_all_hooks(self):
160
+ if is_accelerate_available():
161
+ from accelerate.hooks import remove_hook_from_module
162
+ else:
163
+ raise ImportError("Please install accelerate via `pip install accelerate`")
164
+
165
+ for model in [self.text_encoder, self.unet, self.safety_checker]:
166
+ if model is not None:
167
+ remove_hook_from_module(model, recurse=True)
168
+
169
+ self.unet_offload_hook = None
170
+ self.text_encoder_offload_hook = None
171
+ self.final_offload_hook = None
172
+
173
+ @property
174
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
175
+ def _execution_device(self):
176
+ r"""
177
+ Returns the device on which the pipeline's models will be executed. After calling
178
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
179
+ hooks.
180
+ """
181
+ if not hasattr(self.unet, "_hf_hook"):
182
+ return self.device
183
+ for module in self.unet.modules():
184
+ if (
185
+ hasattr(module, "_hf_hook")
186
+ and hasattr(module._hf_hook, "execution_device")
187
+ and module._hf_hook.execution_device is not None
188
+ ):
189
+ return torch.device(module._hf_hook.execution_device)
190
+ return self.device
191
+
192
+ @torch.no_grad()
193
+ def encode_prompt(
194
+ self,
195
+ prompt,
196
+ do_classifier_free_guidance=True,
197
+ num_images_per_prompt=1,
198
+ device=None,
199
+ negative_prompt=None,
200
+ prompt_embeds: Optional[torch.FloatTensor] = None,
201
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
202
+ clean_caption: bool = False,
203
+ ):
204
+ r"""
205
+ Encodes the prompt into text encoder hidden states.
206
+
207
+ Args:
208
+ prompt (`str` or `List[str]`, *optional*):
209
+ prompt to be encoded
210
+ device: (`torch.device`, *optional*):
211
+ torch device to place the resulting embeddings on
212
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
213
+ number of images that should be generated per prompt
214
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
215
+ whether to use classifier free guidance or not
216
+ negative_prompt (`str` or `List[str]`, *optional*):
217
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
218
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
219
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
220
+ prompt_embeds (`torch.FloatTensor`, *optional*):
221
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
222
+ provided, text embeddings will be generated from `prompt` input argument.
223
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
224
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
225
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
226
+ argument.
227
+ """
228
+ if prompt is not None and negative_prompt is not None:
229
+ if type(prompt) is not type(negative_prompt):
230
+ raise TypeError(
231
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
232
+ f" {type(prompt)}."
233
+ )
234
+
235
+ if device is None:
236
+ device = self._execution_device
237
+
238
+ if prompt is not None and isinstance(prompt, str):
239
+ batch_size = 1
240
+ elif prompt is not None and isinstance(prompt, list):
241
+ batch_size = len(prompt)
242
+ else:
243
+ batch_size = prompt_embeds.shape[0]
244
+
245
+ # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
246
+ max_length = 77
247
+
248
+ if prompt_embeds is None:
249
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
250
+ text_inputs = self.tokenizer(
251
+ prompt,
252
+ padding="max_length",
253
+ max_length=max_length,
254
+ truncation=True,
255
+ add_special_tokens=True,
256
+ return_tensors="pt",
257
+ )
258
+ text_input_ids = text_inputs.input_ids
259
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
260
+
261
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
262
+ text_input_ids, untruncated_ids
263
+ ):
264
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
265
+ logger.warning(
266
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
267
+ f" {max_length} tokens: {removed_text}"
268
+ )
269
+
270
+ attention_mask = text_inputs.attention_mask.to(device)
271
+
272
+ prompt_embeds = self.text_encoder(
273
+ text_input_ids.to(device),
274
+ attention_mask=attention_mask,
275
+ )
276
+ prompt_embeds = prompt_embeds[0]
277
+
278
+ if self.text_encoder is not None:
279
+ dtype = self.text_encoder.dtype
280
+ elif self.unet is not None:
281
+ dtype = self.unet.dtype
282
+ else:
283
+ dtype = None
284
+
285
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
286
+
287
+ bs_embed, seq_len, _ = prompt_embeds.shape
288
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
289
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
290
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
291
+
292
+ # get unconditional embeddings for classifier free guidance
293
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
294
+ uncond_tokens: List[str]
295
+ if negative_prompt is None:
296
+ uncond_tokens = [""] * batch_size
297
+ elif isinstance(negative_prompt, str):
298
+ uncond_tokens = [negative_prompt]
299
+ elif batch_size != len(negative_prompt):
300
+ raise ValueError(
301
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
302
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
303
+ " the batch size of `prompt`."
304
+ )
305
+ else:
306
+ uncond_tokens = negative_prompt
307
+
308
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
309
+ max_length = prompt_embeds.shape[1]
310
+ uncond_input = self.tokenizer(
311
+ uncond_tokens,
312
+ padding="max_length",
313
+ max_length=max_length,
314
+ truncation=True,
315
+ return_attention_mask=True,
316
+ add_special_tokens=True,
317
+ return_tensors="pt",
318
+ )
319
+ attention_mask = uncond_input.attention_mask.to(device)
320
+
321
+ negative_prompt_embeds = self.text_encoder(
322
+ uncond_input.input_ids.to(device),
323
+ attention_mask=attention_mask,
324
+ )
325
+ negative_prompt_embeds = negative_prompt_embeds[0]
326
+
327
+ if do_classifier_free_guidance:
328
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
329
+ seq_len = negative_prompt_embeds.shape[1]
330
+
331
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
332
+
333
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
334
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
335
+
336
+ # For classifier free guidance, we need to do two forward passes.
337
+ # Here we concatenate the unconditional and text embeddings into a single batch
338
+ # to avoid doing two forward passes
339
+ else:
340
+ negative_prompt_embeds = None
341
+
342
+ return prompt_embeds, negative_prompt_embeds
343
+
344
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
345
+ def prepare_extra_step_kwargs(self, generator, eta):
346
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
347
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
348
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
349
+ # and should be between [0, 1]
350
+
351
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
352
+ extra_step_kwargs = {}
353
+ if accepts_eta:
354
+ extra_step_kwargs["eta"] = eta
355
+
356
+ # check if the scheduler accepts generator
357
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
358
+ if accepts_generator:
359
+ extra_step_kwargs["generator"] = generator
360
+ return extra_step_kwargs
361
+
362
+ def check_inputs(
363
+ self,
364
+ prompt,
365
+ callback_steps,
366
+ negative_prompt=None,
367
+ prompt_embeds=None,
368
+ negative_prompt_embeds=None,
369
+ ):
370
+ if (callback_steps is None) or (
371
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
372
+ ):
373
+ raise ValueError(
374
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
375
+ f" {type(callback_steps)}."
376
+ )
377
+
378
+ if prompt is not None and prompt_embeds is not None:
379
+ raise ValueError(
380
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
381
+ " only forward one of the two."
382
+ )
383
+ elif prompt is None and prompt_embeds is None:
384
+ raise ValueError(
385
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
386
+ )
387
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
388
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
389
+
390
+ if negative_prompt is not None and negative_prompt_embeds is not None:
391
+ raise ValueError(
392
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
393
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
394
+ )
395
+
396
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
397
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
398
+ raise ValueError(
399
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
400
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
401
+ f" {negative_prompt_embeds.shape}."
402
+ )
403
+
404
+ def prepare_intermediate_images(self, batch_size, num_channels, num_frames, height, width, dtype, device, generator):
405
+ shape = (batch_size, num_channels, num_frames, height, width)
406
+ if isinstance(generator, list) and len(generator) != batch_size:
407
+ raise ValueError(
408
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
409
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
410
+ )
411
+
412
+ intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
413
+
414
+ # scale the initial noise by the standard deviation required by the scheduler
415
+ intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
416
+ return intermediate_images
417
+
418
+ def _text_preprocessing(self, text, clean_caption=False):
419
+ if clean_caption and not is_bs4_available():
420
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
421
+ logger.warn("Setting `clean_caption` to False...")
422
+ clean_caption = False
423
+
424
+ if clean_caption and not is_ftfy_available():
425
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
426
+ logger.warn("Setting `clean_caption` to False...")
427
+ clean_caption = False
428
+
429
+ if not isinstance(text, (tuple, list)):
430
+ text = [text]
431
+
432
+ def process(text: str):
433
+ if clean_caption:
434
+ text = self._clean_caption(text)
435
+ text = self._clean_caption(text)
436
+ else:
437
+ text = text.lower().strip()
438
+ return text
439
+
440
+ return [process(t) for t in text]
441
+
442
+ def _clean_caption(self, caption):
443
+ caption = str(caption)
444
+ caption = ul.unquote_plus(caption)
445
+ caption = caption.strip().lower()
446
+ caption = re.sub("<person>", "person", caption)
447
+ # urls:
448
+ caption = re.sub(
449
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
450
+ "",
451
+ caption,
452
+ ) # regex for urls
453
+ caption = re.sub(
454
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
455
+ "",
456
+ caption,
457
+ ) # regex for urls
458
+ # html:
459
+ caption = BeautifulSoup(caption, features="html.parser").text
460
+
461
+ # @<nickname>
462
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
463
+
464
+ # 31C0—31EF CJK Strokes
465
+ # 31F0—31FF Katakana Phonetic Extensions
466
+ # 3200—32FF Enclosed CJK Letters and Months
467
+ # 3300—33FF CJK Compatibility
468
+ # 3400—4DBF CJK Unified Ideographs Extension A
469
+ # 4DC0—4DFF Yijing Hexagram Symbols
470
+ # 4E00—9FFF CJK Unified Ideographs
471
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
472
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
473
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
474
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
475
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
476
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
477
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
478
+ #######################################################
479
+
480
+ # все виды тире / all types of dash --> "-"
481
+ caption = re.sub(
482
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
483
+ "-",
484
+ caption,
485
+ )
486
+
487
+ # кавычки к одному стандарту
488
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
489
+ caption = re.sub(r"[‘’]", "'", caption)
490
+
491
+ # &quot;
492
+ caption = re.sub(r"&quot;?", "", caption)
493
+ # &amp
494
+ caption = re.sub(r"&amp", "", caption)
495
+
496
+ # ip adresses:
497
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
498
+
499
+ # article ids:
500
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
501
+
502
+ # \n
503
+ caption = re.sub(r"\\n", " ", caption)
504
+
505
+ # "#123"
506
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
507
+ # "#12345.."
508
+ caption = re.sub(r"#\d{5,}\b", "", caption)
509
+ # "123456.."
510
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
511
+ # filenames:
512
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
513
+
514
+ #
515
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
516
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
517
+
518
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
519
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
520
+
521
+ # this-is-my-cute-cat / this_is_my_cute_cat
522
+ regex2 = re.compile(r"(?:\-|\_)")
523
+ if len(re.findall(regex2, caption)) > 3:
524
+ caption = re.sub(regex2, " ", caption)
525
+
526
+ caption = ftfy.fix_text(caption)
527
+ caption = html.unescape(html.unescape(caption))
528
+
529
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
530
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
531
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
532
+
533
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
534
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
535
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
536
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
537
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
538
+
539
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
540
+
541
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
542
+
543
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
544
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
545
+ caption = re.sub(r"\s+", " ", caption)
546
+
547
+ caption.strip()
548
+
549
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
550
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
551
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
552
+ caption = re.sub(r"^\.\S+$", "", caption)
553
+
554
+ return caption.strip()
555
+
556
+ @torch.no_grad()
557
+ def __call__(
558
+ self,
559
+ prompt: Union[str, List[str]] = None,
560
+ num_inference_steps: int = 100,
561
+ timesteps: List[int] = None,
562
+ guidance_scale: float = 7.0,
563
+ negative_prompt: Optional[Union[str, List[str]]] = None,
564
+ num_images_per_prompt: Optional[int] = 1,
565
+ height: Optional[int] = None,
566
+ width: Optional[int] = None,
567
+ num_frames: int = 16,
568
+ eta: float = 0.0,
569
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
570
+ prompt_embeds: Optional[torch.FloatTensor] = None,
571
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
572
+ output_type: Optional[str] = "np",
573
+ return_dict: bool = True,
574
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
575
+ callback_steps: int = 1,
576
+ clean_caption: bool = True,
577
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
578
+ ):
579
+ """
580
+ Function invoked when calling the pipeline for generation.
581
+
582
+ Args:
583
+ prompt (`str` or `List[str]`, *optional*):
584
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
585
+ instead.
586
+ num_inference_steps (`int`, *optional*, defaults to 50):
587
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
588
+ expense of slower inference.
589
+ timesteps (`List[int]`, *optional*):
590
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
591
+ timesteps are used. Must be in descending order.
592
+ guidance_scale (`float`, *optional*, defaults to 7.5):
593
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
594
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
595
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
596
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
597
+ usually at the expense of lower image quality.
598
+ negative_prompt (`str` or `List[str]`, *optional*):
599
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
600
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
601
+ less than `1`).
602
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
603
+ The number of images to generate per prompt.
604
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
605
+ The height in pixels of the generated image.
606
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
607
+ The width in pixels of the generated image.
608
+ eta (`float`, *optional*, defaults to 0.0):
609
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
610
+ [`schedulers.DDIMScheduler`], will be ignored for others.
611
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
612
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
613
+ to make generation deterministic.
614
+ prompt_embeds (`torch.FloatTensor`, *optional*):
615
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
616
+ provided, text embeddings will be generated from `prompt` input argument.
617
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
618
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
619
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
620
+ argument.
621
+ output_type (`str`, *optional*, defaults to `"pil"`):
622
+ The output format of the generate image. Choose between
623
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
624
+ return_dict (`bool`, *optional*, defaults to `True`):
625
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
626
+ callback (`Callable`, *optional*):
627
+ A function that will be called every `callback_steps` steps during inference. The function will be
628
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
629
+ callback_steps (`int`, *optional*, defaults to 1):
630
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
631
+ called at every step.
632
+ clean_caption (`bool`, *optional*, defaults to `True`):
633
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
634
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
635
+ prompt.
636
+ cross_attention_kwargs (`dict`, *optional*):
637
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
638
+ `self.processor` in
639
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
640
+
641
+ Examples:
642
+
643
+ Returns:
644
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
645
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
646
+ returning a tuple, the first element is a list with the generated images, and the second element is a list
647
+ of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
648
+ or watermarked content, according to the `safety_checker`.
649
+ """
650
+ # 1. Check inputs. Raise error if not correct
651
+ self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
652
+
653
+ # 2. Define call parameters
654
+ height = height or self.unet.config.sample_size
655
+ width = width or self.unet.config.sample_size
656
+
657
+ if prompt is not None and isinstance(prompt, str):
658
+ batch_size = 1
659
+ elif prompt is not None and isinstance(prompt, list):
660
+ batch_size = len(prompt)
661
+ else:
662
+ batch_size = prompt_embeds.shape[0]
663
+
664
+ device = self._execution_device
665
+
666
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
667
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
668
+ # corresponds to doing no classifier free guidance.
669
+ do_classifier_free_guidance = guidance_scale > 1.0
670
+
671
+ # 3. Encode input prompt
672
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
673
+ prompt,
674
+ do_classifier_free_guidance,
675
+ num_images_per_prompt=num_images_per_prompt,
676
+ device=device,
677
+ negative_prompt=negative_prompt,
678
+ prompt_embeds=prompt_embeds,
679
+ negative_prompt_embeds=negative_prompt_embeds,
680
+ clean_caption=clean_caption,
681
+ )
682
+
683
+ if do_classifier_free_guidance:
684
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
685
+
686
+ # 4. Prepare timesteps
687
+ if timesteps is not None:
688
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
689
+ timesteps = self.scheduler.timesteps
690
+ num_inference_steps = len(timesteps)
691
+ else:
692
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
693
+ timesteps = self.scheduler.timesteps
694
+
695
+ # 5. Prepare intermediate images
696
+ intermediate_images = self.prepare_intermediate_images(
697
+ batch_size * num_images_per_prompt,
698
+ self.unet.config.in_channels,
699
+ num_frames,
700
+ height,
701
+ width,
702
+ prompt_embeds.dtype,
703
+ device,
704
+ generator,
705
+ )
706
+
707
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
708
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
709
+
710
+ # HACK: see comment in `enable_model_cpu_offload`
711
+ if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
712
+ self.text_encoder_offload_hook.offload()
713
+
714
+ # 7. Denoising loop
715
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
716
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
717
+ for i, t in enumerate(timesteps):
718
+ model_input = (
719
+ torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images
720
+ )
721
+ model_input = self.scheduler.scale_model_input(model_input, t)
722
+
723
+ # predict the noise residual
724
+ noise_pred = self.unet(
725
+ model_input,
726
+ t,
727
+ encoder_hidden_states=prompt_embeds,
728
+ cross_attention_kwargs=cross_attention_kwargs,
729
+ ).sample
730
+
731
+ # perform guidance
732
+ if do_classifier_free_guidance:
733
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
734
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
735
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
736
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
737
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
738
+
739
+ if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
740
+ noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1)
741
+
742
+ # reshape latents
743
+ bsz, channel, frames, height, width = intermediate_images.shape
744
+ intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, height, width)
745
+ noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, height, width)
746
+
747
+ # compute the previous noisy sample x_t -> x_t-1
748
+ intermediate_images = self.scheduler.step(
749
+ noise_pred, t, intermediate_images, **extra_step_kwargs
750
+ ).prev_sample
751
+
752
+ # reshape latents back
753
+ intermediate_images = intermediate_images[None, :].reshape(bsz, frames, channel, height, width).permute(0, 2, 1, 3, 4)
754
+
755
+ # call the callback, if provided
756
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
757
+ progress_bar.update()
758
+ if callback is not None and i % callback_steps == 0:
759
+ callback(i, t, intermediate_images)
760
+
761
+ video_tensor = intermediate_images
762
+
763
+ if output_type == "pt":
764
+ video = video_tensor
765
+ else:
766
+ video = tensor2vid(video_tensor)
767
+
768
+ # Offload last model to CPU
769
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
770
+ self.final_offload_hook.offload()
771
+
772
+ if not return_dict:
773
+ return (video,)
774
+
775
+ return TextToVideoPipelineOutput(frames=video)
showone/pipelines/pipeline_t2v_interp_pixel.py ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import inspect
3
+ import re
4
+ import urllib.parse as ul
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
10
+
11
+ from diffusers.schedulers import DDPMScheduler
12
+ from diffusers.utils import (
13
+ BACKENDS_MAPPING,
14
+ is_accelerate_available,
15
+ is_accelerate_version,
16
+ is_bs4_available,
17
+ is_ftfy_available,
18
+ logging,
19
+ randn_tensor,
20
+ replace_example_docstring,
21
+ )
22
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
23
+
24
+ from ..models import UNet3DConditionModel
25
+ from . import TextToVideoPipelineOutput
26
+
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+ if is_bs4_available():
31
+ from bs4 import BeautifulSoup
32
+
33
+ if is_ftfy_available():
34
+ import ftfy
35
+
36
+
37
+ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
38
+ # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
39
+ # reshape to ncfhw
40
+ mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
41
+ std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
42
+ # unnormalize back to [0,1]
43
+ video = video.mul_(std).add_(mean)
44
+ video.clamp_(0, 1)
45
+ # prepare the final outputs
46
+ i, c, f, h, w = video.shape
47
+ images = video.permute(2, 3, 0, 4, 1).reshape(
48
+ f, h, i * w, c
49
+ ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
50
+ images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
51
+ images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
52
+ return images
53
+
54
+
55
+ class TextToVideoIFInterpPipeline(DiffusionPipeline):
56
+ tokenizer: T5Tokenizer
57
+ text_encoder: T5EncoderModel
58
+
59
+ unet: UNet3DConditionModel
60
+ scheduler: DDPMScheduler
61
+
62
+ feature_extractor: Optional[CLIPImageProcessor]
63
+ # safety_checker: Optional[IFSafetyChecker]
64
+
65
+ # watermarker: Optional[IFWatermarker]
66
+
67
+ bad_punct_regex = re.compile(
68
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
69
+ ) # noqa
70
+
71
+ _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
72
+
73
+ def __init__(
74
+ self,
75
+ tokenizer: T5Tokenizer,
76
+ text_encoder: T5EncoderModel,
77
+ unet: UNet3DConditionModel,
78
+ scheduler: DDPMScheduler,
79
+ feature_extractor: Optional[CLIPImageProcessor],
80
+ ):
81
+ super().__init__()
82
+
83
+ self.register_modules(
84
+ tokenizer=tokenizer,
85
+ text_encoder=text_encoder,
86
+ unet=unet,
87
+ scheduler=scheduler,
88
+ feature_extractor=feature_extractor,
89
+ )
90
+ self.safety_checker = None
91
+
92
+ def enable_sequential_cpu_offload(self, gpu_id=0):
93
+ r"""
94
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
95
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
96
+ when their specific submodule has its `forward` method called.
97
+ """
98
+ if is_accelerate_available():
99
+ from accelerate import cpu_offload
100
+ else:
101
+ raise ImportError("Please install accelerate via `pip install accelerate`")
102
+
103
+ device = torch.device(f"cuda:{gpu_id}")
104
+
105
+ models = [
106
+ self.text_encoder,
107
+ self.unet,
108
+ ]
109
+ for cpu_offloaded_model in models:
110
+ if cpu_offloaded_model is not None:
111
+ cpu_offload(cpu_offloaded_model, device)
112
+
113
+ if self.safety_checker is not None:
114
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
115
+
116
+ def enable_model_cpu_offload(self, gpu_id=0):
117
+ r"""
118
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
119
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
120
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
121
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
122
+ """
123
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
124
+ from accelerate import cpu_offload_with_hook
125
+ else:
126
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
127
+
128
+ device = torch.device(f"cuda:{gpu_id}")
129
+
130
+
131
+ if self.device.type != "cpu":
132
+ self.to("cpu", silence_dtype_warnings=True)
133
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
134
+
135
+ hook = None
136
+
137
+ if self.text_encoder is not None:
138
+ _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
139
+
140
+ # Accelerate will move the next model to the device _before_ calling the offload hook of the
141
+ # previous model. This will cause both models to be present on the device at the same time.
142
+ # IF uses T5 for its text encoder which is really large. We can manually call the offload
143
+ # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
144
+ # the GPU.
145
+ self.text_encoder_offload_hook = hook
146
+
147
+ _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
148
+
149
+ # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
150
+ self.unet_offload_hook = hook
151
+
152
+ if self.safety_checker is not None:
153
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
154
+
155
+ # We'll offload the last model manually.
156
+ self.final_offload_hook = hook
157
+
158
+ def remove_all_hooks(self):
159
+ if is_accelerate_available():
160
+ from accelerate.hooks import remove_hook_from_module
161
+ else:
162
+ raise ImportError("Please install accelerate via `pip install accelerate`")
163
+
164
+ for model in [self.text_encoder, self.unet, self.safety_checker]:
165
+ if model is not None:
166
+ remove_hook_from_module(model, recurse=True)
167
+
168
+ self.unet_offload_hook = None
169
+ self.text_encoder_offload_hook = None
170
+ self.final_offload_hook = None
171
+
172
+ @property
173
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
174
+ def _execution_device(self):
175
+ r"""
176
+ Returns the device on which the pipeline's models will be executed. After calling
177
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
178
+ hooks.
179
+ """
180
+ if not hasattr(self.unet, "_hf_hook"):
181
+ return self.device
182
+ for module in self.unet.modules():
183
+ if (
184
+ hasattr(module, "_hf_hook")
185
+ and hasattr(module._hf_hook, "execution_device")
186
+ and module._hf_hook.execution_device is not None
187
+ ):
188
+ return torch.device(module._hf_hook.execution_device)
189
+ return self.device
190
+
191
+ @torch.no_grad()
192
+ def encode_prompt(
193
+ self,
194
+ prompt,
195
+ do_classifier_free_guidance=True,
196
+ num_images_per_prompt=1,
197
+ device=None,
198
+ negative_prompt=None,
199
+ prompt_embeds: Optional[torch.FloatTensor] = None,
200
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
201
+ clean_caption: bool = False,
202
+ ):
203
+ r"""
204
+ Encodes the prompt into text encoder hidden states.
205
+
206
+ Args:
207
+ prompt (`str` or `List[str]`, *optional*):
208
+ prompt to be encoded
209
+ device: (`torch.device`, *optional*):
210
+ torch device to place the resulting embeddings on
211
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
212
+ number of images that should be generated per prompt
213
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
214
+ whether to use classifier free guidance or not
215
+ negative_prompt (`str` or `List[str]`, *optional*):
216
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
217
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
218
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
219
+ prompt_embeds (`torch.FloatTensor`, *optional*):
220
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
221
+ provided, text embeddings will be generated from `prompt` input argument.
222
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
223
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
224
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
225
+ argument.
226
+ """
227
+ if prompt is not None and negative_prompt is not None:
228
+ if type(prompt) is not type(negative_prompt):
229
+ raise TypeError(
230
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
231
+ f" {type(prompt)}."
232
+ )
233
+
234
+ if device is None:
235
+ device = self._execution_device
236
+
237
+ if prompt is not None and isinstance(prompt, str):
238
+ batch_size = 1
239
+ elif prompt is not None and isinstance(prompt, list):
240
+ batch_size = len(prompt)
241
+ else:
242
+ batch_size = prompt_embeds.shape[0]
243
+
244
+ # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
245
+ max_length = 77
246
+
247
+ if prompt_embeds is None:
248
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
249
+ text_inputs = self.tokenizer(
250
+ prompt,
251
+ padding="max_length",
252
+ max_length=max_length,
253
+ truncation=True,
254
+ add_special_tokens=True,
255
+ return_tensors="pt",
256
+ )
257
+ text_input_ids = text_inputs.input_ids
258
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
259
+
260
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
261
+ text_input_ids, untruncated_ids
262
+ ):
263
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
264
+ logger.warning(
265
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
266
+ f" {max_length} tokens: {removed_text}"
267
+ )
268
+
269
+ attention_mask = text_inputs.attention_mask.to(device)
270
+
271
+ prompt_embeds = self.text_encoder(
272
+ text_input_ids.to(device),
273
+ attention_mask=attention_mask,
274
+ )
275
+ prompt_embeds = prompt_embeds[0]
276
+
277
+ if self.text_encoder is not None:
278
+ dtype = self.text_encoder.dtype
279
+ elif self.unet is not None:
280
+ dtype = self.unet.dtype
281
+ else:
282
+ dtype = None
283
+
284
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
285
+
286
+ bs_embed, seq_len, _ = prompt_embeds.shape
287
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
288
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
289
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
290
+
291
+ # get unconditional embeddings for classifier free guidance
292
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
293
+ uncond_tokens: List[str]
294
+ if negative_prompt is None:
295
+ uncond_tokens = [""] * batch_size
296
+ elif isinstance(negative_prompt, str):
297
+ uncond_tokens = [negative_prompt]
298
+ elif batch_size != len(negative_prompt):
299
+ raise ValueError(
300
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
301
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
302
+ " the batch size of `prompt`."
303
+ )
304
+ else:
305
+ uncond_tokens = negative_prompt
306
+
307
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
308
+ max_length = prompt_embeds.shape[1]
309
+ uncond_input = self.tokenizer(
310
+ uncond_tokens,
311
+ padding="max_length",
312
+ max_length=max_length,
313
+ truncation=True,
314
+ return_attention_mask=True,
315
+ add_special_tokens=True,
316
+ return_tensors="pt",
317
+ )
318
+ attention_mask = uncond_input.attention_mask.to(device)
319
+
320
+ negative_prompt_embeds = self.text_encoder(
321
+ uncond_input.input_ids.to(device),
322
+ attention_mask=attention_mask,
323
+ )
324
+ negative_prompt_embeds = negative_prompt_embeds[0]
325
+
326
+ if do_classifier_free_guidance:
327
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
328
+ seq_len = negative_prompt_embeds.shape[1]
329
+
330
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
331
+
332
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
333
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
334
+
335
+ # For classifier free guidance, we need to do two forward passes.
336
+ # Here we concatenate the unconditional and text embeddings into a single batch
337
+ # to avoid doing two forward passes
338
+ else:
339
+ negative_prompt_embeds = None
340
+
341
+ return prompt_embeds, negative_prompt_embeds
342
+
343
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
344
+ def prepare_extra_step_kwargs(self, generator, eta):
345
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
346
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
347
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
348
+ # and should be between [0, 1]
349
+
350
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
351
+ extra_step_kwargs = {}
352
+ if accepts_eta:
353
+ extra_step_kwargs["eta"] = eta
354
+
355
+ # check if the scheduler accepts generator
356
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
357
+ if accepts_generator:
358
+ extra_step_kwargs["generator"] = generator
359
+ return extra_step_kwargs
360
+
361
+ def check_inputs(
362
+ self,
363
+ prompt,
364
+ callback_steps,
365
+ negative_prompt=None,
366
+ prompt_embeds=None,
367
+ negative_prompt_embeds=None,
368
+ ):
369
+ if (callback_steps is None) or (
370
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
371
+ ):
372
+ raise ValueError(
373
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
374
+ f" {type(callback_steps)}."
375
+ )
376
+
377
+ if prompt is not None and prompt_embeds is not None:
378
+ raise ValueError(
379
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
380
+ " only forward one of the two."
381
+ )
382
+ elif prompt is None and prompt_embeds is None:
383
+ raise ValueError(
384
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
385
+ )
386
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
387
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
388
+
389
+ if negative_prompt is not None and negative_prompt_embeds is not None:
390
+ raise ValueError(
391
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
392
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
393
+ )
394
+
395
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
396
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
397
+ raise ValueError(
398
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
399
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
400
+ f" {negative_prompt_embeds.shape}."
401
+ )
402
+
403
+ def prepare_intermediate_images(self, batch_size, num_channels, num_frames, height, width, dtype, device, generator):
404
+ shape = (batch_size, num_channels, num_frames, height, width)
405
+ if isinstance(generator, list) and len(generator) != batch_size:
406
+ raise ValueError(
407
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
408
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
409
+ )
410
+
411
+ intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
412
+
413
+ # scale the initial noise by the standard deviation required by the scheduler
414
+ intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
415
+ return intermediate_images
416
+
417
+ def _text_preprocessing(self, text, clean_caption=False):
418
+ if clean_caption and not is_bs4_available():
419
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
420
+ logger.warn("Setting `clean_caption` to False...")
421
+ clean_caption = False
422
+
423
+ if clean_caption and not is_ftfy_available():
424
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
425
+ logger.warn("Setting `clean_caption` to False...")
426
+ clean_caption = False
427
+
428
+ if not isinstance(text, (tuple, list)):
429
+ text = [text]
430
+
431
+ def process(text: str):
432
+ if clean_caption:
433
+ text = self._clean_caption(text)
434
+ text = self._clean_caption(text)
435
+ else:
436
+ text = text.lower().strip()
437
+ return text
438
+
439
+ return [process(t) for t in text]
440
+
441
+ def _clean_caption(self, caption):
442
+ caption = str(caption)
443
+ caption = ul.unquote_plus(caption)
444
+ caption = caption.strip().lower()
445
+ caption = re.sub("<person>", "person", caption)
446
+ # urls:
447
+ caption = re.sub(
448
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
449
+ "",
450
+ caption,
451
+ ) # regex for urls
452
+ caption = re.sub(
453
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
454
+ "",
455
+ caption,
456
+ ) # regex for urls
457
+ # html:
458
+ caption = BeautifulSoup(caption, features="html.parser").text
459
+
460
+ # @<nickname>
461
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
462
+
463
+ # 31C0—31EF CJK Strokes
464
+ # 31F0—31FF Katakana Phonetic Extensions
465
+ # 3200—32FF Enclosed CJK Letters and Months
466
+ # 3300—33FF CJK Compatibility
467
+ # 3400—4DBF CJK Unified Ideographs Extension A
468
+ # 4DC0—4DFF Yijing Hexagram Symbols
469
+ # 4E00—9FFF CJK Unified Ideographs
470
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
471
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
472
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
473
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
474
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
475
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
476
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
477
+ #######################################################
478
+
479
+ # все виды тире / all types of dash --> "-"
480
+ caption = re.sub(
481
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
482
+ "-",
483
+ caption,
484
+ )
485
+
486
+ # кавычки к одному стандарту
487
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
488
+ caption = re.sub(r"[‘’]", "'", caption)
489
+
490
+ # &quot;
491
+ caption = re.sub(r"&quot;?", "", caption)
492
+ # &amp
493
+ caption = re.sub(r"&amp", "", caption)
494
+
495
+ # ip adresses:
496
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
497
+
498
+ # article ids:
499
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
500
+
501
+ # \n
502
+ caption = re.sub(r"\\n", " ", caption)
503
+
504
+ # "#123"
505
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
506
+ # "#12345.."
507
+ caption = re.sub(r"#\d{5,}\b", "", caption)
508
+ # "123456.."
509
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
510
+ # filenames:
511
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
512
+
513
+ #
514
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
515
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
516
+
517
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
518
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
519
+
520
+ # this-is-my-cute-cat / this_is_my_cute_cat
521
+ regex2 = re.compile(r"(?:\-|\_)")
522
+ if len(re.findall(regex2, caption)) > 3:
523
+ caption = re.sub(regex2, " ", caption)
524
+
525
+ caption = ftfy.fix_text(caption)
526
+ caption = html.unescape(html.unescape(caption))
527
+
528
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
529
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
530
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
531
+
532
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
533
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
534
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
535
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
536
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
537
+
538
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
539
+
540
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
541
+
542
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
543
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
544
+ caption = re.sub(r"\s+", " ", caption)
545
+
546
+ caption.strip()
547
+
548
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
549
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
550
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
551
+ caption = re.sub(r"^\.\S+$", "", caption)
552
+
553
+ return caption.strip()
554
+
555
+ @torch.no_grad()
556
+ def __call__(
557
+ self,
558
+ pixel_values,
559
+ prompt: Union[str, List[str]] = None,
560
+ num_inference_steps: int = 100,
561
+ timesteps: List[int] = None,
562
+ guidance_scale: float = 7.0,
563
+ negative_prompt: Optional[Union[str, List[str]]] = None,
564
+ num_images_per_prompt: Optional[int] = 1,
565
+ height: Optional[int] = None,
566
+ width: Optional[int] = None,
567
+ num_frames: int = 16,
568
+ eta: float = 0.0,
569
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
570
+ prompt_embeds: Optional[torch.FloatTensor] = None,
571
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
572
+ output_type: Optional[str] = "np",
573
+ return_dict: bool = True,
574
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
575
+ callback_steps: int = 1,
576
+ clean_caption: bool = True,
577
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
578
+ init_noise = None,
579
+ cond_interpolation = False,
580
+ ):
581
+ """
582
+ Function invoked when calling the pipeline for generation.
583
+
584
+ Args:
585
+ prompt (`str` or `List[str]`, *optional*):
586
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
587
+ instead.
588
+ num_inference_steps (`int`, *optional*, defaults to 50):
589
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
590
+ expense of slower inference.
591
+ timesteps (`List[int]`, *optional*):
592
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
593
+ timesteps are used. Must be in descending order.
594
+ guidance_scale (`float`, *optional*, defaults to 7.5):
595
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
596
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
597
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
598
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
599
+ usually at the expense of lower image quality.
600
+ negative_prompt (`str` or `List[str]`, *optional*):
601
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
602
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
603
+ less than `1`).
604
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
605
+ The number of images to generate per prompt.
606
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
607
+ The height in pixels of the generated image.
608
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
609
+ The width in pixels of the generated image.
610
+ eta (`float`, *optional*, defaults to 0.0):
611
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
612
+ [`schedulers.DDIMScheduler`], will be ignored for others.
613
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
614
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
615
+ to make generation deterministic.
616
+ prompt_embeds (`torch.FloatTensor`, *optional*):
617
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
618
+ provided, text embeddings will be generated from `prompt` input argument.
619
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
620
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
621
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
622
+ argument.
623
+ output_type (`str`, *optional*, defaults to `"pil"`):
624
+ The output format of the generate image. Choose between
625
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
626
+ return_dict (`bool`, *optional*, defaults to `True`):
627
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
628
+ callback (`Callable`, *optional*):
629
+ A function that will be called every `callback_steps` steps during inference. The function will be
630
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
631
+ callback_steps (`int`, *optional*, defaults to 1):
632
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
633
+ called at every step.
634
+ clean_caption (`bool`, *optional*, defaults to `True`):
635
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
636
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
637
+ prompt.
638
+ cross_attention_kwargs (`dict`, *optional*):
639
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
640
+ `self.processor` in
641
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
642
+
643
+ Examples:
644
+
645
+ Returns:
646
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
647
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
648
+ returning a tuple, the first element is a list with the generated images, and the second element is a list
649
+ of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
650
+ or watermarked content, according to the `safety_checker`.
651
+ """
652
+ # 1. Check inputs. Raise error if not correct
653
+ self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
654
+
655
+ # 2. Define call parameters
656
+ height = height or self.unet.config.sample_size
657
+ width = width or self.unet.config.sample_size
658
+
659
+ if prompt is not None and isinstance(prompt, str):
660
+ batch_size = 1
661
+ elif prompt is not None and isinstance(prompt, list):
662
+ batch_size = len(prompt)
663
+ else:
664
+ batch_size = prompt_embeds.shape[0]
665
+
666
+ device = self._execution_device
667
+
668
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
669
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
670
+ # corresponds to doing no classifier free guidance.
671
+ do_classifier_free_guidance = guidance_scale > 1.0
672
+
673
+ # 3. Encode input prompt
674
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
675
+ prompt,
676
+ do_classifier_free_guidance,
677
+ num_images_per_prompt=num_images_per_prompt,
678
+ device=device,
679
+ negative_prompt=negative_prompt,
680
+ prompt_embeds=prompt_embeds,
681
+ negative_prompt_embeds=negative_prompt_embeds,
682
+ clean_caption=clean_caption,
683
+ )
684
+
685
+ if do_classifier_free_guidance:
686
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
687
+
688
+ # 4. Prepare timesteps
689
+ if timesteps is not None:
690
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
691
+ timesteps = self.scheduler.timesteps
692
+ num_inference_steps = len(timesteps)
693
+ else:
694
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
695
+ timesteps = self.scheduler.timesteps
696
+
697
+ # 5. Prepare intermediate images
698
+ pixel_values = pixel_values.to(device)
699
+ if init_noise is not None:
700
+ intermediate_images = init_noise
701
+ else:
702
+ intermediate_images = self.prepare_intermediate_images(
703
+ batch_size * num_images_per_prompt,
704
+ # self.unet.config.in_channels, # mask not noise.
705
+ pixel_values.shape[1],
706
+ num_frames,
707
+ height,
708
+ width,
709
+ prompt_embeds.dtype,
710
+ device,
711
+ generator,
712
+ )
713
+
714
+ bsz = intermediate_images.shape[0]
715
+ interp_mask = torch.zeros(bsz, 1, *intermediate_images.shape[2:], device=device, dtype=intermediate_images.dtype)
716
+ interp_mask[:, :, 0, :, :] = 1
717
+ interp_mask[:, :, -1, :, :] = 1
718
+
719
+ if cond_interpolation:
720
+ import torch.nn.functional as F
721
+ pixel_values = F.interpolate(pixel_values[:, :, [0, -1], ...], pixel_values.shape[2:],
722
+ mode="trilinear", align_corners=True)
723
+ else:
724
+ raise Exception("apply mask to pixel_values")
725
+
726
+ # intermediate_images[:, :, 0, :, :] = pixel_values[:, :, 0, :, :]
727
+ # intermediate_images[:, :, -1, :, :] = pixel_values[:, :, -1, :, :]
728
+ pixel_values_condition = torch.cat((pixel_values, interp_mask), dim=1)
729
+
730
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
731
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
732
+
733
+ # HACK: see comment in `enable_model_cpu_offload`
734
+ if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
735
+ self.text_encoder_offload_hook.offload()
736
+
737
+ # 7. Denoising loop
738
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
739
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
740
+ for i, t in enumerate(timesteps):
741
+ intermediate_images_input = torch.cat((intermediate_images, pixel_values_condition), dim=1)
742
+ model_input = (
743
+ torch.cat([intermediate_images_input] * 2) if do_classifier_free_guidance else intermediate_images_input
744
+ )
745
+ model_input = self.scheduler.scale_model_input(model_input, t)
746
+
747
+ # predict the noise residual
748
+ noise_pred = self.unet(
749
+ model_input,
750
+ t,
751
+ encoder_hidden_states=prompt_embeds,
752
+ cross_attention_kwargs=cross_attention_kwargs,
753
+ ).sample
754
+ # perform guidance
755
+ if do_classifier_free_guidance:
756
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
757
+ noise_pred_uncond, _ = noise_pred_uncond.split(intermediate_images.shape[1], dim=1)
758
+ noise_pred_text, predicted_variance = noise_pred_text.split(intermediate_images.shape[1], dim=1)
759
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
760
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
761
+
762
+ if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
763
+ noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
764
+
765
+ # reshape latents
766
+ bsz, channel, frames, width, height = intermediate_images.shape
767
+ intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
768
+ noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, width, height)
769
+
770
+ # compute the previous noisy sample x_t -> x_t-1
771
+ intermediate_images = self.scheduler.step(
772
+ noise_pred, t, intermediate_images, **extra_step_kwargs
773
+ ).prev_sample
774
+
775
+ # reshape latents back
776
+ intermediate_images = intermediate_images[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4)
777
+
778
+ # call the callback, if provided
779
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
780
+ progress_bar.update()
781
+ if callback is not None and i % callback_steps == 0:
782
+ callback(i, t, intermediate_images)
783
+
784
+ video_tensor = intermediate_images
785
+
786
+ if output_type == "pt":
787
+ video = video_tensor
788
+ else:
789
+ video = tensor2vid(video_tensor)
790
+
791
+ # Offload last model to CPU
792
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
793
+ self.final_offload_hook.offload()
794
+
795
+ if not return_dict:
796
+ return (video,)
797
+
798
+ return TextToVideoPipelineOutput(frames=video)
showone/pipelines/pipeline_t2v_sr_pixel.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import inspect
3
+ import re
4
+ import urllib.parse as ul
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+ from einops import rearrange
9
+ import PIL
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
13
+
14
+ from diffusers.loaders import LoraLoaderMixin
15
+ from diffusers.schedulers import DDPMScheduler
16
+ from diffusers.utils import (
17
+ BACKENDS_MAPPING,
18
+ is_accelerate_available,
19
+ is_accelerate_version,
20
+ is_bs4_available,
21
+ is_ftfy_available,
22
+ logging,
23
+ randn_tensor,
24
+ )
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
+
27
+ from ..models import UNet3DConditionModel
28
+ from . import TextToVideoPipelineOutput
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+ if is_bs4_available():
34
+ from bs4 import BeautifulSoup
35
+
36
+ if is_ftfy_available():
37
+ import ftfy
38
+
39
+
40
+ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
41
+ # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
42
+ # reshape to ncfhw
43
+ mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
44
+ std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
45
+ # unnormalize back to [0,1]
46
+ video = video.mul_(std).add_(mean)
47
+ video.clamp_(0, 1)
48
+ # prepare the final outputs
49
+ i, c, f, h, w = video.shape
50
+ images = video.permute(2, 3, 0, 4, 1).reshape(
51
+ f, h, i * w, c
52
+ ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
53
+ images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
54
+ images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
55
+ return images
56
+
57
+
58
+ class TextToVideoIFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
59
+ tokenizer: T5Tokenizer
60
+ text_encoder: T5EncoderModel
61
+
62
+ unet: UNet3DConditionModel
63
+ scheduler: DDPMScheduler
64
+ image_noising_scheduler: DDPMScheduler
65
+
66
+ feature_extractor: Optional[CLIPImageProcessor]
67
+ # safety_checker: Optional[IFSafetyChecker]
68
+
69
+ # watermarker: Optional[IFWatermarker]
70
+
71
+ bad_punct_regex = re.compile(
72
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
73
+ ) # noqa
74
+
75
+ _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
76
+
77
+ def __init__(
78
+ self,
79
+ tokenizer: T5Tokenizer,
80
+ text_encoder: T5EncoderModel,
81
+ unet: UNet3DConditionModel,
82
+ scheduler: DDPMScheduler,
83
+ image_noising_scheduler: DDPMScheduler,
84
+ feature_extractor: Optional[CLIPImageProcessor],
85
+ ):
86
+ super().__init__()
87
+
88
+ self.register_modules(
89
+ tokenizer=tokenizer,
90
+ text_encoder=text_encoder,
91
+ unet=unet,
92
+ scheduler=scheduler,
93
+ image_noising_scheduler=image_noising_scheduler,
94
+ feature_extractor=feature_extractor,
95
+ )
96
+ self.safety_checker = None
97
+
98
+ def enable_sequential_cpu_offload(self, gpu_id=0):
99
+ r"""
100
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
101
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
102
+ when their specific submodule has its `forward` method called.
103
+ """
104
+ if is_accelerate_available():
105
+ from accelerate import cpu_offload
106
+ else:
107
+ raise ImportError("Please install accelerate via `pip install accelerate`")
108
+
109
+ device = torch.device(f"cuda:{gpu_id}")
110
+
111
+ models = [
112
+ self.text_encoder,
113
+ self.unet,
114
+ ]
115
+ for cpu_offloaded_model in models:
116
+ if cpu_offloaded_model is not None:
117
+ cpu_offload(cpu_offloaded_model, device)
118
+
119
+ if self.safety_checker is not None:
120
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
121
+
122
+ def enable_model_cpu_offload(self, gpu_id=0):
123
+ r"""
124
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
125
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
126
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
127
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
128
+ """
129
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
130
+ from accelerate import cpu_offload_with_hook
131
+ else:
132
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
133
+
134
+ device = torch.device(f"cuda:{gpu_id}")
135
+
136
+ if self.device.type != "cpu":
137
+ self.to("cpu", silence_dtype_warnings=True)
138
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
139
+
140
+ hook = None
141
+
142
+ if self.text_encoder is not None:
143
+ _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
144
+
145
+ # Accelerate will move the next model to the device _before_ calling the offload hook of the
146
+ # previous model. This will cause both models to be present on the device at the same time.
147
+ # IF uses T5 for its text encoder which is really large. We can manually call the offload
148
+ # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
149
+ # the GPU.
150
+ self.text_encoder_offload_hook = hook
151
+
152
+ _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
153
+
154
+ # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
155
+ self.unet_offload_hook = hook
156
+
157
+ if self.safety_checker is not None:
158
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
159
+
160
+ # We'll offload the last model manually.
161
+ self.final_offload_hook = hook
162
+
163
+ def remove_all_hooks(self):
164
+ if is_accelerate_available():
165
+ from accelerate.hooks import remove_hook_from_module
166
+ else:
167
+ raise ImportError("Please install accelerate via `pip install accelerate`")
168
+
169
+ for model in [self.text_encoder, self.unet, self.safety_checker]:
170
+ if model is not None:
171
+ remove_hook_from_module(model, recurse=True)
172
+
173
+ self.unet_offload_hook = None
174
+ self.text_encoder_offload_hook = None
175
+ self.final_offload_hook = None
176
+
177
+ @property
178
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
179
+ def _execution_device(self):
180
+ r"""
181
+ Returns the device on which the pipeline's models will be executed. After calling
182
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
183
+ hooks.
184
+ """
185
+ if not hasattr(self.unet, "_hf_hook"):
186
+ return self.device
187
+ for module in self.unet.modules():
188
+ if (
189
+ hasattr(module, "_hf_hook")
190
+ and hasattr(module._hf_hook, "execution_device")
191
+ and module._hf_hook.execution_device is not None
192
+ ):
193
+ return torch.device(module._hf_hook.execution_device)
194
+ return self.device
195
+
196
+ @torch.no_grad()
197
+ def encode_prompt(
198
+ self,
199
+ prompt,
200
+ do_classifier_free_guidance=True,
201
+ num_images_per_prompt=1,
202
+ device=None,
203
+ negative_prompt=None,
204
+ prompt_embeds: Optional[torch.FloatTensor] = None,
205
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
206
+ clean_caption: bool = False,
207
+ ):
208
+ r"""
209
+ Encodes the prompt into text encoder hidden states.
210
+
211
+ Args:
212
+ prompt (`str` or `List[str]`, *optional*):
213
+ prompt to be encoded
214
+ device: (`torch.device`, *optional*):
215
+ torch device to place the resulting embeddings on
216
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
217
+ number of images that should be generated per prompt
218
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
219
+ whether to use classifier free guidance or not
220
+ negative_prompt (`str` or `List[str]`, *optional*):
221
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
222
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
223
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
224
+ prompt_embeds (`torch.FloatTensor`, *optional*):
225
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
226
+ provided, text embeddings will be generated from `prompt` input argument.
227
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
228
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
229
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
230
+ argument.
231
+ """
232
+ if prompt is not None and negative_prompt is not None:
233
+ if type(prompt) is not type(negative_prompt):
234
+ raise TypeError(
235
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
236
+ f" {type(prompt)}."
237
+ )
238
+
239
+ if device is None:
240
+ device = self._execution_device
241
+
242
+ if prompt is not None and isinstance(prompt, str):
243
+ batch_size = 1
244
+ elif prompt is not None and isinstance(prompt, list):
245
+ batch_size = len(prompt)
246
+ else:
247
+ batch_size = prompt_embeds.shape[0]
248
+
249
+ # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
250
+ max_length = 77
251
+
252
+ if prompt_embeds is None:
253
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
254
+ text_inputs = self.tokenizer(
255
+ prompt,
256
+ padding="max_length",
257
+ max_length=max_length,
258
+ truncation=True,
259
+ add_special_tokens=True,
260
+ return_tensors="pt",
261
+ )
262
+ text_input_ids = text_inputs.input_ids
263
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
264
+
265
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
266
+ text_input_ids, untruncated_ids
267
+ ):
268
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
269
+ logger.warning(
270
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
271
+ f" {max_length} tokens: {removed_text}"
272
+ )
273
+
274
+ attention_mask = text_inputs.attention_mask.to(device)
275
+
276
+ prompt_embeds = self.text_encoder(
277
+ text_input_ids.to(device),
278
+ attention_mask=attention_mask,
279
+ )
280
+ prompt_embeds = prompt_embeds[0]
281
+
282
+ if self.text_encoder is not None:
283
+ dtype = self.text_encoder.dtype
284
+ elif self.unet is not None:
285
+ dtype = self.unet.dtype
286
+ else:
287
+ dtype = None
288
+
289
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
290
+
291
+ bs_embed, seq_len, _ = prompt_embeds.shape
292
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
293
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
294
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
295
+
296
+ # get unconditional embeddings for classifier free guidance
297
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
298
+ uncond_tokens: List[str]
299
+ if negative_prompt is None:
300
+ uncond_tokens = [""] * batch_size
301
+ elif isinstance(negative_prompt, str):
302
+ uncond_tokens = [negative_prompt]
303
+ elif batch_size != len(negative_prompt):
304
+ raise ValueError(
305
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
306
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
307
+ " the batch size of `prompt`."
308
+ )
309
+ else:
310
+ uncond_tokens = negative_prompt
311
+
312
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
313
+ max_length = prompt_embeds.shape[1]
314
+ uncond_input = self.tokenizer(
315
+ uncond_tokens,
316
+ padding="max_length",
317
+ max_length=max_length,
318
+ truncation=True,
319
+ return_attention_mask=True,
320
+ add_special_tokens=True,
321
+ return_tensors="pt",
322
+ )
323
+ attention_mask = uncond_input.attention_mask.to(device)
324
+
325
+ negative_prompt_embeds = self.text_encoder(
326
+ uncond_input.input_ids.to(device),
327
+ attention_mask=attention_mask,
328
+ )
329
+ negative_prompt_embeds = negative_prompt_embeds[0]
330
+
331
+ if do_classifier_free_guidance:
332
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
333
+ seq_len = negative_prompt_embeds.shape[1]
334
+
335
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
336
+
337
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
338
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
339
+
340
+ # For classifier free guidance, we need to do two forward passes.
341
+ # Here we concatenate the unconditional and text embeddings into a single batch
342
+ # to avoid doing two forward passes
343
+ else:
344
+ negative_prompt_embeds = None
345
+
346
+ return prompt_embeds, negative_prompt_embeds
347
+
348
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
349
+ def prepare_extra_step_kwargs(self, generator, eta):
350
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
351
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
352
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
353
+ # and should be between [0, 1]
354
+
355
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
356
+ extra_step_kwargs = {}
357
+ if accepts_eta:
358
+ extra_step_kwargs["eta"] = eta
359
+
360
+ # check if the scheduler accepts generator
361
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
362
+ if accepts_generator:
363
+ extra_step_kwargs["generator"] = generator
364
+ return extra_step_kwargs
365
+
366
+ def check_inputs(
367
+ self,
368
+ prompt,
369
+ image,
370
+ batch_size,
371
+ noise_level,
372
+ callback_steps,
373
+ negative_prompt=None,
374
+ prompt_embeds=None,
375
+ negative_prompt_embeds=None,
376
+ ):
377
+ if (callback_steps is None) or (
378
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
379
+ ):
380
+ raise ValueError(
381
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
382
+ f" {type(callback_steps)}."
383
+ )
384
+
385
+ if prompt is not None and prompt_embeds is not None:
386
+ raise ValueError(
387
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
388
+ " only forward one of the two."
389
+ )
390
+ elif prompt is None and prompt_embeds is None:
391
+ raise ValueError(
392
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
393
+ )
394
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
395
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
396
+
397
+ if negative_prompt is not None and negative_prompt_embeds is not None:
398
+ raise ValueError(
399
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
400
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
401
+ )
402
+
403
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
404
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
405
+ raise ValueError(
406
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
407
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
408
+ f" {negative_prompt_embeds.shape}."
409
+ )
410
+
411
+ if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
412
+ raise ValueError(
413
+ f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})"
414
+ )
415
+
416
+ if isinstance(image, list):
417
+ check_image_type = image[0]
418
+ else:
419
+ check_image_type = image
420
+
421
+ if (
422
+ not isinstance(check_image_type, torch.Tensor)
423
+ and not isinstance(check_image_type, PIL.Image.Image)
424
+ and not isinstance(check_image_type, np.ndarray)
425
+ ):
426
+ raise ValueError(
427
+ "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
428
+ f" {type(check_image_type)}"
429
+ )
430
+
431
+ if isinstance(image, list):
432
+ image_batch_size = len(image)
433
+ elif isinstance(image, torch.Tensor):
434
+ image_batch_size = image.shape[0]
435
+ elif isinstance(image, PIL.Image.Image):
436
+ image_batch_size = 1
437
+ elif isinstance(image, np.ndarray):
438
+ image_batch_size = image.shape[0]
439
+ else:
440
+ assert False
441
+
442
+ if batch_size != image_batch_size:
443
+ raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
444
+
445
+ def prepare_intermediate_images(self, batch_size, num_channels, num_frames, height, width, dtype, device, generator):
446
+ shape = (batch_size, num_channels, num_frames, height, width)
447
+ if isinstance(generator, list) and len(generator) != batch_size:
448
+ raise ValueError(
449
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
450
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
451
+ )
452
+
453
+ intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
454
+
455
+ # scale the initial noise by the standard deviation required by the scheduler
456
+ intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
457
+ return intermediate_images
458
+
459
+ def preprocess_image(self, image, num_images_per_prompt, device):
460
+ if not isinstance(image, torch.Tensor) and not isinstance(image, list):
461
+ image = [image]
462
+
463
+ if isinstance(image[0], PIL.Image.Image):
464
+ image = [np.array(i).astype(np.float32) / 255.0 for i in image]
465
+
466
+ image = np.stack(image, axis=0) # to np
467
+ torch.from_numpy(image.transpose(0, 3, 1, 2))
468
+ elif isinstance(image[0], np.ndarray):
469
+ image = np.stack(image, axis=0) # to np
470
+ if image.ndim == 5:
471
+ image = image[0]
472
+
473
+ image = torch.from_numpy(image.transpose(0, 3, 1, 2))
474
+ elif isinstance(image, list) and isinstance(image[0], torch.Tensor):
475
+ dims = image[0].ndim
476
+
477
+ if dims == 3:
478
+ image = torch.stack(image, dim=0)
479
+ elif dims == 4:
480
+ image = torch.concat(image, dim=0)
481
+ else:
482
+ raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}")
483
+
484
+ image = image.to(device=device, dtype=self.unet.dtype)
485
+
486
+ image = image.repeat_interleave(num_images_per_prompt, dim=0)
487
+
488
+ return image
489
+
490
+ def _text_preprocessing(self, text, clean_caption=False):
491
+ if clean_caption and not is_bs4_available():
492
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
493
+ logger.warn("Setting `clean_caption` to False...")
494
+ clean_caption = False
495
+
496
+ if clean_caption and not is_ftfy_available():
497
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
498
+ logger.warn("Setting `clean_caption` to False...")
499
+ clean_caption = False
500
+
501
+ if not isinstance(text, (tuple, list)):
502
+ text = [text]
503
+
504
+ def process(text: str):
505
+ if clean_caption:
506
+ text = self._clean_caption(text)
507
+ text = self._clean_caption(text)
508
+ else:
509
+ text = text.lower().strip()
510
+ return text
511
+
512
+ return [process(t) for t in text]
513
+
514
+ def _clean_caption(self, caption):
515
+ caption = str(caption)
516
+ caption = ul.unquote_plus(caption)
517
+ caption = caption.strip().lower()
518
+ caption = re.sub("<person>", "person", caption)
519
+ # urls:
520
+ caption = re.sub(
521
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
522
+ "",
523
+ caption,
524
+ ) # regex for urls
525
+ caption = re.sub(
526
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
527
+ "",
528
+ caption,
529
+ ) # regex for urls
530
+ # html:
531
+ caption = BeautifulSoup(caption, features="html.parser").text
532
+
533
+ # @<nickname>
534
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
535
+
536
+ # 31C0—31EF CJK Strokes
537
+ # 31F0—31FF Katakana Phonetic Extensions
538
+ # 3200—32FF Enclosed CJK Letters and Months
539
+ # 3300—33FF CJK Compatibility
540
+ # 3400—4DBF CJK Unified Ideographs Extension A
541
+ # 4DC0—4DFF Yijing Hexagram Symbols
542
+ # 4E00—9FFF CJK Unified Ideographs
543
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
544
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
545
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
546
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
547
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
548
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
549
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
550
+ #######################################################
551
+
552
+ # все виды тире / all types of dash --> "-"
553
+ caption = re.sub(
554
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
555
+ "-",
556
+ caption,
557
+ )
558
+
559
+ # кавычки к одному стандарту
560
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
561
+ caption = re.sub(r"[‘’]", "'", caption)
562
+
563
+ # &quot;
564
+ caption = re.sub(r"&quot;?", "", caption)
565
+ # &amp
566
+ caption = re.sub(r"&amp", "", caption)
567
+
568
+ # ip adresses:
569
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
570
+
571
+ # article ids:
572
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
573
+
574
+ # \n
575
+ caption = re.sub(r"\\n", " ", caption)
576
+
577
+ # "#123"
578
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
579
+ # "#12345.."
580
+ caption = re.sub(r"#\d{5,}\b", "", caption)
581
+ # "123456.."
582
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
583
+ # filenames:
584
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
585
+
586
+ #
587
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
588
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
589
+
590
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
591
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
592
+
593
+ # this-is-my-cute-cat / this_is_my_cute_cat
594
+ regex2 = re.compile(r"(?:\-|\_)")
595
+ if len(re.findall(regex2, caption)) > 3:
596
+ caption = re.sub(regex2, " ", caption)
597
+
598
+ caption = ftfy.fix_text(caption)
599
+ caption = html.unescape(html.unescape(caption))
600
+
601
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
602
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
603
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
604
+
605
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
606
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
607
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
608
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
609
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
610
+
611
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
612
+
613
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
614
+
615
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
616
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
617
+ caption = re.sub(r"\s+", " ", caption)
618
+
619
+ caption.strip()
620
+
621
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
622
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
623
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
624
+ caption = re.sub(r"^\.\S+$", "", caption)
625
+
626
+ return caption.strip()
627
+
628
+ @torch.no_grad()
629
+ def __call__(
630
+ self,
631
+ prompt: Union[str, List[str]] = None,
632
+ height: Optional[int] = None,
633
+ width: Optional[int] = None,
634
+ image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
635
+ num_inference_steps: int = 50,
636
+ timesteps: List[int] = None,
637
+ guidance_scale: float = 4.0,
638
+ negative_prompt: Optional[Union[str, List[str]]] = None,
639
+ num_images_per_prompt: Optional[int] = 1,
640
+ eta: float = 0.0,
641
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
642
+ prompt_embeds: Optional[torch.FloatTensor] = None,
643
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
644
+ output_type: Optional[str] = "np",
645
+ return_dict: bool = True,
646
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
647
+ callback_steps: int = 1,
648
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
649
+ noise_level: int = 20,
650
+ clean_caption: bool = True,
651
+ ):
652
+ """
653
+ Function invoked when calling the pipeline for generation.
654
+
655
+ Args:
656
+ prompt (`str` or `List[str]`, *optional*):
657
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
658
+ instead.
659
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
660
+ The height in pixels of the generated image.
661
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
662
+ The width in pixels of the generated image.
663
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`):
664
+ The image to be upscaled.
665
+ num_inference_steps (`int`, *optional*, defaults to 50):
666
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
667
+ expense of slower inference.
668
+ timesteps (`List[int]`, *optional*):
669
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
670
+ timesteps are used. Must be in descending order.
671
+ guidance_scale (`float`, *optional*, defaults to 7.5):
672
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
673
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
674
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
675
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
676
+ usually at the expense of lower image quality.
677
+ negative_prompt (`str` or `List[str]`, *optional*):
678
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
679
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
680
+ less than `1`).
681
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
682
+ The number of images to generate per prompt.
683
+ eta (`float`, *optional*, defaults to 0.0):
684
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
685
+ [`schedulers.DDIMScheduler`], will be ignored for others.
686
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
687
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
688
+ to make generation deterministic.
689
+ prompt_embeds (`torch.FloatTensor`, *optional*):
690
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
691
+ provided, text embeddings will be generated from `prompt` input argument.
692
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
693
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
694
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
695
+ argument.
696
+ output_type (`str`, *optional*, defaults to `"pil"`):
697
+ The output format of the generate image. Choose between
698
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
699
+ return_dict (`bool`, *optional*, defaults to `True`):
700
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
701
+ callback (`Callable`, *optional*):
702
+ A function that will be called every `callback_steps` steps during inference. The function will be
703
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
704
+ callback_steps (`int`, *optional*, defaults to 1):
705
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
706
+ called at every step.
707
+ cross_attention_kwargs (`dict`, *optional*):
708
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
709
+ `self.processor` in
710
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
711
+ noise_level (`int`, *optional*, defaults to 250):
712
+ The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)`
713
+ clean_caption (`bool`, *optional*, defaults to `True`):
714
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
715
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
716
+ prompt.
717
+
718
+ Examples:
719
+
720
+ Returns:
721
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
722
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
723
+ returning a tuple, the first element is a list with the generated images, and the second element is a list
724
+ of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
725
+ or watermarked content, according to the `safety_checker`.
726
+ """
727
+ # 1. Check inputs. Raise error if not correct
728
+
729
+ if prompt is not None and isinstance(prompt, str):
730
+ batch_size = 1
731
+ elif prompt is not None and isinstance(prompt, list):
732
+ batch_size = len(prompt)
733
+ else:
734
+ batch_size = prompt_embeds.shape[0]
735
+
736
+ self.check_inputs(
737
+ prompt,
738
+ image,
739
+ batch_size,
740
+ noise_level,
741
+ callback_steps,
742
+ negative_prompt,
743
+ prompt_embeds,
744
+ negative_prompt_embeds,
745
+ )
746
+
747
+ # 2. Define call parameters
748
+
749
+ height = height or self.unet.config.sample_size
750
+ width = width or self.unet.config.sample_size
751
+ assert isinstance(image, torch.Tensor), f"{type(image)} is not supported."
752
+ num_frames = image.shape[2]
753
+
754
+ device = self._execution_device
755
+
756
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
757
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
758
+ # corresponds to doing no classifier free guidance.
759
+ do_classifier_free_guidance = guidance_scale > 1.0
760
+
761
+ # 3. Encode input prompt
762
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
763
+ prompt,
764
+ do_classifier_free_guidance,
765
+ num_images_per_prompt=num_images_per_prompt,
766
+ device=device,
767
+ negative_prompt=negative_prompt,
768
+ prompt_embeds=prompt_embeds,
769
+ negative_prompt_embeds=negative_prompt_embeds,
770
+ clean_caption=clean_caption,
771
+ )
772
+
773
+ if do_classifier_free_guidance:
774
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
775
+
776
+ # 4. Prepare timesteps
777
+ if timesteps is not None:
778
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
779
+ timesteps = self.scheduler.timesteps
780
+ num_inference_steps = len(timesteps)
781
+ else:
782
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
783
+ timesteps = self.scheduler.timesteps
784
+
785
+ # 5. Prepare intermediate images
786
+ num_channels = self.unet.config.in_channels // 2
787
+ intermediate_images = self.prepare_intermediate_images(
788
+ batch_size * num_images_per_prompt,
789
+ num_channels,
790
+ num_frames,
791
+ height,
792
+ width,
793
+ prompt_embeds.dtype,
794
+ device,
795
+ generator,
796
+ )
797
+
798
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
799
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
800
+
801
+ # 7. Prepare upscaled image and noise level
802
+ image = self.preprocess_image(image, num_images_per_prompt, device)
803
+ upscaled = rearrange(image, "b c f h w -> (b f) c h w")
804
+ upscaled = F.interpolate(upscaled, (height, width), mode="bilinear", align_corners=True)
805
+ upscaled = rearrange(upscaled, "(b f) c h w -> b c f h w", f=image.shape[2])
806
+
807
+ noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
808
+ noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
809
+ upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
810
+
811
+ if do_classifier_free_guidance:
812
+ noise_level = torch.cat([noise_level] * 2)
813
+
814
+ # HACK: see comment in `enable_model_cpu_offload`
815
+ if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
816
+ self.text_encoder_offload_hook.offload()
817
+
818
+ # 8. Denoising loop
819
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
820
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
821
+ for i, t in enumerate(timesteps):
822
+ model_input = torch.cat([intermediate_images, upscaled], dim=1)
823
+
824
+ model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
825
+ model_input = self.scheduler.scale_model_input(model_input, t)
826
+
827
+ # predict the noise residual
828
+ noise_pred = self.unet(
829
+ model_input,
830
+ t,
831
+ encoder_hidden_states=prompt_embeds,
832
+ class_labels=noise_level,
833
+ cross_attention_kwargs=cross_attention_kwargs,
834
+ ).sample
835
+
836
+ # perform guidance
837
+ if do_classifier_free_guidance:
838
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
839
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
840
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
841
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
842
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
843
+
844
+ # reshape latents
845
+ bsz, channel, frames, height, width = intermediate_images.shape
846
+ intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, height, width)
847
+ noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, height, width)
848
+
849
+ # compute the previous noisy sample x_t -> x_t-1
850
+ intermediate_images = self.scheduler.step(
851
+ noise_pred, t, intermediate_images, **extra_step_kwargs
852
+ ).prev_sample
853
+
854
+ # reshape latents back
855
+ intermediate_images = intermediate_images[None, :].reshape(bsz, frames, channel, height, width).permute(0, 2, 1, 3, 4)
856
+
857
+ # call the callback, if provided
858
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
859
+ progress_bar.update()
860
+ if callback is not None and i % callback_steps == 0:
861
+ callback(i, t, intermediate_images)
862
+
863
+ video_tensor = intermediate_images
864
+
865
+ if output_type == "pt":
866
+ video = video_tensor
867
+ else:
868
+ video = tensor2vid(video_tensor)
869
+
870
+ # Offload last model to CPU
871
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
872
+ self.final_offload_hook.offload()
873
+
874
+ if not return_dict:
875
+ return (video,)
876
+
877
+ return TextToVideoPipelineOutput(frames=video)
showone/pipelines/pipeline_t2v_sr_pixel_cond.py ADDED
@@ -0,0 +1,890 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import inspect
3
+ import re
4
+ import urllib.parse as ul
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import PIL
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
12
+ from einops import rearrange
13
+
14
+ from diffusers.loaders import LoraLoaderMixin
15
+ from diffusers.schedulers import DDPMScheduler
16
+ from diffusers.utils import (
17
+ BACKENDS_MAPPING,
18
+ is_accelerate_available,
19
+ is_accelerate_version,
20
+ is_bs4_available,
21
+ is_ftfy_available,
22
+ logging,
23
+ randn_tensor,
24
+ replace_example_docstring,
25
+ )
26
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
27
+
28
+ from ..models import UNet3DConditionModel
29
+ from . import TextToVideoPipelineOutput
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+ if is_bs4_available():
35
+ from bs4 import BeautifulSoup
36
+
37
+ if is_ftfy_available():
38
+ import ftfy
39
+
40
+
41
+ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
42
+ # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
43
+ # reshape to ncfhw
44
+ mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
45
+ std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
46
+ # unnormalize back to [0,1]
47
+ video = video.mul_(std).add_(mean)
48
+ video.clamp_(0, 1)
49
+ # prepare the final outputs
50
+ i, c, f, h, w = video.shape
51
+ images = video.permute(2, 3, 0, 4, 1).reshape(
52
+ f, h, i * w, c
53
+ ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
54
+ images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
55
+ images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
56
+ return images
57
+
58
+
59
+ class TextToVideoIFSuperResolutionPipeline_Cond(DiffusionPipeline, LoraLoaderMixin):
60
+ tokenizer: T5Tokenizer
61
+ text_encoder: T5EncoderModel
62
+
63
+ unet: UNet3DConditionModel
64
+ scheduler: DDPMScheduler
65
+ image_noising_scheduler: DDPMScheduler
66
+
67
+ feature_extractor: Optional[CLIPImageProcessor]
68
+ # safety_checker: Optional[IFSafetyChecker]
69
+
70
+ # watermarker: Optional[IFWatermarker]
71
+
72
+ bad_punct_regex = re.compile(
73
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
74
+ ) # noqa
75
+
76
+ _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
77
+
78
+ def __init__(
79
+ self,
80
+ tokenizer: T5Tokenizer,
81
+ text_encoder: T5EncoderModel,
82
+ unet: UNet3DConditionModel,
83
+ scheduler: DDPMScheduler,
84
+ image_noising_scheduler: DDPMScheduler,
85
+ feature_extractor: Optional[CLIPImageProcessor],
86
+ ):
87
+ super().__init__()
88
+
89
+ self.register_modules(
90
+ tokenizer=tokenizer,
91
+ text_encoder=text_encoder,
92
+ unet=unet,
93
+ scheduler=scheduler,
94
+ image_noising_scheduler=image_noising_scheduler,
95
+ feature_extractor=feature_extractor,
96
+ )
97
+ self.safety_checker = None
98
+
99
+ def enable_sequential_cpu_offload(self, gpu_id=0):
100
+ r"""
101
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
102
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
103
+ when their specific submodule has its `forward` method called.
104
+ """
105
+ if is_accelerate_available():
106
+ from accelerate import cpu_offload
107
+ else:
108
+ raise ImportError("Please install accelerate via `pip install accelerate`")
109
+
110
+ device = torch.device(f"cuda:{gpu_id}")
111
+
112
+ models = [
113
+ self.text_encoder,
114
+ self.unet,
115
+ ]
116
+ for cpu_offloaded_model in models:
117
+ if cpu_offloaded_model is not None:
118
+ cpu_offload(cpu_offloaded_model, device)
119
+
120
+ if self.safety_checker is not None:
121
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
122
+
123
+ def enable_model_cpu_offload(self, gpu_id=0):
124
+ r"""
125
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
126
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
127
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
128
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
129
+ """
130
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
131
+ from accelerate import cpu_offload_with_hook
132
+ else:
133
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
134
+
135
+ device = torch.device(f"cuda:{gpu_id}")
136
+
137
+ if self.device.type != "cpu":
138
+ self.to("cpu", silence_dtype_warnings=True)
139
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
140
+
141
+ hook = None
142
+
143
+ if self.text_encoder is not None:
144
+ _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
145
+
146
+ # Accelerate will move the next model to the device _before_ calling the offload hook of the
147
+ # previous model. This will cause both models to be present on the device at the same time.
148
+ # IF uses T5 for its text encoder which is really large. We can manually call the offload
149
+ # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
150
+ # the GPU.
151
+ self.text_encoder_offload_hook = hook
152
+
153
+ _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
154
+
155
+ # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
156
+ self.unet_offload_hook = hook
157
+
158
+ if self.safety_checker is not None:
159
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
160
+
161
+ # We'll offload the last model manually.
162
+ self.final_offload_hook = hook
163
+
164
+ def remove_all_hooks(self):
165
+ if is_accelerate_available():
166
+ from accelerate.hooks import remove_hook_from_module
167
+ else:
168
+ raise ImportError("Please install accelerate via `pip install accelerate`")
169
+
170
+ for model in [self.text_encoder, self.unet, self.safety_checker]:
171
+ if model is not None:
172
+ remove_hook_from_module(model, recurse=True)
173
+
174
+ self.unet_offload_hook = None
175
+ self.text_encoder_offload_hook = None
176
+ self.final_offload_hook = None
177
+
178
+ @property
179
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
180
+ def _execution_device(self):
181
+ r"""
182
+ Returns the device on which the pipeline's models will be executed. After calling
183
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
184
+ hooks.
185
+ """
186
+ if not hasattr(self.unet, "_hf_hook"):
187
+ return self.device
188
+ for module in self.unet.modules():
189
+ if (
190
+ hasattr(module, "_hf_hook")
191
+ and hasattr(module._hf_hook, "execution_device")
192
+ and module._hf_hook.execution_device is not None
193
+ ):
194
+ return torch.device(module._hf_hook.execution_device)
195
+ return self.device
196
+
197
+ @torch.no_grad()
198
+ def encode_prompt(
199
+ self,
200
+ prompt,
201
+ do_classifier_free_guidance=True,
202
+ num_images_per_prompt=1,
203
+ device=None,
204
+ negative_prompt=None,
205
+ prompt_embeds: Optional[torch.FloatTensor] = None,
206
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
207
+ clean_caption: bool = False,
208
+ ):
209
+ r"""
210
+ Encodes the prompt into text encoder hidden states.
211
+
212
+ Args:
213
+ prompt (`str` or `List[str]`, *optional*):
214
+ prompt to be encoded
215
+ device: (`torch.device`, *optional*):
216
+ torch device to place the resulting embeddings on
217
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
218
+ number of images that should be generated per prompt
219
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
220
+ whether to use classifier free guidance or not
221
+ negative_prompt (`str` or `List[str]`, *optional*):
222
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
223
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
224
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
225
+ prompt_embeds (`torch.FloatTensor`, *optional*):
226
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
227
+ provided, text embeddings will be generated from `prompt` input argument.
228
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
229
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
230
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
231
+ argument.
232
+ """
233
+ if prompt is not None and negative_prompt is not None:
234
+ if type(prompt) is not type(negative_prompt):
235
+ raise TypeError(
236
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
237
+ f" {type(prompt)}."
238
+ )
239
+
240
+ if device is None:
241
+ device = self._execution_device
242
+
243
+ if prompt is not None and isinstance(prompt, str):
244
+ batch_size = 1
245
+ elif prompt is not None and isinstance(prompt, list):
246
+ batch_size = len(prompt)
247
+ else:
248
+ batch_size = prompt_embeds.shape[0]
249
+
250
+ # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
251
+ max_length = 77
252
+
253
+ if prompt_embeds is None:
254
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
255
+ text_inputs = self.tokenizer(
256
+ prompt,
257
+ padding="max_length",
258
+ max_length=max_length,
259
+ truncation=True,
260
+ add_special_tokens=True,
261
+ return_tensors="pt",
262
+ )
263
+ text_input_ids = text_inputs.input_ids
264
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
265
+
266
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
267
+ text_input_ids, untruncated_ids
268
+ ):
269
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
270
+ logger.warning(
271
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
272
+ f" {max_length} tokens: {removed_text}"
273
+ )
274
+
275
+ attention_mask = text_inputs.attention_mask.to(device)
276
+
277
+ prompt_embeds = self.text_encoder(
278
+ text_input_ids.to(device),
279
+ attention_mask=attention_mask,
280
+ )
281
+ prompt_embeds = prompt_embeds[0]
282
+
283
+ if self.text_encoder is not None:
284
+ dtype = self.text_encoder.dtype
285
+ elif self.unet is not None:
286
+ dtype = self.unet.dtype
287
+ else:
288
+ dtype = None
289
+
290
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
291
+
292
+ bs_embed, seq_len, _ = prompt_embeds.shape
293
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
294
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
295
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
296
+
297
+ # get unconditional embeddings for classifier free guidance
298
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
299
+ uncond_tokens: List[str]
300
+ if negative_prompt is None:
301
+ uncond_tokens = [""] * batch_size
302
+ elif isinstance(negative_prompt, str):
303
+ uncond_tokens = [negative_prompt]
304
+ elif batch_size != len(negative_prompt):
305
+ raise ValueError(
306
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
307
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
308
+ " the batch size of `prompt`."
309
+ )
310
+ else:
311
+ uncond_tokens = negative_prompt
312
+
313
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
314
+ max_length = prompt_embeds.shape[1]
315
+ uncond_input = self.tokenizer(
316
+ uncond_tokens,
317
+ padding="max_length",
318
+ max_length=max_length,
319
+ truncation=True,
320
+ return_attention_mask=True,
321
+ add_special_tokens=True,
322
+ return_tensors="pt",
323
+ )
324
+ attention_mask = uncond_input.attention_mask.to(device)
325
+
326
+ negative_prompt_embeds = self.text_encoder(
327
+ uncond_input.input_ids.to(device),
328
+ attention_mask=attention_mask,
329
+ )
330
+ negative_prompt_embeds = negative_prompt_embeds[0]
331
+
332
+ if do_classifier_free_guidance:
333
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
334
+ seq_len = negative_prompt_embeds.shape[1]
335
+
336
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
337
+
338
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
339
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
340
+
341
+ # For classifier free guidance, we need to do two forward passes.
342
+ # Here we concatenate the unconditional and text embeddings into a single batch
343
+ # to avoid doing two forward passes
344
+ else:
345
+ negative_prompt_embeds = None
346
+
347
+ return prompt_embeds, negative_prompt_embeds
348
+
349
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
350
+ def prepare_extra_step_kwargs(self, generator, eta):
351
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
352
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
353
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
354
+ # and should be between [0, 1]
355
+
356
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
357
+ extra_step_kwargs = {}
358
+ if accepts_eta:
359
+ extra_step_kwargs["eta"] = eta
360
+
361
+ # check if the scheduler accepts generator
362
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
363
+ if accepts_generator:
364
+ extra_step_kwargs["generator"] = generator
365
+ return extra_step_kwargs
366
+
367
+ def check_inputs(
368
+ self,
369
+ prompt,
370
+ image,
371
+ batch_size,
372
+ noise_level,
373
+ callback_steps,
374
+ negative_prompt=None,
375
+ prompt_embeds=None,
376
+ negative_prompt_embeds=None,
377
+ ):
378
+ if (callback_steps is None) or (
379
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
380
+ ):
381
+ raise ValueError(
382
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
383
+ f" {type(callback_steps)}."
384
+ )
385
+
386
+ if prompt is not None and prompt_embeds is not None:
387
+ raise ValueError(
388
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
389
+ " only forward one of the two."
390
+ )
391
+ elif prompt is None and prompt_embeds is None:
392
+ raise ValueError(
393
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
394
+ )
395
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
396
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
397
+
398
+ if negative_prompt is not None and negative_prompt_embeds is not None:
399
+ raise ValueError(
400
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
401
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
402
+ )
403
+
404
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
405
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
406
+ raise ValueError(
407
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
408
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
409
+ f" {negative_prompt_embeds.shape}."
410
+ )
411
+
412
+ if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
413
+ raise ValueError(
414
+ f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})"
415
+ )
416
+
417
+ if isinstance(image, list):
418
+ check_image_type = image[0]
419
+ else:
420
+ check_image_type = image
421
+
422
+ if (
423
+ not isinstance(check_image_type, torch.Tensor)
424
+ and not isinstance(check_image_type, PIL.Image.Image)
425
+ and not isinstance(check_image_type, np.ndarray)
426
+ ):
427
+ raise ValueError(
428
+ "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
429
+ f" {type(check_image_type)}"
430
+ )
431
+
432
+ if isinstance(image, list):
433
+ image_batch_size = len(image)
434
+ elif isinstance(image, torch.Tensor):
435
+ image_batch_size = image.shape[0]
436
+ elif isinstance(image, PIL.Image.Image):
437
+ image_batch_size = 1
438
+ elif isinstance(image, np.ndarray):
439
+ image_batch_size = image.shape[0]
440
+ else:
441
+ assert False
442
+
443
+ if batch_size != image_batch_size:
444
+ raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
445
+
446
+ def prepare_intermediate_images(self, batch_size, num_channels, num_frames, height, width, dtype, device, generator):
447
+ shape = (batch_size, num_channels, num_frames, height, width)
448
+ if isinstance(generator, list) and len(generator) != batch_size:
449
+ raise ValueError(
450
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
451
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
452
+ )
453
+
454
+ intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
455
+
456
+ # scale the initial noise by the standard deviation required by the scheduler
457
+ intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
458
+ return intermediate_images
459
+
460
+ def preprocess_image(self, image, num_images_per_prompt, device):
461
+ if not isinstance(image, torch.Tensor) and not isinstance(image, list):
462
+ image = [image]
463
+
464
+ if isinstance(image[0], PIL.Image.Image):
465
+ image = [np.array(i).astype(np.float32) / 255.0 for i in image]
466
+
467
+ image = np.stack(image, axis=0) # to np
468
+ torch.from_numpy(image.transpose(0, 3, 1, 2))
469
+ elif isinstance(image[0], np.ndarray):
470
+ image = np.stack(image, axis=0) # to np
471
+ if image.ndim == 5:
472
+ image = image[0]
473
+
474
+ image = torch.from_numpy(image.transpose(0, 3, 1, 2))
475
+ elif isinstance(image, list) and isinstance(image[0], torch.Tensor):
476
+ dims = image[0].ndim
477
+
478
+ if dims == 3:
479
+ image = torch.stack(image, dim=0)
480
+ elif dims == 4:
481
+ image = torch.concat(image, dim=0)
482
+ else:
483
+ raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}")
484
+
485
+ image = image.to(device=device, dtype=self.unet.dtype)
486
+
487
+ image = image.repeat_interleave(num_images_per_prompt, dim=0)
488
+
489
+ return image
490
+
491
+ def _text_preprocessing(self, text, clean_caption=False):
492
+ if clean_caption and not is_bs4_available():
493
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
494
+ logger.warn("Setting `clean_caption` to False...")
495
+ clean_caption = False
496
+
497
+ if clean_caption and not is_ftfy_available():
498
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
499
+ logger.warn("Setting `clean_caption` to False...")
500
+ clean_caption = False
501
+
502
+ if not isinstance(text, (tuple, list)):
503
+ text = [text]
504
+
505
+ def process(text: str):
506
+ if clean_caption:
507
+ text = self._clean_caption(text)
508
+ text = self._clean_caption(text)
509
+ else:
510
+ text = text.lower().strip()
511
+ return text
512
+
513
+ return [process(t) for t in text]
514
+
515
+ def _clean_caption(self, caption):
516
+ caption = str(caption)
517
+ caption = ul.unquote_plus(caption)
518
+ caption = caption.strip().lower()
519
+ caption = re.sub("<person>", "person", caption)
520
+ # urls:
521
+ caption = re.sub(
522
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
523
+ "",
524
+ caption,
525
+ ) # regex for urls
526
+ caption = re.sub(
527
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
528
+ "",
529
+ caption,
530
+ ) # regex for urls
531
+ # html:
532
+ caption = BeautifulSoup(caption, features="html.parser").text
533
+
534
+ # @<nickname>
535
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
536
+
537
+ # 31C0—31EF CJK Strokes
538
+ # 31F0—31FF Katakana Phonetic Extensions
539
+ # 3200—32FF Enclosed CJK Letters and Months
540
+ # 3300—33FF CJK Compatibility
541
+ # 3400—4DBF CJK Unified Ideographs Extension A
542
+ # 4DC0—4DFF Yijing Hexagram Symbols
543
+ # 4E00—9FFF CJK Unified Ideographs
544
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
545
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
546
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
547
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
548
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
549
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
550
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
551
+ #######################################################
552
+
553
+ # все виды тире / all types of dash --> "-"
554
+ caption = re.sub(
555
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
556
+ "-",
557
+ caption,
558
+ )
559
+
560
+ # кавычки к одному стандарту
561
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
562
+ caption = re.sub(r"[‘’]", "'", caption)
563
+
564
+ # &quot;
565
+ caption = re.sub(r"&quot;?", "", caption)
566
+ # &amp
567
+ caption = re.sub(r"&amp", "", caption)
568
+
569
+ # ip adresses:
570
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
571
+
572
+ # article ids:
573
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
574
+
575
+ # \n
576
+ caption = re.sub(r"\\n", " ", caption)
577
+
578
+ # "#123"
579
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
580
+ # "#12345.."
581
+ caption = re.sub(r"#\d{5,}\b", "", caption)
582
+ # "123456.."
583
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
584
+ # filenames:
585
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
586
+
587
+ #
588
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
589
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
590
+
591
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
592
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
593
+
594
+ # this-is-my-cute-cat / this_is_my_cute_cat
595
+ regex2 = re.compile(r"(?:\-|\_)")
596
+ if len(re.findall(regex2, caption)) > 3:
597
+ caption = re.sub(regex2, " ", caption)
598
+
599
+ caption = ftfy.fix_text(caption)
600
+ caption = html.unescape(html.unescape(caption))
601
+
602
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
603
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
604
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
605
+
606
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
607
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
608
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
609
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
610
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
611
+
612
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
613
+
614
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
615
+
616
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
617
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
618
+ caption = re.sub(r"\s+", " ", caption)
619
+
620
+ caption.strip()
621
+
622
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
623
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
624
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
625
+ caption = re.sub(r"^\.\S+$", "", caption)
626
+
627
+ return caption.strip()
628
+
629
+ @torch.no_grad()
630
+ def __call__(
631
+ self,
632
+ prompt: Union[str, List[str]] = None,
633
+ height: Optional[int] = None,
634
+ width: Optional[int] = None,
635
+ image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
636
+ first_frame_cond: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
637
+ all_frame_cond: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
638
+ num_inference_steps: int = 50,
639
+ timesteps: List[int] = None,
640
+ guidance_scale: float = 4.0,
641
+ negative_prompt: Optional[Union[str, List[str]]] = None,
642
+ num_images_per_prompt: Optional[int] = 1,
643
+ eta: float = 0.0,
644
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
645
+ prompt_embeds: Optional[torch.FloatTensor] = None,
646
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
647
+ output_type: Optional[str] = "np",
648
+ return_dict: bool = True,
649
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
650
+ callback_steps: int = 1,
651
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
652
+ noise_level: int = 250,
653
+ clean_caption: bool = True,
654
+ ):
655
+ """
656
+ Function invoked when calling the pipeline for generation.
657
+
658
+ Args:
659
+ prompt (`str` or `List[str]`, *optional*):
660
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
661
+ instead.
662
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
663
+ The height in pixels of the generated image.
664
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
665
+ The width in pixels of the generated image.
666
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`):
667
+ The image to be upscaled.
668
+ num_inference_steps (`int`, *optional*, defaults to 50):
669
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
670
+ expense of slower inference.
671
+ timesteps (`List[int]`, *optional*):
672
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
673
+ timesteps are used. Must be in descending order.
674
+ guidance_scale (`float`, *optional*, defaults to 7.5):
675
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
676
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
677
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
678
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
679
+ usually at the expense of lower image quality.
680
+ negative_prompt (`str` or `List[str]`, *optional*):
681
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
682
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
683
+ less than `1`).
684
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
685
+ The number of images to generate per prompt.
686
+ eta (`float`, *optional*, defaults to 0.0):
687
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
688
+ [`schedulers.DDIMScheduler`], will be ignored for others.
689
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
690
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
691
+ to make generation deterministic.
692
+ prompt_embeds (`torch.FloatTensor`, *optional*):
693
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
694
+ provided, text embeddings will be generated from `prompt` input argument.
695
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
696
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
697
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
698
+ argument.
699
+ output_type (`str`, *optional*, defaults to `"pil"`):
700
+ The output format of the generate image. Choose between
701
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
702
+ return_dict (`bool`, *optional*, defaults to `True`):
703
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
704
+ callback (`Callable`, *optional*):
705
+ A function that will be called every `callback_steps` steps during inference. The function will be
706
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
707
+ callback_steps (`int`, *optional*, defaults to 1):
708
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
709
+ called at every step.
710
+ cross_attention_kwargs (`dict`, *optional*):
711
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
712
+ `self.processor` in
713
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
714
+ noise_level (`int`, *optional*, defaults to 250):
715
+ The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)`
716
+ clean_caption (`bool`, *optional*, defaults to `True`):
717
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
718
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
719
+ prompt.
720
+
721
+ Examples:
722
+
723
+ Returns:
724
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
725
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
726
+ returning a tuple, the first element is a list with the generated images, and the second element is a list
727
+ of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
728
+ or watermarked content, according to the `safety_checker`.
729
+ """
730
+ # 1. Check inputs. Raise error if not correct
731
+
732
+ if prompt is not None and isinstance(prompt, str):
733
+ batch_size = 1
734
+ elif prompt is not None and isinstance(prompt, list):
735
+ batch_size = len(prompt)
736
+ else:
737
+ batch_size = prompt_embeds.shape[0]
738
+
739
+ self.check_inputs(
740
+ prompt,
741
+ image,
742
+ batch_size,
743
+ noise_level,
744
+ callback_steps,
745
+ negative_prompt,
746
+ prompt_embeds,
747
+ negative_prompt_embeds,
748
+ )
749
+
750
+ # 2. Define call parameters
751
+
752
+ height = height or self.unet.config.sample_size
753
+ width = width or self.unet.config.sample_size
754
+ assert isinstance(image, torch.Tensor), f"{type(image)} is not supported."
755
+ num_frames = image.shape[2]
756
+
757
+ device = self._execution_device
758
+
759
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
760
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
761
+ # corresponds to doing no classifier free guidance.
762
+ do_classifier_free_guidance = guidance_scale > 1.0
763
+
764
+ # 3. Encode input prompt
765
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
766
+ prompt,
767
+ do_classifier_free_guidance,
768
+ num_images_per_prompt=num_images_per_prompt,
769
+ device=device,
770
+ negative_prompt=negative_prompt,
771
+ prompt_embeds=prompt_embeds,
772
+ negative_prompt_embeds=negative_prompt_embeds,
773
+ clean_caption=clean_caption,
774
+ )
775
+
776
+ if do_classifier_free_guidance:
777
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
778
+
779
+ # 4. Prepare timesteps
780
+ if timesteps is not None:
781
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
782
+ timesteps = self.scheduler.timesteps
783
+ num_inference_steps = len(timesteps)
784
+ else:
785
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
786
+ timesteps = self.scheduler.timesteps
787
+
788
+ # 5. Prepare intermediate images
789
+ num_channels = self.unet.config.in_channels // 2
790
+ intermediate_images = self.prepare_intermediate_images(
791
+ batch_size * num_images_per_prompt,
792
+ num_channels,
793
+ num_frames,
794
+ height,
795
+ width,
796
+ prompt_embeds.dtype,
797
+ device,
798
+ generator,
799
+ )
800
+
801
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
802
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
803
+
804
+ # 7. Prepare upscaled image and noise level
805
+ image = self.preprocess_image(image, num_images_per_prompt, device)
806
+ # upscaled = F.interpolate(image, (num_frames, height, width), mode="trilinear", align_corners=True)
807
+ if all_frame_cond is not None:
808
+ upscaled = all_frame_cond
809
+ else:
810
+ upscaled = rearrange(image, "b c f h w -> (b f) c h w")
811
+ upscaled = F.interpolate(upscaled, (height, width), mode="bilinear", align_corners=True)
812
+ upscaled = rearrange(upscaled, "(b f) c h w -> b c f h w", f=image.shape[2])
813
+
814
+ noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
815
+ noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
816
+ upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
817
+ if first_frame_cond is not None:
818
+ first_frame_cond = first_frame_cond.to(device=device, dtype=self.unet.dtype)
819
+ upscaled[:,:,:1,:,:] = first_frame_cond
820
+
821
+ if do_classifier_free_guidance:
822
+ noise_level = torch.cat([noise_level] * 2)
823
+
824
+ # HACK: see comment in `enable_model_cpu_offload`
825
+ if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
826
+ self.text_encoder_offload_hook.offload()
827
+
828
+ # 8. Denoising loop
829
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
830
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
831
+ for i, t in enumerate(timesteps):
832
+ model_input = torch.cat([intermediate_images, upscaled], dim=1)
833
+
834
+ model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
835
+ model_input = self.scheduler.scale_model_input(model_input, t)
836
+
837
+ # predict the noise residual
838
+ noise_pred = self.unet(
839
+ model_input,
840
+ t,
841
+ encoder_hidden_states=prompt_embeds,
842
+ class_labels=noise_level,
843
+ cross_attention_kwargs=cross_attention_kwargs,
844
+ ).sample
845
+
846
+ # perform guidance
847
+ if do_classifier_free_guidance:
848
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
849
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
850
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
851
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
852
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
853
+
854
+ if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
855
+ noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
856
+
857
+ # reshape latents
858
+ bsz, channel, frames, height, width = intermediate_images.shape
859
+ intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, height, width)
860
+ noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, height, width)
861
+
862
+ # compute the previous noisy sample x_t -> x_t-1
863
+ intermediate_images = self.scheduler.step(
864
+ noise_pred, t, intermediate_images, **extra_step_kwargs
865
+ ).prev_sample
866
+
867
+ # reshape latents back
868
+ intermediate_images = intermediate_images[None, :].reshape(bsz, frames, channel, height, width).permute(0, 2, 1, 3, 4)
869
+
870
+ # call the callback, if provided
871
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
872
+ progress_bar.update()
873
+ if callback is not None and i % callback_steps == 0:
874
+ callback(i, t, intermediate_images)
875
+
876
+ video_tensor = intermediate_images
877
+
878
+ if output_type == "pt":
879
+ video = video_tensor
880
+ else:
881
+ video = tensor2vid(video_tensor)
882
+
883
+ # Offload last model to CPU
884
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
885
+ self.final_offload_hook.offload()
886
+
887
+ if not return_dict:
888
+ return (video,)
889
+
890
+ return TextToVideoPipelineOutput(frames=video)