multimodalart HF staff commited on
Commit
a722e19
1 Parent(s): f34dfad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -10
app.py CHANGED
@@ -119,8 +119,8 @@ def load_and_invert(
119
  skip=skip,
120
  eta=1.0,
121
  )
122
- wts = gr.State(value=wts_tensor)
123
- zs = gr.State(value=zs_tensor)
124
  do_inversion = False
125
 
126
  return wts, zs, do_inversion, gr.update(visible=False)
@@ -173,8 +173,8 @@ def edit(input_image,
173
  skip = skip,
174
  eta = 1.0,
175
  )
176
- wts = gr.State(value=wts_tensor)
177
- zs = gr.State(value=zs_tensor)
178
  do_inversion = False
179
 
180
  if image_caption.lower() == tar_prompt.lower(): # if image caption was not changed, run pure sega
@@ -194,7 +194,7 @@ def edit(input_image,
194
  use_intersect_mask=use_intersect_mask
195
  )
196
 
197
- latnets = wts.value[-1].expand(1, -1, -1, -1)
198
  sega_out = pipe(prompt=tar_prompt,
199
  init_latents=latnets,
200
  guidance_scale = tar_cfg_scale,
@@ -202,7 +202,7 @@ def edit(input_image,
202
  # num_inference_steps=steps,
203
  # use_ddpm=True,
204
  # wts=wts.value,
205
- zs=zs.value, **editing_args)
206
 
207
  return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
208
 
@@ -210,12 +210,12 @@ def edit(input_image,
210
  else: # if sega concepts were not added, performs regular ddpm sampling
211
 
212
  if do_reconstruction: # if ddpm sampling wasn't computed
213
- pure_ddpm_img = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
214
- reconstruction = gr.State(value=pure_ddpm_img)
215
  do_reconstruction = False
216
  return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
217
 
218
- return reconstruction.value, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
219
 
220
 
221
  def randomize_seed_fn(seed, is_random):
@@ -872,6 +872,5 @@ with gr.Blocks(css="style.css") as demo:
872
  cache_examples=True
873
  )
874
 
875
-
876
  demo.queue()
877
  demo.launch()
 
119
  skip=skip,
120
  eta=1.0,
121
  )
122
+ wts = wts_tensor
123
+ zs = zs_tensor
124
  do_inversion = False
125
 
126
  return wts, zs, do_inversion, gr.update(visible=False)
 
173
  skip = skip,
174
  eta = 1.0,
175
  )
176
+ wts = wts_tensor
177
+ zs = zs_tensor
178
  do_inversion = False
179
 
180
  if image_caption.lower() == tar_prompt.lower(): # if image caption was not changed, run pure sega
 
194
  use_intersect_mask=use_intersect_mask
195
  )
196
 
197
+ latnets = wts[-1].expand(1, -1, -1, -1)
198
  sega_out = pipe(prompt=tar_prompt,
199
  init_latents=latnets,
200
  guidance_scale = tar_cfg_scale,
 
202
  # num_inference_steps=steps,
203
  # use_ddpm=True,
204
  # wts=wts.value,
205
+ zs=zs, **editing_args)
206
 
207
  return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
208
 
 
210
  else: # if sega concepts were not added, performs regular ddpm sampling
211
 
212
  if do_reconstruction: # if ddpm sampling wasn't computed
213
+ pure_ddpm_img = sample(zs, wts, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
214
+ reconstruction = pure_ddpm_img
215
  do_reconstruction = False
216
  return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
217
 
218
+ return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
219
 
220
 
221
  def randomize_seed_fn(seed, is_random):
 
872
  cache_examples=True
873
  )
874
 
 
875
  demo.queue()
876
  demo.launch()