twodgirl commited on
Commit
5c31d1f
·
verified ·
1 Parent(s): bf6b364

Upload 27 files

Browse files
images/bear_avocado__spatext.jpg ADDED
images/bedroom__sketch.jpg ADDED
images/cat__mesh.jpg ADDED
images/cat__point_cloud.jpg ADDED
images/dog__sketch.jpg ADDED
images/fruit_bowl.jpg ADDED
images/grapes.jpg ADDED
images/horse.jpg ADDED
images/horse__point_cloud.jpg ADDED
images/knight__humanoid.jpg ADDED
images/library__mesh.jpg ADDED
images/living_room__seg.jpg ADDED
images/living_room_modern.jpg ADDED
images/man_park.jpg ADDED
images/person__mesh.jpg ADDED
images/running__pose.jpg ADDED
images/squirrel.jpg ADDED
images/tiger.jpg ADDED
images/van_gogh.jpg ADDED
pipelines/__init__.py ADDED
File without changes
pipelines/pipeline_sdxl.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from dataclasses import dataclass
3
+ from diffusers import StableDiffusionXLPipeline
4
+ from diffusers.image_processor import PipelineImageInput
5
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img\
6
+ import rescale_noise_cfg, retrieve_latents, retrieve_timesteps
7
+ from diffusers.utils import BaseOutput
8
+ from diffusers.utils.torch_utils import randn_tensor
9
+ import numpy as np
10
+ from PIL import Image
11
+ import torch
12
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
13
+ from utils.utils import batch_dict_to_tensor, batch_tensor_to_dict, noise_prev, noise_t2t
14
+ from utils.sdxl import register_attr
15
+
16
+ ###
17
+ # Code from genforce/ctrl-x/ctrl_x/pipelines/pipeline_sdxl.py
18
+
19
+ BATCH_ORDER = [
20
+ "structure_uncond", "appearance_uncond", "uncond", "structure_cond", "appearance_cond", "cond",
21
+ ]
22
+
23
+ def get_last_control_i(control_schedule, num_inference_steps):
24
+ if control_schedule is None:
25
+ return num_inference_steps, num_inference_steps
26
+
27
+ def max_(l):
28
+ if len(l) == 0:
29
+ return 0.0
30
+ return max(l)
31
+
32
+ structure_max = 0.0
33
+ appearance_max = 0.0
34
+ for block in control_schedule.values():
35
+ if isinstance(block, list): # Handling mid_block
36
+ block = {0: block}
37
+ for layer in block.values():
38
+ structure_max = max(structure_max, max_(layer[0] + layer[1]))
39
+ appearance_max = max(appearance_max, max_(layer[2]))
40
+
41
+ structure_i = round(num_inference_steps * structure_max)
42
+ appearance_i = round(num_inference_steps * appearance_max)
43
+
44
+ return structure_i, appearance_i
45
+
46
+ @dataclass
47
+ class CtrlXStableDiffusionXLPipelineOutput(BaseOutput):
48
+ images: Union[List[Image.Image], np.ndarray]
49
+ structures = Union[List[Image.Image], np.ndarray]
50
+ appearances = Union[List[Image.Image], np.ndarray]
51
+
52
+ class CtrlXStableDiffusionXLPipeline(StableDiffusionXLPipeline):
53
+ def __call__(
54
+ self,
55
+ prompt: Union[str, List[str]] = None, # TODO: Support prompt_2 and negative_prompt_2
56
+ structure_prompt: Optional[Union[str, List[str]]] = None,
57
+ appearance_prompt: Optional[Union[str, List[str]]] = None,
58
+ structure_image: Optional[PipelineImageInput] = None,
59
+ appearance_image: Optional[PipelineImageInput] = None,
60
+ num_inference_steps: int = 50,
61
+ timesteps: List[int] = None,
62
+ negative_prompt: Optional[Union[str, List[str]]] = None,
63
+ positive_prompt: Optional[Union[str, List[str]]] = None,
64
+ height: Optional[int] = None,
65
+ width: Optional[int] = None,
66
+ guidance_scale: float = 5.0,
67
+ structure_guidance_scale: Optional[float] = None,
68
+ appearance_guidance_scale: Optional[float] = None,
69
+ num_images_per_prompt: Optional[int] = 1,
70
+ eta: float = 0.0,
71
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
72
+ latents: Optional[torch.Tensor] = None,
73
+ structure_latents: Optional[torch.Tensor] = None,
74
+ appearance_latents: Optional[torch.Tensor] = None,
75
+ prompt_embeds: Optional[torch.Tensor] = None, # Positive prompt is concatenated with prompt, so no embeddings
76
+ structure_prompt_embeds: Optional[torch.Tensor] = None,
77
+ appearance_prompt_embeds: Optional[torch.Tensor] = None,
78
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
79
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
80
+ structure_pooled_prompt_embeds: Optional[torch.Tensor] = None,
81
+ appearance_pooled_prompt_embeds: Optional[torch.Tensor] = None,
82
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
83
+ control_schedule: Optional[Dict] = None,
84
+ self_recurrence_schedule: Optional[List[int]] = [], # Format: [(start, end, num_repeat)]
85
+ decode_structure: Optional[bool] = True,
86
+ decode_appearance: Optional[bool] = True,
87
+ output_type: Optional[str] = "pil",
88
+ return_dict: bool = True,
89
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
90
+ guidance_rescale: float = 0.0,
91
+ original_size: Tuple[int, int] = None,
92
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
93
+ target_size: Tuple[int, int] = None,
94
+ clip_skip: Optional[int] = None,
95
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
96
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
97
+ **kwargs,
98
+ ):
99
+ callback = kwargs.pop("callback", None)
100
+ callback_steps = kwargs.pop("callback_steps", None)
101
+ self._guidance_scale = guidance_scale
102
+
103
+ # 0. Default height and width to U-Net
104
+ height = height or self.default_sample_size * self.vae_scale_factor
105
+ width = width or self.default_sample_size * self.vae_scale_factor
106
+ original_size = original_size or (height, width)
107
+ target_size = target_size or (height, width)
108
+
109
+ # 2. Set batch_size = 1 as per instruction
110
+ batch_size = 1
111
+ if isinstance(prompt, list):
112
+ assert len(prompt) == batch_size
113
+ if prompt_embeds is not None:
114
+ assert prompt_embeds.shape[0] == batch_size
115
+
116
+ device = self._execution_device
117
+
118
+ # 3. Encode input prompt
119
+ text_encoder_lora_scale = (
120
+ cross_attention_kwargs.get("scale", None)
121
+ if cross_attention_kwargs is not None else None
122
+ )
123
+
124
+ # 3-3.2 Encode input, structure, appearance prompt
125
+ # bc98db93-468b-4511-b30d-3a330eca9968
126
+ # Prepare prompt data
127
+ prompts = [
128
+ (prompt, None, None, None, None, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds),
129
+ (structure_prompt, structure_prompt_embeds, negative_prompt if structure_image is None else "", None, None, structure_prompt_embeds, None, structure_pooled_prompt_embeds, None),
130
+ (appearance_prompt, appearance_prompt_embeds, negative_prompt if appearance_image is None else "", None, None, appearance_prompt_embeds, None, appearance_pooled_prompt_embeds, None)
131
+ ]
132
+ prompt_embeds_list = []
133
+ add_text_embeds_list = []
134
+ for item in prompts:
135
+ prompt_text, prompt_embeds_temp, negative_prompt_temp, pooled_prompt_embeds_temp = item[:4] # Unpack relevant items
136
+
137
+ if prompt_text is not None and prompt_text != "":
138
+ (
139
+ prompt_embeds_,
140
+ negative_prompt_embeds,
141
+ pooled_prompt_embeds_,
142
+ negative_pooled_prompt_embeds,
143
+ ) = self.encode_prompt(
144
+ prompt=prompt_text,
145
+ prompt_2=None,
146
+ device=device,
147
+ num_images_per_prompt=num_images_per_prompt,
148
+ do_classifier_free_guidance=True,
149
+ negative_prompt=negative_prompt_temp,
150
+ negative_prompt_2=None,
151
+ prompt_embeds=prompt_embeds_temp,
152
+ negative_prompt_embeds=None,
153
+ pooled_prompt_embeds=pooled_prompt_embeds_temp,
154
+ negative_pooled_prompt_embeds=None,
155
+ lora_scale=text_encoder_lora_scale,
156
+ clip_skip=clip_skip,
157
+ )
158
+ prompt_embeds_list.append(torch.cat([negative_prompt_embeds, prompt_embeds_], dim=0).to(device))
159
+ add_text_embeds_list.append(torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_], dim=0).to(device))
160
+ else:
161
+ prompt_embeds_list.append(prompt_embeds_list[0])
162
+ add_text_embeds_list.append(add_text_embeds_list[0])
163
+ # prompt_embeds, structure_prompt_embeds, appearance_prompt_embeds = prompt_embeds_list
164
+ # add_text_embeds, structure_add_text_embeds, appearance_add_text_embeds = add_text_embeds_list
165
+
166
+ # 3.3. Prepare added time ids & embeddings
167
+ if self.text_encoder_2 is None:
168
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
169
+ else:
170
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
171
+
172
+ add_time_ids = self._get_add_time_ids(
173
+ original_size,
174
+ crops_coords_top_left,
175
+ target_size,
176
+ dtype=self.dtype,
177
+ text_encoder_projection_dim=text_encoder_projection_dim,
178
+ )
179
+ negative_add_time_ids = add_time_ids
180
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0).to(device)
181
+
182
+ # 4. Prepare timesteps
183
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
184
+
185
+ # 5. Prepare latent variables
186
+ num_channels_latents = self.unet.config.in_channels
187
+
188
+ # The second variable is _.
189
+ latents, _ = self.prepare_latents(
190
+ None, batch_size, num_images_per_prompt, num_channels_latents, height, width,
191
+ self.dtype, device, generator, latents
192
+ )
193
+ latents_ = [structure_latents, appearance_latents]
194
+ clean_latents_ = []
195
+ for image_index, image_ in enumerate([structure_image, appearance_image]):
196
+ if image_ is not None:
197
+ # The first variable is _.
198
+ _, clean_latent = self.prepare_latents(
199
+ image_, batch_size, num_images_per_prompt, num_channels_latents, height, width,
200
+ self.dtype, device, generator, latents_[image_index]
201
+ )
202
+ clean_latents_.append(clean_latent)
203
+ else:
204
+ clean_latents_.append(None)
205
+ if latents_[image_index] is None:
206
+ latents_[image_index] = latents
207
+ latents_ = [latents] + latents_
208
+ # clean_structure_latents, clean_appearance_latents = clean_latents_
209
+
210
+ # 6. Prepare extra step kwargs
211
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
212
+
213
+ # 7. Denoising loop
214
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
215
+
216
+ # 7.1 Apply denoising_end
217
+ if hasattr(self, 'denoising_end') and self.denoising_end is not None and 0.0 < float(self.denoising_end) < 1.0:
218
+ discrete_timestep_cutoff = int(
219
+ round(
220
+ self.scheduler.config.num_train_timesteps
221
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
222
+ )
223
+ )
224
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
225
+ timesteps = timesteps[:num_inference_steps]
226
+
227
+ # 7.2 Optionally get guidance scale embedding
228
+ timestep_cond = None
229
+ assert self.unet.config.time_cond_proj_dim is None
230
+
231
+ # 7.3 Get batch order
232
+ batch_order = deepcopy(BATCH_ORDER)
233
+ if structure_image is not None: # If image is provided, not generating, so no CFG needed
234
+ batch_order.remove("structure_uncond")
235
+ if appearance_image is not None:
236
+ batch_order.remove("appearance_uncond")
237
+
238
+ baked_latents = self.cfg_loop(batch_order,
239
+ prompt_embeds_list,
240
+ add_text_embeds_list,
241
+ add_time_ids,
242
+ latents_,
243
+ clean_latents_,
244
+ num_inference_steps,
245
+ num_warmup_steps,
246
+ extra_step_kwargs,
247
+ timesteps,
248
+ timestep_cond=timestep_cond,
249
+ control_schedule=control_schedule,
250
+ self_recurrence_schedule=self_recurrence_schedule,
251
+ guidance_rescale=guidance_rescale,
252
+ callback=callback,
253
+ callback_steps=callback_steps,
254
+ cross_attention_kwargs=cross_attention_kwargs)
255
+ latents, structure_latents, appearance_latents = baked_latents
256
+
257
+ # For passing important information onto the refiner
258
+ self.refiner_args = {"latents": latents.detach(), "prompt": prompt, "negative_prompt": negative_prompt}
259
+
260
+ if not output_type == "latent":
261
+ # Make sure the VAE is in float32 mode, as it overflows in float16
262
+ if self.vae.config.force_upcast:
263
+ self.upcast_vae()
264
+ vae_dtype = next(iter(self.vae.post_quant_conv.parameters())).dtype
265
+ latents = latents.to(vae_dtype)
266
+ structure_latents = structure_latents.to(vae_dtype)
267
+ appearance_latents = appearance_latents.to(vae_dtype)
268
+
269
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
270
+ image = self.image_processor.postprocess(image, output_type=output_type)
271
+ if decode_structure:
272
+ structure = self.vae.decode(structure_latents / self.vae.config.scaling_factor, return_dict=False)[0]
273
+ structure = self.image_processor.postprocess(structure, output_type=output_type)
274
+ else:
275
+ structure = structure_latents
276
+ if decode_appearance:
277
+ appearance = self.vae.decode(appearance_latents / self.vae.config.scaling_factor, return_dict=False)[0]
278
+ appearance = self.image_processor.postprocess(appearance, output_type=output_type)
279
+ else:
280
+ appearance = appearance_latents
281
+
282
+ # Cast back to fp16 if needed
283
+ if self.vae.config.force_upcast:
284
+ self.vae.to(dtype=torch.float16)
285
+ else:
286
+ return CtrlXStableDiffusionXLPipelineOutput(
287
+ images=latents, structures=structure_latents, appearances=appearance_latents
288
+ )
289
+
290
+ # Offload all models
291
+ self.maybe_free_model_hooks()
292
+
293
+ if not return_dict:
294
+ return image, structure, appearance
295
+
296
+ return CtrlXStableDiffusionXLPipelineOutput(images=image, structures=structure, appearances=appearance)
297
+
298
+ def cfg_loop(self,
299
+ batch_order,
300
+ prompt_embeds_list,
301
+ add_text_embeds_list,
302
+ add_time_ids,
303
+ latents_,
304
+ clean_latents_,
305
+ num_inference_steps,
306
+ num_warmup_steps,
307
+ extra_step_kwargs,
308
+ timesteps,
309
+ timestep_cond=None,
310
+ control_schedule=None,
311
+ self_recurrence_schedule=None,
312
+ guidance_rescale=0.0,
313
+ callback=None,
314
+ callback_steps=None,
315
+ callback_on_step_end=None,
316
+ callback_on_step_end_tensor_inputs=None,
317
+ cross_attention_kwargs=None):
318
+ prompt_embeds, structure_prompt_embeds, appearance_prompt_embeds = prompt_embeds_list
319
+ add_text_embeds, structure_add_text_embeds, appearance_add_text_embeds = add_text_embeds_list
320
+ latents, structure_latents, appearance_latents = latents_
321
+ clean_structure_latents, clean_appearance_latents = clean_latents_
322
+ structure_control_stop_i, appearance_control_stop_i = get_last_control_i(control_schedule, num_inference_steps)
323
+
324
+ if self_recurrence_schedule is None:
325
+ self_recurrence_schedule = [0] * num_inference_steps
326
+
327
+ self._num_timesteps = len(timesteps)
328
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
329
+ for i, t in enumerate(timesteps):
330
+ if hasattr(self, 'interrupt') and self.interrupt:
331
+ continue
332
+
333
+ if i == structure_control_stop_i: # If not generating structure/appearance, drop after last control
334
+ if "structure_uncond" not in batch_order:
335
+ batch_order.remove("structure_cond")
336
+ if i == appearance_control_stop_i:
337
+ if "appearance_uncond" not in batch_order:
338
+ batch_order.remove("appearance_cond")
339
+
340
+ register_attr(self, t=t.item(), do_control=True, batch_order=batch_order)
341
+
342
+ # With CFG.
343
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
344
+ structure_latent_model_input = self.scheduler.scale_model_input(structure_latents, t)
345
+ appearance_latent_model_input = self.scheduler.scale_model_input(appearance_latents, t)
346
+
347
+ pass
348
+ all_latent_model_input = {
349
+ "structure_uncond": structure_latent_model_input[0:1],
350
+ "appearance_uncond": appearance_latent_model_input[0:1],
351
+ "uncond": latent_model_input[0:1],
352
+ "structure_cond": structure_latent_model_input[0:1],
353
+ "appearance_cond": appearance_latent_model_input[0:1],
354
+ "cond": latent_model_input[0:1],
355
+ }
356
+ all_prompt_embeds = {
357
+ "structure_uncond": structure_prompt_embeds[0:1],
358
+ "appearance_uncond": appearance_prompt_embeds[0:1],
359
+ "uncond": prompt_embeds[0:1],
360
+ "structure_cond": structure_prompt_embeds[1:2],
361
+ "appearance_cond": appearance_prompt_embeds[1:2],
362
+ "cond": prompt_embeds[1:2],
363
+ }
364
+ all_add_text_embeds = {
365
+ "structure_uncond": structure_add_text_embeds[0:1],
366
+ "appearance_uncond": appearance_add_text_embeds[0:1],
367
+ "uncond": add_text_embeds[0:1],
368
+ "structure_cond": structure_add_text_embeds[1:2],
369
+ "appearance_cond": appearance_add_text_embeds[1:2],
370
+ "cond": add_text_embeds[1:2],
371
+ }
372
+ all_time_ids = {
373
+ "structure_uncond": add_time_ids[0:1],
374
+ "appearance_uncond": add_time_ids[0:1],
375
+ "uncond": add_time_ids[0:1],
376
+ "structure_cond": add_time_ids[1:2],
377
+ "appearance_cond": add_time_ids[1:2],
378
+ "cond": add_time_ids[1:2],
379
+ }
380
+
381
+ concat_latent_model_input = batch_dict_to_tensor(all_latent_model_input, batch_order)
382
+ concat_prompt_embeds = batch_dict_to_tensor(all_prompt_embeds, batch_order)
383
+ concat_add_text_embeds = batch_dict_to_tensor(all_add_text_embeds, batch_order)
384
+ concat_add_time_ids = batch_dict_to_tensor(all_time_ids, batch_order)
385
+
386
+ # Predict the noise residual
387
+ added_cond_kwargs = {"text_embeds": concat_add_text_embeds, "time_ids": concat_add_time_ids}
388
+
389
+ concat_noise_pred = self.unet(
390
+ concat_latent_model_input,
391
+ t,
392
+ encoder_hidden_states=concat_prompt_embeds,
393
+ timestep_cond=timestep_cond,
394
+ cross_attention_kwargs=cross_attention_kwargs,
395
+ added_cond_kwargs=added_cond_kwargs,
396
+ ).sample
397
+ all_noise_pred = batch_tensor_to_dict(concat_noise_pred, batch_order)
398
+
399
+ # Classifier-free guidance
400
+ noise_pred = all_noise_pred["uncond"] +\
401
+ self.guidance_scale * (all_noise_pred["cond"] - all_noise_pred["uncond"])
402
+
403
+ structure_noise_pred = all_noise_pred["structure_cond"]\
404
+ if "structure_cond" in batch_order else noise_pred
405
+ if "structure_uncond" in all_noise_pred:
406
+ structure_noise_pred = all_noise_pred["structure_uncond"] +\
407
+ self.structure_guidance_scale * (structure_noise_pred - all_noise_pred["structure_uncond"])
408
+
409
+ appearance_noise_pred = all_noise_pred["appearance_cond"]\
410
+ if "appearance_cond" in batch_order else noise_pred
411
+ if "appearance_uncond" in all_noise_pred:
412
+ appearance_noise_pred = all_noise_pred["appearance_uncond"] +\
413
+ self.appearance_guidance_scale * (appearance_noise_pred - all_noise_pred["appearance_uncond"])
414
+
415
+ if guidance_rescale > 0.0:
416
+ noise_pred = rescale_noise_cfg(
417
+ noise_pred, all_noise_pred["cond"], guidance_rescale=guidance_rescale
418
+ )
419
+ if "structure_uncond" in all_noise_pred:
420
+ structure_noise_pred = rescale_noise_cfg(
421
+ structure_noise_pred, all_noise_pred["structure_cond"],
422
+ guidance_rescale=guidance_rescale
423
+ )
424
+ if "appearance_uncond" in all_noise_pred:
425
+ appearance_noise_pred = rescale_noise_cfg(
426
+ appearance_noise_pred, all_noise_pred["appearance_cond"],
427
+ guidance_rescale=guidance_rescale
428
+ )
429
+
430
+ # Compute the previous noisy sample x_t -> x_t-1
431
+ concat_noise_pred = torch.cat(
432
+ [structure_noise_pred, appearance_noise_pred, noise_pred], dim=0,
433
+ )
434
+ concat_latents = torch.cat(
435
+ [structure_latents, appearance_latents, latents], dim=0,
436
+ )
437
+ structure_latents, appearance_latents, latents = self.scheduler.step(
438
+ concat_noise_pred, t, concat_latents, **extra_step_kwargs,
439
+ ).prev_sample.chunk(3)
440
+
441
+ if clean_structure_latents is not None:
442
+ structure_latents = noise_prev(self.scheduler, t, clean_structure_latents)
443
+ if clean_appearance_latents is not None:
444
+ appearance_latents = noise_prev(self.scheduler, t, clean_appearance_latents)
445
+
446
+ # Self-recurrence
447
+ for _ in range(self_recurrence_schedule[i]):
448
+ if hasattr(self.scheduler, "_step_index"): # For fancier schedulers
449
+ self.scheduler._step_index -= 1 # TODO: Does this actually work?
450
+
451
+ t_prev = 0 if i + 1 >= num_inference_steps else timesteps[i + 1]
452
+ latents = noise_t2t(self.scheduler, t_prev, t, latents)
453
+ latent_model_input = torch.cat([latents] * 2)
454
+
455
+ register_attr(self, t=t.item(), do_control=False, batch_order=["uncond", "cond"])
456
+
457
+ # Predict the noise residual
458
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
459
+ noise_pred_uncond, noise_pred_ = self.unet(
460
+ latent_model_input,
461
+ t,
462
+ encoder_hidden_states=prompt_embeds,
463
+ timestep_cond=timestep_cond,
464
+ cross_attention_kwargs=cross_attention_kwargs,
465
+ added_cond_kwargs=added_cond_kwargs,
466
+ ).sample.chunk(2)
467
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_ - noise_pred_uncond)
468
+
469
+ if guidance_rescale > 0.0:
470
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_, guidance_rescale=guidance_rescale)
471
+
472
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
473
+
474
+ # Callbacks
475
+ assert callback_on_step_end is None
476
+
477
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
478
+ progress_bar.update()
479
+ if callback is not None and i % callback_steps == 0:
480
+ step_idx = i // getattr(self.scheduler, "order", 1)
481
+ callback(step_idx, t, latents)
482
+
483
+ # "Reconstruction"
484
+ if clean_structure_latents is not None:
485
+ structure_latents = clean_structure_latents
486
+ if clean_appearance_latents is not None:
487
+ appearance_latents = clean_appearance_latents
488
+
489
+ return latents, structure_latents, appearance_latents
490
+
491
+ @property
492
+ def appearance_guidance_scale(self):
493
+ return self._guidance_scale if self._appearance_guidance_scale is None else self._appearance_guidance_scale
494
+
495
+ @property
496
+ def structure_guidance_scale(self):
497
+ return self._guidance_scale if self._structure_guidance_scale is None else self._structure_guidance_scale
498
+
499
+ def prepare_latents(self, image, batch_size, num_images_per_prompt, num_channels_latents, height, width,
500
+ dtype, device, generator=None, noise=None):
501
+ batch_size = batch_size * num_images_per_prompt
502
+
503
+ if noise is None:
504
+ shape = (
505
+ batch_size,
506
+ num_channels_latents,
507
+ height // self.vae_scale_factor,
508
+ width // self.vae_scale_factor
509
+ )
510
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
511
+ noise = noise * self.scheduler.init_noise_sigma # Starting noise, need to scale
512
+ else:
513
+ noise = noise.to(device)
514
+
515
+ if image is None:
516
+ return noise, None
517
+
518
+ if not isinstance(image, (torch.Tensor, Image.Image, list)):
519
+ raise ValueError(
520
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
521
+ )
522
+
523
+ # Offload text encoder if `enable_model_cpu_offload` was enabled
524
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
525
+ self.text_encoder_2.to("cpu")
526
+ torch.cuda.empty_cache()
527
+
528
+ image = image.to(device=device, dtype=dtype)
529
+
530
+ if image.shape[1] == 4: # Image already in latents form
531
+ init_latents = image
532
+
533
+ else:
534
+ # Make sure the VAE is in float32 mode, as it overflows in float16
535
+ if self.vae.config.force_upcast:
536
+ image = image.to(torch.float32)
537
+ self.vae.to(torch.float32)
538
+
539
+ if isinstance(generator, list) and len(generator) != batch_size:
540
+ raise ValueError(
541
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
542
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
543
+ )
544
+ elif isinstance(generator, list):
545
+ init_latents = [
546
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
547
+ for i in range(batch_size)
548
+ ]
549
+ init_latents = torch.cat(init_latents, dim=0)
550
+ else:
551
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
552
+
553
+ if self.vae.config.force_upcast:
554
+ self.vae.to(dtype)
555
+
556
+ init_latents = init_latents.to(dtype)
557
+ init_latents = self.vae.config.scaling_factor * init_latents
558
+
559
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
560
+ # Expand init_latents for batch_size
561
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
562
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
563
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
564
+ raise ValueError(
565
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
566
+ )
567
+ else:
568
+ init_latents = torch.cat([init_latents], dim=0)
569
+
570
+ return noise, init_latents
run_ctrlx.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ from datetime import datetime
3
+ from diffusers import DDIMScheduler, StableDiffusionXLImg2ImgPipeline
4
+ from diffusers.utils import load_image
5
+ from os import makedirs, path
6
+ from pipelines.pipeline_sdxl import CtrlXStableDiffusionXLPipeline
7
+ import torch
8
+ from time import time
9
+ from utils import *
10
+ from utils.media import preprocess
11
+ from utils.sdxl import *
12
+ import yaml
13
+
14
+
15
+ @torch.no_grad()
16
+ def inference(
17
+ pipe, refiner, device,
18
+ structure_image, appearance_image,
19
+ prompt, structure_prompt, appearance_prompt,
20
+ positive_prompt, negative_prompt,
21
+ guidance_scale, structure_guidance_scale, appearance_guidance_scale,
22
+ num_inference_steps, eta, seed,
23
+ width, height,
24
+ structure_schedule, appearance_schedule,
25
+ ):
26
+ seed_everything(seed)
27
+
28
+ # Process images.
29
+ # Moved from CtrlXStableDiffusionXLPipeline.__call__.
30
+ if structure_image is not None and isinstance(args.structure_image, str):
31
+ structure_image = load_image(args.structure_image)
32
+ structure_image = preprocess(structure_image, pipe.image_processor,
33
+ height=height, width=width, resize_mode="crop")
34
+ if appearance_image is not None:
35
+ appearance_image = load_image(appearance_image)
36
+ appearance_image = preprocess(appearance_image, pipe.image_processor,
37
+ height=height, width=width, resize_mode="crop")
38
+
39
+
40
+ # Scheduler.
41
+ pipe.scheduler.set_timesteps(num_inference_steps, device=device)
42
+ timesteps = pipe.scheduler.timesteps
43
+ control_config = get_control_config(structure_schedule, appearance_schedule)
44
+ print(f"\nUsing the following control config:\n{control_config}\n")
45
+ config = yaml.safe_load(control_config)
46
+ register_control(
47
+ model=pipe,
48
+ timesteps=timesteps,
49
+ control_schedule=config["control_schedule"],
50
+ control_target=config["control_target"],
51
+ )
52
+
53
+ # Pipe settings.
54
+ pipe.safety_checker = None
55
+ pipe.requires_safety_checker = False
56
+ self_recurrence_schedule = get_self_recurrence_schedule(config["self_recurrence_schedule"], num_inference_steps)
57
+ pipe.set_progress_bar_config(desc="Ctrl-X inference")
58
+
59
+ # Inference.
60
+ result, structure, appearance = pipe(
61
+ prompt=prompt,
62
+ structure_prompt=structure_prompt,
63
+ appearance_prompt=appearance_prompt,
64
+ structure_image=structure_image,
65
+ appearance_image=appearance_image,
66
+ num_inference_steps=num_inference_steps,
67
+ negative_prompt=negative_prompt,
68
+ positive_prompt=positive_prompt,
69
+ height=height,
70
+ width=width,
71
+ guidance_scale=guidance_scale,
72
+ structure_guidance_scale=structure_guidance_scale,
73
+ appearance_guidance_scale=appearance_guidance_scale,
74
+ eta=eta,
75
+ output_type="pil",
76
+ return_dict=False,
77
+ control_schedule=config["control_schedule"],
78
+ self_recurrence_schedule=self_recurrence_schedule,
79
+ )
80
+ result_refiner = [None]
81
+
82
+ del pipe.refiner_args
83
+
84
+ return result[0], result_refiner[0], structure[0], appearance[0]
85
+
86
+
87
+ @torch.no_grad()
88
+ def main(args):
89
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
90
+
91
+ model_id_or_path = "/mnt/newhome/SSD-1B"
92
+ # refiner_id_or_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
93
+ device = "cuda" if torch.cuda.is_available() else "cpu"
94
+ variant = "fp16" if device == "cuda" else "fp32"
95
+
96
+ scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
97
+
98
+ if args.model is None:
99
+ pipe = CtrlXStableDiffusionXLPipeline.from_pretrained(
100
+ model_id_or_path, scheduler=scheduler, torch_dtype=torch_dtype, variant=variant, use_safetensors=True,
101
+ )
102
+ else:
103
+ print(f"Using weights {args.model} for SDXL base model.")
104
+ pipe = CtrlXStableDiffusionXLPipeline.from_single_file(args.model, scheduler=scheduler, torch_dtype=torch_dtype)
105
+
106
+ if args.model_offload or args.sequential_offload:
107
+ try:
108
+ import accelerate # Checking if accelerate is installed for Model/CPU offloading
109
+ except:
110
+ raise ModuleNotFoundError("`accelerate` must be installed for Model/CPU offloading.")
111
+
112
+ if args.sequential_offload:
113
+ pipe.enable_sequential_cpu_offload()
114
+ elif args.model_offload:
115
+ pipe.enable_model_cpu_offload()
116
+ else:
117
+ pipe = pipe.to(device)
118
+
119
+ model_load_print = "Base model "
120
+ if not args.disable_refiner:
121
+ model_load_print += "+ refiner "
122
+ if args.sequential_offload:
123
+ model_load_print += "loaded with sequential CPU offloading."
124
+ elif args.model_offload:
125
+ model_load_print += "loaded with model CPU offloading."
126
+ else:
127
+ model_load_print += "loaded."
128
+ print(f"{model_load_print} Running on device: {device}.")
129
+
130
+ t = time()
131
+
132
+ result, result_refiner, structure, appearance = inference(
133
+ pipe=pipe,
134
+ refiner=None,
135
+ device=device,
136
+ structure_image=args.structure_image,
137
+ appearance_image=args.appearance_image,
138
+ prompt=args.prompt,
139
+ structure_prompt=args.structure_prompt,
140
+ appearance_prompt=args.appearance_prompt,
141
+ positive_prompt=args.positive_prompt,
142
+ negative_prompt=args.negative_prompt,
143
+ guidance_scale=args.guidance_scale,
144
+ structure_guidance_scale=args.structure_guidance_scale,
145
+ appearance_guidance_scale=args.appearance_guidance_scale,
146
+ num_inference_steps=args.num_inference_steps,
147
+ eta=args.eta,
148
+ seed=args.seed,
149
+ width=args.width,
150
+ height=args.height,
151
+ structure_schedule=args.structure_schedule,
152
+ appearance_schedule=args.appearance_schedule,
153
+ )
154
+
155
+ makedirs(args.output_folder, exist_ok=True)
156
+ prefix = "ctrlx__" + datetime.now().strftime("%Y%m%d_%H%M%S")
157
+ structure.save(path.join(args.output_folder, f"{prefix}__structure.jpg"), quality=JPEG_QUALITY)
158
+ appearance.save(path.join(args.output_folder, f"{prefix}__appearance.jpg"), quality=JPEG_QUALITY)
159
+ result.save(path.join(args.output_folder, f"{prefix}__result.jpg"), quality=JPEG_QUALITY)
160
+ if result_refiner is not None:
161
+ result_refiner.save(path.join(args.output_folder, f"{prefix}__result_refiner.jpg"), quality=JPEG_QUALITY)
162
+
163
+ if args.benchmark:
164
+ inference_time = time() - t
165
+ peak_memory_usage = torch.cuda.max_memory_reserved()
166
+ print(f"Inference time: {inference_time:.2f}s")
167
+ print(f"Peak memory usage: {peak_memory_usage / pow(1024, 3):.2f}GiB")
168
+
169
+ print("Done.")
170
+
171
+
172
+ if __name__ == "__main__":
173
+ parser = ArgumentParser()
174
+
175
+ parser.add_argument("--structure_image", "-si", type=str, default=None)
176
+ parser.add_argument("--appearance_image", "-ai", type=str, default=None)
177
+
178
+ parser.add_argument("--prompt", "-p", type=str, required=True)
179
+ parser.add_argument("--structure_prompt", "-sp", type=str, default="")
180
+ parser.add_argument("--appearance_prompt", "-ap", type=str, default="")
181
+
182
+ parser.add_argument("--positive_prompt", "-pp", type=str, default="high quality")
183
+ parser.add_argument("--negative_prompt", "-np", type=str, default="ugly, blurry, dark, low res, unrealistic")
184
+
185
+ parser.add_argument("--guidance_scale", "-g", type=float, default=5.0)
186
+ parser.add_argument("--structure_guidance_scale", "-sg", type=float, default=5.0)
187
+ parser.add_argument("--appearance_guidance_scale", "-ag", type=float, default=5.0)
188
+
189
+ parser.add_argument("--num_inference_steps", "-n", type=int, default=50)
190
+ parser.add_argument("--eta", "-e", type=float, default=1.0)
191
+ parser.add_argument("--seed", "-s", type=int, default=90095)
192
+
193
+ parser.add_argument("--width", "-W", type=int, default=1024)
194
+ parser.add_argument("--height", "-H", type=int, default=1024)
195
+
196
+ parser.add_argument("--structure_schedule", "-ss", type=float, default=0.6)
197
+ parser.add_argument("--appearance_schedule", "-as", type=float, default=0.6)
198
+
199
+ parser.add_argument("--output_folder", "-o", type=str, default="./results")
200
+
201
+ parser.add_argument(
202
+ "-mo", "--model_offload", action="store_true",
203
+ help="Model CPU offload, lowers memory usage with slight runtime increase. `accelerate` must be installed.",
204
+ )
205
+ parser.add_argument(
206
+ "-so", "--sequential_offload", action="store_true",
207
+ help=(
208
+ "Sequential layer CPU offload, significantly lowers memory usage with massive runtime increase."
209
+ "`accelerate` must be installed. If both model_offload and sequential_offload are set, then use the latter."
210
+ ),
211
+ )
212
+ parser.add_argument("-r", "--disable_refiner", action="store_true")
213
+ parser.add_argument("-m", "--model", type=str, default=None, help="Optionally, load model safetensors.")
214
+ parser.add_argument("-b", "--benchmark", action="store_true", help="Show inference time and max memory usage.")
215
+
216
+ args = parser.parse_args()
217
+ main(args)
218
+
utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .feature import *
2
+ from .media import *
3
+ from .utils import *
utils/feature.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch.nn.functional as F
4
+
5
+ from .utils import *
6
+
7
+
8
+ def get_schedule(timesteps, schedule):
9
+ end = round(len(timesteps) * schedule)
10
+ timesteps = timesteps[:end]
11
+ return timesteps
12
+
13
+
14
+ def get_elem(l, i, default=0.0):
15
+ if i >= len(l):
16
+ return default
17
+ return l[i]
18
+
19
+
20
+ def pad_list(l_1, l_2, pad=0.0):
21
+ max_len = max(len(l_1), len(l_2))
22
+ l_1 = l_1 + [pad] * (max_len - len(l_1))
23
+ l_2 = l_2 + [pad] * (max_len - len(l_2))
24
+ return l_1, l_2
25
+
26
+
27
+ def normalize(x, dim):
28
+ x_mean = x.mean(dim=dim, keepdim=True)
29
+ x_std = x.std(dim=dim, keepdim=True)
30
+ x_normalized = (x - x_mean) / x_std
31
+ return x_normalized
32
+
33
+
34
+ # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
35
+ def appearance_mean_std(q_c_normed, k_s_normed, v_s): # c: content, s: style
36
+ q_c = q_c_normed # q_c and k_s must be projected from normalized features
37
+ k_s = k_s_normed
38
+ mean = F.scaled_dot_product_attention(q_c, k_s, v_s) # Use scaled_dot_product_attention for efficiency
39
+ std = (F.scaled_dot_product_attention(q_c, k_s, v_s.square()) - mean.square()).relu().sqrt()
40
+
41
+ return mean, std
42
+
43
+
44
+ def feature_injection(features, batch_order):
45
+ assert features.shape[0] % len(batch_order) == 0
46
+ features_dict = batch_tensor_to_dict(features, batch_order)
47
+ features_dict["cond"] = features_dict["structure_cond"]
48
+ features = batch_dict_to_tensor(features_dict, batch_order)
49
+ return features
50
+
51
+
52
+ def appearance_transfer(features, q_normed, k_normed, batch_order, v=None, reshape_fn=None):
53
+ assert features.shape[0] % len(batch_order) == 0
54
+
55
+ features_dict = batch_tensor_to_dict(features, batch_order)
56
+ q_normed_dict = batch_tensor_to_dict(q_normed, batch_order)
57
+ k_normed_dict = batch_tensor_to_dict(k_normed, batch_order)
58
+ v_dict = features_dict
59
+ if v is not None:
60
+ v_dict = batch_tensor_to_dict(v, batch_order)
61
+
62
+ mean_cond, std_cond = appearance_mean_std(
63
+ q_normed_dict["cond"], k_normed_dict["appearance_cond"], v_dict["appearance_cond"],
64
+ )
65
+
66
+ if reshape_fn is not None:
67
+ mean_cond = reshape_fn(mean_cond)
68
+ std_cond = reshape_fn(std_cond)
69
+
70
+ features_dict["cond"] = std_cond * normalize(features_dict["cond"], dim=-2) + mean_cond
71
+
72
+ features = batch_dict_to_tensor(features_dict, batch_order)
73
+ return features
utils/media.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchvision.transforms.functional as vF
4
+ import PIL
5
+
6
+
7
+ JPEG_QUALITY = 95
8
+
9
+
10
+ def preprocess(image, processor, **kwargs):
11
+ if isinstance(image, PIL.Image.Image):
12
+ pass
13
+ elif isinstance(image, np.ndarray):
14
+ image = PIL.Image.fromarray(image)
15
+ elif isinstance(image, torch.Tensor):
16
+ image = vF.to_pil_image(image)
17
+ else:
18
+ raise TypeError(f"Image must be of type PIL.Image, np.ndarray, or torch.Tensor, got {type(image)} instead.")
19
+
20
+ image = processor.preprocess(image, **kwargs)
21
+ return image
utils/sdxl.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import MethodType
2
+ from typing import Optional
3
+
4
+ from diffusers.models.attention_processor import Attention
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from .feature import *
9
+ from .utils import *
10
+
11
+
12
+ def get_control_config(structure_schedule, appearance_schedule):
13
+ s = structure_schedule
14
+ a = appearance_schedule
15
+
16
+ control_config =\
17
+ f"""control_schedule:
18
+ # structure_conv structure_attn appearance_attn conv/attn
19
+ encoder: # (num layers)
20
+ 0: [[ ], [ ], [ ]] # 2/0
21
+ 1: [[ ], [ ], [{a}, {a} ]] # 2/2
22
+ 2: [[ ], [ ], [{a}, {a} ]] # 2/2
23
+ middle: [[ ], [ ], [ ]] # 2/1
24
+ decoder:
25
+ 0: [[{s} ], [{s}, {s}, {s}], [0.0, {a}, {a}]] # 3/3
26
+ 1: [[ ], [ ], [{a}, {a} ]] # 3/3
27
+ 2: [[ ], [ ], [ ]] # 3/0
28
+
29
+ control_target:
30
+ - [output_tensor] # structure_conv choices: {{hidden_states, output_tensor}}
31
+ - [query, key] # structure_attn choices: {{query, key, value}}
32
+ - [before] # appearance_attn choices: {{before, value, after}}
33
+
34
+ self_recurrence_schedule:
35
+ - [0.1, 0.5, 2] # format: [start, end, num_recurrence]"""
36
+
37
+ return control_config
38
+
39
+
40
+ def convolution_forward( # From <class 'diffusers.models.resnet.ResnetBlock2D'>, forward (diffusers==0.28.0)
41
+ self,
42
+ input_tensor: torch.Tensor,
43
+ temb: torch.Tensor,
44
+ *args,
45
+ **kwargs,
46
+ ) -> torch.Tensor:
47
+ do_structure_control = self.do_control and self.t in self.structure_schedule
48
+
49
+ hidden_states = input_tensor
50
+
51
+ hidden_states = self.norm1(hidden_states)
52
+ hidden_states = self.nonlinearity(hidden_states)
53
+
54
+ if self.upsample is not None:
55
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
56
+ if hidden_states.shape[0] >= 64:
57
+ input_tensor = input_tensor.contiguous()
58
+ hidden_states = hidden_states.contiguous()
59
+ input_tensor = self.upsample(input_tensor)
60
+ hidden_states = self.upsample(hidden_states)
61
+ elif self.downsample is not None:
62
+ input_tensor = self.downsample(input_tensor)
63
+ hidden_states = self.downsample(hidden_states)
64
+
65
+ hidden_states = self.conv1(hidden_states)
66
+
67
+ if self.time_emb_proj is not None:
68
+ if not self.skip_time_act:
69
+ temb = self.nonlinearity(temb)
70
+ temb = self.time_emb_proj(temb)[:, :, None, None]
71
+
72
+ if self.time_embedding_norm == "default":
73
+ if temb is not None:
74
+ hidden_states = hidden_states + temb
75
+ hidden_states = self.norm2(hidden_states)
76
+ elif self.time_embedding_norm == "scale_shift":
77
+ if temb is None:
78
+ raise ValueError(
79
+ f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}"
80
+ )
81
+ time_scale, time_shift = torch.chunk(temb, 2, dim=1)
82
+ hidden_states = self.norm2(hidden_states)
83
+ hidden_states = hidden_states * (1 + time_scale) + time_shift
84
+ else:
85
+ hidden_states = self.norm2(hidden_states)
86
+
87
+ hidden_states = self.nonlinearity(hidden_states)
88
+
89
+ hidden_states = self.dropout(hidden_states)
90
+ hidden_states = self.conv2(hidden_states)
91
+
92
+ # Feature injection and AdaIN (hidden_states)
93
+ if do_structure_control and "hidden_states" in self.structure_target:
94
+ hidden_states = feature_injection(hidden_states, batch_order=self.batch_order)
95
+
96
+ if self.conv_shortcut is not None:
97
+ input_tensor = self.conv_shortcut(input_tensor)
98
+
99
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
100
+
101
+ # Feature injection and AdaIN (output_tensor)
102
+ if do_structure_control and "output_tensor" in self.structure_target:
103
+ output_tensor = feature_injection(output_tensor, batch_order=self.batch_order)
104
+
105
+ return output_tensor
106
+
107
+
108
+ class AttnProcessor2_0: # From <class 'diffusers.models.attention_processor.AttnProcessor2_0'> (diffusers==0.28.0)
109
+
110
+ def __init__(self):
111
+ if not hasattr(F, "scaled_dot_product_attention"):
112
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
113
+
114
+ def __call__(
115
+ self,
116
+ attn: Attention,
117
+ hidden_states: torch.FloatTensor,
118
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
119
+ attention_mask: Optional[torch.FloatTensor] = None,
120
+ temb: Optional[torch.FloatTensor] = None,
121
+ *args,
122
+ **kwargs,
123
+ ) -> torch.FloatTensor:
124
+ do_structure_control = attn.do_control and attn.t in attn.structure_schedule
125
+ do_appearance_control = attn.do_control and attn.t in attn.appearance_schedule
126
+
127
+ residual = hidden_states
128
+ if attn.spatial_norm is not None:
129
+ hidden_states = attn.spatial_norm(hidden_states, temb)
130
+
131
+ input_ndim = hidden_states.ndim
132
+
133
+ if input_ndim == 4:
134
+ batch_size, channel, height, width = hidden_states.shape
135
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
136
+
137
+ batch_size, sequence_length, _ = (
138
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
139
+ )
140
+
141
+ if attention_mask is not None:
142
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
143
+ # scaled_dot_product_attention expects attention_mask shape to be
144
+ # (batch, heads, source_length, target_length)
145
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
146
+
147
+ if attn.group_norm is not None:
148
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
149
+
150
+ no_encoder_hidden_states = encoder_hidden_states is None
151
+ if no_encoder_hidden_states:
152
+ encoder_hidden_states = hidden_states
153
+ elif attn.norm_cross:
154
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
155
+
156
+ if do_appearance_control: # Assume we only have this for self attention
157
+ hidden_states_normed = normalize(hidden_states, dim=-2) # B H D C
158
+ encoder_hidden_states_normed = normalize(encoder_hidden_states, dim=-2)
159
+
160
+ query_normed = attn.to_q(hidden_states_normed)
161
+ key_normed = attn.to_k(encoder_hidden_states_normed)
162
+
163
+ inner_dim = key_normed.shape[-1]
164
+ head_dim = inner_dim // attn.heads
165
+ query_normed = query_normed.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
166
+ key_normed = key_normed.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
167
+
168
+ # Match query and key injection with structure injection (if injection is happening this layer)
169
+ if do_structure_control:
170
+ if "query" in attn.structure_target:
171
+ query_normed = feature_injection(query_normed, batch_order=attn.batch_order)
172
+ if "key" in attn.structure_target:
173
+ key_normed = feature_injection(key_normed, batch_order=attn.batch_order)
174
+
175
+ # Appearance transfer (before)
176
+ if do_appearance_control and "before" in attn.appearance_target:
177
+ hidden_states = hidden_states.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
178
+ hidden_states = appearance_transfer(hidden_states, query_normed, key_normed, batch_order=attn.batch_order)
179
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
180
+
181
+ if no_encoder_hidden_states:
182
+ encoder_hidden_states = hidden_states
183
+ elif attn.norm_cross:
184
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
185
+
186
+ query = attn.to_q(hidden_states)
187
+
188
+ key = attn.to_k(encoder_hidden_states)
189
+ value = attn.to_v(encoder_hidden_states)
190
+
191
+ inner_dim = key.shape[-1]
192
+ head_dim = inner_dim // attn.heads
193
+
194
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
195
+
196
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
197
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
198
+
199
+ # Feature injection (query, key, and/or value)
200
+ if do_structure_control:
201
+ if "query" in attn.structure_target:
202
+ query = feature_injection(query, batch_order=attn.batch_order)
203
+ if "key" in attn.structure_target:
204
+ key = feature_injection(key, batch_order=attn.batch_order)
205
+ if "value" in attn.structure_target:
206
+ value = feature_injection(value, batch_order=attn.batch_order)
207
+
208
+ # Appearance transfer (value)
209
+ if do_appearance_control and "value" in attn.appearance_target:
210
+ value = appearance_transfer(value, query_normed, key_normed, batch_order=attn.batch_order)
211
+
212
+ # The output of sdp = (batch, num_heads, seq_len, head_dim)
213
+ # TODO: add support for attn.scale when we move to Torch 2.1
214
+ hidden_states = F.scaled_dot_product_attention(
215
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
216
+ )
217
+
218
+ # Appearance transfer (after)
219
+ if do_appearance_control and "after" in attn.appearance_target:
220
+ hidden_states = appearance_transfer(hidden_states, query_normed, key_normed, batch_order=attn.batch_order)
221
+
222
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
223
+ hidden_states = hidden_states.to(query.dtype)
224
+
225
+ # Linear projection
226
+ hidden_states = attn.to_out[0](hidden_states, *args)
227
+ # Dropout
228
+ hidden_states = attn.to_out[1](hidden_states)
229
+
230
+ if input_ndim == 4:
231
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
232
+
233
+ if attn.residual_connection:
234
+ hidden_states = hidden_states + residual
235
+
236
+ hidden_states = hidden_states / attn.rescale_output_factor
237
+
238
+ return hidden_states
239
+
240
+
241
+ def register_control(
242
+ model,
243
+ timesteps,
244
+ control_schedule, # structure_conv, structure_attn, appearance_attn
245
+ control_target = [["output_tensor"], ["query", "key"], ["before"]],
246
+ ):
247
+ # Assume timesteps in reverse order (T -> 0)
248
+ for block_type in ["encoder", "decoder", "middle"]:
249
+ blocks = {
250
+ "encoder": model.unet.down_blocks,
251
+ "decoder": model.unet.up_blocks,
252
+ "middle": [model.unet.mid_block],
253
+ }[block_type]
254
+
255
+ control_schedule_block = control_schedule[block_type]
256
+ if block_type == "middle":
257
+ control_schedule_block = [control_schedule_block]
258
+
259
+ for layer in range(len(control_schedule_block)):
260
+ # Convolution
261
+ num_blocks = len(blocks[layer].resnets) if hasattr(blocks[layer], "resnets") else 0
262
+ for block in range(num_blocks):
263
+ convolution = blocks[layer].resnets[block]
264
+ convolution.structure_target = control_target[0]
265
+ convolution.structure_schedule = get_schedule(
266
+ timesteps, get_elem(control_schedule_block[layer][0], block)
267
+ )
268
+ convolution.forward = MethodType(convolution_forward, convolution)
269
+
270
+ # Self-attention
271
+ num_blocks = len(blocks[layer].attentions) if hasattr(blocks[layer], "attentions") else 0
272
+ for block in range(num_blocks):
273
+ for transformer_block in blocks[layer].attentions[block].transformer_blocks:
274
+ attention = transformer_block.attn1
275
+ attention.structure_target = control_target[1]
276
+ attention.structure_schedule = get_schedule(
277
+ timesteps, get_elem(control_schedule_block[layer][1], block)
278
+ )
279
+ attention.appearance_target = control_target[2]
280
+ attention.appearance_schedule = get_schedule(
281
+ timesteps, get_elem(control_schedule_block[layer][2], block)
282
+ )
283
+ attention.processor = AttnProcessor2_0()
284
+
285
+
286
+ def register_attr(model, t, do_control, batch_order):
287
+ for layer_type in ["encoder", "decoder", "middle"]:
288
+ blocks = {"encoder": model.unet.down_blocks, "decoder": model.unet.up_blocks,
289
+ "middle": [model.unet.mid_block]}[layer_type]
290
+ for layer in blocks:
291
+ # Convolution
292
+ for module in layer.resnets:
293
+ module.t = t
294
+ module.do_control = do_control
295
+ module.batch_order = batch_order
296
+ # Self-attention
297
+ if hasattr(layer, "attentions"):
298
+ for block in layer.attentions:
299
+ for module in block.transformer_blocks:
300
+ module.attn1.t = t
301
+ module.attn1.do_control = do_control
302
+ module.attn1.batch_order = batch_order
utils/utils.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from os import environ
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ JPEG_QUALITY = 100
9
+
10
+
11
+ def seed_everything(seed):
12
+ random.seed(seed)
13
+ environ["PYTHONHASHSEED"] = str(seed)
14
+ np.random.seed(seed)
15
+ torch.manual_seed(seed)
16
+ torch.backends.cudnn.deterministic = True
17
+ torch.backends.cudnn.benchmark = False
18
+
19
+
20
+ def exists(x):
21
+ return x is not None
22
+
23
+
24
+ def get(x, default):
25
+ if exists(x):
26
+ return x
27
+ return default
28
+
29
+
30
+ def get_self_recurrence_schedule(schedule, num_inference_steps):
31
+ self_recurrence_schedule = [0] * num_inference_steps
32
+ for schedule_current in reversed(schedule):
33
+ if schedule_current is None or len(schedule_current) == 0:
34
+ continue
35
+ [start, end, repeat] = schedule_current
36
+ start_i = round(num_inference_steps * start)
37
+ end_i = round(num_inference_steps * end)
38
+ for i in range(start_i, end_i):
39
+ self_recurrence_schedule[i] = repeat
40
+ return self_recurrence_schedule
41
+
42
+
43
+ def batch_dict_to_tensor(batch_dict, batch_order):
44
+ batch_tensor = []
45
+ for batch_type in batch_order:
46
+ batch_tensor.append(batch_dict[batch_type])
47
+ batch_tensor = torch.cat(batch_tensor, dim=0)
48
+ return batch_tensor
49
+
50
+
51
+ def batch_tensor_to_dict(batch_tensor, batch_order):
52
+ batch_tensor_chunk = batch_tensor.chunk(len(batch_order))
53
+ batch_dict = {}
54
+ for i, batch_type in enumerate(batch_order):
55
+ batch_dict[batch_type] = batch_tensor_chunk[i]
56
+ return batch_dict
57
+
58
+
59
+ def noise_prev(scheduler, timestep, x_0, noise=None):
60
+ if scheduler.num_inference_steps is None:
61
+ raise ValueError(
62
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
63
+ )
64
+
65
+ if noise is None:
66
+ noise = torch.randn_like(x_0).to(x_0)
67
+
68
+ # From DDIMScheduler step function (hopefully this works)
69
+ timestep_i = (scheduler.timesteps == timestep).nonzero(as_tuple=True)[0][0].item()
70
+ if timestep_i + 1 >= scheduler.timesteps.shape[0]: # We are at t = 0 (ish)
71
+ return x_0
72
+ prev_timestep = scheduler.timesteps[timestep_i + 1:timestep_i + 2] # Make sure t is not 0-dim
73
+
74
+ x_t_prev = scheduler.add_noise(x_0, noise, prev_timestep)
75
+ return x_t_prev
76
+
77
+
78
+ def noise_t2t(scheduler, timestep, timestep_target, x_t, noise=None):
79
+ assert timestep_target >= timestep
80
+ if noise is None:
81
+ noise = torch.randn_like(x_t).to(x_t)
82
+
83
+ alphas_cumprod = scheduler.alphas_cumprod.to(device=x_t.device, dtype=x_t.dtype)
84
+
85
+ timestep = timestep.to(torch.long)
86
+ timestep_target = timestep_target.to(torch.long)
87
+
88
+ alpha_prod_t = alphas_cumprod[timestep]
89
+ alpha_prod_tt = alphas_cumprod[timestep_target]
90
+ alpha_prod = alpha_prod_tt / alpha_prod_t
91
+
92
+ sqrt_alpha_prod = (alpha_prod ** 0.5).flatten()
93
+ while len(sqrt_alpha_prod.shape) < len(x_t.shape):
94
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
95
+
96
+ sqrt_one_minus_alpha_prod = ((1 - alpha_prod) ** 0.5).flatten()
97
+ while len(sqrt_one_minus_alpha_prod.shape) < len(x_t.shape):
98
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
99
+
100
+ x_tt = sqrt_alpha_prod * x_t + sqrt_one_minus_alpha_prod * noise
101
+ return x_tt