liray commited on
Commit
2f72267
1 Parent(s): 3b04f0c

Initial commit.

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ ckpts/
app.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import os
3
+ from PIL import Image, ImageOps
4
+ import random
5
+
6
+ import cv2
7
+ from diffusers.models import AutoencoderKL
8
+ import gradio as gr
9
+ import numpy as np
10
+ from segment_anything import build_sam, SamPredictor
11
+ from tqdm import tqdm
12
+ from transformers import CLIPModel, AutoProcessor, CLIPVisionModel
13
+ import torch
14
+ from torchvision import transforms
15
+
16
+ from diffusion import create_diffusion
17
+ from model import UNet2DDragConditionModel
18
+
19
+
20
+ TITLE = '''DragAPart: Learning a Part-Level Motion Prior for Articulated Objects'''
21
+ DESCRIPTION = """
22
+ <div>
23
+ Try <a href='https://arxiv.org/abs/24xx.xxxxx'><b>DragAPart</b></a> yourself to manipulate your favorite articulated objects in 2 seconds!
24
+ </div>
25
+ """
26
+ INSTRUCTION = '''
27
+ 2 steps to get started:
28
+ - Upload an image of an articulated object.
29
+ - Add one or more drags on the object to specify the part-level interactions.
30
+
31
+ How to add drags:
32
+ - To add a drag, first click on the starting point of the drag, then click on the ending point of the drag, on the Input Image (leftmost).
33
+ - You can add up to 10 drags, but we suggest one drag per part.
34
+ - After every click, the drags will be visualized on the Image with Drags (second from left).
35
+ - If the last drag is not completed (you specified the starting point but not the ending point), it will simply be ignored.
36
+ - Have fun dragging!
37
+
38
+ Then, you will be prompted to verify the object segmentation. Once you confirm that the segmentation is decent, the output image will be generated in seconds!
39
+ '''
40
+ PREPROCESS_INSTRUCTION = '''
41
+ Segmentation is needed if it is not already provided through an alpha channel in the input image.
42
+ You don't need to tick this box if you have chosen one of the example images.
43
+ If you have uploaded one of your own images, it is very likely that you will need to tick this box.
44
+ You should verify that the preprocessed image is object-centric (i.e., clearly contains a single object) and has white background.
45
+ '''
46
+
47
+ def center_and_square_image(pil_image_rgba, drags):
48
+ image = pil_image_rgba
49
+ alpha = np.array(image)[:, :, 3] # Extract the alpha channel
50
+
51
+ cy, cx = np.round(np.mean(np.nonzero(alpha), axis=1)).astype(int)
52
+ side_length = max(image.width, image.height)
53
+ padded_image = ImageOps.expand(
54
+ image,
55
+ (side_length // 2, side_length // 2, side_length // 2, side_length // 2),
56
+ fill=(255, 255, 255, 255)
57
+ )
58
+ left, top = cx, cy
59
+ new_drags = []
60
+ for d in drags:
61
+ x, y = d
62
+ new_x, new_y = (x + side_length // 2 - cx) / side_length, (y + side_length // 2 - cy) / side_length
63
+ new_drags.append((new_x, new_y))
64
+
65
+ # Crop or pad the image as needed to make it centered around (cx, cy)
66
+ image = padded_image.crop((left, top, left + side_length, top + side_length))
67
+ # Resize the image to 256x256
68
+ image = image.resize((256, 256), Image.Resampling.LANCZOS)
69
+ return image, new_drags
70
+
71
+ def sam_init():
72
+ sam_checkpoint = os.path.join(os.path.dirname(__file__), "ckpts", "sam_vit_h_4b8939.pth")
73
+ predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to("cuda"))
74
+ return predictor
75
+
76
+ def model_init():
77
+ model_checkpoint = os.path.join(os.path.dirname(__file__), "ckpts", "drag-a-part-final.pt")
78
+ model = UNet2DDragConditionModel.from_pretrained_sd(
79
+ os.path.join(os.path.dirname(__file__), "ckpts", "stable-diffusion-v1-5"),
80
+ unet_additional_kwargs=dict(
81
+ sample_size=32,
82
+ flow_original_res=False,
83
+ input_concat_dragging=False,
84
+ attn_concat_dragging=True,
85
+ use_drag_tokens=False,
86
+ single_drag_token=False,
87
+ one_sided_attn=True,
88
+ flow_in_old_version=False,
89
+ ),
90
+ load=False,
91
+ )
92
+ model.load_state_dict(torch.load(model_checkpoint)["model"])
93
+ model = model.to("cuda")
94
+ return model
95
+
96
+ def sam_segment(predictor, input_image, drags, foreground_points=None):
97
+ image = np.asarray(input_image)
98
+ predictor.set_image(image)
99
+
100
+ with torch.no_grad():
101
+ masks_bbox, _, _ = predictor.predict(
102
+ point_coords=foreground_points if foreground_points is not None else None,
103
+ point_labels=np.ones(len(foreground_points)) if foreground_points is not None else None,
104
+ multimask_output=True
105
+ )
106
+
107
+ out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
108
+ out_image[:, :, :3] = image
109
+ out_image[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255
110
+ torch.cuda.empty_cache()
111
+ out_image, new_drags = center_and_square_image(Image.fromarray(out_image, mode="RGBA"), drags)
112
+
113
+ return out_image, new_drags
114
+
115
+ def get_point(img, sel_pix, evt: gr.SelectData):
116
+ sel_pix.append(evt.index)
117
+ points = []
118
+ img = np.array(img)
119
+ height = img.shape[0]
120
+ arrow_width_large = 7 * height // 256
121
+ arrow_width_small = 3 * height // 256
122
+ circle_size = 5 * height // 256
123
+
124
+ with_alpha = img.shape[2] == 4
125
+ for idx, point in enumerate(sel_pix):
126
+ if idx % 2 == 1:
127
+ cv2.circle(img, tuple(point), circle_size, (0, 0, 255, 255) if with_alpha else (0, 0, 255), -1)
128
+ else:
129
+ cv2.circle(img, tuple(point), circle_size, (255, 0, 0, 255) if with_alpha else (255, 0, 0), -1)
130
+ points.append(tuple(point))
131
+ if len(points) == 2:
132
+ cv2.arrowedLine(img, points[0], points[1], (0, 0, 0, 255) if with_alpha else (0, 0, 0), arrow_width_large)
133
+ cv2.arrowedLine(img, points[0], points[1], (255, 255, 0, 255) if with_alpha else (0, 0, 0), arrow_width_small)
134
+ points = []
135
+ return img if isinstance(img, np.ndarray) else np.array(img)
136
+
137
+ def clear_drag():
138
+ return []
139
+
140
+ def preprocess_image(SAM_predictor, img, chk_group, drags):
141
+ if img is None:
142
+ gr.Warning("No image is specified. Please specify an image before preprocessing.")
143
+ return None, drags
144
+
145
+ if drags is None or len(drags) == 0:
146
+ foreground_points = None
147
+ else:
148
+ foreground_points = np.array([drags[i] for i in range(0, len(drags), 2)])
149
+
150
+ if len(drags) == 0:
151
+ gr.Warning("No drags are specified. We recommend first specifying the drags before preprocessing.")
152
+
153
+ new_drags = drags
154
+ if "Preprocess with Segmentation" in chk_group:
155
+ img_np = np.array(img)
156
+ rgb_img = img_np[..., :3]
157
+ img, new_drags = sam_segment(
158
+ SAM_predictor,
159
+ rgb_img,
160
+ drags,
161
+ foreground_points=foreground_points,
162
+ )
163
+ else:
164
+ new_drags = [(d[0] / img.width, d[1] / img.height) for d in drags]
165
+
166
+ img = np.array(img).astype(np.float32)
167
+ processed_img = img[..., :3] * img[..., 3:] / 255. + 255. * (1 - img[..., 3:] / 255.)
168
+ image_pil = Image.fromarray(processed_img.astype(np.uint8), mode="RGB")
169
+ processed_img = image_pil.resize((256, 256), Image.LANCZOS)
170
+ return processed_img, new_drags
171
+
172
+ def single_image_sample(
173
+ model,
174
+ diffusion,
175
+ x_cond,
176
+ x_cond_clip,
177
+ rel,
178
+ cfg_scale,
179
+ x_cond_extra,
180
+ drags,
181
+ hidden_cls,
182
+ num_steps=50,
183
+ ):
184
+ z = torch.randn(2, 4, 32, 32).to("cuda")
185
+
186
+ # Prepare input for classifer-free guidance
187
+ rel = torch.cat([rel, rel], dim=0)
188
+ x_cond = torch.cat([x_cond, x_cond], dim=0)
189
+ x_cond_clip = torch.cat([x_cond_clip, x_cond_clip], dim=0)
190
+ x_cond_extra = torch.cat([x_cond_extra, x_cond_extra], dim=0)
191
+ drags = torch.cat([drags, drags], dim=0)
192
+ hidden_cls = torch.cat([hidden_cls, hidden_cls], dim=0)
193
+
194
+ model_kwargs = dict(
195
+ x_cond=x_cond,
196
+ x_cond_extra=x_cond_extra,
197
+ cfg_scale=cfg_scale,
198
+ hidden_cls=hidden_cls,
199
+ drags=drags,
200
+ )
201
+
202
+ # Denoising
203
+ step_delta = diffusion.num_timesteps // num_steps
204
+ for i in tqdm(range(num_steps)):
205
+ with torch.no_grad():
206
+ samples = diffusion.p_sample(
207
+ model.forward_with_cfg,
208
+ z,
209
+ torch.Tensor([diffusion.num_timesteps - 1 - step_delta * i]).long().to("cuda").repeat(z.shape[0]),
210
+ clip_denoised=False,
211
+ model_kwargs=model_kwargs,
212
+ )["pred_xstart"]
213
+ if i != num_steps - 1:
214
+ z = diffusion.q_sample(
215
+ samples,
216
+ torch.Tensor([diffusion.num_timesteps - 1 - step_delta * i]).long().to("cuda").repeat(z.shape[0])
217
+ )
218
+
219
+ samples, _ = samples.chunk(2, dim=0)
220
+ return samples
221
+
222
+ def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion, img_cond, seed, cfg_scale, drags_list):
223
+ if img_cond is None:
224
+ gr.Warning("Please preprocess the image first.")
225
+ return None
226
+
227
+ with torch.no_grad():
228
+ torch.manual_seed(seed)
229
+ np.random.seed(seed)
230
+ torch.cuda.manual_seed(seed)
231
+ torch.cuda.manual_seed_all(seed)
232
+ random.seed(seed)
233
+
234
+ pixels_cond = transforms.ToTensor()(img_cond.astype(np.float32) / 127.5 - 1).unsqueeze(0).to("cuda")
235
+
236
+ cond_pixel_preprocessed_for_clip = image_processor(
237
+ images=Image.fromarray(img_cond), return_tensors="pt"
238
+ ).pixel_values.to("cuda")
239
+ with torch.no_grad():
240
+ x_cond = vae.encode(pixels_cond).latent_dist.sample().mul_(0.18215)
241
+ cond_clip_features = clip_model.get_image_features(cond_pixel_preprocessed_for_clip)
242
+ cls_embedding = torch.stack(
243
+ clip_vit(pixel_values=cond_pixel_preprocessed_for_clip, output_hidden_states=True).hidden_states,
244
+ dim=1
245
+ )[:, :, 0]
246
+
247
+ # dummies
248
+ rel = torch.zeros(1, 4).to("cuda")
249
+ x_cond_extra = torch.zeros(1, 3, 32, 32).to("cuda")
250
+
251
+ drags = torch.zeros(1, 10, 4).to("cuda")
252
+ for i in range(0, len(drags_list), 2):
253
+ if i + 1 == len(drags_list):
254
+ gr.Warning("The ending point of the last drag is not specified. The last drag is ignored.")
255
+ break
256
+
257
+ idx = i // 2
258
+ drags[0, idx, 0], drags[0, idx, 1], drags[0, idx, 2], drags[0, idx, 3] = \
259
+ drags_list[i][0], drags_list[i][1], drags_list[i + 1][0], drags_list[i + 1][1]
260
+
261
+ if idx == 9:
262
+ break
263
+
264
+ samples = single_image_sample(
265
+ model,
266
+ diffusion,
267
+ x_cond,
268
+ cond_clip_features,
269
+ rel,
270
+ cfg_scale,
271
+ x_cond_extra,
272
+ drags,
273
+ cls_embedding,
274
+ num_steps=50,
275
+ )
276
+
277
+ with torch.no_grad():
278
+ images = vae.decode(samples / 0.18215).sample
279
+ images = ((images + 1)[0].permute(1, 2, 0) * 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
280
+ return images
281
+
282
+
283
+ sam_predictor = sam_init()
284
+ model = model_init()
285
+
286
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to('cuda')
287
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to('cuda')
288
+ clip_vit = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to('cuda')
289
+ image_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
290
+ diffusion = create_diffusion(
291
+ timestep_respacing="",
292
+ learn_sigma=False,
293
+ )
294
+
295
+ with gr.Blocks(title=TITLE) as demo:
296
+ gr.Markdown("# " + DESCRIPTION)
297
+
298
+ with gr.Row():
299
+ gr.Markdown(INSTRUCTION)
300
+
301
+ drags = gr.State(value=[])
302
+
303
+ with gr.Row(variant="panel"):
304
+ with gr.Column(scale=1):
305
+ input_image = gr.Image(
306
+ interactive=True,
307
+ type='pil',
308
+ image_mode="RGBA",
309
+ width=256,
310
+ show_label=True,
311
+ label="Input Image",
312
+ )
313
+
314
+ example_folder = os.path.join(os.path.dirname(__file__), "./example_images")
315
+ example_fns = [os.path.join(example_folder, example) for example in sorted(os.listdir(example_folder))]
316
+ gr.Examples(
317
+ examples=example_fns,
318
+ inputs=[input_image],
319
+ cache_examples=False,
320
+ label='Feel free to use one of our provided examples!',
321
+ examples_per_page=30
322
+ )
323
+
324
+ input_image.change(
325
+ fn=clear_drag,
326
+ outputs=[drags],
327
+ )
328
+
329
+ with gr.Column(scale=1):
330
+ drag_image = gr.Image(
331
+ type="numpy",
332
+ label="Image with Drags",
333
+ interactive=False,
334
+ width=256,
335
+ image_mode="RGB",
336
+ )
337
+
338
+ input_image.select(
339
+ fn=get_point,
340
+ inputs=[input_image, drags],
341
+ outputs=[drag_image],
342
+ )
343
+
344
+ with gr.Column(scale=1):
345
+ processed_image = gr.Image(
346
+ type='numpy',
347
+ label="Processed Image",
348
+ interactive=False,
349
+ width=256,
350
+ height=256,
351
+ image_mode='RGB',
352
+ )
353
+ processed_image_highres = gr.Image(type='pil', image_mode='RGB', visible=False)
354
+
355
+ with gr.Accordion('Advanced preprocessing options', open=True):
356
+ with gr.Row():
357
+ with gr.Column():
358
+ preprocess_chk_group = gr.CheckboxGroup(
359
+ ['Preprocess with Segmentation'],
360
+ label='Segment',
361
+ info=PREPROCESS_INSTRUCTION
362
+ )
363
+
364
+ preprocess_button = gr.Button(
365
+ value="Preprocess Input Image",
366
+ )
367
+ preprocess_button.click(
368
+ fn=partial(preprocess_image, sam_predictor),
369
+ inputs=[input_image, preprocess_chk_group, drags],
370
+ outputs=[processed_image, drags],
371
+ queue=True,
372
+ )
373
+
374
+ with gr.Column(scale=1):
375
+ generated_image = gr.Image(
376
+ type="numpy",
377
+ label="Generated Image",
378
+ interactive=False,
379
+ height=256,
380
+ width=256,
381
+ image_mode="RGB",
382
+ )
383
+
384
+ with gr.Accordion('Advanced generation options', open=True):
385
+ with gr.Row():
386
+ with gr.Column():
387
+ seed = gr.Slider(label="seed", value=0, minimum=0, maximum=10000, step=1, randomize=False)
388
+ cfg_scale = gr.Slider(
389
+ label="classifier-free guidance weight",
390
+ value=5, minimum=1, maximum=10, step=0.1
391
+ )
392
+
393
+ generate_button = gr.Button(
394
+ value="Generate Image",
395
+ )
396
+ generate_button.click(
397
+ fn=partial(generate_image, model, image_processor, vae, clip_model, clip_vit, diffusion),
398
+ inputs=[processed_image, seed, cfg_scale, drags],
399
+ outputs=[generated_image],
400
+ )
401
+
402
+ demo.launch(share=True)
diffusion/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from . import gaussian_diffusion as gd
7
+ from .respace import SpacedDiffusion, space_timesteps
8
+
9
+
10
+ def create_diffusion(
11
+ timestep_respacing,
12
+ noise_schedule="linear",
13
+ use_kl=False,
14
+ sigma_small=False,
15
+ predict_xstart=False,
16
+ learn_sigma=True,
17
+ rescale_learned_sigmas=False,
18
+ diffusion_steps=1000
19
+ ):
20
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21
+ if use_kl:
22
+ loss_type = gd.LossType.RESCALED_KL
23
+ elif rescale_learned_sigmas:
24
+ loss_type = gd.LossType.RESCALED_MSE
25
+ else:
26
+ loss_type = gd.LossType.MSE
27
+ if timestep_respacing is None or timestep_respacing == "":
28
+ timestep_respacing = [diffusion_steps]
29
+ return SpacedDiffusion(
30
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31
+ betas=betas,
32
+ model_mean_type=(
33
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34
+ ),
35
+ model_var_type=(
36
+ (
37
+ gd.ModelVarType.FIXED_LARGE
38
+ if not sigma_small
39
+ else gd.ModelVarType.FIXED_SMALL
40
+ )
41
+ if not learn_sigma
42
+ else gd.ModelVarType.LEARNED_RANGE
43
+ ),
44
+ loss_type=loss_type
45
+ # rescale_timesteps=rescale_timesteps,
46
+ )
diffusion/diffusion_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import torch as th
7
+ import numpy as np
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a continuous Gaussian distribution.
50
+ :param x: the targets
51
+ :param means: the Gaussian mean Tensor.
52
+ :param log_scales: the Gaussian log stddev Tensor.
53
+ :return: a tensor like x of log probabilities (in nats).
54
+ """
55
+ centered_x = x - means
56
+ inv_stdv = th.exp(-log_scales)
57
+ normalized_x = centered_x * inv_stdv
58
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59
+ return log_probs
60
+
61
+
62
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63
+ """
64
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
65
+ given image.
66
+ :param x: the target images. It is assumed that this was uint8 values,
67
+ rescaled to the range [-1, 1].
68
+ :param means: the Gaussian mean Tensor.
69
+ :param log_scales: the Gaussian log stddev Tensor.
70
+ :return: a tensor like x of log probabilities (in nats).
71
+ """
72
+ assert x.shape == means.shape == log_scales.shape
73
+ centered_x = x - means
74
+ inv_stdv = th.exp(-log_scales)
75
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76
+ cdf_plus = approx_standard_normal_cdf(plus_in)
77
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78
+ cdf_min = approx_standard_normal_cdf(min_in)
79
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81
+ cdf_delta = cdf_plus - cdf_min
82
+ log_probs = th.where(
83
+ x < -0.999,
84
+ log_cdf_plus,
85
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86
+ )
87
+ assert log_probs.shape == x.shape
88
+ return log_probs
diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,957 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch as th
11
+ import torch.nn.functional as F
12
+ import enum
13
+
14
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
15
+
16
+
17
+ def mean_flat(tensor):
18
+ """
19
+ Take the mean over all non-batch dimensions.
20
+ """
21
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
22
+
23
+
24
+ class ModelMeanType(enum.Enum):
25
+ """
26
+ Which type of output the model predicts.
27
+ """
28
+
29
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
30
+ START_X = enum.auto() # the model predicts x_0
31
+ EPSILON = enum.auto() # the model predicts epsilon
32
+
33
+
34
+ class ModelVarType(enum.Enum):
35
+ """
36
+ What is used as the model's output variance.
37
+ The LEARNED_RANGE option has been added to allow the model to predict
38
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
39
+ """
40
+
41
+ LEARNED = enum.auto()
42
+ FIXED_SMALL = enum.auto()
43
+ FIXED_LARGE = enum.auto()
44
+ LEARNED_RANGE = enum.auto()
45
+
46
+
47
+ class LossType(enum.Enum):
48
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
49
+ RESCALED_MSE = (
50
+ enum.auto()
51
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
52
+ KL = enum.auto() # use the variational lower-bound
53
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
54
+
55
+ def is_vb(self):
56
+ return self == LossType.KL or self == LossType.RESCALED_KL
57
+
58
+
59
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
60
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
61
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
62
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
63
+ return betas
64
+
65
+
66
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
67
+ """
68
+ This is the deprecated API for creating beta schedules.
69
+ See get_named_beta_schedule() for the new library of schedules.
70
+ """
71
+ if beta_schedule == "quad":
72
+ betas = (
73
+ np.linspace(
74
+ beta_start ** 0.5,
75
+ beta_end ** 0.5,
76
+ num_diffusion_timesteps,
77
+ dtype=np.float64,
78
+ )
79
+ ** 2
80
+ )
81
+ elif beta_schedule == "linear":
82
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
83
+ elif beta_schedule == "warmup10":
84
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
85
+ elif beta_schedule == "warmup50":
86
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
87
+ elif beta_schedule == "const":
88
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
89
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
90
+ betas = 1.0 / np.linspace(
91
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
92
+ )
93
+ else:
94
+ raise NotImplementedError(beta_schedule)
95
+ assert betas.shape == (num_diffusion_timesteps,)
96
+ return betas
97
+
98
+
99
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
100
+ """
101
+ Get a pre-defined beta schedule for the given name.
102
+ The beta schedule library consists of beta schedules which remain similar
103
+ in the limit of num_diffusion_timesteps.
104
+ Beta schedules may be added, but should not be removed or changed once
105
+ they are committed to maintain backwards compatibility.
106
+ """
107
+ if schedule_name == "linear":
108
+ # Linear schedule from Ho et al, extended to work for any number of
109
+ # diffusion steps.
110
+ scale = 1000 / num_diffusion_timesteps
111
+ return get_beta_schedule(
112
+ "linear",
113
+ beta_start=scale * 0.0001,
114
+ beta_end=scale * 0.02,
115
+ num_diffusion_timesteps=num_diffusion_timesteps,
116
+ )
117
+ elif schedule_name == "squaredcos_cap_v2":
118
+ return betas_for_alpha_bar(
119
+ num_diffusion_timesteps,
120
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
121
+ )
122
+ else:
123
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
124
+
125
+
126
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
127
+ """
128
+ Create a beta schedule that discretizes the given alpha_t_bar function,
129
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
130
+ :param num_diffusion_timesteps: the number of betas to produce.
131
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
132
+ produces the cumulative product of (1-beta) up to that
133
+ part of the diffusion process.
134
+ :param max_beta: the maximum beta to use; use values lower than 1 to
135
+ prevent singularities.
136
+ """
137
+ betas = []
138
+ for i in range(num_diffusion_timesteps):
139
+ t1 = i / num_diffusion_timesteps
140
+ t2 = (i + 1) / num_diffusion_timesteps
141
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
142
+ return np.array(betas)
143
+
144
+
145
+ class GaussianDiffusion:
146
+ """
147
+ Utilities for training and sampling diffusion models.
148
+ Original ported from this codebase:
149
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
150
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
151
+ starting at T and going to 1.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ *,
157
+ betas,
158
+ model_mean_type,
159
+ model_var_type,
160
+ loss_type
161
+ ):
162
+
163
+ self.model_mean_type = model_mean_type
164
+ self.model_var_type = model_var_type
165
+ self.loss_type = loss_type
166
+
167
+ # Use float64 for accuracy.
168
+ betas = np.array(betas, dtype=np.float64)
169
+ self.betas = betas
170
+ assert len(betas.shape) == 1, "betas must be 1-D"
171
+ assert (betas > 0).all() and (betas <= 1).all()
172
+
173
+ self.num_timesteps = int(betas.shape[0])
174
+
175
+ alphas = 1.0 - betas
176
+ self.alphas = alphas
177
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
178
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
179
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
180
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
181
+
182
+ # calculations for diffusion q(x_t | x_{t-1}) and others
183
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
184
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
185
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
186
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
187
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
188
+
189
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
190
+ self.posterior_variance = (
191
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
192
+ )
193
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
194
+ self.posterior_log_variance_clipped = np.log(
195
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
196
+ ) if len(self.posterior_variance) > 1 else np.array([])
197
+
198
+ self.posterior_mean_coef1 = (
199
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
200
+ )
201
+ self.posterior_mean_coef2 = (
202
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
203
+ )
204
+
205
+ def q_mean_variance(self, x_start, t):
206
+ """
207
+ Get the distribution q(x_t | x_0).
208
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
209
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
210
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
211
+ """
212
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
213
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
214
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
215
+ return mean, variance, log_variance
216
+
217
+ def q_sample(self, x_start, t, noise=None):
218
+ """
219
+ Diffuse the data for a given number of diffusion steps.
220
+ In other words, sample from q(x_t | x_0).
221
+ :param x_start: the initial data batch.
222
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
223
+ :param noise: if specified, the split-out normal noise.
224
+ :return: A noisy version of x_start.
225
+ """
226
+ if noise is None:
227
+ noise = th.randn_like(x_start)
228
+ assert noise.shape == x_start.shape
229
+ return (
230
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
231
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
232
+ )
233
+
234
+ def q_posterior_mean_variance(self, x_start, x_t, t):
235
+ """
236
+ Compute the mean and variance of the diffusion posterior:
237
+ q(x_{t-1} | x_t, x_0)
238
+ """
239
+ assert x_start.shape == x_t.shape
240
+ posterior_mean = (
241
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
242
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
243
+ )
244
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
245
+ posterior_log_variance_clipped = _extract_into_tensor(
246
+ self.posterior_log_variance_clipped, t, x_t.shape
247
+ )
248
+ assert (
249
+ posterior_mean.shape[0]
250
+ == posterior_variance.shape[0]
251
+ == posterior_log_variance_clipped.shape[0]
252
+ == x_start.shape[0]
253
+ )
254
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
255
+
256
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
257
+ """
258
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
259
+ the initial x, x_0.
260
+ :param model: the model, which takes a signal and a batch of timesteps
261
+ as input.
262
+ :param x: the [N x C x ...] tensor at time t.
263
+ :param t: a 1-D Tensor of timesteps.
264
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
265
+ :param denoised_fn: if not None, a function which applies to the
266
+ x_start prediction before it is used to sample. Applies before
267
+ clip_denoised.
268
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
269
+ pass to the model. This can be used for conditioning.
270
+ :return: a dict with the following keys:
271
+ - 'mean': the model mean output.
272
+ - 'variance': the model variance output.
273
+ - 'log_variance': the log of 'variance'.
274
+ - 'pred_xstart': the prediction for x_0.
275
+ """
276
+ if model_kwargs is None:
277
+ model_kwargs = {}
278
+ elif callable(model_kwargs):
279
+ model_kwargs = model_kwargs()
280
+
281
+ B, C = x.shape[:2]
282
+ assert t.shape == (B,)
283
+ model_output = model(x, t, **model_kwargs)
284
+ if isinstance(model_output, tuple):
285
+ model_output, extra = model_output
286
+ else:
287
+ extra = None
288
+
289
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
290
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
291
+ model_output, model_var_values = th.split(model_output, C, dim=1)
292
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
293
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
294
+ # The model_var_values is [-1, 1] for [min_var, max_var].
295
+ frac = (model_var_values + 1) / 2
296
+ model_log_variance = frac * max_log + (1 - frac) * min_log
297
+ model_variance = th.exp(model_log_variance)
298
+ else:
299
+ model_variance, model_log_variance = {
300
+ # for fixedlarge, we set the initial (log-)variance like so
301
+ # to get a better decoder log likelihood.
302
+ ModelVarType.FIXED_LARGE: (
303
+ np.append(self.posterior_variance[1], self.betas[1:]),
304
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
305
+ ),
306
+ ModelVarType.FIXED_SMALL: (
307
+ self.posterior_variance,
308
+ self.posterior_log_variance_clipped,
309
+ ),
310
+ }[self.model_var_type]
311
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
312
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
313
+
314
+ def process_xstart(x):
315
+ if denoised_fn is not None:
316
+ x = denoised_fn(x)
317
+ if clip_denoised:
318
+ return x.clamp(-1, 1)
319
+ return x
320
+
321
+ if self.model_mean_type == ModelMeanType.START_X:
322
+ pred_xstart = process_xstart(model_output)
323
+ else:
324
+ pred_xstart = process_xstart(
325
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
326
+ )
327
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
328
+
329
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
330
+ return {
331
+ "mean": model_mean,
332
+ "variance": model_variance,
333
+ "log_variance": model_log_variance,
334
+ "pred_xstart": pred_xstart,
335
+ "extra": extra,
336
+ }
337
+
338
+ def _predict_xstart_from_eps(self, x_t, t, eps):
339
+ assert x_t.shape == eps.shape
340
+ return (
341
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
342
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
343
+ )
344
+
345
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
346
+ return (
347
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
348
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
349
+
350
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
351
+ """
352
+ Compute the mean for the previous step, given a function cond_fn that
353
+ computes the gradient of a conditional log probability with respect to
354
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
355
+ condition on y.
356
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
357
+ """
358
+ gradient = cond_fn(x, t, **model_kwargs)
359
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
360
+ return new_mean
361
+
362
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
363
+ """
364
+ Compute what the p_mean_variance output would have been, should the
365
+ model's score function be conditioned by cond_fn.
366
+ See condition_mean() for details on cond_fn.
367
+ Unlike condition_mean(), this instead uses the conditioning strategy
368
+ from Song et al (2020).
369
+ """
370
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
371
+
372
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
373
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
374
+
375
+ out = p_mean_var.copy()
376
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
377
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
378
+ return out
379
+
380
+ def p_sample(
381
+ self,
382
+ model,
383
+ x,
384
+ t,
385
+ clip_denoised=True,
386
+ denoised_fn=None,
387
+ cond_fn=None,
388
+ model_kwargs=None,
389
+ keep_mask_region=None,
390
+ original_x=None,
391
+ resampling_steps: int = 20,
392
+ ):
393
+ """
394
+ Sample x_{t-1} from the model at the given timestep.
395
+ :param model: the model to sample from.
396
+ :param x: the current tensor at x_{t-1}.
397
+ :param t: the value of t, starting at 0 for the first diffusion step.
398
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
399
+ :param denoised_fn: if not None, a function which applies to the
400
+ x_start prediction before it is used to sample.
401
+ :param cond_fn: if not None, this is a gradient function that acts
402
+ similarly to the model.
403
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
404
+ pass to the model. This can be used for conditioning.
405
+ :return: a dict containing the following keys:
406
+ - 'sample': a random sample from the model.
407
+ - 'pred_xstart': a prediction of x_0.
408
+ """
409
+ if keep_mask_region is None:
410
+ out = self.p_mean_variance(
411
+ model,
412
+ x,
413
+ t,
414
+ clip_denoised=clip_denoised,
415
+ denoised_fn=denoised_fn,
416
+ model_kwargs=model_kwargs,
417
+ )
418
+ noise = th.randn_like(x)
419
+ nonzero_mask = (
420
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
421
+ ) # no noise when t == 0
422
+ if cond_fn is not None:
423
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
424
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
425
+ else:
426
+ assert original_x is not None
427
+ for _ in range(resampling_steps):
428
+ out = self.p_mean_variance(
429
+ model,
430
+ x,
431
+ t,
432
+ clip_denoised=clip_denoised,
433
+ denoised_fn=denoised_fn,
434
+ model_kwargs=model_kwargs,
435
+ )
436
+ noise = th.randn_like(x)
437
+ nonzero_mask = (
438
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
439
+ ) # no noise when t == 0
440
+ if cond_fn is not None:
441
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
442
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
443
+ t_neq_0 = t.clone()
444
+ t_neq_0[t_neq_0 == 0] = 1
445
+ x_known_sample = (1 - nonzero_mask) * original_x + nonzero_mask * self.q_sample(original_x, t_neq_0)
446
+ sample = keep_mask_region * x_known_sample + (1 - keep_mask_region) * sample
447
+
448
+ n = th.randn_like(x)
449
+ x = th.sqrt(_extract_into_tensor(self.alphas, t, x.shape)) * sample + \
450
+ _extract_into_tensor(self.betas, t, x.shape) * n
451
+
452
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
453
+
454
+ def p_sample_loop(
455
+ self,
456
+ model,
457
+ shape,
458
+ noise=None,
459
+ clip_denoised=True,
460
+ denoised_fn=None,
461
+ cond_fn=None,
462
+ model_kwargs=None,
463
+ device=None,
464
+ progress=False,
465
+ keep_mask_region=None,
466
+ original_x=None,
467
+ resampling_steps: int = 20,
468
+ ):
469
+ """
470
+ Generate samples from the model.
471
+ :param model: the model module.
472
+ :param shape: the shape of the samples, (N, C, H, W).
473
+ :param noise: if specified, the noise from the encoder to sample.
474
+ Should be of the same shape as `shape`.
475
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
476
+ :param denoised_fn: if not None, a function which applies to the
477
+ x_start prediction before it is used to sample.
478
+ :param cond_fn: if not None, this is a gradient function that acts
479
+ similarly to the model.
480
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
481
+ pass to the model. This can be used for conditioning.
482
+ :param device: if specified, the device to create the samples on.
483
+ If not specified, use a model parameter's device.
484
+ :param progress: if True, show a tqdm progress bar.
485
+ :return: a non-differentiable batch of samples.
486
+ """
487
+ final = None
488
+ for sample in self.p_sample_loop_progressive(
489
+ model,
490
+ shape,
491
+ noise=noise,
492
+ clip_denoised=clip_denoised,
493
+ denoised_fn=denoised_fn,
494
+ cond_fn=cond_fn,
495
+ model_kwargs=model_kwargs,
496
+ device=device,
497
+ progress=progress,
498
+ keep_mask_region=keep_mask_region,
499
+ original_x=original_x,
500
+ resampling_steps=resampling_steps,
501
+ ):
502
+ final = sample
503
+ return final["sample"]
504
+
505
+ def p_sample_loop_progressive(
506
+ self,
507
+ model,
508
+ shape,
509
+ noise=None,
510
+ clip_denoised=True,
511
+ denoised_fn=None,
512
+ cond_fn=None,
513
+ model_kwargs=None,
514
+ device=None,
515
+ progress=False,
516
+ keep_mask_region=None,
517
+ original_x=None,
518
+ resampling_steps: int = 20,
519
+ ):
520
+ """
521
+ Generate samples from the model and yield intermediate samples from
522
+ each timestep of diffusion.
523
+ Arguments are the same as p_sample_loop().
524
+ Returns a generator over dicts, where each dict is the return value of
525
+ p_sample().
526
+ """
527
+ if device is None:
528
+ device = next(model.parameters()).device
529
+ assert isinstance(shape, (tuple, list))
530
+ if noise is not None:
531
+ img = noise
532
+ else:
533
+ img = th.randn(*shape, device=device)
534
+ indices = list(range(self.num_timesteps))[::-1]
535
+
536
+ if progress:
537
+ # Lazy import so that we don't depend on tqdm.
538
+ from tqdm.auto import tqdm
539
+
540
+ indices = tqdm(indices)
541
+
542
+ for i in indices:
543
+ t = th.tensor([i] * shape[0], device=device)
544
+ with th.no_grad():
545
+ out = self.p_sample(
546
+ model,
547
+ img,
548
+ t,
549
+ clip_denoised=clip_denoised,
550
+ denoised_fn=denoised_fn,
551
+ cond_fn=cond_fn,
552
+ model_kwargs=model_kwargs,
553
+ keep_mask_region=keep_mask_region,
554
+ original_x=original_x,
555
+ resampling_steps=resampling_steps,
556
+ )
557
+ yield out
558
+ img = out["sample"]
559
+
560
+ def ddim_sample(
561
+ self,
562
+ model,
563
+ x,
564
+ t,
565
+ clip_denoised=True,
566
+ denoised_fn=None,
567
+ cond_fn=None,
568
+ model_kwargs=None,
569
+ eta=0.0,
570
+ ):
571
+ """
572
+ Sample x_{t-1} from the model using DDIM.
573
+ Same usage as p_sample().
574
+ """
575
+ out = self.p_mean_variance(
576
+ model,
577
+ x,
578
+ t,
579
+ clip_denoised=clip_denoised,
580
+ denoised_fn=denoised_fn,
581
+ model_kwargs=model_kwargs,
582
+ )
583
+ if cond_fn is not None:
584
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
585
+
586
+ # Usually our model outputs epsilon, but we re-derive it
587
+ # in case we used x_start or x_prev prediction.
588
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
589
+
590
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
591
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
592
+ sigma = (
593
+ eta
594
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
595
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
596
+ )
597
+ # Equation 12.
598
+ noise = th.randn_like(x)
599
+ mean_pred = (
600
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
601
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
602
+ )
603
+ nonzero_mask = (
604
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
605
+ ) # no noise when t == 0
606
+ sample = mean_pred + nonzero_mask * sigma * noise
607
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
608
+
609
+ def ddim_reverse_sample(
610
+ self,
611
+ model,
612
+ x,
613
+ t,
614
+ clip_denoised=True,
615
+ denoised_fn=None,
616
+ cond_fn=None,
617
+ model_kwargs=None,
618
+ eta=0.0,
619
+ ):
620
+ """
621
+ Sample x_{t+1} from the model using DDIM reverse ODE.
622
+ """
623
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
624
+ out = self.p_mean_variance(
625
+ model,
626
+ x,
627
+ t,
628
+ clip_denoised=clip_denoised,
629
+ denoised_fn=denoised_fn,
630
+ model_kwargs=model_kwargs,
631
+ )
632
+ if cond_fn is not None:
633
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
634
+ # Usually our model outputs epsilon, but we re-derive it
635
+ # in case we used x_start or x_prev prediction.
636
+ eps = (
637
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
638
+ - out["pred_xstart"]
639
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
640
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
641
+
642
+ # Equation 12. reversed
643
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
644
+
645
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
646
+
647
+ def ddim_sample_loop(
648
+ self,
649
+ model,
650
+ shape,
651
+ noise=None,
652
+ clip_denoised=True,
653
+ denoised_fn=None,
654
+ cond_fn=None,
655
+ model_kwargs=None,
656
+ device=None,
657
+ progress=False,
658
+ eta=0.0,
659
+ ):
660
+ """
661
+ Generate samples from the model using DDIM.
662
+ Same usage as p_sample_loop().
663
+ """
664
+ final = None
665
+ for sample in self.ddim_sample_loop_progressive(
666
+ model,
667
+ shape,
668
+ noise=noise,
669
+ clip_denoised=clip_denoised,
670
+ denoised_fn=denoised_fn,
671
+ cond_fn=cond_fn,
672
+ model_kwargs=model_kwargs,
673
+ device=device,
674
+ progress=progress,
675
+ eta=eta,
676
+ ):
677
+ final = sample
678
+ return final["sample"]
679
+
680
+ def ddim_sample_loop_progressive(
681
+ self,
682
+ model,
683
+ shape,
684
+ noise=None,
685
+ clip_denoised=True,
686
+ denoised_fn=None,
687
+ cond_fn=None,
688
+ model_kwargs=None,
689
+ device=None,
690
+ progress=False,
691
+ eta=0.0,
692
+ ):
693
+ """
694
+ Use DDIM to sample from the model and yield intermediate samples from
695
+ each timestep of DDIM.
696
+ Same usage as p_sample_loop_progressive().
697
+ """
698
+ if device is None:
699
+ device = next(model.parameters()).device
700
+ assert isinstance(shape, (tuple, list))
701
+ if noise is not None:
702
+ img = noise
703
+ else:
704
+ img = th.randn(*shape, device=device)
705
+ indices = list(range(self.num_timesteps))[::-1]
706
+
707
+ if progress:
708
+ # Lazy import so that we don't depend on tqdm.
709
+ from tqdm.auto import tqdm
710
+
711
+ indices = tqdm(indices)
712
+
713
+ for i in indices:
714
+ t = th.tensor([i] * shape[0], device=device)
715
+ with th.no_grad():
716
+ out = self.ddim_sample(
717
+ model,
718
+ img,
719
+ t,
720
+ clip_denoised=clip_denoised,
721
+ denoised_fn=denoised_fn,
722
+ cond_fn=cond_fn,
723
+ model_kwargs=model_kwargs,
724
+ eta=eta,
725
+ )
726
+ yield out
727
+ img = out["sample"]
728
+
729
+ def _vb_terms_bpd(
730
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
731
+ ):
732
+ """
733
+ Get a term for the variational lower-bound.
734
+ The resulting units are bits (rather than nats, as one might expect).
735
+ This allows for comparison to other papers.
736
+ :return: a dict with the following keys:
737
+ - 'output': a shape [N] tensor of NLLs or KLs.
738
+ - 'pred_xstart': the x_0 predictions.
739
+ """
740
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
741
+ x_start=x_start, x_t=x_t, t=t
742
+ )
743
+ out = self.p_mean_variance(
744
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
745
+ )
746
+ kl = normal_kl(
747
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
748
+ )
749
+ kl = mean_flat(kl) / np.log(2.0)
750
+
751
+ decoder_nll = -discretized_gaussian_log_likelihood(
752
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
753
+ )
754
+ assert decoder_nll.shape == x_start.shape
755
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
756
+
757
+ # At the first timestep return the decoder NLL,
758
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
759
+ output = th.where((t == 0), decoder_nll, kl)
760
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
761
+
762
+ def sds_losses(self, model, x_start, t, model_kwargs=None, noise=None):
763
+ if model_kwargs is None:
764
+ model_kwargs = {}
765
+ else:
766
+ model_kwargs = {
767
+ k: th.cat([v, v], dim=0) for k, v in model_kwargs.items()
768
+ }
769
+
770
+ if noise is None:
771
+ noise = th.randn_like(x_start)
772
+ x_t = self.q_sample(x_start, t, noise=noise)
773
+ x_t = th.cat([x_t, x_t], dim=0)
774
+ t = th.cat([t, t], dim=0)
775
+ model_output = model(x_t, t, **model_kwargs)
776
+ assert model_output.shape[0] % 2 == 0
777
+
778
+ B, C = x_t.shape[:2]
779
+ model_output = th.split(model_output, B // 2, dim=0)
780
+
781
+ target = {
782
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
783
+ x_start=x_start, x_t=x_t, t=t
784
+ )[0],
785
+ ModelMeanType.START_X: x_start,
786
+ ModelMeanType.EPSILON: noise,
787
+ }[self.model_mean_type]
788
+
789
+ assert self.model_mean_type == ModelMeanType.EPSILON
790
+ grad = (model_output - target)
791
+ t = (x_start - grad).detach()
792
+
793
+ return 0.5 * F.mse_loss(x_start, t, reduction="sum") / B * 2
794
+
795
+
796
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
797
+ """
798
+ Compute training losses for a single timestep.
799
+ :param model: the model to evaluate loss on.
800
+ :param x_start: the [N x C x ...] tensor of inputs.
801
+ :param t: a batch of timestep indices.
802
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
803
+ pass to the model. This can be used for conditioning.
804
+ :param noise: if specified, the specific Gaussian noise to try to remove.
805
+ :return: a dict with the key "loss" containing a tensor of shape [N].
806
+ Some mean or variance settings may also have other keys.
807
+ """
808
+ if model_kwargs is None:
809
+ model_kwargs = {}
810
+ if noise is None:
811
+ noise = th.randn_like(x_start)
812
+ x_t = self.q_sample(x_start, t, noise=noise)
813
+
814
+ terms = {}
815
+
816
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
817
+ terms["loss"] = self._vb_terms_bpd(
818
+ model=model,
819
+ x_start=x_start,
820
+ x_t=x_t,
821
+ t=t,
822
+ clip_denoised=False,
823
+ model_kwargs=model_kwargs,
824
+ )["output"]
825
+ if self.loss_type == LossType.RESCALED_KL:
826
+ terms["loss"] *= self.num_timesteps
827
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
828
+ model_output = model(x_t, t, **model_kwargs)
829
+
830
+ if self.model_var_type in [
831
+ ModelVarType.LEARNED,
832
+ ModelVarType.LEARNED_RANGE,
833
+ ]:
834
+ B, C = x_t.shape[:2]
835
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:]), (
836
+ model_output.shape,
837
+ (B, C * 2, *x_t.shape[2:]),
838
+ )
839
+ model_output, model_var_values = th.split(model_output, C, dim=1)
840
+ # Learn the variance using the variational bound, but don't let
841
+ # it affect our mean prediction.
842
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
843
+ terms["vb"] = self._vb_terms_bpd(
844
+ model=lambda *args, r=frozen_out: r,
845
+ x_start=x_start,
846
+ x_t=x_t,
847
+ t=t,
848
+ clip_denoised=False,
849
+ )["output"]
850
+ if self.loss_type == LossType.RESCALED_MSE:
851
+ # Divide by 1000 for equivalence with initial implementation.
852
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
853
+ terms["vb"] *= self.num_timesteps / 1000.0
854
+
855
+ target = {
856
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
857
+ x_start=x_start, x_t=x_t, t=t
858
+ )[0],
859
+ ModelMeanType.START_X: x_start,
860
+ ModelMeanType.EPSILON: noise,
861
+ }[self.model_mean_type]
862
+ assert model_output.shape == target.shape == x_start.shape
863
+ terms["mse"] = mean_flat((target - model_output) ** 2)
864
+ if "vb" in terms:
865
+ terms["loss"] = terms["mse"] + terms["vb"]
866
+ else:
867
+ terms["loss"] = terms["mse"]
868
+ else:
869
+ raise NotImplementedError(self.loss_type)
870
+
871
+ return terms
872
+
873
+ def _prior_bpd(self, x_start):
874
+ """
875
+ Get the prior KL term for the variational lower-bound, measured in
876
+ bits-per-dim.
877
+ This term can't be optimized, as it only depends on the encoder.
878
+ :param x_start: the [N x C x ...] tensor of inputs.
879
+ :return: a batch of [N] KL values (in bits), one per batch element.
880
+ """
881
+ batch_size = x_start.shape[0]
882
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
883
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
884
+ kl_prior = normal_kl(
885
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
886
+ )
887
+ return mean_flat(kl_prior) / np.log(2.0)
888
+
889
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
890
+ """
891
+ Compute the entire variational lower-bound, measured in bits-per-dim,
892
+ as well as other related quantities.
893
+ :param model: the model to evaluate loss on.
894
+ :param x_start: the [N x C x ...] tensor of inputs.
895
+ :param clip_denoised: if True, clip denoised samples.
896
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
897
+ pass to the model. This can be used for conditioning.
898
+ :return: a dict containing the following keys:
899
+ - total_bpd: the total variational lower-bound, per batch element.
900
+ - prior_bpd: the prior term in the lower-bound.
901
+ - vb: an [N x T] tensor of terms in the lower-bound.
902
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
903
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
904
+ """
905
+ device = x_start.device
906
+ batch_size = x_start.shape[0]
907
+
908
+ vb = []
909
+ xstart_mse = []
910
+ mse = []
911
+ for t in list(range(self.num_timesteps))[::-1]:
912
+ t_batch = th.tensor([t] * batch_size, device=device)
913
+ noise = th.randn_like(x_start)
914
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
915
+ # Calculate VLB term at the current timestep
916
+ with th.no_grad():
917
+ out = self._vb_terms_bpd(
918
+ model,
919
+ x_start=x_start,
920
+ x_t=x_t,
921
+ t=t_batch,
922
+ clip_denoised=clip_denoised,
923
+ model_kwargs=model_kwargs,
924
+ )
925
+ vb.append(out["output"])
926
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
927
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
928
+ mse.append(mean_flat((eps - noise) ** 2))
929
+
930
+ vb = th.stack(vb, dim=1)
931
+ xstart_mse = th.stack(xstart_mse, dim=1)
932
+ mse = th.stack(mse, dim=1)
933
+
934
+ prior_bpd = self._prior_bpd(x_start)
935
+ total_bpd = vb.sum(dim=1) + prior_bpd
936
+ return {
937
+ "total_bpd": total_bpd,
938
+ "prior_bpd": prior_bpd,
939
+ "vb": vb,
940
+ "xstart_mse": xstart_mse,
941
+ "mse": mse,
942
+ }
943
+
944
+
945
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
946
+ """
947
+ Extract values from a 1-D numpy array for a batch of indices.
948
+ :param arr: the 1-D numpy array.
949
+ :param timesteps: a tensor of indices into the array to extract.
950
+ :param broadcast_shape: a larger shape of K dimensions with the batch
951
+ dimension equal to the length of timesteps.
952
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
953
+ """
954
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
955
+ while len(res.shape) < len(broadcast_shape):
956
+ res = res[..., None]
957
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diffusion/respace.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from .gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ if section_count <= 1:
52
+ frac_stride = 1
53
+ else:
54
+ frac_stride = (size - 1) / (section_count - 1)
55
+ cur_idx = 0.0
56
+ taken_steps = []
57
+ for _ in range(section_count):
58
+ taken_steps.append(start_idx + round(cur_idx))
59
+ cur_idx += frac_stride
60
+ all_steps += taken_steps
61
+ start_idx += size
62
+ return set(all_steps)
63
+
64
+
65
+ class SpacedDiffusion(GaussianDiffusion):
66
+ """
67
+ A diffusion process which can skip steps in a base diffusion process.
68
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
69
+ original diffusion process to retain.
70
+ :param kwargs: the kwargs to create the base diffusion process.
71
+ """
72
+
73
+ def __init__(self, use_timesteps, **kwargs):
74
+ self.use_timesteps = set(use_timesteps)
75
+ self.timestep_map = []
76
+ self.original_num_steps = len(kwargs["betas"])
77
+
78
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79
+ last_alpha_cumprod = 1.0
80
+ new_betas = []
81
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82
+ if i in self.use_timesteps:
83
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84
+ last_alpha_cumprod = alpha_cumprod
85
+ self.timestep_map.append(i)
86
+ kwargs["betas"] = np.array(new_betas)
87
+ super().__init__(**kwargs)
88
+
89
+ def p_mean_variance(
90
+ self, model, *args, **kwargs
91
+ ): # pylint: disable=signature-differs
92
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93
+
94
+ def training_losses(
95
+ self, model, *args, **kwargs
96
+ ): # pylint: disable=signature-differs
97
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
98
+
99
+ def condition_mean(self, cond_fn, *args, **kwargs):
100
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101
+
102
+ def condition_score(self, cond_fn, *args, **kwargs):
103
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104
+
105
+ def _wrap_model(self, model):
106
+ if isinstance(model, _WrappedModel):
107
+ return model
108
+ return _WrappedModel(
109
+ model, self.timestep_map, self.original_num_steps
110
+ )
111
+
112
+ def _scale_timesteps(self, t):
113
+ # Scaling is done by the wrapped model.
114
+ return t
115
+
116
+
117
+ class _WrappedModel:
118
+ def __init__(self, model, timestep_map, original_num_steps):
119
+ self.model = model
120
+ self.timestep_map = timestep_map
121
+ # self.rescale_timesteps = rescale_timesteps
122
+ self.original_num_steps = original_num_steps
123
+
124
+ def __call__(self, x, ts, **kwargs):
125
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126
+ new_ts = map_tensor[ts]
127
+ # if self.rescale_timesteps:
128
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129
+ return self.model(x, new_ts, **kwargs)
diffusion/timestep_sampler.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ import torch.distributed as dist
11
+
12
+
13
+ def create_named_schedule_sampler(name, diffusion):
14
+ """
15
+ Create a ScheduleSampler from a library of pre-defined samplers.
16
+ :param name: the name of the sampler.
17
+ :param diffusion: the diffusion object to sample for.
18
+ """
19
+ if name == "uniform":
20
+ return UniformSampler(diffusion)
21
+ elif name == "loss-second-moment":
22
+ return LossSecondMomentResampler(diffusion)
23
+ else:
24
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
25
+
26
+
27
+ class ScheduleSampler(ABC):
28
+ """
29
+ A distribution over timesteps in the diffusion process, intended to reduce
30
+ variance of the objective.
31
+ By default, samplers perform unbiased importance sampling, in which the
32
+ objective's mean is unchanged.
33
+ However, subclasses may override sample() to change how the resampled
34
+ terms are reweighted, allowing for actual changes in the objective.
35
+ """
36
+
37
+ @abstractmethod
38
+ def weights(self):
39
+ """
40
+ Get a numpy array of weights, one per diffusion step.
41
+ The weights needn't be normalized, but must be positive.
42
+ """
43
+
44
+ def sample(self, batch_size, device):
45
+ """
46
+ Importance-sample timesteps for a batch.
47
+ :param batch_size: the number of timesteps.
48
+ :param device: the torch device to save to.
49
+ :return: a tuple (timesteps, weights):
50
+ - timesteps: a tensor of timestep indices.
51
+ - weights: a tensor of weights to scale the resulting losses.
52
+ """
53
+ w = self.weights()
54
+ p = w / np.sum(w)
55
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56
+ indices = th.from_numpy(indices_np).long().to(device)
57
+ weights_np = 1 / (len(p) * p[indices_np])
58
+ weights = th.from_numpy(weights_np).float().to(device)
59
+ return indices, weights
60
+
61
+
62
+ class UniformSampler(ScheduleSampler):
63
+ def __init__(self, diffusion):
64
+ self.diffusion = diffusion
65
+ self._weights = np.ones([diffusion.num_timesteps])
66
+
67
+ def weights(self):
68
+ return self._weights
69
+
70
+
71
+ class LossAwareSampler(ScheduleSampler):
72
+ def update_with_local_losses(self, local_ts, local_losses):
73
+ """
74
+ Update the reweighting using losses from a model.
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+ :param local_ts: an integer Tensor of timesteps.
80
+ :param local_losses: a 1D Tensor of losses.
81
+ """
82
+ batch_sizes = [
83
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
84
+ for _ in range(dist.get_world_size())
85
+ ]
86
+ dist.all_gather(
87
+ batch_sizes,
88
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89
+ )
90
+
91
+ # Pad all_gather batches to be the maximum batch size.
92
+ batch_sizes = [x.item() for x in batch_sizes]
93
+ max_bs = max(batch_sizes)
94
+
95
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97
+ dist.all_gather(timestep_batches, local_ts)
98
+ dist.all_gather(loss_batches, local_losses)
99
+ timesteps = [
100
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101
+ ]
102
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103
+ self.update_with_all_losses(timesteps, losses)
104
+
105
+ @abstractmethod
106
+ def update_with_all_losses(self, ts, losses):
107
+ """
108
+ Update the reweighting using losses from a model.
109
+ Sub-classes should override this method to update the reweighting
110
+ using losses from the model.
111
+ This method directly updates the reweighting without synchronizing
112
+ between workers. It is called by update_with_local_losses from all
113
+ ranks with identical arguments. Thus, it should have deterministic
114
+ behavior to maintain state across workers.
115
+ :param ts: a list of int timesteps.
116
+ :param losses: a list of float losses, one per timestep.
117
+ """
118
+
119
+
120
+ class LossSecondMomentResampler(LossAwareSampler):
121
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122
+ self.diffusion = diffusion
123
+ self.history_per_term = history_per_term
124
+ self.uniform_prob = uniform_prob
125
+ self._loss_history = np.zeros(
126
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
127
+ )
128
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129
+
130
+ def weights(self):
131
+ if not self._warmed_up():
132
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134
+ weights /= np.sum(weights)
135
+ weights *= 1 - self.uniform_prob
136
+ weights += self.uniform_prob / len(weights)
137
+ return weights
138
+
139
+ def update_with_all_losses(self, ts, losses):
140
+ for t, loss in zip(ts, losses):
141
+ if self._loss_counts[t] == self.history_per_term:
142
+ # Shift out the oldest loss term.
143
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
144
+ self._loss_history[t, -1] = loss
145
+ else:
146
+ self._loss_history[t, self._loss_counts[t]] = loss
147
+ self._loss_counts[t] += 1
148
+
149
+ def _warmed_up(self):
150
+ return (self._loss_counts == self.history_per_term).all()
example_images/000.png ADDED
example_images/001.png ADDED
example_images/002.png ADDED
example_images/003.png ADDED
example_images/004.png ADDED
example_images/005.png ADDED
example_images/006.png ADDED
example_images/007.png ADDED
example_images/008.png ADDED
example_images/009.png ADDED
example_images/010.png ADDED
example_images/011.png ADDED
example_images/012.png ADDED
example_images/013.png ADDED
example_images/014.png ADDED
example_images/015.png ADDED
example_images/016.png ADDED
example_images/018.png ADDED
example_images/019.png ADDED
example_images/020.png ADDED
example_images/021.png ADDED
example_images/022.png ADDED
example_images/023.png ADDED
example_images/024.png ADDED
model.py ADDED
@@ -0,0 +1,1992 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+ import os
4
+ import math
5
+ import json
6
+ from glob import glob
7
+ from functools import partial
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+
15
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
16
+ from diffusers.loaders import UNet2DConditionLoadersMixin
17
+ from diffusers.utils import BaseOutput, logging, is_torch_version
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+ from diffusers.models.unet_2d_blocks import (
20
+ CrossAttnDownBlock2D,
21
+ CrossAttnUpBlock2D,
22
+ DownBlock2D,
23
+ UNetMidBlock2DCrossAttn,
24
+ UNetMidBlock2DSimpleCrossAttn,
25
+ UpBlock2D,
26
+ get_down_block as gdb,
27
+ get_up_block as gub,
28
+ )
29
+ from diffusers.models.embeddings import (
30
+ GaussianFourierProjection,
31
+ ImageHintTimeEmbedding,
32
+ ImageProjection,
33
+ ImageTimeEmbedding,
34
+ TextImageProjection,
35
+ TextImageTimeEmbedding,
36
+ TextTimeEmbedding,
37
+ TimestepEmbedding,
38
+ Timesteps,
39
+ )
40
+ from diffusers.models.activations import get_activation
41
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor, Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
42
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
43
+ from diffusers.models.transformer_2d import Transformer2DModel
44
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
45
+
46
+
47
+ class CrossAttnDownBlock2DWithFlow(nn.Module):
48
+ def __init__(
49
+ self,
50
+ in_channels: int,
51
+ out_channels: int,
52
+ temb_channels: int,
53
+ flow_channels: int, # Added
54
+ dropout: float = 0.0,
55
+ num_layers: int = 1,
56
+ transformer_layers_per_block: int = 1,
57
+ resnet_eps: float = 1e-6,
58
+ resnet_time_scale_shift: str = "default",
59
+ resnet_act_fn: str = "swish",
60
+ resnet_groups: int = 32,
61
+ resnet_pre_norm: bool = True,
62
+ num_attention_heads=1,
63
+ cross_attention_dim=1280,
64
+ output_scale_factor=1.0,
65
+ downsample_padding=1,
66
+ add_downsample=True,
67
+ dual_cross_attention=False,
68
+ use_linear_projection=False,
69
+ only_cross_attention=False,
70
+ upcast_attention=False,
71
+ ):
72
+ super().__init__()
73
+ resnets = []
74
+ attentions = []
75
+ flow_convs = []
76
+
77
+ self.has_cross_attention = True
78
+ self.num_attention_heads = num_attention_heads
79
+
80
+ for i in range(num_layers):
81
+ in_channels = in_channels if i == 0 else out_channels
82
+ resnets.append(
83
+ ResnetBlock2D(
84
+ in_channels=in_channels,
85
+ out_channels=out_channels,
86
+ temb_channels=temb_channels,
87
+ eps=resnet_eps,
88
+ groups=resnet_groups,
89
+ dropout=dropout,
90
+ time_embedding_norm=resnet_time_scale_shift,
91
+ non_linearity=resnet_act_fn,
92
+ output_scale_factor=output_scale_factor,
93
+ pre_norm=resnet_pre_norm,
94
+ )
95
+ )
96
+ if not dual_cross_attention:
97
+ attentions.append(
98
+ Transformer2DModel(
99
+ num_attention_heads,
100
+ out_channels // num_attention_heads,
101
+ in_channels=out_channels,
102
+ num_layers=transformer_layers_per_block,
103
+ cross_attention_dim=cross_attention_dim,
104
+ norm_num_groups=resnet_groups,
105
+ use_linear_projection=use_linear_projection,
106
+ only_cross_attention=only_cross_attention,
107
+ upcast_attention=upcast_attention,
108
+ )
109
+ )
110
+ else:
111
+ attentions.append(
112
+ DualTransformer2DModel(
113
+ num_attention_heads,
114
+ out_channels // num_attention_heads,
115
+ in_channels=out_channels,
116
+ num_layers=1,
117
+ cross_attention_dim=cross_attention_dim,
118
+ norm_num_groups=resnet_groups,
119
+ )
120
+ )
121
+ flow_convs.append(
122
+ nn.Conv2d(
123
+ flow_channels, out_channels, kernel_size=3, padding=1, bias=False,
124
+ )
125
+ )
126
+ self.attentions = nn.ModuleList(attentions)
127
+ self.resnets = nn.ModuleList(resnets)
128
+ self.flow_convs = nn.ModuleList(flow_convs)
129
+
130
+ if add_downsample:
131
+ self.downsamplers = nn.ModuleList(
132
+ [
133
+ Downsample2D(
134
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
135
+ )
136
+ ]
137
+ )
138
+ else:
139
+ self.downsamplers = None
140
+
141
+ self.gradient_checkpointing = False
142
+
143
+ def forward(
144
+ self,
145
+ hidden_states: torch.FloatTensor,
146
+ temb: Optional[torch.FloatTensor] = None,
147
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
148
+ attention_mask: Optional[torch.FloatTensor] = None,
149
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
150
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
151
+ additional_residuals=None,
152
+ flow: Optional[torch.FloatTensor] = None, # Added
153
+ ):
154
+ output_states = ()
155
+
156
+ blocks = list(zip(self.resnets, self.attentions, self.flow_convs))
157
+
158
+ for i, (resnet, attn, flow_conv) in enumerate(blocks):
159
+ if self.training and self.gradient_checkpointing:
160
+
161
+ def create_custom_forward(module, return_dict=None):
162
+ def custom_forward(*inputs):
163
+ if return_dict is not None:
164
+ return module(*inputs, return_dict=return_dict)
165
+ else:
166
+ return module(*inputs)
167
+
168
+ return custom_forward
169
+
170
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
171
+ hidden_states = torch.utils.checkpoint.checkpoint(
172
+ create_custom_forward(resnet),
173
+ hidden_states,
174
+ temb,
175
+ **ckpt_kwargs,
176
+ )
177
+ hidden_states = torch.utils.checkpoint.checkpoint(
178
+ create_custom_forward(attn, return_dict=False),
179
+ hidden_states,
180
+ encoder_hidden_states,
181
+ None, # timestep
182
+ None, # class_labels
183
+ cross_attention_kwargs,
184
+ attention_mask,
185
+ encoder_attention_mask,
186
+ **ckpt_kwargs,
187
+ )[0]
188
+ else:
189
+ hidden_states = resnet(hidden_states, temb)
190
+ if flow is not None:
191
+ hidden_states = hidden_states + flow_conv(flow)
192
+ hidden_states = attn(
193
+ hidden_states,
194
+ encoder_hidden_states=encoder_hidden_states,
195
+ cross_attention_kwargs=cross_attention_kwargs,
196
+ attention_mask=attention_mask,
197
+ encoder_attention_mask=encoder_attention_mask,
198
+ return_dict=False,
199
+ )[0]
200
+
201
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
202
+ if i == len(blocks) - 1 and additional_residuals is not None:
203
+ hidden_states = hidden_states + additional_residuals
204
+
205
+ output_states = output_states + (hidden_states,)
206
+
207
+ if self.downsamplers is not None:
208
+ for downsampler in self.downsamplers:
209
+ hidden_states = downsampler(hidden_states)
210
+
211
+ output_states = output_states + (hidden_states,)
212
+
213
+ return hidden_states, output_states
214
+
215
+
216
+ class UNetMidBlock2DCrossAttnWithFlow(nn.Module):
217
+ def __init__(
218
+ self,
219
+ in_channels: int,
220
+ temb_channels: int,
221
+ flow_channels: int, # Added
222
+ dropout: float = 0.0,
223
+ num_layers: int = 1,
224
+ transformer_layers_per_block: int = 1,
225
+ resnet_eps: float = 1e-6,
226
+ resnet_time_scale_shift: str = "default",
227
+ resnet_act_fn: str = "swish",
228
+ resnet_groups: int = 32,
229
+ resnet_pre_norm: bool = True,
230
+ num_attention_heads=1,
231
+ output_scale_factor=1.0,
232
+ cross_attention_dim=1280,
233
+ dual_cross_attention=False,
234
+ use_linear_projection=False,
235
+ upcast_attention=False,
236
+ ):
237
+ super().__init__()
238
+
239
+ self.has_cross_attention = True
240
+ self.num_attention_heads = num_attention_heads
241
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
242
+
243
+ # there is always at least one resnet
244
+ resnets = [
245
+ ResnetBlock2D(
246
+ in_channels=in_channels,
247
+ out_channels=in_channels,
248
+ temb_channels=temb_channels,
249
+ eps=resnet_eps,
250
+ groups=resnet_groups,
251
+ dropout=dropout,
252
+ time_embedding_norm=resnet_time_scale_shift,
253
+ non_linearity=resnet_act_fn,
254
+ output_scale_factor=output_scale_factor,
255
+ pre_norm=resnet_pre_norm,
256
+ )
257
+ ]
258
+ flow_convs = [
259
+ nn.Conv2d(
260
+ flow_channels, in_channels, kernel_size=3, padding=1, bias=False,
261
+ )
262
+ ]
263
+ attentions = []
264
+
265
+ for _ in range(num_layers):
266
+ if not dual_cross_attention:
267
+ attentions.append(
268
+ Transformer2DModel(
269
+ num_attention_heads,
270
+ in_channels // num_attention_heads,
271
+ in_channels=in_channels,
272
+ num_layers=transformer_layers_per_block,
273
+ cross_attention_dim=cross_attention_dim,
274
+ norm_num_groups=resnet_groups,
275
+ use_linear_projection=use_linear_projection,
276
+ upcast_attention=upcast_attention,
277
+ )
278
+ )
279
+ else:
280
+ attentions.append(
281
+ DualTransformer2DModel(
282
+ num_attention_heads,
283
+ in_channels // num_attention_heads,
284
+ in_channels=in_channels,
285
+ num_layers=1,
286
+ cross_attention_dim=cross_attention_dim,
287
+ norm_num_groups=resnet_groups,
288
+ )
289
+ )
290
+ resnets.append(
291
+ ResnetBlock2D(
292
+ in_channels=in_channels,
293
+ out_channels=in_channels,
294
+ temb_channels=temb_channels,
295
+ eps=resnet_eps,
296
+ groups=resnet_groups,
297
+ dropout=dropout,
298
+ time_embedding_norm=resnet_time_scale_shift,
299
+ non_linearity=resnet_act_fn,
300
+ output_scale_factor=output_scale_factor,
301
+ pre_norm=resnet_pre_norm,
302
+ )
303
+ )
304
+ flow_convs.append(
305
+ nn.Conv2d(
306
+ flow_channels, in_channels, kernel_size=3, padding=1, bias=False,
307
+ )
308
+ )
309
+
310
+ self.attentions = nn.ModuleList(attentions)
311
+ self.resnets = nn.ModuleList(resnets)
312
+ self.flow_convs = nn.ModuleList(flow_convs)
313
+
314
+ def forward(
315
+ self,
316
+ hidden_states: torch.FloatTensor,
317
+ temb: Optional[torch.FloatTensor] = None,
318
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
319
+ attention_mask: Optional[torch.FloatTensor] = None,
320
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
321
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
322
+ flow: Optional[torch.FloatTensor] = None, # Added
323
+ ) -> torch.FloatTensor:
324
+ hidden_states = self.resnets[0](hidden_states, temb)
325
+ hidden_states = hidden_states + self.flow_convs[0](flow)
326
+ for attn, resnet, flow_conv in zip(self.attentions, self.resnets[1:], self.flow_convs[1:]):
327
+ hidden_states = attn(
328
+ hidden_states,
329
+ encoder_hidden_states=encoder_hidden_states,
330
+ cross_attention_kwargs=cross_attention_kwargs,
331
+ attention_mask=attention_mask,
332
+ encoder_attention_mask=encoder_attention_mask,
333
+ return_dict=False,
334
+ )[0]
335
+ hidden_states = resnet(hidden_states, temb)
336
+ hidden_states = hidden_states + flow_conv(flow)
337
+
338
+ return hidden_states
339
+
340
+
341
+ class CrossAttnUpBlock2DWithFlow(nn.Module):
342
+ def __init__(
343
+ self,
344
+ in_channels: int,
345
+ out_channels: int,
346
+ prev_output_channel: int,
347
+ temb_channels: int,
348
+ flow_channels: int, # Added
349
+ dropout: float = 0.0,
350
+ num_layers: int = 1,
351
+ transformer_layers_per_block: int = 1,
352
+ resnet_eps: float = 1e-6,
353
+ resnet_time_scale_shift: str = "default",
354
+ resnet_act_fn: str = "swish",
355
+ resnet_groups: int = 32,
356
+ resnet_pre_norm: bool = True,
357
+ num_attention_heads=1,
358
+ cross_attention_dim=1280,
359
+ output_scale_factor=1.0,
360
+ add_upsample=True,
361
+ dual_cross_attention=False,
362
+ use_linear_projection=False,
363
+ only_cross_attention=False,
364
+ upcast_attention=False,
365
+ ):
366
+ super().__init__()
367
+ resnets = []
368
+ attentions = []
369
+ flow_convs = []
370
+
371
+ self.has_cross_attention = True
372
+ self.num_attention_heads = num_attention_heads
373
+
374
+ for i in range(num_layers):
375
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
376
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
377
+
378
+ resnets.append(
379
+ ResnetBlock2D(
380
+ in_channels=resnet_in_channels + res_skip_channels,
381
+ out_channels=out_channels,
382
+ temb_channels=temb_channels,
383
+ eps=resnet_eps,
384
+ groups=resnet_groups,
385
+ dropout=dropout,
386
+ time_embedding_norm=resnet_time_scale_shift,
387
+ non_linearity=resnet_act_fn,
388
+ output_scale_factor=output_scale_factor,
389
+ pre_norm=resnet_pre_norm,
390
+ )
391
+ )
392
+ if not dual_cross_attention:
393
+ attentions.append(
394
+ Transformer2DModel(
395
+ num_attention_heads,
396
+ out_channels // num_attention_heads,
397
+ in_channels=out_channels,
398
+ num_layers=transformer_layers_per_block,
399
+ cross_attention_dim=cross_attention_dim,
400
+ norm_num_groups=resnet_groups,
401
+ use_linear_projection=use_linear_projection,
402
+ only_cross_attention=only_cross_attention,
403
+ upcast_attention=upcast_attention,
404
+ )
405
+ )
406
+ else:
407
+ attentions.append(
408
+ DualTransformer2DModel(
409
+ num_attention_heads,
410
+ out_channels // num_attention_heads,
411
+ in_channels=out_channels,
412
+ num_layers=1,
413
+ cross_attention_dim=cross_attention_dim,
414
+ norm_num_groups=resnet_groups,
415
+ )
416
+ )
417
+ flow_convs.append(
418
+ nn.Conv2d(
419
+ flow_channels, out_channels, kernel_size=3, padding=1, bias=False,
420
+ )
421
+ )
422
+ self.attentions = nn.ModuleList(attentions)
423
+ self.resnets = nn.ModuleList(resnets)
424
+ self.flow_convs = nn.ModuleList(flow_convs)
425
+
426
+ if add_upsample:
427
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
428
+ else:
429
+ self.upsamplers = None
430
+
431
+ self.gradient_checkpointing = False
432
+
433
+ def forward(
434
+ self,
435
+ hidden_states: torch.FloatTensor,
436
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
437
+ temb: Optional[torch.FloatTensor] = None,
438
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
439
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
440
+ upsample_size: Optional[int] = None,
441
+ attention_mask: Optional[torch.FloatTensor] = None,
442
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
443
+ flow: Optional[torch.FloatTensor] = None, # Added
444
+ ):
445
+ for resnet, attn, flow_conv in zip(self.resnets, self.attentions, self.flow_convs):
446
+ # pop res hidden states
447
+ res_hidden_states = res_hidden_states_tuple[-1]
448
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
449
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
450
+
451
+ if self.training and self.gradient_checkpointing:
452
+
453
+ def create_custom_forward(module, return_dict=None):
454
+ def custom_forward(*inputs):
455
+ if return_dict is not None:
456
+ return module(*inputs, return_dict=return_dict)
457
+ else:
458
+ return module(*inputs)
459
+
460
+ return custom_forward
461
+
462
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
463
+ hidden_states = torch.utils.checkpoint.checkpoint(
464
+ create_custom_forward(resnet),
465
+ hidden_states,
466
+ temb,
467
+ **ckpt_kwargs,
468
+ )
469
+ hidden_states = torch.utils.checkpoint.checkpoint(
470
+ create_custom_forward(attn, return_dict=False),
471
+ hidden_states,
472
+ encoder_hidden_states,
473
+ None, # timestep
474
+ None, # class_labels
475
+ cross_attention_kwargs,
476
+ attention_mask,
477
+ encoder_attention_mask,
478
+ **ckpt_kwargs,
479
+ )[0]
480
+ else:
481
+ hidden_states = resnet(hidden_states, temb)
482
+ hidden_states = hidden_states + flow_conv(flow)
483
+ hidden_states = attn(
484
+ hidden_states,
485
+ encoder_hidden_states=encoder_hidden_states,
486
+ cross_attention_kwargs=cross_attention_kwargs,
487
+ attention_mask=attention_mask,
488
+ encoder_attention_mask=encoder_attention_mask,
489
+ return_dict=False,
490
+ )[0]
491
+
492
+ if self.upsamplers is not None:
493
+ for upsampler in self.upsamplers:
494
+ hidden_states = upsampler(hidden_states, upsample_size)
495
+
496
+ return hidden_states
497
+
498
+
499
+
500
+ def get_sin_cos_pos_embed(embed_dim: int, x: torch.Tensor):
501
+ bsz, _ = x.shape
502
+ x = x.reshape(-1)[:, None]
503
+ div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)).to(x.device)
504
+ pos = x * div_term
505
+ pos = torch.cat([torch.sin(pos), torch.cos(pos)], dim=-1).reshape(bsz, -1)
506
+ return pos
507
+
508
+
509
+ def get_down_block(
510
+ with_concatenated_flow: bool = False,
511
+ *args,
512
+ **kwargs,
513
+ ):
514
+ if not with_concatenated_flow or args[0] == "DownBlock2D":
515
+ kwargs.pop("flow_channels", None)
516
+ return gdb(*args, **kwargs)
517
+ elif args[0] == "CrossAttnDownBlock2D":
518
+ kwargs.pop("downsample_type", None)
519
+ kwargs.pop("attention_head_dim", None)
520
+ kwargs.pop("resnet_skip_time_act", None)
521
+ kwargs.pop("resnet_out_scale_factor", None)
522
+ kwargs.pop("cross_attention_norm", None)
523
+ return CrossAttnDownBlock2DWithFlow(*args[1:], **kwargs)
524
+ else:
525
+ raise ValueError(f"Unknown down block type: {args[0]}")
526
+
527
+
528
+ def get_up_block(
529
+ with_concatenated_flow: bool = False,
530
+ *args,
531
+ **kwargs,
532
+ ):
533
+ if not with_concatenated_flow or args[0] == "UpBlock2D":
534
+ kwargs.pop("flow_channels", None)
535
+ return gub(*args, **kwargs)
536
+ elif args[0] == "CrossAttnUpBlock2D":
537
+ kwargs.pop("upsample_type", None)
538
+ kwargs.pop("attention_head_dim", None)
539
+ kwargs.pop("resnet_skip_time_act", None)
540
+ kwargs.pop("resnet_out_scale_factor", None)
541
+ kwargs.pop("cross_attention_norm", None)
542
+ return CrossAttnUpBlock2DWithFlow(*args[1:], **kwargs)
543
+ else:
544
+ raise ValueError(f"Unknown up block type: {args[0]}")
545
+
546
+
547
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
548
+
549
+
550
+ def avg_pool_nd(dims, *args, **kwargs):
551
+ """
552
+ Create a 1D, 2D, or 3D average pooling module.
553
+ """
554
+ if dims == 1:
555
+ return nn.AvgPool1d(*args, **kwargs)
556
+ elif dims == 2:
557
+ return nn.AvgPool2d(*args, **kwargs)
558
+ elif dims == 3:
559
+ return nn.AvgPool3d(*args, **kwargs)
560
+ raise ValueError(f"unsupported dimensions: {dims}")
561
+
562
+
563
+ def conv_nd(dims, *args, **kwargs):
564
+ """
565
+ Create a 1D, 2D, or 3D convolution module.
566
+ """
567
+ if dims == 1:
568
+ return nn.Conv1d(*args, **kwargs)
569
+ elif dims == 2:
570
+ return nn.Conv2d(*args, **kwargs)
571
+ elif dims == 3:
572
+ return nn.Conv3d(*args, **kwargs)
573
+ raise ValueError(f"unsupported dimensions: {dims}")
574
+
575
+
576
+ class Downsample(nn.Module):
577
+ """
578
+ A downsampling layer with an optional convolution.
579
+ :param channels: channels in the inputs and outputs.
580
+ :param use_conv: a bool determining if a convolution is applied.
581
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
582
+ downsampling occurs in the inner-two dimensions.
583
+ """
584
+
585
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
586
+ super().__init__()
587
+ self.channels = channels
588
+ self.out_channels = out_channels or channels
589
+ self.use_conv = use_conv
590
+ self.dims = dims
591
+ stride = 2 if dims != 3 else (1, 2, 2)
592
+ if use_conv:
593
+ self.op = conv_nd(
594
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
595
+ )
596
+ else:
597
+ assert self.channels == self.out_channels
598
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
599
+
600
+ def forward(self, x):
601
+ assert x.shape[1] == self.channels
602
+ return self.op(x)
603
+
604
+
605
+ class ResnetBlock(nn.Module):
606
+ def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
607
+ super().__init__()
608
+ ps = ksize // 2
609
+ if in_c != out_c or sk == False:
610
+ self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
611
+ else:
612
+ # print('n_in')
613
+ self.in_conv = None
614
+ self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
615
+ self.act = nn.ReLU()
616
+ self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
617
+ if sk == False:
618
+ # self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps) # edit by zhouxiawang
619
+ self.skep = nn.Conv2d(out_c, out_c, ksize, 1, ps)
620
+ else:
621
+ self.skep = None
622
+
623
+ self.down = down
624
+ if self.down == True:
625
+ self.down_opt = Downsample(in_c, use_conv=use_conv)
626
+
627
+ def forward(self, x):
628
+ if self.down == True:
629
+ x = self.down_opt(x)
630
+ if self.in_conv is not None: # edit
631
+ x = self.in_conv(x)
632
+
633
+ h = self.block1(x)
634
+ h = self.act(h)
635
+ h = self.block2(h)
636
+ if self.skep is not None:
637
+ return h + self.skep(x)
638
+ else:
639
+ return h + x
640
+
641
+
642
+ class Adapter(nn.Module):
643
+ def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True):
644
+ super(Adapter, self).__init__()
645
+ self.unshuffle = nn.PixelUnshuffle(16)
646
+ self.channels = channels
647
+ self.nums_rb = nums_rb
648
+ self.body = []
649
+ for i in range(len(channels)):
650
+ for j in range(nums_rb):
651
+ if (i != 0) and (j == 0):
652
+ self.body.append(
653
+ ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
654
+ else:
655
+ self.body.append(
656
+ ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
657
+ self.body = nn.ModuleList(self.body)
658
+ self.conv_in = nn.Conv2d(cin * 16 * 16, channels[0], 3, 1, 1)
659
+
660
+ def forward(self, x):
661
+ # unshuffle
662
+ x = self.unshuffle(x)
663
+ # extract features
664
+ features = []
665
+ x = self.conv_in(x)
666
+ for i in range(len(self.channels)):
667
+ for j in range(self.nums_rb):
668
+ idx = i * self.nums_rb + j
669
+ x = self.body[idx](x)
670
+ features.append(x)
671
+
672
+ return features
673
+
674
+
675
+ class OneSidedAttnProcessor:
676
+ r"""
677
+ Processor for performing attention-related computations where the key and value are always from the upper half batch
678
+ """
679
+
680
+ def __call__(
681
+ self,
682
+ attn: Attention,
683
+ hidden_states,
684
+ encoder_hidden_states=None,
685
+ attention_mask=None,
686
+ temb=None,
687
+ ):
688
+ assert encoder_hidden_states is None
689
+ residual = hidden_states
690
+
691
+ if attn.spatial_norm is not None:
692
+ hidden_states = attn.spatial_norm(hidden_states, temb)
693
+
694
+ input_ndim = hidden_states.ndim
695
+
696
+ if input_ndim == 4:
697
+ batch_size, channel, height, width = hidden_states.shape
698
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
699
+
700
+ batch_size, sequence_length, _ = (
701
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
702
+ )
703
+
704
+ assert batch_size % 2 == 0, "batch size must be even"
705
+ half_batch_size = batch_size // 2
706
+ hidden_states_1, hidden_states_2 = hidden_states.chunk(2, dim=0)
707
+
708
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, half_batch_size)
709
+
710
+ if attn.group_norm is not None:
711
+ hidden_states_1 = attn.group_norm(hidden_states_1.transpose(1, 2)).transpose(1, 2)
712
+ hidden_states_2 = attn.group_norm(hidden_states_2.transpose(1, 2)).transpose(1, 2)
713
+
714
+ query_1 = attn.to_q(hidden_states_1)
715
+ query_2 = attn.to_q(hidden_states_2)
716
+ key = attn.to_k(hidden_states_1)
717
+ value = attn.to_v(hidden_states_1)
718
+
719
+ query = torch.cat([query_1, query_2], dim=0)
720
+ key = torch.cat([key, key], dim=0)
721
+ value = torch.cat([value, value], dim=0)
722
+
723
+ query = attn.head_to_batch_dim(query)
724
+ key = attn.head_to_batch_dim(key)
725
+ value = attn.head_to_batch_dim(value)
726
+
727
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
728
+ hidden_states = torch.bmm(attention_probs, value)
729
+ hidden_states = attn.batch_to_head_dim(hidden_states)
730
+
731
+ # linear proj
732
+ hidden_states = attn.to_out[0](hidden_states)
733
+ # dropout
734
+ hidden_states = attn.to_out[1](hidden_states)
735
+
736
+ if input_ndim == 4:
737
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
738
+
739
+ if attn.residual_connection:
740
+ hidden_states = hidden_states + residual
741
+
742
+ hidden_states = hidden_states / attn.rescale_output_factor
743
+
744
+ return hidden_states
745
+
746
+
747
+ class UNet2DDragConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
748
+ r"""
749
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
750
+ shaped output.
751
+
752
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
753
+ for all models (such as downloading or saving).
754
+
755
+ Parameters:
756
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
757
+ Height and width of input/output sample.
758
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
759
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
760
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
761
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
762
+ Whether to flip the sin to cos in the time embedding.
763
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
764
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
765
+ The tuple of downsample blocks to use.
766
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
767
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
768
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
769
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
770
+ The tuple of upsample blocks to use.
771
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
772
+ Whether to include self-attention in the basic transformer blocks, see
773
+ [`~models.attention.BasicTransformerBlock`].
774
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
775
+ The tuple of output channels for each block.
776
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
777
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
778
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
779
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
780
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
781
+ If `None`, normalization and activation layers is skipped in post-processing.
782
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
783
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
784
+ The dimension of the cross attention features.
785
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
786
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
787
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
788
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
789
+ encoder_hid_dim (`int`, *optional*, defaults to None):
790
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
791
+ dimension to `cross_attention_dim`.
792
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
793
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
794
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
795
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
796
+ num_attention_heads (`int`, *optional*):
797
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
798
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
799
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
800
+ class_embed_type (`str`, *optional*, defaults to `None`):
801
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
802
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
803
+ addition_embed_type (`str`, *optional*, defaults to `None`):
804
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
805
+ "text". "text" will use the `TextTimeEmbedding` layer.
806
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
807
+ Dimension for the timestep embeddings.
808
+ num_class_embeds (`int`, *optional*, defaults to `None`):
809
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
810
+ class conditioning with `class_embed_type` equal to `None`.
811
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
812
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
813
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
814
+ An optional override for the dimension of the projected time embedding.
815
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
816
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
817
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
818
+ timestep_post_act (`str`, *optional*, defaults to `None`):
819
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
820
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
821
+ The dimension of `cond_proj` layer in the timestep embedding.
822
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
823
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
824
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
825
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
826
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
827
+ embeddings with the class embeddings.
828
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
829
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
830
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
831
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
832
+ otherwise.
833
+ """
834
+
835
+ _supports_gradient_checkpointing = True
836
+
837
+ @register_to_config
838
+ def __init__(
839
+ self,
840
+ sample_size: Optional[int] = None,
841
+ in_channels: int = 4,
842
+ flow_channels: int = 3,
843
+ out_channels: int = 4,
844
+ center_input_sample: bool = False,
845
+ flip_sin_to_cos: bool = True,
846
+ freq_shift: int = 0,
847
+ down_block_types: Tuple[str] = (
848
+ "CrossAttnDownBlock2D",
849
+ "CrossAttnDownBlock2D",
850
+ "CrossAttnDownBlock2D",
851
+ "DownBlock2D",
852
+ ),
853
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
854
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
855
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
856
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
857
+ layers_per_block: Union[int, Tuple[int]] = 2,
858
+ downsample_padding: int = 1,
859
+ mid_block_scale_factor: float = 1,
860
+ act_fn: str = "silu",
861
+ norm_num_groups: Optional[int] = 32,
862
+ norm_eps: float = 1e-5,
863
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
864
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
865
+ encoder_hid_dim: Optional[int] = None,
866
+ encoder_hid_dim_type: Optional[str] = None,
867
+ attention_head_dim: Union[int, Tuple[int]] = 8,
868
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
869
+ dual_cross_attention: bool = False,
870
+ use_linear_projection: bool = False,
871
+ class_embed_type: Optional[str] = None,
872
+ addition_embed_type: Optional[str] = None,
873
+ addition_time_embed_dim: Optional[int] = None,
874
+ num_class_embeds: Optional[int] = None,
875
+ upcast_attention: bool = False,
876
+ resnet_time_scale_shift: str = "default",
877
+ resnet_skip_time_act: bool = False,
878
+ resnet_out_scale_factor: int = 1.0,
879
+ time_embedding_type: str = "positional",
880
+ time_embedding_dim: Optional[int] = None,
881
+ time_embedding_act_fn: Optional[str] = None,
882
+ timestep_post_act: Optional[str] = None,
883
+ time_cond_proj_dim: Optional[int] = None,
884
+ conv_in_kernel: int = 3,
885
+ conv_out_kernel: int = 3,
886
+ projection_class_embeddings_input_dim: Optional[int] = None,
887
+ class_embeddings_concat: bool = False,
888
+ mid_block_only_cross_attention: Optional[bool] = None,
889
+ cross_attention_norm: Optional[str] = None,
890
+ addition_embed_type_num_heads=64,
891
+
892
+ # Added
893
+ clip_embedding_dim: int = 1024,
894
+ num_clip_in: int = 25,
895
+ dragging_embedding_dim: int = 256,
896
+ use_drag_tokens: bool = True,
897
+ single_drag_token: bool = False,
898
+ num_drags: int = 10,
899
+
900
+ class_dropout_prob: float = 0.1,
901
+
902
+ flow_original_res: bool = False,
903
+ flow_size: int = 512,
904
+
905
+ input_concat_dragging: bool = True,
906
+ attn_concat_dragging: bool = False,
907
+ flow_multi_resolution_conv: bool = False,
908
+
909
+ flow_in_old_version: bool = True,
910
+ ):
911
+ super().__init__()
912
+
913
+ assert input_concat_dragging or attn_concat_dragging or flow_multi_resolution_conv
914
+ if flow_multi_resolution_conv:
915
+ assert not attn_concat_dragging and not input_concat_dragging
916
+
917
+ self.sample_size = sample_size
918
+
919
+ self.drag_dropout_prob = class_dropout_prob
920
+
921
+ if num_attention_heads is not None:
922
+ raise ValueError(
923
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
924
+ )
925
+
926
+ # If `num_attention_heads` is not defined (which is the case for most models)
927
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
928
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
929
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
930
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
931
+ # which is why we correct for the naming here.
932
+ num_attention_heads = num_attention_heads or attention_head_dim
933
+
934
+ # Check inputs
935
+ if len(down_block_types) != len(up_block_types):
936
+ raise ValueError(
937
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
938
+ )
939
+
940
+ if len(block_out_channels) != len(down_block_types):
941
+ raise ValueError(
942
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
943
+ )
944
+
945
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
946
+ raise ValueError(
947
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
948
+ )
949
+
950
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
951
+ raise ValueError(
952
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
953
+ )
954
+
955
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
956
+ raise ValueError(
957
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
958
+ )
959
+
960
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
961
+ raise ValueError(
962
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
963
+ )
964
+
965
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
966
+ raise ValueError(
967
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
968
+ )
969
+
970
+ # input
971
+ conv_in_padding = (conv_in_kernel - 1) // 2
972
+
973
+ self.num_drags = num_drags
974
+
975
+ self.attn_concat_dragging = attn_concat_dragging
976
+ if self.attn_concat_dragging:
977
+ self.drag_extra_dim = 4 * self.num_drags
978
+
979
+ self.flow_multi_resolution_conv = flow_multi_resolution_conv
980
+ if self.flow_multi_resolution_conv:
981
+ self.flow_adapter = Adapter(
982
+ channels=block_out_channels[:1] + block_out_channels[:-1],
983
+ nums_rb=2,
984
+ cin=3,
985
+ sk=True,
986
+ use_conv=False,
987
+ )
988
+
989
+ self.input_concat_dragging = input_concat_dragging
990
+ self.flow_in_old_version = flow_in_old_version
991
+ if self.input_concat_dragging:
992
+ if self.flow_in_old_version:
993
+ self.conv_in_flow = nn.Conv2d(
994
+ in_channels + flow_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
995
+ )
996
+ else:
997
+ self.conv_in = nn.Conv2d(
998
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
999
+ )
1000
+ self.conv_in_flow = nn.Conv2d(
1001
+ flow_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding, bias=False
1002
+ )
1003
+ else:
1004
+ self.conv_in = nn.Conv2d(
1005
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
1006
+ )
1007
+
1008
+ self.flow_original_res = flow_original_res
1009
+ if flow_original_res and self.input_concat_dragging:
1010
+ self.num_flow_down_layers = 0
1011
+ cur_sample_size = sample_size
1012
+ while flow_size > cur_sample_size:
1013
+ assert flow_size % cur_sample_size == 0
1014
+ self.num_flow_down_layers += 1
1015
+ cur_sample_size *= 2
1016
+
1017
+ self.flow_preprocess = nn.ModuleList([])
1018
+ for _ in range(self.num_flow_down_layers):
1019
+ self.flow_preprocess.append(nn.Conv2d(
1020
+ flow_channels, flow_channels, kernel_size=3, padding=1
1021
+ ))
1022
+ self.flow_proj_act = get_activation(act_fn)
1023
+
1024
+ self.num_clip_in = num_clip_in
1025
+ self.clip_proj = nn.ModuleList([])
1026
+ for i in range(num_clip_in):
1027
+ self.clip_proj.append(nn.Linear(clip_embedding_dim, clip_embedding_dim))
1028
+ self.clip_final = nn.Linear(clip_embedding_dim, cross_attention_dim)
1029
+
1030
+ self.use_drag_tokens = use_drag_tokens
1031
+ self.single_drag_token = single_drag_token
1032
+ if use_drag_tokens:
1033
+ self.dragging_embedding_dim = dragging_embedding_dim
1034
+ self.drag_proj = nn.Linear(dragging_embedding_dim * 4, dragging_embedding_dim * 4)
1035
+ self.drag_final = nn.Linear(dragging_embedding_dim * 4, cross_attention_dim)
1036
+ self.proj_act = get_activation(act_fn)
1037
+
1038
+ # time
1039
+ if time_embedding_type == "fourier":
1040
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
1041
+ if time_embed_dim % 2 != 0:
1042
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
1043
+ self.time_proj = GaussianFourierProjection(
1044
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
1045
+ )
1046
+ timestep_input_dim = time_embed_dim
1047
+ elif time_embedding_type == "positional":
1048
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
1049
+
1050
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
1051
+ timestep_input_dim = block_out_channels[0]
1052
+ else:
1053
+ raise ValueError(
1054
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
1055
+ )
1056
+
1057
+ self.time_embedding = TimestepEmbedding(
1058
+ timestep_input_dim,
1059
+ time_embed_dim,
1060
+ act_fn=act_fn,
1061
+ post_act_fn=timestep_post_act,
1062
+ cond_proj_dim=time_cond_proj_dim,
1063
+ )
1064
+
1065
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
1066
+ encoder_hid_dim_type = "text_proj"
1067
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
1068
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
1069
+
1070
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
1071
+ raise ValueError(
1072
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
1073
+ )
1074
+
1075
+ if encoder_hid_dim_type == "text_proj":
1076
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
1077
+ elif encoder_hid_dim_type == "text_image_proj":
1078
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
1079
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
1080
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
1081
+ self.encoder_hid_proj = TextImageProjection(
1082
+ text_embed_dim=encoder_hid_dim,
1083
+ image_embed_dim=cross_attention_dim,
1084
+ cross_attention_dim=cross_attention_dim,
1085
+ )
1086
+ elif encoder_hid_dim_type == "image_proj":
1087
+ # Kandinsky 2.2
1088
+ self.encoder_hid_proj = ImageProjection(
1089
+ image_embed_dim=encoder_hid_dim,
1090
+ cross_attention_dim=cross_attention_dim,
1091
+ )
1092
+ elif encoder_hid_dim_type is not None:
1093
+ raise ValueError(
1094
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
1095
+ )
1096
+ else:
1097
+ self.encoder_hid_proj = None
1098
+
1099
+ # class embedding
1100
+ if class_embed_type is None and num_class_embeds is not None:
1101
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
1102
+ elif class_embed_type == "timestep":
1103
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
1104
+ elif class_embed_type == "identity":
1105
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
1106
+ elif class_embed_type == "projection":
1107
+ if projection_class_embeddings_input_dim is None:
1108
+ raise ValueError(
1109
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
1110
+ )
1111
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
1112
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
1113
+ # 2. it projects from an arbitrary input dimension.
1114
+ #
1115
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
1116
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
1117
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
1118
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
1119
+ elif class_embed_type == "simple_projection":
1120
+ if projection_class_embeddings_input_dim is None:
1121
+ raise ValueError(
1122
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
1123
+ )
1124
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
1125
+ else:
1126
+ self.class_embedding = None
1127
+
1128
+ if addition_embed_type == "text":
1129
+ if encoder_hid_dim is not None:
1130
+ text_time_embedding_from_dim = encoder_hid_dim
1131
+ else:
1132
+ text_time_embedding_from_dim = cross_attention_dim
1133
+
1134
+ self.add_embedding = TextTimeEmbedding(
1135
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
1136
+ )
1137
+ elif addition_embed_type == "text_image":
1138
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
1139
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
1140
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
1141
+ self.add_embedding = TextImageTimeEmbedding(
1142
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
1143
+ )
1144
+ elif addition_embed_type == "text_time":
1145
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
1146
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
1147
+ elif addition_embed_type == "image":
1148
+ # Kandinsky 2.2
1149
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
1150
+ elif addition_embed_type == "image_hint":
1151
+ # Kandinsky 2.2 ControlNet
1152
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
1153
+ elif addition_embed_type is not None:
1154
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
1155
+
1156
+ if time_embedding_act_fn is None:
1157
+ self.time_embed_act = None
1158
+ else:
1159
+ self.time_embed_act = get_activation(time_embedding_act_fn)
1160
+
1161
+ self.down_blocks = nn.ModuleList([])
1162
+ self.up_blocks = nn.ModuleList([])
1163
+
1164
+ if isinstance(only_cross_attention, bool):
1165
+ if mid_block_only_cross_attention is None:
1166
+ mid_block_only_cross_attention = only_cross_attention
1167
+
1168
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
1169
+
1170
+ if mid_block_only_cross_attention is None:
1171
+ mid_block_only_cross_attention = False
1172
+
1173
+ if isinstance(num_attention_heads, int):
1174
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
1175
+
1176
+ if isinstance(attention_head_dim, int):
1177
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
1178
+
1179
+ if isinstance(cross_attention_dim, int):
1180
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
1181
+
1182
+ if isinstance(layers_per_block, int):
1183
+ layers_per_block = [layers_per_block] * len(down_block_types)
1184
+
1185
+ if isinstance(transformer_layers_per_block, int):
1186
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
1187
+
1188
+ if class_embeddings_concat:
1189
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
1190
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
1191
+ # regular time embeddings
1192
+ blocks_time_embed_dim = time_embed_dim * 2
1193
+ else:
1194
+ blocks_time_embed_dim = time_embed_dim
1195
+
1196
+ # down
1197
+ output_channel = block_out_channels[0]
1198
+ for i, down_block_type in enumerate(down_block_types):
1199
+ input_channel = output_channel
1200
+ output_channel = block_out_channels[i]
1201
+ is_final_block = i == len(block_out_channels) - 1
1202
+
1203
+ down_block = get_down_block(
1204
+ self.attn_concat_dragging,
1205
+ down_block_type,
1206
+ num_layers=layers_per_block[i],
1207
+ transformer_layers_per_block=transformer_layers_per_block[i],
1208
+ in_channels=input_channel,
1209
+ out_channels=output_channel,
1210
+ temb_channels=blocks_time_embed_dim,
1211
+ add_downsample=not is_final_block,
1212
+ resnet_eps=norm_eps,
1213
+ resnet_act_fn=act_fn,
1214
+ resnet_groups=norm_num_groups,
1215
+ cross_attention_dim=cross_attention_dim[i],
1216
+ num_attention_heads=num_attention_heads[i],
1217
+ downsample_padding=downsample_padding,
1218
+ dual_cross_attention=dual_cross_attention,
1219
+ use_linear_projection=use_linear_projection,
1220
+ only_cross_attention=only_cross_attention[i],
1221
+ upcast_attention=upcast_attention,
1222
+ resnet_time_scale_shift=resnet_time_scale_shift,
1223
+ resnet_skip_time_act=resnet_skip_time_act,
1224
+ resnet_out_scale_factor=resnet_out_scale_factor,
1225
+ cross_attention_norm=cross_attention_norm,
1226
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
1227
+ flow_channels=self.drag_extra_dim if self.attn_concat_dragging else None,
1228
+ )
1229
+ self.down_blocks.append(down_block)
1230
+
1231
+ # mid
1232
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
1233
+ mid_block_kwargs = dict(
1234
+ transformer_layers_per_block=transformer_layers_per_block[-1],
1235
+ in_channels=block_out_channels[-1],
1236
+ temb_channels=blocks_time_embed_dim,
1237
+ resnet_eps=norm_eps,
1238
+ resnet_act_fn=act_fn,
1239
+ output_scale_factor=mid_block_scale_factor,
1240
+ resnet_time_scale_shift=resnet_time_scale_shift,
1241
+ cross_attention_dim=cross_attention_dim[-1],
1242
+ num_attention_heads=num_attention_heads[-1],
1243
+ resnet_groups=norm_num_groups,
1244
+ dual_cross_attention=dual_cross_attention,
1245
+ use_linear_projection=use_linear_projection,
1246
+ upcast_attention=upcast_attention,
1247
+ )
1248
+
1249
+ if self.attn_concat_dragging:
1250
+ mid_block_kwargs["flow_channels"] = self.drag_extra_dim
1251
+ mid_block_type += "WithFlow"
1252
+
1253
+ self.mid_block = eval(mid_block_type)(
1254
+ **mid_block_kwargs
1255
+ )
1256
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
1257
+ raise NotImplementedError
1258
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
1259
+ in_channels=block_out_channels[-1],
1260
+ temb_channels=blocks_time_embed_dim,
1261
+ resnet_eps=norm_eps,
1262
+ resnet_act_fn=act_fn,
1263
+ output_scale_factor=mid_block_scale_factor,
1264
+ cross_attention_dim=cross_attention_dim[-1],
1265
+ attention_head_dim=attention_head_dim[-1],
1266
+ resnet_groups=norm_num_groups,
1267
+ resnet_time_scale_shift=resnet_time_scale_shift,
1268
+ skip_time_act=resnet_skip_time_act,
1269
+ only_cross_attention=mid_block_only_cross_attention,
1270
+ cross_attention_norm=cross_attention_norm,
1271
+ )
1272
+ elif mid_block_type is None:
1273
+ self.mid_block = None
1274
+ else:
1275
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
1276
+
1277
+ # count how many layers upsample the images
1278
+ self.num_upsamplers = 0
1279
+
1280
+ # up
1281
+ reversed_block_out_channels = list(reversed(block_out_channels))
1282
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
1283
+ reversed_layers_per_block = list(reversed(layers_per_block))
1284
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
1285
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
1286
+ only_cross_attention = list(reversed(only_cross_attention))
1287
+
1288
+ output_channel = reversed_block_out_channels[0]
1289
+ for i, up_block_type in enumerate(up_block_types):
1290
+ is_final_block = i == len(block_out_channels) - 1
1291
+
1292
+ prev_output_channel = output_channel
1293
+ output_channel = reversed_block_out_channels[i]
1294
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
1295
+
1296
+ # add upsample block for all BUT final layer
1297
+ if not is_final_block:
1298
+ add_upsample = True
1299
+ self.num_upsamplers += 1
1300
+ else:
1301
+ add_upsample = False
1302
+
1303
+ up_block = get_up_block(
1304
+ self.attn_concat_dragging,
1305
+ up_block_type,
1306
+ num_layers=reversed_layers_per_block[i] + 1,
1307
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
1308
+ in_channels=input_channel,
1309
+ out_channels=output_channel,
1310
+ prev_output_channel=prev_output_channel,
1311
+ temb_channels=blocks_time_embed_dim,
1312
+ add_upsample=add_upsample,
1313
+ resnet_eps=norm_eps,
1314
+ resnet_act_fn=act_fn,
1315
+ resnet_groups=norm_num_groups,
1316
+ cross_attention_dim=reversed_cross_attention_dim[i],
1317
+ num_attention_heads=reversed_num_attention_heads[i],
1318
+ dual_cross_attention=dual_cross_attention,
1319
+ use_linear_projection=use_linear_projection,
1320
+ only_cross_attention=only_cross_attention[i],
1321
+ upcast_attention=upcast_attention,
1322
+ resnet_time_scale_shift=resnet_time_scale_shift,
1323
+ resnet_skip_time_act=resnet_skip_time_act,
1324
+ resnet_out_scale_factor=resnet_out_scale_factor,
1325
+ cross_attention_norm=cross_attention_norm,
1326
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
1327
+ flow_channels=self.drag_extra_dim if self.attn_concat_dragging else None,
1328
+ )
1329
+ self.up_blocks.append(up_block)
1330
+ prev_output_channel = output_channel
1331
+
1332
+ # out
1333
+ if norm_num_groups is not None:
1334
+ self.conv_norm_out = nn.GroupNorm(
1335
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
1336
+ )
1337
+
1338
+ self.conv_act = get_activation(act_fn)
1339
+
1340
+ else:
1341
+ self.conv_norm_out = None
1342
+ self.conv_act = None
1343
+
1344
+ conv_out_padding = (conv_out_kernel - 1) // 2
1345
+ self.conv_out = nn.Conv2d(
1346
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
1347
+ )
1348
+
1349
+ @property
1350
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
1351
+ r"""
1352
+ Returns:
1353
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
1354
+ indexed by its weight name.
1355
+ """
1356
+ # set recursively
1357
+ processors = {}
1358
+
1359
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
1360
+ if hasattr(module, "set_processor"):
1361
+ processors[f"{name}.processor"] = module.processor
1362
+
1363
+ for sub_name, child in module.named_children():
1364
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
1365
+
1366
+ return processors
1367
+
1368
+ for name, module in self.named_children():
1369
+ fn_recursive_add_processors(name, module, processors)
1370
+
1371
+ return processors
1372
+
1373
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
1374
+ r"""
1375
+ Sets the attention processor to use to compute attention.
1376
+
1377
+ Parameters:
1378
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
1379
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
1380
+ for **all** `Attention` layers.
1381
+
1382
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
1383
+ processor. This is strongly recommended when setting trainable attention processors.
1384
+
1385
+ """
1386
+ count = len(self.attn_processors.keys())
1387
+
1388
+ if isinstance(processor, dict) and len(processor) != count:
1389
+ raise ValueError(
1390
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
1391
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
1392
+ )
1393
+
1394
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
1395
+ if hasattr(module, "set_processor"):
1396
+ if not isinstance(processor, dict):
1397
+ module.set_processor(processor)
1398
+ else:
1399
+ module.set_processor(processor.pop(f"{name}.processor"))
1400
+
1401
+ for sub_name, child in module.named_children():
1402
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
1403
+
1404
+ for name, module in self.named_children():
1405
+ fn_recursive_attn_processor(name, module, processor)
1406
+
1407
+ def set_default_attn_processor(self):
1408
+ """
1409
+ Disables custom attention processors and sets the default attention implementation.
1410
+ """
1411
+ self.set_attn_processor(AttnProcessor())
1412
+
1413
+ def set_attention_slice(self, slice_size):
1414
+ r"""
1415
+ Enable sliced attention computation.
1416
+
1417
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
1418
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
1419
+
1420
+ Args:
1421
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
1422
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
1423
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
1424
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
1425
+ must be a multiple of `slice_size`.
1426
+ """
1427
+ sliceable_head_dims = []
1428
+
1429
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
1430
+ if hasattr(module, "set_attention_slice"):
1431
+ sliceable_head_dims.append(module.sliceable_head_dim)
1432
+
1433
+ for child in module.children():
1434
+ fn_recursive_retrieve_sliceable_dims(child)
1435
+
1436
+ # retrieve number of attention layers
1437
+ for module in self.children():
1438
+ fn_recursive_retrieve_sliceable_dims(module)
1439
+
1440
+ num_sliceable_layers = len(sliceable_head_dims)
1441
+
1442
+ if slice_size == "auto":
1443
+ # half the attention head size is usually a good trade-off between
1444
+ # speed and memory
1445
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
1446
+ elif slice_size == "max":
1447
+ # make smallest slice possible
1448
+ slice_size = num_sliceable_layers * [1]
1449
+
1450
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
1451
+
1452
+ if len(slice_size) != len(sliceable_head_dims):
1453
+ raise ValueError(
1454
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
1455
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
1456
+ )
1457
+
1458
+ for i in range(len(slice_size)):
1459
+ size = slice_size[i]
1460
+ dim = sliceable_head_dims[i]
1461
+ if size is not None and size > dim:
1462
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
1463
+
1464
+ # Recursively walk through all the children.
1465
+ # Any children which exposes the set_attention_slice method
1466
+ # gets the message
1467
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
1468
+ if hasattr(module, "set_attention_slice"):
1469
+ module.set_attention_slice(slice_size.pop())
1470
+
1471
+ for child in module.children():
1472
+ fn_recursive_set_attention_slice(child, slice_size)
1473
+
1474
+ reversed_slice_size = list(reversed(slice_size))
1475
+ for module in self.children():
1476
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
1477
+
1478
+ def _set_gradient_checkpointing(self, module, value=False):
1479
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
1480
+ module.gradient_checkpointing = value
1481
+
1482
+ def _convert_drag_to_concatting_image(self, drag: torch.Tensor, current_resolution: int) -> torch.Tensor:
1483
+ assert self.drag_extra_dim == 4 * self.num_drags
1484
+
1485
+ bsz = drag.shape[0]
1486
+ concatting_image = -torch.ones(bsz, self.drag_extra_dim, current_resolution, current_resolution)
1487
+ concatting_image = concatting_image.to(drag.device)
1488
+
1489
+ not_all_zeros = drag.any(dim=-1).repeat_interleave(4, dim=1).unsqueeze(-1).unsqueeze(-1)
1490
+
1491
+ y_grid, x_grid = torch.meshgrid(torch.arange(current_resolution), torch.arange(current_resolution), indexing="ij")
1492
+ y_grid = y_grid.to(drag.device).unsqueeze(0).unsqueeze(0) # (1, 1, res, res)
1493
+ x_grid = x_grid.to(drag.device).unsqueeze(0).unsqueeze(0)
1494
+
1495
+ x0 = (drag[..., 0] * current_resolution - 0.5).round().clip(0, current_resolution - 1)
1496
+ x_src = (drag[..., 0] * current_resolution - x0).unsqueeze(-1).unsqueeze(-1) # (bsz, num_drags, 1, 1)
1497
+ x0 = x0.unsqueeze(-1).unsqueeze(-1)
1498
+ x0 = torch.stack([x0, x0, torch.zeros_like(x0) - 1, torch.zeros_like(x0) - 1], dim=2).view(bsz, 4 * self.num_drags, 1, 1)
1499
+
1500
+ y0 = (drag[..., 1] * current_resolution - 0.5).round().clip(0, current_resolution - 1)
1501
+ y_src = (drag[..., 1] * current_resolution - y0).unsqueeze(-1).unsqueeze(-1)
1502
+ y0 = y0.unsqueeze(-1).unsqueeze(-1)
1503
+ y0 = torch.stack([y0, y0, torch.zeros_like(y0) - 1, torch.zeros_like(y0) - 1], dim=2).view(bsz, 4 * self.num_drags, 1, 1)
1504
+
1505
+ x1 = (drag[..., 2] * current_resolution - 0.5).round().clip(0, current_resolution - 1)
1506
+ x_tgt = (drag[..., 2] * current_resolution - x1).unsqueeze(-1).unsqueeze(-1)
1507
+ x1 = x1.unsqueeze(-1).unsqueeze(-1)
1508
+ x1 = torch.stack([torch.zeros_like(x1) - 1, torch.zeros_like(x1) - 1, x1, x1], dim=2).view(bsz, 4 * self.num_drags, 1, 1)
1509
+
1510
+ y1 = (drag[..., 3] * current_resolution - 0.5).round().clip(0, current_resolution - 1)
1511
+ y_tgt = (drag[..., 3] * current_resolution - y1).unsqueeze(-1).unsqueeze(-1)
1512
+ y1 = y1.unsqueeze(-1).unsqueeze(-1)
1513
+ y1 = torch.stack([torch.zeros_like(y1) - 1, torch.zeros_like(y1) - 1, y1, y1], dim=2).view(bsz, 4 * self.num_drags, 1, 1)
1514
+
1515
+ # assert torch.all(x_src >= 0) and torch.all(x_src <= 1)
1516
+ # assert torch.all(y_src >= 0) and torch.all(y_src <= 1)
1517
+ # assert torch.all(x_tgt >= 0) and torch.all(x_tgt <= 1)
1518
+ # assert torch.all(y_tgt >= 0) and torch.all(y_tgt <= 1)
1519
+
1520
+ value_image = torch.stack([x_src, y_src, x_tgt, y_tgt], dim=2).view(bsz, 4 * self.num_drags, 1, 1)
1521
+ value_image = value_image.expand(bsz, 4 * self.num_drags, current_resolution, current_resolution)
1522
+
1523
+ concatting_image[(x_grid == x0) & (y_grid == y0) & not_all_zeros] = value_image[(x_grid == x0) & (y_grid == y0) & not_all_zeros]
1524
+ concatting_image[(x_grid == x1) & (y_grid == y1) & not_all_zeros] = value_image[(x_grid == x1) & (y_grid == y1) & not_all_zeros]
1525
+
1526
+ return concatting_image
1527
+
1528
+ def forward(
1529
+ self,
1530
+ # sample: torch.FloatTensor,
1531
+ # timestep: Union[torch.Tensor, float, int],
1532
+ # encoder_hidden_states: torch.Tensor,
1533
+ # class_labels: Optional[torch.Tensor] = None,
1534
+ # timestep_cond: Optional[torch.Tensor] = None,
1535
+ # attention_mask: Optional[torch.Tensor] = None,
1536
+ # cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1537
+ # added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1538
+ # down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1539
+ # mid_block_additional_residual: Optional[torch.Tensor] = None,
1540
+ # encoder_attention_mask: Optional[torch.Tensor] = None,
1541
+ # return_dict: bool = True,
1542
+ x: torch.FloatTensor,
1543
+ t: torch.Tensor,
1544
+ x_cond: torch.FloatTensor,
1545
+ x_cond_extra: Optional[torch.Tensor] = None,
1546
+ force_drop_ids: Optional[torch.Tensor] = None,
1547
+ hidden_cls: Optional[torch.Tensor] = None,
1548
+ drags: Optional[torch.Tensor] = None,
1549
+ save_features: bool = False,
1550
+ ) -> torch.Tensor:
1551
+ r"""
1552
+ The [`UNet2DConditionModel`] forward method.
1553
+
1554
+ Args:
1555
+ sample (`torch.FloatTensor`):
1556
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
1557
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
1558
+ encoder_hidden_states (`torch.FloatTensor`):
1559
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1560
+ encoder_attention_mask (`torch.Tensor`):
1561
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1562
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1563
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
1564
+ return_dict (`bool`, *optional*, defaults to `True`):
1565
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1566
+ tuple.
1567
+ cross_attention_kwargs (`dict`, *optional*):
1568
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
1569
+ added_cond_kwargs: (`dict`, *optional*):
1570
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
1571
+ are passed along to the UNet blocks.
1572
+
1573
+ Returns:
1574
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1575
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
1576
+ a `tuple` is returned where the first element is the sample tensor.
1577
+ """
1578
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1579
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1580
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1581
+ # on the fly if necessary.
1582
+ default_overall_up_factor = 2 ** self.num_upsamplers
1583
+
1584
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1585
+ forward_upsample_size = False
1586
+ upsample_size = None
1587
+
1588
+ if any(s % default_overall_up_factor != 0 for s in x.shape[-2:]):
1589
+ logger.info("Forward upsample size to force interpolation output size.")
1590
+ forward_upsample_size = True
1591
+
1592
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1593
+ # expects mask of shape:
1594
+ # [batch, key_tokens]
1595
+ # adds singleton query_tokens dimension:
1596
+ # [batch, 1, key_tokens]
1597
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1598
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1599
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1600
+ # if attention_mask is not None:
1601
+ # assume that mask is expressed as:
1602
+ # (1 = keep, 0 = discard)
1603
+ # convert mask into a bias that can be added to attention scores:
1604
+ # (keep = +0, discard = -10000.0)
1605
+ # attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1606
+ # attention_mask = attention_mask.unsqueeze(1)
1607
+
1608
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
1609
+ # if encoder_attention_mask is not None:
1610
+ # encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1611
+ # encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1612
+ if self.flow_original_res and self.input_concat_dragging:
1613
+ for i in range(self.num_flow_down_layers):
1614
+ x_cond_extra = self.flow_preprocess[i](x_cond_extra)
1615
+ x_cond_extra = self.flow_proj_act(x_cond_extra)
1616
+ x_cond_extra = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)(x_cond_extra)
1617
+ if self.input_concat_dragging:
1618
+ assert x_cond_extra.shape[-1] == x.shape[-1], f"{x_cond_extra.shape} != {x.shape}"
1619
+
1620
+ bsz, num_drags, drag_dim = drags.shape
1621
+ assert num_drags == self.num_drags
1622
+ if (self.train and self.drag_dropout_prob > 0) or force_drop_ids is not None:
1623
+ if force_drop_ids is None:
1624
+ drop_ids = torch.rand(bsz, device=x_cond_extra.device) < self.drag_dropout_prob
1625
+ else:
1626
+ drop_ids = force_drop_ids == 1
1627
+ x_cond_extra = torch.where(
1628
+ drop_ids[:, None, None, None].expand_as(x_cond_extra),
1629
+ torch.zeros_like(x_cond_extra),
1630
+ x_cond_extra,
1631
+ )
1632
+ drags = torch.where(
1633
+ drop_ids[:, None, None].expand_as(drags),
1634
+ torch.zeros_like(drags),
1635
+ drags,
1636
+ )
1637
+
1638
+ if not self.input_concat_dragging:
1639
+ sample = torch.cat([x_cond, x], dim=0)
1640
+ else:
1641
+ sample_noised = torch.cat([x, x_cond_extra], dim=1)
1642
+ sample_input = torch.cat([x_cond, torch.zeros_like(x_cond_extra)], dim=1)
1643
+ sample = torch.cat([sample_input, sample_noised], dim=0)
1644
+
1645
+ drags = torch.cat([torch.zeros_like(drags), drags], dim=0)
1646
+
1647
+ if self.flow_multi_resolution_conv:
1648
+ x_cond_extra = torch.cat([torch.zeros_like(x_cond_extra), x_cond_extra], dim=0)
1649
+ flow_multi_resolution_features = self.flow_adapter(x_cond_extra)
1650
+
1651
+ # -1. (new) get encoder_hidden_states
1652
+ if self.use_drag_tokens:
1653
+ assert drag_dim == 4
1654
+ drags = drags.reshape(-1, 4)
1655
+ drags = get_sin_cos_pos_embed(embed_dim=self.dragging_embedding_dim, x=drags)
1656
+ drags = drags.reshape(-1, num_drags, self.dragging_embedding_dim * 4)
1657
+ drag_states = self.drag_proj(drags)
1658
+ drag_states = self.proj_act(drag_states)
1659
+ drag_states = self.drag_final(drag_states)
1660
+
1661
+ assert hidden_cls.shape[1] >= self.num_clip_in
1662
+ cls_proj = 0
1663
+ for i in range(self.num_clip_in):
1664
+ current_cls = hidden_cls[:, -(i+1), :]
1665
+ cls_proj += self.clip_proj[i](current_cls)
1666
+ cls_proj = cls_proj / self.num_clip_in
1667
+ cls_proj = self.proj_act(cls_proj)
1668
+ cls_proj = self.clip_final(cls_proj)
1669
+
1670
+ if self.use_drag_tokens:
1671
+ if not self.single_drag_token:
1672
+ encoder_hidden_states = torch.cat([drag_states, torch.concat([cls_proj[:, None, :], cls_proj[:, None, :]], dim=0)], dim=1)
1673
+ assert encoder_hidden_states.shape[1] == num_drags + 1
1674
+ else:
1675
+ encoder_hidden_states = torch.cat([torch.mean(drag_states, dim=1, keepdim=True), torch.concat([cls_proj[:, None, :], cls_proj[:, None, :]], dim=0)], dim=1)
1676
+ assert encoder_hidden_states.shape[1] == 2
1677
+ else:
1678
+ encoder_hidden_states = cls_proj[:, None, :]
1679
+ assert encoder_hidden_states.shape[1] == 1
1680
+ encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states], dim=0)
1681
+
1682
+ # 0. center input if necessary
1683
+ assert not self.config.center_input_sample, "center_input_sample is not supported yet."
1684
+ if self.config.center_input_sample:
1685
+ sample = 2 * sample - 1.0
1686
+
1687
+ # 1. time
1688
+ timesteps = t
1689
+ if len(timesteps.shape) == 0:
1690
+ timesteps = timesteps[None].to(sample.device)
1691
+
1692
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1693
+ timesteps = torch.cat([timesteps, timesteps], dim=0)
1694
+
1695
+ t_emb = self.time_proj(timesteps)
1696
+
1697
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1698
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1699
+ # there might be better ways to encapsulate this.
1700
+ t_emb = t_emb.to(dtype=sample.dtype)
1701
+
1702
+ emb = self.time_embedding(t_emb, None)
1703
+ aug_emb = None
1704
+
1705
+ if self.class_embedding is not None:
1706
+ if class_labels is None:
1707
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
1708
+
1709
+ if self.config.class_embed_type == "timestep":
1710
+ class_labels = self.time_proj(class_labels)
1711
+
1712
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1713
+ # there might be better ways to encapsulate this.
1714
+ class_labels = class_labels.to(dtype=sample.dtype)
1715
+
1716
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1717
+
1718
+ if self.config.class_embeddings_concat:
1719
+ emb = torch.cat([emb, class_emb], dim=-1)
1720
+ else:
1721
+ emb = emb + class_emb
1722
+
1723
+ if self.config.addition_embed_type == "text":
1724
+ aug_emb = self.add_embedding(encoder_hidden_states)
1725
+ elif self.config.addition_embed_type == "text_image":
1726
+ # Kandinsky 2.1 - style
1727
+ if "image_embeds" not in added_cond_kwargs:
1728
+ raise ValueError(
1729
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1730
+ )
1731
+
1732
+ image_embs = added_cond_kwargs.get("image_embeds")
1733
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1734
+ aug_emb = self.add_embedding(text_embs, image_embs)
1735
+ elif self.config.addition_embed_type == "text_time":
1736
+ # SDXL - style
1737
+ if "text_embeds" not in added_cond_kwargs:
1738
+ raise ValueError(
1739
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1740
+ )
1741
+ text_embeds = added_cond_kwargs.get("text_embeds")
1742
+ if "time_ids" not in added_cond_kwargs:
1743
+ raise ValueError(
1744
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1745
+ )
1746
+ time_ids = added_cond_kwargs.get("time_ids")
1747
+ time_embeds = self.add_time_proj(time_ids.flatten())
1748
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1749
+
1750
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1751
+ add_embeds = add_embeds.to(emb.dtype)
1752
+ aug_emb = self.add_embedding(add_embeds)
1753
+ elif self.config.addition_embed_type == "image":
1754
+ # Kandinsky 2.2 - style
1755
+ if "image_embeds" not in added_cond_kwargs:
1756
+ raise ValueError(
1757
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1758
+ )
1759
+ image_embs = added_cond_kwargs.get("image_embeds")
1760
+ aug_emb = self.add_embedding(image_embs)
1761
+ elif self.config.addition_embed_type == "image_hint":
1762
+ # Kandinsky 2.2 - style
1763
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1764
+ raise ValueError(
1765
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1766
+ )
1767
+ image_embs = added_cond_kwargs.get("image_embeds")
1768
+ hint = added_cond_kwargs.get("hint")
1769
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1770
+ sample = torch.cat([sample, hint], dim=1)
1771
+
1772
+ emb = emb + aug_emb if aug_emb is not None else emb
1773
+
1774
+ if self.time_embed_act is not None:
1775
+ emb = self.time_embed_act(emb)
1776
+
1777
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1778
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1779
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1780
+ # Kadinsky 2.1 - style
1781
+ if "image_embeds" not in added_cond_kwargs:
1782
+ raise ValueError(
1783
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1784
+ )
1785
+
1786
+ image_embeds = added_cond_kwargs.get("image_embeds")
1787
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1788
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1789
+ # Kandinsky 2.2 - style
1790
+ if "image_embeds" not in added_cond_kwargs:
1791
+ raise ValueError(
1792
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1793
+ )
1794
+ image_embeds = added_cond_kwargs.get("image_embeds")
1795
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1796
+ # 2. pre-process
1797
+ if self.input_concat_dragging:
1798
+ if self.flow_in_old_version:
1799
+ sample = self.conv_in_flow(sample)
1800
+ else:
1801
+ sample_x, sample_flow = torch.split(sample, 4, dim=1)
1802
+ sample_x = self.conv_in(sample_x)
1803
+ sample_flow = self.conv_in_flow(sample_flow)
1804
+ sample = sample_x + sample_flow
1805
+ else:
1806
+ sample = self.conv_in(sample)
1807
+
1808
+ # 3. down
1809
+ down_block_res_samples = (sample,)
1810
+ for idx, downsample_block in enumerate(self.down_blocks):
1811
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1812
+ # For t2i-adapter CrossAttnDownBlock2D
1813
+ additional_residuals = {}
1814
+
1815
+ down_forward_kwargs = dict(
1816
+ hidden_states=sample if not self.flow_multi_resolution_conv else (sample + flow_multi_resolution_features[idx]),
1817
+ temb=emb,
1818
+ encoder_hidden_states=encoder_hidden_states,
1819
+ attention_mask=None,
1820
+ cross_attention_kwargs=None,
1821
+ encoder_attention_mask=None,
1822
+ **additional_residuals,
1823
+ )
1824
+
1825
+ if self.attn_concat_dragging:
1826
+ down_forward_kwargs["flow"] = self._convert_drag_to_concatting_image(drags, sample.shape[-1])
1827
+
1828
+ sample, res_samples = downsample_block(
1829
+ **down_forward_kwargs
1830
+ )
1831
+ else:
1832
+ sample, res_samples = downsample_block(
1833
+ hidden_states=sample if not self.flow_multi_resolution_conv else (sample + flow_multi_resolution_features[idx]),
1834
+ temb=emb
1835
+ )
1836
+
1837
+ down_block_res_samples += res_samples
1838
+
1839
+ # 4. mid
1840
+ if self.mid_block is not None:
1841
+ if self.attn_concat_dragging:
1842
+ sample = self.mid_block(
1843
+ sample,
1844
+ emb,
1845
+ encoder_hidden_states=encoder_hidden_states,
1846
+ attention_mask=None,
1847
+ cross_attention_kwargs=None,
1848
+ encoder_attention_mask=None,
1849
+ flow=self._convert_drag_to_concatting_image(drags, sample.shape[-1]),
1850
+ )
1851
+ else:
1852
+ sample = self.mid_block(
1853
+ sample,
1854
+ emb,
1855
+ encoder_hidden_states=encoder_hidden_states,
1856
+ attention_mask=None,
1857
+ cross_attention_kwargs=None,
1858
+ encoder_attention_mask=None,
1859
+ )
1860
+
1861
+ # 5. up
1862
+ for i, upsample_block in enumerate(self.up_blocks):
1863
+ is_final_block = i == len(self.up_blocks) - 1
1864
+
1865
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1866
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1867
+
1868
+ # if we have not reached the final block and need to forward the
1869
+ # upsample size, we do it here
1870
+ if not is_final_block and forward_upsample_size:
1871
+ upsample_size = down_block_res_samples[-1].shape[2:]
1872
+
1873
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1874
+ up_block_forward_kwargs = dict(
1875
+ hidden_states=sample,
1876
+ temb=emb,
1877
+ res_hidden_states_tuple=res_samples,
1878
+ encoder_hidden_states=encoder_hidden_states,
1879
+ attention_mask=None,
1880
+ cross_attention_kwargs=None,
1881
+ encoder_attention_mask=None,
1882
+ )
1883
+
1884
+ if self.attn_concat_dragging:
1885
+ up_block_forward_kwargs["flow"] = self._convert_drag_to_concatting_image(drags, sample.shape[-1])
1886
+
1887
+ sample = upsample_block(
1888
+ **up_block_forward_kwargs
1889
+ )
1890
+ else:
1891
+ sample = upsample_block(
1892
+ hidden_states=sample,
1893
+ temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1894
+ )
1895
+
1896
+ # 6. post-process
1897
+ if self.conv_norm_out:
1898
+ sample = self.conv_norm_out(sample)
1899
+ sample = self.conv_act(sample)
1900
+ sample = self.conv_out(sample)
1901
+
1902
+ return sample[bsz:]
1903
+
1904
+ def forward_with_cfg(
1905
+ self,
1906
+ x: torch.FloatTensor,
1907
+ t: torch.Tensor,
1908
+ x_cond: torch.FloatTensor,
1909
+ x_cond_extra: Optional[torch.Tensor] = None,
1910
+ hidden_cls: Optional[torch.Tensor] = None,
1911
+ drags: Optional[torch.Tensor] = None,
1912
+ cfg_scale: float = 1,
1913
+ ) -> torch.Tensor:
1914
+ half = x[: len(x) // 2]
1915
+ combined = torch.cat([half, half], dim=0)
1916
+ force_drop_ids = torch.arange(len(combined), device=combined.device) < len(half)
1917
+ model_out = self.forward(combined, t, x_cond, x_cond_extra, force_drop_ids=force_drop_ids, hidden_cls=hidden_cls, drags=drags)
1918
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
1919
+ # three channels by default. The standard approach to cfg applies it to all channels.
1920
+ # This can be done by uncommenting the following line and commenting-out the line following that.
1921
+ # eps, rest = model_out[:, :3], model_out[:, 3:]
1922
+ eps, rest = model_out[:, :4], model_out[:, 4:]
1923
+ uncond_eps, cond_eps = torch.split(eps, len(eps) // 2, dim=0)
1924
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
1925
+ eps = torch.cat([half_eps, half_eps], dim=0)
1926
+ return torch.cat([eps, rest], dim=1)
1927
+
1928
+ @classmethod
1929
+ def from_pretrained_sd(cls, pretrained_model_path, subfolder="unet", unet_additional_kwargs=None, load=True):
1930
+ if subfolder is not None:
1931
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1932
+ print(f"loading unet's pretrained weights from {pretrained_model_path} ...")
1933
+
1934
+ config_file = os.path.join(pretrained_model_path, 'config.json')
1935
+ if not os.path.isfile(config_file):
1936
+ raise RuntimeError(f"{config_file} does not exist")
1937
+ with open(config_file, "r") as f:
1938
+ config = json.load(f)
1939
+ config["_class_name"] = cls.__name__
1940
+
1941
+ from diffusers.utils import WEIGHTS_NAME
1942
+ one_sided_attn = unet_additional_kwargs.pop("one_sided_attn", True) if unet_additional_kwargs is not None else True
1943
+ model = cls.from_config(config, **unet_additional_kwargs) if unet_additional_kwargs is not None else cls.from_config(config)
1944
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1945
+ if not os.path.isfile(model_file):
1946
+ raise RuntimeError(f"{model_file} does not exist")
1947
+
1948
+ if load:
1949
+ state_dict = torch.load(model_file, map_location="cpu")
1950
+ m, u = model.load_state_dict(state_dict, strict=False)
1951
+
1952
+ # Set the attention processor to always take k, v from the input (upper) branch
1953
+ if one_sided_attn:
1954
+ attn_processors_dict={
1955
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor": OneSidedAttnProcessor(),
1956
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor": AttnProcessor(),
1957
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor": OneSidedAttnProcessor(),
1958
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor": AttnProcessor(),
1959
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor": OneSidedAttnProcessor(),
1960
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor": AttnProcessor(),
1961
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor": OneSidedAttnProcessor(),
1962
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor": AttnProcessor(),
1963
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor": OneSidedAttnProcessor(),
1964
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor": AttnProcessor(),
1965
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor": OneSidedAttnProcessor(),
1966
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor": AttnProcessor(),
1967
+
1968
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor": OneSidedAttnProcessor(),
1969
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor": AttnProcessor(),
1970
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor": OneSidedAttnProcessor(),
1971
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor": AttnProcessor(),
1972
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor": OneSidedAttnProcessor(),
1973
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor": AttnProcessor(),
1974
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor": OneSidedAttnProcessor(),
1975
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor": AttnProcessor(),
1976
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor": OneSidedAttnProcessor(),
1977
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor": AttnProcessor(),
1978
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor": OneSidedAttnProcessor(),
1979
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor": AttnProcessor(),
1980
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor": OneSidedAttnProcessor(),
1981
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor": AttnProcessor(),
1982
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor": OneSidedAttnProcessor(),
1983
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor": AttnProcessor(),
1984
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor": OneSidedAttnProcessor(),
1985
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor": AttnProcessor(),
1986
+
1987
+ "mid_block.attentions.0.transformer_blocks.0.attn1.processor": OneSidedAttnProcessor(),
1988
+ "mid_block.attentions.0.transformer_blocks.0.attn2.processor": AttnProcessor(),
1989
+ }
1990
+ model.set_attn_processor(attn_processors_dict)
1991
+
1992
+ return model
requirements.txt ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.0.0
2
+ accelerate==0.24.1
3
+ addict==2.4.0
4
+ aiofiles==23.2.1
5
+ aiohttp==3.9.1
6
+ aiosignal==1.3.1
7
+ altair==5.2.0
8
+ annotated-types==0.6.0
9
+ antlr4-python3-runtime==4.9.3
10
+ anyio==3.7.1
11
+ appdirs==1.4.4
12
+ asttokens==2.4.1
13
+ async-timeout==4.0.3
14
+ attrs==23.1.0
15
+ backcall==0.2.0
16
+ backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1687772187254/work
17
+ basicsr==1.4.2
18
+ blessed @ file:///home/conda/feedstock_root/build_artifacts/blessed_1666523113356/work
19
+ Brotli @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work
20
+ cachetools==5.3.2
21
+ certifi @ file:///croot/certifi_1700501669400/work/certifi
22
+ cffi @ file:///croot/cffi_1670423208954/work
23
+ charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
24
+ clean-fid==0.1.35
25
+ click==8.1.7
26
+ clip-anytorch==2.6.0
27
+ cloudpickle==3.0.0
28
+ colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work
29
+ coloredlogs==15.0.1
30
+ colorlog==6.8.2
31
+ ConfigArgParse==1.7
32
+ contourpy==1.1.1
33
+ controlnet-aux==0.0.7
34
+ cryptography @ file:///croot/cryptography_1694444244250/work
35
+ cycler==0.12.1
36
+ cypari==2.5.4
37
+ dctorch==0.1.2
38
+ decorator==4.4.2
39
+ diffusers==0.19.3
40
+ docker-pycreds==0.4.0
41
+ dotmap==1.3.30
42
+ einops==0.7.0
43
+ envlight @ git+https://github.com/ashawkey/envlight@ef492c03711c87287549a0283ee51199f45cbea4
44
+ exceptiongroup==1.2.0
45
+ executing==2.0.1
46
+ faiss==1.7.4
47
+ fastapi==0.105.0
48
+ fastcore==1.5.29
49
+ ffmpy==0.3.1
50
+ filelock @ file:///croot/filelock_1672387128942/work
51
+ flatbuffers==23.5.26
52
+ fonttools==4.45.1
53
+ frozenlist==1.4.0
54
+ fsspec==2023.10.0
55
+ ftfy==6.1.3
56
+ future==1.0.0
57
+ fvcore==0.1.5.post20210915
58
+ FXrays==1.3.5
59
+ gitdb==4.0.11
60
+ GitPython==3.1.40
61
+ gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645455532332/work
62
+ google-auth==2.24.0
63
+ google-auth-oauthlib==1.0.0
64
+ gpustat @ file:///home/conda/feedstock_root/build_artifacts/gpustat_1692786716371/work
65
+ GPUtil==1.4.0
66
+ gradio==4.21.0
67
+ gradio_client==0.12.0
68
+ -e git+https://github.com/RuiningLi/Animal-Data-Engine.git@3b3e155f572ce797a28ca88766101e73369cba17#egg=groundingdino&subdirectory=utils/GroundingDINO
69
+ grpcio==1.59.3
70
+ h11==0.14.0
71
+ h5py==3.10.0
72
+ httpcore==1.0.2
73
+ httpx==0.25.2
74
+ huggingface-hub==0.19.4
75
+ humanfriendly==10.0
76
+ icecream==2.1.3
77
+ idna @ file:///croot/idna_1666125576474/work
78
+ imageio==2.32.0
79
+ imageio-ffmpeg==0.4.9
80
+ importlib-metadata==6.8.0
81
+ importlib-resources==6.1.1
82
+ iopath==0.1.9
83
+ ipython==8.12.3
84
+ jaxtyping==0.2.19
85
+ jedi==0.19.1
86
+ Jinja2 @ file:///croot/jinja2_1666908132255/work
87
+ joblib==1.3.2
88
+ jsonmerge==1.9.2
89
+ jsonschema==4.20.0
90
+ jsonschema-specifications==2023.11.2
91
+ k-diffusion==0.1.1.post1
92
+ kiwisolver==1.4.5
93
+ knot-floer-homology==1.2
94
+ kornia==0.7.1
95
+ lazy_loader==0.3
96
+ libigl==2.5.0
97
+ lightning==2.1.3
98
+ lightning-utilities==0.10.0
99
+ llvmlite==0.41.1
100
+ lmdb==1.4.1
101
+ loguru==0.7.2
102
+ lovely-numpy==0.2.10
103
+ lovely-tensors==0.1.15
104
+ low-index==1.2
105
+ lpips==0.1.4
106
+ Markdown==3.5.1
107
+ markdown-it-py==3.0.0
108
+ MarkupSafe @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work
109
+ matplotlib==3.7.4
110
+ matplotlib-inline==0.1.6
111
+ mdurl==0.1.2
112
+ mkl-fft @ file:///croot/mkl_fft_1695058164594/work
113
+ mkl-random @ file:///croot/mkl_random_1695059800811/work
114
+ mkl-service==2.4.0
115
+ moviepy==1.0.3
116
+ mpmath @ file:///croot/mpmath_1690848262763/work
117
+ multidict==6.0.4
118
+ mypy-extensions==1.0.0
119
+ natsort==8.4.0
120
+ nerfacc==0.3.2
121
+ networkx @ file:///croot/networkx_1690561992265/work
122
+ ninja==1.11.1.1
123
+ numba==0.58.1
124
+ numpy @ file:///work/mkl/numpy_and_numpy_base_1682953417311/work
125
+ nvdiffrast @ git+https://github.com/NVlabs/nvdiffrast/@c5caf7bdb8a2448acc491a9faa47753972edd380
126
+ nvidia-ml-py @ file:///home/conda/feedstock_root/build_artifacts/nvidia-ml-py_1698947663801/work
127
+ oauthlib==3.2.2
128
+ objaverse==0.1.7
129
+ omegaconf==2.3.0
130
+ onnxruntime==1.16.3
131
+ opencv-python==4.8.1.78
132
+ opencv-python-headless==4.8.1.78
133
+ opt-einsum==3.3.0
134
+ orjson==3.9.10
135
+ packaging==23.2
136
+ pandas==1.3.0
137
+ parso==0.8.3
138
+ pexpect==4.8.0
139
+ pickleshare==0.7.5
140
+ Pillow @ file:///croot/pillow_1696580024257/work
141
+ pkgutil_resolve_name==1.3.10
142
+ platformdirs==4.0.0
143
+ plink==2.4.2
144
+ pooch==1.8.0
145
+ portalocker @ file:///home/conda/feedstock_root/build_artifacts/portalocker_1695662050140/work
146
+ proglog==0.1.10
147
+ prompt-toolkit==3.0.40
148
+ protobuf==4.25.1
149
+ psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work
150
+ ptyprocess==0.7.0
151
+ pure-eval==0.2.2
152
+ pyarrow==14.0.2
153
+ pyasn1==0.5.1
154
+ pyasn1-modules==0.3.0
155
+ pybind11==2.11.1
156
+ pycocotools==2.0.7
157
+ pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
158
+ pydantic==2.5.2
159
+ pydantic_core==2.14.5
160
+ pydensecrf @ git+https://github.com/lucasb-eyer/pydensecrf.git@dd070546eda51e21ab772ee6f14807c7f5b1548b
161
+ pyDeprecate==0.3.2
162
+ pydub==0.25.1
163
+ Pygments==2.16.1
164
+ pyhocon==0.3.57
165
+ PyMatting==1.1.12
166
+ PyMCubes==0.1.4
167
+ pynndescent==0.5.11
168
+ pyOpenSSL @ file:///croot/pyopenssl_1690223430423/work
169
+ pyparsing==3.1.1
170
+ pypng==0.20220715.0
171
+ PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work
172
+ python-dateutil==2.8.2
173
+ python-multipart==0.0.9
174
+ pytorch-lightning==2.1.3
175
+ pytorch3d==0.7.5
176
+ pytz==2023.3.post1
177
+ PyWavelets==1.4.1
178
+ PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1695373436676/work
179
+ referencing==0.31.1
180
+ regex==2023.10.3
181
+ rembg==2.0.52
182
+ requests @ file:///croot/requests_1690400202158/work
183
+ requests-oauthlib==1.3.1
184
+ rich==13.7.0
185
+ rpds-py==0.13.2
186
+ rsa==4.9
187
+ ruff==0.3.2
188
+ safetensors==0.4.0
189
+ scikit-image==0.21.0
190
+ scikit-learn==1.3.2
191
+ scipy==1.10.1
192
+ seaborn==0.13.1
193
+ -e git+https://github.com/RuiningLi/Animal-Data-Engine.git@3b3e155f572ce797a28ca88766101e73369cba17#egg=segment_anything&subdirectory=utils/segment_anything
194
+ semantic-version==2.10.0
195
+ sentry-sdk==1.35.0
196
+ setproctitle==1.3.3
197
+ shellingham==1.5.4
198
+ silence-tensorflow==1.2.1
199
+ six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
200
+ smmap==5.0.1
201
+ snappy==3.1.1
202
+ snappy-manifolds==1.2
203
+ sniffio==1.3.0
204
+ spherogram==2.2.1
205
+ stack-data==0.6.3
206
+ starlette==0.27.0
207
+ submitit==1.5.1
208
+ supervision==0.16.0
209
+ support-developer==1.0.5
210
+ sympy @ file:///croot/sympy_1668202399572/work
211
+ tabulate @ file:///home/conda/feedstock_root/build_artifacts/tabulate_1665138452165/work
212
+ taming-transformers==0.0.1
213
+ taming-transformers-rom1504==0.0.6
214
+ tap.py==3.1
215
+ tb-nightly==2.14.0a20230808
216
+ tensorboard==2.14.0
217
+ tensorboard-data-server==0.7.2
218
+ tensorboardX==2.6.2.2
219
+ termcolor @ file:///home/conda/feedstock_root/build_artifacts/termcolor_1682317048417/work
220
+ threadpoolctl==3.2.0
221
+ tifffile==2023.7.10
222
+ timm==0.9.10
223
+ tinycudann @ git+https://github.com/NVlabs/tiny-cuda-nn@212104156403bd87616c1a4f73a1c5f2c2e172a9#subdirectory=bindings/torch
224
+ tokenizers==0.15.0
225
+ tomli==2.0.1
226
+ tomlkit==0.12.0
227
+ toolz==0.12.0
228
+ torch==2.1.0+cu118
229
+ torch-efficient-distloss==0.1.3
230
+ torch-ema==0.3
231
+ torchaudio==2.1.0+cu118
232
+ torchdiffeq==0.2.3
233
+ torchmetrics==0.11.4
234
+ torchsde==0.2.6
235
+ torchvision==0.16.0+cu118
236
+ tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1691671248568/work
237
+ traitlets==5.13.0
238
+ trampoline==0.1.2
239
+ transformers==4.35.2
240
+ trimesh==4.0.5
241
+ triton==2.1.0
242
+ typeguard==4.1.5
243
+ typer==0.9.0
244
+ typing-inspect==0.9.0
245
+ typing_extensions==4.9.0
246
+ tzdata==2023.3
247
+ umap-learn==0.5.5
248
+ urllib3 @ file:///croot/urllib3_1698257533958/work
249
+ uvicorn==0.24.0.post1
250
+ wandb==0.16.0
251
+ wcwidth==0.2.13
252
+ websockets==11.0.3
253
+ Werkzeug==3.0.1
254
+ xatlas==0.0.8
255
+ yacs @ file:///home/conda/feedstock_root/build_artifacts/yacs_1645705974477/work
256
+ yapf==0.40.2
257
+ yarl==1.9.3
258
+ zipp==3.17.0