dbaranchuk commited on
Commit
bf363c0
1 Parent(s): 21d434e

Main update

Browse files
Files changed (7) hide show
  1. README.md +1 -1
  2. app.py +364 -91
  3. generation.py +621 -0
  4. inversion.py +104 -0
  5. p2p.py +454 -0
  6. requirements.txt +3 -1
  7. seq_aligner.py +181 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: ICD Image Editing
3
  emoji: 🖼
4
  colorFrom: purple
5
  colorTo: red
 
1
  ---
2
+ title: Demo App Editing
3
  emoji: 🖼
4
  colorFrom: purple
5
  colorTo: red
app.py CHANGED
@@ -1,52 +1,141 @@
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- from diffusers import DiffusionPipeline
5
  import torch
 
 
6
 
 
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
- if torch.cuda.is_available():
10
- torch.cuda.max_memory_allocated(device=device)
11
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
12
- pipe.enable_xformers_memory_efficient_attention()
13
- pipe = pipe.to(device)
14
- else:
15
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
16
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  MAX_IMAGE_SIZE = 1024
20
 
21
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
 
 
 
22
 
23
- if randomize_seed:
24
- seed = random.randint(0, MAX_SEED)
25
-
26
- generator = torch.Generator().manual_seed(seed)
27
-
28
- image = pipe(
29
- prompt = prompt,
30
- negative_prompt = negative_prompt,
31
- guidance_scale = guidance_scale,
32
- num_inference_steps = num_inference_steps,
33
- width = width,
34
- height = height,
35
- generator = generator
36
- ).images[0]
37
-
38
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- examples = [
41
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
42
- "An astronaut riding a green horse",
43
- "A delicious ceviche cheesecake slice",
44
- ]
 
 
 
 
 
 
 
 
 
 
45
 
46
  css="""
47
  #col-container {
48
  margin: 0 auto;
49
- max-width: 520px;
50
  }
51
  """
52
 
@@ -58,89 +147,273 @@ else:
58
  with gr.Blocks(css=css) as demo:
59
 
60
  with gr.Column(elem_id="col-container"):
61
- gr.Markdown(f"""
62
- # Text-to-Image Gradio Template
63
- Currently running on {power_device}.
64
- """)
65
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  with gr.Row():
67
 
68
- prompt = gr.Text(
69
- label="Prompt",
70
- show_label=False,
71
  max_lines=1,
72
  placeholder="Enter your prompt",
73
- container=False,
74
  )
75
-
76
- run_button = gr.Button("Run", scale=0)
77
-
78
- result = gr.Image(label="Result", show_label=False)
79
 
80
- with gr.Accordion("Advanced Settings", open=False):
81
-
82
- negative_prompt = gr.Text(
83
- label="Negative prompt",
84
  max_lines=1,
85
- placeholder="Enter a negative prompt",
86
- visible=False,
87
  )
88
 
89
- seed = gr.Slider(
90
- label="Seed",
91
- minimum=0,
92
- maximum=MAX_SEED,
93
- step=1,
94
- value=0,
95
- )
96
 
97
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
 
98
 
99
  with gr.Row():
100
 
101
- width = gr.Slider(
102
- label="Width",
103
- minimum=256,
104
- maximum=MAX_IMAGE_SIZE,
105
- step=32,
106
- value=512,
107
  )
108
-
109
- height = gr.Slider(
110
- label="Height",
111
- minimum=256,
112
- maximum=MAX_IMAGE_SIZE,
113
- step=32,
114
- value=512,
115
  )
116
-
117
  with gr.Row():
118
 
119
- guidance_scale = gr.Slider(
120
- label="Guidance scale",
121
  minimum=0.0,
122
- maximum=10.0,
123
  step=0.1,
124
- value=0.0,
125
  )
126
-
127
- num_inference_steps = gr.Slider(
128
- label="Number of inference steps",
129
- minimum=1,
130
- maximum=12,
131
- step=1,
132
- value=2,
133
  )
134
-
135
- gr.Examples(
136
- examples = examples,
137
- inputs = [prompt]
138
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  run_button.click(
141
  fn = infer,
142
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
 
 
143
  outputs = [result]
144
  )
145
 
146
- demo.queue().launch()
 
1
+ import spaces
2
  import gradio as gr
3
  import numpy as np
4
  import random
 
5
  import torch
6
+ from diffusers import DDPMScheduler, StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
7
+ import p2p, generation, inversion
8
 
9
+ model_id = 'runwayml/stable-diffusion-v1-5'
10
+ dtype=torch.float16
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ # Reverse
14
+ # -----------------------------
15
+ pipe_reverse = StableDiffusionPipeline.from_pretrained(model_id,
16
+ scheduler=DDIMScheduler.from_pretrained(model_id,
17
+ subfolder="scheduler"),
18
+ ).to(device=device, dtype=dtype)
19
+ unet = UNet2DConditionModel.from_pretrained("dbaranchuk/sd15-cfg-distill-unet").to(device)
20
+ pipe_reverse.unet = unet
21
+ pipe_reverse.load_lora_weights("dbaranchuk/icd-lora-sd15",
22
+ weight_name='reverse-259-519-779-999.safetensors')
23
+ pipe_reverse.fuse_lora()
24
+ pipe_reverse.to(device)
25
+ # -----------------------------
26
+
27
+ # Forward
28
+ # -----------------------------
29
+ pipe_forward = StableDiffusionPipeline.from_pretrained(model_id,
30
+ scheduler=DDIMScheduler.from_pretrained(model_id,
31
+ subfolder="scheduler"),
32
+ ).to(device=device, dtype=dtype)
33
+ unet = UNet2DConditionModel.from_pretrained("dbaranchuk/sd15-cfg-distill-unet").to(device)
34
+ pipe_forward.unet = unet
35
+ pipe_forward.load_lora_weights("dbaranchuk/icd-lora-sd15",
36
+ weight_name='forward-19-259-519-779.safetensors')
37
+ pipe_forward.fuse_lora()
38
+ pipe_forward.to(device)
39
+ # -----------------------------
40
 
41
  MAX_SEED = np.iinfo(np.int32).max
42
  MAX_IMAGE_SIZE = 1024
43
 
44
+ @spaces.GPU(duration=30)
45
+ def infer(image_path, input_prompt, edited_prompt, guidance, tau,
46
+ crs, srs, amplify_factor, amplify_word,
47
+ blend_orig, blend_edited, is_replacement):
48
 
49
+ tokenizer = pipe_forward.tokenizer
50
+ noise_scheduler = DDPMScheduler.from_pretrained(
51
+ "runwayml/stable-diffusion-v1-5", subfolder="scheduler", )
52
+
53
+ NUM_REVERSE_CONS_STEPS = 4
54
+ REVERSE_TIMESTEPS = [259, 519, 779, 999]
55
+ NUM_FORWARD_CONS_STEPS = 4
56
+ FORWARD_TIMESTEPS = [19, 259, 519, 779]
57
+ NUM_DDIM_STEPS = 50
58
+
59
+ solver = generation.Generator(
60
+ model=pipe_forward,
61
+ noise_scheduler=noise_scheduler,
62
+ n_steps=NUM_DDIM_STEPS,
63
+ forward_cons_model=pipe_forward,
64
+ forward_timesteps=FORWARD_TIMESTEPS,
65
+ reverse_cons_model=pipe_reverse,
66
+ reverse_timesteps=REVERSE_TIMESTEPS,
67
+ num_endpoints=NUM_REVERSE_CONS_STEPS,
68
+ num_forward_endpoints=NUM_FORWARD_CONS_STEPS,
69
+ max_forward_timestep_index=49,
70
+ start_timestep=19)
71
+
72
+ p2p.NUM_DDIM_STEPS = NUM_DDIM_STEPS
73
+ p2p.tokenizer = tokenizer
74
+ p2p.device = 'cuda'
75
+
76
+ prompt = [input_prompt]
77
+
78
+ (image_gt, image_rec), ddim_latent, uncond_embeddings = inversion.invert(
79
+ # Playing params
80
+ image_path=image_path,
81
+ prompt=prompt,
82
+
83
+ # Fixed params
84
+ is_cons_inversion=True,
85
+ w_embed_dim=512,
86
+ inv_guidance_scale=0.0,
87
+ stop_step=50,
88
+ solver=solver,
89
+ seed=10500)
90
+
91
+ p2p.NUM_DDIM_STEPS = 4
92
+ p2p.tokenizer = tokenizer
93
+ p2p.device = 'cuda'
94
+
95
+ prompts = [input_prompt,
96
+ edited_prompt
97
+ ]
98
+
99
+ # Playing params
100
+ cross_replace_steps = {'default_': crs, }
101
+ self_replace_steps = srs
102
+ blend_word = (((blend_orig,), (blend_edited,)))
103
+ eq_params = {"words": (amplify_word,), "values": (amplify_factor,)}
104
+
105
+ controller = p2p.make_controller(prompts,
106
+ is_replacement, # (is_replacement) True if only one word is changed
107
+ cross_replace_steps,
108
+ self_replace_steps,
109
+ blend_word,
110
+ eq_params)
111
+
112
+ tau = tau
113
+ image, _ = generation.runner(
114
+ # Playing params
115
+ guidance_scale=guidance-1,
116
+ tau1=tau, # Dynamic guidance if tau < 1.0
117
+ tau2=tau,
118
 
119
+ # Fixed params
120
+ model=pipe_reverse,
121
+ is_cons_forward=True,
122
+ w_embed_dim=512,
123
+ solver=solver,
124
+ prompt=prompts,
125
+ controller=controller,
126
+ num_inference_steps=50,
127
+ generator=None,
128
+ latent=ddim_latent,
129
+ uncond_embeddings=uncond_embeddings,
130
+ return_type='image')
131
+
132
+ image = generation.to_pil_images(image[1, :, :, :])
133
+ return image
134
 
135
  css="""
136
  #col-container {
137
  margin: 0 auto;
138
+ max-width: 1024px;
139
  }
140
  """
141
 
 
147
  with gr.Blocks(css=css) as demo:
148
 
149
  with gr.Column(elem_id="col-container"):
150
+ gr.Markdown(
151
+ f"""
152
+ # Invertible Consistency Distillation ⚡
153
+ # ⚡ Text-guided image editing with 8-step iCD-SD1.5 ⚡
154
+ This is a demo for [Invertible Consistency Distillation](https://yandex-research.github.io/invertible-cd/),
155
+ a diffusion distillation method proposed in [Invertible Consistency Distillation for Text-Guided Image Editing in Around 7 Steps](https://arxiv.org/abs/2406.14539)
156
+ by [Yandex Research](https://github.com/yandex-research).
157
+ Currently running on {power_device}
158
+ """
159
+ )
160
+ gr.Markdown(
161
+ "**Please** check the examples to catch the intuition behind the hyperparameters, which are quite important for successful editing. A short description: <br />1. *Dynamic guidance tau*. Controls the interval where guidance is applied: if t < tau, then guidance is turned on for t < tau."
162
+ " Lower tau values provide better reference preservation. We commonly use tau=0.6 and tau=0.8. <br />"
163
+ "2. *Cross replace steps (crs)* and *self replace steps (srs)*. Controls the time step interval "
164
+ "where the cross- and self-attention maps are replaced. Higher values lead to better preservation of the reference image. "
165
+ "The optimal values depend on the particular image. "
166
+ "Mostly, we use crs and srs from 0.2 to 0.6. <br />"
167
+ "3. *Amplify word* and *Amplify factor*. Define the word that needs to be enhanced in the edited image. <br />"
168
+ "4. *Blended word*. Specifies the object used for making local edits. That is, edit only selected objects. <br />"
169
+ "5. *Is replacement*. You can set True, if you replace only one word in the original prompt. But False also works in these cases."
170
+ )
171
+ gr.Markdown(
172
+ "Feel free to check out our [image generation demo](https://huggingface.co/spaces/dbaranchuk/demo-app) as well."
173
+ )
174
+ gr.Markdown(
175
+ "If you enjoy the space, feel free to give a ⭐ to the <a href='https://github.com/yandex-research/invertible-cd' target='_blank'>Github Repo</a>. [![GitHub Stars](https://img.shields.io/github/stars/yandex-research/invertible-cd?style=social)](https://github.com/yandex-research/invertible-cd)"
176
+ )
177
  with gr.Row():
178
 
179
+ input_prompt = gr.Text(
180
+ label="Origial prompt",
 
181
  max_lines=1,
182
  placeholder="Enter your prompt",
 
183
  )
 
 
 
 
184
 
185
+ prompt = gr.Text(
186
+ label="Edited prompt",
 
 
187
  max_lines=1,
188
+ placeholder="Enter your prompt",
 
189
  )
190
 
191
+
192
+ with gr.Row():
 
 
 
 
 
193
 
194
+ with gr.Column():
195
+ input_image = gr.Image(label="Input image", height=512, width=512, show_label=False)
196
+ with gr.Column():
197
+ result = gr.Image(label="Result", height=512, width=512, show_label=False)
198
+
199
+ with gr.Accordion("Advanced Settings", open=True):
200
 
201
  with gr.Row():
202
 
203
+ guidance_scale = gr.Slider(
204
+ label="Guidance scale",
205
+ minimum=1.0,
206
+ maximum=20.0,
207
+ step=1.0,
208
+ value=20.0,
209
  )
210
+
211
+ tau = gr.Slider(
212
+ label="Dynamic guidance tau",
213
+ minimum=0.0,
214
+ maximum=1.0,
215
+ step=0.2,
216
+ value=0.8,
217
  )
218
+
219
  with gr.Row():
220
 
221
+ crs = gr.Slider(
222
+ label="Cross replace steps",
223
  minimum=0.0,
224
+ maximum=1.0,
225
  step=0.1,
226
+ value=0.4
227
  )
228
+
229
+ srs = gr.Slider(
230
+ label="Self replace steps",
231
+ minimum=0.0,
232
+ maximum=1.0,
233
+ step=0.1,
234
+ value=0.4,
235
  )
236
+
237
+ with gr.Row():
238
+ amplify_word = gr.Text(
239
+ label="Amplify word",
240
+ max_lines=1,
241
+ placeholder="Enter your word",
242
+ )
243
+
244
+ amplify_factor = gr.Slider(
245
+ label="Amplify factor",
246
+ minimum=0.0,
247
+ maximum=30,
248
+ step=1.0,
249
+ value=1,
250
+ )
251
+ with gr.Row():
252
+
253
+ blend_orig = gr.Text(
254
+ label="Blended word 1",
255
+ max_lines=1,
256
+ placeholder="Enter your word",)
257
+
258
+ blend_edited = gr.Text(
259
+ label="Blended word 2",
260
+ max_lines=1,
261
+ placeholder="Enter your word",)
262
+
263
+ with gr.Row():
264
+
265
+ is_replacement = gr.Checkbox(label="Is replacement?", value=False)
266
+
267
+ with gr.Row():
268
+ run_button = gr.Button("Edit", scale=0)
269
+
270
+ with gr.Row():
271
+ examples = [
272
+ [
273
+ "examples/orig_3.jpg", #input_image
274
+ "a photo of a basket of apples", #src_prompt
275
+ "a photo of a basket of oranges", #tgt_prompt
276
+ 20, #guidance_scale
277
+ 0.6, #tau
278
+ 0.4, #crs
279
+ 0.6, #srs
280
+ 1, #amplify factor
281
+ 'oranges', # amplify word
282
+ '', #orig blend
283
+ 'oranges', #edited blend
284
+ False #replacement
285
+ ],
286
+ [
287
+ "examples/orig_3.jpg", #input_image
288
+ "a photo of a basket of apples", #src_prompt
289
+ "a photo of a basket of puppies", #tgt_prompt
290
+ 20, #guidance_scale
291
+ 0.6, #tau
292
+ 0.4, #crs
293
+ 0.1, #srs
294
+ 2, #amplify factor
295
+ 'puppies', # amplify word
296
+ '', #orig blend
297
+ 'puppies', #edited blend
298
+ True #replacement
299
+ ],
300
+ [
301
+ "examples/orig_3.jpg", #input_image
302
+ "a photo of a basket of apples", #src_prompt
303
+ "a photo of a basket of apples under snowfall", #tgt_prompt
304
+ 20, #guidance_scale
305
+ 0.6, #tau
306
+ 0.4, #crs
307
+ 0.4, #srs
308
+ 30, #amplify factor
309
+ 'snowfall', # amplify word
310
+ '', #orig blend
311
+ 'snowfall', #edited blend
312
+ False #replacement
313
+ ],
314
+ [
315
+ "examples/orig_1.jpg", #input_image
316
+ "a photo of an owl", #src_prompt
317
+ "a photo of an yellow owl", #tgt_prompt
318
+ 20, #guidance_scale
319
+ 0.6, #tau
320
+ 0.9, #crs
321
+ 0.9, #srs
322
+ 20, #amplify factor
323
+ 'yellow', # amplify word
324
+ 'owl', #orig blend
325
+ 'yellow', #edited blend
326
+ False #replacement
327
+ ],
328
+ [
329
+ "examples/orig_1.jpg", #input_image
330
+ "a photo of an owl", #src_prompt
331
+ "an anime-style painting of an owl", #tgt_prompt
332
+ 20, #guidance_scale
333
+ 0.8, #tau
334
+ 0.6, #crs
335
+ 0.3, #srs
336
+ 10, #amplify factor
337
+ 'anime-style', # amplify word
338
+ 'painting', #orig blend
339
+ 'anime-style', #edited blend
340
+ False #replacement
341
+ ],
342
+ [
343
+ "examples/orig_1.jpg", #input_image
344
+ "a photo of an owl", #src_prompt
345
+ "a photo of an owl underwater with many fishes nearby", #tgt_prompt
346
+ 20, #guidance_scale
347
+ 0.8, #tau
348
+ 0.4, #crs
349
+ 0.4, #srs
350
+ 18, #amplify factor
351
+ 'fishes', # amplify word
352
+ '', #orig blend
353
+ 'fishes', #edited blend
354
+ False #replacement
355
+ ],
356
+ [
357
+ "examples/orig_2.jpg", #input_image
358
+ "a photograph of a teddy bear sitting on a wall", #src_prompt
359
+ "a photograph of a teddy bear sitting on a wall surrounded by roses", #tgt_prompt
360
+ 20, #guidance_scale
361
+ 0.6, #tau
362
+ 0.4, #crs
363
+ 0.1, #srs
364
+ 25, #amplify factor
365
+ 'roses', # amplify word
366
+ '', #orig blend
367
+ 'roses', #edited blend
368
+ False #replacement
369
+ ],
370
+ [
371
+ "examples/orig_2.jpg", #input_image
372
+ "a photograph of a teddy bear sitting on a wall", #src_prompt
373
+ "a photograph of a wooden bear sitting on a wall", #tgt_prompt
374
+ 20, #guidance_scale
375
+ 0.8, #tau
376
+ 0.5, #crs
377
+ 0.5, #srs
378
+ 14, #amplify factor
379
+ 'wooden', # amplify word
380
+ '', #orig blend
381
+ 'wooden', #edited blend
382
+ True #replacement
383
+ ],
384
+ [
385
+ "examples/orig_2.jpg", #input_image
386
+ "a photograph of a teddy bear sitting on a wall", #src_prompt
387
+ "a photograph of a teddy rabbit sitting on a wall", #tgt_prompt
388
+ 20, #guidance_scale
389
+ 0.8, #tau
390
+ 0.4, #crs
391
+ 0.4, #srs
392
+ 3, #amplify factor
393
+ 'rabbit', # amplify word
394
+ '', #orig blend
395
+ 'rabbit', #edited blend
396
+ True #replacement
397
+ ],
398
+ ]
399
+
400
+ gr.Examples(
401
+ examples = examples,
402
+ inputs =[input_image, input_prompt, prompt,
403
+ guidance_scale, tau, crs, srs, amplify_factor, amplify_word,
404
+ blend_orig, blend_edited, is_replacement],
405
+ outputs=[
406
+ result
407
+ ],
408
+ fn=infer, cache_examples=True
409
+ )
410
 
411
  run_button.click(
412
  fn = infer,
413
+ inputs=[input_image, input_prompt, prompt,
414
+ guidance_scale, tau, crs, srs, amplify_factor, amplify_word,
415
+ blend_orig, blend_edited, is_replacement],
416
  outputs = [result]
417
  )
418
 
419
+ demo.queue().launch()
generation.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ from tqdm import tqdm
5
+ from typing import Union
6
+ from IPython.display import display
7
+ import p2p
8
+
9
+
10
+ # Main function to run
11
+ # ----------------------------------------------------------------------
12
+ @torch.no_grad()
13
+ def runner(
14
+ model,
15
+ prompt,
16
+ controller,
17
+ solver,
18
+ is_cons_forward=False,
19
+ num_inference_steps=50,
20
+ guidance_scale=7.5,
21
+ generator=None,
22
+ latent=None,
23
+ uncond_embeddings=None,
24
+ start_time=50,
25
+ return_type='image',
26
+ dynamic_guidance=False,
27
+ tau1=0.4,
28
+ tau2=0.6,
29
+ w_embed_dim=0,
30
+ ):
31
+ p2p.register_attention_control(model, controller)
32
+ height = width = 512
33
+ solver.init_prompt(prompt, None)
34
+ latent, latents = init_latent(latent, model, 512, 512, generator, len(prompt))
35
+ model.scheduler.set_timesteps(num_inference_steps)
36
+ dynamic_guidance = True if tau1 < 1.0 or tau1 < 1.0 else False
37
+
38
+ if not is_cons_forward:
39
+ latents = solver.ddim_loop(latents,
40
+ num_inference_steps,
41
+ is_forward=False,
42
+ guidance_scale=guidance_scale,
43
+ dynamic_guidance=dynamic_guidance,
44
+ tau1=tau1,
45
+ tau2=tau2,
46
+ w_embed_dim=w_embed_dim,
47
+ uncond_embeddings=uncond_embeddings if uncond_embeddings is not None else None,
48
+ controller=controller)
49
+ latents = latents[-1]
50
+ else:
51
+ latents = solver.cons_generation(
52
+ latents,
53
+ guidance_scale=guidance_scale,
54
+ w_embed_dim=w_embed_dim,
55
+ dynamic_guidance=dynamic_guidance,
56
+ tau1=tau1,
57
+ tau2=tau2,
58
+ controller=controller)
59
+ latents = latents[-1]
60
+
61
+ if return_type == 'image':
62
+ image = latent2image(model.vae, latents.to(model.vae.dtype))
63
+ else:
64
+ image = latents
65
+
66
+ return image, latent
67
+
68
+
69
+ # ----------------------------------------------------------------------
70
+
71
+
72
+ # Utils
73
+ # ----------------------------------------------------------------------
74
+ def linear_schedule_old(t, guidance_scale, tau1, tau2):
75
+ t = t / 1000
76
+ if t <= tau1:
77
+ gamma = 1.0
78
+ elif t >= tau2:
79
+ gamma = 0.0
80
+ else:
81
+ gamma = (tau2 - t) / (tau2 - tau1)
82
+ return gamma * guidance_scale
83
+
84
+
85
+ def linear_schedule(t, guidance_scale, tau1=0.4, tau2=0.8):
86
+ t = t / 1000
87
+ if t <= tau1:
88
+ return guidance_scale
89
+ if t >= tau2:
90
+ return 1.0
91
+ gamma = (tau2 - t) / (tau2 - tau1) * (guidance_scale - 1.0) + 1.0
92
+
93
+ return gamma
94
+
95
+
96
+ def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
97
+ """
98
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
99
+
100
+ Args:
101
+ timesteps (`torch.Tensor`):
102
+ generate embedding vectors at these timesteps
103
+ embedding_dim (`int`, *optional*, defaults to 512):
104
+ dimension of the embeddings to generate
105
+ dtype:
106
+ data type of the generated embeddings
107
+
108
+ Returns:
109
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
110
+ """
111
+ assert len(w.shape) == 1
112
+ w = w * 1000.0
113
+
114
+ half_dim = embedding_dim // 2
115
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
116
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
117
+ emb = w.to(dtype)[:, None] * emb[None, :]
118
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
119
+ if embedding_dim % 2 == 1: # zero pad
120
+ emb = torch.nn.functional.pad(emb, (0, 1))
121
+ assert emb.shape == (w.shape[0], embedding_dim)
122
+ return emb
123
+
124
+
125
+ # ----------------------------------------------------------------------
126
+
127
+
128
+ # Diffusion step with scheduler from diffusers and controller for editing
129
+ # ----------------------------------------------------------------------
130
+ def extract_into_tensor(a, t, x_shape):
131
+ b, *_ = t.shape
132
+ out = a.gather(-1, t)
133
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
134
+
135
+
136
+ def predicted_origin(model_output, timesteps, boundary_timesteps, sample, prediction_type, alphas, sigmas):
137
+ sigmas_s = extract_into_tensor(sigmas, boundary_timesteps, sample.shape)
138
+ alphas_s = extract_into_tensor(alphas, boundary_timesteps, sample.shape)
139
+
140
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
141
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
142
+
143
+ # Set hard boundaries to ensure equivalence with forward (direct) CD
144
+ alphas_s[boundary_timesteps == 0] = 1.0
145
+ sigmas_s[boundary_timesteps == 0] = 0.0
146
+
147
+ if prediction_type == "epsilon":
148
+ pred_x_0 = (sample - sigmas * model_output) / alphas # x0 prediction
149
+ pred_x_0 = alphas_s * pred_x_0 + sigmas_s * model_output # Euler step to the boundary step
150
+ elif prediction_type == "v_prediction":
151
+ assert boundary_timesteps == 0, "v_prediction does not support multiple endpoints at the moment"
152
+ pred_x_0 = alphas * sample - sigmas * model_output
153
+ else:
154
+ raise ValueError(f"Prediction type {prediction_type} currently not supported.")
155
+ return pred_x_0
156
+
157
+
158
+ def guided_step(noise_prediction_text,
159
+ noise_pred_uncond,
160
+ t,
161
+ guidance_scale,
162
+ dynamic_guidance=False,
163
+ tau1=0.4,
164
+ tau2=0.6):
165
+ if dynamic_guidance:
166
+ if not isinstance(t, int):
167
+ t = t.item()
168
+ new_guidance_scale = linear_schedule(t, guidance_scale, tau1=tau1, tau2=tau2)
169
+ else:
170
+ new_guidance_scale = guidance_scale
171
+
172
+ noise_pred = noise_pred_uncond + new_guidance_scale * (noise_prediction_text - noise_pred_uncond)
173
+ return noise_pred
174
+
175
+
176
+ # ----------------------------------------------------------------------
177
+
178
+
179
+ # DDIM scheduler with inversion
180
+ # ----------------------------------------------------------------------
181
+ class Generator:
182
+
183
+ def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
184
+ sample: Union[torch.FloatTensor, np.ndarray]):
185
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
186
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
187
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[
188
+ prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
189
+ beta_prod_t = 1 - alpha_prod_t
190
+ pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
191
+ pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
192
+ prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
193
+ return prev_sample
194
+
195
+ def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
196
+ sample: Union[torch.FloatTensor, np.ndarray]):
197
+ timestep, next_timestep = min(
198
+ timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
199
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
200
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
201
+ beta_prod_t = 1 - alpha_prod_t
202
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
203
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
204
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
205
+ return next_sample
206
+
207
+ def get_noise_pred_single(self, latents, t, context):
208
+ noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"]
209
+ return noise_pred
210
+
211
+ def get_noise_pred(self,
212
+ model,
213
+ latent,
214
+ t,
215
+ guidance_scale=1,
216
+ context=None,
217
+ w_embed_dim=0,
218
+ dynamic_guidance=False,
219
+ tau1=0.4,
220
+ tau2=0.6):
221
+ latents_input = torch.cat([latent] * 2)
222
+ if context is None:
223
+ context = self.context
224
+
225
+ # w embed
226
+ # --------------------------------------
227
+ if w_embed_dim > 0:
228
+ if dynamic_guidance:
229
+ if not isinstance(t, int):
230
+ t_item = t.item()
231
+ guidance_scale = linear_schedule_old(t_item, guidance_scale, tau1=tau1, tau2=tau2) # TODO UPDATE
232
+ if len(latents_input) == 4:
233
+ guidance_scale_tensor = torch.tensor([0.0, 0.0, 0.0, guidance_scale])
234
+ else:
235
+ guidance_scale_tensor = torch.tensor([guidance_scale] * len(latents_input))
236
+ w_embedding = guidance_scale_embedding(guidance_scale_tensor, embedding_dim=w_embed_dim)
237
+ w_embedding = w_embedding.to(device=latent.device, dtype=latent.dtype)
238
+ else:
239
+ w_embedding = None
240
+ # --------------------------------------
241
+ noise_pred = model.unet(latents_input.to(dtype=model.unet.dtype),
242
+ t,
243
+ timestep_cond=w_embedding.to(dtype=model.unet.dtype) if w_embed_dim > 0 else None,
244
+ encoder_hidden_states=context)["sample"]
245
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
246
+
247
+ if guidance_scale > 1 and w_embedding is None:
248
+ noise_pred = guided_step(noise_prediction_text, noise_pred_uncond, t, guidance_scale, dynamic_guidance,
249
+ tau1, tau2)
250
+ else:
251
+ noise_pred = noise_prediction_text
252
+
253
+ return noise_pred
254
+
255
+ @torch.no_grad()
256
+ def latent2image(self, latents, return_type='np'):
257
+ latents = 1 / 0.18215 * latents.detach()
258
+ image = self.model.vae.decode(latents.to(dtype=self.model.dtype))['sample']
259
+ if return_type == 'np':
260
+ image = (image / 2 + 0.5).clamp(0, 1)
261
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
262
+ image = (image * 255).astype(np.uint8)
263
+ return image
264
+
265
+ @torch.no_grad()
266
+ def image2latent(self, image):
267
+ with torch.no_grad():
268
+ if type(image) is Image:
269
+ image = np.array(image)
270
+ if type(image) is torch.Tensor and image.dim() == 4:
271
+ latents = image
272
+ elif type(image) is list:
273
+ image = [np.array(i).reshape(1, 512, 512, 3) for i in image]
274
+ image = np.concatenate(image)
275
+ image = torch.from_numpy(image).float() / 127.5 - 1
276
+ image = image.permute(0, 3, 1, 2).to(self.model.device, dtype=self.model.vae.dtype)
277
+ latents = self.model.vae.encode(image)['latent_dist'].mean
278
+ latents = latents * 0.18215
279
+ else:
280
+ image = torch.from_numpy(image).float() / 127.5 - 1
281
+ image = image.permute(2, 0, 1).unsqueeze(0).to(self.model.device, dtype=self.model.dtype)
282
+ latents = self.model.vae.encode(image)['latent_dist'].mean
283
+ latents = latents * 0.18215
284
+ return latents
285
+
286
+ @torch.no_grad()
287
+ def init_prompt(self, prompt, uncond_embeddings=None):
288
+ if uncond_embeddings is None:
289
+ uncond_input = self.model.tokenizer(
290
+ [""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
291
+ return_tensors="pt"
292
+ )
293
+ uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
294
+ text_input = self.model.tokenizer(
295
+ prompt,
296
+ padding="max_length",
297
+ max_length=self.model.tokenizer.model_max_length,
298
+ truncation=True,
299
+ return_tensors="pt",
300
+ )
301
+ text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
302
+ self.context = torch.cat([uncond_embeddings.expand(*text_embeddings.shape), text_embeddings])
303
+ self.prompt = prompt
304
+
305
+ @torch.no_grad()
306
+ def ddim_loop(self,
307
+ latent,
308
+ n_steps,
309
+ is_forward=True,
310
+ guidance_scale=1,
311
+ dynamic_guidance=False,
312
+ tau1=0.4,
313
+ tau2=0.6,
314
+ w_embed_dim=0,
315
+ uncond_embeddings=None,
316
+ controller=None):
317
+ all_latent = [latent]
318
+ latent = latent.clone().detach()
319
+ for i in tqdm(range(n_steps)):
320
+ if uncond_embeddings is not None:
321
+ self.init_prompt(self.prompt, uncond_embeddings[i])
322
+ if is_forward:
323
+ t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
324
+ else:
325
+ t = self.model.scheduler.timesteps[i]
326
+ noise_pred = self.get_noise_pred(
327
+ model=self.model,
328
+ latent=latent,
329
+ t=t,
330
+ context=None,
331
+ guidance_scale=guidance_scale,
332
+ dynamic_guidance=dynamic_guidance,
333
+ w_embed_dim=w_embed_dim,
334
+ tau1=tau1,
335
+ tau2=tau2)
336
+ if is_forward:
337
+ latent = self.next_step(noise_pred, t, latent)
338
+ else:
339
+ latent = self.prev_step(noise_pred, t, latent)
340
+ if controller is not None:
341
+ latent = controller.step_callback(latent)
342
+ all_latent.append(latent)
343
+ return all_latent
344
+
345
+ @property
346
+ def scheduler(self):
347
+ return self.model.scheduler
348
+
349
+ @torch.no_grad()
350
+ def ddim_inversion(self,
351
+ image,
352
+ n_steps=None,
353
+ guidance_scale=1,
354
+ dynamic_guidance=False,
355
+ tau1=0.4,
356
+ tau2=0.6,
357
+ w_embed_dim=0):
358
+
359
+ if n_steps is None:
360
+ n_steps = self.n_steps
361
+ latent = self.image2latent(image)
362
+ image_rec = self.latent2image(latent)
363
+ ddim_latents = self.ddim_loop(latent,
364
+ is_forward=True,
365
+ guidance_scale=guidance_scale,
366
+ n_steps=n_steps,
367
+ dynamic_guidance=dynamic_guidance,
368
+ tau1=tau1,
369
+ tau2=tau2,
370
+ w_embed_dim=w_embed_dim)
371
+ return image_rec, ddim_latents
372
+
373
+ @torch.no_grad()
374
+ def cons_generation(self,
375
+ latent,
376
+ guidance_scale=1,
377
+ dynamic_guidance=False,
378
+ tau1=0.4,
379
+ tau2=0.6,
380
+ w_embed_dim=0,
381
+ controller=None, ):
382
+
383
+ all_latent = [latent]
384
+ latent = latent.clone().detach()
385
+ alpha_schedule = torch.sqrt(self.model.scheduler.alphas_cumprod).to(self.model.device)
386
+ sigma_schedule = torch.sqrt(1 - self.model.scheduler.alphas_cumprod).to(self.model.device)
387
+
388
+ for i, (t, s) in enumerate(tqdm(zip(self.reverse_timesteps, self.reverse_boundary_timesteps))):
389
+ noise_pred = self.get_noise_pred(
390
+ model=self.reverse_cons_model,
391
+ latent=latent,
392
+ t=t.to(self.model.device),
393
+ context=None,
394
+ tau1=tau1, tau2=tau2,
395
+ w_embed_dim=w_embed_dim,
396
+ guidance_scale=guidance_scale,
397
+ dynamic_guidance=dynamic_guidance)
398
+
399
+ latent = predicted_origin(
400
+ noise_pred,
401
+ torch.tensor([t] * len(latent), device=self.model.device),
402
+ torch.tensor([s] * len(latent), device=self.model.device),
403
+ latent,
404
+ self.model.scheduler.config.prediction_type,
405
+ alpha_schedule,
406
+ sigma_schedule,
407
+ )
408
+ if controller is not None:
409
+ latent = controller.step_callback(latent)
410
+ all_latent.append(latent)
411
+
412
+ return all_latent
413
+
414
+ @torch.no_grad()
415
+ def cons_inversion(self,
416
+ image,
417
+ guidance_scale=0.0,
418
+ w_embed_dim=0,
419
+ seed=0):
420
+ alpha_schedule = torch.sqrt(self.model.scheduler.alphas_cumprod).to(self.model.device)
421
+ sigma_schedule = torch.sqrt(1 - self.model.scheduler.alphas_cumprod).to(self.model.device)
422
+
423
+ # 5. Prepare latent variables
424
+ latent = self.image2latent(image)
425
+ generator = torch.Generator().manual_seed(seed)
426
+ noise = torch.randn(latent.shape, generator=generator).to(latent.device)
427
+ latent = self.noise_scheduler.add_noise(latent, noise, torch.tensor([self.start_timestep]))
428
+ image_rec = self.latent2image(latent)
429
+
430
+ for i, (t, s) in enumerate(tqdm(zip(self.forward_timesteps, self.forward_boundary_timesteps))):
431
+ # predict the noise residual
432
+ noise_pred = self.get_noise_pred(
433
+ model=self.forward_cons_model,
434
+ latent=latent,
435
+ t=t.to(self.model.device),
436
+ context=None,
437
+ guidance_scale=guidance_scale,
438
+ w_embed_dim=w_embed_dim,
439
+ dynamic_guidance=False)
440
+
441
+ latent = predicted_origin(
442
+ noise_pred,
443
+ torch.tensor([t] * len(latent), device=self.model.device),
444
+ torch.tensor([s] * len(latent), device=self.model.device),
445
+ latent,
446
+ self.model.scheduler.config.prediction_type,
447
+ alpha_schedule,
448
+ sigma_schedule,
449
+ )
450
+
451
+ return image_rec, [latent]
452
+
453
+ def _create_forward_inverse_timesteps(self,
454
+ num_endpoints,
455
+ n_steps,
456
+ max_inverse_timestep_index):
457
+ timestep_interval = n_steps // num_endpoints + int(n_steps % num_endpoints > 0)
458
+ endpoint_idxs = torch.arange(timestep_interval, n_steps, timestep_interval) - 1
459
+ inverse_endpoint_idxs = torch.arange(timestep_interval, n_steps, timestep_interval) - 1
460
+ inverse_endpoint_idxs = torch.tensor(inverse_endpoint_idxs.tolist() + [max_inverse_timestep_index])
461
+
462
+ endpoints = torch.tensor([0] + self.ddim_timesteps[endpoint_idxs].tolist())
463
+ inverse_endpoints = self.ddim_timesteps[inverse_endpoint_idxs]
464
+
465
+ return endpoints, inverse_endpoints
466
+
467
+ def __init__(self,
468
+ model,
469
+ n_steps,
470
+ noise_scheduler,
471
+ forward_cons_model=None,
472
+ reverse_cons_model=None,
473
+ num_endpoints=1,
474
+ num_forward_endpoints=1,
475
+ reverse_timesteps=None,
476
+ forward_timesteps=None,
477
+ max_forward_timestep_index=49,
478
+ start_timestep=19):
479
+
480
+ self.model = model
481
+ self.forward_cons_model = forward_cons_model
482
+ self.reverse_cons_model = reverse_cons_model
483
+ self.noise_scheduler = noise_scheduler
484
+
485
+ self.n_steps = n_steps
486
+ self.tokenizer = self.model.tokenizer
487
+ self.model.scheduler.set_timesteps(n_steps)
488
+ self.prompt = None
489
+ self.context = None
490
+ step_ratio = 1000 // n_steps
491
+ self.ddim_timesteps = (np.arange(1, n_steps + 1) * step_ratio).round().astype(np.int64) - 1
492
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
493
+ self.start_timestep = start_timestep
494
+
495
+ # Set endpoints for direct CTM
496
+ if reverse_timesteps is None or forward_timesteps is None:
497
+ endpoints, inverse_endpoints = self._create_forward_inverse_timesteps(num_endpoints, n_steps,
498
+ max_forward_timestep_index)
499
+ self.reverse_timesteps, self.reverse_boundary_timesteps = inverse_endpoints.flip(0), endpoints.flip(0)
500
+
501
+ # Set endpoints for forward CTM
502
+ endpoints, inverse_endpoints = self._create_forward_inverse_timesteps(num_forward_endpoints, n_steps,
503
+ max_forward_timestep_index)
504
+ self.forward_timesteps, self.forward_boundary_timesteps = endpoints, inverse_endpoints
505
+ self.forward_timesteps[0] = self.start_timestep
506
+ else:
507
+ self.reverse_timesteps, self.reverse_boundary_timesteps = reverse_timesteps, reverse_timesteps
508
+ self.reverse_timesteps.reverse()
509
+ self.reverse_boundary_timesteps = self.reverse_boundary_timesteps[1:] + [self.reverse_boundary_timesteps[0]]
510
+ self.reverse_boundary_timesteps[-1] = 0
511
+ self.reverse_timesteps, self.reverse_boundary_timesteps = torch.tensor(reverse_timesteps), torch.tensor(
512
+ self.reverse_boundary_timesteps)
513
+
514
+ self.forward_timesteps, self.forward_boundary_timesteps = forward_timesteps, forward_timesteps
515
+ self.forward_boundary_timesteps = self.forward_boundary_timesteps[1:] + [self.forward_boundary_timesteps[0]]
516
+ self.forward_boundary_timesteps[-1] = 999
517
+ self.forward_timesteps, self.forward_boundary_timesteps = torch.tensor(
518
+ self.forward_timesteps), torch.tensor(self.forward_boundary_timesteps)
519
+
520
+ print(f"Endpoints reverse CTM: {self.reverse_timesteps}, {self.reverse_boundary_timesteps}")
521
+ print(f"Endpoints forward CTM: {self.forward_timesteps}, {self.forward_boundary_timesteps}")
522
+
523
+ # ----------------------------------------------------------------------
524
+
525
+ # 3rd party utils
526
+ # ----------------------------------------------------------------------
527
+ def latent2image(vae, latents):
528
+ latents = 1 / 0.18215 * latents
529
+ image = vae.decode(latents)['sample']
530
+ image = (image / 2 + 0.5).clamp(0, 1)
531
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
532
+ image = (image * 255).astype(np.uint8)
533
+ return image
534
+
535
+
536
+ def init_latent(latent, model, height, width, generator, batch_size):
537
+ if latent is None:
538
+ latent = torch.randn(
539
+ (1, model.unet.in_channels, height // 8, width // 8),
540
+ generator=generator,
541
+ )
542
+ latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
543
+ return latent, latents
544
+
545
+
546
+ def load_512(image_path, left=0, right=0, top=0, bottom=0):
547
+ # if type(image_path) is str:
548
+ # image = np.array(Image.open(image_path))[:, :, :3]
549
+ # else:
550
+ # image = image_path
551
+ # h, w, c = image.shape
552
+ # left = min(left, w - 1)
553
+ # right = min(right, w - left - 1)
554
+ # top = min(top, h - left - 1)
555
+ # bottom = min(bottom, h - top - 1)
556
+ # image = image[top:h - bottom, left:w - right]
557
+ # h, w, c = image.shape
558
+ # if h < w:
559
+ # offset = (w - h) // 2
560
+ # image = image[:, offset:offset + h]
561
+ # elif w < h:
562
+ # offset = (h - w) // 2
563
+ # image = image[offset:offset + w]
564
+ image = np.array(Image.open(image_path).convert('RGB'))[:, :, :3]
565
+ image = np.array(Image.fromarray(image).resize((512, 512)))
566
+ return image
567
+
568
+
569
+ def to_pil_images(images, num_rows=1, offset_ratio=0.02):
570
+ if type(images) is list:
571
+ num_empty = len(images) % num_rows
572
+ elif images.ndim == 4:
573
+ num_empty = images.shape[0] % num_rows
574
+ else:
575
+ images = [images]
576
+ num_empty = 0
577
+
578
+ empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
579
+ images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
580
+ num_items = len(images)
581
+
582
+ h, w, c = images[0].shape
583
+ offset = int(h * offset_ratio)
584
+ num_cols = num_items // num_rows
585
+ image_ = np.ones((h * num_rows + offset * (num_rows - 1),
586
+ w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
587
+ for i in range(num_rows):
588
+ for j in range(num_cols):
589
+ image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
590
+ i * num_cols + j]
591
+
592
+ pil_img = Image.fromarray(image_)
593
+ return pil_img
594
+
595
+
596
+ def view_images(images, num_rows=1, offset_ratio=0.02):
597
+ if type(images) is list:
598
+ num_empty = len(images) % num_rows
599
+ elif images.ndim == 4:
600
+ num_empty = images.shape[0] % num_rows
601
+ else:
602
+ images = [images]
603
+ num_empty = 0
604
+
605
+ empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
606
+ images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
607
+ num_items = len(images)
608
+
609
+ h, w, c = images[0].shape
610
+ offset = int(h * offset_ratio)
611
+ num_cols = num_items // num_rows
612
+ image_ = np.ones((h * num_rows + offset * (num_rows - 1),
613
+ w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
614
+ for i in range(num_rows):
615
+ for j in range(num_cols):
616
+ image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
617
+ i * num_cols + j]
618
+
619
+ pil_img = Image.fromarray(image_)
620
+ display(pil_img)
621
+ # ----------------------------------------------------------------------
inversion.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as nnf
2
+ import torch
3
+ import numpy as np
4
+
5
+ from tqdm import tqdm
6
+ from torch.optim.adam import Adam
7
+ from PIL import Image
8
+
9
+ from generation import load_512
10
+ from p2p import register_attention_control
11
+
12
+
13
+ def null_optimization(solver,
14
+ latents,
15
+ guidance_scale,
16
+ num_inner_steps,
17
+ epsilon):
18
+ uncond_embeddings, cond_embeddings = solver.context.chunk(2)
19
+ uncond_embeddings_list = []
20
+ latent_cur = latents[-1]
21
+ bar = tqdm(total=num_inner_steps * solver.n_steps)
22
+ for i in range(solver.n_steps):
23
+ uncond_embeddings = uncond_embeddings.clone().detach()
24
+ uncond_embeddings.requires_grad = True
25
+ optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
26
+ latent_prev = latents[len(latents) - i - 2]
27
+ t = solver.model.scheduler.timesteps[i]
28
+ with torch.no_grad():
29
+ noise_pred_cond = solver.get_noise_pred_single(latent_cur, t, cond_embeddings)
30
+ for j in range(num_inner_steps):
31
+ noise_pred_uncond = solver.get_noise_pred_single(latent_cur, t, uncond_embeddings)
32
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
33
+ latents_prev_rec = solver.prev_step(noise_pred, t, latent_cur)
34
+ loss = nnf.mse_loss(latents_prev_rec, latent_prev)
35
+ optimizer.zero_grad()
36
+ loss.backward()
37
+ optimizer.step()
38
+ loss_item = loss.item()
39
+ bar.update()
40
+ if loss_item < epsilon + i * 2e-5:
41
+ break
42
+ for j in range(j + 1, num_inner_steps):
43
+ bar.update()
44
+ uncond_embeddings_list.append(uncond_embeddings[:1].detach())
45
+ with torch.no_grad():
46
+ context = torch.cat([uncond_embeddings, cond_embeddings])
47
+ noise_pred = solver.get_noise_pred(solver.model, latent_cur, t, guidance_scale, context)
48
+ latent_cur = solver.prev_step(noise_pred, t, latent_cur)
49
+ bar.close()
50
+ return uncond_embeddings_list
51
+
52
+
53
+ def invert(solver,
54
+ stop_step,
55
+ is_cons_inversion=False,
56
+ inv_guidance_scale=1,
57
+ nti_guidance_scale=8,
58
+ dynamic_guidance=False,
59
+ tau1=0.4,
60
+ tau2=0.6,
61
+ w_embed_dim=0,
62
+ image_path=None,
63
+ prompt='',
64
+ offsets=(0, 0, 0, 0),
65
+ do_nti=False,
66
+ do_npi=False,
67
+ num_inner_steps=10,
68
+ early_stop_epsilon=1e-5,
69
+ seed=0,
70
+ ):
71
+ solver.init_prompt(prompt)
72
+ uncond_embeddings, cond_embeddings = solver.context.chunk(2)
73
+ register_attention_control(solver.model, None)
74
+ if isinstance(image_path, list):
75
+ image_gt = [load_512(path, *offsets) for path in image_path]
76
+ elif isinstance(image_path, str):
77
+ image_gt = load_512(image_path, *offsets)
78
+ else:
79
+ image_gt = np.array(Image.fromarray(image_path).resize((512, 512)))
80
+
81
+ if is_cons_inversion:
82
+ image_rec, ddim_latents = solver.cons_inversion(image_gt,
83
+ w_embed_dim=w_embed_dim,
84
+ guidance_scale=inv_guidance_scale,
85
+ seed=seed,)
86
+ else:
87
+ image_rec, ddim_latents = solver.ddim_inversion(image_gt,
88
+ n_steps=stop_step,
89
+ guidance_scale=inv_guidance_scale,
90
+ dynamic_guidance=dynamic_guidance,
91
+ tau1=tau1, tau2=tau2,
92
+ w_embed_dim=w_embed_dim)
93
+ if do_nti:
94
+ print("Null-text optimization...")
95
+ uncond_embeddings = null_optimization(solver,
96
+ ddim_latents,
97
+ nti_guidance_scale,
98
+ num_inner_steps,
99
+ early_stop_epsilon)
100
+ elif do_npi:
101
+ uncond_embeddings = [cond_embeddings] * solver.n_steps
102
+ else:
103
+ uncond_embeddings = None
104
+ return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings
p2p.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as nnf
2
+ import torch
3
+ import abc
4
+ import numpy as np
5
+ import seq_aligner
6
+
7
+ from typing import Optional, Union, Tuple, List, Callable, Dict
8
+
9
+ MAX_NUM_WORDS = 77
10
+ LOW_RESOURCE = False
11
+ NUM_DDIM_STEPS = 50
12
+ device = 'cuda'
13
+ tokenizer = None
14
+
15
+
16
+ # Different attention controllers
17
+ # ----------------------------------------------------------------------
18
+ class LocalBlend:
19
+
20
+ def get_mask(self, maps, alpha, use_pool, x_t):
21
+ k = 1
22
+ maps = (maps * alpha).sum(-1).mean(1)
23
+ if use_pool:
24
+ maps = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
25
+ mask = nnf.interpolate(maps, size=(x_t.shape[2:]))
26
+ mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
27
+ mask = mask.gt(self.th[1 - int(use_pool)])
28
+ mask = mask[:1] + mask
29
+ return mask
30
+
31
+ def __call__(self, x_t, attention_store):
32
+ self.counter += 1
33
+ if self.counter > self.start_blend:
34
+
35
+ maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
36
+ maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps]
37
+ maps = torch.cat(maps, dim=1)
38
+ mask = self.get_mask(maps, self.alpha_layers, True, x_t)
39
+ if self.substruct_layers is not None:
40
+ maps_sub = ~self.get_mask(maps, self.substruct_layers, False, x_t)
41
+ mask = mask * maps_sub
42
+ mask = mask.float()
43
+ x_t = x_t[:1] + mask * (x_t - x_t[:1])
44
+ return x_t
45
+
46
+ def __init__(self, prompts: List[str], words: [List[List[str]]], substruct_words=None, start_blend=0.2,
47
+ th=(.3, .3)):
48
+ alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
49
+ for i, (prompt, words_) in enumerate(zip(prompts, words)):
50
+ if type(words_) is str:
51
+ words_ = [words_]
52
+ for word in words_:
53
+ ind = get_word_inds(prompt, word, tokenizer)
54
+ alpha_layers[i, :, :, :, :, ind] = 1
55
+
56
+ if substruct_words is not None:
57
+ substruct_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
58
+ for i, (prompt, words_) in enumerate(zip(prompts, substruct_words)):
59
+ if type(words_) is str:
60
+ words_ = [words_]
61
+ for word in words_:
62
+ ind = get_word_inds(prompt, word, tokenizer)
63
+ substruct_layers[i, :, :, :, :, ind] = 1
64
+ self.substruct_layers = substruct_layers.to(device)
65
+ else:
66
+ self.substruct_layers = None
67
+ self.alpha_layers = alpha_layers.to(device)
68
+ self.start_blend = int(start_blend * NUM_DDIM_STEPS)
69
+ self.counter = 0
70
+ self.th = th
71
+
72
+
73
+ class EmptyControl:
74
+
75
+ def step_callback(self, x_t):
76
+ return x_t
77
+
78
+ def between_steps(self):
79
+ return
80
+
81
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
82
+ return attn
83
+
84
+
85
+ class AttentionControl(abc.ABC):
86
+
87
+ def step_callback(self, x_t):
88
+ return x_t
89
+
90
+ def between_steps(self):
91
+ return
92
+
93
+ @property
94
+ def num_uncond_att_layers(self):
95
+ return self.num_att_layers if LOW_RESOURCE else 0
96
+
97
+ @abc.abstractmethod
98
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
99
+ raise NotImplementedError
100
+
101
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
102
+ if self.cur_att_layer >= self.num_uncond_att_layers:
103
+ if LOW_RESOURCE:
104
+ attn = self.forward(attn, is_cross, place_in_unet)
105
+ else:
106
+ h = attn.shape[0]
107
+ attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
108
+ self.cur_att_layer += 1
109
+ if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
110
+ self.cur_att_layer = 0
111
+ self.cur_step += 1
112
+ self.between_steps()
113
+ return attn
114
+
115
+ def reset(self):
116
+ self.cur_step = 0
117
+ self.cur_att_layer = 0
118
+
119
+ def __init__(self):
120
+ self.cur_step = 0
121
+ self.num_att_layers = -1
122
+ self.cur_att_layer = 0
123
+
124
+
125
+ class SpatialReplace(EmptyControl):
126
+
127
+ def step_callback(self, x_t):
128
+ if self.cur_step < self.stop_inject:
129
+ b = x_t.shape[0]
130
+ x_t = x_t[:1].expand(b, *x_t.shape[1:])
131
+ return x_t
132
+
133
+ def __init__(self, stop_inject: float):
134
+ super(SpatialReplace, self).__init__()
135
+ self.stop_inject = int((1 - stop_inject) * NUM_DDIM_STEPS)
136
+
137
+
138
+ class AttentionStore(AttentionControl):
139
+
140
+ @staticmethod
141
+ def get_empty_store():
142
+ return {"down_cross": [], "mid_cross": [], "up_cross": [],
143
+ "down_self": [], "mid_self": [], "up_self": []}
144
+
145
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
146
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
147
+ if attn.shape[1] <= 32 ** 2: # avoid memory overhead
148
+ self.step_store[key].append(attn)
149
+ return attn
150
+
151
+ def between_steps(self):
152
+ if len(self.attention_store) == 0:
153
+ self.attention_store = self.step_store
154
+ else:
155
+ for key in self.attention_store:
156
+ for i in range(len(self.attention_store[key])):
157
+ self.attention_store[key][i] += self.step_store[key][i]
158
+ self.step_store = self.get_empty_store()
159
+
160
+ def get_average_attention(self):
161
+ average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
162
+ self.attention_store}
163
+ return average_attention
164
+
165
+ def reset(self):
166
+ super(AttentionStore, self).reset()
167
+ self.step_store = self.get_empty_store()
168
+ self.attention_store = {}
169
+
170
+ def __init__(self):
171
+ super(AttentionStore, self).__init__()
172
+ self.step_store = self.get_empty_store()
173
+ self.attention_store = {}
174
+
175
+
176
+ class AttentionControlEdit(AttentionStore, abc.ABC):
177
+
178
+ def step_callback(self, x_t):
179
+ if self.local_blend is not None:
180
+ x_t = self.local_blend(x_t, self.attention_store)
181
+ return x_t
182
+
183
+ def replace_self_attention(self, attn_base, att_replace, place_in_unet):
184
+ if att_replace.shape[2] <= 32 ** 2:
185
+ attn_base = attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
186
+ return attn_base
187
+ else:
188
+ return att_replace
189
+
190
+ @abc.abstractmethod
191
+ def replace_cross_attention(self, attn_base, att_replace):
192
+ raise NotImplementedError
193
+
194
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
195
+ super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
196
+ if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
197
+ h = attn.shape[0] // (self.batch_size)
198
+ attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
199
+ attn_base, attn_repalce = attn[0], attn[1:]
200
+ if is_cross:
201
+ alpha_words = self.cross_replace_alpha[self.cur_step]
202
+ attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (
203
+ 1 - alpha_words) * attn_repalce
204
+ attn[1:] = attn_repalce_new
205
+ else:
206
+ attn[1:] = self.replace_self_attention(attn_base, attn_repalce, place_in_unet)
207
+ attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
208
+ return attn
209
+
210
+ def __init__(self, prompts, num_steps: int,
211
+ cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
212
+ self_replace_steps: Union[float, Tuple[float, float]],
213
+ local_blend: Optional[LocalBlend]):
214
+ super(AttentionControlEdit, self).__init__()
215
+ self.batch_size = len(prompts)
216
+ self.cross_replace_alpha = get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps,
217
+ tokenizer).to(device)
218
+ if type(self_replace_steps) is float:
219
+ self_replace_steps = 0, self_replace_steps
220
+ self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
221
+ self.local_blend = local_blend
222
+
223
+
224
+ class AttentionReplace(AttentionControlEdit):
225
+
226
+ def replace_cross_attention(self, attn_base, att_replace):
227
+ return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
228
+
229
+ def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
230
+ local_blend: Optional[LocalBlend] = None):
231
+ super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
232
+ self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
233
+
234
+
235
+ class AttentionRefine(AttentionControlEdit):
236
+
237
+ def replace_cross_attention(self, attn_base, att_replace):
238
+ attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
239
+ attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
240
+ # attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True)
241
+ return attn_replace
242
+
243
+ def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
244
+ local_blend: Optional[LocalBlend] = None):
245
+ super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
246
+ self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
247
+ self.mapper, alphas = self.mapper.to(device), alphas.to(device)
248
+ self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
249
+
250
+
251
+ class AttentionReweight(AttentionControlEdit):
252
+
253
+ def replace_cross_attention(self, attn_base, att_replace):
254
+ if self.prev_controller is not None:
255
+ attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
256
+ attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
257
+ # attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True)
258
+ return attn_replace
259
+
260
+ def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer,
261
+ local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None):
262
+ super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps,
263
+ local_blend)
264
+ self.equalizer = equalizer.to(device)
265
+ self.prev_controller = controller
266
+ self.attn = []
267
+ # ----------------------------------------------------------------------
268
+
269
+
270
+ # Attention controller during sampling
271
+ # ----------------------------------------------------------------------
272
+ def make_controller(prompts: List[str], is_replace_controller: bool, cross_replace_steps: Dict[str, float],
273
+ self_replace_steps: float, blend_words=None, equilizer_params=None) -> AttentionControlEdit:
274
+ if blend_words is None:
275
+ lb = None
276
+ else:
277
+ lb = LocalBlend(prompts, blend_words, start_blend=0.0, th=(0.3, 0.3))
278
+ if is_replace_controller:
279
+ controller = AttentionReplace(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps,
280
+ self_replace_steps=self_replace_steps, local_blend=lb)
281
+ else:
282
+ controller = AttentionRefine(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps,
283
+ self_replace_steps=self_replace_steps, local_blend=lb)
284
+ if equilizer_params is not None:
285
+ eq = get_equalizer(prompts[1], equilizer_params["words"], equilizer_params["values"])
286
+ controller = AttentionReweight(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps,
287
+ self_replace_steps=self_replace_steps, equalizer=eq, local_blend=lb,
288
+ controller=controller)
289
+ return controller
290
+
291
+ def register_attention_control(model, controller):
292
+ def ca_forward(self, place_in_unet):
293
+ to_out = self.to_out
294
+ if type(to_out) is torch.nn.modules.container.ModuleList:
295
+ to_out = self.to_out[0]
296
+ else:
297
+ to_out = self.to_out
298
+
299
+ def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ):
300
+ is_cross = encoder_hidden_states is not None
301
+
302
+ residual = hidden_states
303
+
304
+ if self.spatial_norm is not None:
305
+ hidden_states = self.spatial_norm(hidden_states, temb)
306
+
307
+ input_ndim = hidden_states.ndim
308
+
309
+ if input_ndim == 4:
310
+ batch_size, channel, height, width = hidden_states.shape
311
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
312
+
313
+ batch_size, sequence_length, _ = (
314
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
315
+ )
316
+ attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
317
+
318
+ if self.group_norm is not None:
319
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
320
+
321
+ query = self.to_q(hidden_states)
322
+
323
+ if encoder_hidden_states is None:
324
+ encoder_hidden_states = hidden_states
325
+ elif self.norm_cross:
326
+ encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
327
+
328
+ key = self.to_k(encoder_hidden_states)
329
+ value = self.to_v(encoder_hidden_states)
330
+
331
+ query = self.head_to_batch_dim(query)
332
+ key = self.head_to_batch_dim(key)
333
+ value = self.head_to_batch_dim(value)
334
+
335
+ attention_probs = self.get_attention_scores(query, key, attention_mask)
336
+ attention_probs = controller(attention_probs, is_cross, place_in_unet)
337
+
338
+ hidden_states = torch.bmm(attention_probs, value)
339
+ hidden_states = self.batch_to_head_dim(hidden_states)
340
+
341
+ # linear proj
342
+ hidden_states = to_out(hidden_states)
343
+
344
+ if input_ndim == 4:
345
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
346
+
347
+ if self.residual_connection:
348
+ hidden_states = hidden_states + residual
349
+
350
+ hidden_states = hidden_states / self.rescale_output_factor
351
+
352
+ return hidden_states
353
+
354
+ return forward
355
+
356
+ class DummyController:
357
+
358
+ def __call__(self, *args):
359
+ return args[0]
360
+
361
+ def __init__(self):
362
+ self.num_att_layers = 0
363
+
364
+ if controller is None:
365
+ controller = DummyController()
366
+
367
+ def register_recr(net_, count, place_in_unet):
368
+ if net_.__class__.__name__ == 'Attention':
369
+ net_.forward = ca_forward(net_, place_in_unet)
370
+ return count + 1
371
+ elif hasattr(net_, 'children'):
372
+ for net__ in net_.children():
373
+ count = register_recr(net__, count, place_in_unet)
374
+ return count
375
+
376
+ cross_att_count = 0
377
+ sub_nets = model.unet.named_children()
378
+ for net in sub_nets:
379
+ if "down" in net[0]:
380
+ cross_att_count += register_recr(net[1], 0, "down")
381
+ elif "up" in net[0]:
382
+ cross_att_count += register_recr(net[1], 0, "up")
383
+ elif "mid" in net[0]:
384
+ cross_att_count += register_recr(net[1], 0, "mid")
385
+
386
+ controller.num_att_layers = cross_att_count
387
+ # ----------------------------------------------------------------------
388
+
389
+ # Other
390
+ # ----------------------------------------------------------------------
391
+ def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float],
392
+ Tuple[float, ...]]):
393
+ if type(word_select) is int or type(word_select) is str:
394
+ word_select = (word_select,)
395
+ equalizer = torch.ones(1, 77)
396
+
397
+ for word, val in zip(word_select, values):
398
+ inds = get_word_inds(text, word, tokenizer)
399
+ equalizer[:, inds] = val
400
+ return equalizer
401
+
402
+ def get_time_words_attention_alpha(prompts, num_steps,
403
+ cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
404
+ tokenizer, max_num_words=77):
405
+ if type(cross_replace_steps) is not dict:
406
+ cross_replace_steps = {"default_": cross_replace_steps}
407
+ if "default_" not in cross_replace_steps:
408
+ cross_replace_steps["default_"] = (0., 1.)
409
+ alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
410
+ for i in range(len(prompts) - 1):
411
+ alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
412
+ i)
413
+ for key, item in cross_replace_steps.items():
414
+ if key != "default_":
415
+ inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
416
+ for i, ind in enumerate(inds):
417
+ if len(ind) > 0:
418
+ alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
419
+ alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
420
+ return alpha_time_words
421
+
422
+ def get_word_inds(text: str, word_place: int, tokenizer):
423
+ split_text = text.split(" ")
424
+ if type(word_place) is str:
425
+ word_place = [i for i, word in enumerate(split_text) if word_place == word]
426
+ elif type(word_place) is int:
427
+ word_place = [word_place]
428
+ out = []
429
+ if len(word_place) > 0:
430
+ words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
431
+ cur_len, ptr = 0, 0
432
+
433
+ for i in range(len(words_encode)):
434
+ cur_len += len(words_encode[i])
435
+ if ptr in word_place:
436
+ out.append(i + 1)
437
+ if cur_len >= len(split_text[ptr]):
438
+ ptr += 1
439
+ cur_len = 0
440
+ return np.array(out)
441
+
442
+
443
+ def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
444
+ word_inds: Optional[torch.Tensor] = None):
445
+ if type(bounds) is float:
446
+ bounds = 0, bounds
447
+ start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
448
+ if word_inds is None:
449
+ word_inds = torch.arange(alpha.shape[2])
450
+ alpha[: start, prompt_ind, word_inds] = 0
451
+ alpha[start: end, prompt_ind, word_inds] = 1
452
+ alpha[end:, prompt_ind, word_inds] = 0
453
+ return alpha
454
+ # ----------------------------------------------------------------------
requirements.txt CHANGED
@@ -2,5 +2,7 @@ accelerate
2
  diffusers
3
  invisible_watermark
4
  torch
 
5
  transformers
6
- xformers
 
 
2
  diffusers
3
  invisible_watermark
4
  torch
5
+ peft
6
  transformers
7
+ xformers
8
+ ipython
seq_aligner.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class ScoreParams:
6
+
7
+ def __init__(self, gap, match, mismatch):
8
+ self.gap = gap
9
+ self.match = match
10
+ self.mismatch = mismatch
11
+
12
+ def mis_match_char(self, x, y):
13
+ if x != y:
14
+ return self.mismatch
15
+ else:
16
+ return self.match
17
+
18
+
19
+ def get_matrix(size_x, size_y, gap):
20
+ matrix = []
21
+ for i in range(len(size_x) + 1):
22
+ sub_matrix = []
23
+ for j in range(len(size_y) + 1):
24
+ sub_matrix.append(0)
25
+ matrix.append(sub_matrix)
26
+ for j in range(1, len(size_y) + 1):
27
+ matrix[0][j] = j * gap
28
+ for i in range(1, len(size_x) + 1):
29
+ matrix[i][0] = i * gap
30
+ return matrix
31
+
32
+
33
+ def get_matrix(size_x, size_y, gap):
34
+ matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
35
+ matrix[0, 1:] = (np.arange(size_y) + 1) * gap
36
+ matrix[1:, 0] = (np.arange(size_x) + 1) * gap
37
+ return matrix
38
+
39
+
40
+ def get_traceback_matrix(size_x, size_y):
41
+ matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
42
+ matrix[0, 1:] = 1
43
+ matrix[1:, 0] = 2
44
+ matrix[0, 0] = 4
45
+ return matrix
46
+
47
+
48
+ def global_align(x, y, score):
49
+ matrix = get_matrix(len(x), len(y), score.gap)
50
+ trace_back = get_traceback_matrix(len(x), len(y))
51
+ for i in range(1, len(x) + 1):
52
+ for j in range(1, len(y) + 1):
53
+ left = matrix[i, j - 1] + score.gap
54
+ up = matrix[i - 1, j] + score.gap
55
+ diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
56
+ matrix[i, j] = max(left, up, diag)
57
+ if matrix[i, j] == left:
58
+ trace_back[i, j] = 1
59
+ elif matrix[i, j] == up:
60
+ trace_back[i, j] = 2
61
+ else:
62
+ trace_back[i, j] = 3
63
+ return matrix, trace_back
64
+
65
+
66
+ def get_aligned_sequences(x, y, trace_back):
67
+ x_seq = []
68
+ y_seq = []
69
+ i = len(x)
70
+ j = len(y)
71
+ mapper_y_to_x = []
72
+ while i > 0 or j > 0:
73
+ if trace_back[i, j] == 3:
74
+ x_seq.append(x[i - 1])
75
+ y_seq.append(y[j - 1])
76
+ i = i - 1
77
+ j = j - 1
78
+ mapper_y_to_x.append((j, i))
79
+ elif trace_back[i][j] == 1:
80
+ x_seq.append('-')
81
+ y_seq.append(y[j - 1])
82
+ j = j - 1
83
+ mapper_y_to_x.append((j, -1))
84
+ elif trace_back[i][j] == 2:
85
+ x_seq.append(x[i - 1])
86
+ y_seq.append('-')
87
+ i = i - 1
88
+ elif trace_back[i][j] == 4:
89
+ break
90
+ mapper_y_to_x.reverse()
91
+ return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
92
+
93
+
94
+ def get_mapper(x: str, y: str, tokenizer, max_len=77):
95
+ x_seq = tokenizer.encode(x)
96
+ y_seq = tokenizer.encode(y)
97
+ score = ScoreParams(0, 1, -1)
98
+ matrix, trace_back = global_align(x_seq, y_seq, score)
99
+ mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
100
+ alphas = torch.ones(max_len)
101
+ alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
102
+ mapper = torch.zeros(max_len, dtype=torch.int64)
103
+ mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
104
+ mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
105
+ return mapper, alphas
106
+
107
+
108
+ def get_refinement_mapper(prompts, tokenizer, max_len=77):
109
+ x_seq = prompts[0]
110
+ mappers, alphas = [], []
111
+ for i in range(1, len(prompts)):
112
+ mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
113
+ mappers.append(mapper)
114
+ alphas.append(alpha)
115
+ return torch.stack(mappers), torch.stack(alphas)
116
+
117
+
118
+ def get_word_inds(text: str, word_place: int, tokenizer):
119
+ split_text = text.split(" ")
120
+ if type(word_place) is str:
121
+ word_place = [i for i, word in enumerate(split_text) if word_place == word]
122
+ elif type(word_place) is int:
123
+ word_place = [word_place]
124
+ out = []
125
+ if len(word_place) > 0:
126
+ words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
127
+ cur_len, ptr = 0, 0
128
+
129
+ for i in range(len(words_encode)):
130
+ cur_len += len(words_encode[i])
131
+ if ptr in word_place:
132
+ out.append(i + 1)
133
+ if cur_len >= len(split_text[ptr]):
134
+ ptr += 1
135
+ cur_len = 0
136
+ return np.array(out)
137
+
138
+
139
+ def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
140
+ words_x = x.split(' ')
141
+ words_y = y.split(' ')
142
+ if len(words_x) != len(words_y):
143
+ raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
144
+ f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
145
+ inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
146
+ inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
147
+ inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
148
+ mapper = np.zeros((max_len, max_len))
149
+ i = j = 0
150
+ cur_inds = 0
151
+ while i < max_len and j < max_len:
152
+ if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
153
+ inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
154
+ if len(inds_source_) == len(inds_target_):
155
+ mapper[inds_source_, inds_target_] = 1
156
+ else:
157
+ ratio = 1 / len(inds_target_)
158
+ for i_t in inds_target_:
159
+ mapper[inds_source_, i_t] = ratio
160
+ cur_inds += 1
161
+ i += len(inds_source_)
162
+ j += len(inds_target_)
163
+ elif cur_inds < len(inds_source):
164
+ mapper[i, j] = 1
165
+ i += 1
166
+ j += 1
167
+ else:
168
+ mapper[j, j] = 1
169
+ i += 1
170
+ j += 1
171
+
172
+ return torch.from_numpy(mapper).float()
173
+
174
+
175
+ def get_replacement_mapper(prompts, tokenizer, max_len=77):
176
+ x_seq = prompts[0]
177
+ mappers = []
178
+ for i in range(1, len(prompts)):
179
+ mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
180
+ mappers.append(mapper)
181
+ return torch.stack(mappers)