Linoy Tsaban commited on
Commit
3e5c24d
1 Parent(s): d754544

Update app.py

Browse files

testing interm yielding

Files changed (1) hide show
  1. app.py +8 -314
app.py CHANGED
@@ -5,317 +5,11 @@ from io import BytesIO
5
  from diffusers import StableDiffusionPipeline
6
  from diffusers import DDIMScheduler
7
  from utils import *
8
- # from inversion_utils import *
9
  from torch import autocast, inference_mode
10
  import re
11
 
12
 
13
- ############################################################################################################################################################################
14
- import torch
15
- import os
16
- from tqdm import tqdm
17
- from PIL import Image, ImageDraw ,ImageFont
18
- from matplotlib import pyplot as plt
19
- import torchvision.transforms as T
20
- import os
21
- import yaml
22
- import numpy as np
23
- import gradio as gr
24
-
25
-
26
- def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None):
27
- if type(image_path) is str:
28
- image = np.array(Image.open(image_path).convert('RGB'))[:, :, :3]
29
- else:
30
- image = image_path
31
- h, w, c = image.shape
32
- left = min(left, w-1)
33
- right = min(right, w - left - 1)
34
- top = min(top, h - left - 1)
35
- bottom = min(bottom, h - top - 1)
36
- image = image[top:h-bottom, left:w-right]
37
- h, w, c = image.shape
38
- if h < w:
39
- offset = (w - h) // 2
40
- image = image[:, offset:offset + h]
41
- elif w < h:
42
- offset = (h - w) // 2
43
- image = image[offset:offset + w]
44
- image = np.array(Image.fromarray(image).resize((512, 512)))
45
- image = torch.from_numpy(image).float() / 127.5 - 1
46
- image = image.permute(2, 0, 1).unsqueeze(0).to(device)
47
-
48
- return image
49
-
50
-
51
- def load_real_image(folder = "data/", img_name = None, idx = 0, img_size=512, device='cuda'):
52
- from PIL import Image
53
- from glob import glob
54
- if img_name is not None:
55
- path = os.path.join(folder, img_name)
56
- else:
57
- path = glob(folder + "*")[idx]
58
-
59
- img = Image.open(path).resize((img_size,
60
- img_size))
61
-
62
- img = pil_to_tensor(img).to(device)
63
-
64
- if img.shape[1]== 4:
65
- img = img[:,:3,:,:]
66
- return img
67
-
68
- def mu_tilde(model, xt,x0, timestep):
69
- "mu_tilde(x_t, x_0) DDPM paper eq. 7"
70
- prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
71
- alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
72
- alpha_t = model.scheduler.alphas[timestep]
73
- beta_t = 1 - alpha_t
74
- alpha_bar = model.scheduler.alphas_cumprod[timestep]
75
- return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1-alpha_bar)) * x0 + ((alpha_t**0.5 *(1-alpha_prod_t_prev)) / (1- alpha_bar))*xt
76
-
77
- def sample_xts_from_x0(model, x0, num_inference_steps=50):
78
- """
79
- Samples from P(x_1:T|x_0)
80
- """
81
- # torch.manual_seed(43256465436)
82
- alpha_bar = model.scheduler.alphas_cumprod
83
- sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5
84
- alphas = model.scheduler.alphas
85
- betas = 1 - alphas
86
- variance_noise_shape = (
87
- num_inference_steps,
88
- model.unet.in_channels,
89
- model.unet.sample_size,
90
- model.unet.sample_size)
91
-
92
- timesteps = model.scheduler.timesteps.to(model.device)
93
- t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
94
- xts = torch.zeros(variance_noise_shape).to(x0.device)
95
- for t in reversed(timesteps):
96
- idx = t_to_idx[int(t)]
97
- xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
98
- xts = torch.cat([xts, x0 ],dim = 0)
99
-
100
- return xts
101
-
102
- def encode_text(model, prompts):
103
- text_input = model.tokenizer(
104
- prompts,
105
- padding="max_length",
106
- max_length=model.tokenizer.model_max_length,
107
- truncation=True,
108
- return_tensors="pt",
109
- )
110
- with torch.no_grad():
111
- text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0]
112
- return text_encoding
113
-
114
- def forward_step(model, model_output, timestep, sample):
115
- next_timestep = min(model.scheduler.config.num_train_timesteps - 2,
116
- timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps)
117
-
118
- # 2. compute alphas, betas
119
- alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
120
- # alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 else self.scheduler.final_alpha_cumprod
121
-
122
- beta_prod_t = 1 - alpha_prod_t
123
-
124
- # 3. compute predicted original sample from predicted noise also called
125
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
126
- pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
127
-
128
- # 5. TODO: simple noising implementatiom
129
- next_sample = model.scheduler.add_noise(pred_original_sample,
130
- model_output,
131
- torch.LongTensor([next_timestep]))
132
- return next_sample
133
-
134
-
135
- def get_variance(model, timestep): #, prev_timestep):
136
- prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
137
- alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
138
- alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
139
- beta_prod_t = 1 - alpha_prod_t
140
- beta_prod_t_prev = 1 - alpha_prod_t_prev
141
- variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
142
- return variance
143
-
144
- def inversion_forward_process(model, x0,
145
- etas = None,
146
- prog_bar = False,
147
- prompt = "",
148
- cfg_scale = 3.5,
149
- num_inference_steps=50, eps = None
150
- ):
151
-
152
- if not prompt=="":
153
- text_embeddings = encode_text(model, prompt)
154
- uncond_embedding = encode_text(model, "")
155
- timesteps = model.scheduler.timesteps.to(model.device)
156
- variance_noise_shape = (
157
- num_inference_steps,
158
- model.unet.in_channels,
159
- model.unet.sample_size,
160
- model.unet.sample_size)
161
- if etas is None or (type(etas) in [int, float] and etas == 0):
162
- eta_is_zero = True
163
- zs = None
164
- else:
165
- eta_is_zero = False
166
- if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
167
- xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps)
168
- alpha_bar = model.scheduler.alphas_cumprod
169
- zs = torch.zeros(size=variance_noise_shape, device=model.device)
170
-
171
- t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
172
- xt = x0
173
- op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps)
174
-
175
- for t in op:
176
- idx = t_to_idx[int(t)]
177
- # 1. predict noise residual
178
- if not eta_is_zero:
179
- xt = xts[idx][None]
180
-
181
- with torch.no_grad():
182
- out = model.unet.forward(xt, timestep = t, encoder_hidden_states = uncond_embedding)
183
- if not prompt=="":
184
- cond_out = model.unet.forward(xt, timestep=t, encoder_hidden_states = text_embeddings)
185
-
186
- if not prompt=="":
187
- ## classifier free guidance
188
- noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample)
189
- else:
190
- noise_pred = out.sample
191
-
192
- if eta_is_zero:
193
- # 2. compute more noisy image and set x_t -> x_t+1
194
- xt = forward_step(model, noise_pred, t, xt)
195
-
196
- else:
197
- xtm1 = xts[idx+1][None]
198
- # pred of x0
199
- pred_original_sample = (xt - (1-alpha_bar[t]) ** 0.5 * noise_pred ) / alpha_bar[t] ** 0.5
200
-
201
- # direction to xt
202
- prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
203
- alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
204
-
205
- variance = get_variance(model, t)
206
- pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance ) ** (0.5) * noise_pred
207
-
208
- mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
209
-
210
- z = (xtm1 - mu_xt ) / ( etas[idx] * variance ** 0.5 )
211
- zs[idx] = z
212
-
213
- # correction to avoid error accumulation
214
- xtm1 = mu_xt + ( etas[idx] * variance ** 0.5 )*z
215
- xts[idx+1] = xtm1
216
-
217
- if not zs is None:
218
- zs[-1] = torch.zeros_like(zs[-1])
219
-
220
- return xt, zs, xts
221
-
222
-
223
- def reverse_step(model, model_output, timestep, sample, eta = 0, variance_noise=None):
224
- # 1. get previous step value (=t-1)
225
- prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
226
- # 2. compute alphas, betas
227
- alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
228
- alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
229
- beta_prod_t = 1 - alpha_prod_t
230
- # 3. compute predicted original sample from predicted noise also called
231
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
232
- pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
233
- # 5. compute variance: "sigma_t(η)" -> see formula (16)
234
- # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
235
- # variance = self.scheduler._get_variance(timestep, prev_timestep)
236
- variance = get_variance(model, timestep) #, prev_timestep)
237
- std_dev_t = eta * variance ** (0.5)
238
- # Take care of asymetric reverse process (asyrp)
239
- model_output_direction = model_output
240
- # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
241
- # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
242
- pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
243
- # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
244
- prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
245
- # 8. Add noice if eta > 0
246
- if eta > 0:
247
- if variance_noise is None:
248
- variance_noise = torch.randn(model_output.shape, device=model.device)
249
- sigma_z = eta * variance ** (0.5) * variance_noise
250
- prev_sample = prev_sample + sigma_z
251
-
252
- return prev_sample
253
-
254
- def inversion_reverse_process(model,
255
- xT,
256
- etas = 0,
257
- prompts = "",
258
- cfg_scales = None,
259
- prog_bar = False,
260
- zs = None,
261
- controller=None,
262
- asyrp = False
263
- ):
264
-
265
- batch_size = len(prompts)
266
-
267
- cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(model.device)
268
-
269
- text_embeddings = encode_text(model, prompts)
270
- uncond_embedding = encode_text(model, [""] * batch_size)
271
-
272
- if etas is None: etas = 0
273
- if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
274
- assert len(etas) == model.scheduler.num_inference_steps
275
- timesteps = model.scheduler.timesteps.to(model.device)
276
-
277
- xt = xT.expand(batch_size, -1, -1, -1)
278
- op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
279
-
280
- t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
281
-
282
- for t in op:
283
- idx = t_to_idx[int(t)]
284
- ## Unconditional embedding
285
- with torch.no_grad():
286
- uncond_out = model.unet.forward(xt, timestep = t,
287
- encoder_hidden_states = uncond_embedding)
288
-
289
- ## Conditional embedding
290
- if prompts:
291
- with torch.no_grad():
292
- cond_out = model.unet.forward(xt, timestep = t,
293
- encoder_hidden_states = text_embeddings)
294
-
295
-
296
- z = zs[idx] if not zs is None else None
297
- z = z.expand(batch_size, -1, -1, -1)
298
- if prompts:
299
- ## classifier free guidance
300
- noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
301
- else:
302
- noise_pred = uncond_out.sample
303
- # 2. compute less noisy image and set x_t -> x_t-1
304
- xt = reverse_step(model, noise_pred, t, xt, eta = etas[idx], variance_noise = z)
305
-
306
- # interm denoised img
307
- with autocast("cuda"), inference_mode():
308
- x0_dec = sd_pipe.vae.decode(1 / 0.18215 * xt).sample
309
- if x0_dec.dim()<4:
310
- x0_dec = x0_dec[None,:,:,:]
311
- interm_img = image_grid(x0_dec)
312
- yield interm_img
313
-
314
- if controller is not None:
315
- xt = controller.step_callback(xt)
316
- return xt, zs
317
- ############################################################################################################################################################################
318
-
319
  def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
320
 
321
  # inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
@@ -408,15 +102,15 @@ def edit(input_image,
408
 
409
  batch_size = len(prompts)
410
 
411
- cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(model.device)
412
 
413
  text_embeddings = encode_text(model, prompts)
414
  uncond_embedding = encode_text(model, [""] * batch_size)
415
 
416
  if etas is None: etas = 0
417
- if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
418
- assert len(etas) == model.scheduler.num_inference_steps
419
- timesteps = model.scheduler.timesteps.to(model.device)
420
 
421
  xt = xT.expand(batch_size, -1, -1, -1)
422
  op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
@@ -427,13 +121,13 @@ def edit(input_image,
427
  idx = t_to_idx[int(t)]
428
  ## Unconditional embedding
429
  with torch.no_grad():
430
- uncond_out = model.unet.forward(xt, timestep = t,
431
  encoder_hidden_states = uncond_embedding)
432
 
433
  ## Conditional embedding
434
  if prompts:
435
  with torch.no_grad():
436
- cond_out = model.unet.forward(xt, timestep = t,
437
  encoder_hidden_states = text_embeddings)
438
 
439
 
@@ -445,7 +139,7 @@ def edit(input_image,
445
  else:
446
  noise_pred = uncond_out.sample
447
  # 2. compute less noisy image and set x_t -> x_t-1
448
- xt = reverse_step(model, noise_pred, t, xt, eta = etas[idx], variance_noise = z)
449
 
450
  # interm denoised img
451
  with autocast("cuda"), inference_mode():
 
5
  from diffusers import StableDiffusionPipeline
6
  from diffusers import DDIMScheduler
7
  from utils import *
8
+ from inversion_utils import *
9
  from torch import autocast, inference_mode
10
  import re
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
14
 
15
  # inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
 
102
 
103
  batch_size = len(prompts)
104
 
105
+ cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(sd_pipe.device)
106
 
107
  text_embeddings = encode_text(model, prompts)
108
  uncond_embedding = encode_text(model, [""] * batch_size)
109
 
110
  if etas is None: etas = 0
111
+ if type(etas) in [int, float]: etas = [etas]*sd_pipe.scheduler.num_inference_steps
112
+ assert len(etas) == sd_pipe.scheduler.num_inference_steps
113
+ timesteps = sd_pipe.scheduler.timesteps.to(sd_pipe.device)
114
 
115
  xt = xT.expand(batch_size, -1, -1, -1)
116
  op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
 
121
  idx = t_to_idx[int(t)]
122
  ## Unconditional embedding
123
  with torch.no_grad():
124
+ uncond_out = sd_pipe.unet.forward(xt, timestep = t,
125
  encoder_hidden_states = uncond_embedding)
126
 
127
  ## Conditional embedding
128
  if prompts:
129
  with torch.no_grad():
130
+ cond_out = sd_pipe.unet.forward(xt, timestep = t,
131
  encoder_hidden_states = text_embeddings)
132
 
133
 
 
139
  else:
140
  noise_pred = uncond_out.sample
141
  # 2. compute less noisy image and set x_t -> x_t-1
142
+ xt = reverse_step(sd_pipe, noise_pred, t, xt, eta = etas[idx], variance_noise = z)
143
 
144
  # interm denoised img
145
  with autocast("cuda"), inference_mode():