Linoy Tsaban commited on
Commit
0d84727
1 Parent(s): c37a174

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -83
app.py CHANGED
@@ -78,86 +78,28 @@ def edit(input_image,
78
  cfg_scale_src = 3.5,
79
  cfg_scale_tar = 15,
80
  skip=36,
81
- seed = 0,
82
- left = 0,
83
- right = 0,
84
- top = 0,
85
- bottom = 0
86
  ):
87
  torch.manual_seed(seed)
88
  # offsets=(0,0,0,0)
89
  x0 = load_512(input_image, left,right, top, bottom, device)
90
 
91
-
92
- # invert and retrieve noise maps and latent
93
- wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src)
94
-
95
- # #
96
- # xT=wts[skip]
97
- # etas=1.0
98
- # prompts=[tar_prompt]
99
- # cfg_scales=[cfg_scale_tar]
100
- # prog_bar=False
101
- # zs=zs[skip:]
102
-
103
-
104
- # batch_size = len(prompts)
105
-
106
- # cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(sd_pipe.device)
107
-
108
- # text_embeddings = encode_text(sd_pipe, prompts)
109
- # uncond_embedding = encode_text(sd_pipe, [""] * batch_size)
110
-
111
- # if etas is None: etas = 0
112
- # if type(etas) in [int, float]: etas = [etas]*sd_pipe.scheduler.num_inference_steps
113
- # assert len(etas) == sd_pipe.scheduler.num_inference_steps
114
- # timesteps = sd_pipe.scheduler.timesteps.to(sd_pipe.device)
115
-
116
- # xt = xT.expand(batch_size, -1, -1, -1)
117
- # op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
118
-
119
- # t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
120
-
121
- # for t in op:
122
- # idx = t_to_idx[int(t)]
123
- # ## Unconditional embedding
124
- # with torch.no_grad():
125
- # uncond_out = sd_pipe.unet.forward(xt, timestep = t,
126
- # encoder_hidden_states = uncond_embedding)
127
-
128
- # ## Conditional embedding
129
- # if prompts:
130
- # with torch.no_grad():
131
- # cond_out = sd_pipe.unet.forward(xt, timestep = t,
132
- # encoder_hidden_states = text_embeddings)
133
-
134
-
135
- # z = zs[idx] if not zs is None else None
136
- # z = z.expand(batch_size, -1, -1, -1)
137
- # if prompts:
138
- # ## classifier free guidance
139
- # noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
140
- # else:
141
- # noise_pred = uncond_out.sample
142
- # # 2. compute less noisy image and set x_t -> x_t-1
143
- # xt = reverse_step(sd_pipe, noise_pred, t, xt, eta = etas[idx], variance_noise = z)
144
-
145
- # # interm denoised img
146
- # with autocast("cuda"), inference_mode():
147
- # x0_dec = sd_pipe.vae.decode(1 / 0.18215 * xt).sample
148
- # if x0_dec.dim()<4:
149
- # x0_dec = x0_dec[None,:,:,:]
150
- # interm_img = image_grid(x0_dec)
151
- # yield interm_img
152
-
153
- # yield interm_img
154
 
155
-
156
  output = sample(wt, zs, wts, prompt_tar=tar_prompt, cfg_scale_tar=cfg_scale_tar, skip=skip)
157
 
158
  return output
159
 
160
 
 
 
 
 
161
 
162
 
163
 
@@ -180,7 +122,9 @@ For faster inference without waiting in queue, you may duplicate the space and u
180
  <p/>"""
181
  with gr.Blocks() as demo:
182
  gr.HTML(intro)
183
-
 
 
184
  with gr.Row():
185
  input_image = gr.Image(label="Input Image", interactive=True)
186
  input_image.style(height=512, width=512)
@@ -188,7 +132,7 @@ with gr.Blocks() as demo:
188
  # inverted_image.style(height=512, width=512)
189
  output_image = gr.Image(label=f"Edited Image", interactive=False)
190
  output_image.style(height=512, width=512)
191
-
192
 
193
  with gr.Row():
194
  # with gr.Column(scale=1, min_width=100):
@@ -214,14 +158,6 @@ with gr.Blocks() as demo:
214
  skip = gr.Slider(minimum=0, maximum=40, value=36, precision=0, label="Skip Steps", interactive=True)
215
  cfg_scale_tar = gr.Slider(minimum=7, maximum=18,value=15, label=f"Target Guidance Scale", interactive=True)
216
  seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
217
-
218
- #shift
219
- with gr.Column():
220
- left = gr.Number(value=0, precision=0, label="Left Shift", interactive=True)
221
- right = gr.Number(value=0, precision=0, label="Right Shift", interactive=True)
222
- top = gr.Number(value=0, precision=0, label="Top Shift", interactive=True)
223
- bottom = gr.Number(value=0, precision=0, label="Bottom Shift", interactive=True)
224
-
225
 
226
 
227
 
@@ -255,14 +191,16 @@ with gr.Blocks() as demo:
255
  cfg_scale_tar,
256
  skip,
257
  seed,
258
- left,
259
- right,
260
- top,
261
- bottom
262
  ],
263
  outputs=[output_image],
264
  )
265
 
 
 
 
 
266
 
267
  gr.Examples(
268
  label='Examples',
 
78
  cfg_scale_src = 3.5,
79
  cfg_scale_tar = 15,
80
  skip=36,
81
+ wt = None,
82
+ zs = None,
83
+ wts = None
84
+
 
85
  ):
86
  torch.manual_seed(seed)
87
  # offsets=(0,0,0,0)
88
  x0 = load_512(input_image, left,right, top, bottom, device)
89
 
90
+ if not wt:
91
+ # invert and retrieve noise maps and latent
92
+ wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
 
94
  output = sample(wt, zs, wts, prompt_tar=tar_prompt, cfg_scale_tar=cfg_scale_tar, skip=skip)
95
 
96
  return output
97
 
98
 
99
+ def reset_latents():
100
+ wt = gr.State(value=None)
101
+ zs = gr.State(value=None)
102
+ wts = gr.State(value=None)
103
 
104
 
105
 
 
122
  <p/>"""
123
  with gr.Blocks() as demo:
124
  gr.HTML(intro)
125
+ wt = gr.State(value=None)
126
+ zs = gr.State(value=None)
127
+ wts = gr.State(value=None)
128
  with gr.Row():
129
  input_image = gr.Image(label="Input Image", interactive=True)
130
  input_image.style(height=512, width=512)
 
132
  # inverted_image.style(height=512, width=512)
133
  output_image = gr.Image(label=f"Edited Image", interactive=False)
134
  output_image.style(height=512, width=512)
135
+
136
 
137
  with gr.Row():
138
  # with gr.Column(scale=1, min_width=100):
 
158
  skip = gr.Slider(minimum=0, maximum=40, value=36, precision=0, label="Skip Steps", interactive=True)
159
  cfg_scale_tar = gr.Slider(minimum=7, maximum=18,value=15, label=f"Target Guidance Scale", interactive=True)
160
  seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
 
 
 
 
 
 
 
 
161
 
162
 
163
 
 
191
  cfg_scale_tar,
192
  skip,
193
  seed,
194
+ new_inversion,
195
+
 
 
196
  ],
197
  outputs=[output_image],
198
  )
199
 
200
+ input_image.change(
201
+ fn = reset_latents
202
+ )
203
+
204
 
205
  gr.Examples(
206
  label='Examples',