fffiloni commited on
Commit
125ed31
1 Parent(s): f61459f

Create injection_main.py

Browse files
Files changed (1) hide show
  1. injection_main.py +739 -0
injection_main.py ADDED
@@ -0,0 +1,739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import argparse, os
3
+
4
+
5
+ import torch
6
+ import requests
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from PIL import Image
10
+ from io import BytesIO
11
+ from tqdm.auto import tqdm
12
+ from matplotlib import pyplot as plt
13
+ from torchvision import transforms as tfms
14
+ from diffusers import (
15
+ StableDiffusionPipeline,
16
+ DDIMScheduler,
17
+ DiffusionPipeline,
18
+ StableDiffusionXLPipeline,
19
+ )
20
+ from diffusers.image_processor import VaeImageProcessor
21
+ import torch
22
+ import torch.nn as nn
23
+ import torchvision
24
+ import torchvision.transforms as transforms
25
+ from torchvision.utils import save_image
26
+ import argparse
27
+ import PIL.Image as Image
28
+ from torchvision.utils import make_grid
29
+ import numpy
30
+ from diffusers.schedulers import DDIMScheduler
31
+ import torch.nn.functional as F
32
+ from models import attn_injection
33
+ from omegaconf import OmegaConf
34
+ from typing import List, Tuple
35
+
36
+ import omegaconf
37
+ import utils.exp_utils
38
+ import json
39
+
40
+ device = torch.device("cuda")
41
+
42
+
43
+ def _get_text_embeddings(prompt: str, tokenizer, text_encoder, device):
44
+ # Tokenize text and get embeddings
45
+ text_inputs = tokenizer(
46
+ prompt,
47
+ padding="max_length",
48
+ max_length=tokenizer.model_max_length,
49
+ truncation=True,
50
+ return_tensors="pt",
51
+ )
52
+ text_input_ids = text_inputs.input_ids
53
+
54
+ with torch.no_grad():
55
+ prompt_embeds = text_encoder(
56
+ text_input_ids.to(device),
57
+ output_hidden_states=True,
58
+ )
59
+
60
+ pooled_prompt_embeds = prompt_embeds[0]
61
+ prompt_embeds = prompt_embeds.hidden_states[-2]
62
+ if prompt == "":
63
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
64
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
65
+ return negative_prompt_embeds, negative_pooled_prompt_embeds
66
+ return prompt_embeds, pooled_prompt_embeds
67
+
68
+
69
+ def _encode_text_sdxl(model: StableDiffusionXLPipeline, prompt: str):
70
+ device = model._execution_device
71
+ (
72
+ prompt_embeds,
73
+ pooled_prompt_embeds,
74
+ ) = _get_text_embeddings(prompt, model.tokenizer, model.text_encoder, device)
75
+ (
76
+ prompt_embeds_2,
77
+ pooled_prompt_embeds_2,
78
+ ) = _get_text_embeddings(prompt, model.tokenizer_2, model.text_encoder_2, device)
79
+ prompt_embeds = torch.cat((prompt_embeds, prompt_embeds_2), dim=-1)
80
+ text_encoder_projection_dim = model.text_encoder_2.config.projection_dim
81
+ add_time_ids = model._get_add_time_ids(
82
+ (1024, 1024), (0, 0), (1024, 1024), torch.float16, text_encoder_projection_dim
83
+ ).to(device)
84
+ # repeat the time ids for each prompt
85
+ add_time_ids = add_time_ids.repeat(len(prompt), 1)
86
+ added_cond_kwargs = {
87
+ "text_embeds": pooled_prompt_embeds_2,
88
+ "time_ids": add_time_ids,
89
+ }
90
+ return added_cond_kwargs, prompt_embeds
91
+
92
+
93
+ def _encode_text_sdxl_with_negative(
94
+ model: StableDiffusionXLPipeline, prompt: List[str]
95
+ ):
96
+
97
+ B = len(prompt)
98
+ added_cond_kwargs, prompt_embeds = _encode_text_sdxl(model, prompt)
99
+ added_cond_kwargs_uncond, prompt_embeds_uncond = _encode_text_sdxl(
100
+ model, ["" for _ in range(B)]
101
+ )
102
+ prompt_embeds = torch.cat(
103
+ (
104
+ prompt_embeds_uncond,
105
+ prompt_embeds,
106
+ )
107
+ )
108
+ added_cond_kwargs = {
109
+ "text_embeds": torch.cat(
110
+ (added_cond_kwargs_uncond["text_embeds"], added_cond_kwargs["text_embeds"])
111
+ ),
112
+ "time_ids": torch.cat(
113
+ (added_cond_kwargs_uncond["time_ids"], added_cond_kwargs["time_ids"])
114
+ ),
115
+ }
116
+ return added_cond_kwargs, prompt_embeds
117
+
118
+
119
+ # Sample function (regular DDIM)
120
+ @torch.no_grad()
121
+ def sample(
122
+ pipe,
123
+ prompt,
124
+ start_step=0,
125
+ start_latents=None,
126
+ intermediate_latents=None,
127
+ guidance_scale=3.5,
128
+ num_inference_steps=30,
129
+ num_images_per_prompt=1,
130
+ do_classifier_free_guidance=True,
131
+ negative_prompt="",
132
+ device=device,
133
+ ):
134
+ negative_prompt = [""] * len(prompt)
135
+ # Encode prompt
136
+ if isinstance(pipe, StableDiffusionPipeline):
137
+ text_embeddings = pipe._encode_prompt(
138
+ prompt,
139
+ device,
140
+ num_images_per_prompt,
141
+ do_classifier_free_guidance,
142
+ negative_prompt,
143
+ )
144
+ added_cond_kwargs = None
145
+ elif isinstance(pipe, StableDiffusionXLPipeline):
146
+ added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative(
147
+ pipe, prompt
148
+ )
149
+
150
+ # Set num inference steps
151
+ pipe.scheduler.set_timesteps(num_inference_steps, device=device)
152
+
153
+ # Create a random starting point if we don't have one already
154
+ if start_latents is None:
155
+ start_latents = torch.randn(1, 4, 64, 64, device=device)
156
+ start_latents *= pipe.scheduler.init_noise_sigma
157
+
158
+ latents = start_latents.clone()
159
+
160
+ latents = latents.repeat(len(prompt), 1, 1, 1)
161
+ # assume that the first latent is used for reconstruction
162
+ for i in tqdm(range(start_step, num_inference_steps)):
163
+ latents[0] = intermediate_latents[(-i + 1)]
164
+ t = pipe.scheduler.timesteps[i]
165
+
166
+ # Expand the latents if we are doing classifier free guidance
167
+ latent_model_input = (
168
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
169
+ )
170
+ latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
171
+
172
+ # Predict the noise residual
173
+ noise_pred = pipe.unet(
174
+ latent_model_input,
175
+ t,
176
+ encoder_hidden_states=text_embeddings,
177
+ added_cond_kwargs=added_cond_kwargs,
178
+ ).sample
179
+
180
+ # Perform guidance
181
+ if do_classifier_free_guidance:
182
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
183
+ noise_pred = noise_pred_uncond + guidance_scale * (
184
+ noise_pred_text - noise_pred_uncond
185
+ )
186
+ latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
187
+
188
+ # Post-processing
189
+ images = pipe.decode_latents(latents)
190
+ images = pipe.numpy_to_pil(images)
191
+
192
+ return images
193
+
194
+
195
+ # Sample function (regular DDIM), but disentangle the content and style
196
+ @torch.no_grad()
197
+ def sample_disentangled(
198
+ pipe,
199
+ prompt,
200
+ start_step=0,
201
+ start_latents=None,
202
+ intermediate_latents=None,
203
+ guidance_scale=3.5,
204
+ num_inference_steps=30,
205
+ num_images_per_prompt=1,
206
+ do_classifier_free_guidance=True,
207
+ use_content_anchor=True,
208
+ negative_prompt="",
209
+ device=device,
210
+ ):
211
+ negative_prompt = [""] * len(prompt)
212
+ vae_decoder = VaeImageProcessor(vae_scale_factor=pipe.vae.config.scaling_factor)
213
+ # Encode prompt
214
+ if isinstance(pipe, StableDiffusionPipeline):
215
+ text_embeddings = pipe._encode_prompt(
216
+ prompt,
217
+ device,
218
+ num_images_per_prompt,
219
+ do_classifier_free_guidance,
220
+ negative_prompt,
221
+ )
222
+ added_cond_kwargs = None
223
+ elif isinstance(pipe, StableDiffusionXLPipeline):
224
+ added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative(
225
+ pipe, prompt
226
+ )
227
+
228
+ # Set num inference steps
229
+ pipe.scheduler.set_timesteps(num_inference_steps, device=device)
230
+ # save
231
+
232
+ latent_shape = (
233
+ (1, 4, 64, 64) if isinstance(pipe, StableDiffusionPipeline) else (1, 4, 64, 64)
234
+ )
235
+ generative_latent = torch.randn(latent_shape, device=device)
236
+ generative_latent *= pipe.scheduler.init_noise_sigma
237
+
238
+ latents = start_latents.clone()
239
+
240
+ latents = latents.repeat(len(prompt), 1, 1, 1)
241
+ # randomly initalize the 1st lantent for generation
242
+
243
+ latents[1] = generative_latent
244
+ # assume that the first latent is used for reconstruction
245
+ for i in tqdm(range(start_step, num_inference_steps), desc="Stylizing"):
246
+
247
+ if use_content_anchor:
248
+ latents[0] = intermediate_latents[(-i + 1)]
249
+ t = pipe.scheduler.timesteps[i]
250
+
251
+ # Expand the latents if we are doing classifier free guidance
252
+ latent_model_input = (
253
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
254
+ )
255
+ latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
256
+
257
+ # Predict the noise residual
258
+ noise_pred = pipe.unet(
259
+ latent_model_input,
260
+ t,
261
+ encoder_hidden_states=text_embeddings,
262
+ added_cond_kwargs=added_cond_kwargs,
263
+ ).sample
264
+
265
+ # Perform guidance
266
+ if do_classifier_free_guidance:
267
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
268
+ noise_pred = noise_pred_uncond + guidance_scale * (
269
+ noise_pred_text - noise_pred_uncond
270
+ )
271
+
272
+ latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
273
+
274
+ # Post-processing
275
+ # images = vae_decoder.postprocess(latents)
276
+ pipe.vae.to(dtype=torch.float32)
277
+ latents = latents.to(next(iter(pipe.vae.post_quant_conv.parameters())).dtype)
278
+ latents = 1 / pipe.vae.config.scaling_factor * latents
279
+ images = pipe.vae.decode(latents, return_dict=False)[0]
280
+ images = (images / 2 + 0.5).clamp(0, 1)
281
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
282
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
283
+ images = pipe.numpy_to_pil(images)
284
+ if isinstance(pipe, StableDiffusionXLPipeline):
285
+ pipe.vae.to(dtype=torch.float16)
286
+
287
+ return images
288
+
289
+
290
+ ## Inversion
291
+ @torch.no_grad()
292
+ def invert(
293
+ pipe,
294
+ start_latents,
295
+ prompt,
296
+ guidance_scale=3.5,
297
+ num_inference_steps=50,
298
+ num_images_per_prompt=1,
299
+ do_classifier_free_guidance=True,
300
+ negative_prompt="",
301
+ device=device,
302
+ ):
303
+
304
+ # Encode prompt
305
+ if isinstance(pipe, StableDiffusionPipeline):
306
+ text_embeddings = pipe._encode_prompt(
307
+ prompt,
308
+ device,
309
+ num_images_per_prompt,
310
+ do_classifier_free_guidance,
311
+ negative_prompt,
312
+ )
313
+ added_cond_kwargs = None
314
+ latents = start_latents.clone().detach()
315
+ elif isinstance(pipe, StableDiffusionXLPipeline):
316
+ added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative(
317
+ pipe, [prompt]
318
+ ) # Latents are now the specified start latents
319
+ latents = start_latents.clone().detach().half()
320
+
321
+ # We'll keep a list of the inverted latents as the process goes on
322
+ intermediate_latents = []
323
+
324
+ # Set num inference steps
325
+ pipe.scheduler.set_timesteps(num_inference_steps, device=device)
326
+
327
+ # Reversed timesteps <<<<<<<<<<<<<<<<<<<<
328
+ timesteps = reversed(pipe.scheduler.timesteps)
329
+
330
+ for i in tqdm(
331
+ range(1, num_inference_steps),
332
+ total=num_inference_steps - 1,
333
+ desc="DDIM Inversion",
334
+ ):
335
+
336
+ # We'll skip the final iteration
337
+ if i >= num_inference_steps - 1:
338
+ continue
339
+
340
+ t = timesteps[i]
341
+
342
+ # Expand the latents if we are doing classifier free guidance
343
+ latent_model_input = (
344
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
345
+ )
346
+ latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
347
+
348
+ # Predict the noise residual
349
+ noise_pred = pipe.unet(
350
+ latent_model_input,
351
+ t,
352
+ encoder_hidden_states=text_embeddings,
353
+ added_cond_kwargs=added_cond_kwargs,
354
+ ).sample
355
+
356
+ # Perform guidance
357
+ if do_classifier_free_guidance:
358
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
359
+ noise_pred = noise_pred_uncond + guidance_scale * (
360
+ noise_pred_text - noise_pred_uncond
361
+ )
362
+
363
+ current_t = max(0, t.item() - (1000 // num_inference_steps)) # t
364
+ next_t = t # min(999, t.item() + (1000//num_inference_steps)) # t+1
365
+ alpha_t = pipe.scheduler.alphas_cumprod[current_t]
366
+ alpha_t_next = pipe.scheduler.alphas_cumprod[next_t]
367
+
368
+ # Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents)
369
+ latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * (
370
+ alpha_t_next.sqrt() / alpha_t.sqrt()
371
+ ) + (1 - alpha_t_next).sqrt() * noise_pred
372
+
373
+ # Store
374
+ intermediate_latents.append(latents)
375
+
376
+ return torch.cat(intermediate_latents)
377
+
378
+
379
+ def style_image_with_inversion(
380
+ pipe,
381
+ input_image,
382
+ input_image_prompt,
383
+ style_prompt,
384
+ num_steps=100,
385
+ start_step=30,
386
+ guidance_scale=3.5,
387
+ disentangle=False,
388
+ share_attn=False,
389
+ share_cross_attn=False,
390
+ share_resnet_layers=[0, 1],
391
+ share_attn_layers=[],
392
+ c2s_layers=[0, 1],
393
+ share_key=True,
394
+ share_query=True,
395
+ share_value=False,
396
+ use_adain=True,
397
+ use_content_anchor=True,
398
+ output_dir: str = None,
399
+ resnet_mode: str = None,
400
+ return_intermediate=False,
401
+ intermediate_latents=None,
402
+ ):
403
+ with torch.no_grad():
404
+ pipe.vae.to(dtype=torch.float32)
405
+ latent = pipe.vae.encode(input_image.to(device) * 2 - 1)
406
+ # latent = pipe.vae.encode(input_image.to(device))
407
+ l = pipe.vae.config.scaling_factor * latent.latent_dist.sample()
408
+ if isinstance(pipe, StableDiffusionXLPipeline):
409
+ pipe.vae.to(dtype=torch.float16)
410
+ if intermediate_latents is None:
411
+ inverted_latents = invert(
412
+ pipe, l, input_image_prompt, num_inference_steps=num_steps
413
+ )
414
+ else:
415
+ inverted_latents = intermediate_latents
416
+
417
+ attn_injection.register_attention_processors(
418
+ pipe,
419
+ base_dir=output_dir,
420
+ resnet_mode=resnet_mode,
421
+ attn_mode="artist" if disentangle else "pnp",
422
+ disentangle=disentangle,
423
+ share_resblock=True,
424
+ share_attn=share_attn,
425
+ share_cross_attn=share_cross_attn,
426
+ share_resnet_layers=share_resnet_layers,
427
+ share_attn_layers=share_attn_layers,
428
+ share_key=share_key,
429
+ share_query=share_query,
430
+ share_value=share_value,
431
+ use_adain=use_adain,
432
+ c2s_layers=c2s_layers,
433
+ )
434
+
435
+ if disentangle:
436
+ final_im = sample_disentangled(
437
+ pipe,
438
+ style_prompt,
439
+ start_latents=inverted_latents[-(start_step + 1)][None],
440
+ intermediate_latents=inverted_latents,
441
+ start_step=start_step,
442
+ num_inference_steps=num_steps,
443
+ guidance_scale=guidance_scale,
444
+ use_content_anchor=use_content_anchor,
445
+ )
446
+ else:
447
+ final_im = sample(
448
+ pipe,
449
+ style_prompt,
450
+ start_latents=inverted_latents[-(start_step + 1)][None],
451
+ intermediate_latents=inverted_latents,
452
+ start_step=start_step,
453
+ num_inference_steps=num_steps,
454
+ guidance_scale=guidance_scale,
455
+ )
456
+
457
+ # unset the attention processors
458
+ attn_injection.unset_attention_processors(
459
+ pipe,
460
+ unset_share_attn=True,
461
+ unset_share_resblock=True,
462
+ )
463
+ if return_intermediate:
464
+ return final_im, inverted_latents
465
+ return final_im
466
+
467
+
468
+ if __name__ == "__main__":
469
+
470
+ # Load a pipeline
471
+ pipe = StableDiffusionPipeline.from_pretrained(
472
+ "stabilityai/stable-diffusion-2-1-base"
473
+ ).to(device)
474
+
475
+ # pipe = DiffusionPipeline.from_pretrained(
476
+ # # "playgroundai/playground-v2-1024px-aesthetic",
477
+ # torch_dtype=torch.float16,
478
+ # use_safetensors=True,
479
+ # add_watermarker=False,
480
+ # variant="fp16",
481
+ # )
482
+ # pipe.to("cuda")
483
+
484
+ # Set up a DDIM scheduler
485
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
486
+
487
+ parser = argparse.ArgumentParser(description="Stable Diffusion with OmegaConf")
488
+ parser.add_argument(
489
+ "--config", type=str, default="config.yaml", help="Path to the config file"
490
+ )
491
+ parser.add_argument(
492
+ "--mode",
493
+ type=str,
494
+ default="dataset",
495
+ choices=["dataset", "cli", "app"],
496
+ help="Path to the config file",
497
+ )
498
+ parser.add_argument(
499
+ "--image_dir", type=str, default="test.png", help="Path to the image"
500
+ )
501
+ parser.add_argument(
502
+ "--prompt",
503
+ type=str,
504
+ default="an impressionist painting",
505
+ help="Stylization prompt",
506
+ )
507
+ # mode = "single_control_content"
508
+ args = parser.parse_args()
509
+ config_dir = args.config
510
+ mode = args.mode
511
+ # mode = "dataset"
512
+ out_name = ["content_delegation", "style_delegation", "style_out"]
513
+
514
+ if mode == "dataset":
515
+ cfg = OmegaConf.load(config_dir)
516
+
517
+ base_output_path = cfg.out_path
518
+ if not os.path.exists(cfg.out_path):
519
+ os.makedirs(cfg.out_path)
520
+ base_output_path = os.path.join(base_output_path, cfg.exp_name)
521
+
522
+ experiment_output_path = utils.exp_utils.make_unique_experiment_path(
523
+ base_output_path
524
+ )
525
+
526
+ # Save the experiment configuration
527
+ config_file_path = os.path.join(experiment_output_path, "config.yaml")
528
+ omegaconf.OmegaConf.save(cfg, config_file_path)
529
+
530
+ # Seed all
531
+
532
+ annotation = json.load(open(cfg.annotation))
533
+ with open(os.path.join(experiment_output_path, "annotation.json"), "w") as f:
534
+ json.dump(annotation, f)
535
+ for i, entry in enumerate(annotation):
536
+ utils.exp_utils.seed_all(cfg.seed)
537
+ image_path = entry["image_path"]
538
+ src_prompt = entry["source_prompt"]
539
+ tgt_prompt = entry["target_prompt"]
540
+ resolution = 512 if isinstance(pipe, StableDiffusionXLPipeline) else 512
541
+ input_image = utils.exp_utils.get_processed_image(
542
+ image_path, device, resolution
543
+ )
544
+
545
+ prompt_in = [
546
+ src_prompt, # reconstruction
547
+ tgt_prompt, # uncontrolled style
548
+ "", # controlled style
549
+ ]
550
+
551
+ imgs = style_image_with_inversion(
552
+ pipe,
553
+ input_image,
554
+ src_prompt,
555
+ style_prompt=prompt_in,
556
+ num_steps=cfg.num_steps,
557
+ start_step=cfg.start_step,
558
+ guidance_scale=cfg.style_cfg_scale,
559
+ disentangle=cfg.disentangle,
560
+ resnet_mode=cfg.resnet_mode,
561
+ share_attn=cfg.share_attn,
562
+ share_cross_attn=cfg.share_cross_attn,
563
+ share_resnet_layers=cfg.share_resnet_layers,
564
+ share_attn_layers=cfg.share_attn_layers,
565
+ share_key=cfg.share_key,
566
+ share_query=cfg.share_query,
567
+ share_value=cfg.share_value,
568
+ use_content_anchor=cfg.use_content_anchor,
569
+ use_adain=cfg.use_adain,
570
+ output_dir=experiment_output_path,
571
+ )
572
+
573
+ for j, img in enumerate(imgs):
574
+ img.save(f"{experiment_output_path}/out_{i}_{out_name[j]}.png")
575
+ print(
576
+ f"Image saved as {experiment_output_path}/out_{i}_{out_name[j]}.png"
577
+ )
578
+ elif mode == "cli":
579
+ cfg = OmegaConf.load(config_dir)
580
+ utils.exp_utils.seed_all(cfg.seed)
581
+ image = utils.exp_utils.get_processed_image(args.image_dir, device, 512)
582
+ tgt_prompt = args.prompt
583
+ src_prompt = ""
584
+ prompt_in = [
585
+ "", # reconstruction
586
+ tgt_prompt, # uncontrolled style
587
+ "", # controlled style
588
+ ]
589
+ out_dir = "./out"
590
+ os.makedirs(out_dir, exist_ok=True)
591
+ imgs = style_image_with_inversion(
592
+ pipe,
593
+ image,
594
+ src_prompt,
595
+ style_prompt=prompt_in,
596
+ num_steps=cfg.num_steps,
597
+ start_step=cfg.start_step,
598
+ guidance_scale=cfg.style_cfg_scale,
599
+ disentangle=cfg.disentangle,
600
+ resnet_mode=cfg.resnet_mode,
601
+ share_attn=cfg.share_attn,
602
+ share_cross_attn=cfg.share_cross_attn,
603
+ share_resnet_layers=cfg.share_resnet_layers,
604
+ share_attn_layers=cfg.share_attn_layers,
605
+ share_key=cfg.share_key,
606
+ share_query=cfg.share_query,
607
+ share_value=cfg.share_value,
608
+ use_content_anchor=cfg.use_content_anchor,
609
+ use_adain=cfg.use_adain,
610
+ output_dir=out_dir,
611
+ )
612
+ image_base_name = os.path.basename(args.image_dir).split(".")[0]
613
+ for j, img in enumerate(imgs):
614
+ img.save(f"{out_dir}/{image_base_name}_out_{out_name[j]}.png")
615
+ print(f"Image saved as {out_dir}/{image_base_name}_out_{out_name[j]}.png")
616
+ elif mode == "app":
617
+ # gradio
618
+ import gradio as gr
619
+
620
+ def style_transfer_app(
621
+ prompt,
622
+ image,
623
+ cfg_scale=7.5,
624
+ num_content_layers=4,
625
+ num_style_layers=9,
626
+ seed=0,
627
+ progress=gr.Progress(track_tqdm=True),
628
+ ):
629
+ utils.exp_utils.seed_all(seed)
630
+ image = utils.exp_utils.process_image(image, device, 512)
631
+
632
+ tgt_prompt = prompt
633
+ src_prompt = ""
634
+ prompt_in = [
635
+ "", # reconstruction
636
+ tgt_prompt, # uncontrolled style
637
+ "", # controlled style
638
+ ]
639
+
640
+ share_resnet_layers = (
641
+ list(range(num_content_layers)) if num_content_layers != 0 else None
642
+ )
643
+ share_attn_layers = (
644
+ list(range(num_style_layers)) if num_style_layers != 0 else None
645
+ )
646
+ imgs = style_image_with_inversion(
647
+ pipe,
648
+ image,
649
+ src_prompt,
650
+ style_prompt=prompt_in,
651
+ num_steps=50,
652
+ start_step=0,
653
+ guidance_scale=cfg_scale,
654
+ disentangle=True,
655
+ resnet_mode="hidden",
656
+ share_attn=True,
657
+ share_cross_attn=True,
658
+ share_resnet_layers=share_resnet_layers,
659
+ share_attn_layers=share_attn_layers,
660
+ share_key=True,
661
+ share_query=True,
662
+ share_value=False,
663
+ use_content_anchor=True,
664
+ use_adain=True,
665
+ output_dir="./",
666
+ )
667
+
668
+ return imgs[2]
669
+
670
+ # load examples
671
+ examples = []
672
+ annotation = json.load(open("data/example/annotation.json"))
673
+ for entry in annotation:
674
+ image = utils.exp_utils.get_processed_image(
675
+ entry["image_path"], device, 512
676
+ )
677
+ image = transforms.ToPILImage()(image[0])
678
+
679
+ examples.append([entry["target_prompt"], image, None, None, None])
680
+
681
+ text_input = gr.Textbox(
682
+ value="An impressionist painting",
683
+ label="Text Prompt",
684
+ info="Describe the style you want to apply to the image, do not include the description of the image content itself",
685
+ lines=2,
686
+ placeholder="Enter a text prompt",
687
+ )
688
+ image_input = gr.Image(
689
+ height="80%",
690
+ width="80%",
691
+ label="Content image (will be resized to 512x512)",
692
+ interactive=True,
693
+ )
694
+ cfg_slider = gr.Slider(
695
+ 0,
696
+ 15,
697
+ value=7.5,
698
+ label="Classifier Free Guidance (CFG) Scale",
699
+ info="higher values give more style, 7.5 should be good for most cases",
700
+ )
701
+ content_slider = gr.Slider(
702
+ 0,
703
+ 9,
704
+ value=4,
705
+ step=1,
706
+ label="Number of content control layer",
707
+ info="higher values make it more similar to original image. Default to control first 4 layers",
708
+ )
709
+ style_slider = gr.Slider(
710
+ 0,
711
+ 9,
712
+ value=9,
713
+ step=1,
714
+ label="Number of style control layer",
715
+ info="higher values make it more similar to target style. Default to control first 9 layers, usually not necessary to change.",
716
+ )
717
+ seed_slider = gr.Slider(
718
+ 0,
719
+ 100,
720
+ value=0,
721
+ step=1,
722
+ label="Seed",
723
+ info="Random seed for the model",
724
+ )
725
+ app = gr.Interface(
726
+ fn=style_transfer_app,
727
+ inputs=[
728
+ text_input,
729
+ image_input,
730
+ cfg_slider,
731
+ content_slider,
732
+ style_slider,
733
+ seed_slider,
734
+ ],
735
+ outputs=["image"],
736
+ title="Artist Interactive Demo",
737
+ examples=examples,
738
+ )
739
+ app.launch()