Linoy Tsaban commited on
Commit
d754544
1 Parent(s): d988f61

Update app.py

Browse files

testing interm yielding

Files changed (1) hide show
  1. app.py +378 -4
app.py CHANGED
@@ -5,10 +5,317 @@ from io import BytesIO
5
  from diffusers import StableDiffusionPipeline
6
  from diffusers import DDIMScheduler
7
  from utils import *
8
- from inversion_utils import *
9
  from torch import autocast, inference_mode
10
  import re
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
13
 
14
  # inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
@@ -34,7 +341,7 @@ def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta
34
  def sample(wt, zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
35
 
36
  # reverse process (via Zs and wT)
37
- w0, _ = inversion_reverse_process(sd_pipe, xT=wts[skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[skip:])
38
 
39
  # vae decode image
40
  with autocast("cuda"), inference_mode():
@@ -91,9 +398,76 @@ def edit(input_image,
91
  wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
92
 
93
  #
94
- output = sample(wt, zs, wts, prompt_tar=tar_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- return output
97
 
98
 
99
 
 
5
  from diffusers import StableDiffusionPipeline
6
  from diffusers import DDIMScheduler
7
  from utils import *
8
+ # from inversion_utils import *
9
  from torch import autocast, inference_mode
10
  import re
11
 
12
+
13
+ ############################################################################################################################################################################
14
+ import torch
15
+ import os
16
+ from tqdm import tqdm
17
+ from PIL import Image, ImageDraw ,ImageFont
18
+ from matplotlib import pyplot as plt
19
+ import torchvision.transforms as T
20
+ import os
21
+ import yaml
22
+ import numpy as np
23
+ import gradio as gr
24
+
25
+
26
+ def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None):
27
+ if type(image_path) is str:
28
+ image = np.array(Image.open(image_path).convert('RGB'))[:, :, :3]
29
+ else:
30
+ image = image_path
31
+ h, w, c = image.shape
32
+ left = min(left, w-1)
33
+ right = min(right, w - left - 1)
34
+ top = min(top, h - left - 1)
35
+ bottom = min(bottom, h - top - 1)
36
+ image = image[top:h-bottom, left:w-right]
37
+ h, w, c = image.shape
38
+ if h < w:
39
+ offset = (w - h) // 2
40
+ image = image[:, offset:offset + h]
41
+ elif w < h:
42
+ offset = (h - w) // 2
43
+ image = image[offset:offset + w]
44
+ image = np.array(Image.fromarray(image).resize((512, 512)))
45
+ image = torch.from_numpy(image).float() / 127.5 - 1
46
+ image = image.permute(2, 0, 1).unsqueeze(0).to(device)
47
+
48
+ return image
49
+
50
+
51
+ def load_real_image(folder = "data/", img_name = None, idx = 0, img_size=512, device='cuda'):
52
+ from PIL import Image
53
+ from glob import glob
54
+ if img_name is not None:
55
+ path = os.path.join(folder, img_name)
56
+ else:
57
+ path = glob(folder + "*")[idx]
58
+
59
+ img = Image.open(path).resize((img_size,
60
+ img_size))
61
+
62
+ img = pil_to_tensor(img).to(device)
63
+
64
+ if img.shape[1]== 4:
65
+ img = img[:,:3,:,:]
66
+ return img
67
+
68
+ def mu_tilde(model, xt,x0, timestep):
69
+ "mu_tilde(x_t, x_0) DDPM paper eq. 7"
70
+ prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
71
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
72
+ alpha_t = model.scheduler.alphas[timestep]
73
+ beta_t = 1 - alpha_t
74
+ alpha_bar = model.scheduler.alphas_cumprod[timestep]
75
+ return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1-alpha_bar)) * x0 + ((alpha_t**0.5 *(1-alpha_prod_t_prev)) / (1- alpha_bar))*xt
76
+
77
+ def sample_xts_from_x0(model, x0, num_inference_steps=50):
78
+ """
79
+ Samples from P(x_1:T|x_0)
80
+ """
81
+ # torch.manual_seed(43256465436)
82
+ alpha_bar = model.scheduler.alphas_cumprod
83
+ sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5
84
+ alphas = model.scheduler.alphas
85
+ betas = 1 - alphas
86
+ variance_noise_shape = (
87
+ num_inference_steps,
88
+ model.unet.in_channels,
89
+ model.unet.sample_size,
90
+ model.unet.sample_size)
91
+
92
+ timesteps = model.scheduler.timesteps.to(model.device)
93
+ t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
94
+ xts = torch.zeros(variance_noise_shape).to(x0.device)
95
+ for t in reversed(timesteps):
96
+ idx = t_to_idx[int(t)]
97
+ xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
98
+ xts = torch.cat([xts, x0 ],dim = 0)
99
+
100
+ return xts
101
+
102
+ def encode_text(model, prompts):
103
+ text_input = model.tokenizer(
104
+ prompts,
105
+ padding="max_length",
106
+ max_length=model.tokenizer.model_max_length,
107
+ truncation=True,
108
+ return_tensors="pt",
109
+ )
110
+ with torch.no_grad():
111
+ text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0]
112
+ return text_encoding
113
+
114
+ def forward_step(model, model_output, timestep, sample):
115
+ next_timestep = min(model.scheduler.config.num_train_timesteps - 2,
116
+ timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps)
117
+
118
+ # 2. compute alphas, betas
119
+ alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
120
+ # alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 else self.scheduler.final_alpha_cumprod
121
+
122
+ beta_prod_t = 1 - alpha_prod_t
123
+
124
+ # 3. compute predicted original sample from predicted noise also called
125
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
126
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
127
+
128
+ # 5. TODO: simple noising implementatiom
129
+ next_sample = model.scheduler.add_noise(pred_original_sample,
130
+ model_output,
131
+ torch.LongTensor([next_timestep]))
132
+ return next_sample
133
+
134
+
135
+ def get_variance(model, timestep): #, prev_timestep):
136
+ prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
137
+ alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
138
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
139
+ beta_prod_t = 1 - alpha_prod_t
140
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
141
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
142
+ return variance
143
+
144
+ def inversion_forward_process(model, x0,
145
+ etas = None,
146
+ prog_bar = False,
147
+ prompt = "",
148
+ cfg_scale = 3.5,
149
+ num_inference_steps=50, eps = None
150
+ ):
151
+
152
+ if not prompt=="":
153
+ text_embeddings = encode_text(model, prompt)
154
+ uncond_embedding = encode_text(model, "")
155
+ timesteps = model.scheduler.timesteps.to(model.device)
156
+ variance_noise_shape = (
157
+ num_inference_steps,
158
+ model.unet.in_channels,
159
+ model.unet.sample_size,
160
+ model.unet.sample_size)
161
+ if etas is None or (type(etas) in [int, float] and etas == 0):
162
+ eta_is_zero = True
163
+ zs = None
164
+ else:
165
+ eta_is_zero = False
166
+ if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
167
+ xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps)
168
+ alpha_bar = model.scheduler.alphas_cumprod
169
+ zs = torch.zeros(size=variance_noise_shape, device=model.device)
170
+
171
+ t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
172
+ xt = x0
173
+ op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps)
174
+
175
+ for t in op:
176
+ idx = t_to_idx[int(t)]
177
+ # 1. predict noise residual
178
+ if not eta_is_zero:
179
+ xt = xts[idx][None]
180
+
181
+ with torch.no_grad():
182
+ out = model.unet.forward(xt, timestep = t, encoder_hidden_states = uncond_embedding)
183
+ if not prompt=="":
184
+ cond_out = model.unet.forward(xt, timestep=t, encoder_hidden_states = text_embeddings)
185
+
186
+ if not prompt=="":
187
+ ## classifier free guidance
188
+ noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample)
189
+ else:
190
+ noise_pred = out.sample
191
+
192
+ if eta_is_zero:
193
+ # 2. compute more noisy image and set x_t -> x_t+1
194
+ xt = forward_step(model, noise_pred, t, xt)
195
+
196
+ else:
197
+ xtm1 = xts[idx+1][None]
198
+ # pred of x0
199
+ pred_original_sample = (xt - (1-alpha_bar[t]) ** 0.5 * noise_pred ) / alpha_bar[t] ** 0.5
200
+
201
+ # direction to xt
202
+ prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
203
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
204
+
205
+ variance = get_variance(model, t)
206
+ pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance ) ** (0.5) * noise_pred
207
+
208
+ mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
209
+
210
+ z = (xtm1 - mu_xt ) / ( etas[idx] * variance ** 0.5 )
211
+ zs[idx] = z
212
+
213
+ # correction to avoid error accumulation
214
+ xtm1 = mu_xt + ( etas[idx] * variance ** 0.5 )*z
215
+ xts[idx+1] = xtm1
216
+
217
+ if not zs is None:
218
+ zs[-1] = torch.zeros_like(zs[-1])
219
+
220
+ return xt, zs, xts
221
+
222
+
223
+ def reverse_step(model, model_output, timestep, sample, eta = 0, variance_noise=None):
224
+ # 1. get previous step value (=t-1)
225
+ prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
226
+ # 2. compute alphas, betas
227
+ alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
228
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
229
+ beta_prod_t = 1 - alpha_prod_t
230
+ # 3. compute predicted original sample from predicted noise also called
231
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
232
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
233
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
234
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
235
+ # variance = self.scheduler._get_variance(timestep, prev_timestep)
236
+ variance = get_variance(model, timestep) #, prev_timestep)
237
+ std_dev_t = eta * variance ** (0.5)
238
+ # Take care of asymetric reverse process (asyrp)
239
+ model_output_direction = model_output
240
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
241
+ # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
242
+ pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
243
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
244
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
245
+ # 8. Add noice if eta > 0
246
+ if eta > 0:
247
+ if variance_noise is None:
248
+ variance_noise = torch.randn(model_output.shape, device=model.device)
249
+ sigma_z = eta * variance ** (0.5) * variance_noise
250
+ prev_sample = prev_sample + sigma_z
251
+
252
+ return prev_sample
253
+
254
+ def inversion_reverse_process(model,
255
+ xT,
256
+ etas = 0,
257
+ prompts = "",
258
+ cfg_scales = None,
259
+ prog_bar = False,
260
+ zs = None,
261
+ controller=None,
262
+ asyrp = False
263
+ ):
264
+
265
+ batch_size = len(prompts)
266
+
267
+ cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(model.device)
268
+
269
+ text_embeddings = encode_text(model, prompts)
270
+ uncond_embedding = encode_text(model, [""] * batch_size)
271
+
272
+ if etas is None: etas = 0
273
+ if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
274
+ assert len(etas) == model.scheduler.num_inference_steps
275
+ timesteps = model.scheduler.timesteps.to(model.device)
276
+
277
+ xt = xT.expand(batch_size, -1, -1, -1)
278
+ op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
279
+
280
+ t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
281
+
282
+ for t in op:
283
+ idx = t_to_idx[int(t)]
284
+ ## Unconditional embedding
285
+ with torch.no_grad():
286
+ uncond_out = model.unet.forward(xt, timestep = t,
287
+ encoder_hidden_states = uncond_embedding)
288
+
289
+ ## Conditional embedding
290
+ if prompts:
291
+ with torch.no_grad():
292
+ cond_out = model.unet.forward(xt, timestep = t,
293
+ encoder_hidden_states = text_embeddings)
294
+
295
+
296
+ z = zs[idx] if not zs is None else None
297
+ z = z.expand(batch_size, -1, -1, -1)
298
+ if prompts:
299
+ ## classifier free guidance
300
+ noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
301
+ else:
302
+ noise_pred = uncond_out.sample
303
+ # 2. compute less noisy image and set x_t -> x_t-1
304
+ xt = reverse_step(model, noise_pred, t, xt, eta = etas[idx], variance_noise = z)
305
+
306
+ # interm denoised img
307
+ with autocast("cuda"), inference_mode():
308
+ x0_dec = sd_pipe.vae.decode(1 / 0.18215 * xt).sample
309
+ if x0_dec.dim()<4:
310
+ x0_dec = x0_dec[None,:,:,:]
311
+ interm_img = image_grid(x0_dec)
312
+ yield interm_img
313
+
314
+ if controller is not None:
315
+ xt = controller.step_callback(xt)
316
+ return xt, zs
317
+ ############################################################################################################################################################################
318
+
319
  def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
320
 
321
  # inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
 
341
  def sample(wt, zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
342
 
343
  # reverse process (via Zs and wT)
344
+ w0, _ = inversion_reverse_process(sd_pipe, xT=wts[skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=False, zs=zs[skip:])
345
 
346
  # vae decode image
347
  with autocast("cuda"), inference_mode():
 
398
  wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
399
 
400
  #
401
+ xT=wts[skip]
402
+ etas=eta
403
+ prompts=[prompt_tar]
404
+ cfg_scales=[cfg_scale_tar]
405
+ prog_bar=False
406
+ zs=zs[skip:]
407
+
408
+
409
+ batch_size = len(prompts)
410
+
411
+ cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(model.device)
412
+
413
+ text_embeddings = encode_text(model, prompts)
414
+ uncond_embedding = encode_text(model, [""] * batch_size)
415
+
416
+ if etas is None: etas = 0
417
+ if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
418
+ assert len(etas) == model.scheduler.num_inference_steps
419
+ timesteps = model.scheduler.timesteps.to(model.device)
420
+
421
+ xt = xT.expand(batch_size, -1, -1, -1)
422
+ op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
423
+
424
+ t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
425
+
426
+ for t in op:
427
+ idx = t_to_idx[int(t)]
428
+ ## Unconditional embedding
429
+ with torch.no_grad():
430
+ uncond_out = model.unet.forward(xt, timestep = t,
431
+ encoder_hidden_states = uncond_embedding)
432
+
433
+ ## Conditional embedding
434
+ if prompts:
435
+ with torch.no_grad():
436
+ cond_out = model.unet.forward(xt, timestep = t,
437
+ encoder_hidden_states = text_embeddings)
438
+
439
+
440
+ z = zs[idx] if not zs is None else None
441
+ z = z.expand(batch_size, -1, -1, -1)
442
+ if prompts:
443
+ ## classifier free guidance
444
+ noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
445
+ else:
446
+ noise_pred = uncond_out.sample
447
+ # 2. compute less noisy image and set x_t -> x_t-1
448
+ xt = reverse_step(model, noise_pred, t, xt, eta = etas[idx], variance_noise = z)
449
+
450
+ # interm denoised img
451
+ with autocast("cuda"), inference_mode():
452
+ x0_dec = sd_pipe.vae.decode(1 / 0.18215 * xt).sample
453
+ if x0_dec.dim()<4:
454
+ x0_dec = x0_dec[None,:,:,:]
455
+ interm_img = image_grid(x0_dec)
456
+ yield interm_img
457
+
458
+ return interm_img
459
+
460
+ # # vae decode image
461
+ # with autocast("cuda"), inference_mode():
462
+ # x0_dec = sd_pipe.vae.decode(1 / 0.18215 * w0).sample
463
+ # if x0_dec.dim()<4:
464
+ # x0_dec = x0_dec[None,:,:,:]
465
+ # img = image_grid(x0_dec)
466
+ # return img
467
+
468
+ # output = sample(wt, zs, wts, prompt_tar=tar_prompt)
469
 
470
+ # return output
471
 
472
 
473