Linoy Tsaban commited on
Commit
8cd26eb
1 Parent(s): 406e2d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -56
app.py CHANGED
@@ -92,77 +92,70 @@ def edit(input_image,
92
  # invert and retrieve noise maps and latent
93
  wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src)
94
 
95
- #
96
- xT=wts[skip]
97
- etas=1.0
98
- prompts=[tar_prompt]
99
- cfg_scales=[cfg_scale_tar]
100
- prog_bar=False
101
- zs=zs[skip:]
102
 
103
 
104
- batch_size = len(prompts)
105
 
106
- cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(sd_pipe.device)
107
 
108
- text_embeddings = encode_text(sd_pipe, prompts)
109
- uncond_embedding = encode_text(sd_pipe, [""] * batch_size)
110
 
111
- if etas is None: etas = 0
112
- if type(etas) in [int, float]: etas = [etas]*sd_pipe.scheduler.num_inference_steps
113
- assert len(etas) == sd_pipe.scheduler.num_inference_steps
114
- timesteps = sd_pipe.scheduler.timesteps.to(sd_pipe.device)
115
 
116
- xt = xT.expand(batch_size, -1, -1, -1)
117
- op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
118
 
119
- t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
120
 
121
- for t in op:
122
- idx = t_to_idx[int(t)]
123
- ## Unconditional embedding
124
- with torch.no_grad():
125
- uncond_out = sd_pipe.unet.forward(xt, timestep = t,
126
- encoder_hidden_states = uncond_embedding)
127
 
128
- ## Conditional embedding
129
- if prompts:
130
- with torch.no_grad():
131
- cond_out = sd_pipe.unet.forward(xt, timestep = t,
132
- encoder_hidden_states = text_embeddings)
133
 
134
 
135
- z = zs[idx] if not zs is None else None
136
- z = z.expand(batch_size, -1, -1, -1)
137
- if prompts:
138
- ## classifier free guidance
139
- noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
140
- else:
141
- noise_pred = uncond_out.sample
142
- # 2. compute less noisy image and set x_t -> x_t-1
143
- xt = reverse_step(sd_pipe, noise_pred, t, xt, eta = etas[idx], variance_noise = z)
144
-
145
- # interm denoised img
146
- with autocast("cuda"), inference_mode():
147
- x0_dec = sd_pipe.vae.decode(1 / 0.18215 * xt).sample
148
- if x0_dec.dim()<4:
149
- x0_dec = x0_dec[None,:,:,:]
150
- interm_img = image_grid(x0_dec)
151
- yield interm_img
152
 
153
- yield interm_img
154
 
155
- # # vae decode image
156
- # with autocast("cuda"), inference_mode():
157
- # x0_dec = sd_pipe.vae.decode(1 / 0.18215 * w0).sample
158
- # if x0_dec.dim()<4:
159
- # x0_dec = x0_dec[None,:,:,:]
160
- # img = image_grid(x0_dec)
161
- # return img
162
 
163
- # output = sample(wt, zs, wts, prompt_tar=tar_prompt)
164
 
165
- # return output
166
 
167
 
168
 
 
92
  # invert and retrieve noise maps and latent
93
  wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src)
94
 
95
+ # #
96
+ # xT=wts[skip]
97
+ # etas=1.0
98
+ # prompts=[tar_prompt]
99
+ # cfg_scales=[cfg_scale_tar]
100
+ # prog_bar=False
101
+ # zs=zs[skip:]
102
 
103
 
104
+ # batch_size = len(prompts)
105
 
106
+ # cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(sd_pipe.device)
107
 
108
+ # text_embeddings = encode_text(sd_pipe, prompts)
109
+ # uncond_embedding = encode_text(sd_pipe, [""] * batch_size)
110
 
111
+ # if etas is None: etas = 0
112
+ # if type(etas) in [int, float]: etas = [etas]*sd_pipe.scheduler.num_inference_steps
113
+ # assert len(etas) == sd_pipe.scheduler.num_inference_steps
114
+ # timesteps = sd_pipe.scheduler.timesteps.to(sd_pipe.device)
115
 
116
+ # xt = xT.expand(batch_size, -1, -1, -1)
117
+ # op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
118
 
119
+ # t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
120
 
121
+ # for t in op:
122
+ # idx = t_to_idx[int(t)]
123
+ # ## Unconditional embedding
124
+ # with torch.no_grad():
125
+ # uncond_out = sd_pipe.unet.forward(xt, timestep = t,
126
+ # encoder_hidden_states = uncond_embedding)
127
 
128
+ # ## Conditional embedding
129
+ # if prompts:
130
+ # with torch.no_grad():
131
+ # cond_out = sd_pipe.unet.forward(xt, timestep = t,
132
+ # encoder_hidden_states = text_embeddings)
133
 
134
 
135
+ # z = zs[idx] if not zs is None else None
136
+ # z = z.expand(batch_size, -1, -1, -1)
137
+ # if prompts:
138
+ # ## classifier free guidance
139
+ # noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
140
+ # else:
141
+ # noise_pred = uncond_out.sample
142
+ # # 2. compute less noisy image and set x_t -> x_t-1
143
+ # xt = reverse_step(sd_pipe, noise_pred, t, xt, eta = etas[idx], variance_noise = z)
144
+
145
+ # # interm denoised img
146
+ # with autocast("cuda"), inference_mode():
147
+ # x0_dec = sd_pipe.vae.decode(1 / 0.18215 * xt).sample
148
+ # if x0_dec.dim()<4:
149
+ # x0_dec = x0_dec[None,:,:,:]
150
+ # interm_img = image_grid(x0_dec)
151
+ # yield interm_img
152
 
153
+ # yield interm_img
154
 
 
 
 
 
 
 
 
155
 
156
+ output = sample(wt, zs, wts, prompt_tar=tar_prompt, cfg_scale_tar=cfg_scale_tar, skip=skip)
157
 
158
+ return output
159
 
160
 
161